def test_continue_nested(self): orig_code = utils.clip_head(""" x = 0 for i in range(10): if i == 5: continue for j in range(10): if j == 5: continue x += i * j """) target_code = utils.clip_head(""" x = 0 for i in range(10): continued_0 = False if i == 5: continued_0 = True if not continued_0: for j in range(10): continued_1 = False if j == 5: continued_1 = True if not continued_1: x += i * j """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['x', 'i', 'j']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def test_return_simple(self): orig_code = utils.clip_head(""" def func(a, b): return a + b value = func(3, 39) """) target_code = utils.clip_head(""" def func(a, b): returned_value = None returned_1 = False returned_1 = True returned_value = a + b return returned_value value = func(3, 39) """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['value']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def test_break(self): orig_code = utils.clip_head(""" x = 0 for i in range(10): if i == 5: break x += i """) target_code = utils.clip_head(""" x = 0 for i in range(10): breaked_0 = False if i == 5: breaked_0 = True if not breaked_0: x += i keepgoing = not breaked_0 if breaked_0: break """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['x', 'i']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def __init__(self, classinfo): super().__init__() members = inspect.getmembers(classinfo) init_func = [m[1] for m in members if m[0] == '__init__'] assert(len(init_func) == 1) func = init_func[0] self.inst = func self.name = func.__name__ self.filename = inspect.getfile(func) sourcelines = inspect.getsourcelines(func) if sourcelines is None or len(sourcelines) < 1: utils.print_warning('Failed to parase {}'.format(classinfo), utils.LineProperty()) return self.lineno = sourcelines[1] self.classinfo = classinfo original_code = inspect.getsource(func) code = utils.clip_head(original_code) self.args.analyze_args(func) ast_ = gast.ast_to_gast(ast.parse(code)).body[0] self.ast = canonicalizer.Canonicalizer().visit(ast_)
def infer_user_defined_function(self, func, ty_args, node): if isinstance(func, types.FunctionType) or \ isinstance(func, types.MethodType): func_body = func if isinstance(node.func, gast.Attribute): ty_self = self.nodetype[node.func.value] ty_args = [ty_self] + ty_args else: # defined with __call__ if isinstance(func, chainer.Chain): func_body = func.forward else: func_body = func.__call__ ty_self = type_of_value(func) ty_args = [ty_self] + ty_args code = clip_head(inspect.getsource(func_body)) # FunctionDef of called subroutine func_node = gast.ast_to_gast(ast.parse(code)).body[0] self.subroutine_node[node] = func_node tc = InferenceEngine(is_debug=self.is_debug, module=sys.modules[func.__module__]) tc.infer_function(func_node, ty_args, type_hints=typing.get_type_hints(func_body)) # copy nodetype and subroutine_node from subroutine utils.add_dict(self.nodetype, tc.nodetype) utils.add_dict(self.subroutine_node, tc.subroutine_node) return ty_args, tc.nodetype[func_node]
def test_continue_break_nested(self): orig_code = utils.clip_head(""" x = 0 for i in range(10): if i == 5: continue if i == 6: break for j in range(10): if j == 5: break x += i * j """) target_code = utils.clip_head(""" x = 0 for i in range(10): breaked_1 = False continued_0 = False if i == 5: continued_0 = True if not continued_0: if i == 6: breaked_1 = True if not continued_0 and not breaked_1: for j in range(10): breaked_2 = False if j == 5: breaked_2 = True if not breaked_2: x += i * j keepgoing = not breaked_2 if breaked_2: break keepgoing = not breaked_1 if breaked_1: break """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['x', 'i', 'j']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def generate_id2type_from_forward(model, args, is_debug=False): code = utils.clip_head(inspect.getsource(model.forward)) tree = gast.ast_to_gast(ast.parse(code)) module = sys.modules[model.forward.__module__] node2type, subroutine_node = generate_node2type( tree, (model,) + args, is_debug=is_debug, module=module, type_hints=typing.get_type_hints(model.forward)) node2id = generate_node2id(tree, subroutine_node) id2type = generate_id2type(node2type, node2id) return id2type
def test_return(self): orig_code = utils.clip_head(""" def func(a, b): for i in range(a): if i == b: return i return 0 value = 0 for a in range(10): for b in range(10): value += func(a, b) """) target_code = utils.clip_head(""" def func(a, b): returned_value = None returned_1 = False for i in range(a): if i == b: returned_1 = True returned_value = i keepgoing = not returned_1 if returned_1: break if not returned_1: returned_1 = True returned_value = 0 return returned_value value = 0 for a in range(10): for b in range(10): value += func(a, b) """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['value']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def generate_type_inference_results(model, forward_args, is_debug=True): code = utils.clip_head(inspect.getsource(model.forward)) node = gast.ast_to_gast(ast.parse(code)) # node = Canonicalizer().visit(node) module = sys.modules[model.forward.__module__] node2type, subroutine_node = generate_node2type( node, (model,) + forward_args, is_debug=is_debug, module=module, type_hints=typing.get_type_hints(model.forward)) node2id = generate_node2id(node, subroutine_node) id2type = generate_id2type(node2type, node2id) id2node = generate_id2node(node2id) return id2type, id2node
def __init__(self, func): super().__init__() self.inst = func self.name = func.__name__ self.filename = inspect.getfile(func) sourcelines = inspect.getsourcelines(func) self.lineno = sourcelines[1] code = utils.clip_head(inspect.getsource(func)) self.args.analyze_args(func) ast_ = gast.ast_to_gast(ast.parse(code)).body[0] self.ast = canonicalizer.Canonicalizer().visit(ast_)
def __init__(self, func): super().__init__() self.inst = func self.name = func.__name__ self.filename = inspect.getfile(func) sourcelines = inspect.getsourcelines(func) self.lineno = sourcelines[1] self.args.analyze_args(func) if (func.__name__ == (lambda: None).__name__): original_code = utils.lambda_source(func) code = 'return ' + original_code[re.search('lambda.*?:', original_code).end():] self.ast = gast.ast_to_gast(ast.parse(code)) else: original_code = inspect.getsource(func) code = utils.clip_head(original_code) ast_ = gast.ast_to_gast(ast.parse(code)).body[0] self.ast = canonicalizer.Canonicalizer().visit(ast_)
def dump_ast(mod, name): if IMPORT_ASTMONKEY: mod = deepcopy(mod) mod = transformers.ParentChildNodeTransformer().visit( deepcopy(mod)) visitor = visitors.GraphNodeVisitor() visitor.visit(mod) visitor.graph.write_png(name + '.png') print( "\033[1;32;40mAST visualization saved as \033[94m%s.png\033[0m" % name) else: print("\033[93mInstall astmonkey for visualization.\033[0m") if len(sys.argv) == 3: module = importlib.import_module(sys.argv[1]) func = getattr(module, sys.argv[2]) code = clip_head(inspect.getsource(func)) else: module = None code = open(sys.argv[1]).read() orig_ast = gast.ast_to_gast(ast.parse(code)) dump_ast(orig_ast, 'original') tc = InferenceEngine(is_debug=True, module=module) try: nodetype = tc.infer(orig_ast) except UnifyError as e: print(traceback.format_exc(), end="")
def test_return_continue(self): orig_code = utils.clip_head(""" def func(a, b, c): x = 0 for i in range(a): if i == b: continue for j in range(a): if j == c: continue if j == b: return x x += i * j return x value = 0 for a in range(10): for b in range(10): for c in range(10): value += func(a, b, c) """) target_code = utils.clip_head(""" def func(a, b, c): returned_value = None returned_1 = False x = 0 for i in range(a): continued_1 = False if i == b: continued_1 = True if not continued_1: for j in range(a): continued_2 = False if j == c: continued_2 = True if not continued_2: if j == b: if not continued_2: returned_1 = True returned_value = x if not continued_2 and not returned_1: x += i * j keepgoing = not returned_1 if returned_1: break keepgoing = not returned_1 if returned_1: break if not returned_1: returned_1 = True returned_value = x return returned_value value = 0 for a in range(10): for b in range(10): for c in range(10): value += func(a, b, c) """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['value']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)