def replace(template, **replacements): """Replace placeholders in a Python template. Args: template: A function to be used as a template. Any placeholder is expected to also be a function argument. **replacements: A mapping from placeholder names to (lists of) AST nodes that these placeholders will be replaced by. Returns: body: An AST node or list of AST nodes with the replacements made. If the template was a function, a list will be returned. If the template was a node, the same node will be returned. If the template was a string, an AST node will be returned (a `Module` node in the case of a multi-line string, an `Expr` node otherwise). Raises: ValueError: If a function is used as a template and an incorrect set of replacements was passed. """ tree = parser.parse_object(template).body[0] placeholders = set(arg.id for arg in tree.args.args) tree.args.args = [] if tree.args.vararg: placeholders.add(tree.args.vararg) tree.args.vararg = None if set(replacements.keys()) != placeholders: raise ValueError( 'too many or few replacements. replacements: %s; placeholders: %s' % (replacements.keys(), placeholders)) # Perform the replacement, stripping the function into which the template was # wrapped. return ReplaceTransformer(replacements).visit(tree).body
def function_to_graph(f, conversion_map, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node = parser.parse_object(f).body[0] namespace = six.get_function_globals(f) # This is needed for non-global functions. closure = six.get_function_closure(f) if closure: for e in closure: if callable(e.cell_contents): fn = e.cell_contents namespace[fn.__name__] = fn namer = conversion_map.new_namer(namespace) ctx = context.EntityContext( namer=namer, source_code=tf_inspect.getsource(f), source_file=tf_inspect.getfile(f), namespace=namespace, arg_values=arg_values, arg_types=arg_types) node = node_to_graph(node, ctx, conversion_map.nocompile_decorators) # Simulate a rename to ensure the top level is in the name map. This is needed # for top level functions, and it also helps the consistency verification made # by update_name_map. if owner_type is not None: new_name = namer.compiled_function_name(f.__name__, f, owner_type) else: new_name = namer.compiled_function_name(f.__name__, f) node.name = new_name conversion_map.update_name_map(namer) return node, conversion_map.name_map[f]
def test_if(self): def test_fn(x): if x > 0: x = -x y = 2 * x z = -y else: x = 2 * x y = -x u = -y return z, u node = parser.parse_object(test_fn) node = access.resolve(node) if_node = node.body[0].body[0] self.assertScopeIs( anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'), ('y', 'z')) # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? self.assertScopeIs( anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) self.assertScopeIs( anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'), ('y', 'u')) self.assertScopeIs( anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def function_to_graph(f, conversion_map, param_value_hints, owner_type=None): """Specialization of `object_to_graph` for callable functions.""" node = parser.parse_object(f).body[0] node_globals = six.get_function_globals(f) # This is needed for non-global functions. closure = six.get_function_closure(f) if closure: for e in closure: if callable(e.cell_contents): fn = e.cell_contents node_globals[fn.__name__] = fn namer = conversion_map.new_namer(node_globals) node = node_to_graph(node, namer, node_globals, param_value_hints) # Simulate a rename to ensure the top level is in the name map. This is needed # for top level functions, and it also helps the consistency verification made # by update_name_map. if owner_type is not None: new_name = namer.compiled_function_name(f.__name__, f, owner_type) else: new_name = namer.compiled_function_name(f.__name__, f) node.name = new_name conversion_map.update_name_map(namer) return node, conversion_map.name_map[f]
def test_print_statement(self): def test_fn(a): b = 0 c = 1 print(a, b) return c node = parser.parse_object(test_fn) node = access.resolve(node) print_node = node.body[0].body[2] if isinstance(print_node, gast.Print): # Python 2 print_args_scope = anno.getanno(print_node, 'args_scope') else: # Python 3 assert isinstance(print_node, gast.Expr) # The call node should be the one being annotated. print_node = print_node.value print_args_scope = anno.getanno(print_node, 'args_scope') # We basically need to detect which variables are captured by the call # arguments. self.assertItemsEqual(['a', 'b'], print_args_scope.used) self.assertItemsEqual([], print_args_scope.modified) self.assertItemsEqual([], print_args_scope.created)
def function_to_graph(f, conversion_map, param_value_hints, owner_type=None): """Specialization of `object_to_graph` for callable functions.""" node = parser.parse_object(f).body[0] node_globals = six.get_function_globals(f) # This is needed for non-global functions. closure = six.get_function_closure(f) if closure: for e in closure: if callable(e.cell_contents): fn = e.cell_contents node_globals[fn.__name__] = fn namer = conversion_map.new_namer(node_globals) node = node_to_graph(node, namer, node_globals, param_value_hints) # Simulate a rename to ensure the top level is in the name map. This is needed # for top level functions, and it also helps the consistency verification made # by update_name_map. if owner_type is not None: new_name = namer.compiled_function_name(f.__name__, f, owner_type) else: new_name = namer.compiled_function_name(f.__name__, f) node.name = new_name conversion_map.update_name_map(namer) return node, conversion_map.name_map[f]
def function_to_graph(f, conversion_map, param_value_hints): """Specialization of `object_to_graph` for callable functions.""" node = parser.parse_object(f).body[0] node_globals = six.get_function_globals(f) # This is needed for non-global functions. closure = six.get_function_closure(f) if closure: for e in closure: if callable(e.cell_contents): fn = e.cell_contents node_globals[fn.__name__] = fn namer = conversion_map.new_namer(node_globals) node = node_to_graph(node, namer, node_globals, param_value_hints) # Simulate a rename to ensure the top level is in the name map. This is needed # for top level functions, and it also helps the consistency verification made # by update_name_map. namer.compiled_function_name(f.__name__, f) conversion_map.add_to_cache(f, node) conversion_map.update_name_map(namer) # Recursively convert any remaining dependencies. for obj in conversion_map.name_map.keys(): if obj not in conversion_map.dependency_cache: object_to_graph(obj, conversion_map, None) return node, conversion_map.name_map[f]
def function_to_graph(f, conversion_map, param_value_hints): """Specialization of `object_to_graph` for callable functions.""" node = parser.parse_object(f).body[0] node_globals = six.get_function_globals(f) # This is needed for non-global functions. closure = six.get_function_closure(f) if closure: for e in closure: if callable(e.cell_contents): fn = e.cell_contents node_globals[fn.__name__] = fn namer = conversion_map.new_namer(node_globals) node = node_to_graph(node, namer, node_globals, param_value_hints) # Simulate a rename to ensure the top level is in the name map. This is needed # for top level functions, and it also helps the consistency verification made # by update_name_map. namer.compiled_function_name(f.__name__, f) conversion_map.add_to_cache(f, node) conversion_map.update_name_map(namer) # Recursively convert any remaining dependencies. for obj in conversion_map.name_map.keys(): if obj not in conversion_map.dependency_cache: object_to_graph(obj, conversion_map, None) return node, conversion_map.name_map[f]
def replace(template, **replacements): """Replace placeholders in a Python template. Args: template: A function to be used as a template. Any placeholder is expected to also be a function argument. **replacements: A mapping from placeholder names to (lists of) AST nodes that these placeholders will be replaced by. Returns: body: An AST node or list of AST nodes with the replacements made. If the template was a function, a list will be returned. If the template was a node, the same node will be returned. If the template was a string, an AST node will be returned (a `Module` node in the case of a multi-line string, an `Expr` node otherwise). Raises: ValueError: If a function is used as a template and an incorrect set of replacements was passed. """ tree = parser.parse_object(template).body[0] placeholders = set(arg.id for arg in tree.args.args) tree.args.args = [] if tree.args.vararg: placeholders.add(tree.args.vararg) tree.args.vararg = None if set(replacements.keys()) != placeholders: raise ValueError( 'too many or few replacements. replacements: %s; placeholders: %s' % (replacements.keys(), placeholders)) # Perform the replacement, stripping the function into which the template was # wrapped. return ReplaceTransformer(replacements).visit(tree).body
def test_print_statement(self): def test_fn(a): b = 0 c = 1 print(a, b) return c node = parser.parse_object(test_fn) node = access.resolve(node) print_node = node.body[0].body[2] if isinstance(print_node, gast.Print): # Python 2 print_args_scope = anno.getanno(print_node, 'args_scope') else: # Python 3 assert isinstance(print_node, gast.Expr) # The call node should be the one being annotated. print_node = print_node.value print_args_scope = anno.getanno(print_node, 'args_scope') # We basically need to detect which variables are captured by the call # arguments. self.assertItemsEqual(['a', 'b'], print_args_scope.used) self.assertItemsEqual([], print_args_scope.modified) self.assertItemsEqual([], print_args_scope.created)
def test_if(self): def test_fn(x): if x > 0: x = -x y = 2 * x z = -y else: x = 2 * x y = -x u = -y return z, u node = parser.parse_object(test_fn) node = access.resolve(node) if_node = node.body[0].body[0] self.assertScopeIs(anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'), ('y', 'z')) # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? self.assertScopeIs(anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) self.assertScopeIs(anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'), ('y', 'u')) self.assertScopeIs(anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def test_parameter_class_members(self): def test_fn(opt): opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def test_nested_members(self): def test_fn(): foo = training.GradientDescentOptimizer(0.1) foo.bar.baz() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def test_parameter_class_members(self): def test_fn(opt): opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def test_literals(self): def test_fn(): return Foo # pylint: disable=undefined-variable node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {}, {'Foo': 'bar'}) retval_node = node.body[0].body[0].value self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
def test_literals(self): def test_fn(): return Foo # pylint: disable=undefined-variable node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {}, {'Foo': 'bar'}) retval_node = node.body[0].body[0].value self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
def test_transform(self): def test_fn(a, b, c): return (a or b) and (a or b or c) node = parser.parse_object(test_fn) node = logical_expressions.transform(node) result = compiler.ast_to_object(node) setattr(result, 'tf', math_ops) with self.test_session() as sess: self.assertTrue(sess.run(result.test_fn(True, False, True)))
def test_nested_members(self): def test_fn(): foo = training.GradientDescentOptimizer(0.1) foo.bar.baz() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def test_transform(self): def test_fn(a, b, c): return (a or b) and (a or b or c) node = parser.parse_object(test_fn) node = logical_expressions.transform(node) result = compiler.ast_to_object(node) setattr(result, 'tf', math_ops) with self.test_session() as sess: self.assertTrue(sess.run(result.test_fn(True, False, True)))
def _parse_and_analyze(self, test_fn, namespace, arg_types=None): ctx = context.EntityContext(namer=None, source_code=None, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, namespace, {}) node = type_info.resolve(node, ctx) return node
def test_attribute_names(self): def test_fn(): return constant_op.constant(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'constant_op': constant_op}, {}) func_node = node.body[0].body[0].value.func self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) self.assertEquals((constant_op.__name__, 'constant'), anno.getanno(func_node, 'fqn'))
def test_class_members(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) node = type_info.resolve(node, None) attr_call_node = node.body[0].body[1].value.func self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(attr_call_node, 'type_fqn'))
def test_equals(self): def test_fn(a, b): return a == b node = parser.parse_object(test_fn) node = logical_expressions.transform(node) result = compiler.ast_to_object(node) setattr(result, 'tf', math_ops) with self.test_session() as sess: self.assertTrue(sess.run(result.test_fn(1, 1))) self.assertFalse(sess.run(result.test_fn(1, 2)))
def test_attribute_names(self): def test_fn(): return constant_op.constant(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'constant_op': constant_op}, {}) func_node = node.body[0].body[0].value.func self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) self.assertEquals((constant_op.__name__, 'constant'), anno.getanno(func_node, 'fqn'))
def test_function_variables(self): def bar(): pass def test_fn(): foo = bar foo() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'bar': bar}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def test_constructor_deta_dependent(self): def test_fn(x): if x > 0: opt = training.GradientDescentOptimizer(0.1) else: opt = training.GradientDescentOptimizer(0.01) opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def test_constructor_deta_dependent(self): def test_fn(x): if x > 0: opt = training.GradientDescentOptimizer(0.1) else: opt = training.GradientDescentOptimizer(0.01) opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, {})
def test_class_members(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) node = type_info.resolve(node, None) attr_call_node = node.body[0].body[1].value.func self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(attr_call_node, 'type_fqn'))
def test_namespace(self): def foo(): return 'bar' def test_fn(): return foo() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'foo': foo}, {}) func_node = node.body[0].body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo', ), anno.getanno(func_node, 'fqn'))
def test_function_variables(self): def bar(): pass def test_fn(): foo = bar foo() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'bar': bar}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def test_constructor_detection(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) return opt node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) node = type_info.resolve(node, None) call_node = node.body[0].body[0].value self.assertEquals(training.GradientDescentOptimizer, anno.getanno(call_node, 'type')) self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(call_node, 'type_fqn'))
def test_call(self): def test_fn(a): b = 0 c = 1 foo(a, b) # pylint:disable=undefined-variable return c node = parser.parse_object(test_fn) node = access.resolve(node) call_node = node.body[0].body[2].value # We basically need to detect which variables are captured by the call # arguments. self.assertScopeIs(anno.getanno(call_node, 'args_scope'), ('a', 'b'), (), ())
def test_namespace(self): def foo(): return 'bar' def test_fn(): return foo() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'foo': foo}, {}) func_node = node.body[0].body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
def test_call(self): def test_fn(a): b = 0 c = 1 foo(a, b) # pylint:disable=undefined-variable return c node = parser.parse_object(test_fn) node = access.resolve(node) call_node = node.body[0].body[2].value # We basically need to detect which variables are captured by the call # arguments. self.assertScopeIs( anno.getanno(call_node, 'args_scope'), ('a', 'b'), (), ())
def test_constructor_detection(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) return opt node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) node = type_info.resolve(node, None) call_node = node.body[0].body[0].value self.assertEquals(training.GradientDescentOptimizer, anno.getanno(call_node, 'type')) self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(call_node, 'type_fqn'))
def test_while(self): def test_fn(a): b = a while b > 0: c = b b -= 1 return b, c node = parser.parse_object(test_fn) node = access.resolve(node) while_node = node.body[0].body[1] self.assertScopeIs(anno.getanno(while_node, 'body_scope'), ('b', ), ('b', 'c'), ('c', )) self.assertScopeIs(anno.getanno(while_node, 'body_parent_scope'), ('a', 'b', 'c'), ('a', 'b', 'c'), ('a', 'b', 'c'))
def test_parameter_class_members_with_value_hints(self): def test_fn(opt): opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) node = type_info.resolve( node, { 'opt': (('%s.GradientDescentOptimizer' % training.__name__), training.GradientDescentOptimizer(0.1)) }) attr_call_node = node.body[0].body[0].value.func self.assertEquals( training.__name__.split('.') + ['GradientDescentOptimizer'], anno.getanno(attr_call_node, 'type_fqn'))
def test_local_markers(self): def test_fn(a): # pylint:disable=unused-argument b = c # pylint:disable=undefined-variable while b > 0: b -= 1 return b node = parser.parse_object(test_fn) node = access.resolve(node) self.assertFalse(anno.getanno(node.body[0].body[0].value, 'is_local')) # c in b = c self.assertTrue( anno.getanno(node.body[0].body[1].test.left, 'is_local')) # b in b > 0 self.assertTrue(anno.getanno(node.body[0].body[2].value, 'is_local')) # b in return b
def test_local_markers(self): def test_fn(a): # pylint:disable=unused-argument b = c # pylint:disable=undefined-variable while b > 0: b -= 1 return b node = parser.parse_object(test_fn) node = access.resolve(node) self.assertFalse(anno.getanno(node.body[0].body[0].value, 'is_local')) # c in b = c self.assertTrue(anno.getanno(node.body[0].body[1].test.left, 'is_local')) # b in b > 0 self.assertTrue(anno.getanno(node.body[0].body[2].value, 'is_local')) # b in return b
def test_for(self): def test_fn(a): b = a for _ in a: c = b b -= 1 return b, c node = parser.parse_object(test_fn) node = access.resolve(node) for_node = node.body[0].body[1] self.assertScopeIs( anno.getanno(for_node, 'body_scope'), ('b',), ('b', 'c'), ('c',)) self.assertScopeIs( anno.getanno(for_node, 'body_parent_scope'), ('a', 'b', 'c'), ('a', 'b', 'c', '_'), ('a', 'b', 'c', '_'))
def test_parameter_class_members_with_value_hints(self): def test_fn(opt): opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) node = type_info.resolve( node, { 'opt': (('%s.GradientDescentOptimizer' % training.__name__), training.GradientDescentOptimizer(0.1)) }) attr_call_node = node.body[0].body[0].value.func self.assertEquals( training.__name__.split('.') + ['GradientDescentOptimizer'], anno.getanno(attr_call_node, 'type_fqn'))
def parse_and_analyze(self, test_fn, namespace, arg_types=None, include_type_analysis=True): ctx = context.EntityContext( namer=None, source_code=None, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, namespace, {}) if include_type_analysis: node = type_info.resolve(node, ctx) return node
def test_call(self): def test_fn(a): b = 0 c = 1 foo(a, b) # pylint:disable=undefined-variable return c node = parser.parse_object(test_fn) node = access.resolve(node) call_node = node.body[0].body[2].value call_args_scope = anno.getanno(call_node, 'args_scope') # We basically need to detect which variables are captured by the call # arguments. self.assertItemsEqual(['a', 'b'], call_args_scope.used) self.assertItemsEqual([], call_args_scope.modified) self.assertItemsEqual([], call_args_scope.created)
def test_transform(self): def loss(x, w): return x * w def test_fn(x, w): l, (dw, ) = tfe.value_and_gradients_function(loss, [1])(x, w) # pylint:disable=undefined-variable return l, dw node = parser.parse_object(test_fn) node = gradients_function.transform(node) result = compiler.ast_to_object(node) setattr(result, 'tf', gradients_impl) setattr(result, 'loss', loss) with self.test_session() as sess: self.assertEqual((12, 3), sess.run( result.test_fn(constant_op.constant(3), constant_op.constant(4))))
def test_call(self): def test_fn(a): b = 0 c = 1 foo(a, b) # pylint:disable=undefined-variable return c node = parser.parse_object(test_fn) node = access.resolve(node) call_node = node.body[0].body[2].value call_args_scope = anno.getanno(call_node, 'args_scope') # We basically need to detect which variables are captured by the call # arguments. self.assertItemsEqual(['a', 'b'], call_args_scope.used) self.assertItemsEqual([], call_args_scope.modified) self.assertItemsEqual([], call_args_scope.created)
def test_class_members_in_with_stmt(self): def test_fn(x): with session.Session() as sess: sess.run(x) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'session': session}, {}) node = type_info.resolve(node, {}) constructor_call = node.body[0].body[0].items[0].context_expr self.assertEquals(session.Session, anno.getanno(constructor_call, 'type')) self.assertEquals((session.__name__, 'Session'), anno.getanno(constructor_call, 'type_fqn')) member_call = node.body[0].body[0].body[0].value.func self.assertEquals((session.__name__, 'Session'), anno.getanno(member_call, 'type_fqn'))
def test_class_members_in_with_stmt(self): def test_fn(x): with session.Session() as sess: sess.run(x) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'session': session}, {}) node = type_info.resolve(node, None) constructor_call = node.body[0].body[0].items[0].context_expr self.assertEquals(session.Session, anno.getanno(constructor_call, 'type')) self.assertEquals((session.__name__, 'Session'), anno.getanno(constructor_call, 'type_fqn')) member_call = node.body[0].body[0].body[0].value.func self.assertEquals((session.__name__, 'Session'), anno.getanno(member_call, 'type_fqn'))
def _should_compile(self, node, fqn): for i in range(1, len(fqn)): if fqn[:i] in self.uncompiled_modules: return False # Check for local decorations if anno.hasanno(node, 'graph_ready'): return False # The decorators themselves are not to be converted. # If present, the decorators should appear as static functions. target_obj = self._try_resolve_target(node.func) if target_obj is not None: # This attribute is set by the decorator itself. # TODO(mdan): This may not play nicely with other wrapping decorators. if hasattr(target_obj, '__pyct_is_compile_decorator'): return False if target_obj in self.nocompile_decorators: return False # Inspect the target function decorators. If any include a @convert # or @graph_ready annotation, then they must be called as they are. # TODO(mdan): This may be quite heavy. # To parse and re-analize each function for every call site could be quite # wasteful. Maybe we could cache the parsed AST? try: target_node = parser.parse_object(target_obj).body[0] except TypeError: # Functions whose source we cannot access are compilable (e.g. wrapped # to py_func). return True for dec in target_node.decorator_list: decorator_fn = self._resolve_name(dec) if (decorator_fn is not None and decorator_fn in self.nocompile_decorators): return False return True
def test_while(self): def test_fn(a): b = a while b > 0: c = b b -= 1 return b, c node = parser.parse_object(test_fn) node = access.resolve(node) while_node = node.body[0].body[1] while_body_scope = anno.getanno(while_node, 'body_scope') while_parent_scope = anno.getanno(while_node, 'parent_scope') self.assertItemsEqual(['b'], while_body_scope.used) self.assertItemsEqual(['b', 'c'], while_body_scope.modified) self.assertItemsEqual(['c'], while_body_scope.created) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.used) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.modified) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.created)
def test_while(self): def test_fn(a): b = a while b > 0: c = b b -= 1 return b, c node = parser.parse_object(test_fn) node = access.resolve(node) while_node = node.body[0].body[1] while_body_scope = anno.getanno(while_node, 'body_scope') while_parent_scope = anno.getanno(while_node, 'parent_scope') self.assertItemsEqual(['b'], while_body_scope.used) self.assertItemsEqual(['b', 'c'], while_body_scope.modified) self.assertItemsEqual(['c'], while_body_scope.created) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.used) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.modified) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.created)
def _parse_and_analyze(self, test_fn, namespace): node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, namespace, {}) node = type_info.resolve(node, {}) return node
def _parse_and_analyze(self, test_fn, namespace): node = parser.parse_object(test_fn) node = access.resolve(node) return node
def test_parse_object(self): mod = parser.parse_object(f) self.assertEqual('f', mod.body[0].name)
def _parse_and_analyze(self, test_fn, namespace): node = parser.parse_object(test_fn) node = access.resolve(node) return node
def _parse_and_analyze(self, test_fn, namespace): node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, namespace, {}) node = type_info.resolve(node, None) return node