Example #1
0
  def visit_For(self, node):
    node.iter = self.visit(node.iter)
    node.body = self.visit_block(node.body)
    node.orelse = self.visit_block(node.orelse)

    targets = []

    if isinstance(node.target, gast.Tuple) or isinstance(
        node.target, gast.List):
      for target in node.target.elts:
        targets.append(target)
    elif isinstance(node.target, gast.Name):
      targets.append(node.target)
    else:
      raise ValueError(
          'For target must be gast.Tuple, gast.List, or gast.Name, got {}.'
          .format(type(node.target)))

    n_target = self.ctx.namer.new_symbol('n_target',
                                         set([target.id for target in targets]))
    target_assigns = []

    if len(targets) > 1:
      for i, target in enumerate(targets):
        target_assign = self._make_target_assign(target, n_target, i,
                                                 self.overload)
        target_assigns.extend(target_assign)
    else:
      target_assign = templates.replace(
          'overload.assign(target, n_target)',
          overload=self.overload.symbol_name,
          target=targets[0],
          n_target=n_target)
      target_assigns.extend(target_assign)

    template = """
      for n_target in iter:
        target_assigns
        body
      else:
        orelse
    """

    node = templates.replace(
        template,
        n_target=n_target,
        iter=node.iter,
        target_assigns=target_assigns,
        body=node.body,
        orelse=node.orelse)

    return node
Example #2
0
    def test_replace_attribute(self):
        template = """
      def test_fn(a):
        return a.foo
    """

        node = templates.replace(template, foo='b')[0]
        result, _ = parsing.ast_to_object(node)
        mod = imp.new_module('test')
        mod.b = 3
        self.assertEqual(3, result.test_fn(mod))

        with self.assertRaises(ValueError):
            templates.replace(template, foo=1)
Example #3
0
    def visit_While(self, node):
        body_scope = anno.getanno(node, anno.Static.BODY_SCOPE)
        orelse_scope = anno.getanno(node, anno.Static.ORELSE_SCOPE)
        modified_in_cond = body_scope.modified | orelse_scope.modified

        node = self.generic_visit(node)

        if not hasattr(self.overload.module, 'while_stmt'):
            return node

        template = """
      def test_name():
        return test
      def body_name():
        body
      def orelse_name():
        orelse
      overload.while_stmt(test_name, body_name, orelse_name, (local_writes,))
    """

        node = templates.replace(
            template,
            overload=self.overload.symbol_name,
            test_name=self.ctx.namer.new_symbol('while_test', set()),
            test=node.test,
            body_name=self.ctx.namer.new_symbol('while_body', set()),
            body=node.body,
            orelse_name=self.ctx.namer.new_symbol('while_orelse', set()),
            orelse=node.orelse if node.orelse else gast.Pass(),
            local_writes=tuple(modified_in_cond))

        return node
Example #4
0
 def create_assignment(self, target, expression):
     template = """
   target = expression
 """
     return templates.replace(template,
                              target=target,
                              expression=expression)
Example #5
0
    def test_replace_call_keyword(self):
        template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

        source = parsing.parse_expression('f(d=3, f=5)')
        node = templates.replace(template, kws=source.keywords)[0]
        result, _ = parsing.ast_to_object(node)
        self.assertEqual(9, result.test_fn())

        with self.assertRaises(ValueError):
            templates.replace(template, kws=[])
            templates.replace(template, kws=1)
Example #6
0
    def test_replace_expression_context(self):
        template = """
      def test_fn():
        foo
    """

        node = templates.replace(
            template, foo=parsing.parse_expression('a + 2 * b / -c'))[0]
        self.assertIsInstance(node.body[0].left.ctx, gast.Load)
        self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)
