def test_for_variable_defined_in_body(self): def bad_for_loop(n): for i in range(n): s = i return s node, ctx = self.prepare(bad_for_loop, {}) with self.assertRaises(transformer.AutographParseError): control_flow.transform(node, ctx)
def test_if_imbalanced_outputs(self): def test_fn(n): if n > 0: b = 4 return b node, ctx = self.prepare(test_fn, {}) with self.assertRaises(transformer.AutographParseError): control_flow.transform(node, ctx)
def test_while_variable_defined_in_body(self): def bad_while_loop(n): while n > 0: n -= 1 s = n return s node, ctx = self.prepare(bad_while_loop, {}) with self.assertRaises(transformer.AutographParseError): control_flow.transform(node, ctx)
def test_handle_temp_variable(self): def test_fn_using_temp(x, y, w): if x < y: z = x + y else: w = 2 tmp = w z = x - tmp return z, w node = self.parse_and_analyze(test_fn_using_temp, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: with self.test_session() as sess: z, w = sess.run( result.test_fn_using_temp(constant_op.constant(-3), constant_op.constant(3), constant_op.constant(3))) self.assertEqual(0, z) self.assertEqual(3, w) z, w = sess.run( result.test_fn_using_temp(constant_op.constant(3), constant_op.constant(-3), constant_op.constant(3))) self.assertEqual(1, z) self.assertEqual(2, w) def test_fn_ignoring_temp(x, y, w): if x < y: z = x + y else: w = 2 tmp = w z = x - tmp return z node = self.parse_and_analyze(test_fn_ignoring_temp, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: with self.test_session() as sess: z = sess.run( result.test_fn_ignoring_temp(constant_op.constant(-3), constant_op.constant(3), constant_op.constant(3))) self.assertEqual(0, z) z = sess.run( result.test_fn_ignoring_temp(constant_op.constant(3), constant_op.constant(-3), constant_op.constant(3))) self.assertEqual(1, z)
def test_handle_temp_variable(self): def test_fn_using_temp(x, y, w): if x < y: z = x + y else: w = 2 tmp = w z = x - tmp return z, w node = self.parse_and_analyze(test_fn_using_temp, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: with self.test_session() as sess: z, w = sess.run( result.test_fn_using_temp( constant_op.constant(-3), constant_op.constant(3), constant_op.constant(3))) self.assertEqual(0, z) self.assertEqual(3, w) z, w = sess.run( result.test_fn_using_temp( constant_op.constant(3), constant_op.constant(-3), constant_op.constant(3))) self.assertEqual(1, z) self.assertEqual(2, w) def test_fn_ignoring_temp(x, y, w): if x < y: z = x + y else: w = 2 tmp = w z = x - tmp return z node = self.parse_and_analyze(test_fn_ignoring_temp, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: with self.test_session() as sess: z = sess.run( result.test_fn_ignoring_temp( constant_op.constant(-3), constant_op.constant(3), constant_op.constant(3))) self.assertEqual(0, z) z = sess.run( result.test_fn_ignoring_temp( constant_op.constant(3), constant_op.constant(-3), constant_op.constant(3))) self.assertEqual(1, z)
def test_simple_for(self): def test_fn(l): s1 = 0 s2 = 0 for e in l: s1 += e s2 += e * e return s1, s2 node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node) as result: with self.test_session() as sess: l = [1, 2, 3] self.assertEqual( test_fn(l), sess.run(result.test_fn(constant_op.constant(l)))) l = [] self.assertEqual( test_fn(l), sess.run( result.test_fn( constant_op.constant(l, shape=(0, ), dtype=dtypes.int32))))
def test_while_single_var(self): def test_fn(n): while n > 0: n -= 1 return n node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.while_loop) as result: with self.test_session() as sess: self.assertEqual( 0, sess.run(result.test_fn(constant_op.constant(5))))
def test_if_single_var(self): def test_fn(n): if n > 0: n = -n return n node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.cond) as result: with self.test_session() as sess: self.assertEqual( -1, sess.run(result.test_fn(constant_op.constant(1))))
def test_if_single_var(self): def test_fn(n): if n > 0: n = -n return n node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.cond) as result: with self.test_session() as sess: self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
def test_while_single_var(self): def test_fn(n): while n > 0: n -= 1 return n node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.while_loop) as result: with self.test_session() as sess: self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
def test_imbalanced_aliasing(self): def test_fn(n): if n > 0: n = 3 return n node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.cond) as result: with self.test_session() as sess: self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(2)))) self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3))))
def test_ignore_unread_variable(self): def test_fn(n): b = 3 # pylint: disable=unused-variable if n > 0: b = 4 return n node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: with self.test_session() as sess: self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(3)))) self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3))))
def test_simple_while(self): def test_fn(n): i = 0 s = 0 while i < n: s += i i += 1 return s, i, n node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.while_loop) as result: with self.test_session() as sess: self.assertEqual((10, 5, 5), sess.run(result.test_fn(constant_op.constant(5))))
def test_simple_while(self): def test_fn(n): i = 0 s = 0 while i < n: s += i i += 1 return s, i, n node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.while_loop) as result: with self.test_session() as sess: self.assertEqual( (10, 5, 5), sess.run(result.test_fn(constant_op.constant(5))))
def test_simple_if(self): def test_fn(n): a = 0 b = 0 if n > 0: a = -n else: b = 2 * n return a, b node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.cond) as result: with self.test_session() as sess: self.assertEqual((-1, 0), sess.run(result.test_fn(constant_op.constant(1)))) self.assertEqual((0, -2), sess.run(result.test_fn(constant_op.constant(-1))))
def test_simple_if(self): def test_fn(n): a = 0 b = 0 if n > 0: a = -n else: b = 2 * n return a, b node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node, control_flow_ops.cond) as result: with self.test_session() as sess: self.assertEqual( (-1, 0), sess.run(result.test_fn(constant_op.constant(1)))) self.assertEqual( (0, -2), sess.run(result.test_fn(constant_op.constant(-1))))
def test_for_iterated_expression(self): eval_count = [0] def count_evals(x): eval_count[0] += 1 return x def test_fn(n): s = 0 for e in count_evals(range(n)): s += e return s ns = {'count_evals': count_evals} node, ctx = self.prepare(test_fn, ns) node = control_flow.transform(node, ctx) with self.compiled(node, ns) as result: self.assertEqual(result.test_fn(5), 10) self.assertEqual(eval_count[0], 1)
def test_for_with_iterated_expression(self): eval_count = [0] def count_evals(x): eval_count[0] += 1 return x def test_fn(n): s = 0 for e in count_evals(range(n)): s += e return s node = self.parse_and_analyze(test_fn, {'count_evals': count_evals}) node = control_flow.transform(node, self.ctx) with self.compiled(node) as result: result.count_evals = count_evals self.assertEqual(test_fn(5), result.test_fn(5)) # count_evals ran twice, once for test_fn and another for result.test_fn self.assertEqual(eval_count[0], 2)
def test_for_single_var(self): def test_fn(l): s = 0 for e in l: s += e return s node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) with self.compiled(node) as result: with self.test_session() as sess: l = [1, 2, 3] self.assertEqual( test_fn(l), sess.run(result.test_fn(constant_op.constant(l)))) l = [] self.assertEqual( test_fn(l), sess.run( result.test_fn( constant_op.constant(l, shape=(0,), dtype=dtypes.int32))))
def node_to_graph(node, ctx, nocompile_decorators): """Convert Python code to equivalent TF graph mode code. Args: node: A Python AST node representing the code to convert. ctx: An EntityContext object. nocompile_decorators: A tuple containing decorators to be stripped from functions during conversion. Returns: A tuple (node, deps): * node: A Python ast node, representing the converted code. * deps: A set of strings, the fully qualified names of entity dependencies that this node has. """ # TODO(mdan): Verify arguments for correctness. # TODO(mdan): Factor out common elements. # These include: # * code move between blocks # * visiting blocks in transformers # Certain steps, especially canonicalization, insert new symbols into the # tree, which must be accounted. Although less efficient, it is most robust # to re-run the analysis. node = _static_analysis_pass(node, ctx) # TODO(mdan): Clean this up. # Some intermediate analyses are not required, and some comments got orphaned. # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? ctx.source_code = None node = ifexp.transform(node, ctx) node, deps = decorators.transform(node, nocompile_decorators) node = break_statements.transform(node, ctx) node = asserts.transform(node, ctx) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = continue_statements.transform(node, ctx) ctx.namespace['len'] = len node = _static_analysis_pass(node, ctx) node = single_return.transform(node, ctx) node = _static_analysis_pass(node, ctx) node = lists.transform(node, ctx) node = builtin_functions.transform(node, ctx) node = _static_analysis_pass(node, ctx) node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES, nocompile_decorators) node = control_flow.transform(node, ctx) # control_flow may create new symbols and change scopes. node = _static_analysis_pass(node, ctx) node = logical_expressions.transform(node, ctx) node = side_effect_guards.transform(node, ctx) node = name_scopes.transform(node, ctx) return node, deps