예제 #1
0
    def test_continue_nested(self):
        orig_code = utils.clip_head("""
        x = 0
        for i in range(10):
            if i == 5:
                continue
            for j in range(10):
                if j == 5:
                    continue
                x += i * j
        """)
        target_code = utils.clip_head("""
        x = 0
        for i in range(10):
            continued_0 = False
            if i == 5:
                continued_0 = True
            if not continued_0:
                for j in range(10):
                    continued_1 = False
                    if j == 5:
                        continued_1 = True
                    if not continued_1:
                        x += i * j
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['x', 'i', 'j'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
예제 #2
0
    def test_return_simple(self):
        orig_code = utils.clip_head("""
        def func(a, b):
            return a + b

        value = func(3, 39)
        """)
        target_code = utils.clip_head("""
        def func(a, b):
            returned_value = None
            returned_1 = False

            returned_1 = True
            returned_value = a + b
            
            return returned_value

        value = func(3, 39)
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['value'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
예제 #3
0
    def test_break(self):
        orig_code = utils.clip_head("""
        x = 0
        for i in range(10):
            if i == 5:
                break
            x += i
        """)
        target_code = utils.clip_head("""
        x = 0
        for i in range(10):
            breaked_0 = False
            if i == 5:
                breaked_0 = True
            if not breaked_0:
                x += i
            keepgoing = not breaked_0
            if breaked_0:
                break
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['x', 'i'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
예제 #4
0
    def __init__(self, classinfo):
        super().__init__()

        members = inspect.getmembers(classinfo)
        init_func = [m[1] for m in members if m[0] == '__init__']
        assert(len(init_func) == 1)

        func = init_func[0]
        self.inst = func
        self.name = func.__name__
        self.filename = inspect.getfile(func)

        sourcelines = inspect.getsourcelines(func)
        if sourcelines is None or len(sourcelines) < 1:
            utils.print_warning('Failed to parase {}'.format(classinfo), utils.LineProperty())
            return

        self.lineno = sourcelines[1]
        self.classinfo = classinfo

        original_code = inspect.getsource(func)
        code = utils.clip_head(original_code)

        self.args.analyze_args(func)

        ast_ = gast.ast_to_gast(ast.parse(code)).body[0]
        self.ast = canonicalizer.Canonicalizer().visit(ast_)
예제 #5
0
    def infer_user_defined_function(self, func, ty_args, node):
        if isinstance(func, types.FunctionType) or \
                isinstance(func, types.MethodType):
            func_body = func

            if isinstance(node.func, gast.Attribute):
                ty_self = self.nodetype[node.func.value]
                ty_args = [ty_self] + ty_args

        else:
            # defined with __call__
            if isinstance(func, chainer.Chain):
                func_body = func.forward
            else:
                func_body = func.__call__

            ty_self = type_of_value(func)
            ty_args = [ty_self] + ty_args

        code = clip_head(inspect.getsource(func_body))
        # FunctionDef of called subroutine
        func_node = gast.ast_to_gast(ast.parse(code)).body[0]
        self.subroutine_node[node] = func_node
        tc = InferenceEngine(is_debug=self.is_debug,
                             module=sys.modules[func.__module__])
        tc.infer_function(func_node,
                          ty_args,
                          type_hints=typing.get_type_hints(func_body))

        # copy nodetype and subroutine_node from subroutine
        utils.add_dict(self.nodetype, tc.nodetype)
        utils.add_dict(self.subroutine_node, tc.subroutine_node)
        return ty_args, tc.nodetype[func_node]
예제 #6
0
    def test_continue_break_nested(self):
        orig_code = utils.clip_head("""
        x = 0
        for i in range(10):
            if i == 5:
                continue
            if i == 6:
                break
            for j in range(10):
                if j == 5:
                    break
                x += i * j
        """)
        target_code = utils.clip_head("""
        x = 0
        for i in range(10):
            breaked_1 = False
            continued_0 = False
            if i == 5:
                continued_0 = True
            if not continued_0:
                if i == 6:
                    breaked_1 = True
            if not continued_0 and not breaked_1:
                for j in range(10):
                    breaked_2 = False
                    if j == 5:
                        breaked_2 = True
                    if not breaked_2:
                        x += i * j
                    keepgoing = not breaked_2
                    if breaked_2:
                        break
            keepgoing = not breaked_1
            if breaked_1:
                break
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['x', 'i', 'j'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
예제 #7
0
def generate_id2type_from_forward(model, args, is_debug=False):
    code = utils.clip_head(inspect.getsource(model.forward))
    tree = gast.ast_to_gast(ast.parse(code))
    module = sys.modules[model.forward.__module__]
    node2type, subroutine_node = generate_node2type(
            tree, (model,) + args, is_debug=is_debug, module=module,
            type_hints=typing.get_type_hints(model.forward))
    node2id = generate_node2id(tree, subroutine_node)
    id2type = generate_id2type(node2type, node2id)
    return id2type
예제 #8
0
    def test_return(self):
        orig_code = utils.clip_head("""
        def func(a, b):
            for i in range(a):
                if i == b:
                    return i
            return 0

        value = 0
        for a in range(10):
            for b in range(10):
                value += func(a, b)
        """)
        target_code = utils.clip_head("""
        def func(a, b):
            returned_value = None
            returned_1 = False
            for i in range(a):
                if i == b:
                    returned_1 = True
                    returned_value = i
                keepgoing = not returned_1
                if returned_1:
                    break
            if not returned_1:
                returned_1 = True
                returned_value = 0
            return returned_value

        value = 0
        for a in range(10):
            for b in range(10):
                value += func(a, b)
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['value'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)
예제 #9
0
def generate_type_inference_results(model, forward_args, is_debug=True):
    code = utils.clip_head(inspect.getsource(model.forward))
    node = gast.ast_to_gast(ast.parse(code))
    # node = Canonicalizer().visit(node)
    module = sys.modules[model.forward.__module__]
    node2type, subroutine_node = generate_node2type(
        node, (model,) + forward_args, is_debug=is_debug, module=module,
        type_hints=typing.get_type_hints(model.forward))
    node2id = generate_node2id(node, subroutine_node)
    id2type = generate_id2type(node2type, node2id)
    id2node = generate_id2node(node2id)
    return id2type, id2node
예제 #10
0
    def __init__(self, func):
        super().__init__()

        self.inst = func
        self.name = func.__name__
        self.filename = inspect.getfile(func)
        sourcelines = inspect.getsourcelines(func)
        self.lineno = sourcelines[1]

        code = utils.clip_head(inspect.getsource(func))
        self.args.analyze_args(func)

        ast_ = gast.ast_to_gast(ast.parse(code)).body[0]
        self.ast = canonicalizer.Canonicalizer().visit(ast_)
예제 #11
0
    def __init__(self, func):
        super().__init__()

        self.inst = func
        self.name = func.__name__
        self.filename = inspect.getfile(func)
        sourcelines = inspect.getsourcelines(func)
        self.lineno = sourcelines[1]
        self.args.analyze_args(func)

        if (func.__name__ == (lambda: None).__name__):
            original_code = utils.lambda_source(func)
            code = 'return ' + original_code[re.search('lambda.*?:', original_code).end():]
            self.ast = gast.ast_to_gast(ast.parse(code))
        else:
            original_code = inspect.getsource(func)
            code = utils.clip_head(original_code)
            ast_ = gast.ast_to_gast(ast.parse(code)).body[0]
            self.ast = canonicalizer.Canonicalizer().visit(ast_)
예제 #12
0
    def dump_ast(mod, name):
        if IMPORT_ASTMONKEY:
            mod = deepcopy(mod)
            mod = transformers.ParentChildNodeTransformer().visit(
                deepcopy(mod))
            visitor = visitors.GraphNodeVisitor()
            visitor.visit(mod)
            visitor.graph.write_png(name + '.png')
            print(
                "\033[1;32;40mAST visualization saved as \033[94m%s.png\033[0m"
                % name)
        else:
            print("\033[93mInstall astmonkey for visualization.\033[0m")

    if len(sys.argv) == 3:
        module = importlib.import_module(sys.argv[1])
        func = getattr(module, sys.argv[2])
        code = clip_head(inspect.getsource(func))
    else:
        module = None
        code = open(sys.argv[1]).read()
    orig_ast = gast.ast_to_gast(ast.parse(code))
    dump_ast(orig_ast, 'original')

    tc = InferenceEngine(is_debug=True, module=module)
    try:
        nodetype = tc.infer(orig_ast)
    except UnifyError as e:
        print(traceback.format_exc(), end="")
예제 #13
0
    def test_return_continue(self):
        orig_code = utils.clip_head("""
        def func(a, b, c):
            x = 0
            for i in range(a):
                if i == b:
                    continue
                for j in range(a):
                    if j == c:
                        continue
                    if j == b:
                        return x
                    x += i * j
            return x

        value = 0
        for a in range(10):
            for b in range(10):
                for c in range(10):
                    value += func(a, b, c)
        """)
        target_code = utils.clip_head("""
        def func(a, b, c):
            returned_value = None
            returned_1 = False
            x = 0
            for i in range(a):
                continued_1 = False
                if i == b:
                    continued_1 = True
                if not continued_1:
                    for j in range(a):
                        continued_2 = False
                        if j == c:
                            continued_2 = True
                        if not continued_2:
                            if j == b:
                                if not continued_2:
                                    returned_1 = True
                                    returned_value = x
                        if not continued_2 and not returned_1:
                            x += i * j
                        keepgoing = not returned_1
                        if returned_1:
                            break
                keepgoing = not returned_1
                if returned_1:
                    break
            if not returned_1:
                returned_1 = True
                returned_value = x
            return returned_value

        value = 0
        for a in range(10):
            for b in range(10):
                for c in range(10):
                    value += func(a, b, c)
        """)
        orig_ast = gast.ast_to_gast(ast.parse(orig_code))
        target_ast = gast.ast_to_gast(ast.parse(target_code))
        converted_ast = self.canonicalizer.visit(orig_ast)

        assert_semantically_equals(orig_code, target_code, ['value'])
        assert compare_ast(converted_ast, target_ast)
        assert compare_ast(target_ast, converted_ast)