Пример #1
0
  def test_for_variable_defined_in_body(self):
    def bad_for_loop(n):
      for i in range(n):
        s = i
      return s

    node, ctx = self.prepare(bad_for_loop, {})
    with self.assertRaises(transformer.AutographParseError):
      control_flow.transform(node, ctx)
Пример #2
0
    def test_for_variable_defined_in_body(self):
        def bad_for_loop(n):
            for i in range(n):
                s = i
            return s

        node, ctx = self.prepare(bad_for_loop, {})
        with self.assertRaises(NameError):
            control_flow.transform(node, ctx)
Пример #3
0
    def test_if_imbalanced_outputs(self):
        def test_fn(n):
            if n > 0:
                b = 4
            return b

        node, ctx = self.prepare(test_fn, {})
        with self.assertRaises(ValueError):
            control_flow.transform(node, ctx)
Пример #4
0
    def test_while_variable_defined_in_body(self):
        def bad_while_loop(n):
            while n > 0:
                n -= 1
                s = n
            return s

        node, ctx = self.prepare(bad_while_loop, {})
        with self.assertRaises(NameError):
            control_flow.transform(node, ctx)
Пример #5
0
  def test_while_variable_defined_in_body(self):
    def bad_while_loop(n):
      while n > 0:
        n -= 1
        s = n
      return s

    node, ctx = self.prepare(bad_while_loop, {})
    with self.assertRaises(transformer.AutographParseError):
      control_flow.transform(node, ctx)
Пример #6
0
  def test_if_imbalanced_outputs(self):

    def test_fn(n):
      if n > 0:
        b = 4
      return b

    node, ctx = self.prepare(test_fn, {})
    with self.assertRaises(transformer.AutographParseError):
      control_flow.transform(node, ctx)
Пример #7
0
    def transform_ast(self, node, ctx):
        # TODO(mdan): Insert list_comprehensions somewhere.
        unsupported_features_checker.verify(node)

        # Run initial analysis.
        graphs = cfg.build(node)
        node = qual_names.resolve(node)
        node = activity.resolve(node, ctx, None)
        node = reaching_definitions.resolve(node, ctx, graphs)
        anno.dup(
            node,
            {
                anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
            },
        )

        node = functions.transform(node, ctx)
        node = directives.transform(node, ctx)
        node = break_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
            node = asserts.transform(node, ctx)
        # Note: sequencing continue canonicalization before for loop one avoids
        # dealing with the extra loop increment operation that the for
        # canonicalization creates.
        node = continue_statements.transform(node, ctx)
        node = return_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.LISTS):
            node = lists.transform(node, ctx)
            node = slices.transform(node, ctx)
        node = call_trees.transform(node, ctx)
        node = control_flow.transform(node, ctx)
        node = conditional_expressions.transform(node, ctx)
        node = logical_expressions.transform(node, ctx)
        return node
Пример #8
0
    def test_no_origin_annotation(self):
        def test_fn(x):
            a = 0
            if x:
                a = random_ops.random_normal((2, 3),
                                             mean=0.0,
                                             dtype=dtypes.int32)
            else:
                a = 0
            return a

        node, ctx = self.prepare(test_fn, {
            'random_ops': random_ops,
            'dtypes': dtypes
        })
        # To simulate a function without origin info we use the control flow
        # converter which adds a function that lacks origin info so we will not have
        # a wrapping try/except that reraises the NotImplementedError as a
        # GraphConstructionError.
        node = control_flow.transform(node, ctx)
        node = error_handlers.transform(node, ctx)
        # TODO(b/111562364): remove run_cond from traceback.
        test_fn_try_body = node.body[0].body
        true_fn_body = test_fn_try_body[1].body
        false_fn_body = test_fn_try_body[2].body
        self.assertNotIn(gast.Try, true_fn_body)
        self.assertNotIn(gast.Try, false_fn_body)
Пример #9
0
def _apply_py_to_tf_passes(node, ctx):
  """Apply transformations from PyToTF to match tf.function tracing."""
  # TODO(fengliuai): we don't know which passes are required, thus we evalute
  # each one when the corresponding node is handled.
  # copied from PyToTF.transform_ast
  node = return_statements.transform(node, ctx, False)
  node = control_flow.transform(node, ctx)
  return node
Пример #10
0
    def test_for_iterated_expression(self):

        eval_count = [0]

        def count_evals(x):
            eval_count[0] += 1
            return x

        def test_fn(n):
            s = 0
            for e in count_evals(range(n)):
                s += e
            return s

        ns = {'count_evals': count_evals}
        node, ctx = self.prepare(test_fn, ns)
        node = control_flow.transform(node, ctx)

        with self.compiled(node, ns) as result:
            self.assertEqual(result.test_fn(5), 10)
            self.assertEqual(eval_count[0], 1)
Пример #11
0
  def test_for_iterated_expression(self):

    eval_count = [0]

    def count_evals(x):
      eval_count[0] += 1
      return x

    def test_fn(n):
      s = 0
      for e in count_evals(range(n)):
        s += e
      return s

    ns = {'count_evals': count_evals}
    node, ctx = self.prepare(test_fn, ns)
    node = control_flow.transform(node, ctx)

    with self.compiled(node, ns) as result:
      self.assertEqual(result.test_fn(5), 10)
      self.assertEqual(eval_count[0], 1)
Пример #12
0
    def transform_ast(self, node, ctx):
        unsupported_features_checker.verify(node)
        node = self.initial_analysis(node, ctx)

        node = functions.transform(node, ctx)
        node = directives.transform(node, ctx)
        node = break_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
            node = asserts.transform(node, ctx)
        # Note: sequencing continue canonicalization before for loop one avoids
        # dealing with the extra loop increment operation that the for
        # canonicalization creates.
        node = continue_statements.transform(node, ctx)
        node = return_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.LISTS):
            node = lists.transform(node, ctx)
            node = slices.transform(node, ctx)
        node = call_trees.transform(node, ctx)
        node = control_flow.transform(node, ctx)
        node = conditional_expressions.transform(node, ctx)
        node = logical_expressions.transform(node, ctx)
        node = variables.transform(node, ctx)
        return node