示例#1
0
    def test_local_scope_info_stack_checks_integrity(self):
        class TestTransformer(transformer.Base):
            def visit_If(self, node):
                self.enter_local_scope()
                return self.generic_visit(node)

            def visit_For(self, node):
                node = self.generic_visit(node)
                self.exit_local_scope()
                return node

        tr = TestTransformer(self._simple_source_info())

        def no_exit(a):
            if a > 0:
                print(a)
            return None

        node, _ = parser.parse_entity(no_exit)
        with self.assertRaises(AssertionError):
            tr.visit(node)

        def no_entry(a):
            for _ in a:
                print(a)

        node, _ = parser.parse_entity(no_entry)
        with self.assertRaises(AssertionError):
            tr.visit(node)
示例#2
0
  def test_local_scope_info_stack_checks_integrity(self):

    class TestTransformer(transformer.Base):

      def visit_If(self, node):
        self.enter_local_scope()
        return self.generic_visit(node)

      def visit_For(self, node):
        node = self.generic_visit(node)
        self.exit_local_scope()
        return node

    tr = TestTransformer(self._context_for_testing())

    def no_exit(a):
      if a > 0:
        print(a)
      return None

    node, _ = parser.parse_entity(no_exit)
    with self.assertRaises(AssertionError):
      tr.visit(node)

    def no_entry(a):
      for _ in a:
        print(a)

    node, _ = parser.parse_entity(no_entry)
    with self.assertRaises(AssertionError):
      tr.visit(node)
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    node = node.body[0]
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = program_ctx.new_namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types,
                                         owner_type=owner_type)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    node = node_to_graph(node, context)

    # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
    new_name, did_rename = namer.compiled_function_name(
        f.__name__, f, owner_type)
    if not did_rename:
        new_name = f.__name__
        if node.name != f.__name__:
            raise NotImplementedError(
                'Strange corner case. Send us offending code!')

    node.name = new_name
    program_ctx.update_name_map(namer)
    # TODO(mdan): Use this at compilation.

    return node, new_name, namespace
示例#4
0
  def test_visit_block_postprocessing(self):

    class TestTransformer(transformer.Base):

      def _process_body_item(self, node):
        if isinstance(node, gast.Assign) and (node.value.id == 'y'):
          if_node = gast.If(gast.Name('x', gast.Load(), None), [node], [])
          return if_node, if_node.body
        return node, None

      def visit_FunctionDef(self, node):
        node.body = self.visit_block(
            node.body, after_visit=self._process_body_item)
        return node

    def test_function(x, y):
      z = x
      z = y
      return z

    tr = TestTransformer(self._context_for_testing())

    node, _ = parser.parse_entity(test_function)
    node = tr.visit(node)
    node = node.body[0]

    self.assertEqual(len(node.body), 2)
    self.assertTrue(isinstance(node.body[0], gast.Assign))
    self.assertTrue(isinstance(node.body[1], gast.If))
    self.assertTrue(isinstance(node.body[1].body[0], gast.Assign))
    self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
示例#5
0
  def test_parse_entity(self):

    def f(x):
      return x + 1

    mod, _ = parser.parse_entity(f)
    self.assertEqual('f', mod.body[0].name)
  def test_resolve(self):

    def test_fn(x):
      """Docstring."""
      return x  # comment

    node, source = parser.parse_entity(test_fn)
    fn_node = node.body[0]
    origin_info.resolve(fn_node, source)

    origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 1)
    self.assertEqual(origin.loc.col_offset, 0)
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 2)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  """Docstring."""')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 3)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  return x  # comment')
    self.assertEqual(origin.comment, 'comment')
示例#7
0
    def test_robust_error_on_list_visit(self):
        class BrokenTransformer(transformer.Base):
            def visit_If(self, node):
                # This is broken because visit expects a single node, not a list, and
                # the body of an if is a list.
                # Importantly, the default error handling in visit also expects a single
                # node.  Therefore, mistakes like this need to trigger a type error
                # before the visit called here installs its error handler.
                # That type error can then be caught by the enclosing call to visit,
                # and correctly blame the If node.
                self.visit(node.body)
                return node

        def test_function(x):
            if x > 0:
                return x

        tr = BrokenTransformer(self._simple_source_info())

        node, _ = parser.parse_entity(test_function)
        with self.assertRaises(transformer.AutographParseError) as cm:
            node = tr.visit(node)
        obtained_message = str(cm.exception)
        expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
        self.assertRegexpMatches(obtained_message, expected_message)
        # The exception should point at the if statement, not any place else.  Could
        # also check the stack trace.
        self.assertTrue('Occurred at node:\nIf' in obtained_message,
                        obtained_message)
        self.assertTrue(
            'Occurred at node:\nFunctionDef' not in obtained_message,
            obtained_message)
        self.assertTrue('Occurred at node:\nReturn' not in obtained_message,
                        obtained_message)
