예제 #1
0
  def visit_Assign(self, node):
    if not isinstance(node.value, gast.ListComp):
      return self.generic_visit(node)
    if len(node.targets) > 1:
      raise NotImplementedError('multiple assignments')

    target, = node.targets
    list_comp_node = node.value

    template = """
      target = []
    """
    initialization = templates.replace(template, target=target)

    template = """
      target.append(elt)
    """
    body = templates.replace(template, target=target, elt=list_comp_node.elt)

    for gen in reversed(list_comp_node.generators):
      for gen_if in reversed(gen.ifs):
        template = """
          if test:
            body
        """
        body = templates.replace(template, test=gen_if, body=body)
      template = """
        for target in iter_:
          body
      """
      body = templates.replace(
          template, iter_=gen.iter, target=gen.target, body=body)

    return initialization + body
예제 #2
0
 def replace_as_expression_restrictions(self):
   template = """
     foo(a)
     bar(b)
   """
   with self.assertRaises(ValueError):
     templates.replace_as_expression(template)
   with self.assertRaises(ValueError):
     templates.replace('')
   with self.assertRaises(ValueError):
     templates.replace('a = b')
예제 #3
0
  def test_replace_attribute(self):
    template = """
      def test_fn(a):
        return a.foo
    """

    node = templates.replace(template, foo='b')[0]
    result, _ = compiler.ast_to_object(node)
    mod = imp.new_module('test')
    mod.b = 3
    self.assertEquals(3, result.test_fn(mod))

    with self.assertRaises(ValueError):
      templates.replace(template, foo=1)
  def visit_With(self, node):
    # Depth-first traversal of syntax
    node = self.generic_visit(node)

    # If the with statement returns, lift the return
    if isinstance(node.body[-1], gast.Return):
      node.body[-1] = templates.replace(
          'a = b', a=self.common_return_name, b=node.body[-1].value)[0]
      return_node = templates.replace('return a', a=self.common_return_name)[0]
      node = self.generic_visit(node)
      self.changes_made = True
      return [node, return_node]
    else:
      return node
예제 #5
0
 def _insert_dynamic_conversion(self, node):
   """Inlines a dynamic conversion for a dynamic function."""
   # TODO(mdan): Pass information on the statically compiled functions.
   # Having access to the statically compiled functions can help avoid
   # unnecessary compilation.
   # For example, this would lead to function `a` being compiled twice:
   #
   #   def a():
   #     v = b
   #     b()
   #   def b():
   #     a()
   #
   # This is really a problem with recursive calls, which currently can
   # only be gated by a static condition, and should be rare.
   # TODO(mdan): It probably makes sense to use dynamic conversion every time.
   # Before we could convert all the time though, we'd need a reasonable
   # caching mechanism.
   template = """
     ag__.converted_call(func, True, False, False, {}, args)
   """
   call_expr = templates.replace(template, func=node.func, args=node.args)
   new_call = call_expr[0].value
   # TODO(mdan): Improve the template mechanism to better support this.
   new_call.keywords = node.keywords
   return new_call
예제 #6
0
  def visit_For(self, node):
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    break_var = self.context.namer.new_symbol('break__', scope.referenced)

    node.target = self.visit(node.target)
    node.iter = self.visit(node.iter)
    node.body, break_used = self._track_body(node.body, break_var)
    # A break in the else clause applies to the containing scope.
    node.orelse = self.visit_block(node.orelse)

    if break_used:
      node.orelse = self._guard_if_present(node.orelse, break_var)
      template = """
        var_name = False
        for_stmt
      """
      # Python's else clause only triggers if the loop exited cleanly (e.g.
      # break did not trigger).
      node = templates.replace(
          template,
          var_name=break_var,
          for_stmt=node)
      extra_test = templates.replace_as_expression(
          'not var_name', var_name=break_var)
      anno.setanno(node[1], 'extra_test', extra_test)

    return node
예제 #7
0
  def visit_While(self, node):
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

    node.test = self.visit(node.test)
    node.body, break_used = self._track_body(node.body, break_var)
    # A break in the else clause applies to the containing scope.
    node.orelse = self.visit_block(node.orelse)

    if break_used:
      # Python's else clause only triggers if the loop exited cleanly (e.g.
      # break did not trigger).
      guarded_orelse = self._guard_if_present(node.orelse, break_var)

      template = """
        var_name = False
        while test and not var_name:
          body
        else:
          orelse
      """
      node = templates.replace(
          template,
          var_name=break_var,
          test=node.test,
          body=node.body,
          orelse=guarded_orelse)

    return node
예제 #8
0
  def test_replace_call_keyword(self):
    template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

    source = parser.parse_expression('f(d=3, f=5)')
    node = templates.replace(template, kws=source.keywords)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(9, result.test_fn())

    with self.assertRaises(ValueError):
      templates.replace(template, kws=[])
      templates.replace(template, kws=1)
