Exemplo n.º 1
0
  def test_local_scope_info_stack_checks_integrity(self):

    class TestTransformer(transformer.Base):

      def visit_If(self, node):
        self.enter_local_scope()
        return self.generic_visit(node)

      def visit_For(self, node):
        node = self.generic_visit(node)
        self.exit_local_scope()
        return node

    tr = TestTransformer(self._simple_context())

    def no_exit(a):
      if a > 0:
        print(a)
      return None

    node, _ = parser.parse_entity(no_exit, future_features=())
    with self.assertRaises(AssertionError):
      tr.visit(node)

    def no_entry(a):
      for _ in a:
        print(a)

    node, _ = parser.parse_entity(no_entry, future_features=())
    with self.assertRaises(AssertionError):
      tr.visit(node)
Exemplo n.º 2
0
  def test_parse_comments(self):

    def f():
# unindented comment
      pass

    with self.assertRaises(ValueError):
      parser.parse_entity(f, future_features=())
Exemplo n.º 3
0
  def test_parse_multiline_strings(self):
    def f():
      print("""
some
multiline
string""")
    with self.assertRaises(ValueError):
      parser.parse_entity(f)
Exemplo n.º 4
0
  def test_origin_info_preserved_in_moved_nodes(self):

    class TestTransformer(transformer.Base):

      def visit_If(self, node):
        return node.body

    tr = TestTransformer(self._simple_context())

    def test_fn():
      x = 1
      if x > 0:
        x = 1
        x += 3
      return x

    node, source = parser.parse_entity(test_fn, future_features=())
    origin_info.resolve(node, source, 'test_file', 100, 0)
    node = tr.visit(node)

    assign_node = node.body[1]
    aug_assign_node = node.body[2]
    # Keep their original line numbers.
    self.assertEqual(
        anno.getanno(assign_node, anno.Basic.ORIGIN).loc.lineno, 103)
    self.assertEqual(
        anno.getanno(aug_assign_node, anno.Basic.ORIGIN).loc.lineno, 104)
Exemplo n.º 5
0
  def test_robust_error_on_ast_corruption(self):
    # A child class should not be able to be so broken that it causes the error
    # handling in `transformer.Base` to raise an exception.  Why not?  Because
    # then the original error location is dropped, and an error handler higher
    # up in the call stack gives misleading information.

    # Here we test that the error handling in `visit` completes, and blames the
    # correct original exception, even if the AST gets corrupted.

    class NotANode(object):
      pass

    class BrokenTransformer(transformer.Base):

      def visit_If(self, node):
        node.body = NotANode()
        raise ValueError('I blew up')

    def test_function(x):
      if x > 0:
        return x

    tr = BrokenTransformer(self._simple_context())

    node, _ = parser.parse_entity(test_function, future_features=())
    with self.assertRaises(ValueError) as cm:
      node = tr.visit(node)
    obtained_message = str(cm.exception)
    # The message should reference the exception actually raised, not anything
    # from the exception handler.
    expected_substring = 'I blew up'
    self.assertTrue(expected_substring in obtained_message, obtained_message)
Exemplo n.º 6
0
    def test_parse_comments(self):
        def f():
            # unindented comment
            pass

        node, _ = parser.parse_entity(f, future_features=())
        self.assertEqual('f', node.name)
Exemplo n.º 7
0
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    node = node.body[0]
    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = program_ctx.new_namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types,
                                         owner_type=owner_type)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    node = node_to_graph(node, context)

    # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py
    new_name, did_rename = namer.compiled_function_name(
        f.__name__, f, owner_type)
    if not did_rename:
        new_name = f.__name__
        if node.name != f.__name__:
            raise NotImplementedError(
                'Strange corner case. Send us offending code!')
    node.name = new_name

    program_ctx.update_name_map(namer)
    # TODO(mdan): Use this at compilation.

    return [node], new_name, namespace
Exemplo n.º 8
0
  def test_ext_slice_roundtrip(self):
    def ext_slice(n):
      return n[:, :], n[0, :], n[:, 0]

    node, _ = parser.parse_entity(ext_slice, future_features=())
    source = parser.unparse(node)
    self.assertAstMatches(node, source, expr=False)
Exemplo n.º 9
0
  def test_parse_entity_print_function(self):

    def f(x):
      print(x)

    node, _ = parser.parse_entity(f, future_features=('print_function',))
    self.assertEqual('f', node.name)
Exemplo n.º 10
0
  def test_parse_entity(self):

    def f(x):
      return x + 1

    node, _ = parser.parse_entity(f, future_features=())
    self.assertEqual('f', node.name)
Exemplo n.º 11
0
  def test_parse_lambda_complex_body(self):

    l = lambda x: (  # pylint:disable=g-long-lambda
        x.y(
            [],
            x.z,
            (),
            x[0:2],
        ),
        x.u,
        'abc',
        1,
    )

    node, source = parser.parse_entity(l, future_features=())
    expected_node_src = "lambda x: (x.y([], x.z, (), x[0:2]), x.u, 'abc', 1)"
    self.assertAstMatches(node, expected_node_src)

    base_source = ('lambda x: (  # pylint:disable=g-long-lambda\n'
                   '        x.y(\n'
                   '            [],\n'
                   '            x.z,\n'
                   '            (),\n'
                   '            x[0:2],\n'
                   '        ),\n'
                   '        x.u,\n'
                   '        \'abc\',\n'
                   '        1,')
    # The complete source includes the trailing parenthesis. But that is only
    # detected in runtimes which correctly track end_lineno for ASTs.
    self.assertMatchesWithPotentialGarbage(source, base_source, '\n    )')
