def node_to_graph(node, ctx, nocompile_decorators): """Convert Python code to equivalent TF graph mode code. Args: node: A Python AST node representing the code to convert. ctx: An EntityContext object. nocompile_decorators: A tuple containing decorators to be stripped from functions during conversion. Returns: A tuple (node, deps): * node: A Python ast node, representing the converted code. * deps: A set of strings, the fully qualified names of entity dependencies that this node has. """ # TODO(mdan): Verify arguments for correctness. # TODO(mdan): Factor out common elements. # These include: # * code move between blocks # * visiting blocks in transformers # Certain steps, especially canonicalization, insert new symbols into the # tree, which must be accounted. Although less efficient, it is most robust # to re-run the analysis. node = _static_analysis_pass(node, ctx) # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? ctx.source_code = None node, deps = decorators.transform(node, nocompile_decorators) node = break_statements.transform(node, ctx) node = asserts.transform(node, ctx) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = continue_statements.transform(node, ctx) ctx.namespace['len'] = len node = _static_analysis_pass(node, ctx) node = for_loops.transform(node, ctx) # for_loops may insert new global references. node = builtin_functions.transform(node, ctx) node = _static_analysis_pass(node, ctx) node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES, nocompile_decorators) node = control_flow.transform(node, ctx) # control_flow may create new symbols and change scopes. node = _static_analysis_pass(node, ctx) node = logical_expressions.transform(node) node = side_effect_guards.transform(node, ctx) node = name_scopes.transform(node, ctx) return node, deps
def test_basic_name(self): def test_fn(l): a = 5 l += a return l node = self.parse_and_analyze(test_fn, {}) node = name_scopes.transform(node, self.ctx) with self.compiled(node, ops.name_scope) as result: result_op = result.test_fn(constant_op.constant([1, 2, 3])) self.assertIn('test_fn/', result_op.op.name)
def test_nested_name(self): def test_fn(l): def body(i): return i**2 l += [4] return body(l) node = self.parse_and_analyze(test_fn, {}) node = name_scopes.transform(node, self.ctx) with self.compiled(node, ops.name_scope) as result: result_op = result.test_fn(constant_op.constant([1, 2, 3])) first_result_input_name = result_op.op.inputs[0].name second_result_input_name = result_op.op.inputs[1].name self.assertIn('test_fn/', first_result_input_name) self.assertNotIn('body/', first_result_input_name) self.assertIn('test_fn/body/', second_result_input_name)
def test_class_name(self): class TestClass(object): def test_fn(self, l): def body(i): return i**2 l += [4] return body(l) # Note that 'TestClass' was needed in the namespace here. node = self.parse_and_analyze(TestClass, {'TestClass': TestClass}, owner_type=TestClass) node = name_scopes.transform(node, self.ctx) with self.compiled(node, ops.name_scope) as result: result_op = result.TestClass().test_fn( constant_op.constant([1, 2, 3])) first_result_input_name = result_op.op.inputs[0].name second_result_input_name = result_op.op.inputs[1].name self.assertIn('TestClass/test_fn/', first_result_input_name) self.assertNotIn('body/', first_result_input_name) self.assertIn('TestClass/test_fn/body/', second_result_input_name)
def test_class_name(self): class TestClass(object): def test_fn(self, l): def body(i): return i**2 l += [4] return body(l) # Note that 'TestClass' was needed in the namespace here. node = self.parse_and_analyze( TestClass, {'TestClass': TestClass}, owner_type=TestClass) node = name_scopes.transform(node, self.ctx) with self.compiled(node, ops.name_scope) as result: result_op = result.TestClass().test_fn(constant_op.constant([1, 2, 3])) first_result_input_name = result_op.op.inputs[0].name second_result_input_name = result_op.op.inputs[1].name self.assertIn('TestClass/test_fn/', first_result_input_name) self.assertNotIn('body/', first_result_input_name) self.assertIn('TestClass/test_fn/body/', second_result_input_name)