예제 #9
0
  def _rename_compilable_function(self, node):
    assert anno.hasanno(node.func, 'live_val')
    assert anno.hasanno(node.func, 'fqn')
    target_entity = anno.getanno(node.func, 'live_val')
    target_fqn = anno.getanno(node.func, 'fqn')

    if not self._should_compile(node, target_fqn):
      return node

    if anno.hasanno(node, 'is_constructor'):
      new_name = self.ctx.namer.compiled_class_name(
          target_fqn, live_entity=target_entity)
      do_rename = True
    else:
      if anno.hasanno(node.func, 'parent_type'):
        owner_type = anno.getanno(node.func, 'parent_type')
      else:
        # Fallback - not reliable.
        owner_type = inspect_utils.getmethodclass(target_entity)
      new_name, do_rename = self.ctx.namer.compiled_function_name(
          target_fqn, live_entity=target_entity, owner_type=owner_type)

    if do_rename:
      if target_entity is not None:
        if tf_inspect.ismethod(target_entity):
          # The renaming process will transform it into a regular function.
          # TODO(mdan): Is this complete? How does it work with nested members?
          node.args = [node.func.value] + node.args
      node.func = templates.replace('func_name', func_name=new_name)[0]
    return node
예제 #10
0
 def _create_break_trigger(self):
   template = """
     var_name = True
   """
   block = templates.replace(template, var_name=self.break_uses[-1][1])
   block.append(gast.Continue())
   return block
예제 #11
0
 def visit_Break(self, node):
   self.break_uses[-1][0] = True
   template = """
     var_name = True
     continue
   """
   return templates.replace(template, var_name=self.break_uses[-1][1])
예제 #12
0
  def visit_Assert(self, node):
    self.generic_visit(node)

    # Note: The lone tf.Assert call will be wrapped with control_dependencies
    # by side_effect_guards.
    template = """
      tf.Assert(test, (msg,))
    """

    if node.msg is None:
      return templates.replace(
          template, test=node.test, msg=gast.Str('Assertion error'))
    elif isinstance(node.msg, gast.Str):
      return templates.replace(template, test=node.test, msg=node.msg)
    else:
      raise NotImplementedError('can only convert string messages for now.')
 def visit_Continue(self, node):
   self.set_local(CONTINUE_USED, True)
   template = """
     var_name = True
   """
   return templates.replace(
       template, var_name=self.get_local(CONTROL_VAR_NAME))
예제 #14
0
  def _generate_pop_operation(self, original_call_node, pop_var_name):
    assert isinstance(original_call_node.func, gast.Attribute)

    if original_call_node.args:
      pop_element = original_call_node.args[0]
    else:
      pop_element = parser.parse_expression('None')
    # The call will be something like "target.pop()", and the dtype is hooked to
    # target, hence the func.value.
    dtype = anno.getanno(
        original_call_node.func.value,
        'element_type',
        default=templates.replace_as_expression('None'))
    shape = anno.getanno(
        original_call_node.func.value,
        'element_shape',
        default=templates.replace_as_expression('None'))

    template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
    return templates.replace(
        template,
        target=original_call_node.func.value,
        pop_var_name=pop_var_name,
        element=pop_element,
        dtype=dtype,
        shape=shape)
예제 #15
0
 def _create_cond_expr(self, results, test, body_name, orelse_name):
   if results is not None:
     template = """
       results = ag__.utils.run_cond(test, body_name, orelse_name)
     """
     return templates.replace(
         template,
         test=test,
         results=results,
         body_name=body_name,
         orelse_name=orelse_name)
   else:
     template = """
       ag__.utils.run_cond(test, body_name, orelse_name)
     """
     return templates.replace(
         template, test=test, body_name=body_name, orelse_name=orelse_name)