Exemplo n.º 12
0
  def test_robust_error_on_ast_corruption(self):
    # A child class should not be able to be so broken that it causes the error
    # handling in `transformer.Base` to raise an exception.  Why not?  Because
    # then the original error location is dropped, and an error handler higher
    # up in the call stack gives misleading information.

    # Here we test that the error handling in `visit` completes, and blames the
    # correct original exception, even if the AST gets corrupted.

    class NotANode(object):
      pass

    class BrokenTransformer(transformer.Base):

      def visit_If(self, node):
        node.body = NotANode()
        raise ValueError('I blew up')

    def test_function(x):
      if x > 0:
        return x

    tr = BrokenTransformer(self._simple_context())

    node, _ = parser.parse_entity(test_function, future_features=())
    with self.assertRaises(ValueError) as cm:
      node = tr.visit(node)
    obtained_message = str(cm.exception)
    # The message should reference the exception actually raised, not anything
    # from the exception handler.
    expected_substring = 'I blew up'
    self.assertTrue(expected_substring in obtained_message, obtained_message)
Exemplo n.º 13
0
  def test_resolve(self):

    def test_fn(x):
      """Docstring."""
      return x  # comment

    node, source = parser.parse_entity(test_fn)
    fn_node = node.body[0]

    origin_info.resolve(fn_node, source)

    origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 1)
    self.assertEqual(origin.loc.col_offset, 0)
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 2)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  """Docstring."""')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 3)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  return x  # comment')
    self.assertEqual(origin.comment, 'comment')
Exemplo n.º 14
0
  def test_state_tracking_context_manager(self):

    class CondState(object):
      pass

    class TestTransformer(transformer.Base):

      def visit(self, node):
        anno.setanno(node, 'cond_state', self.state[CondState].value)
        return super(TestTransformer, self).visit(node)

      def visit_If(self, node):
        with self.state[CondState]:
          return self.generic_visit(node)

    tr = TestTransformer(self._simple_context())

    def test_function(a):
      a = 1
      if a > 2:
        _ = 'b'
        if a < 5:
          _ = 'c'
        _ = 'd'

    node, _ = parser.parse_entity(test_function, future_features=())
    node = tr.visit(node)

    fn_body = node.body
    outer_if_body = fn_body[1].body
    self.assertDifferentAnno(fn_body[0], outer_if_body[0], 'cond_state')
    self.assertSameAnno(outer_if_body[0], outer_if_body[2], 'cond_state')

    inner_if_body = outer_if_body[1].body
    self.assertDifferentAnno(inner_if_body[0], outer_if_body[0], 'cond_state')
Exemplo n.º 15
0
  def test_origin_info_preserved_in_moved_nodes(self):

    class TestTransformer(transformer.Base):

      def visit_If(self, node):
        return node.body

    tr = TestTransformer(self._simple_context())

    def test_fn():
      x = 1
      if x > 0:
        x = 1
        x += 3
      return x

    node, source = parser.parse_entity(test_fn, future_features=())
    origin_info.resolve(node, source)
    node = tr.visit(node)

    assign_node = node.body[1]
    aug_assign_node = node.body[2]
    self.assertEqual(
        anno.getanno(assign_node, anno.Basic.ORIGIN).loc.lineno, 4)
    self.assertEqual(
        anno.getanno(aug_assign_node, anno.Basic.ORIGIN).loc.lineno, 5)
Exemplo n.º 16
0
 def assert_body_anfs_as_expected(self, expected_fn, test_fn, config=None):
     # Testing the code bodies only.  Wrapping them in functions so the
     # syntax highlights nicely, but Python doesn't try to execute the
     # statements.
     exp_node, _ = parser.parse_entity(expected_fn, future_features=())
     node, _ = parser.parse_entity(test_fn, future_features=())
     node = anf.transform(node, self._simple_context(), config=config)
     exp_name = exp_node.name
     # Ignoring the function names in the result because they can't be
     # the same (because both functions have to exist in the same scope
     # at the same time).
     node.name = exp_name
     self.assert_same_ast(exp_node, node)
     # Check that ANF is idempotent
     node_repeated = anf.transform(node, self._simple_context())
     self.assert_same_ast(node_repeated, node)
Exemplo n.º 17
0
    def test_resolve(self):
        def test_fn(x):
            """Docstring."""
            return x  # comment

        node, _, source = parser.parse_entity(test_fn, future_imports=())

        origin_info.resolve(node, source)

        origin = anno.getanno(node, anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 1)
        self.assertEqual(origin.loc.col_offset, 0)
        self.assertEqual(origin.source_code_line, 'def test_fn(x):')
        self.assertIsNone(origin.comment)

        origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 2)
        self.assertEqual(origin.loc.col_offset, 2)
        self.assertEqual(origin.source_code_line, '  """Docstring."""')
        self.assertIsNone(origin.comment)

        origin = anno.getanno(node.body[1], anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 3)
        self.assertEqual(origin.loc.col_offset, 2)
        self.assertEqual(origin.source_code_line, '  return x  # comment')
        self.assertEqual(origin.comment, 'comment')