示例#8
0
 def prepare(self,
             test_fn,
             namespace,
             namer=None,
             arg_types=None,
             owner_type=None,
             recursive=True,
             autograph_decorators=()):
     node, source = parser.parse_entity(test_fn)
     node = node.body[0]
     if namer is None:
         namer = FakeNamer()
     program_ctx = converter.ProgramContext(
         recursive=recursive,
         autograph_decorators=autograph_decorators,
         partial_types=None,
         autograph_module=None,
         uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
     entity_info = transformer.EntityInfo(source_code=source,
                                          source_file='<fragment>',
                                          namespace=namespace,
                                          arg_values=None,
                                          arg_types=arg_types,
                                          owner_type=owner_type)
     ctx = converter.EntityContext(namer, entity_info, program_ctx)
     node = converter.standard_analysis(node, ctx, is_initial=True)
     return node, ctx
def function_to_graph(f, conversion_map, arg_values, arg_types,
                      owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""
  node, source = parser.parse_entity(f)
  node = node.body[0]

  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, conversion_map.api_module)
  namer = conversion_map.new_namer(namespace)

  ctx = context.EntityContext(
      namer=namer,
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      owner_type=owner_type,
      recursive=conversion_map.recursive,
      type_annotation_func=type_hints.set_element_type)
  node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators)

  # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
  new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
  if not did_rename:
    new_name = f.__name__
    if node.name != f.__name__:
      raise NotImplementedError('Strange corner case. Send us offending code!')

  node.name = new_name
  conversion_map.update_name_map(namer)
  # TODO(mdan): Use this at compilation.
  conversion_map.additional_imports.update(deps)

  return node, new_name, namespace
示例#10
0
def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""
  node, source = parser.parse_entity(f)
  node = node.body[0]

  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, program_ctx.autograph_module)
  namer = program_ctx.new_namer(namespace)

  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      owner_type=owner_type)
  context = converter.EntityContext(namer, entity_info, program_ctx)
  node = node_to_graph(node, context)

  # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
  new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
  if not did_rename:
    new_name = f.__name__
    if node.name != f.__name__:
      raise NotImplementedError('Strange corner case. Send us offending code!')

  node.name = new_name
  program_ctx.update_name_map(namer)
  # TODO(mdan): Use this at compilation.

  return node, new_name, namespace
示例#11
0
def function_to_graph(f, conversion_map, arg_values, arg_types,
                      owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""
  node, source = parser.parse_entity(f)
  node = node.body[0]

  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, conversion_map.api_module)
  namer = conversion_map.new_namer(namespace)

  ctx = context.EntityContext(
      namer=namer,
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      owner_type=owner_type,
      recursive=conversion_map.recursive,
      type_annotation_func=type_hints.set_element_type)
  node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators)

  # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
  new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
  if not did_rename:
    new_name = f.__name__
    if node.name != f.__name__:
      raise NotImplementedError('Strange corner case. Send us offending code!')

  node.name = new_name
  conversion_map.update_name_map(namer)
  # TODO(mdan): Use this at compilation.
  conversion_map.additional_imports.update(deps)

  return node, new_name
