예제 #1
0
    def test_parse_comments(self):
        def f():
            # unindented comment
            pass

        with self.assertRaises(ValueError):
            parsing.parse_entity(f)
예제 #2
0
    def test_parse_multiline_strings(self):
        def f():
            print("""
some
multiline
string""")

        with self.assertRaises(ValueError):
            parsing.parse_entity(f)
예제 #3
0
  def test_resolve(self):

    def test_fn(x):
      """Docstring."""
      return x  # comment

    node, source = parsing.parse_entity(test_fn)
    fn_node = node.body[0]

    origin_info.resolve(fn_node, source)

    origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 1)
    self.assertEqual(origin.loc.col_offset, 0)
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 2)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  """Docstring."""')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 3)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  return x  # comment')
    self.assertEqual(origin.comment, 'comment')
예제 #4
0
    def test_robust_error_on_ast_corruption(self):
        # A child class should not be able to be so broken that it causes the error
        # handling in `transformer.Base` to raise an exception.  Why not?  Because
        # then the original error location is dropped, and an error handler higher
        # up in the call stack gives misleading information.

        # Here we test that the error handling in `visit` completes, and blames the
        # correct original exception, even if the AST gets corrupted.

        class NotANode(object):
            pass

        class BrokenTransformer(transformer.Base):
            def visit_If(self, node):
                node.body = NotANode()
                raise ValueError('I blew up')

        def test_function(x):
            if x > 0:
                return x

        tr = BrokenTransformer(self._simple_source_info())

        node, _ = parsing.parse_entity(test_function)
        with self.assertRaises(ValueError) as cm:
            node = tr.visit(node)
        obtained_message = str(cm.exception)
        # The message should reference the exception actually raised, not anything
        # from the exception handler.
        expected_substring = 'I blew up'
        self.assertIn(expected_substring, obtained_message, obtained_message)
예제 #5
0
def convert(func, overload_module, transformers):
  """Main entry point for converting a function using Pyct.

  Args:
    func: function to be converted
    overload_module: module containing overloaded functionality
    transformers: list of transformers to be applied

  Returns:
    gen_func: converted function
  """
  source, _ = parsing.parse_entity(func)
  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      namespace={},
      arg_values=None,
      arg_types={},
      owner_type=None)

  namer = naming.Namer(entity_info.namespace)
  ctx = transformer.EntityContext(namer, entity_info)
  overload_name = ctx.namer.new_symbol('overload', set())
  overload = config.VirtualizationConfig(overload_module, overload_name)

  source = _transform(source, ctx, overload, transformers)
  gen_func = _wrap_in_generator(func, source, namer, overload)
  gen_func = _attach_closure(func, gen_func)
  return gen_func
예제 #6
0
    def test_robust_error_on_list_visit(self):
        class BrokenTransformer(transformer.Base):
            def visit_If(self, node):
                # This is broken because visit expects a single node, not a list, and
                # the body of an if is a list.
                # Importantly, the default error handling in visit also expects a single
                # node.  Therefore, mistakes like this need to trigger a type error
                # before the visit called here installs its error handler.
                # That type error can then be caught by the enclosing call to visit,
                # and correctly blame the If node.
                self.visit(node.body)
                return node

        def test_function(x):
            if x > 0:
                return x

        tr = BrokenTransformer(self._simple_source_info())

        node, _ = parsing.parse_entity(test_function)
        with self.assertRaises(ValueError) as cm:
            node = tr.visit(node)
        obtained_message = str(cm.exception)
        expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
        self.assertRegexpMatches(obtained_message, expected_message)
예제 #7
0
    def test_visit_block_postprocessing(self):
        class TestTransformer(transformer.Base):
            def _process_body_item(self, node):
                if isinstance(node, gast.Assign) and (node.value.id == 'y'):
                    if_node = gast.If(gast.Name('x', gast.Load(), None),
                                      [node], [])
                    return if_node, if_node.body
                return node, None

            def visit_FunctionDef(self, node):
                node.body = self.visit_block(
                    node.body, after_visit=self._process_body_item)
                return node

        def test_function(x, y):
            z = x
            z = y
            return z

        tr = TestTransformer(self._simple_source_info())

        node, _ = parsing.parse_entity(test_function)
        node = tr.visit(node)
        node = node.body[0]

        self.assertLen(node.body, 2)
        self.assertIsInstance(node.body[0], gast.Assign)
        self.assertIsInstance(node.body[1], gast.If)
        self.assertIsInstance(node.body[1].body[0], gast.Assign)
        self.assertIsInstance(node.body[1].body[1], gast.Return)
예제 #8
0
    def test_entity_scope_tracking(self):
        class TestTransformer(transformer.Base):

            # The choice of note to assign to is arbitrary. Using Assign because it's
            # easy to find in the tree.
            def visit_Assign(self, node):
                anno.setanno(node, 'enclosing_entities',
                             self.enclosing_entities)
                return self.generic_visit(node)

            # This will show up in the lambda function.
            def visit_BinOp(self, node):
                anno.setanno(node, 'enclosing_entities',
                             self.enclosing_entities)
                return self.generic_visit(node)

        tr = TestTransformer(self._simple_source_info())

        def test_function():
            a = 0

            class TestClass(object):
                def test_method(self):
                    b = 0

                    def inner_function(x):
                        c = 0
                        d = lambda y: (x + y)
                        return c, d

                    return b, inner_function

            return a, TestClass

        node, _ = parsing.parse_entity(test_function)
        node = tr.visit(node)

        test_function_node = node.body[0]
        test_class = test_function_node.body[1]
        test_method = test_class.body[0]
        inner_function = test_method.body[1]
        lambda_node = inner_function.body[1].value

        a = test_function_node.body[0]
        b = test_method.body[0]
        c = inner_function.body[0]
        lambda_expr = lambda_node.body

        self.assertEqual((test_function_node, ),
                         anno.getanno(a, 'enclosing_entities'))
        self.assertEqual((test_function_node, test_class, test_method),
                         anno.getanno(b, 'enclosing_entities'))
        self.assertEqual(
            (test_function_node, test_class, test_method, inner_function),
            anno.getanno(c, 'enclosing_entities'))
        self.assertEqual((test_function_node, test_class, test_method,
                          inner_function, lambda_node),
                         anno.getanno(lambda_expr, 'enclosing_entities'))
