def test_for_variable_defined_in_body(self): def bad_for_loop(n): for i in range(n): s = i return s node, ctx = self.prepare(bad_for_loop, {}) with self.assertRaises(transformer.AutographParseError): control_flow.transform(node, ctx)
def test_for_variable_defined_in_body(self): def bad_for_loop(n): for i in range(n): s = i return s node, ctx = self.prepare(bad_for_loop, {}) with self.assertRaises(NameError): control_flow.transform(node, ctx)
def test_if_imbalanced_outputs(self): def test_fn(n): if n > 0: b = 4 return b node, ctx = self.prepare(test_fn, {}) with self.assertRaises(ValueError): control_flow.transform(node, ctx)
def test_while_variable_defined_in_body(self): def bad_while_loop(n): while n > 0: n -= 1 s = n return s node, ctx = self.prepare(bad_while_loop, {}) with self.assertRaises(NameError): control_flow.transform(node, ctx)
def test_while_variable_defined_in_body(self): def bad_while_loop(n): while n > 0: n -= 1 s = n return s node, ctx = self.prepare(bad_while_loop, {}) with self.assertRaises(transformer.AutographParseError): control_flow.transform(node, ctx)
def test_if_imbalanced_outputs(self): def test_fn(n): if n > 0: b = 4 return b node, ctx = self.prepare(test_fn, {}) with self.assertRaises(transformer.AutographParseError): control_flow.transform(node, ctx)
def transform_ast(self, node, ctx): # TODO(mdan): Insert list_comprehensions somewhere. unsupported_features_checker.verify(node) # Run initial analysis. graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) node = reaching_definitions.resolve(node, ctx, graphs) anno.dup( node, { anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, }, ) node = functions.transform(node, ctx) node = directives.transform(node, ctx) node = break_statements.transform(node, ctx) if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): node = asserts.transform(node, ctx) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = continue_statements.transform(node, ctx) node = return_statements.transform(node, ctx) if ctx.user.options.uses(converter.Feature.LISTS): node = lists.transform(node, ctx) node = slices.transform(node, ctx) node = call_trees.transform(node, ctx) node = control_flow.transform(node, ctx) node = conditional_expressions.transform(node, ctx) node = logical_expressions.transform(node, ctx) return node
def test_no_origin_annotation(self): def test_fn(x): a = 0 if x: a = random_ops.random_normal((2, 3), mean=0.0, dtype=dtypes.int32) else: a = 0 return a node, ctx = self.prepare(test_fn, { 'random_ops': random_ops, 'dtypes': dtypes }) # To simulate a function without origin info we use the control flow # converter which adds a function that lacks origin info so we will not have # a wrapping try/except that reraises the NotImplementedError as a # GraphConstructionError. node = control_flow.transform(node, ctx) node = error_handlers.transform(node, ctx) # TODO(b/111562364): remove run_cond from traceback. test_fn_try_body = node.body[0].body true_fn_body = test_fn_try_body[1].body false_fn_body = test_fn_try_body[2].body self.assertNotIn(gast.Try, true_fn_body) self.assertNotIn(gast.Try, false_fn_body)
def _apply_py_to_tf_passes(node, ctx): """Apply transformations from PyToTF to match tf.function tracing.""" # TODO(fengliuai): we don't know which passes are required, thus we evalute # each one when the corresponding node is handled. # copied from PyToTF.transform_ast node = return_statements.transform(node, ctx, False) node = control_flow.transform(node, ctx) return node
def test_for_iterated_expression(self): eval_count = [0] def count_evals(x): eval_count[0] += 1 return x def test_fn(n): s = 0 for e in count_evals(range(n)): s += e return s ns = {'count_evals': count_evals} node, ctx = self.prepare(test_fn, ns) node = control_flow.transform(node, ctx) with self.compiled(node, ns) as result: self.assertEqual(result.test_fn(5), 10) self.assertEqual(eval_count[0], 1)
def test_for_iterated_expression(self): eval_count = [0] def count_evals(x): eval_count[0] += 1 return x def test_fn(n): s = 0 for e in count_evals(range(n)): s += e return s ns = {'count_evals': count_evals} node, ctx = self.prepare(test_fn, ns) node = control_flow.transform(node, ctx) with self.compiled(node, ns) as result: self.assertEqual(result.test_fn(5), 10) self.assertEqual(eval_count[0], 1)
def transform_ast(self, node, ctx): unsupported_features_checker.verify(node) node = self.initial_analysis(node, ctx) node = functions.transform(node, ctx) node = directives.transform(node, ctx) node = break_statements.transform(node, ctx) if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): node = asserts.transform(node, ctx) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = continue_statements.transform(node, ctx) node = return_statements.transform(node, ctx) if ctx.user.options.uses(converter.Feature.LISTS): node = lists.transform(node, ctx) node = slices.transform(node, ctx) node = call_trees.transform(node, ctx) node = control_flow.transform(node, ctx) node = conditional_expressions.transform(node, ctx) node = logical_expressions.transform(node, ctx) node = variables.transform(node, ctx) return node