示例#12
0
 def parse_and_analyze(self,
                       test_fn,
                       namespace,
                       namer=None,
                       arg_types=None,
                       include_type_analysis=True,
                       owner_type=None,
                       recursive=True):
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=namer or FakeNamer(),
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       owner_type=owner_type,
       recursive=recursive,
       type_annotation_func=utils.set_element_type)
   node = qual_names.resolve(node)
   node = activity.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   if include_type_analysis:
     node = type_info.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
   self.ctx = ctx
   return node
    def parse_and_analyze(self,
                          test_fn,
                          namespace,
                          namer=None,
                          arg_types=None,
                          include_type_analysis=True,
                          owner_type=None,
                          recursive=True,
                          autograph_decorators=()):
        node, source = parser.parse_entity(test_fn)

        if namer is None:
            namer = FakeNamer()
        program_ctx = converter.ProgramContext(
            recursive=recursive,
            autograph_decorators=autograph_decorators,
            partial_types=None,
            autograph_module=None,
            uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
        entity_info = transformer.EntityInfo(source_code=source,
                                             source_file='<fragment>',
                                             namespace=namespace,
                                             arg_values=None,
                                             arg_types=arg_types,
                                             owner_type=owner_type)
        ctx = converter.EntityContext(namer, entity_info, program_ctx)

        node = qual_names.resolve(node)
        node = activity.resolve(node, entity_info)
        node = live_values.resolve(node, entity_info, {})
        if include_type_analysis:
            node = type_info.resolve(node, entity_info)
            node = live_values.resolve(node, entity_info, {})
        self.ctx = ctx
        return node
示例#14
0
 def parse_and_analyze(self,
                       test_fn,
                       namespace,
                       namer=None,
                       arg_types=None,
                       include_type_analysis=True,
                       owner_type=None,
                       recursive=True):
     node, source = parser.parse_entity(test_fn)
     ctx = context.EntityContext(
         namer=namer or FakeNamer(),
         source_code=source,
         source_file=None,
         namespace=namespace,
         arg_values=None,
         arg_types=arg_types,
         owner_type=owner_type,
         recursive=recursive,
         type_annotation_func=utils.set_element_type)
     node = qual_names.resolve(node)
     node = activity.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
     if include_type_analysis:
         node = type_info.resolve(node, ctx)
         node = live_values.resolve(node, ctx, {})
     self.ctx = ctx
     return node
示例#15
0
 def prepare(self,
             test_fn,
             namespace,
             namer=None,
             arg_types=None,
             owner_type=None,
             recursive=True,
             autograph_decorators=()):
   node, source = parser.parse_entity(test_fn)
   if namer is None:
     namer = FakeNamer()
   program_ctx = converter.ProgramContext(
       recursive=recursive,
       autograph_decorators=autograph_decorators,
       partial_types=None,
       autograph_module=None,
       uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file='<fragment>',
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       owner_type=owner_type)
   ctx = converter.EntityContext(namer, entity_info, program_ctx)
   node = converter.standard_analysis(node, ctx, is_initial=True)
   return node, ctx
示例#16
0
    def test_visit_block_postprocessing(self):
        class TestTransformer(transformer.Base):
            def _process_body_item(self, node):
                if isinstance(node, gast.Assign) and (node.value.id == 'y'):
                    if_node = gast.If(gast.Name('x', gast.Load(), None),
                                      [node], [])
                    return if_node, if_node.body
                return node, None

            def visit_FunctionDef(self, node):
                node.body = self.visit_block(
                    node.body, after_visit=self._process_body_item)
                return node

        def test_function(x, y):
            z = x
            z = y
            return z

        tr = TestTransformer(self._simple_source_info())

        node, _ = parser.parse_entity(test_function)
        node = tr.visit(node)
        node = node.body[0]

        self.assertEqual(len(node.body), 2)
        self.assertTrue(isinstance(node.body[0], gast.Assign))
        self.assertTrue(isinstance(node.body[1], gast.If))
        self.assertTrue(isinstance(node.body[1].body[0], gast.Assign))
        self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
示例#17
0
 def test_basic(self):
   def test_function():
     a = 0
     return a
   node, _ = parser.parse_entity(test_function)
   node = anf.transform(node.body[0], self._simple_source_info())
   result, _ = compiler.ast_to_object(node)
   self.assertEqual(test_function(), result.test_function())
示例#18
0
    def test_entity_scope_tracking(self):
        class TestTransformer(transformer.Base):

            # The choice of note to assign to is arbitrary. Using Assign because it's
            # easy to find in the tree.
            def visit_Assign(self, node):
                anno.setanno(node, 'enclosing_entities',
                             self.enclosing_entities)
                return self.generic_visit(node)

            # This will show up in the lambda function.
            def visit_BinOp(self, node):
                anno.setanno(node, 'enclosing_entities',
                             self.enclosing_entities)
                return self.generic_visit(node)

        tr = TestTransformer(self._simple_source_info())

        def test_function():
            a = 0

            class TestClass(object):
                def test_method(self):
                    b = 0

                    def inner_function(x):
                        c = 0
                        d = lambda y: (x + y)
                        return c, d

                    return b, inner_function

            return a, TestClass

        node, _ = parser.parse_entity(test_function)
        node = tr.visit(node)

        test_function_node = node.body[0]
        test_class = test_function_node.body[1]
        test_method = test_class.body[0]
        inner_function = test_method.body[1]
        lambda_node = inner_function.body[1].value

        a = test_function_node.body[0]
        b = test_method.body[0]
        c = inner_function.body[0]
        lambda_expr = lambda_node.body

        self.assertEqual((test_function_node, ),
                         anno.getanno(a, 'enclosing_entities'))
        self.assertEqual((test_function_node, test_class, test_method),
                         anno.getanno(b, 'enclosing_entities'))
        self.assertEqual(
            (test_function_node, test_class, test_method, inner_function),
            anno.getanno(c, 'enclosing_entities'))
        self.assertEqual((test_function_node, test_class, test_method,
                          inner_function, lambda_node),
                         anno.getanno(lambda_expr, 'enclosing_entities'))
示例#19
0
    def test_basic(self):
        def test_function():
            a = 0
            return a

        node, _ = parser.parse_entity(test_function)
        node = anf.transform(node.body[0], self._simple_source_info())
        result, _ = compiler.ast_to_object(node)
        self.assertEqual(test_function(), result.test_function())
示例#20
0
 def assert_body_anfs_as_expected(self, expected_fn, test_fn):
   # Testing the code bodies only.  Wrapping them in functions so the
   # syntax highlights nicely, but Python doesn't try to execute the
   # statements.
   exp_node, _ = parser.parse_entity(expected_fn)
   node, _ = parser.parse_entity(test_fn)
   node = anf.transform(
       node, self._simple_source_info(), gensym_source=DummyGensym)
   exp_name = exp_node.body[0].name
   # Ignoring the function names in the result because they can't be
   # the same (because both functions have to exist in the same scope
   # at the same time).
   node.body[0].name = exp_name
   self.assert_same_ast(exp_node, node)
   # Check that ANF is idempotent
   node_repeated = anf.transform(
       node, self._simple_source_info(), gensym_source=DummyGensym)
   self.assert_same_ast(node_repeated, node)
示例#21
0
  def test_entity_scope_tracking(self):

    class TestTransformer(transformer.Base):

      # The choice of note to assign to is arbitrary. Using Assign because it's
      # easy to find in the tree.
      def visit_Assign(self, node):
        anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
        return self.generic_visit(node)

      # This will show up in the lambda function.
      def visit_BinOp(self, node):
        anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
        return self.generic_visit(node)

    tr = TestTransformer(self._context_for_testing())

    def test_function():
      a = 0

      class TestClass(object):

        def test_method(self):
          b = 0
          def inner_function(x):
            c = 0
            d = lambda y: (x + y)
            return c, d
          return b, inner_function
      return a, TestClass

    node, _ = parser.parse_entity(test_function)
    node = tr.visit(node)

    test_function_node = node.body[0]
    test_class = test_function_node.body[1]
    test_method = test_class.body[0]
    inner_function = test_method.body[1]
    lambda_node = inner_function.body[1].value

    a = test_function_node.body[0]
    b = test_method.body[0]
    c = inner_function.body[0]
    lambda_expr = lambda_node.body

    self.assertEqual(
        (test_function_node,), anno.getanno(a, 'enclosing_entities'))
    self.assertEqual((test_function_node, test_class, test_method),
                     anno.getanno(b, 'enclosing_entities'))
    self.assertEqual(
        (test_function_node, test_class, test_method, inner_function),
        anno.getanno(c, 'enclosing_entities'))
    self.assertEqual((test_function_node, test_class, test_method,
                      inner_function, lambda_node),
                     anno.getanno(lambda_expr, 'enclosing_entities'))
示例#22
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace={},
       arg_values=None,
       arg_types=None,
       owner_type=None)
   node = qual_names.resolve(node)
   return node, entity_info
 def _parse_and_analyze(self, test_fn):
     node, source = parser.parse_entity(test_fn)
     entity_info = transformer.EntityInfo(source_code=source,
                                          source_file=None,
                                          namespace={},
                                          arg_values=None,
                                          arg_types=None,
                                          owner_type=None)
     node = qual_names.resolve(node)
     node = activity.resolve(node, entity_info)
     return node, entity_info