Exemplo n.º 18
0
  def _should_compile(self, node, fqn):
    """Determines whether an entity should be compiled in the context."""
    # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
    module_name = fqn[0]
    for mod in self.ctx.program.uncompiled_modules:
      if module_name.startswith(mod[0] + '.'):
        return False

    for i in range(1, len(fqn)):
      if fqn[:i] in self.ctx.program.uncompiled_modules:
        return False

    # Check for local decorations
    if anno.hasanno(node, 'graph_ready'):
      return False

    # The decorators themselves are not to be converted.
    # If present, the decorators should appear as static functions.
    target_entity = self._try_resolve_target(node.func)

    if target_entity is not None:

      # This may be reached when "calling" a callable attribute of an object.
      # For example:
      #
      #   self.fc = tf.keras.layers.Dense()
      #   self.fc()
      #
      for mod in self.ctx.program.uncompiled_modules:
        if target_entity.__module__.startswith(mod[0] + '.'):
          return False

      # This attribute is set by the decorator itself.
      # TODO(mdan): This may not play nicely with other wrapping decorators.
      if hasattr(target_entity, '__pyct_is_compile_decorator'):
        return False

      if target_entity in self.ctx.program.options.strip_decorators:
        return False

      # Inspect the target function decorators. If any include a @convert
      # or @graph_ready annotation, then they must be called as they are.
      # TODO(mdan): This may be quite heavy.
      # To parse and re-analyze each function for every call site could be quite
      # wasteful. Maybe we could cache the parsed AST?
      try:
        target_node, _ = parser.parse_entity(target_entity)
        target_node = target_node.body[0]
      except TypeError:
        # Functions whose source we cannot access are compilable (e.g. wrapped
        # to py_func).
        return True

      for dec in target_node.decorator_list:
        decorator_fn = self._resolve_name(dec)
        if (decorator_fn is not None and
            decorator_fn in self.ctx.program.options.strip_decorators):
          return False

    return True
Exemplo n.º 19
0
  def test_visit_block_postprocessing(self):

    class TestTransformer(transformer.Base):

      def _process_body_item(self, node):
        if isinstance(node, gast.Assign) and (node.value.id == 'y'):
          if_node = gast.If(gast.Name('x', gast.Load(), None), [node], [])
          return if_node, if_node.body
        return node, None

      def visit_FunctionDef(self, node):
        node.body = self.visit_block(
            node.body, after_visit=self._process_body_item)
        return node

    def test_function(x, y):
      z = x
      z = y
      return z

    tr = TestTransformer(self._simple_context())

    node, _ = parser.parse_entity(test_function, future_features=())
    node = tr.visit(node)

    self.assertEqual(len(node.body), 2)
    self.assertTrue(isinstance(node.body[0], gast.Assign))
    self.assertTrue(isinstance(node.body[1], gast.If))
    self.assertTrue(isinstance(node.body[1].body[0], gast.Assign))
    self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
Exemplo n.º 20
0
  def test_robust_error_on_list_visit(self):

    class BrokenTransformer(transformer.Base):

      def visit_If(self, node):
        # This is broken because visit expects a single node, not a list, and
        # the body of an if is a list.
        # Importantly, the default error handling in visit also expects a single
        # node.  Therefore, mistakes like this need to trigger a type error
        # before the visit called here installs its error handler.
        # That type error can then be caught by the enclosing call to visit,
        # and correctly blame the If node.
        self.visit(node.body)
        return node

    def test_function(x):
      if x > 0:
        return x

    tr = BrokenTransformer(self._simple_context())

    node, _ = parser.parse_entity(test_function, future_features=())
    with self.assertRaises(ValueError) as cm:
      node = tr.visit(node)
    obtained_message = str(cm.exception)
    expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
    self.assertRegexpMatches(obtained_message, expected_message)
Exemplo n.º 21
0
    def _should_compile(self, node, fqn):
        """Determines whether an entity should be compiled in the context."""
        # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
        module_name = fqn[0]
        for mod in self.ctx.program.uncompiled_modules:
            if module_name.startswith(mod[0] + '.'):
                return False

        for i in range(1, len(fqn)):
            if fqn[:i] in self.ctx.program.uncompiled_modules:
                return False

        # Check for local decorations
        if anno.hasanno(node, 'graph_ready'):
            return False

        # The decorators themselves are not to be converted.
        # If present, the decorators should appear as static functions.
        target_entity = self._try_resolve_target(node.func)

        if target_entity is not None:

            # This may be reached when "calling" a callable attribute of an object.
            # For example:
            #
            #   self.fc = tf.keras.layers.Dense()
            #   self.fc()
            #
            for mod in self.ctx.program.uncompiled_modules:
                if target_entity.__module__.startswith(mod[0] + '.'):
                    return False

            # This attribute is set by the decorator itself.
            # TODO(mdan): This may not play nicely with other wrapping decorators.
            if hasattr(target_entity, '__pyct_is_compile_decorator'):
                return False

            if target_entity in self.ctx.program.options.strip_decorators:
                return False

            # Inspect the target function decorators. If any include a @convert
            # or @graph_ready annotation, then they must be called as they are.
            # TODO(mdan): This may be quite heavy.
            # To parse and re-analyze each function for every call site could be quite
            # wasteful. Maybe we could cache the parsed AST?
            try:
                target_node, _ = parser.parse_entity(target_entity)
                target_node = target_node.body[0]
            except TypeError:
                # Functions whose source we cannot access are compilable (e.g. wrapped
                # to py_func).
                return True

            for dec in target_node.decorator_list:
                decorator_fn = self._resolve_name(dec)
                if (decorator_fn is not None and decorator_fn
                        in self.ctx.program.options.strip_decorators):
                    return False

        return True