Example #7
0
def _wrap_in_generator(func, source, namer, overload):
  """Wraps the source code in a generated function.

  Args:
    func: the original function
    source: the generated source code
    namer: naming.Namer, used for naming vars
    overload: config.VirtualizationConfig

  Returns:
    The generated function with a new closure variable.
  """

  nonlocals = []

  for var in six.get_function_code(func).co_freevars:
    # We must generate dummy vars so the generated function has the same closure
    # as the original function.
    free_template = 'var = None'
    nonlocal_node = templates.replace(free_template, var=var)
    nonlocals.extend(nonlocal_node)

  gen_fun_name = namer.new_symbol('gen_fun', set())
  template = """
    def gen_fun(overload):
      nonlocals

      program

      return f_name
  """

  ret = templates.replace(
      template,
      gen_fun=gen_fun_name,
      nonlocals=nonlocals,
      overload=overload.symbol_name,
      program=source,
      f_name=func.__name__)

  converted_module, _ = parsing.ast_to_object(ret)
  outer_func = getattr(converted_module, gen_fun_name)
  return outer_func(overload.module)
Example #8
0
    def test_replace_tuple(self):
        template = """
      def test_fn(a, c):
        return b,
    """

        node = templates.replace(template, b=('a', 'c'))[0]
        result, _ = parsing.ast_to_object(node)

        self.assertEqual((2, 3), result.test_fn(2, 3))
Example #9
0
    def test_replace_name_with_dict(self):
        template = """
      def test_fn():
        return foo['bar']
    """

        source = parsing.parse_expression('{\'bar\': 3}')
        node = templates.replace(template, foo=source)[0]
        result, _ = parsing.ast_to_object(node)
        self.assertEqual(3, result.test_fn())
Example #10
0
    def test_replace_function_name(self):
        template = """
      def fname(a):
        a += 1
        a = 2 * a + 1
        return a
    """

        node = templates.replace(template, fname='test_fn')[0]
        result, _ = parsing.ast_to_object(node)
        self.assertEqual(7, result.test_fn(2))
Example #11
0
    def test_replace_variable(self):
        template = """
      def test_fn(a):
        a += 1
        a = 2 * a + 1
        return b
    """

        node = templates.replace(template, a='b')[0]
        result, _ = parsing.ast_to_object(node)
        self.assertEqual(7, result.test_fn(2))
Example #12
0
    def test_replace_tuple_context(self):
        template = """
      def test_fn(foo):
        foo = 0
    """

        node = templates.replace(template,
                                 foo=parsing.parse_expression('(a, b)'))[0]
        self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
        self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
        self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
Example #13
0
    def test_replace_index(self):
        template = """
      def test_fn():
        foo = 0
    """

        node = templates.replace(
            template, foo=parsing.parse_expression('foo(a[b]).bar'))[0]
        function_call_arg = node.body[0].targets[0].value.args[0]
        self.assertIsInstance(function_call_arg.ctx, gast.Load)
        self.assertIsInstance(function_call_arg.slice.value.ctx, gast.Load)
Example #14
0
    def test_replace_attribute_context(self):
        template = """
      def test_fn(foo):
        foo = 0
    """

        node = templates.replace(template,
                                 foo=parsing.parse_expression('a.b.c'))[0]
        self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
        self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
        self.assertIsInstance(node.body[0].targets[0].value.value.ctx,
                              gast.Load)
Example #15
0
    def test_replace_complex_context(self):
        template = """
      def test_fn():
        foo = 0
    """

        node = templates.replace(
            template, foo=parsing.parse_expression('bar(([a, b],)).baz'))[0]
        self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
        function_call_arg = node.body[0].targets[0].value.args[0]
        self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
        self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
        self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
Example #16
0
    def visit_For(self, node):
        body_scope = anno.getanno(node, anno.Static.BODY_SCOPE)
        orelse_scope = anno.getanno(node, anno.Static.ORELSE_SCOPE)
        modified_in_cond = body_scope.modified | orelse_scope.modified

        node = self.generic_visit(node)

        if not hasattr(self.overload.module, 'for_stmt'):
            return node

        # TODO(jmd1011): Handle extra_test

        targets = []

        if isinstance(node.target, gast.Tuple) or isinstance(
                node.target, gast.List):
            for target in node.target.elts:
                targets.append(target)
        elif isinstance(node.target, gast.Name):
            targets.append(node.target)
        else:
            raise ValueError(
                'For target must be gast.Tuple, gast.List, or gast.Name, got {}.'
                .format(type(node.target)))

        target_inits = [
            self._make_target_init(target, self.overload) for target in targets
        ]

        template = """
      target_inits
      def body_name():
        body
      def orelse_name():
        orelse
      overload.for_stmt(target, iter_, body_name, orelse_name, (local_writes,))
    """

        node = templates.replace(
            template,
            target_inits=target_inits,
            target=node.target,
            body_name=self.ctx.namer.new_symbol('for_body', set()),
            body=node.body,
            orelse_name=self.ctx.namer.new_symbol('for_orelse', set()),
            orelse=node.orelse if node.orelse else gast.Pass(),
            overload=self.overload.symbol_name,
            iter_=node.iter,
            local_writes=tuple(modified_in_cond))

        return node
