Ejemplo n.º 1
0
    def test_return_simple(self):
        orig_code = utils.clip_head("""
        def func(a, b):
            return a + b

        value = func(3, 39)
        """)
        target_code = utils.clip_head("""
        def func(a, b):
            returned_value = None
            returned_1 = False

            returned_1 = True
            returned_value = a + b
            
            return returned_value

        value = func(3, 39)
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['value'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
Ejemplo n.º 2
0
    def test_break(self):
        orig_code = utils.clip_head("""
        x = 0
        for i in range(10):
            if i == 5:
                break
            x += i
        """)
        target_code = utils.clip_head("""
        x = 0
        for i in range(10):
            breaked_0 = False
            if i == 5:
                breaked_0 = True
            if not breaked_0:
                x += i
            keepgoing = not breaked_0
            if breaked_0:
                break
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['x', 'i'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
Ejemplo n.º 3
0
    def test_continue_nested(self):
        orig_code = utils.clip_head("""
        x = 0
        for i in range(10):
            if i == 5:
                continue
            for j in range(10):
                if j == 5:
                    continue
                x += i * j
        """)
        target_code = utils.clip_head("""
        x = 0
        for i in range(10):
            continued_0 = False
            if i == 5:
                continued_0 = True
            if not continued_0:
                for j in range(10):
                    continued_1 = False
                    if j == 5:
                        continued_1 = True
                    if not continued_1:
                        x += i * j
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['x', 'i', 'j'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
Ejemplo n.º 4
0
    def __init__(self, classinfo):
        super().__init__()

        members = inspect.getmembers(classinfo)
        init_func = [m[1] for m in members if m[0] == '__init__']
        assert(len(init_func) == 1)

        func = init_func[0]
        self.inst = func
        self.name = func.__name__
        self.filename = inspect.getfile(func)

        sourcelines = inspect.getsourcelines(func)
        if sourcelines is None or len(sourcelines) < 1:
            utils.print_warning('Failed to parase {}'.format(classinfo), utils.LineProperty())
            return

        self.lineno = sourcelines[1]
        self.classinfo = classinfo

        original_code = inspect.getsource(func)
        code = utils.clip_head(original_code)

        self.args.analyze_args(func)

        ast_ = gast.ast_to_gast(ast.parse(code)).body[0]
        self.ast = canonicalizer.Canonicalizer().visit(ast_)
Ejemplo n.º 5
0
    def infer_user_defined_function(self, func, ty_args, node):
        if isinstance(func, types.FunctionType) or \
                isinstance(func, types.MethodType):
            func_body = func

            if isinstance(node.func, gast.Attribute):
                ty_self = self.nodetype[node.func.value]
                ty_args = [ty_self] + ty_args

        else:
            # defined with __call__
            if isinstance(func, chainer.Chain):
                func_body = func.forward
            else:
                func_body = func.__call__

            ty_self = type_of_value(func)
            ty_args = [ty_self] + ty_args

        code = clip_head(inspect.getsource(func_body))
        # FunctionDef of called subroutine
        func_node = gast.ast_to_gast(ast.parse(code)).body[0]
        self.subroutine_node[node] = func_node
        tc = InferenceEngine(is_debug=self.is_debug,
                             module=sys.modules[func.__module__])
        tc.infer_function(func_node,
                          ty_args,
                          type_hints=typing.get_type_hints(func_body))

        # copy nodetype and subroutine_node from subroutine
        utils.add_dict(self.nodetype, tc.nodetype)
        utils.add_dict(self.subroutine_node, tc.subroutine_node)
        return ty_args, tc.nodetype[func_node]
Ejemplo n.º 6
0
    def __init__(self, ch, env):
        src = clip_head(inspect.getsource(ch.forward))
        dprint(src)
        self.ast = gast.ast_to_gast(ast.parse(src)).body[0]

        self.call = User_Defined_Func_In_Link(ch, ch.forward).call

        # 以下、 最初の外からのためのやつ
        # code.InteractiveConsole({'v': self.ast}).interact()
        self.forward_arglen = len(self.ast.args.args) - 1

        # ここで、初期化したやつを上書きしてやる必要が出てくる
        # あとでchainerで実行するために回復しないといけないので、
        # restore_funcs に復元すべきものを追加している
        self.inits = []

        for s, v in ch.namedparams():
            s = s[1:]
            if s.find('/') != -1:
                continue
            t = helper.make_tensor_value_info('/' + s, TensorProto.FLOAT,
                                              list(v.shape))
            self.inits.append(t)
            mv = getattr(ch, s)
            setattr(ch, s, t)
            env.restore_funcs.append(lambda: setattr(ch, s, mv))

        # TODO(satos) Yieldをコンパイルできるとこれを消せる
        mv = getattr(ch, 'children')
        setattr(ch, 'children', Func(lambda _, __, ___: mv()))
        env.restore_funcs.append(lambda: setattr(ch, 'children', mv))
Ejemplo n.º 7
0
    def test_continue_break_nested(self):
        orig_code = utils.clip_head("""
        x = 0
        for i in range(10):
            if i == 5:
                continue
            if i == 6:
                break
            for j in range(10):
                if j == 5:
                    break
                x += i * j
        """)
        target_code = utils.clip_head("""
        x = 0
        for i in range(10):
            breaked_1 = False
            continued_0 = False
            if i == 5:
                continued_0 = True
            if not continued_0:
                if i == 6:
                    breaked_1 = True
            if not continued_0 and not breaked_1:
                for j in range(10):
                    breaked_2 = False
                    if j == 5:
                        breaked_2 = True
                    if not breaked_2:
                        x += i * j
                    keepgoing = not breaked_2
                    if breaked_2:
                        break
            keepgoing = not breaked_1
            if breaked_1:
                break
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['x', 'i', 'j'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
Ejemplo n.º 8
0
def generate_id2type_from_forward(model, args, is_debug=False):
    code = utils.clip_head(inspect.getsource(model.forward))
    tree = gast.ast_to_gast(ast.parse(code))
    module = sys.modules[model.forward.__module__]
    node2type, subroutine_node = generate_node2type(
            tree, (model,) + args, is_debug=is_debug, module=module,
            type_hints=typing.get_type_hints(model.forward))
    node2id = generate_node2id(tree, subroutine_node)
    id2type = generate_id2type(node2type, node2id)
    return id2type
Ejemplo n.º 9
0
    def __init__(self, func):
        super().__init__()

        self.inst = func
        self.name = func.__name__
        self.filename = inspect.getfile(func)
        sourcelines = inspect.getsourcelines(func)
        self.lineno = sourcelines[1]
        self.args.analyze_args(func)

        if (func.__name__ == (lambda: None).__name__):
            original_code = utils.lambda_source(func)
            code = 'return ' + original_code[re.search('lambda.*?:', original_code).end():]
            self.ast = gast.ast_to_gast(ast.parse(code))
        else:
            original_code = inspect.getsource(func)
            code = utils.clip_head(original_code)
            ast_ = gast.ast_to_gast(ast.parse(code)).body[0]
            self.ast = canonicalizer.Canonicalizer().visit(ast_)
Ejemplo n.º 10
0
    def test_return(self):
        orig_code = utils.clip_head("""
        def func(a, b):
            for i in range(a):
                if i == b:
                    return i
            return 0

        value = 0
        for a in range(10):
            for b in range(10):
                value += func(a, b)
        """)
        target_code = utils.clip_head("""
        def func(a, b):
            returned_value = None
            returned_1 = False
            for i in range(a):
                if i == b:
                    returned_1 = True
                    returned_value = i
                keepgoing = not returned_1
                if returned_1:
                    break
            if not returned_1:
                returned_1 = True
                returned_value = 0
            return returned_value

        value = 0
        for a in range(10):
            for b in range(10):
                value += func(a, b)
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['value'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
Ejemplo n.º 11
0
def eval_list_comp(nast, env):
    vn = "dummy@" + new_tensor().name  # 重ならない名前にする(ループ内ループもあるため)
    assert len(nast.generators) >= 1
    tast = gast.ast_to_gast(ast.parse("v.append(w)")).body[0]
    tast.value.func.value.id = vn
    tast.value.args[0] = nast.elt

    for gen in nast.generators:
        # とりあえず、このあたりはまだ実装しません
        assert len(gen.ifs) == 0 and gen.is_async == 0
        tast = gast.For(target=gen.target, iter=gen.iter,
                        body=[tast], orelse=[])

    init = gast.ast_to_gast(ast.parse("v = []")).body[0]
    init.targets[0].id = vn
    tast = [init, tast]

    rv = eval_ast(tast, env)
    assert rv.is_none()
    res = env.pop_var(vn)
    return res
Ejemplo n.º 12
0
    def __init__(self, func):
        super().__init__()

        self.inst = func
        self.name = func.__name__
        self.lineno = inspect.getsourcelines(func)[1]

        code = utils.clip_head(inspect.getsource(func))

        self.analyze_args(func)

        self.ast = gast.ast_to_gast(ast.parse(code)).body[0]
Ejemplo n.º 13
0
def generate_type_inference_results(model, forward_args, is_debug=True):
    code = utils.clip_head(inspect.getsource(model.forward))
    node = gast.ast_to_gast(ast.parse(code))
    # node = Canonicalizer().visit(node)
    module = sys.modules[model.forward.__module__]
    node2type, subroutine_node = generate_node2type(
        node, (model,) + forward_args, is_debug=is_debug, module=module,
        type_hints=typing.get_type_hints(model.forward))
    node2id = generate_node2id(node, subroutine_node)
    id2type = generate_id2type(node2type, node2id)
    id2node = generate_id2node(node2id)
    return id2type, id2node
Ejemplo n.º 14
0
 def __call__(self, codeobj):
     cache = self.cache
     key = self.get_file_info(codeobj)
     result = cache.get(key)
     if result is not None:
         return result
     fname = key[0]
     cache[(fname, 0)] = mod_ast = gast.ast_to_gast(self.parse_file(fname))
     for obj in gast.walk(mod_ast):
         if isinstance(obj, gast.FunctionDef):
             cache[(fname, obj.lineno)] = obj
     return cache[key]
Ejemplo n.º 15
0
    def __init__(self, func):
        super().__init__()

        self.inst = func
        self.name = func.__name__
        self.filename = inspect.getfile(func)
        sourcelines = inspect.getsourcelines(func)
        self.lineno = sourcelines[1]

        code = utils.clip_head(inspect.getsource(func))
        self.args.analyze_args(func)

        ast_ = gast.ast_to_gast(ast.parse(code)).body[0]
        self.ast = canonicalizer.Canonicalizer().visit(ast_)
Ejemplo n.º 16
0
    def __init__(self, classinfo):
        super().__init__()

        members = inspect.getmembers(classinfo)
        init_func = [m[1] for m in members if m[0] == '__init__']
        assert(len(init_func) == 1)

        func = init_func[0]
        self.inst = func
        self.name = func.__name__
        self.lineno = inspect.getsourcelines(func)[1]
        self.classinfo = classinfo

        code = utils.clip_head(inspect.getsource(func))

        self.analyze_args(func)

        self.ast = gast.ast_to_gast(ast.parse(code)).body[0]
Ejemplo n.º 17
0
 def test_usub(self):
     orig_ast = gast.ast_to_gast(ast.parse("-3"))
     target_ast = gast.Module(body=[gast.Expr(value=gast.Num(n=-3))])
     assert compare_ast(self.canonicalizer.visit(orig_ast), target_ast)
Ejemplo n.º 18
0
 def __init__(self, func):
     self.func = func
     src = clip_head(inspect.getsource(func))
     dprint(src)
     self.ast = gast.ast_to_gast(ast.parse(src)).body[0]
     assert (isinstance(self.ast, gast.gast.FunctionDef))
Ejemplo n.º 19
0
    def test_return_continue(self):
        orig_code = utils.clip_head("""
        def func(a, b, c):
            x = 0
            for i in range(a):
                if i == b:
                    continue
                for j in range(a):
                    if j == c:
                        continue
                    if j == b:
                        return x
                    x += i * j
            return x

        value = 0
        for a in range(10):
            for b in range(10):
                for c in range(10):
                    value += func(a, b, c)
        """)
        target_code = utils.clip_head("""
        def func(a, b, c):
            returned_value = None
            returned_1 = False
            x = 0
            for i in range(a):
                continued_1 = False
                if i == b:
                    continued_1 = True
                if not continued_1:
                    for j in range(a):
                        continued_2 = False
                        if j == c:
                            continued_2 = True
                        if not continued_2:
                            if j == b:
                                if not continued_2:
                                    returned_1 = True
                                    returned_value = x
                        if not continued_2 and not returned_1:
                            x += i * j
                        keepgoing = not returned_1
                        if returned_1:
                            break
                keepgoing = not returned_1
                if returned_1:
                    break
            if not returned_1:
                returned_1 = True
                returned_value = x
            return returned_value

        value = 0
        for a in range(10):
            for b in range(10):
                for c in range(10):
                    value += func(a, b, c)
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['value'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
 def test_usub(self):
     orig_ast = gast.ast_to_gast(ast.parse("-3"))
     target_ast = gast.Module(
         body=[gast.Expr(value=gast.Constant(value=-3, kind=None))],
         type_ignores=[])
     assert compare_ast(self.canonicalizer.visit(orig_ast), target_ast)
Ejemplo n.º 21
0
def parse(file):
    code = open(file).read()
    tree = ast.parse(code)
    gtree = gast.ast_to_gast(tree)
    return gtree
Ejemplo n.º 22
0
import ast
import gast
import sys

code = open(sys.argv[1]).read()
tree = ast.parse(code)
gtree = gast.ast_to_gast(tree)

for node in gtree.body:
    print(ast.dump(node))
Ejemplo n.º 23
0
    from copy import deepcopy
    try:
        from astmonkey import transformers, visitors
        IMPORT_ASTMONKEY = True
    except ImportError:
        IMPORT_ASTMONKEY = False

    def dump_ast(mod, name):
        print(gast.dump(mod))
        if IMPORT_ASTMONKEY:
            mod = deepcopy(mod)
            mod = transformers.ParentChildNodeTransformer().visit(
                deepcopy(mod))
            visitor = visitors.GraphNodeVisitor()
            visitor.visit(mod)
            visitor.graph.write_png(name + '.png')
            print(
                "\033[1;32;40mAST visualization saved as \033[94m%s.png\033[0m"
                % name)
        else:
            print("\033[93mInstall astmonkey for visualization.\033[0m")

    code = open(sys.argv[1]).read()
    orig_ast = gast.ast_to_gast(ast.parse(code))
    print('=== Original AST ===')
    dump_ast(orig_ast, 'original')

    print('=== Canonicalized AST ===')
    canon_ast = Canonicalizer().visit(orig_ast)
    dump_ast(canon_ast, 'canonicalized')