Exemplo n.º 22
0
  def test_parse_entity(self):

    def f(x):
      return x + 1

    mod, _ = parser.parse_entity(f)
    self.assertEqual('f', mod.body[0].name)
Exemplo n.º 23
0
    def disabled_test_resolve_with_future_imports(self):
        def test_fn(x):
            """Docstring."""
            print(x)
            return x  # comment

        node, source, _ = parser.parse_entity(test_fn)

        origin_info.resolve(node, source)

        origin = anno.getanno(node, anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 2)
        self.assertEqual(origin.loc.col_offset, 0)
        self.assertEqual(origin.source_code_line, 'def test_fn(x):')
        self.assertIsNone(origin.comment)

        origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 3)
        self.assertEqual(origin.loc.col_offset, 2)
        self.assertEqual(origin.source_code_line, '  """Docstring."""')
        self.assertIsNone(origin.comment)

        origin = anno.getanno(node.body[2], anno.Basic.ORIGIN)
        self.assertEqual(origin.loc.lineno, 5)
        self.assertEqual(origin.loc.col_offset, 2)
        self.assertEqual(origin.source_code_line, '  return x  # comment')
        self.assertEqual(origin.comment, 'comment')
Exemplo n.º 24
0
  def test_visit_block_postprocessing(self):

    class TestTransformer(transformer.Base):

      def _process_body_item(self, node):
        if isinstance(node, gast.Assign) and (node.value.id == 'y'):
          if_node = gast.If(
              gast.Name(
                  'x', ctx=gast.Load(), annotation=None, type_comment=None),
              [node], [])
          return if_node, if_node.body
        return node, None

      def visit_FunctionDef(self, node):
        node.body = self.visit_block(
            node.body, after_visit=self._process_body_item)
        return node

    def test_function(x, y):
      z = x
      z = y
      return z

    tr = TestTransformer(self._simple_context())

    node, _ = parser.parse_entity(test_function, future_features=())
    node = tr.visit(node)

    self.assertEqual(len(node.body), 2)
    self.assertIsInstance(node.body[0], gast.Assign)
    self.assertIsInstance(node.body[1], gast.If)
    self.assertIsInstance(node.body[1].body[0], gast.Assign)
    self.assertIsInstance(node.body[1].body[1], gast.Return)
Exemplo n.º 25
0
    def test_robust_error_on_list_visit(self):
        class BrokenTransformer(transformer.Base):
            def visit_If(self, node):
                # This is broken because visit expects a single node, not a list, and
                # the body of an if is a list.
                # Importantly, the default error handling in visit also expects a single
                # node.  Therefore, mistakes like this need to trigger a type error
                # before the visit called here installs its error handler.
                # That type error can then be caught by the enclosing call to visit,
                # and correctly blame the If node.
                self.visit(node.body)
                return node

        def test_function(x):
            if x > 0:
                return x

        tr = BrokenTransformer(self._simple_context())

        node, _ = parser.parse_entity(test_function, future_features=())
        with self.assertRaises(ValueError) as cm:
            node = tr.visit(node)
        obtained_message = str(cm.exception)
        expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
        self.assertRegexpMatches(obtained_message, expected_message)
        # The exception should point at the if statement, not any place else.  Could
        # also check the stack trace.
        self.assertTrue('Occurred at node:\nIf' in obtained_message,
                        obtained_message)
        self.assertTrue(
            'Occurred at node:\nFunctionDef' not in obtained_message,
            obtained_message)
        self.assertTrue('Occurred at node:\nReturn' not in obtained_message,
                        obtained_message)
Exemplo n.º 26
0
  def test_parse_entity_print_function(self):

    def f(x):
      print(x)

    node, _ = parser.parse_entity(f, future_features=('print_function',))
    self.assertEqual('f', node.name)
Exemplo n.º 27
0
    def prepare(self,
                test_fn,
                namespace,
                namer=None,
                arg_types=None,
                owner_type=None,
                recursive=True,
                strip_decorators=()):
        namespace['ConversionOptions'] = converter.ConversionOptions

        node, source = parser.parse_entity(test_fn)
        node = node.body[0]
        if namer is None:
            namer = FakeNamer()
        program_ctx = converter.ProgramContext(
            options=converter.ConversionOptions(
                recursive=recursive,
                strip_decorators=strip_decorators,
                verbose=True),
            partial_types=None,
            autograph_module=None,
            uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
        entity_info = transformer.EntityInfo(source_code=source,
                                             source_file='<fragment>',
                                             namespace=namespace,
                                             arg_values=None,
                                             arg_types=arg_types,
                                             owner_type=owner_type)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)
        origin_info.resolve(node, source, test_fn)
        node = converter.standard_analysis(node, ctx, is_initial=True)
        return node, ctx
Exemplo n.º 28
0
  def prepare(self,
              test_fn,
              namespace,
              namer=None,
              arg_types=None,
              owner_type=None,
              recursive=True,
              strip_decorators=()):
    namespace['ConversionOptions'] = converter.ConversionOptions

    node, source = parser.parse_entity(test_fn)
    node = node.body[0]
    if namer is None:
      namer = FakeNamer()
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            strip_decorators=strip_decorators,
            verbose=True),
        partial_types=None,
        autograph_module=None,
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    entity_info = transformer.EntityInfo(
        source_code=source,
        source_file='<fragment>',
        namespace=namespace,
        arg_values=None,
        arg_types=arg_types,
        owner_type=owner_type)
    ctx = converter.EntityContext(namer, entity_info, program_ctx)
    origin_info.resolve(node, source, test_fn)
    node = converter.standard_analysis(node, ctx, is_initial=True)
    return node, ctx