示例#24
0
 def assert_body_anfs_as_expected(self, expected_fn, test_fn):
     # Testing the code bodies only.  Wrapping them in functions so the
     # syntax highlights nicely, but Python doesn't try to execute the
     # statements.
     exp_node, _ = parser.parse_entity(expected_fn)
     node, _ = parser.parse_entity(test_fn)
     node = anf.transform(node,
                          self._simple_source_info(),
                          gensym_source=DummyGensym)
     exp_name = exp_node.body[0].name
     # Ignoring the function names in the result because they can't be
     # the same (because both functions have to exist in the same scope
     # at the same time).
     node.body[0].name = exp_name
     self.assert_same_ast(exp_node, node)
     # Check that ANF is idempotent
     node_repeated = anf.transform(node,
                                   self._simple_source_info(),
                                   gensym_source=DummyGensym)
     self.assert_same_ast(node_repeated, node)
示例#25
0
 def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
     arg_types = arg_types or {}
     node, source = parser.parse_entity(test_fn)
     ctx = context.EntityContext(namer=None,
                                 source_code=source,
                                 source_file=None,
                                 namespace=namespace,
                                 arg_values=None,
                                 arg_types=arg_types,
                                 owner_type=None,
                                 recursive=True)
     node = qual_names.resolve(node)
     return node, ctx