예제 #9
0
  def test_source_map_no_origin(self):

    def test_fn(x):
      return x + 1

    node, _ = parsing.parse_entity(test_fn)
    fn_node = node.body[0]
    converted_code = parsing.ast_to_source(fn_node)

    source_map = origin_info.create_source_map(fn_node, converted_code,
                                               'test_filename', [0])

    self.assertEmpty(source_map)
예제 #10
0
 def _parse_and_analyze(self, test_fn):
   node, source = parsing.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace={},
       arg_values=None,
       arg_types=None,
       owner_type=None)
   node = qual_names.resolve(node)
   ctx = transformer.Context(entity_info)
   node = activity.resolve(node, ctx)
   return node, entity_info
예제 #11
0
    def test_parsing_compile_idempotent(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = x + 1
            return b

        self.assertEqual(
            textwrap.dedent(inspect.getsource(test_fn)),
            inspect.getsource(
                parsing.ast_to_object(
                    parsing.parse_entity(test_fn)[0].body[0])[0].test_fn))
예제 #12
0
    def get_scopes(self, func):
        source, _ = parsing.parse_entity(func)
        entity_info = transformer.EntityInfo(source_code=source,
                                             source_file='<fragment>',
                                             namespace={},
                                             arg_values=None,
                                             arg_types={},
                                             owner_type=None)

        namer = naming.Namer(entity_info.namespace)
        ctx = transformer.EntityContext(namer, entity_info)
        scope_transformer = scoping.ScopeTransformer(ctx)
        scope_transformer.visit(source)
        return scope_transformer.scopes
예제 #13
0
  def test_create_source_map(self):

    def test_fn(x):
      return x + 1

    node, _ = parsing.parse_entity(test_fn)
    fake_origin = origin_info.OriginInfo(
        loc=origin_info.Location('fake_filename', 3, 7),
        function_name='fake_function_name',
        source_code_line='fake source line',
        comment=None)
    fn_node = node.body[0]
    anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin)
    converted_code = parsing.ast_to_source(fn_node)

    source_map = origin_info.create_source_map(fn_node, converted_code,
                                               'test_filename', [0])

    loc = origin_info.LineLocation('test_filename', 2)
    self.assertIn(loc, source_map)
    self.assertIs(source_map[loc], fake_origin)
예제 #14
0
    def test_parse_entity(self):
        def f(x):
            return x + 1

        mod, _ = parsing.parse_entity(f)
        self.assertEqual('f', mod.body[0].name)
예제 #15
0
파일: cfg_test.py 프로젝트: google/pyctr
 def _build_cfg(self, fn):
   node, _ = parsing.parse_entity(fn)
   cfgs = cfg.build(node)
   return cfgs
예제 #16
0
    def test_state_tracking(self):
        class LoopState(object):
            pass

        class CondState(object):
            pass

        class TestTransformer(transformer.Base):
            def visit(self, node):
                anno.setanno(node, 'loop_state', self.state[LoopState].value)
                anno.setanno(node, 'cond_state', self.state[CondState].value)
                return super(TestTransformer, self).visit(node)

            def visit_While(self, node):
                self.state[LoopState].enter()
                node = self.generic_visit(node)
                self.state[LoopState].exit()
                return node

            def visit_If(self, node):
                self.state[CondState].enter()
                node = self.generic_visit(node)
                self.state[CondState].exit()
                return node

        tr = TestTransformer(self._simple_source_info())

        def test_function(a):
            a = 1
            while a:
                _ = 'a'
                if a > 2:
                    _ = 'b'
                    while True:
                        raise '1'
                if a > 3:
                    _ = 'c'
                    while True:
                        raise '1'

        node, _ = parsing.parse_entity(test_function)
        node = tr.visit(node)

        fn_body = node.body[0].body
        outer_while_body = fn_body[1].body
        self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state')
        self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state')

        first_if_body = outer_while_body[1].body
        self.assertDifferentAnno(outer_while_body[0], first_if_body[0],
                                 'cond_state')
        self.assertSameAnno(outer_while_body[0], first_if_body[0],
                            'loop_state')

        first_inner_while_body = first_if_body[1].body
        self.assertSameAnno(first_if_body[0], first_inner_while_body[0],
                            'cond_state')
        self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0],
                                 'loop_state')

        second_if_body = outer_while_body[2].body
        self.assertDifferentAnno(first_if_body[0], second_if_body[0],
                                 'cond_state')
        self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state')

        second_inner_while_body = second_if_body[1].body
        self.assertDifferentAnno(first_inner_while_body[0],
                                 second_inner_while_body[0], 'cond_state')
        self.assertDifferentAnno(first_inner_while_body[0],
                                 second_inner_while_body[0], 'loop_state')