Exemplo n.º 29
0
  def test_resolve_entity(self):
    test_fn = basic_definitions.simple_function
    node, source = parser.parse_entity(
        test_fn, inspect_utils.getfutureimports(test_fn))
    origin_info.resolve_entity(node, source, test_fn)

    # The line numbers below should match those in basic_definitions.py
    fn_start = inspect.getsourcelines(test_fn)[1]

    def_origin = anno.getanno(node, anno.Basic.ORIGIN)
    self.assertEqual(def_origin.loc.lineno, fn_start)
    self.assertEqual(def_origin.loc.col_offset, 0)
    self.assertEqual(def_origin.source_code_line, 'def simple_function(x):')
    self.assertIsNone(def_origin.comment)

    docstring_origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
    self.assertEqual(docstring_origin.loc.lineno, fn_start + 1)
    self.assertEqual(docstring_origin.loc.col_offset, 2)
    self.assertEqual(docstring_origin.source_code_line, '  """Docstring."""')
    self.assertIsNone(docstring_origin.comment)

    ret_origin = anno.getanno(node.body[1], anno.Basic.ORIGIN)
    self.assertEqual(ret_origin.loc.lineno, fn_start + 2)
    self.assertEqual(ret_origin.loc.col_offset, 2)
    self.assertEqual(ret_origin.source_code_line, '  return x  # comment')
    self.assertEqual(ret_origin.comment, 'comment')
Exemplo n.º 30
0
  def test_parse_entity(self):

    def f(x):
      return x + 1

    node, _ = parser.parse_entity(f, future_features=())
    self.assertEqual('f', node.name)
Exemplo n.º 31
0
def convert_func_to_ast(f, program_ctx, do_rename=True):
    """Specialization of `convert_entity_to_ast` for callable functions."""

    future_features = inspect_utils.getfutureimports(f)
    node, source = parser.parse_entity(f, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    # Parsed AST should contain future imports and one function def node.

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve_entity(node, source, f)

    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
    elif do_rename:
        new_name = namer.function_name(f.__name__)
    else:
        new_name = f.__name__

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         future_features=future_features,
                                         namespace=namespace)
    context = converter.EntityContext(namer, entity_info, program_ctx,
                                      new_name)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        node = gast.Assign(targets=[
            gast.Name(new_name,
                      ctx=gast.Store(),
                      annotation=None,
                      type_comment=None)
        ],
                           value=node)
    elif do_rename:
        node.name = new_name
    else:
        assert node.name == new_name

    return (node, ), new_name, entity_info
Exemplo n.º 32
0
  def test_parse_lambda_prefix_cleanup(self):

    lambda_lam = lambda x: x + 1
    expected_node_src = 'lambda x: (x + 1)'

    node, source = parser.parse_entity(lambda_lam, future_features=())
    self.assertAstMatches(node, source)
    self.assertAstMatches(node, expected_node_src)
Exemplo n.º 33
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn, future_features=())
   entity_info = transformer.EntityInfo(
       source_code=source, source_file=None, future_features=(), namespace={})
   node = qual_names.resolve(node)
   ctx = transformer.Context(entity_info)
   node = activity.resolve(node, ctx)
   return node, entity_info
Exemplo n.º 34
0
 def test_basic(self):
   def test_function():
     a = 0
     return a
   node, _ = parser.parse_entity(test_function)
   node = anf.transform(node.body[0], self._simple_source_info())
   result, _ = compiler.ast_to_object(node)
   self.assertEqual(test_function(), result.test_function())
Exemplo n.º 35
0
 def test_basic(self):
   def test_function():
     a = 0
     return a
   node, _, _ = parser.parse_entity(test_function, future_imports=())
   node = anf.transform(node, self._simple_context())
   result, _ = compiler.ast_to_object(node)
   self.assertEqual(test_function(), result.test_function())
Exemplo n.º 36
0
    def test_parse_multiline_strings(self):
        def f():
            print("""
multiline
string""")

        node, _ = parser.parse_entity(f, future_features=())
        self.assertEqual('f', node.name)
Exemplo n.º 37
0
        def do_parse_and_test(lam, **unused_kwargs):
            node, source = parser.parse_entity(lam, future_features=())

            self.assertEqual(
                parser.unparse(node, include_encoding_marker=False),
                '(lambda x: x)')
            self.assertMatchesWithPotentialGarbage(source, 'lambda x: x',
                                                   ', named_arg=1)')
Exemplo n.º 38
0
    def test_parse_lambda_prefix_cleanup(self):

        lambda_lam = lambda x: x + 1

        node, source = parser.parse_entity(lambda_lam, future_features=())
        self.assertEqual(parser.unparse(node, include_encoding_marker=False),
                         '(lambda x: (x + 1))')
        self.assertEqual(source, 'lambda x: x + 1')
Exemplo n.º 39
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn, future_features=())
   entity_info = transformer.EntityInfo(
       source_code=source, source_file=None, future_features=(), namespace={})
   node = qual_names.resolve(node)
   ctx = transformer.Context(entity_info)
   node = activity.resolve(node, ctx)
   return node, entity_info