예제 #16
0
  def visit_For(self, node):
    self.generic_visit(node)

    self._validate_no_live_vars_created(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iter_=node.iter,
        iterate=node.target,
        extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
        extra_test_expr=extra_test,
        body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
예제 #17
0
 def visit_Break(self, node):
   self.state[_Break].used = True
   var_name = self.state[_Break].control_var_name
   # TODO(mdan): This will fail when expanded inside a top-level else block.
   template = """
     var_name = True
     continue
   """
   return templates.replace(template, var_name=var_name)
예제 #18
0
 def visit_Break(self, node):
   self.set_local(BREAK_USED, True)
   var_name = self.get_local(CONTROL_VAR_NAME)
   # TODO(mdan): This will fail when expanded inside a top-level else block.
   template = """
     var_name = True
     continue
   """
   return templates.replace(template, var_name=var_name)
예제 #19
0
  def _process_single_assignment(self, target, value):
    if not isinstance(target, gast.Subscript):
      return None

    template = """
      target = ag__.set_item(target, key, item)
    """
    return templates.replace(
        template, target=target.value, key=target.slice, item=value)
예제 #20
0
  def replace_as_expression(self):
    template = """
      foo(a)
    """

    node = templates.replace(template, foo='bar', a='baz')
    self.assertTrue(node is gast.Call)
    self.assertEqual(node.func.id, 'bar')
    self.assertEqual(node.func.args[0].id, 'baz')
예제 #21
0
 def _convert_builtin(self, f, args, as_expression):
   template = """
     ag__.func(args)
   """
   if as_expression:
     return templates.replace_as_expression(
         template, func=py_builtins.overload_of(f).__name__, args=args)
   else:
     return templates.replace(
         template, func=py_builtins.overload_of(f).__name__, args=args)
예제 #22
0
  def canonicalize_listcomp(self, result_node, list_comp_node):

    make_list = templates.replace(
        'list_ = create_list',
        list_=result_node,
        create_list=self.instantiate_list_node())
    loop_body = self.make_update_list_node(result_node, list_comp_node.elt)

    for gen in reversed(list_comp_node.generators):
      for gen_if in reversed(gen.ifs):
        loop_body = templates.replace(
            'if test: loop_body', test=gen_if, loop_body=loop_body)
      loop_body = templates.replace(
          'for target in iter_: loop_body',
          iter_=gen.iter,
          target=gen.target,
          loop_body=loop_body)

    return make_list + loop_body
예제 #23
0
  def test_replace_tuple(self):
    template = """
      def test_fn(a, c):
        return b,
    """

    node = templates.replace(template, b=('a', 'c'))[0]
    result, _ = compiler.ast_to_object(node)

    self.assertEquals((2, 3), result.test_fn(2, 3))
예제 #24
0
  def test_replace_name_with_dict(self):
    template = """
      def test_fn():
        return foo['bar']
    """

    source = parser.parse_expression('{\'bar\': 3}')
    node = templates.replace(template, foo=source)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(3, result.test_fn())
예제 #25
0
 def _replace_append_call(self, node):
   assert len(node.args) == 1
   assert isinstance(node.func, gast.Attribute)
   template = """
     target = ag__.list_append(target, element)
   """
   return templates.replace(
       template,
       target=node.func.value,
       element=node.args[0])
 def _create_branch(self, expr, name_stem):
   scope = self.state[_Statement].scope
   name = self.ctx.namer.new_symbol(name_stem, scope.referenced)
   template = """
     def name():
       return expr,
   """
   node = templates.replace(template, name=name, expr=expr)
   self.state[_FunctionDefs].nodes.append(node)
   return name
예제 #27
0
 def _wrap_to_py_func_no_return(self, node):
   # TODO(mdan): Properly handle varargs, etc.
   template = """
     ag__.utils.wrap_py_func(func, None, (args,), kwargs, True)
   """
   return templates.replace(
       template,
       func=node.func,
       args=node.args,
       kwargs=ast_util.keywords_to_dict(node.keywords))
예제 #28
0
  def visit_For(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_cond'):
      extra_cond = anno.getanno(node, 'extra_cond')
      extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
    else:
      extra_cond = parser.parse_expression('True')

    template = """
      def extra_cond_name(state_ssf):
        return extra_cond_expr
      def body_name(iterate, state_ssf):
        body
        return state_ssf,
      state_ast_tuple = __ops.for_loop(
          iterated, extra_cond_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iterated=node.iter,
        iterate=node.target,
        extra_cond_name=self.context.namer.new_symbol('extra_cond',
                                                      all_referenced),
        extra_cond_expr=extra_cond,
        body_name=self.context.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
예제 #29
0
  def test_replace_function_name(self):
    template = """
      def fname(a):
        a += 1
        a = 2 * a + 1
        return a
    """

    node = templates.replace(template, fname='test_fn')[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(7, result.test_fn(2))
예제 #30
0
 def visit_Print(self, node):
   self.generic_visit(node)
   args = node.values
   # Following is the case when calling print(a, b)
   if len(args) == 1 and isinstance(args[0], gast.Tuple):
     args = args[0].elts
   template = """
     fname(args)
   """
   function_call = templates.replace(template, fname='print', args=args)[0]
   return self.visit(function_call)
예제 #31
0
 def _ensure_node_is_trivial(self, node):
   if node is None:
     return node
   elif isinstance(node, self._trivial_nodes):
     return node
   elif isinstance(node, list):
     # If something's field was actually a list, e.g., variadic arguments.
     return [self._ensure_node_is_trivial(n) for n in node]
   elif isinstance(node, gast.keyword):
     node.value = self._ensure_node_is_trivial(node.value)
     return node
   elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)):
     return self._ensure_fields_trivial(node)
   elif isinstance(node, gast.expr):
     temp_name = self._gensym.new_name()
     temp_assign = templates.replace(
         'temp_name = expr', temp_name=temp_name, expr=node)[0]
     self._add_pending_statement(temp_assign)
     answer = templates.replace('temp_name', temp_name=temp_name)[0]
     return answer
   else:
     raise ValueError('Do not know how to treat {}'.format(node))
예제 #32
0
    def test_replace_complex_context(self):
        template = """
      def test_fn(foo):
        foo = 0
    """

        node = templates.replace(
            template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
        self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
        function_call_arg = node.body[0].targets[0].value.args[0]
        self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
        self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
        self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
예제 #33
0
    def _process_single_assignment(self, target, value):
        if not isinstance(target, gast.Subscript):
            return None
        if not isinstance(target.slice, gast.Index):
            return None

        template = """
      target = ag__.set_item(target, key, item)
    """
        return templates.replace(template,
                                 target=target.value,
                                 key=target.slice.value,
                                 item=value)
예제 #34
0
 def _guard_if_present(self, block, var_name):
   """Prevents the block from executing if var_name is set."""
   if not block:
     return block
   template = """
       if not var_name:
         block
     """
   node = templates.replace(
       template,
       var_name=var_name,
       block=block)
   return node
예제 #35
0
  def visit_If(self, node):
    # Depth-first traversal of if statements
    node = self.generic_visit(node)

    # We check if both branches return, and if so, lift the return out of the
    # conditional. We don't enforce that the true and false branches either
    # both return or both do not, because FoldElse might move a return
    # into a branch after this transform completes. FoldElse and LiftReturn
    # are alternately run until the code reaches a fixed point.
    true_branch_returns = isinstance(node.body[-1], gast.Return)
    false_branch_returns = len(node.orelse) and isinstance(
        node.orelse[-1], gast.Return)
    if true_branch_returns and false_branch_returns:
      node.body[-1] = templates.replace(
          'a = b', a=self.common_return_name, b=node.body[-1].value)[0]
      node.orelse[-1] = templates.replace(
          'a = b', a=self.common_return_name, b=node.orelse[-1].value)[0]
      return_node = templates.replace('return a', a=self.common_return_name)[0]
      self.changes_made = True
      return [node, return_node]
    else:
      return node
예제 #36
0
 def visit_FunctionDef(self, node):
     self._function_level += 1
     try:
         self.generic_visit(node)
     finally:
         self._function_level -= 1
     scope_name = node.name
     if self._function_level == 0 and self.context.owner_type is not None:
         scope_name = '{}/{}'.format(self.context.owner_type.__name__,
                                     scope_name)
     node.body = templates.replace('with tf.name_scope(scope_name): body',
                                   scope_name=gast.Str(scope_name),
                                   body=node.body)
     return node
예제 #37
0
    def test_replace_name_with_call(self):
        template = """
      def test_fn():
        b = 5
        def g(a):
          return 3 * a
        def f():
          return g
        return foo
    """

        source = parser.parse_expression('f()(b)')
        node = templates.replace(template, foo=source)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(15, result.test_fn())
예제 #38
0
 def _create_cond_branch(self, body_name, aliased_orig_names,
                         aliased_new_names, body, returns):
   if aliased_orig_names:
     template = """
       def body_name():
         aliased_new_names, = aliased_orig_names,
         body
         return (returns,)
     """
     return templates.replace(
         template,
         body_name=body_name,
         body=body,
         aliased_orig_names=aliased_orig_names,
         aliased_new_names=aliased_new_names,
         returns=returns)
   else:
     template = """
       def body_name():
         body
         return (returns,)
     """
     return templates.replace(
         template, body_name=body_name, body=body, returns=returns)
    def _guard_if_present(self, block, var_name):
        """Prevents the block from executing if var_name is set."""

        # If we don't have statements that immediately depend on the break
        # we still need to make sure that the break variable remains
        # used, in case the break becomes useful in later stages of transformation.
        # Not having this broke the break_in_inner_loop test.
        if not block:
            block = [gast.Pass()]
        template = """
        if not var_name:
          block
      """
        node = templates.replace(template, var_name=var_name, block=block)
        return node
예제 #40
0
    def test_replace_code_block(self):
        template = """
      def test_fn(a):
        block
        return a
    """

        node = templates.replace(
            template,
            block=[
                gast.Assign([gast.Name('a', None, None)],
                            gast.BinOp(gast.Name('a', None, None), gast.Add(),
                                       gast.Num(1))),
            ] * 2)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(3, result.test_fn(1))
  def _visit_loop_body(self, node, nodes):
    self.enter_local_scope()
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    continue_var = self.ctx.namer.new_symbol('continue_', scope.referenced)
    self.set_local(CONTROL_VAR_NAME, continue_var)

    nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)

    if self.get_local(CONTINUE_USED, False):
      template = """
        var_name = tf.constant(False)
      """
      control_var_init = templates.replace(template, var_name=continue_var)
      nodes = control_var_init + nodes

    self.exit_local_scope()
    return nodes
예제 #42
0
    def visit_Expr(self, node):
        node = self.generic_visit(node)
        if isinstance(node.value, gast.Call):
            call_node = node.value

            if not anno.hasanno(call_node.func, anno.Basic.QN):
                return node
            qn = anno.getanno(call_node.func, anno.Basic.QN)

            if qn.qn[-1] == 'append' and (len(call_node.args) == 1):
                template = """
          target = autograph_utils.dynamic_list_append(target, element)
        """
                node = templates.replace(template,
                                         target=qn.parent.ast(),
                                         element=call_node.args[0])
        return node
예제 #43
0
    def visit_For(self, node):
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.context.namer.new_symbol('break_requested',
                                                  scope.referenced)

        self.break_uses.append([False, break_var])
        node = self.generic_visit(node)
        if self.break_uses[-1][0]:
            template = """
        var_name = False
        original_for
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     original_for=node)
            extra_cond = templates.replace_as_expression('not var_name',
                                                         var_name=break_var)
            new_for_node = node[1]
            anno.setanno(new_for_node, 'extra_cond', extra_cond)
        self.break_uses.pop()

        return node
예제 #44
0
  def visit_FunctionDef(self, node):
    node = self.generic_visit(node)

    unscoped_body = []
    scoped_body = node.body
    if scoped_body:
      first = scoped_body[0]
      if isinstance(first, gast.Expr) and isinstance(first.value, gast.Str):
        # Skip any docstring.
        unscoped_body = scoped_body[:1]
        scoped_body = scoped_body[1:]

    template = """
      with tf.name_scope(scope_name):
        body
    """
    scoped_body = templates.replace(
        template,
        scope_name=gast.Str(self._name_for_current_scope()),
        body=scoped_body)
    node.body = unscoped_body + scoped_body
    return node
    def _postprocess_statement(self, node):
        # Example of how the state machine below works:
        #
        #   1| stmt           # State: CONTINUE_USED = False
        #    |                # Action: none
        #   2| if cond:
        #   3|   continue     # State: CONTINUE_USED = True,
        #    |                #        GUARD_CREATED = False,
        #    |                #        CREATE_GUARD_NEXT = False
        #    |                # Action: set CREATE_GUARD_NEXT = True
        #   4| stmt           # State: CONTINUE_USED = True,
        #    |                #        GUARD_CREATED = False,
        #    |                #        CREATE_GUARD_NEXT = True
        #    |                # Action: create `if not continue_used`,
        #    |                #         set GUARD_CREATED = True
        #   5| stmt           # State: CONTINUE_USED = True, GUARD_CREATED = True
        #    |                # Action: none (will be wrapped under previously
        #    |                #         created if node)

        if self.get_local(CONTINUE_USED, False):
            if self.get_local(GUARD_CREATED, False):
                return node, None

            elif not self.get_local(CREATE_GUARD_NEXT, False):
                self.set_local(CREATE_GUARD_NEXT, True)
                return node, None

            else:
                self.set_local(GUARD_CREATED, True)
                template = """
          if not var_name:
            original_node
        """
                cond, = templates.replace(
                    template,
                    var_name=self.get_local(CONTROL_VAR_NAME),
                    original_node=node)
                return cond, cond.body
        return node, None
예제 #46
0
    def visit_For(self, node):
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

        node.target = self.visit(node.target)
        node.iter = self.visit(node.iter)
        node.body, break_used = self._process_body(node.body, break_var)
        # A break in the else clause applies to the containing scope.
        node.orelse = self.visit_block(node.orelse)

        if break_used:
            # Python's else clause only triggers if the loop exited cleanly (e.g.
            # break did not trigger).
            guarded_orelse = self._guard_if_present(node.orelse, break_var)
            extra_test = templates.replace_as_expression('not var_name',
                                                         var_name=break_var)

            # The extra code is hidden in the AST, which will confuse the static
            # analysis. To mitigate that, we insert a no-op statement that ensures
            # the control variable is marked as used.
            # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
            template = """
        var_name = tf.constant(False)
        for target in iter_:
          (var_name,)
          body
        else:
          orelse
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     iter_=node.iter,
                                     target=node.target,
                                     body=node.body,
                                     orelse=guarded_orelse)

            anno.setanno(node[1], 'extra_test', extra_test)

        return node
예제 #47
0
    def visit_While(self, node):
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.context.namer.new_symbol('break_requested',
                                                  scope.referenced)

        self.break_uses.append([False, break_var])
        node = self.generic_visit(node)
        if self.break_uses[-1][0]:
            template = """
        var_name = False
        while original_test and not var_name:
          original_body
        else:
          original_orelse
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     original_test=node.test,
                                     original_body=node.body,
                                     original_orelse=node.orelse)
        self.break_uses.pop()

        return node