示例#26
0
    def test_parser_compile_idempotent(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = x + 1
            return b

        self.assertEqual(
            textwrap.dedent(tf_inspect.getsource(test_fn)),
            tf_inspect.getsource(
                compiler.ast_to_object(
                    parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
示例#27
0
    def _should_compile(self, node, fqn):
        """Determines whether an entity should be compiled in the context."""
        # TODO (mdan): Needs cleanup. We should remove the use of fqn altogether. id:489
        # https://github.com/imdone/tensorflow/issues/490
        module_name = fqn[0]
        for mod in self.uncompiled_modules:
            if module_name.startswith(mod[0] + '.'):
                return False

        for i in range(1, len(fqn)):
            if fqn[:i] in self.uncompiled_modules:
                return False

        # Check for local decorations
        if anno.hasanno(node, 'graph_ready'):
            return False

        # The decorators themselves are not to be converted.
        # If present, the decorators should appear as static functions.
        target_entity = self._try_resolve_target(node.func)
        if target_entity is not None:
            # This attribute is set by the decorator itself.
            # TODO (mdan): This may not play nicely with other wrapping decorators. id:480
            # https://github.com/imdone/tensorflow/issues/481
            if hasattr(target_entity, '__pyct_is_compile_decorator'):
                return False

            if target_entity in self.nocompile_decorators:
                return False

            # Inspect the target function decorators. If any include a @convert
            # or @graph_ready annotation, then they must be called as they are.
            # TODO (mdan): This may be quite heavy. id:952
            # https://github.com/imdone/tensorflow/issues/953
            # To parse and re-analyze each function for every call site could be quite
            # wasteful. Maybe we could cache the parsed AST?
            try:
                target_node, _ = parser.parse_entity(target_entity)
                target_node = target_node.body[0]
            except TypeError:
                # Functions whose source we cannot access are compilable (e.g. wrapped
                # to py_func).
                return True

            for dec in target_node.decorator_list:
                decorator_fn = self._resolve_name(dec)
                if (decorator_fn is not None
                        and decorator_fn in self.nocompile_decorators):
                    return False

        return True
示例#28
0
  def test_parser_compile_idempotent(self):

    def test_fn(x):
      a = True
      b = ''
      if a:
        b = x + 1
      return b

    self.assertEqual(
        textwrap.dedent(tf_inspect.getsource(test_fn)),
        tf_inspect.getsource(
            compiler.ast_to_object(
                parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
示例#29
0
    def test_invalid_default(self):
        def invalid_directive(valid_arg, invalid_default=object()):
            del valid_arg
            del invalid_default
            return

        def call_invalid_directive():
            invalid_directive(1)

        node, _ = parser.parse_entity(call_invalid_directive)
        # Find the call to the invalid directive
        node = node.body[0].body[0].value
        with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):
            directives_converter._map_args(node, invalid_directive)
示例#30
0
 def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
   arg_types = arg_types or {}
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=None,
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       owner_type=None,
       recursive=True)
   node = qual_names.resolve(node)
   return node, ctx
示例#31
0
  def test_invalid_default(self):

    def invalid_directive(valid_arg, invalid_default=object()):
      del valid_arg
      del invalid_default
      return

    def call_invalid_directive():
      invalid_directive(1)

    node, _ = parser.parse_entity(call_invalid_directive)
    # Find the call to the invalid directive
    node = node.body[0].body[0].value
    with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):
      directives_converter._map_args(node, invalid_directive)
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace={},
       arg_values=None,
       arg_types=None,
       owner_type=None)
   node = qual_names.resolve(node)
   node = activity.resolve(node, entity_info)
   graphs = cfg.build(node)
   node = reaching_definitions.resolve(node, entity_info, graphs,
                                       reaching_definitions.Definition)
   return node
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace={},
       arg_values=None,
       arg_types=None,
       owner_type=None)
   node = qual_names.resolve(node)
   node = activity.resolve(node, entity_info)
   graphs = cfg.build(node)
   node = reaching_definitions.resolve(node, entity_info, graphs,
                                       reaching_definitions.Definition)
   return node
示例#34
0
  def _should_compile(self, node, fqn):
    """Determines whether an entity should be compiled in the context."""
    # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
    module_name = fqn[0]
    for mod in self.ctx.program.uncompiled_modules:
      if module_name.startswith(mod[0] + '.'):
        return False

    for i in range(1, len(fqn)):
      if fqn[:i] in self.ctx.program.uncompiled_modules:
        return False

    # Check for local decorations
    if anno.hasanno(node, 'graph_ready'):
      return False

    # The decorators themselves are not to be converted.
    # If present, the decorators should appear as static functions.
    target_entity = self._try_resolve_target(node.func)
    if target_entity is not None:
      # This attribute is set by the decorator itself.
      # TODO(mdan): This may not play nicely with other wrapping decorators.
      if hasattr(target_entity, '__pyct_is_compile_decorator'):
        return False

      if target_entity in self.ctx.program.autograph_decorators:
        return False

      # Inspect the target function decorators. If any include a @convert
      # or @graph_ready annotation, then they must be called as they are.
      # TODO(mdan): This may be quite heavy.
      # To parse and re-analyze each function for every call site could be quite
      # wasteful. Maybe we could cache the parsed AST?
      try:
        target_node, _ = parser.parse_entity(target_entity)
        target_node = target_node.body[0]
      except TypeError:
        # Functions whose source we cannot access are compilable (e.g. wrapped
        # to py_func).
        return True

      for dec in target_node.decorator_list:
        decorator_fn = self._resolve_name(dec)
        if (decorator_fn is not None and
            decorator_fn in self.ctx.program.autograph_decorators):
          return False

    return True
示例#35
0
  def test_local_scope_info_stack(self):

    class TestTransformer(transformer.Base):

      # Extract all string constants from the block.
      def visit_Str(self, node):
        self.set_local('string', self.get_local('string', default='') + node.s)
        return self.generic_visit(node)

      def _annotate_result(self, node):
        self.enter_local_scope()
        node = self.generic_visit(node)
        anno.setanno(node, 'test', self.get_local('string'))
        self.exit_local_scope()
        return node

      def visit_While(self, node):
        return self._annotate_result(node)

      def visit_For(self, node):
        return self._annotate_result(node)

    tr = TestTransformer(self._context_for_testing())

    def test_function(a):
      """Docstring."""
      assert a == 'This should not be counted'
      for i in range(3):
        _ = 'a'
        if i > 2:
          return 'b'
        else:
          _ = 'c'
          while True:
            raise '1'
      return 'nor this'

    node, _ = parser.parse_entity(test_function)
    node = tr.visit(node)

    for_node = node.body[0].body[2]
    while_node = for_node.body[1].orelse[1]

    self.assertFalse(anno.hasanno(for_node, 'string'))
    self.assertEqual('abc', anno.getanno(for_node, 'test'))
    self.assertFalse(anno.hasanno(while_node, 'string'))
    self.assertEqual('1', anno.getanno(while_node, 'test'))
示例#36
0
    def test_local_scope_info_stack(self):
        class TestTransformer(transformer.Base):

            # Extract all string constants from the block.
            def visit_Str(self, node):
                self.set_local('string',
                               self.get_local('string', default='') + node.s)
                return self.generic_visit(node)

            def _annotate_result(self, node):
                self.enter_local_scope()
                node = self.generic_visit(node)
                anno.setanno(node, 'test', self.get_local('string'))
                self.exit_local_scope()
                return node

            def visit_While(self, node):
                return self._annotate_result(node)

            def visit_For(self, node):
                return self._annotate_result(node)

        tr = TestTransformer(self._simple_source_info())

        def test_function(a):
            """Docstring."""
            assert a == 'This should not be counted'
            for i in range(3):
                _ = 'a'
                if i > 2:
                    return 'b'
                else:
                    _ = 'c'
                    while True:
                        raise '1'
            return 'nor this'

        node, _ = parser.parse_entity(test_function)
        node = tr.visit(node)

        for_node = node.body[0].body[2]
        while_node = for_node.body[1].orelse[1]

        self.assertFalse(anno.hasanno(for_node, 'string'))
        self.assertEqual('abc', anno.getanno(for_node, 'test'))
        self.assertFalse(anno.hasanno(while_node, 'string'))
        self.assertEqual('1', anno.getanno(while_node, 'test'))
示例#37
0
    def test_robust_error_on_ast_corruption(self):
        # A child class should not be able to be so broken that it causes the error
        # handling in `transformer.Base` to raise an exception.  Why not?  Because
        # then the original error location is dropped, and an error handler higher
        # up in the call stack gives misleading information.

        # Here we test that the error handling in `visit` completes, and blames the
        # correct original exception, even if the AST gets corrupted.

        class NotANode(object):
            pass

        class BrokenTransformer(transformer.Base):
            def visit_If(self, node):
                node.body = NotANode()
                raise ValueError('I blew up')

        def test_function(x):
            if x > 0:
                return x

        tr = BrokenTransformer(self._simple_source_info())

        node, _ = parser.parse_entity(test_function)
        with self.assertRaises(transformer.AutographParseError) as cm:
            node = tr.visit(node)
        obtained_message = str(cm.exception)
        # The message should reference the exception actually raised, not anything
        # from the exception handler.
        expected_substring = 'I blew up'
        self.assertTrue(expected_substring in obtained_message,
                        obtained_message)
        # Expect the exception to have failed to parse the corrupted AST
        self.assertTrue(
            '<could not convert AST to source>' in obtained_message,
            obtained_message)
        # The exception should point at the if statement, not any place else.  Could
        # also check the stack trace.
        self.assertTrue('Occurred at node:\nIf' in obtained_message,
                        obtained_message)
        self.assertTrue(
            'Occurred at node:\nFunctionDef' not in obtained_message,
            obtained_message)
        self.assertTrue('Occurred at node:\nReturn' not in obtained_message,
                        obtained_message)
  def test_robust_error_on_ast_corruption(self):
    # A child class should not be able to be so broken that it causes the error
    # handling in `transformer.Base` to raise an exception.  Why not?  Because
    # then the original error location is dropped, and an error handler higher
    # up in the call stack gives misleading information.

    # Here we test that the error handling in `visit` completes, and blames the
    # correct original exception, even if the AST gets corrupted.

    class NotANode(object):
      pass

    class BrokenTransformer(transformer.Base):

      def visit_If(self, node):
        node.body = NotANode()
        raise ValueError('I blew up')

    def test_function(x):
      if x > 0:
        return x

    tr = BrokenTransformer(self._simple_source_info())

    node, _ = parser.parse_entity(test_function)
    with self.assertRaises(transformer.AutographParseError) as cm:
      node = tr.visit(node)
    obtained_message = str(cm.exception)
    # The message should reference the exception actually raised, not anything
    # from the exception handler.
    expected_substring = 'I blew up'
    self.assertTrue(expected_substring in obtained_message, obtained_message)
    # Expect the exception to have failed to parse the corrupted AST
    self.assertTrue(
        '<could not convert AST to source>' in obtained_message,
        obtained_message)
    # The exception should point at the if statement, not any place else.  Could
    # also check the stack trace.
    self.assertTrue(
        'Occurred at node:\nIf' in obtained_message, obtained_message)
    self.assertTrue(
        'Occurred at node:\nFunctionDef' not in obtained_message,
        obtained_message)
    self.assertTrue(
        'Occurred at node:\nReturn' not in obtained_message, obtained_message)
示例#39
0
 def _parse_and_analyze(self,
                        test_fn,
                        namespace,
                        arg_types=None):
   node, source = parser.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       owner_type=None)
   node = qual_names.resolve(node)
   node = activity.resolve(node, entity_info)
   node = live_values.resolve(node, entity_info, {})
   node = type_info.resolve(node, entity_info)
   node = live_values.resolve(node, entity_info, {})
   return node
  def test_source_map(self):

    def test_fn(x):
      if x > 0:
        x += 1
      return x

    node, source = parser.parse_entity(test_fn)
    fn_node = node.body[0]
    origin_info.resolve(fn_node, source)

    # Insert a traced line.
    new_node = parser.parse_str('x = abs(x)').body[0]
    anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
    fn_node.body.insert(0, new_node)

    # Insert an untraced line.
    fn_node.body.insert(0, parser.parse_str('x = 0').body[0])

    modified_source = compiler.ast_to_source(fn_node)

    source_map = origin_info.source_map(fn_node, modified_source,
                                        'test_filename', [0])

    loc = origin_info.LineLocation('test_filename', 1)
    origin = source_map[loc]
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertEqual(origin.loc.lineno, 1)

    # The untraced line, inserted second.
    loc = origin_info.LineLocation('test_filename', 2)
    self.assertFalse(loc in source_map)

    # The traced line, inserted first.
    loc = origin_info.LineLocation('test_filename', 3)
    origin = source_map[loc]
    self.assertEqual(origin.source_code_line, '  if x > 0:')
    self.assertEqual(origin.loc.lineno, 2)

    loc = origin_info.LineLocation('test_filename', 4)
    origin = source_map[loc]
    self.assertEqual(origin.source_code_line, '  if x > 0:')
    self.assertEqual(origin.loc.lineno, 2)
示例#41
0
 def _parse_and_analyze(self,
                        test_fn,
                        namespace,
                        literals=None,
                        arg_types=None):
   literals = literals or {}
   node, source = parser.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       owner_type=None)
   node = qual_names.resolve(node)
   node = activity.resolve(node, entity_info)
   node = live_values.resolve(node, entity_info, literals)
   node = type_info.resolve(node, entity_info)
   node = live_values.resolve(node, entity_info, literals)
   return node
  def test_robust_error_on_list_visit(self):

    class BrokenTransformer(transformer.Base):

      def visit_If(self, node):
        # This is broken because visit expects a single node, not a list, and
        # the body of an if is a list.
        # Importantly, the default error handling in visit also expects a single
        # node.  Therefore, mistakes like this need to trigger a type error
        # before the visit called here installs its error handler.
        # That type error can then be caught by the enclosing call to visit,
        # and correctly blame the If node.
        self.visit(node.body)
        return node

    def test_function(x):
      if x > 0:
        return x

    tr = BrokenTransformer(self._simple_source_info())

    node, _ = parser.parse_entity(test_function)
    with self.assertRaises(transformer.AutographParseError) as cm:
      node = tr.visit(node)
    obtained_message = str(cm.exception)
    expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
    self.assertRegexpMatches(obtained_message, expected_message)
    # The exception should point at the if statement, not any place else.  Could
    # also check the stack trace.
    self.assertTrue(
        'Occurred at node:\nIf' in obtained_message, obtained_message)
    self.assertTrue(
        'Occurred at node:\nFunctionDef' not in obtained_message,
        obtained_message)
    self.assertTrue(
        'Occurred at node:\nReturn' not in obtained_message, obtained_message)
  def parse_and_analyze(self,
                        test_fn,
                        namespace,
                        namer=None,
                        arg_types=None,
                        include_type_analysis=True,
                        owner_type=None,
                        recursive=True,
                        autograph_decorators=()):
    node, source = parser.parse_entity(test_fn)

    if namer is None:
      namer = FakeNamer()
    program_ctx = converter.ProgramContext(
        recursive=recursive,
        autograph_decorators=autograph_decorators,
        partial_types=None,
        autograph_module=None,
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    entity_info = transformer.EntityInfo(
        source_code=source,
        source_file='<fragment>',
        namespace=namespace,
        arg_values=None,
        arg_types=arg_types,
        owner_type=owner_type)
    ctx = converter.EntityContext(namer, entity_info, program_ctx)

    node = qual_names.resolve(node)
    node = activity.resolve(node, entity_info)
    node = live_values.resolve(node, entity_info, {})
    if include_type_analysis:
      node = type_info.resolve(node, entity_info)
      node = live_values.resolve(node, entity_info, {})
    self.ctx = ctx
    return node
示例#44
0
 def _build_cfg(self, fn):
     node, _ = parser.parse_entity(fn)
     cfgs = cfg.build(node)
     return cfgs
示例#45
0
    def test_state_tracking(self):
        class LoopState(object):
            pass

        class CondState(object):
            pass

        class TestTransformer(transformer.Base):
            def visit(self, node):
                anno.setanno(node, 'loop_state', self.state[LoopState].value)
                anno.setanno(node, 'cond_state', self.state[CondState].value)
                return super(TestTransformer, self).visit(node)

            def visit_While(self, node):
                self.state[LoopState].enter()
                node = self.generic_visit(node)
                self.state[LoopState].exit()
                return node

            def visit_If(self, node):
                self.state[CondState].enter()
                node = self.generic_visit(node)
                self.state[CondState].exit()
                return node

        tr = TestTransformer(self._simple_source_info())

        def test_function(a):
            a = 1
            while a:
                _ = 'a'
                if a > 2:
                    _ = 'b'
                    while True:
                        raise '1'
                if a > 3:
                    _ = 'c'
                    while True:
                        raise '1'

        node, _ = parser.parse_entity(test_function)
        node = tr.visit(node)

        fn_body = node.body[0].body
        outer_while_body = fn_body[1].body
        self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state')
        self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state')

        first_if_body = outer_while_body[1].body
        self.assertDifferentAnno(outer_while_body[0], first_if_body[0],
                                 'cond_state')
        self.assertSameAnno(outer_while_body[0], first_if_body[0],
                            'loop_state')

        first_inner_while_body = first_if_body[1].body
        self.assertSameAnno(first_if_body[0], first_inner_while_body[0],
                            'cond_state')
        self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0],
                                 'loop_state')

        second_if_body = outer_while_body[2].body
        self.assertDifferentAnno(first_if_body[0], second_if_body[0],
                                 'cond_state')
        self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state')

        second_inner_while_body = second_if_body[1].body
        self.assertDifferentAnno(first_inner_while_body[0],
                                 second_inner_while_body[0], 'cond_state')
        self.assertDifferentAnno(first_inner_while_body[0],
                                 second_inner_while_body[0], 'loop_state')