Exemplo n.º 40
0
    def test_entity_scope_tracking(self):
        class TestTransformer(transformer.Base):

            # The choice of note to assign to is arbitrary. Using Assign because it's
            # easy to find in the tree.
            def visit_Assign(self, node):
                anno.setanno(node, 'enclosing_entities',
                             self.enclosing_entities)
                return self.generic_visit(node)

            # This will show up in the lambda function.
            def visit_BinOp(self, node):
                anno.setanno(node, 'enclosing_entities',
                             self.enclosing_entities)
                return self.generic_visit(node)

        tr = TestTransformer(self._simple_context())

        def test_function():
            a = 0

            class TestClass(object):
                def test_method(self):
                    b = 0

                    def inner_function(x):
                        c = 0
                        d = lambda y: (x + y)
                        return c, d

                    return b, inner_function

            return a, TestClass

        node, _ = parser.parse_entity(test_function, future_features=())
        node = tr.visit(node)

        test_function_node = node
        test_class = test_function_node.body[1]
        test_method = test_class.body[0]
        inner_function = test_method.body[1]
        lambda_node = inner_function.body[1].value

        a = test_function_node.body[0]
        b = test_method.body[0]
        c = inner_function.body[0]
        lambda_expr = lambda_node.body

        self.assertEqual((test_function_node, ),
                         anno.getanno(a, 'enclosing_entities'))
        self.assertEqual((test_function_node, test_class, test_method),
                         anno.getanno(b, 'enclosing_entities'))
        self.assertEqual(
            (test_function_node, test_class, test_method, inner_function),
            anno.getanno(c, 'enclosing_entities'))
        self.assertEqual((test_function_node, test_class, test_method,
                          inner_function, lambda_node),
                         anno.getanno(lambda_expr, 'enclosing_entities'))
Exemplo n.º 41
0
  def test_basic(self):
    def test_function():
      a = 0
      return a

    node, _ = parser.parse_entity(test_function, future_features=())
    node = anf.transform(node, self._simple_context())
    result, _, _ = loader.load_ast(node)
    self.assertEqual(test_function(), result.test_function())
Exemplo n.º 42
0
    def test_basic(self):
        def test_function():
            a = 0
            return a

        node, _ = parser.parse_entity(test_function)
        node = anf.transform(node.body[0], self._simple_source_info())
        result, _ = compiler.ast_to_object(node)
        self.assertEqual(test_function(), result.test_function())
Exemplo n.º 43
0
    def test_parse_lambda_resolution_ambiguous(self):

        l = lambda x: lambda x: 2 * x

        expected_exception_text = re.compile(
            r'found multiple definitions'
            r'.+'
            r'\(lambda x: \(lambda x'
            r'.+'
            r'\(lambda x: \(2', re.DOTALL)

        with self.assertRaisesRegex(errors.UnsupportedLanguageElementError,
                                    expected_exception_text):
            parser.parse_entity(l, future_features=())

        with self.assertRaisesRegex(errors.UnsupportedLanguageElementError,
                                    expected_exception_text):
            parser.parse_entity(l(0), future_features=())
Exemplo n.º 44
0
 def assert_body_anfs_as_expected(self, expected_fn, test_fn):
   # Testing the code bodies only.  Wrapping them in functions so the
   # syntax highlights nicely, but Python doesn't try to execute the
   # statements.
   exp_node, _ = parser.parse_entity(expected_fn)
   node, _ = parser.parse_entity(test_fn)
   node = anf.transform(
       node, self._simple_source_info(), gensym_source=DummyGensym)
   exp_name = exp_node.body[0].name
   # Ignoring the function names in the result because they can't be
   # the same (because both functions have to exist in the same scope
   # at the same time).
   node.body[0].name = exp_name
   self.assert_same_ast(exp_node, node)
   # Check that ANF is idempotent
   node_repeated = anf.transform(
       node, self._simple_source_info(), gensym_source=DummyGensym)
   self.assert_same_ast(node_repeated, node)
Exemplo n.º 45
0
  def _should_compile(self, node, fqn):
    """Determines whether an entity should be compiled in the context."""
    # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
    module_name = fqn[0]
    for mod in self.ctx.program.uncompiled_modules:
      if module_name.startswith(mod[0] + '.'):
        return False

    for i in range(1, len(fqn)):
      if fqn[:i] in self.ctx.program.uncompiled_modules:
        return False

    target_entity = self._try_resolve_target(node.func)

    if target_entity is not None:

      # Currently, lambdas are always converted.
      # TODO(mdan): Allow markers of the kind f = ag.do_not_convert(lambda: ...)
      if inspect_utils.islambda(target_entity):
        return True

      # This may be reached when "calling" a callable attribute of an object.
      # For example:
      #
      #   self.fc = tf.keras.layers.Dense()
      #   self.fc()
      #
      for mod in self.ctx.program.uncompiled_modules:
        if target_entity.__module__.startswith(mod[0] + '.'):
          return False

      # Inspect the target function decorators. If any include a @convert
      # or @do_not_convert annotation, then they must be called as they are.
      # TODO(mdan): This may be quite heavy. Perhaps always dynamically convert?
      # To parse and re-analyze each function for every call site could be quite
      # wasteful. Maybe we could cache the parsed AST?
      try:
        target_node, _ = parser.parse_entity(target_entity)
        target_node = target_node.body[0]
      except TypeError:
        # Functions whose source we cannot access are compilable (e.g. wrapped
        # to py_func).
        return True

      # This attribute is set when the decorator was applied before the
      # function was parsed. See api.py.
      if hasattr(target_entity, '__ag_compiled'):
        return False

      for dec in target_node.decorator_list:
        decorator_fn = self._resolve_decorator_name(dec)
        if (decorator_fn is not None and
            decorator_fn in self.ctx.program.options.strip_decorators):
          return False

    return True
