Example #1
0
    def test_ast_to_object(self):
        node = gast.FunctionDef(
            name='f',
            args=gast.arguments(args=[gast.Name('a', gast.Param(), None)],
                                vararg=None,
                                kwonlyargs=[],
                                kwarg=None,
                                defaults=[],
                                kw_defaults=[]),
            body=[
                gast.Return(
                    gast.BinOp(op=gast.Add(),
                               left=gast.Name('a', gast.Load(), None),
                               right=gast.Num(1)))
            ],
            decorator_list=[],
            returns=None)

        module, source = parsing.ast_to_object(node)

        expected_source = """
      def f(a):
        return a + 1
    """
        self.assertEqual(
            textwrap.dedent(expected_source).strip(), source.strip())
        self.assertEqual(2, module.f(1))
        with open(module.__file__, 'r') as temp_output:
            self.assertEqual(
                textwrap.dedent(expected_source).strip(),
                temp_output.read().strip())
Example #2
0
 def test_keywords_to_dict(self):
   keywords = parsing.parse_expression('f(a=b, c=1, d=\'e\')').keywords
   d = ast_util.keywords_to_dict(keywords)
   # Make sure we generate a usable dict node by attaching it to a variable and
   # compiling everything.
   node = parsing.parse_str('def f(b): pass').body[0]
   node.body.append(ast.Return(d))
   result, _ = parsing.ast_to_object(node)
   self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
Example #3
0
    def test_replace_tuple(self):
        template = """
      def test_fn(a, c):
        return b,
    """

        node = templates.replace(template, b=('a', 'c'))[0]
        result, _ = parsing.ast_to_object(node)

        self.assertEqual((2, 3), result.test_fn(2, 3))
Example #4
0
    def test_replace_name_with_dict(self):
        template = """
      def test_fn():
        return foo['bar']
    """

        source = parsing.parse_expression('{\'bar\': 3}')
        node = templates.replace(template, foo=source)[0]
        result, _ = parsing.ast_to_object(node)
        self.assertEqual(3, result.test_fn())
Example #5
0
    def test_replace_function_name(self):
        template = """
      def fname(a):
        a += 1
        a = 2 * a + 1
        return a
    """

        node = templates.replace(template, fname='test_fn')[0]
        result, _ = parsing.ast_to_object(node)
        self.assertEqual(7, result.test_fn(2))
Example #6
0
    def test_replace_variable(self):
        template = """
      def test_fn(a):
        a += 1
        a = 2 * a + 1
        return b
    """

        node = templates.replace(template, a='b')[0]
        result, _ = parsing.ast_to_object(node)
        self.assertEqual(7, result.test_fn(2))
Example #7
0
    def test_parsing_compile_idempotent(self):
        def test_fn(x):
            a = True
            b = ''
            if a:
                b = x + 1
            return b

        self.assertEqual(
            textwrap.dedent(inspect.getsource(test_fn)),
            inspect.getsource(
                parsing.ast_to_object(
                    parsing.parse_entity(test_fn)[0].body[0])[0].test_fn))
Example #8
0
    def test_replace_attribute(self):
        template = """
      def test_fn(a):
        return a.foo
    """

        node = templates.replace(template, foo='b')[0]
        result, _ = parsing.ast_to_object(node)
        mod = imp.new_module('test')
        mod.b = 3
        self.assertEqual(3, result.test_fn(mod))

        with self.assertRaises(ValueError):
            templates.replace(template, foo=1)
Example #9
0
    def test_replace_name_with_call(self):
        template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

        source = parsing.parse_expression('f()(b)')
        node = templates.replace(template, foo=source)[0]
        result, _ = parsing.ast_to_object(node)
        self.assertEqual(15, result.test_fn())
Example #10
0
    def test_replace_code_block(self):
        template = """
      def test_fn(a):
        block
        return a
    """

        node = templates.replace(
            template,
            block=[
                gast.Assign([gast.Name('a', None, None)],
                            gast.BinOp(gast.Name('a', None, None), gast.Add(),
                                       gast.Num(1))),
            ] * 2)[0]
        result, _ = parsing.ast_to_object(node)
        self.assertEqual(3, result.test_fn(1))
Example #11
0
    def test_replace_call_keyword(self):
        template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

        source = parsing.parse_expression('f(d=3, f=5)')
        node = templates.replace(template, kws=source.keywords)[0]
        result, _ = parsing.ast_to_object(node)
        self.assertEqual(9, result.test_fn())

        with self.assertRaises(ValueError):
            templates.replace(template, kws=[])
            templates.replace(template, kws=1)
Example #12
0
def _wrap_in_generator(func, source, namer, overload):
  """Wraps the source code in a generated function.

  Args:
    func: the original function
    source: the generated source code
    namer: naming.Namer, used for naming vars
    overload: config.VirtualizationConfig

  Returns:
    The generated function with a new closure variable.
  """

  nonlocals = []

  for var in six.get_function_code(func).co_freevars:
    # We must generate dummy vars so the generated function has the same closure
    # as the original function.
    free_template = 'var = None'
    nonlocal_node = templates.replace(free_template, var=var)
    nonlocals.extend(nonlocal_node)

  gen_fun_name = namer.new_symbol('gen_fun', set())
  template = """
    def gen_fun(overload):
      nonlocals

      program

      return f_name
  """

  ret = templates.replace(
      template,
      gen_fun=gen_fun_name,
      nonlocals=nonlocals,
      overload=overload.symbol_name,
      program=source,
      f_name=func.__name__)

  converted_module, _ = parsing.ast_to_object(ret)
  outer_func = getattr(converted_module, gen_fun_name)
  return outer_func(overload.module)