예제 #48
0
    def _generate_pop_operation(self, original_call_node, pop_var_name):
        assert isinstance(original_call_node.func, gast.Attribute)

        if original_call_node.args:
            pop_element = original_call_node.args[0]
        else:
            pop_element = parser.parse_expression('None')

        # The call will be something like "target.pop()", and the dtype is hooked to
        # target, hence the func.value.
        # TODO(mdan): For lists of lists, this won't work.
        # The reason why it won't work is because it's unclear how to annotate
        # the list as a "list of lists with a certain element type" when using
        # operations like `l.pop().pop()`.
        dtype = self.get_definition_directive(
            original_call_node.func.value,
            directives.set_element_type,
            'dtype',
            default=templates.replace_as_expression('None'))
        shape = self.get_definition_directive(
            original_call_node.func.value,
            directives.set_element_type,
            'shape',
            default=templates.replace_as_expression('None'))

        template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
        return templates.replace(template,
                                 target=original_call_node.func.value,
                                 pop_var_name=pop_var_name,
                                 element=pop_element,
                                 dtype=dtype,
                                 shape=shape)
예제 #49
0
  def visit_While(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE)
    cond_closure = set()
    for s in cond_scope.referenced:
      for root in s.support_set:
        if root not in body_scope.created:
          cond_closure.add(root)

    state = list(body_closure)
    if not state:
      # TODO (mdan): Implement this properly. id:486
      # https://github.com/imdone/tensorflow/issues/487
      # To complete this statement, we need to check whether any variable
      # created inside the body scope is used before being modified outside the
      # scope. This should be done during activity analysis, and in general
      # should cover the case where variables may not be initialized.
      raise ValueError('cannot convert while loop: no outputs')

    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    test = ast_util.rename_symbols(node.test, ssf_map)

    template = """
      def test_name(state_ssf):
        return test
      def body_name(state_ssf):
        body
        return state_ssf,
      state_ast_tuple = ag__.while_loop(
          test_name, body_name, (state,), (extra_deps,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        test_name=self.context.namer.new_symbol('loop_test',
                                                body_scope.referenced),
        test=test,
        body_name=self.context.namer.new_symbol('loop_body',
                                                body_scope.referenced),
        body=node_body,
        extra_deps=tuple(s.ast() for s in cond_closure),
    )

    return node
예제 #50
0
 def make_update_list_node(self, list_, elt):
     return templates.replace('list_.append(elt)', list_=list_, elt=elt)[0]
 def generate_Print(self):
     return templates.replace('print(x)', x=self.generate_expression())[0]
예제 #52
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
    """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    if tf_inspect.isclass(o):
        node, name, ns = class_to_graph(o, program_ctx)
    elif tf_inspect.isfunction(o):
        # TODO(mdan): This is not a reliable mechanism.
        # The most reliable way is to check the source code, the AST will contain
        # a Lambda node instead of a FunctionDef
        if o.__name__ == '<lambda>':
            raise NotImplementedError(
                'lambda functions are not yet supported; declare the function'
                ' using def instead: %s' % o)
        else:
            node, name, ns = function_to_graph(o, program_ctx, arg_values,
                                               arg_types)
    elif tf_inspect.ismethod(o):
        node, name, ns = function_to_graph(o, program_ctx, arg_values,
                                           arg_types)
    # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
    elif hasattr(o, '__class__'):
        raise NotImplementedError(
            'Object conversion is not yet supported. If you are '
            'trying to convert code that uses an existing object, '
            'try including the creation of that object in the '
            'conversion. For example, instead of converting the method '
            'of a class, try converting the entire class instead. '
            'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
            'contrib/autograph/README.md#using-the-functional-api '
            'for more information.')
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    # TODO(mdan): This is temporary. it should be created using a converter.
    # TODO(mdan): The attribute should be added with a helper, not directly.
    # The helper can ensure there are no collisions.
    template = '''
      entity.autograph_info__ = {}
  '''
    node.extend(templates.replace(template, entity=name))

    program_ctx.add_to_cache(o, node)

    if program_ctx.recursive:
        while True:
            candidate = None
            for obj in program_ctx.name_map.keys():
                if obj not in program_ctx.dependency_cache:
                    candidate = obj
                    break
            if candidate is None:
                break
            if (hasattr(candidate, 'im_class') and getattr(
                    candidate, 'im_class') not in program_ctx.partial_types):
                # Class members are converted with their objects, unless they're
                # only converted partially.
                continue
            entity_to_graph(candidate, program_ctx, {}, {})

    return node, name, ns