Exemplo n.º 46
0
def function_to_graph(f,
                      program_ctx,
                      arg_values,
                      arg_types,
                      owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""

  node, source = parser.parse_entity(f)
  node = node.body[0]

  # TODO(mdan): Can we convert everything and scoop the lambda afterwards?
  if f.__name__ == '<lambda>':
    nodes = ast_util.find_matching_lambda_definitions(node, f)
    if len(nodes) != 1:
      raise ValueError(
          'Unable to identify source code of lambda function {}. It was'
          ' defined on this line: {}, which contains multiple lambdas with'
          ' identical argument names. To avoid ambiguity, define each lambda'
          ' in a separate expression.'.format(f, source))
    node, = nodes

  # TODO(znado): Place inside standard_analysis.
  origin_info.resolve(node, source, f)
  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, program_ctx.autograph_module)
  namer = program_ctx.new_namer(namespace)

  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      owner_type=owner_type)
  context = converter.EntityContext(namer, entity_info, program_ctx)
  node = node_to_graph(node, context)

  if isinstance(node, gast.Lambda):
    new_name = namer.new_symbol('tf__lambda', ())
    node = gast.Assign(
        targets=[gast.Name(new_name, gast.Store(), None)], value=node)

  else:
    # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
    new_name, did_rename = namer.compiled_function_name(f.__name__, f,
                                                        owner_type)
    if did_rename:
      node.name = new_name
    else:
      new_name = f.__name__
      assert node.name == new_name

  program_ctx.update_name_map(namer)
  # TODO(mdan): Use this at compilation.

  return [node], new_name, namespace
Exemplo n.º 47
0
def convert_func_to_ast(f, program_ctx, do_rename=True):
  """Specialization of `convert_entity_to_ast` for callable functions."""

  future_features = inspect_utils.getfutureimports(f)
  node, source = parser.parse_entity(f, future_features=future_features)
  logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
  # Parsed AST should contain future imports and one function def node.

  # In general, the output of inspect.getsource is inexact for lambdas because
  # it uses regex matching to adjust the exact location around the line number
  # that CPython records. Then, the entire containing line is returned, which
  # we may have trouble disambiguating. For example:
  # x, y = lambda: 1, lambda: 2
  if f.__name__ == '<lambda>':
    nodes = ast_util.find_matching_definitions(node, f)
    if len(nodes) != 1:
      raise ValueError(
          'Unable to identify source code of lambda function {}. It was'
          ' defined on this line: {}, which must contain a single lambda with'
          ' matching signature. To avoid ambiguity, define each lambda'
          ' in a separate expression.'.format(f, source))
    node, = nodes

  # TODO(znado): Place inside standard_analysis.
  origin_info.resolve(node, source, f)
  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, program_ctx.autograph_module)
  namer = naming.Namer(namespace)

  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      future_features=future_features,
      namespace=namespace)
  context = converter.EntityContext(namer, entity_info, program_ctx)
  try:
    node = node_to_graph(node, context)
  except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
    logging.error(1, 'Error converting %s', f, exc_info=True)
    raise errors.InternalError('conversion', e)
    # TODO(mdan): Catch and rethrow syntax errors.

  if isinstance(node, gast.Lambda):
    new_name = namer.new_symbol('tf__lambda', ())
    node = gast.Assign(
        targets=[gast.Name(new_name, gast.Store(), None)], value=node)

  elif do_rename:
    new_name = namer.function_name(f.__name__)
    node.name = new_name
  else:
    new_name = f.__name__
    assert node.name == new_name

  return (node,), new_name, entity_info
Exemplo n.º 48
0
  def test_entity_scope_tracking(self):

    class TestTransformer(transformer.Base):

      # The choice of note to assign to is arbitrary. Using Assign because it's
      # easy to find in the tree.
      def visit_Assign(self, node):
        anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
        return self.generic_visit(node)

      # This will show up in the lambda function.
      def visit_BinOp(self, node):
        anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
        return self.generic_visit(node)

    tr = TestTransformer(self._simple_context())

    def test_function():
      a = 0

      class TestClass(object):

        def test_method(self):
          b = 0
          def inner_function(x):
            c = 0
            d = lambda y: (x + y)
            return c, d
          return b, inner_function
      return a, TestClass

    node, _ = parser.parse_entity(test_function, future_features=())
    node = tr.visit(node)

    test_function_node = node
    test_class = test_function_node.body[1]
    test_method = test_class.body[0]
    inner_function = test_method.body[1]
    lambda_node = inner_function.body[1].value

    a = test_function_node.body[0]
    b = test_method.body[0]
    c = inner_function.body[0]
    lambda_expr = lambda_node.body

    self.assertEqual(
        (test_function_node,), anno.getanno(a, 'enclosing_entities'))
    self.assertEqual((test_function_node, test_class, test_method),
                     anno.getanno(b, 'enclosing_entities'))
    self.assertEqual(
        (test_function_node, test_class, test_method, inner_function),
        anno.getanno(c, 'enclosing_entities'))
    self.assertEqual((test_function_node, test_class, test_method,
                      inner_function, lambda_node),
                     anno.getanno(lambda_expr, 'enclosing_entities'))