示例#46
0
  def test_state_tracking(self):

    class LoopState(object):
      pass

    class CondState(object):
      pass

    class TestTransformer(transformer.Base):

      def visit(self, node):
        anno.setanno(node, 'loop_state', self.state[LoopState].value)
        anno.setanno(node, 'cond_state', self.state[CondState].value)
        return super(TestTransformer, self).visit(node)

      def visit_While(self, node):
        self.state[LoopState].enter()
        node = self.generic_visit(node)
        self.state[LoopState].exit()
        return node

      def visit_If(self, node):
        self.state[CondState].enter()
        node = self.generic_visit(node)
        self.state[CondState].exit()
        return node

    tr = TestTransformer(self._simple_source_info())

    def test_function(a):
      a = 1
      while a:
        _ = 'a'
        if a > 2:
          _ = 'b'
          while True:
            raise '1'
        if a > 3:
          _ = 'c'
          while True:
            raise '1'

    node, _ = parser.parse_entity(test_function)
    node = tr.visit(node)

    fn_body = node.body[0].body
    outer_while_body = fn_body[1].body
    self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state')
    self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state')

    first_if_body = outer_while_body[1].body
    self.assertDifferentAnno(outer_while_body[0], first_if_body[0],
                             'cond_state')
    self.assertSameAnno(outer_while_body[0], first_if_body[0], 'loop_state')

    first_inner_while_body = first_if_body[1].body
    self.assertSameAnno(first_if_body[0], first_inner_while_body[0],
                        'cond_state')
    self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0],
                             'loop_state')

    second_if_body = outer_while_body[2].body
    self.assertDifferentAnno(first_if_body[0], second_if_body[0], 'cond_state')
    self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state')

    second_inner_while_body = second_if_body[1].body
    self.assertDifferentAnno(first_inner_while_body[0],
                             second_inner_while_body[0], 'cond_state')
    self.assertDifferentAnno(first_inner_while_body[0],
                             second_inner_while_body[0], 'loop_state')
示例#47
0
 def _build_cfg(self, fn):
   node, _ = parser.parse_entity(fn)
   cfgs = cfg.build(node)
   return cfgs
示例#48
0
    def test_parse_entity(self):
        def f(x):
            return x + 1

        mod, _ = parser.parse_entity(f)
        self.assertEqual('f', mod.body[0].name)