예제 #53
0
 def _create_break_check(self):
     template = """
   (not var_name)
 """
     expr, = templates.replace(template, var_name=self.break_uses[-1][1])
     return expr.value
예제 #54
0
 def _create_break_init(self):
     template = """
   var_name = False
 """
     assign, = templates.replace(template, var_name=self.break_uses[-1][1])
     return assign
예제 #55
0
    def visit_If(self, node):
        self.generic_visit(node)

        body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)
        body_defs = body_scope.created | body_scope.modified
        orelse_defs = orelse_scope.created | orelse_scope.modified
        live = anno.getanno(node, 'live_out')

        # We'll need to check if we're closing over variables that are defined
        # elsewhere in the function
        # NOTE: we can only detect syntactic closure in the scope
        # of the code passed in. If the AutoGraph'd function itself closes
        # over other variables, this analysis won't take that into account.
        defined = anno.getanno(node, 'defined_in')

        # We only need to return variables that are
        # - modified by one or both branches
        # - live (or has a live parent) at the end of the conditional
        modified = []
        for def_ in body_defs | orelse_defs:
            def_with_parents = set((def_, )) | def_.support_set
            if live & def_with_parents:
                modified.append(def_)

        # We need to check if live created variables are balanced
        # in both branches
        created = live & (body_scope.created | orelse_scope.created)

        # The if statement is illegal if there are variables that are created,
        # that are also live, but both branches don't create them.
        if created:
            if created != (body_scope.created & live):
                raise ValueError(
                    'The main branch does not create all live symbols that the else '
                    'branch does.')
            if created != (orelse_scope.created & live):
                raise ValueError(
                    'The else branch does not create all live symbols that the main '
                    'branch does.')

        # Alias the closure variables inside the conditional functions
        # to avoid errors caused by the local variables created in the branch
        # functions.
        # We will alias variables independently for body and orelse scope,
        # because different branches might write different variables.
        aliased_body_orig_names = tuple(body_scope.modified -
                                        body_scope.created)
        aliased_orelse_orig_names = tuple(orelse_scope.modified -
                                          orelse_scope.created)
        aliased_body_new_names = tuple(
            self.context.namer.new_symbol(s.ssf(), body_scope.referenced)
            for s in aliased_body_orig_names)
        aliased_orelse_new_names = tuple(
            self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced)
            for s in aliased_orelse_orig_names)

        alias_body_map = dict(
            zip(aliased_body_orig_names, aliased_body_new_names))
        alias_orelse_map = dict(
            zip(aliased_orelse_orig_names, aliased_orelse_new_names))

        node_body = ast_util.rename_symbols(node.body, alias_body_map)
        node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

        if not modified:
            # When the cond would return no value, we leave the cond called without
            # results. That in turn should trigger the side effect guards. The
            # branch functions will return a dummy value that ensures cond
            # actually has some return value as well.
            results = None
        elif len(modified) == 1:
            results = modified[0]
        else:
            results = gast.Tuple([s.ast() for s in modified], None)

        body_name = self.context.namer.new_symbol('if_true',
                                                  body_scope.referenced)
        orelse_name = self.context.namer.new_symbol('if_false',
                                                    orelse_scope.referenced)
        if modified:

            def build_returns(aliased_names, alias_map, scope):
                """Builds list of return variables for a branch of a conditional."""
                returns = []
                for s in modified:
                    if s in aliased_names:
                        returns.append(alias_map[s])
                    else:
                        if s not in scope.created | defined:
                            raise ValueError(
                                'Attempting to return variable "%s" from the true branch of '
                                'a conditional, but it was not closed over, or created in '
                                'this branch.' % str(s))
                        else:
                            returns.append(s)
                return tuple(returns)

            body_returns = build_returns(aliased_body_orig_names,
                                         alias_body_map, body_scope)
            orelse_returns = build_returns(aliased_orelse_orig_names,
                                           alias_orelse_map, orelse_scope)

        else:
            body_returns = orelse_returns = templates.replace(
                'tf.ones(())')[0].value

        body_def = self._create_cond_branch(
            body_name,
            aliased_orig_names=tuple(aliased_body_orig_names),
            aliased_new_names=tuple(aliased_body_new_names),
            body=node_body,
            returns=body_returns)
        orelse_def = self._create_cond_branch(
            orelse_name,
            aliased_orig_names=tuple(aliased_orelse_orig_names),
            aliased_new_names=tuple(aliased_orelse_new_names),
            body=node_orelse,
            returns=orelse_returns)
        cond_expr = self._create_cond_expr(results, node.test, body_name,
                                           orelse_name)

        return body_def + orelse_def + cond_expr
 def _convert_builtin(self, node):
     template = """
   autograph_utils.dynamic_builtin(func, args)
 """
     return templates.replace(template, func=node.func,
                              args=node.args)[0].value
 def _convert_print(self, node):
     template = """
   autograph_utils.dynamic_print(args)
 """
     return templates.replace(template, args=node.args)[0].value