Exemplo n.º 49
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn, future_features=())
   entity_info = transformer.EntityInfo(
       source_code=source, source_file=None, future_features=(), namespace={})
   node = qual_names.resolve(node)
   ctx = transformer.Context(entity_info)
   node = activity.resolve(node, ctx)
   graphs = cfg.build(node)
   node = reaching_definitions.resolve(node, ctx, graphs,
                                       reaching_definitions.Definition)
   return node
Exemplo n.º 50
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace={},
       arg_values=None,
       arg_types=None,
       owner_type=None)
   node = qual_names.resolve(node)
   node = activity.resolve(node, entity_info)
   return node, entity_info
Exemplo n.º 51
0
  def test_source_map_no_origin(self):

    def test_fn(x):
      return x + 1

    node, _ = parser.parse_entity(test_fn)
    fn_node = node.body[0]
    converted_code = compiler.ast_to_source(fn_node)

    source_map = origin_info.create_source_map(
        fn_node, converted_code, 'test_filename', [0])

    self.assertEqual(len(source_map), 0)
Exemplo n.º 52
0
  def test_parser_compile_idempotent(self):

    def test_fn(x):
      a = True
      b = ''
      if a:
        b = x + 1
      return b

    self.assertEqual(
        textwrap.dedent(tf_inspect.getsource(test_fn)),
        tf_inspect.getsource(
            compiler.ast_to_object(
                parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
Exemplo n.º 53
0
  def test_parser_compile_identity(self):

    def test_fn(x):
      a = True
      b = ''
      if a:
        b = x + 1
      return b

    node, _ = parser.parse_entity(test_fn, future_features=())
    module, _, _ = compiler.ast_to_object(node)

    self.assertEqual(
        textwrap.dedent(tf_inspect.getsource(test_fn)),
        tf_inspect.getsource(module.test_fn))
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace={},
       arg_values=None,
       arg_types=None,
       owner_type=None)
   node = qual_names.resolve(node)
   node = activity.resolve(node, entity_info)
   graphs = cfg.build(node)
   node = reaching_definitions.resolve(node, entity_info, graphs,
                                       reaching_definitions.Definition)
   return node
Exemplo n.º 55
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace={},
       arg_values=None,
       arg_types=None,
       owner_type=None)
   node = qual_names.resolve(node)
   ctx = transformer.Context(entity_info)
   node = activity.resolve(node, ctx)
   graphs = cfg.build(node)
   liveness.resolve(node, ctx, graphs)
   return node
Exemplo n.º 56
0
  def test_invalid_default(self):

    def invalid_directive(valid_arg, invalid_default=object()):
      del valid_arg
      del invalid_default
      return

    def call_invalid_directive():
      invalid_directive(1)

    node, _ = parser.parse_entity(call_invalid_directive, ())
    # Find the call to the invalid directive
    node = node.body[0].value
    with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):
      directives_converter._map_args(node, invalid_directive)
Exemplo n.º 57
0
  def test_local_scope_info_stack(self):

    class TestTransformer(transformer.Base):

      # Extract all string constants from the block.
      def visit_Str(self, node):
        self.set_local('string', self.get_local('string', default='') + node.s)
        return self.generic_visit(node)

      def _annotate_result(self, node):
        self.enter_local_scope()
        node = self.generic_visit(node)
        anno.setanno(node, 'test', self.get_local('string'))
        self.exit_local_scope()
        return node

      def visit_While(self, node):
        return self._annotate_result(node)

      def visit_For(self, node):
        return self._annotate_result(node)

    tr = TestTransformer(self._simple_context())

    def test_function(a):
      """Docstring."""
      assert a == 'This should not be counted'
      for i in range(3):
        _ = 'a'
        if i > 2:
          return 'b'
        else:
          _ = 'c'
          while True:
            raise '1'
      return 'nor this'

    node, _ = parser.parse_entity(test_function, future_features=())
    node = tr.visit(node)

    for_node = node.body[2]
    while_node = for_node.body[1].orelse[1]

    self.assertFalse(anno.hasanno(for_node, 'string'))
    self.assertEqual('abc', anno.getanno(for_node, 'test'))
    self.assertFalse(anno.hasanno(while_node, 'string'))
    self.assertEqual('1', anno.getanno(while_node, 'test'))
Exemplo n.º 58
0
  def prepare(self, test_fn, namespace, recursive=True):
    namespace['ConversionOptions'] = converter.ConversionOptions

    future_features = ('print_function', 'division')
    node, source = parser.parse_entity(test_fn, future_features=future_features)
    namer = naming.Namer(namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=recursive),
        autograph_module=None)
    entity_info = transformer.EntityInfo(
        source_code=source,
        source_file='<fragment>',
        future_features=future_features,
        namespace=namespace)
    ctx = converter.EntityContext(namer, entity_info, program_ctx)
    origin_info.resolve(node, source, test_fn)
    node = converter.standard_analysis(node, ctx, is_initial=True)
    return node, ctx
Exemplo n.º 59
0
  def test_create_source_map(self):

    def test_fn(x):
      return x + 1

    node, _ = parser.parse_entity(test_fn)
    fake_origin = origin_info.OriginInfo(
        loc=origin_info.Location('fake_filename', 3, 7),
        function_name='fake_function_name',
        source_code_line='fake source line',
        comment=None)
    fn_node = node.body[0]
    anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin)
    converted_code = compiler.ast_to_source(fn_node)

    source_map = origin_info.create_source_map(
        fn_node, converted_code, 'test_filename', [0])

    loc = origin_info.LineLocation('test_filename', 2)
    self.assertIn(loc, source_map)
    self.assertIs(source_map[loc], fake_origin)