def test_return_simple(self): orig_code = utils.clip_head(""" def func(a, b): return a + b value = func(3, 39) """) target_code = utils.clip_head(""" def func(a, b): returned_value = None returned_1 = False returned_1 = True returned_value = a + b return returned_value value = func(3, 39) """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['value']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def test_break(self): orig_code = utils.clip_head(""" x = 0 for i in range(10): if i == 5: break x += i """) target_code = utils.clip_head(""" x = 0 for i in range(10): breaked_0 = False if i == 5: breaked_0 = True if not breaked_0: x += i keepgoing = not breaked_0 if breaked_0: break """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['x', 'i']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def test_continue_nested(self): orig_code = utils.clip_head(""" x = 0 for i in range(10): if i == 5: continue for j in range(10): if j == 5: continue x += i * j """) target_code = utils.clip_head(""" x = 0 for i in range(10): continued_0 = False if i == 5: continued_0 = True if not continued_0: for j in range(10): continued_1 = False if j == 5: continued_1 = True if not continued_1: x += i * j """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['x', 'i', 'j']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def __init__(self, classinfo): super().__init__() members = inspect.getmembers(classinfo) init_func = [m[1] for m in members if m[0] == '__init__'] assert(len(init_func) == 1) func = init_func[0] self.inst = func self.name = func.__name__ self.filename = inspect.getfile(func) sourcelines = inspect.getsourcelines(func) if sourcelines is None or len(sourcelines) < 1: utils.print_warning('Failed to parase {}'.format(classinfo), utils.LineProperty()) return self.lineno = sourcelines[1] self.classinfo = classinfo original_code = inspect.getsource(func) code = utils.clip_head(original_code) self.args.analyze_args(func) ast_ = gast.ast_to_gast(ast.parse(code)).body[0] self.ast = canonicalizer.Canonicalizer().visit(ast_)
def infer_user_defined_function(self, func, ty_args, node): if isinstance(func, types.FunctionType) or \ isinstance(func, types.MethodType): func_body = func if isinstance(node.func, gast.Attribute): ty_self = self.nodetype[node.func.value] ty_args = [ty_self] + ty_args else: # defined with __call__ if isinstance(func, chainer.Chain): func_body = func.forward else: func_body = func.__call__ ty_self = type_of_value(func) ty_args = [ty_self] + ty_args code = clip_head(inspect.getsource(func_body)) # FunctionDef of called subroutine func_node = gast.ast_to_gast(ast.parse(code)).body[0] self.subroutine_node[node] = func_node tc = InferenceEngine(is_debug=self.is_debug, module=sys.modules[func.__module__]) tc.infer_function(func_node, ty_args, type_hints=typing.get_type_hints(func_body)) # copy nodetype and subroutine_node from subroutine utils.add_dict(self.nodetype, tc.nodetype) utils.add_dict(self.subroutine_node, tc.subroutine_node) return ty_args, tc.nodetype[func_node]
def __init__(self, ch, env): src = clip_head(inspect.getsource(ch.forward)) dprint(src) self.ast = gast.ast_to_gast(ast.parse(src)).body[0] self.call = User_Defined_Func_In_Link(ch, ch.forward).call # 以下、 最初の外からのためのやつ # code.InteractiveConsole({'v': self.ast}).interact() self.forward_arglen = len(self.ast.args.args) - 1 # ここで、初期化したやつを上書きしてやる必要が出てくる # あとでchainerで実行するために回復しないといけないので、 # restore_funcs に復元すべきものを追加している self.inits = [] for s, v in ch.namedparams(): s = s[1:] if s.find('/') != -1: continue t = helper.make_tensor_value_info('/' + s, TensorProto.FLOAT, list(v.shape)) self.inits.append(t) mv = getattr(ch, s) setattr(ch, s, t) env.restore_funcs.append(lambda: setattr(ch, s, mv)) # TODO(satos) Yieldをコンパイルできるとこれを消せる mv = getattr(ch, 'children') setattr(ch, 'children', Func(lambda _, __, ___: mv())) env.restore_funcs.append(lambda: setattr(ch, 'children', mv))
def test_continue_break_nested(self): orig_code = utils.clip_head(""" x = 0 for i in range(10): if i == 5: continue if i == 6: break for j in range(10): if j == 5: break x += i * j """) target_code = utils.clip_head(""" x = 0 for i in range(10): breaked_1 = False continued_0 = False if i == 5: continued_0 = True if not continued_0: if i == 6: breaked_1 = True if not continued_0 and not breaked_1: for j in range(10): breaked_2 = False if j == 5: breaked_2 = True if not breaked_2: x += i * j keepgoing = not breaked_2 if breaked_2: break keepgoing = not breaked_1 if breaked_1: break """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['x', 'i', 'j']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def generate_id2type_from_forward(model, args, is_debug=False): code = utils.clip_head(inspect.getsource(model.forward)) tree = gast.ast_to_gast(ast.parse(code)) module = sys.modules[model.forward.__module__] node2type, subroutine_node = generate_node2type( tree, (model,) + args, is_debug=is_debug, module=module, type_hints=typing.get_type_hints(model.forward)) node2id = generate_node2id(tree, subroutine_node) id2type = generate_id2type(node2type, node2id) return id2type
def __init__(self, func): super().__init__() self.inst = func self.name = func.__name__ self.filename = inspect.getfile(func) sourcelines = inspect.getsourcelines(func) self.lineno = sourcelines[1] self.args.analyze_args(func) if (func.__name__ == (lambda: None).__name__): original_code = utils.lambda_source(func) code = 'return ' + original_code[re.search('lambda.*?:', original_code).end():] self.ast = gast.ast_to_gast(ast.parse(code)) else: original_code = inspect.getsource(func) code = utils.clip_head(original_code) ast_ = gast.ast_to_gast(ast.parse(code)).body[0] self.ast = canonicalizer.Canonicalizer().visit(ast_)
def test_return(self): orig_code = utils.clip_head(""" def func(a, b): for i in range(a): if i == b: return i return 0 value = 0 for a in range(10): for b in range(10): value += func(a, b) """) target_code = utils.clip_head(""" def func(a, b): returned_value = None returned_1 = False for i in range(a): if i == b: returned_1 = True returned_value = i keepgoing = not returned_1 if returned_1: break if not returned_1: returned_1 = True returned_value = 0 return returned_value value = 0 for a in range(10): for b in range(10): value += func(a, b) """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['value']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def eval_list_comp(nast, env): vn = "dummy@" + new_tensor().name # 重ならない名前にする(ループ内ループもあるため) assert len(nast.generators) >= 1 tast = gast.ast_to_gast(ast.parse("v.append(w)")).body[0] tast.value.func.value.id = vn tast.value.args[0] = nast.elt for gen in nast.generators: # とりあえず、このあたりはまだ実装しません assert len(gen.ifs) == 0 and gen.is_async == 0 tast = gast.For(target=gen.target, iter=gen.iter, body=[tast], orelse=[]) init = gast.ast_to_gast(ast.parse("v = []")).body[0] init.targets[0].id = vn tast = [init, tast] rv = eval_ast(tast, env) assert rv.is_none() res = env.pop_var(vn) return res
def __init__(self, func): super().__init__() self.inst = func self.name = func.__name__ self.lineno = inspect.getsourcelines(func)[1] code = utils.clip_head(inspect.getsource(func)) self.analyze_args(func) self.ast = gast.ast_to_gast(ast.parse(code)).body[0]
def generate_type_inference_results(model, forward_args, is_debug=True): code = utils.clip_head(inspect.getsource(model.forward)) node = gast.ast_to_gast(ast.parse(code)) # node = Canonicalizer().visit(node) module = sys.modules[model.forward.__module__] node2type, subroutine_node = generate_node2type( node, (model,) + forward_args, is_debug=is_debug, module=module, type_hints=typing.get_type_hints(model.forward)) node2id = generate_node2id(node, subroutine_node) id2type = generate_id2type(node2type, node2id) id2node = generate_id2node(node2id) return id2type, id2node
def __call__(self, codeobj): cache = self.cache key = self.get_file_info(codeobj) result = cache.get(key) if result is not None: return result fname = key[0] cache[(fname, 0)] = mod_ast = gast.ast_to_gast(self.parse_file(fname)) for obj in gast.walk(mod_ast): if isinstance(obj, gast.FunctionDef): cache[(fname, obj.lineno)] = obj return cache[key]
def __init__(self, func): super().__init__() self.inst = func self.name = func.__name__ self.filename = inspect.getfile(func) sourcelines = inspect.getsourcelines(func) self.lineno = sourcelines[1] code = utils.clip_head(inspect.getsource(func)) self.args.analyze_args(func) ast_ = gast.ast_to_gast(ast.parse(code)).body[0] self.ast = canonicalizer.Canonicalizer().visit(ast_)
def __init__(self, classinfo): super().__init__() members = inspect.getmembers(classinfo) init_func = [m[1] for m in members if m[0] == '__init__'] assert(len(init_func) == 1) func = init_func[0] self.inst = func self.name = func.__name__ self.lineno = inspect.getsourcelines(func)[1] self.classinfo = classinfo code = utils.clip_head(inspect.getsource(func)) self.analyze_args(func) self.ast = gast.ast_to_gast(ast.parse(code)).body[0]
def test_usub(self): orig_ast = gast.ast_to_gast(ast.parse("-3")) target_ast = gast.Module(body=[gast.Expr(value=gast.Num(n=-3))]) assert compare_ast(self.canonicalizer.visit(orig_ast), target_ast)
def __init__(self, func): self.func = func src = clip_head(inspect.getsource(func)) dprint(src) self.ast = gast.ast_to_gast(ast.parse(src)).body[0] assert (isinstance(self.ast, gast.gast.FunctionDef))
def test_return_continue(self): orig_code = utils.clip_head(""" def func(a, b, c): x = 0 for i in range(a): if i == b: continue for j in range(a): if j == c: continue if j == b: return x x += i * j return x value = 0 for a in range(10): for b in range(10): for c in range(10): value += func(a, b, c) """) target_code = utils.clip_head(""" def func(a, b, c): returned_value = None returned_1 = False x = 0 for i in range(a): continued_1 = False if i == b: continued_1 = True if not continued_1: for j in range(a): continued_2 = False if j == c: continued_2 = True if not continued_2: if j == b: if not continued_2: returned_1 = True returned_value = x if not continued_2 and not returned_1: x += i * j keepgoing = not returned_1 if returned_1: break keepgoing = not returned_1 if returned_1: break if not returned_1: returned_1 = True returned_value = x return returned_value value = 0 for a in range(10): for b in range(10): for c in range(10): value += func(a, b, c) """) orig_ast = gast.ast_to_gast(ast.parse(orig_code)) target_ast = gast.ast_to_gast(ast.parse(target_code)) converted_ast = self.canonicalizer.visit(orig_ast) assert_semantically_equals(orig_code, target_code, ['value']) assert compare_ast(converted_ast, target_ast) assert compare_ast(target_ast, converted_ast)
def test_usub(self): orig_ast = gast.ast_to_gast(ast.parse("-3")) target_ast = gast.Module( body=[gast.Expr(value=gast.Constant(value=-3, kind=None))], type_ignores=[]) assert compare_ast(self.canonicalizer.visit(orig_ast), target_ast)
def parse(file): code = open(file).read() tree = ast.parse(code) gtree = gast.ast_to_gast(tree) return gtree
import ast import gast import sys code = open(sys.argv[1]).read() tree = ast.parse(code) gtree = gast.ast_to_gast(tree) for node in gtree.body: print(ast.dump(node))
from copy import deepcopy try: from astmonkey import transformers, visitors IMPORT_ASTMONKEY = True except ImportError: IMPORT_ASTMONKEY = False def dump_ast(mod, name): print(gast.dump(mod)) if IMPORT_ASTMONKEY: mod = deepcopy(mod) mod = transformers.ParentChildNodeTransformer().visit( deepcopy(mod)) visitor = visitors.GraphNodeVisitor() visitor.visit(mod) visitor.graph.write_png(name + '.png') print( "\033[1;32;40mAST visualization saved as \033[94m%s.png\033[0m" % name) else: print("\033[93mInstall astmonkey for visualization.\033[0m") code = open(sys.argv[1]).read() orig_ast = gast.ast_to_gast(ast.parse(code)) print('=== Original AST ===') dump_ast(orig_ast, 'original') print('=== Canonicalized AST ===') canon_ast = Canonicalizer().visit(orig_ast) dump_ast(canon_ast, 'canonicalized')