예제 #58
0
  def visit_If(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE)

    if body_scope.created - orelse_scope.created:
      raise ValueError(
          'The if branch creates new symbols that the else branch does not.')
    if orelse_scope.created - body_scope.created:
      raise ValueError(
          'The else branch creates new symbols that the if branch does not.')

    modified = tuple(body_scope.modified | orelse_scope.modified)
    all_referenced = body_scope.referenced | orelse_scope.referenced

    # Alias the closure variables inside the conditional functions
    # to avoid errors caused by the local variables created in the branch
    # functions.
    need_alias = (
        (body_scope.modified | orelse_scope.modified) -
        (body_scope.created | orelse_scope.created))
    aliased_orig_names = tuple(need_alias)
    aliased_new_names = tuple(
        self.context.namer.new_symbol(s.ssf(), all_referenced)
        for s in aliased_orig_names)
    alias_map = dict(zip(aliased_orig_names, aliased_new_names))
    node_body = ast_util.rename_symbols(node.body, alias_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_map)

    if not modified:
      # When the cond would return no value, we leave the cond called without
      # results. That in turn should trigger the side effect guards. The
      # branch functions will return a dummy value that ensures cond
      # actually has some return value as well.
      results = None
    elif len(modified) == 1:
      results = modified[0]
    else:
      results = gast.Tuple([s.ast() for s in modified], None)

    body_name = self.context.namer.new_symbol('if_true', all_referenced)
    orelse_name = self.context.namer.new_symbol('if_false', all_referenced)
    if modified:
      body_returns = tuple(
          alias_map[s] if s in aliased_orig_names else s for s in modified)
    else:
      body_returns = templates.replace('tf.ones(())')[0].value

    body_def = self._create_cond_branch(
        body_name,
        aliased_orig_names=tuple(aliased_orig_names),
        aliased_new_names=tuple(aliased_new_names),
        body=node_body,
        returns=body_returns)
    orelse_def = self._create_cond_branch(
        orelse_name,
        aliased_orig_names=tuple(aliased_orig_names),
        aliased_new_names=tuple(aliased_new_names),
        body=node_orelse,
        returns=body_returns)
    cond_expr = self._create_cond_expr(results, node.test, body_name,
                                       orelse_name)

    return body_def + orelse_def + cond_expr
