def test_TryFinally(self): code = 'try:pass\nfinally:pass' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[Try(body=[Pass()], handlers=[], orelse=[], " "finalbody=[Pass()])], type_ignores=[])") self.assertEqual(gast.dump(tree), norm)
def test_Bytes(self): code = 'b"0012"' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[Expr(value=Constant(value=b'0012', " "kind=None))], type_ignores=[])") self.assertEqual(gast.dump(tree), norm)
def test_walk(self): code = 'x + 1' tree = gast.parse(code, mode='eval') dump = gast.dump(tree) norm = ("Expression(body=BinOp(left=Name(id='x', ctx=Load(), " "annotation=None), op=Add(), right=Num(n=1)))") self.assertEqual(dump, norm) self.assertEqual(len(list(gast.walk(tree))), 6)
def test_TryExcept(self): code = 'try:pass\nexcept e:pass\nelse:pass' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[Try(body=[Pass()], handlers=[ExceptHandler(" "type=Name(id='e', ctx=Load(), annotation=None, " "type_comment=None), name=None, body=[Pass()])]" ", orelse=[Pass()], finalbody=[])], type_ignores=[])") self.assertEqual(gast.dump(tree), norm)
def test_dump(self): code = 'lambda x: x' tree = gast.parse(code, mode='eval') dump = gast.dump(tree) norm = ("Expression(body=Lambda(args=arguments(args=[Name(id='x', " "ctx=Param(), annotation=None)], vararg=None, kwonlyargs=[], " "kw_defaults=[], kwarg=None, defaults=[]), body=Name(id='x', " "ctx=Load(), annotation=None)))") self.assertEqual(dump, norm)
def test_NamedExpr(self): code = '(x := 1) ' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[Expr(value=NamedExpr(target=Name(id='x'," " ctx=Store(), annotation=None, type_comment=None), " "value=Constant(value=1, kind=None)))], type_ignores=" "[])") self.assertEqual(gast.dump(tree), norm)
def test_preprocess_augassign(): # test cases: line 2 is input AST, line 3 is expected output AST def add_augassign_fn(): x, y = 1, 2 x += y x = x + y def sub_augassign_fn(): x, y = 1, 2 x -= y x = x - y def mul_augassign_fn(): x, y = 1, 2 x *= y x = x * y def div_augassign_fn(): x, y = 1, 2 x /= y x = x / y class C: x = 1 def attribute_augassign_fn(): y = 1, 2 C.x /= y C.x = C.x / y augassign_fns = [ add_augassign_fn, sub_augassign_fn, mul_augassign_fn, div_augassign_fn, attribute_augassign_fn, ] for fn in augassign_fns: fn_ast = parse_ast(fn).body[0] aug_ast, expected_ast = fn_ast.body[1], fn_ast.body[2] actual_ast = preprocess_augassign(aug_ast) assert gast.dump(actual_ast) == gast.dump(expected_ast)
def test_With(self): code = 'with open("any"): pass' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[With(items=[withitem(context_expr=Call(func=" "Name(id='open', ctx=Load(), annotation=None, " "type_comment=None), args=[Constant(value='any', " "kind=None)], keywords=[]), optional_vars=None)], body=[" "Pass()], type_comment=None)], type_ignores=[])") self.assertEqual(gast.dump(tree), norm)
def test_TypeIgnore(self): code = 'def foo(): pass # type: ignore[excuse]' tree = gast.parse(code, type_comments=True) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[FunctionDef(name='foo', args=arguments(" "args=[], posonlyargs=[], vararg=None, kwonlyargs=[], " "kw_defaults=[], kwarg=None, defaults=[]), body=[" "Pass()], decorator_list=[], returns=None, " "type_comment=None)], type_ignores=" "[TypeIgnore(lineno=1, tag='[excuse]')])") self.assertEqual(gast.dump(tree), norm)
def dump_ast(mod, name): print(gast.dump(mod)) if IMPORT_ASTMONKEY: mod = deepcopy(mod) mod = transformers.ParentChildNodeTransformer().visit(deepcopy(mod)) visitor = visitors.GraphNodeVisitor() visitor.visit(mod) visitor.graph.write_png(name + '.png') print("\033[1;32;40mAST visualization saved as \033[94m%s.png\033[0m" % name) else: print("\033[93mInstall astmonkey for visualization.\033[0m")
def test_keyword_argument(self): code = 'def foo(**a): pass' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[FunctionDef(name='foo', args=arguments(args=[], " "posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=[], " "kwarg=Name(id='a', ctx=Param(), annotation=None, " "type_comment=None), defaults=[]), body=[Pass()], " "decorator_list=[], returns=None, type_comment=None)], " "type_ignores=[])") self.assertEqual(gast.dump(tree), norm)
def test_KeywordOnlyArgument(self): code = 'def foo(*, x=1): pass' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[FunctionDef(name='foo', args=arguments(args=" "[], posonlyargs=[], vararg=None, kwonlyargs=[Name" "(id='x', ctx=Param(), annotation=None, type_comment=None" ")], kw_defaults=[Constant(value=1, kind=None)], kwarg=" "None, defaults=[]), body=[Pass()], decorator_list=[], " "returns=None, type_comment=None)], type_ignores=[])") self.assertEqual(gast.dump(tree), norm)
def test_FormattedValue(self): code = 'e = 1; f"{e}"' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[Assign(targets=[Name(id='e', ctx=Store()" ", annotation=None, type_comment=None" ")], value=Constant(value=1, kind=None)), Expr(value=" "JoinedStr(values=[FormattedValue(value=Name(id='e', " "ctx=Load(), annotation=None, type_comment=None), " "conversion=-1, format_spec=None)]))], " "type_ignores=[])") self.assertEqual(gast.dump(tree), norm)
def eval_ast(nast, env): for k, v in env.get_var_dict().items(): assert not isinstance(v, onnx.ValueInfoProto), '%s %s' % (k, v) global _eval_ast_depth if not isinstance(nast, list): dprint('-' * _eval_ast_depth, gast.dump(nast), env.get_var_dict().keys()) _eval_ast_depth += 1 r = eval_ast_impl(nast, env) _eval_ast_depth -= 1 return _value(r)
def test_Index(self): code = 'def foo(a): a[1]' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[FunctionDef(name='foo', args=arguments(args=[" "Name(id='a', ctx=Param(), annotation=None, type_comment=None)" "], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=[]" ", kwarg=None, defaults=[]), body=[Expr(value=Subscript(value=" "Name(id='a', ctx=Load(), annotation=None, type_comment=None)" ", slice=Index(value=Constant(value=1, kind=None)), ctx=Load()" "))], decorator_list=[], returns=None, type_comment=None)]" ", type_ignores=[])") self.assertEqual(gast.dump(tree), norm)
def infer_stmt(self, node): if self.is_debug: debug(gast.dump(node)) if isinstance(node, gast.FunctionDef): self.nodetype[node] = self.infer_FunctionDef(node) elif isinstance(node, gast.Return): # Return(expr? value) if node.value is None: self.nodetype[node] = TyNone() else: self.nodetype[node] = self.infer_expr(node.value) elif isinstance(node, gast.Delete): # TODO(momohatt): erase from tyenv, etc. # TODO(momohatt): support deletion of element from list self.nodetype[node] = TyNone() elif isinstance(node, gast.Assign): self.infer_Assign(node) self.nodetype[node] = TyNone() elif isinstance(node, gast.AugAssign): self.infer_AugAssign(node) self.nodetype[node] = TyNone() elif isinstance(node, gast.For): self.infer_For(node) self.nodetype[node] = TyNone() elif isinstance(node, gast.While): # While(expr test, stmt* body, stmt* orelse) pass elif isinstance(node, gast.If): self.nodetype[node] = self.infer_If(node) elif isinstance(node, gast.Raise): self.nodetype[node] = TyVar() elif isinstance(node, gast.Try): # TODO(momohatt): What is 'finalbody' ? ty_ret = self.infer_2blocks(self, self, node.body, node.orelse) self.nodetype[node] = ty_ret elif isinstance(node, gast.Assert): self.nodetype[node] = TyNone() elif isinstance(node, (gast.Import, gast.ImportFrom)): self.nodetype[node] = TyNone() elif isinstance(node, gast.Expr): # Expr(expr value) self.infer_expr(node.value) self.nodetype[node] = TyNone() elif isinstance(node, gast.Pass): self.nodetype[node] = TyNone() assert node in self.nodetype.keys(), type(node).__name__ return self.nodetype[node]
def test_ExtSlices(self): self.maxDiff = None code = 'def foo(a): a[1,:]' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[FunctionDef(name='foo', args=arguments(args=[" "Name(id='a', ctx=Param(), annotation=None, type_comment=None)" "], posonlyargs=[], vararg=None, kwonlyargs=[], kw_defaults=[]" ", kwarg=None, defaults=[]), body=[Expr(value=Subscript(value=" "Name(id='a', ctx=Load(), annotation=None, type_comment=None)" ", slice=Tuple(elts=[Constant(value=1, kind=" "None), Slice(lower=None, upper=None, step=None)], ctx=Load())" ", ctx=Load()))], decorator_list=[], returns=None, " "type_comment=None)], type_ignores=[])") self.assertEqual(gast.dump(tree), norm)
def test_Call(self): self.maxDiff = None code = 'foo(x, y=1, *args, **kwargs)' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[Expr(value=Call(func=Name(id='foo', ctx=Load" "(), annotation=None, type_comment=None" "), args=[Name(id='x', ctx=Load(), " "annotation=None, type_comment=None), Starred(value=Name(" "id='args', ctx=Load(), annotation=None, type_comment=None)" ", ctx=Load())], keywords=[keyword(" "arg='y', value=Constant(value=1, kind=None)), keyword(arg" "=None, value=Name(id='kwargs', ctx=Load(), annotation=None, " "type_comment=None))]))], type_ignores=[])") self.assertEqual(gast.dump(tree), norm)
def test_FunctionDef(self): code = 'def foo((x, y)): return x, y' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') norm = ("Module(body=[FunctionDef(name='foo', args=arguments(args=" "[Tuple(elts=[Name(id='x', ctx=Store(), annotation=None, " "type_comment=None), Name(id='y', ctx=Store(), " "annotation=None, type_comment=None)], ctx=Store())], " "posonlyargs=[], vararg=None, " "kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), " "body=[Return(value=Tuple(elts=[Name(id='x', ctx=Load(), " "annotation=None, type_comment=None), " "Name(id='y', ctx=Load(), " "annotation=None, type_comment=None" ")], ctx=Load()))], decorator_list=" "[], returns=None, type_comment=None)], type_ignores=[])") self.assertEqual(gast.dump(tree), norm)
def infer_stmt(self, node): if self.is_debug: debug(gast.dump(node)) self.stack.append(node) if isinstance(node, gast.FunctionDef): self.nodetype[node] = self.infer_FunctionDef(node) elif isinstance(node, gast.Return): # Return(expr? value) if node.value is None: self.nodetype[node] = TyNone() else: self.nodetype[node] = self.infer_expr(node.value) elif isinstance(node, gast.Delete): # TODO(momohatt): erase from tyenv, etc. # TODO(momohatt): support deletion of element from list self.nodetype[node] = TyNone() elif isinstance(node, gast.Assign): self.infer_Assign(node) self.nodetype[node] = TyNone() elif isinstance(node, gast.AugAssign): self.infer_AugAssign(node) self.nodetype[node] = TyNone() elif isinstance(node, gast.For): self.infer_For(node) self.nodetype[node] = TyNone() elif isinstance(node, gast.While): # While(expr test, stmt* body, stmt* orelse) pass elif isinstance(node, gast.If): self.infer_If(node) self.nodetype[node] = TyNone() elif isinstance(node, gast.Expr): # Expr(expr value) self.infer_expr(node.value) self.nodetype[node] = TyNone() elif isinstance(node, gast.Pass): self.nodetype[node] = TyNone() assert node in self.nodetype.keys(), type(node).__name__ self.stack.pop() return self.nodetype[node]
def test_var_env(self): for i, func in enumerate(test_funcs): var_type = result_var_type[i] test_source_code = inspect.getsource(func) ast_root = gast.parse(test_source_code) print(gast.dump(ast_root)) visitor = StaticAnalysisVisitor(ast_root) var_env = visitor.get_var_env() # There must be 1 sub scope for the test function self.assertEqual(1, len(var_env.cur_scope.sub_scopes)) var_env.cur_scope = var_env.cur_scope.sub_scopes[0] scope_var_type = var_env.get_scope_var_type() print(scope_var_type) self.assertEqual(len(scope_var_type), len(var_type)) for name in scope_var_type: print("Test var name %s" % (name)) self.assertTrue(name in var_type) self.assertEqual(scope_var_type[name], var_type[name])
def test_Raise(self): codes = ( 'raise Exception', 'raise "Exception"', 'raise Exception, "err"', 'raise Exception("err")', 'raise E, V, T', ) norms = ( "Module(body=[Raise(exc=Name(id='Exception', ctx=Load(), " "annotation=None, type_comment=None)," " cause=None)], type_ignores=[])", "Module(body=[Raise(exc=Constant(value='Exception', kind=" "None), cause=None)], type_ignores=[])", "Module(body=[Raise(exc=Call(func=Name(id='Exception', " "ctx=Load(), annotation=None, type_comment=None), " "args=[Constant(value='err', kind=None)], " "keywords=[]), cause=None)], type_ignores=[])", "Module(body=[Raise(exc=Call(func=Name(id='Exception', " "ctx=Load(), annotation=None, type_comment=None), " "args=[Constant(value='err', kind=None)], " "keywords=[]), cause=None)], type_ignores=[])", "Module(body=[Raise(exc=Call(func=Attribute(value=Call(" "func=Name(id='E', ctx=Load(), annotation=None, " "type_comment=None), args=[Name(id='V', ctx=" "Load(), annotation=None, type_comment=None)], keywords=[]), " "attr='with_traceback', ctx=Load" "()), args=[Name(id='T', ctx=Load(), annotation=None, " "type_comment=None)], keywords=[]), " "cause=None)], type_ignores=[])", ) if sys.version_info.major == 3: codes = codes[0], codes[1], codes[3] norms = norms[0], norms[1], norms[3] for code, norm in zip(codes, norms): tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') self.assertEqual(gast.dump(tree), norm)
def get_function_instance(self, node): if isinstance(node, gast.Attribute): if isinstance(node.value, gast.Name) and \ hasattr(self.module, node.value.id): # function of imported libraries (eg. np, chainer, F, L) module = getattr(self.module, node.value.id) return getattr(module, node.attr), None ty_obj = self.infer_expr(node.value).deref() if isinstance(ty_obj, TyList): return getattr(list, node.attr, None), ty_obj if isinstance(ty_obj, TyTensor): if ty_obj.is_ndarray(): return getattr(np.ndarray, node.attr, None), ty_obj if ty_obj.is_torch_tensor(): return getattr(torch.Tensor, node.attr, None), ty_obj if isinstance(ty_obj, TyUserDefinedClass): # if there is no such attribute, just return None (undefined) return getattr(ty_obj.instance, node.attr, None), None return None, None if isinstance(node, gast.Name): if node.id in self.tyenv.keys(): ty = self.tyenv[node.id].deref() if isinstance(ty, TyUserDefinedClass): return ty.instance, None if node.id in __builtins__.keys(): return __builtins__[node.id], None if hasattr(self.module, node.id): return getattr(self.module, node.id), None assert False, gast.dump(node)
def dump(node): return ast.dump(node, indent=2)
def test_buildable(self, template): """Test that each template can be built when given acceptable arguments.""" rng = np.random.RandomState(1234) # Construct a hole that this template can always fill. hole = top_down_refinement.Hole( template.fills_type, python_numbers_control_flow.ASTHoleMetadata( names_in_scope=frozenset({"a"}), inside_function=True, inside_loop=True, op_depth=0)) self.assertTrue(template.can_fill(hole)) # Make sure we can build this object with no errors. filler = template.fill(hole, rng) dummy_values = { python_numbers_control_flow.ASTHoleType.NUMBER: (lambda: gast.Constant(value=1, kind=None)), python_numbers_control_flow.ASTHoleType.BOOL: (lambda: gast.Constant(value=True, kind=None)), python_numbers_control_flow.ASTHoleType.STMT: gast.Pass, python_numbers_control_flow.ASTHoleType.STMTS: (lambda: []), python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY: (lambda: [gast.Pass()]), python_numbers_control_flow.ASTHoleType.BLOCK: (lambda: [gast.Pass()]), } hole_values = [dummy_values[h.hole_type]() for h in filler.holes] value = filler.build(*hole_values) # Check the type of the value that was built. if template.fills_type in ( python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY, python_numbers_control_flow.ASTHoleType.BLOCK): self.assertTrue(value) for item in value: self.assertIsInstance(item, gast.stmt) elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMTS: for item in value: self.assertIsInstance(item, gast.stmt) elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMT: self.assertIsInstance(value, gast.stmt) elif template.fills_type in (python_numbers_control_flow.ASTHoleType.NUMBER, python_numbers_control_flow.ASTHoleType.BOOL): self.assertIsInstance(value, gast.expr) else: raise NotImplementedError(f"Unexpected fill type {template.fills_type}; " "please update this test.") # Check that cost reflects number of AST nodes. total_cost = 0 if isinstance(value, gast.AST): for _ in gast.walk(value): total_cost += 1 else: for item in value: for _ in gast.walk(item): total_cost += 1 self.assertEqual(template.required_cost, total_cost) cost_without_holes = total_cost - sum( python_numbers_control_flow.ALL_COSTS[h.hole_type] for h in filler.holes) self.assertEqual(filler.cost, cost_without_holes) # Check determinism for _ in range(20): rng = np.random.RandomState(1234) redo_value = template.fill(hole, rng).build(*hole_values) if isinstance(value, list): self.assertEqual([gast.dump(v) for v in value], [gast.dump(v) for v in redo_value]) else: self.assertEqual(gast.dump(value), gast.dump(redo_value))
def test_analyze_block(): def ternary_fn(): x = 1 if True else False def list_comp_fn(): x = [i for i in [1, 2, 3]] def dict_comp_fn(): x = {i: j for i, j in [[1, 2], [2, 3], [3, 4]]} def lambda_fn(): x = lambda y: y + 2 def func_fn(): def fn(y): return y + 2 def ifelse_fn(): if True: x = 1 else: x = 3 def for_fn(): for i in [1, 2, 3]: x = i else: y = i def while_fn(): while True: x = 2 def with_fn(): with 1 as x: y = x def try_fn(): try: x = 1 finally: z = x == 1 # test case, whether the first statement is code block, expected ast getter block_fns = [ (ternary_fn, False), (list_comp_fn, False), (dict_comp_fn, False), (lambda_fn, False), (func_fn, True), (ifelse_fn, True), (for_fn, True), (while_fn, True), (with_fn, True), (try_fn, True), ] for fn, is_expected_block in block_fns: analyzed_ast = analyze_block(parse_ast(fn)) fn_ast = analyzed_ast.body[0] block_ast = fn_ast.body[0] # check code blocks are labeled correctly assert block_ast.is_block == is_expected_block # check back edges to code block are created correctedly to child nodes if is_expected_block: for child in gast.walk(block_ast): # walk() will include the root block ast node.. # ignore when checking code block back edges if child == block_ast: continue if child.block != block_ast: __import__("pprint").pprint((gast.dump(child.block))) __import__("pprint").pprint((gast.dump(block_ast))) assert child.block == block_ast
def test_keyword_argument(self): code = 'def foo(**a): pass' tree = gast.parse(code) compile(gast.gast_to_ast(tree), '<test>', 'exec') gast.dump(tree, include_attributes=True)
def test_symbol_resolution(): def simple_fn(): simple = 2 simple def multi_assign(): multi_a, multi_b = True, False multi_a, multi_b def repeated_assign(): repeated = "first" repeated = "second" repeated def scoped_assign(): scoped = False def fn(): scoped = True # should reference the first definition of 'scoped' as the second definition # is scoped to only within function scoped def aug_assign(): aug = 0 aug = aug + 1 aug class Qualified: a = 1 def qualified_assign(): Qualified.a = 1 Qualified.a # test case functions, the line no. wrt. the function where the variable last defined # and finally list of all line no. where variable is defined. # if line no. is None, the symbols is defined global symbol symbol_fns = [ (simple_fn, 0, [0]), (multi_assign, 0, [0]), (repeated_assign, 1, [0, 1]), (scoped_assign, 0, [0]), (aug_assign, 1, [0, 1]), (qualified_assign, 0, [0]), ] for symbol_fn, n_latest_def_line, n_def_lines in symbol_fns: ast = parse_ast(symbol_fn) required_analyzers = [ analyze_symbol, analyze_assign, ] for analyzer in required_analyzers: ast = analyzer(ast) analyzed_ast = resolve_symbol(ast) fn_ast = analyzed_ast.body[0] # check latest symbol definition labeled as 'definition' latest_sym_def = fn_ast.body[n_latest_def_line] sym_ref = fn_ast.body[-1].value latest_sym_defs = latest_sym_def.values sym_refs = sym_ref.elts if isinstance(sym_ref, Tuple) else [sym_ref] for latest_sym_def, sym_ref in zip(latest_sym_defs, sym_refs): if sym_ref.definition != latest_sym_def: print(gast.dump(sym_ref)) print(gast.dump(sym_ref.definition)) assert sym_ref.definition == latest_sym_def # check all symbol definitions labeled as 'definitions' sym_defs = [fn_ast.body[n_line] for n_line in n_def_lines] for line_sym_defs in sym_defs: for sym_def, sym_ref in zip(line_sym_defs.values, sym_refs): if sym_ref.definitions.count(sym_def) != 1: print(gast.dump(sym_ref)) print([gast.dump(d) for d in sym_ref.definitions]) assert sym_ref.definitions.count(sym_def) == 1
def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. The function will also recursively compile all the entities that `o` references, updating `dependency_cache`. This function is reentrant, and relies on dependency_cache to avoid generating duplicate code. Args: o: A Python entity. program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. Returns: A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. * namespace: A dict mapping all symbols visible to the converted entity, keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ logging.log(1, 'Converting %s', o) if tf_inspect.isclass(o): node, name, ns = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) # TODO(mdan,yashkatariya): Remove when object conversion is implemented. elif hasattr(o, '__class__'): raise NotImplementedError( 'Object conversion is not yet supported. If you are ' 'trying to convert code that uses an existing object, ' 'try including the creation of that object in the ' 'conversion. For example, instead of converting the method ' 'of a class, try converting the entire class instead. ' 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 'contrib/autograph/README.md#using-the-functional-api ' 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) # TODO(mdan): This is temporary. it should be created using a converter. # TODO(mdan): The attribute should be added with a helper, not directly. # The helper can ensure there are no collisions. template = ''' entity.autograph_info__ = {} ''' node.extend(templates.replace(template, entity=name)) program_ctx.add_to_cache(o, node) if logging.has_verbosity(2): logging.log(2, 'Compiled output of %s:\n\n%s\n', o, compiler.ast_to_source(node)) if logging.has_verbosity(4): for n in node: logging.log(4, 'Compiled AST of %s:\n\n%s\n', o, gast.dump(n)) if program_ctx.options.recursive: while True: candidate = None for obj in program_ctx.name_map.keys(): if obj not in program_ctx.dependency_cache: candidate = obj break if candidate is None: break if (hasattr(candidate, 'im_class') and getattr(candidate, 'im_class') not in program_ctx.partial_types): # Class members are converted with their objects, unless they're # only converted partially. continue entity_to_graph(candidate, program_ctx, {}, {}) return node, name, ns