def test_if(self): def test_fn(x): if x > 0: x = -x y = 2 * x z = -y else: x = 2 * x y = -x u = -y return z, u node = parser.parse_object(test_fn) node = access.resolve(node) if_node = node.body[0].body[0] self.assertScopeIs(anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'), ('y', 'z')) # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? self.assertScopeIs(anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) self.assertScopeIs(anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'), ('y', 'u')) self.assertScopeIs(anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def node_to_graph(node, namer, namespace, value_hints): """Convert Python code to equivalent TF graph mode code. Args: node: A Python AST node representing the code to convert. namer: A naming.Namer object. namespace: Dict mapping symbol names to their corresponding live objects. value_hints: A dict containing value hints for symbols like function parameters. Returns: A tuple (node, deps): * node: A Python ast node, representing the converted code. * deps: A set of strings, the fully qualified names of object dependencies that this node has. """ # TODO(mdan): Get rid of this. node = gradients_function.transform(node) node = access.resolve(node) node = live_values.resolve(node, namespace, config.PYTHON_LITERALS) node = type_info.resolve(node, value_hints) # TODO(mdan): Factor out common elements. # These include: # * keeping track of symbols that have been created # * marking nodes (e.g. py_func wrappers) to suppress further processing node = print_functions.transform(node) node = call_trees.transform(node, namer, config.DEFAULT_UNCOMPILED_MODULES) node = control_flow.transform(node, namer) node = logical_expressions.transform(node) node = side_effect_guards.transform(node, namer) return node
def node_to_graph(node, namer, namespace, value_hints): """Convert Python code to equivalent TF graph mode code. Args: node: A Python AST node representing the code to convert. namer: A naming.Namer object. namespace: Dict mapping symbol names to their corresponding live objects. value_hints: A dict containing value hints for symbols like function parameters. Returns: A tuple (node, deps): * node: A Python ast node, representing the converted code. * deps: A set of strings, the fully qualified names of object dependencies that this node has. """ node = access.resolve(node) node = live_values.resolve(node, namespace, config.PYTHON_LITERALS) node = type_info.resolve(node, value_hints) # TODO(mdan): Factor out common elements. # These include: # * keeping track of symbols that have been created # * marking nodes (e.g. py_func wrappers) to suppress further processing node = print_functions.transform(node) node = call_trees.transform(node, namer, config.DEFAULT_UNCOMPILED_MODULES) node = control_flow.transform(node, namer) node = logical_expressions.transform(node) node = side_effect_guards.transform(node, namer) return node
def test_print_statement(self): def test_fn(a): b = 0 c = 1 print(a, b) return c node = parser.parse_object(test_fn) node = access.resolve(node) print_node = node.body[0].body[2] if isinstance(print_node, gast.Print): # Python 2 print_args_scope = anno.getanno(print_node, 'args_scope') else: # Python 3 assert isinstance(print_node, gast.Expr) # The call node should be the one being annotated. print_node = print_node.value print_args_scope = anno.getanno(print_node, 'args_scope') # We basically need to detect which variables are captured by the call # arguments. self.assertItemsEqual(['a', 'b'], print_args_scope.used) self.assertItemsEqual([], print_args_scope.modified) self.assertItemsEqual([], print_args_scope.created)
def test_if(self): def test_fn(x): if x > 0: x = -x y = 2 * x z = -y else: x = 2 * x y = -x u = -y return z, u node = parser.parse_object(test_fn) node = access.resolve(node) if_node = node.body[0].body[0] self.assertScopeIs( anno.getanno(if_node, 'body_scope'), ('x', 'y'), ('x', 'y', 'z'), ('y', 'z')) # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? self.assertScopeIs( anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) self.assertScopeIs( anno.getanno(if_node, 'orelse_scope'), ('x', 'y'), ('x', 'y', 'u'), ('y', 'u')) self.assertScopeIs( anno.getanno(if_node, 'body_parent_scope'), ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def test_print_statement(self): def test_fn(a): b = 0 c = 1 print(a, b) return c node = parser.parse_object(test_fn) node = access.resolve(node) print_node = node.body[0].body[2] if isinstance(print_node, gast.Print): # Python 2 print_args_scope = anno.getanno(print_node, 'args_scope') else: # Python 3 assert isinstance(print_node, gast.Expr) # The call node should be the one being annotated. print_node = print_node.value print_args_scope = anno.getanno(print_node, 'args_scope') # We basically need to detect which variables are captured by the call # arguments. self.assertItemsEqual(['a', 'b'], print_args_scope.used) self.assertItemsEqual([], print_args_scope.modified) self.assertItemsEqual([], print_args_scope.created)
def node_to_graph(node, namer, namespace, value_hints): """Convert Python code to equivalent TF graph mode code. Args: node: A Python AST node representing the code to convert. namer: A naming.Namer object. namespace: Dict mapping symbol names to their corresponding live objects. value_hints: A dict containing value hints for symbols like function parameters. Returns: A tuple (node, deps): * node: A Python ast node, representing the converted code. * deps: A set of strings, the fully qualified names of object dependencies that this node has. """ node = access.resolve(node) node = live_values.resolve(node, namespace, config.PYTHON_LITERALS) node = type_info.resolve(node, value_hints) # TODO(mdan): Factor out common elements. # These include: # * keeping track of symbols that have been created # * marking nodes (e.g. py_func wrappers) to suppress further processing node = for_canonicalization.transform(node, namer) node = builtin_functions.transform(node) # The transformation steps above insert new variables. Although less # efficient, it is most robust to re-run the analysis. # We also need to ensure the namespace contains any new references that may # have been created. namespace['len'] = len namespace['print'] = print node = access.resolve(node) node = live_values.resolve(node, namespace, config.PYTHON_LITERALS) node = type_info.resolve(node, value_hints) node = print_functions.transform(node) node = call_trees.transform(node, namer, config.DEFAULT_UNCOMPILED_MODULES) node = control_flow.transform(node, namer) node = logical_expressions.transform(node) node = side_effect_guards.transform(node, namer) return node
def test_parameter_class_members(self): def test_fn(opt): opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def test_nested_members(self): def test_fn(): foo = training.GradientDescentOptimizer(0.1) foo.bar.baz() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def test_literals(self): def test_fn(): return Foo # pylint: disable=undefined-variable node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {}, {'Foo': 'bar'}) retval_node = node.body[0].body[0].value self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
def test_parameter_class_members(self): def test_fn(opt): opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def test_nested_members(self): def test_fn(): foo = training.GradientDescentOptimizer(0.1) foo.bar.baz() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) ctx = context.EntityContext(namer=None, source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, recursive=True) node = access.resolve(node, ctx) return node
def test_literals(self): def test_fn(): return Foo # pylint: disable=undefined-variable node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {}, {'Foo': 'bar'}) retval_node = node.body[0].body[0].value self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
def _parse_and_analyze(self, test_fn, namespace, arg_types=None): ctx = context.EntityContext(namer=None, source_code=None, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, namespace, {}) node = type_info.resolve(node, ctx) return node
def test_class_members(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) node = type_info.resolve(node, None) attr_call_node = node.body[0].body[1].value.func self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(attr_call_node, 'type_fqn'))
def test_function_variables(self): def bar(): pass def test_fn(): foo = bar foo() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'bar': bar}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def test_attribute_names(self): def test_fn(): return constant_op.constant(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'constant_op': constant_op}, {}) func_node = node.body[0].body[0].value.func self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) self.assertEquals((constant_op.__name__, 'constant'), anno.getanno(func_node, 'fqn'))
def test_constructor_deta_dependent(self): def test_fn(x): if x > 0: opt = training.GradientDescentOptimizer(0.1) else: opt = training.GradientDescentOptimizer(0.01) opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def test_attribute_names(self): def test_fn(): return constant_op.constant(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'constant_op': constant_op}, {}) func_node = node.body[0].body[0].value.func self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) self.assertEquals((constant_op.__name__, 'constant'), anno.getanno(func_node, 'fqn'))
def test_namespace(self): def foo(): return 'bar' def test_fn(): return foo() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'foo': foo}, {}) func_node = node.body[0].body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo', ), anno.getanno(func_node, 'fqn'))
def test_class_members(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) node = type_info.resolve(node, None) attr_call_node = node.body[0].body[1].value.func self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(attr_call_node, 'type_fqn'))
def test_constructor_deta_dependent(self): def test_fn(x): if x > 0: opt = training.GradientDescentOptimizer(0.1) else: opt = training.GradientDescentOptimizer(0.01) opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, {})
def test_function_variables(self): def bar(): pass def test_fn(): foo = bar foo() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'bar': bar}, {}) with self.assertRaises(ValueError): node = type_info.resolve(node, None)
def _parse_and_analyze(self, test_fn, namespace, arg_types=None): node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( namer=None, source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, recursive=True) node = access.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) node = type_info.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) return node
def test_namespace(self): def foo(): return 'bar' def test_fn(): return foo() node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'foo': foo}, {}) func_node = node.body[0].body[0].value.func self.assertEquals(foo, anno.getanno(func_node, 'live_val')) self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
def test_call(self): def test_fn(a): b = 0 c = 1 foo(a, b) # pylint:disable=undefined-variable return c node = parser.parse_object(test_fn) node = access.resolve(node) call_node = node.body[0].body[2].value # We basically need to detect which variables are captured by the call # arguments. self.assertScopeIs(anno.getanno(call_node, 'args_scope'), ('a', 'b'), (), ())
def test_constructor_detection(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) return opt node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) node = type_info.resolve(node, None) call_node = node.body[0].body[0].value self.assertEquals(training.GradientDescentOptimizer, anno.getanno(call_node, 'type')) self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(call_node, 'type_fqn'))
def test_call(self): def test_fn(a): b = 0 c = 1 foo(a, b) # pylint:disable=undefined-variable return c node = parser.parse_object(test_fn) node = access.resolve(node) call_node = node.body[0].body[2].value # We basically need to detect which variables are captured by the call # arguments. self.assertScopeIs( anno.getanno(call_node, 'args_scope'), ('a', 'b'), (), ())
def test_constructor_detection(self): def test_fn(): opt = training.GradientDescentOptimizer(0.1) return opt node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) node = type_info.resolve(node, None) call_node = node.body[0].body[0].value self.assertEquals(training.GradientDescentOptimizer, anno.getanno(call_node, 'type')) self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(call_node, 'type_fqn'))
def test_while(self): def test_fn(a): b = a while b > 0: c = b b -= 1 return b, c node = parser.parse_object(test_fn) node = access.resolve(node) while_node = node.body[0].body[1] self.assertScopeIs(anno.getanno(while_node, 'body_scope'), ('b', ), ('b', 'c'), ('c', )) self.assertScopeIs(anno.getanno(while_node, 'body_parent_scope'), ('a', 'b', 'c'), ('a', 'b', 'c'), ('a', 'b', 'c'))
def test_parameter_class_members_with_value_hints(self): def test_fn(opt): opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) node = type_info.resolve( node, { 'opt': (('%s.GradientDescentOptimizer' % training.__name__), training.GradientDescentOptimizer(0.1)) }) attr_call_node = node.body[0].body[0].value.func self.assertEquals( training.__name__.split('.') + ['GradientDescentOptimizer'], anno.getanno(attr_call_node, 'type_fqn'))
def test_local_markers(self): def test_fn(a): # pylint:disable=unused-argument b = c # pylint:disable=undefined-variable while b > 0: b -= 1 return b node = parser.parse_object(test_fn) node = access.resolve(node) self.assertFalse(anno.getanno(node.body[0].body[0].value, 'is_local')) # c in b = c self.assertTrue(anno.getanno(node.body[0].body[1].test.left, 'is_local')) # b in b > 0 self.assertTrue(anno.getanno(node.body[0].body[2].value, 'is_local')) # b in return b
def test_local_markers(self): def test_fn(a): # pylint:disable=unused-argument b = c # pylint:disable=undefined-variable while b > 0: b -= 1 return b node = parser.parse_object(test_fn) node = access.resolve(node) self.assertFalse(anno.getanno(node.body[0].body[0].value, 'is_local')) # c in b = c self.assertTrue( anno.getanno(node.body[0].body[1].test.left, 'is_local')) # b in b > 0 self.assertTrue(anno.getanno(node.body[0].body[2].value, 'is_local')) # b in return b
def parse_and_analyze(self, test_fn, namespace, arg_types=None, include_type_analysis=True): ctx = context.EntityContext( namer=None, source_code=None, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, namespace, {}) if include_type_analysis: node = type_info.resolve(node, ctx) return node
def test_for(self): def test_fn(a): b = a for _ in a: c = b b -= 1 return b, c node = parser.parse_object(test_fn) node = access.resolve(node) for_node = node.body[0].body[1] self.assertScopeIs( anno.getanno(for_node, 'body_scope'), ('b',), ('b', 'c'), ('c',)) self.assertScopeIs( anno.getanno(for_node, 'body_parent_scope'), ('a', 'b', 'c'), ('a', 'b', 'c', '_'), ('a', 'b', 'c', '_'))
def test_call(self): def test_fn(a): b = 0 c = 1 foo(a, b) # pylint:disable=undefined-variable return c node = parser.parse_object(test_fn) node = access.resolve(node) call_node = node.body[0].body[2].value call_args_scope = anno.getanno(call_node, 'args_scope') # We basically need to detect which variables are captured by the call # arguments. self.assertItemsEqual(['a', 'b'], call_args_scope.used) self.assertItemsEqual([], call_args_scope.modified) self.assertItemsEqual([], call_args_scope.created)
def test_parameter_class_members_with_value_hints(self): def test_fn(opt): opt.minimize(0) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'training': training}, {}) node = type_info.resolve( node, { 'opt': (('%s.GradientDescentOptimizer' % training.__name__), training.GradientDescentOptimizer(0.1)) }) attr_call_node = node.body[0].body[0].value.func self.assertEquals( training.__name__.split('.') + ['GradientDescentOptimizer'], anno.getanno(attr_call_node, 'type_fqn'))
def test_call(self): def test_fn(a): b = 0 c = 1 foo(a, b) # pylint:disable=undefined-variable return c node = parser.parse_object(test_fn) node = access.resolve(node) call_node = node.body[0].body[2].value call_args_scope = anno.getanno(call_node, 'args_scope') # We basically need to detect which variables are captured by the call # arguments. self.assertItemsEqual(['a', 'b'], call_args_scope.used) self.assertItemsEqual([], call_args_scope.modified) self.assertItemsEqual([], call_args_scope.created)
def test_class_members_in_with_stmt(self): def test_fn(x): with session.Session() as sess: sess.run(x) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'session': session}, {}) node = type_info.resolve(node, {}) constructor_call = node.body[0].body[0].items[0].context_expr self.assertEquals(session.Session, anno.getanno(constructor_call, 'type')) self.assertEquals((session.__name__, 'Session'), anno.getanno(constructor_call, 'type_fqn')) member_call = node.body[0].body[0].body[0].value.func self.assertEquals((session.__name__, 'Session'), anno.getanno(member_call, 'type_fqn'))
def test_class_members_in_with_stmt(self): def test_fn(x): with session.Session() as sess: sess.run(x) node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, {'session': session}, {}) node = type_info.resolve(node, None) constructor_call = node.body[0].body[0].items[0].context_expr self.assertEquals(session.Session, anno.getanno(constructor_call, 'type')) self.assertEquals((session.__name__, 'Session'), anno.getanno(constructor_call, 'type_fqn')) member_call = node.body[0].body[0].body[0].value.func self.assertEquals((session.__name__, 'Session'), anno.getanno(member_call, 'type_fqn'))
def _parse_and_analyze(self, test_fn, namespace, literals=None, arg_types=None): literals = literals or {} arg_types = arg_types or {} node, source = parser.parse_entity(test_fn) ctx = context.EntityContext(namer=None, source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, recursive=True) node = access.resolve(node, ctx) node = live_values.resolve(node, ctx, literals) node = type_info.resolve(node, ctx) node = live_values.resolve(node, ctx, literals) return node
def test_while(self): def test_fn(a): b = a while b > 0: c = b b -= 1 return b, c node = parser.parse_object(test_fn) node = access.resolve(node) while_node = node.body[0].body[1] while_body_scope = anno.getanno(while_node, 'body_scope') while_parent_scope = anno.getanno(while_node, 'parent_scope') self.assertItemsEqual(['b'], while_body_scope.used) self.assertItemsEqual(['b', 'c'], while_body_scope.modified) self.assertItemsEqual(['c'], while_body_scope.created) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.used) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.modified) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.created)
def parse_and_analyze(self, test_fn, namespace, namer=None, arg_types=None, include_type_analysis=True, recursive=True): node, source = parser.parse_entity(test_fn) ctx = context.EntityContext(namer=namer, source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, recursive=recursive) node = access.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) if include_type_analysis: node = type_info.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) self.ctx = ctx return node
def test_while(self): def test_fn(a): b = a while b > 0: c = b b -= 1 return b, c node = parser.parse_object(test_fn) node = access.resolve(node) while_node = node.body[0].body[1] while_body_scope = anno.getanno(while_node, 'body_scope') while_parent_scope = anno.getanno(while_node, 'parent_scope') self.assertItemsEqual(['b'], while_body_scope.used) self.assertItemsEqual(['b', 'c'], while_body_scope.modified) self.assertItemsEqual(['c'], while_body_scope.created) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.used) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.modified) self.assertItemsEqual(['a', 'b', 'c'], while_parent_scope.created)
def _parse_and_analyze(self, test_fn, namespace): node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, namespace, {}) node = type_info.resolve(node, {}) return node
def _parse_and_analyze(self, test_fn, namespace): node = parser.parse_object(test_fn) node = access.resolve(node) return node
def _parse_and_analyze(self, test_fn, namespace): node = parser.parse_object(test_fn) node = access.resolve(node) return node
def _static_analysis_pass(node, source, f, namespace, value_hints): node = access.resolve(node) node = live_values.resolve(node, namespace, config.PYTHON_LITERALS) node = type_info.resolve(node, source, f, value_hints) return node
def _parse_and_analyze(self, test_fn, namespace): node = parser.parse_object(test_fn) node = access.resolve(node) node = live_values.resolve(node, namespace, {}) node = type_info.resolve(node, None) return node
def _static_analysis_pass(node, ctx): node = access.resolve(node) node = live_values.resolve(node, ctx.namespace, config.PYTHON_LITERALS) node = type_info.resolve(node, ctx) return node