Example #17
0
    def test_replace_name_with_call(self):
        template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

        source = parsing.parse_expression('f()(b)')
        node = templates.replace(template, foo=source)[0]
        result, _ = parsing.ast_to_object(node)
        self.assertEqual(15, result.test_fn())
Example #18
0
    def test_replace_code_block(self):
        template = """
      def test_fn(a):
        block
        return a
    """

        node = templates.replace(
            template,
            block=[
                gast.Assign([gast.Name('a', None, None)],
                            gast.BinOp(gast.Name('a', None, None), gast.Add(),
                                       gast.Num(1))),
            ] * 2)[0]
        result, _ = parsing.ast_to_object(node)
        self.assertEqual(3, result.test_fn(1))
Example #19
0
  def visit_Assign(self, node):
    # TODO(b/123943188): Handle multiple assignment
    node.value = self.visit(node.value)

    lhs = node.targets[0].id
    rhs = node.value

    if not self.scope.should_virtualize(lhs):
      return node

    node = templates.replace(
        'overload.assign(lhs, rhs)',
        lhs=lhs,
        rhs=rhs,
        overload=self.overload.symbol_name)

    return node
Example #20
0
 def _make_target_init(self, target, overload):
     return templates.replace('target = overload.init(target_name)',
                              target=target,
                              target_name='"{}"'.format(target.id),
                              overload=self.overload.symbol_name)
Example #21
0
  def visit_FunctionDef(self, node):
    assert id(node) in self.scopes

    self.scope = self.scopes[id(node)]

    arg_names = [arg.id for arg in node.args.args]
    n_arg_names = [
        self.ctx.namer.new_symbol(arg, set(arg_names)) for arg in arg_names
    ]

    init_nodes = []

    for var in self.scope.locals:
      init_template = 'lhs = overload.init(lhs_name)'
      init_node = templates.replace(
          init_template,
          lhs=var,
          lhs_name='"{}"'.format(var),
          overload=self.overload.symbol_name)
      init_nodes.extend(init_node)

    arg_nodes = []

    for (arg, n_arg) in zip(arg_names, n_arg_names):
      arg_template = 'overload.assign(lhs, rhs)'
      arg_node = templates.replace(
          arg_template, lhs=arg, overload=self.overload.symbol_name, rhs=n_arg)
      arg_nodes.extend(arg_node)

    node.body = self.visit_block(node.body)

    if self.scope.parent and self.scope.parent.is_local(node.name):
      template = """
        def new_fun_name(args):
          inits
          arg_nodes
          body
        overload.assign(fun_name, new_fun_name)
      """

      node = templates.replace(
          template,
          new_fun_name=self.ctx.namer.new_symbol(node.name, set([node.name])),
          args=n_arg_names,
          arg_nodes=arg_nodes,
          inits=init_nodes,
          body=node.body,
          overload=self.overload.symbol_name,
          fun_name=node.name,
      )
    else:
      template = """
        def fun_name(args):
          inits
          arg_nodes
          body
      """

      node = templates.replace(
          template,
          fun_name=node.name,
          args=n_arg_names,
          arg_nodes=arg_nodes,
          inits=init_nodes,
          body=node.body,
      )

    self.scope = self.scope.parent
    return node
Example #22
0
 def _make_target_assign(self, target, n_target, i, overload):
   return templates.replace(
       'overload.assign(target, n_target[{}])'.format(i),
       target=target,
       n_target=n_target,
       overload=self.overload.symbol_name)