예제 #59
0
  def visit_Expr(self, node):
    self.generic_visit(node)
    if isinstance(node.value, gast.Call):
      # Patterns of single function calls, like:
      #   opt.minimize(loss)
      # or:
      #   tf.py_func(...)

      # First, attempt to gate future evaluation of args. If that's not
      # possible, gate all remaining statements (and that may fail too, see
      # _visit_and_reindent.
      args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
      # NOTE: We can't guard object attributes because they may not be writable.
      # In addition, avoid renaming well-known names.
      # TODO(mdan): Move these names into config.
      unguarded_names = (qual_names.QN('self'), qual_names.QN('tf'))
      guarded_args = tuple(s for s in args_scope.used
                           if not s.is_composite() and s not in unguarded_names)

      # TODO(mdan): Include all arguments which depended on guarded_args too.
      # For example, the following will still cause a race:
      #   tf.assign(a, a + 1)
      #   b = a + 1
      #   tf.assign(a, a + 1)  # Control deps here should include `b`
      #   c = b + 1
      # Or maybe we should just raise an "unsafe assign" error?

      if guarded_args:
        # The aliases may need new names to avoid incorrectly making them local.
        # TODO(mdan): This is brutal. It will even rename modules - any fix?
        need_alias = tuple(
            s for s in guarded_args if s not in args_scope.parent.modified)
        aliased_new_names = tuple(
            qual_names.QN(
                self.context.namer.new_symbol(
                    s.ssf(), args_scope.parent.referenced)) for s in need_alias)
        alias_map = dict(zip(need_alias, aliased_new_names))
        if len(guarded_args) == 1:
          s, = guarded_args
          aliased_guarded_args = alias_map.get(s, s)
        else:
          aliased_guarded_args = gast.Tuple(
              [alias_map.get(s, s).ast() for s in guarded_args], None)

        template = """
          with ag__.utils.control_dependency_on_returns(call):
            aliased_guarded_args = ag__.utils.alias_tensors(guarded_args)
        """
        control_deps_guard = templates.replace(
            template,
            call=node.value,
            aliased_guarded_args=aliased_guarded_args,
            guarded_args=guarded_args)[-1]
      else:
        alias_map = {}

        template = """
          with ag__.utils.control_dependency_on_returns(call):
            pass
        """
        control_deps_guard = templates.replace(template, call=node.value)[-1]
        control_deps_guard.body = []

      node = control_deps_guard
      anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER,
                   (node.body, alias_map))
    return node