コード例 #1
0
    def visit_If(self, node):
        # Depth-first traversal of if statements
        node = self.generic_visit(node)

        # We check if both branches return, and if so, lift the return out of the
        # conditional. We don't enforce that the true and false branches either
        # both return or both do not, because FoldElse might move a return
        # into a branch after this transform completes. FoldElse and LiftReturn
        # are alternately run until the code reaches a fixed point.
        true_branch_returns = isinstance(node.body[-1], gast.Return)
        false_branch_returns = len(node.orelse) and isinstance(
            node.orelse[-1], gast.Return)
        if true_branch_returns and false_branch_returns:
            node.body[-1] = templates.replace('a = b',
                                              a=self.common_return_name,
                                              b=node.body[-1].value)[0]
            node.orelse[-1] = templates.replace('a = b',
                                                a=self.common_return_name,
                                                b=node.orelse[-1].value)[0]
            return_node = templates.replace('return a',
                                            a=self.common_return_name)[0]
            self.changes_made = True
            return [node, return_node]
        else:
            return node
コード例 #2
0
 def _create_cond_branch(self, body_name, aliased_orig_names,
                         aliased_new_names, body, returns):
     if aliased_orig_names:
         template = """
     def body_name():
       aliased_new_names, = aliased_orig_names,
       body
       return (returns,)
   """
         return templates.replace(template,
                                  body_name=body_name,
                                  body=body,
                                  aliased_orig_names=aliased_orig_names,
                                  aliased_new_names=aliased_new_names,
                                  returns=returns)
     else:
         template = """
     def body_name():
       body
       return (returns,)
   """
         return templates.replace(template,
                                  body_name=body_name,
                                  body=body,
                                  returns=returns)
コード例 #3
0
 def replace_as_expression_restrictions(self):
   template = """
     foo(a)
     bar(b)
   """
   with self.assertRaises(ValueError):
     templates.replace_as_expression(template)
   with self.assertRaises(ValueError):
     templates.replace('')
   with self.assertRaises(ValueError):
     templates.replace('a = b')
コード例 #4
0
 def replace_as_expression_restrictions(self):
     template = """
   foo(a)
   bar(b)
 """
     with self.assertRaises(ValueError):
         templates.replace_as_expression(template)
     with self.assertRaises(ValueError):
         templates.replace('')
     with self.assertRaises(ValueError):
         templates.replace('a = b')
コード例 #5
0
  def test_replace_attribute(self):
    template = """
      def test_fn(a):
        return a.foo
    """

    node = templates.replace(template, foo='b')[0]
    result, _ = compiler.ast_to_object(node)
    mod = imp.new_module('test')
    mod.b = 3
    self.assertEquals(3, result.test_fn(mod))

    with self.assertRaises(ValueError):
      templates.replace(template, foo=1)
コード例 #6
0
    def test_replace_attribute(self):
        template = """
      def test_fn(a):
        return a.foo
    """

        node = templates.replace(template, foo='b')[0]
        result, _ = compiler.ast_to_object(node)
        mod = imp.new_module('test')
        mod.b = 3
        self.assertEquals(3, result.test_fn(mod))

        with self.assertRaises(ValueError):
            templates.replace(template, foo=1)
コード例 #7
0
  def visit_For(self, node):
    self.generic_visit(node)
    body_scope = anno.getanno(node, 'body_scope')

    # TODO(mdan): Distinguish between `for i in n` and `for i in range(n)`
    # Or maybe we should replace range with tf.range?

    if anno.hasanno(node, 'extra_cond'):

      def template(loop_iter, target, body, i, n, extra_cond):  # pylint:disable=unused-argument
        i = 0
        n = len(loop_iter)  # pylint:disable=undefined-variable
        while i < n and extra_cond:
          # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
          target = loop_iter[i]
          body  # pylint:disable=pointless-statement
          i += 1

      return templates.replace(
          template,
          loop_iter=node.iter,
          target=node.target,
          body=node.body,
          i=gast.Name(
              self.namer.new_symbol('i', body_scope.referenced), None, None),
          n=gast.Name(
              self.namer.new_symbol('n', body_scope.referenced), None, None),
          extra_cond=anno.getanno(node, 'extra_cond'))
    else:

      def template(loop_iter, target, body, i, n):  # pylint:disable=unused-argument
        i = 0
        n = len(loop_iter)  # pylint:disable=undefined-variable
        while i < n:
          # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
          target = loop_iter[i]
          body  # pylint:disable=pointless-statement
          i += 1

      return templates.replace(
          template,
          loop_iter=node.iter,
          target=node.target,
          body=node.body,
          i=gast.Name(
              self.namer.new_symbol('i', body_scope.referenced), None, None),
          n=gast.Name(
              self.namer.new_symbol('n', body_scope.referenced), None, None))
コード例 #8
0
 def _insert_dynamic_conversion(self, node):
   """Inlines a dynamic conversion for a dynamic function."""
   # TODO(mdan): Pass information on the statically compiled functions.
   # Having access to the statically compiled functions can help avoid
   # unnecessary compilation.
   # For example, this would lead to function `a` being compiled twice:
   #
   #   def a():
   #     v = b
   #     b()
   #   def b():
   #     a()
   #
   # This is really a problem with recursive calls, which currently can
   # only be gated by a static condition, and should be rare.
   # TODO(mdan): It probably makes sense to use dynamic conversion every time.
   # Before we could convert all the time though, we'd need a reasonable
   # caching mechanism.
   template = """
     py2tf_api.converted_call(func, True, False, {}, original_args)
   """
   call_expr = templates.replace(
       template, func=node.func, original_args=node.args)
   new_call = call_expr[0].value
   # TODO(mdan): Improve the template mechanism to better support this.
   new_call.keywords = node.keywords
   return new_call
コード例 #9
0
 def _gate_symbols(self, guard_statement, guarded_args):
   template = """
     (args,) = (tf.identity(a) for a in (args,))
   """
   guards = templates.replace(template, args=tuple(guarded_args))
   guard_statement.body.extend(guards)
   return guard_statement
コード例 #10
0
  def _convert_len(self, node):

    def template(args):
      tf.shape(args)[0]  # pylint:disable=undefined-variable,expression-not-assigned

    new_call = templates.replace(template, args=node.args)[0].value
    return new_call
コード例 #11
0
 def _create_break_trigger(self):
     template = """
   var_name = True
 """
     block = templates.replace(template, var_name=self.break_uses[-1][1])
     block.append(gast.Continue())
     return block
コード例 #12
0
ファイル: call_trees.py プロジェクト: japrogramer/tensorflow
  def _wrap_to_py_func_no_return(self, node):
    func_qn = anno.getanno(node.func, anno.Basic.QN)
    args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
    wrapper_name = self.context.namer.new_symbol(func_qn.ssf(),
                                                 args_scope.referenced)
    wrapper_args = []
    for arg in node.args:
      if anno.hasanno(arg, anno.Basic.QN):
        arg_qn = anno.getanno(arg, anno.Basic.QN)
      else:
        arg_qn = qual_names.QN('arg')
      wrapper_args.append(
          self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced))
    # TODO(mdan): Properly handle varargs, kwargs, etc.
    # TODO(mdan): This is best handled as a dynamic dispatch.
    # That way we can separate tensors from non-tensor args.
    template = """
      def wrapper(wrapper_args):
        call(wrapper_args)
        return 1
      tf.py_func(wrapper, original_args, [tf.int64])
    """
    wrapper_def, call_expr = templates.replace(
        template,
        call=node.func,
        wrapper=wrapper_name,
        original_args=gast.List(elts=node.args, ctx=None),
        wrapper_args=wrapper_args)
    anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True)

    return (wrapper_def, call_expr)
コード例 #13
0
 def _create_continuation_init(self):
   template = """
     var_name = False
   """
   assign, = templates.replace(
       template, var_name=self.continuation_uses[-1][1])
   return assign
コード例 #14
0
    def test_replace_call_keyword(self):
        template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

        source = parser.parse_expression('f(d=3, f=5)')
        node = templates.replace(template, kws=source.keywords)[0]
        result, _ = compiler.ast_to_object(node)
        self.assertEquals(9, result.test_fn())

        with self.assertRaises(ValueError):
            templates.replace(template, kws=[])
            templates.replace(template, kws=1)
コード例 #15
0
    def visit_Expr(self, node):
        self.generic_visit(node)
        if isinstance(node.value, gast.Call):
            # Patterns of single function calls, like:
            #   opt.minimize(loss)
            # or:
            #   tf.py_func(...)
            template = """
        with py2tf_utils.control_dependency_on_returns(tf, call):
          # TODO(mdan): Also insert ops to re-fetch if variables are involved?
          pass  # Will be removed below.
      """
            # TODO(mdan): This is brittle. Reorganize the mechanism.
            statements = templates.replace(template, call=node.value)
            control_deps_guard = statements[-1]
            control_deps_guard.body = []

            # First, attempt to gate future evaluation of args. If that's not
            # possible, gate all remaining statements (and that may fail too, see
            # _visit_and_reindent.
            args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
            guarded_args = tuple(args_scope.used
                                 & (args_scope.parent.modified
                                    | args_scope.parent.returned))
            if guarded_args:
                node = tuple(statements[:-1]) + (self._gate_symbols(
                    control_deps_guard, guarded_args), )
            else:
                node = tuple(statements[:-1])
                # The mechanism will insert the guard statement later.
                self.indent_next = True
                self.next_indent_owner = control_deps_guard
        return node
コード例 #16
0
ファイル: control_flow.py プロジェクト: ClowJ/tensorflow
  def visit_While(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, 'body_scope')
    body_closure = tuple(body_scope.modified - body_scope.created)

    if len(body_closure) == 1:
      state = body_closure[0]
      state_ast_tuple = state
    else:
      state = tuple(body_closure)
      state_ast_tuple = gast.Tuple(
          tuple(gast.Name(n, None, None) for n in state), None)
    template = """
      def test_name(state):
        return test
      def body_name(state):
        body
        return state,
      state_ast_tuple = tf.while_loop(test_name, body_name, [state])
    """
    node = templates.replace(
        template,
        state=state,
        state_ast_tuple=state_ast_tuple,
        test_name=self.namer.new_symbol('loop_test', body_scope.referenced),
        test=node.test,
        body_name=self.namer.new_symbol('loop_body', body_scope.referenced),
        body=node.body)

    return node
コード例 #17
0
    def _wrap_to_py_func_no_return(self, node):
        args_scope = anno.getanno(node, 'args_scope')
        # TODO(mdan): Properly handle varargs, kwargs, etc.
        args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used)

        # pylint:disable=undefined-variable,unused-argument,function-redefined

        def template(call, wrapper, args):
            def wrapper(args):
                call(args)
                return 1

            tf.py_func(wrapper, [args], [tf.int64])

        # pylint:enable=undefined-variable,unused-argument,function-redefined

        wrapper_name = self.namer.compiled_function_name(node.func.id)
        wrapper_def, call_expr = templates.replace(template,
                                                   call=node.func,
                                                   wrapper=gast.Name(
                                                       wrapper_name,
                                                       gast.Load(), None),
                                                   args=args)
        anno.setanno(call_expr.value, 'args_scope', args_scope)
        anno.setanno(wrapper_def, 'skip_processing', True)

        return (wrapper_def, call_expr)
コード例 #18
0
 def _create_continuation_trigger(self):
   template = """
     var_name = True
   """
   assign, = templates.replace(
       template, var_name=self.continuation_uses[-1][1])
   return assign
コード例 #19
0
  def visit_Assert(self, node):
    self.generic_visit(node)

    # Note: The lone tf.Assert call will be wrapped with control_dependencies
    # by side_effect_guards.
    template = """
      tf.Assert(test, [msg])
    """

    if node.msg is None:
      return templates.replace(
          template, test=node.test, msg=gast.Str('Assertion error'))
    elif isinstance(node.msg, gast.Str):
      return templates.replace(template, test=node.test, msg=node.msg)
    else:
      raise NotImplementedError('Can only convert string messages for now.')
コード例 #20
0
 def _gate_symbols(self, guard_statement, guarded_args):
     template = """
   (args,) = (tf.identity(a) for a in (args,))
 """
     guards = templates.replace(template, args=tuple(guarded_args))
     guard_statement.body.extend(guards)
     return guard_statement
コード例 #21
0
  def test_replace_call_keyword(self):
    template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

    source = parser.parse_expression('f(d=3, f=5)')
    node = templates.replace(template, kws=source.keywords)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(9, result.test_fn())

    with self.assertRaises(ValueError):
      templates.replace(template, kws=[])
      templates.replace(template, kws=1)
コード例 #22
0
ファイル: call_trees.py プロジェクト: budye/tensorflow-1
    def _wrap_to_py_func_no_return(self, node):
        func_qn = anno.getanno(node.func, anno.Basic.QN)
        args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
        wrapper_name = self.context.namer.new_symbol(func_qn.ssf(),
                                                     args_scope.referenced)
        wrapper_args = []
        for arg in node.args:
            if anno.hasanno(arg, anno.Basic.QN):
                arg_qn = anno.getanno(arg, anno.Basic.QN)
            else:
                arg_qn = qual_names.QN('arg')
            wrapper_args.append(
                self.context.namer.new_symbol(arg_qn.ssf(),
                                              args_scope.referenced))
        # TODO(mdan): Properly handle varargs, kwargs, etc.
        # TODO(mdan): This is best handled as a dynamic dispatch.
        # That way we can separate tensors from non-tensor args.
        template = """
      def wrapper(wrapper_args):
        call(wrapper_args)
        return 1
      tf.py_func(wrapper, original_args, [tf.int64])
    """
        wrapper_def, call_expr = templates.replace(template,
                                                   call=node.func,
                                                   wrapper=wrapper_name,
                                                   original_args=gast.List(
                                                       elts=node.args,
                                                       ctx=None),
                                                   wrapper_args=wrapper_args)
        anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True)

        return (wrapper_def, call_expr)
コード例 #23
0
  def visit_For(self, node):
    self.generic_visit(node)
    body_scope = anno.getanno(node, 'body_scope')

    # TODO(mdan): Distinguish between `for i in n` and `for i in range(n)`
    # Or maybe we should replace range with tf.range?

    def template(loop_iter, target, body, i, n):  # pylint:disable=unused-argument
      i = 0
      n = len(loop_iter)  # pylint:disable=undefined-variable
      while i < n:
        # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
        target = loop_iter[i]
        body  # pylint:disable=pointless-statement
        i += 1

    return templates.replace(
        template,
        loop_iter=node.iter,
        target=node.target,
        body=node.body,
        i=gast.Name(
            self.namer.new_symbol('i', body_scope.referenced), None, None),
        n=gast.Name(
            self.namer.new_symbol('n', body_scope.referenced), None, None))
コード例 #24
0
 def _create_continuation_trigger(self):
     template = """
   var_name = True
 """
     assign, = templates.replace(template,
                                 var_name=self.continuation_uses[-1][1])
     return assign
コード例 #25
0
ファイル: asserts.py プロジェクト: ChengYuXiang/tensorflow
  def visit_Assert(self, node):
    self.generic_visit(node)

    # Note: The lone tf.Assert call will be wrapped with control_dependencies
    # by side_effect_guards.
    template = """
      tf.Assert(test, [tf.constant(msg)])
    """

    if node.msg is None:
      return templates.replace(
          template, test=node.test, msg=gast.Str('Assertion error'))
    elif isinstance(node.msg, gast.Str):
      return templates.replace(template, test=node.test, msg=node.msg)
    else:
      raise NotImplementedError('Can only convert string messages for now.')
コード例 #26
0
 def _create_continuation_init(self):
     template = """
   var_name = False
 """
     assign, = templates.replace(template,
                                 var_name=self.continuation_uses[-1][1])
     return assign
コード例 #27
0
ファイル: call_trees.py プロジェクト: Youed/tensorflow-1
    def _rename_compilable_function(self, node):
        assert anno.hasanno(node.func, 'live_val')
        assert anno.hasanno(node.func, 'fqn')
        target_entity = anno.getanno(node.func, 'live_val')
        target_fqn = anno.getanno(node.func, 'fqn')

        if not self._should_compile(node, target_fqn):
            return node

        if anno.hasanno(node, 'is_constructor'):
            new_name = self.context.namer.compiled_class_name(
                target_fqn, live_entity=target_entity)
            do_rename = True
        else:
            if anno.hasanno(node.func, 'parent_type'):
                owner_type = anno.getanno(node.func, 'parent_type')
            else:
                # Fallback - not reliable.
                owner_type = inspect_utils.getmethodclass(target_entity)
            new_name, do_rename = self.context.namer.compiled_function_name(
                target_fqn, live_entity=target_entity, owner_type=owner_type)

        if do_rename:
            if target_entity is not None:
                if tf_inspect.ismethod(target_entity):
                    # The renaming process will transform it into a regular function.
                    # TODO(mdan): Is this complete? How does it work with nested members?
                    node.args = [node.func.value] + node.args
            node.func = templates.replace('func_name', func_name=new_name)[0]
        return node
コード例 #28
0
  def visit_Expr(self, node):
    self.generic_visit(node)
    if isinstance(node.value, gast.Call):
      # Patterns of single function calls, like:
      #   opt.minimize(loss)
      # or:
      #   tf.py_func(...)
      template = """
        with py2tf_utils.control_dependency_on_returns(tf, call):
          # TODO(mdan): Also insert ops to re-fetch if variables are involved?
          pass  # Will be removed below.
      """
      # TODO(mdan): This is brittle. Reorganize the mechanism.
      statements = templates.replace(template, call=node.value)
      control_deps_guard = statements[-1]
      control_deps_guard.body = []

      # First, attempt to gate future evaluation of args. If that's not
      # possible, gate all remaining statements (and that may fail too, see
      # _visit_and_reindent.
      args_scope = anno.getanno(node.value, 'args_scope')
      guarded_args = tuple(args_scope.used & (args_scope.parent.modified
                                              | args_scope.parent.returned))
      if guarded_args:
        node = tuple(statements[:-1]) + (
            self._gate_symbols(control_deps_guard, guarded_args),)
      else:
        node = tuple(statements[:-1])
        # The mechanism will insert the guard statement later.
        self.indent_next = True
        self.next_indent_owner = control_deps_guard
    return node
コード例 #29
0
ファイル: control_flow.py プロジェクト: b2220333/tensorflow-1
    def visit_While(self, node):
        self.generic_visit(node)

        body_scope = anno.getanno(node, 'body_scope')
        body_closure = tuple(body_scope.modified - body_scope.created)

        if len(body_closure) == 1:
            state = body_closure[0]
            state_ast_tuple = state
        else:
            state = tuple(body_closure)
            state_ast_tuple = gast.Tuple(
                tuple(gast.Name(n, None, None) for n in state), None)
        template = """
      def test_name(state):
        return test
      def body_name(state):
        body
        return state,
      state_ast_tuple = tf.while_loop(test_name, body_name, [state])
    """
        node = templates.replace(
            template,
            state=state,
            state_ast_tuple=state_ast_tuple,
            test_name=self.namer.new_symbol('loop_test',
                                            body_scope.referenced),
            test=node.test,
            body_name=self.namer.new_symbol('loop_body',
                                            body_scope.referenced),
            body=node.body)

        return node
コード例 #30
0
  def _rename_compilable_function(self, node):
    assert anno.hasanno(node.func, 'live_val')
    assert anno.hasanno(node.func, 'fqn')
    target_entity = anno.getanno(node.func, 'live_val')
    target_fqn = anno.getanno(node.func, 'fqn')

    if not self._should_compile(node, target_fqn):
      return node

    if anno.hasanno(node, 'is_constructor'):
      new_name = self.context.namer.compiled_class_name(
          target_fqn, live_entity=target_entity)
      do_rename = True
    else:
      owner_type = self._determine_function_owner(target_entity)
      new_name, do_rename = self.context.namer.compiled_function_name(
          target_fqn, live_entity=target_entity, owner_type=owner_type)

    if do_rename:
      if target_entity is not None:
        if tf_inspect.ismethod(target_entity):
          # The renaming process will transform it into a regular function.
          # TODO(mdan): Is this complete? How does it work with nested members?
          node.args = [node.func.value] + node.args
      node.func = templates.replace('func_name', func_name=new_name)[0]
    return node
コード例 #31
0
ファイル: call_trees.py プロジェクト: Youed/tensorflow-1
 def _insert_dynamic_conversion(self, node):
     """Inlines a dynamic conversion for a dynamic function."""
     # TODO(mdan): Pass information on the statically compiled functions.
     # Having access to the statically compiled functions can help avoid
     # unnecessary compilation.
     # For example, this would lead to function `a` being compiled twice:
     #
     #   def a():
     #     v = b
     #     b()
     #   def b():
     #     a()
     #
     # This is really a problem with recursive calls, which currently can
     # only be gated by a static condition, and should be rare.
     # TODO(mdan): It probably makes sense to use dynamic conversion every time.
     # Before we could convert all the time though, we'd need a reasonable
     # caching mechanism.
     template = """
   py2tf_api.converted_call(func, True, False, {}, original_args)
 """
     call_expr = templates.replace(template,
                                   func=node.func,
                                   original_args=node.args)
     new_call = call_expr[0].value
     # TODO(mdan): Improve the template mechanism to better support this.
     new_call.keywords = node.keywords
     return new_call
コード例 #32
0
 def _create_break_trigger(self):
   template = """
     var_name = True
   """
   block = templates.replace(template, var_name=self.break_uses[-1][1])
   block.append(gast.Continue())
   return block
コード例 #33
0
    def visit_With(self, node):
        # Depth-first traversal of syntax
        node = self.generic_visit(node)

        # If the with statement returns, lift the return
        if isinstance(node.body[-1], gast.Return):
            node.body[-1] = templates.replace('a = b',
                                              a=self.common_return_name,
                                              b=node.body[-1].value)[0]
            return_node = templates.replace('return a',
                                            a=self.common_return_name)[0]
            node = self.generic_visit(node)
            self.changes_made = True
            return [node, return_node]
        else:
            return node
コード例 #34
0
ファイル: call_trees.py プロジェクト: andrewharp/tensorflow
  def _wrap_to_py_func_no_return(self, node):
    args_scope = anno.getanno(node, 'args_scope')
    # TODO(mdan): Properly handle varargs, kwargs, etc.
    args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used)

    # pylint:disable=undefined-variable,unused-argument,function-redefined

    def template(call, wrapper, args):

      def wrapper(args):
        call(args)
        return 1

      tf.py_func(wrapper, [args], [tf.int64])

    # pylint:enable=undefined-variable,unused-argument,function-redefined

    wrapper_name = self.namer.compiled_function_name(node.func.id)
    wrapper_def, call_expr = templates.replace(
        template,
        call=node.func,
        wrapper=gast.Name(wrapper_name, gast.Load(), None),
        args=args)
    anno.setanno(call_expr.value, 'args_scope', args_scope)
    anno.setanno(wrapper_def, 'skip_processing', True)

    return (wrapper_def, call_expr)
コード例 #35
0
 def _create_cond_expr(self, results, test, body_name, orelse_name):
   if results is not None:
     template = """
       results = py2tf_utils.run_cond(test, body_name, orelse_name)
     """
     return templates.replace(
         template,
         test=test,
         results=results,
         body_name=body_name,
         orelse_name=orelse_name)
   else:
     template = """
       py2tf_utils.run_cond(test, body_name, orelse_name)
     """
     return templates.replace(
         template, test=test, body_name=body_name, orelse_name=orelse_name)
コード例 #36
0
 def _create_continuation_check(self):
   template = """
     if not var_name:
       pass
   """
   cond, = templates.replace(template, var_name=self.continuation_uses[-1][1])
   cond.body = []
   return cond
コード例 #37
0
    def _create_break_check(self):
        def template(var_name):
            (not var_name)  # pylint:disable=pointless-statement

        expr, = templates.replace(template,
                                  var_name=gast.Name(self.break_uses[-1][1],
                                                     None, None))
        return expr.value
コード例 #38
0
    def _create_break_init(self):
        def template(var_name):  # pylint:disable=unused-argument
            var_name = False

        assign, = templates.replace(template,
                                    var_name=gast.Name(self.break_uses[-1][1],
                                                       None, None))
        return assign
コード例 #39
0
 def visit_For(self, node):
   self.generic_visit(node)
   body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
   i_var = self.context.namer.new_symbol('i', body_scope.referenced)
   n_var = self.context.namer.new_symbol('n', body_scope.referenced)
   iterated_var = self.context.namer.new_symbol('iterated',
                                                body_scope.referenced)
   # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
   if anno.hasanno(node, 'extra_cond'):
     template = """
       i = 0
       iterated = loop_iter
       n = len(iterated)
       while i < n and extra_cond:
         target = iterated[i]
         body
         i += 1
     """
     return templates.replace(
         template,
         loop_iter=node.iter,
         target=node.target,
         body=node.body,
         i=i_var,
         n=n_var,
         iterated=iterated_var,
         extra_cond=anno.getanno(node, 'extra_cond'))
   else:
     template = """
       i = 0
       iterated = loop_iter
       n = len(iterated)
       while i < n:
         target = iterated[i]
         body
         i += 1
     """
     repl = templates.replace(
         template,
         loop_iter=node.iter,
         target=node.target,
         body=node.body,
         i=i_var,
         n=n_var,
         iterated=iterated_var)
     return repl
コード例 #40
0
  def _create_continuation_trigger(self):

    def template(var_name):  # pylint:disable=unused-argument
      var_name = True

    assign, = templates.replace(
        template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
    return assign
コード例 #41
0
 def _gate_symbols(self, guard_statement, guarded_args):
   # TODO(mdan): This won't work for variables.
   template = """
     (args,) = (tf.identity(a) for a in (args,))
   """
   guards = templates.replace(template, args=tuple(guarded_args))
   guard_statement.body.extend(guards)
   return guard_statement
コード例 #42
0
  def _create_break_check(self):

    def template(var_name):
      (not var_name)  # pylint:disable=pointless-statement

    expr, = templates.replace(
        template, var_name=gast.Name(self.break_uses[-1][1], None, None))
    return expr.value
コード例 #43
0
ファイル: for_loops.py プロジェクト: AndrewTwinz/tensorflow
 def visit_For(self, node):
   self.generic_visit(node)
   body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
   i_var = self.context.namer.new_symbol('i', body_scope.referenced)
   smart_loop_iter_var = self.context.namer.new_symbol('smart_loop_iter',
                                                       body_scope.referenced)
   cont_var = self.context.namer.new_symbol('cont', body_scope.referenced)
   # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
   if anno.hasanno(node, 'extra_cond'):
     template = """
       i = 0
       smart_loop_iter = py2tf_utils.dynamic_dataset(loop_iter)
       cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter)
       while cont and extra_cond:
         body
         i += 1
         cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter)
     """
     return templates.replace(
         template,
         loop_iter=node.iter,
         target=node.target,
         body=node.body,
         i=i_var,
         smart_loop_iter=smart_loop_iter_var,
         cont=cont_var,
         extra_cond=anno.getanno(node, 'extra_cond'))
   else:
     template = """
       i = 0
       smart_loop_iter = py2tf_utils.dynamic_dataset(loop_iter)
       cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter)
       while cont:
         body
         i += 1
         cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter)
     """
     repl = templates.replace(
         template,
         loop_iter=node.iter,
         target=node.target,
         body=node.body,
         i=i_var,
         smart_loop_iter=smart_loop_iter_var,
         cont=cont_var)
     return repl
コード例 #44
0
  def _create_break_init(self):

    def template(var_name):  # pylint:disable=unused-argument
      var_name = False

    assign, = templates.replace(
        template, var_name=gast.Name(self.break_uses[-1][1], None, None))
    return assign
コード例 #45
0
ファイル: call_trees.py プロジェクト: Youed/tensorflow-1
 def _wrap_to_py_func_no_return(self, node):
     # TODO(mdan): Properly handle varargs, kwargs, etc.
     template = """
   py2tf_utils.wrap_py_func(func, None, (original_args,), True)
 """
     return templates.replace(template,
                              func=node.func,
                              original_args=node.args)
コード例 #46
0
 def _gate_symbols(self, guard_statement, guarded_args):
     # TODO(mdan): This won't work for variables.
     template = """
   (args,) = (tf.identity(a) for a in (args,))
 """
     guards = templates.replace(template, args=tuple(guarded_args))
     guard_statement.body.extend(guards)
     return guard_statement
コード例 #47
0
 def visit_For(self, node):
     self.generic_visit(node)
     body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
     i_var = self.context.namer.new_symbol('i', body_scope.referenced)
     smart_loop_iter_var = self.context.namer.new_symbol(
         'smart_loop_iter', body_scope.referenced)
     cont_var = self.context.namer.new_symbol('cont', body_scope.referenced)
     # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
     if anno.hasanno(node, 'extra_cond'):
         template = """
     i = 0
     smart_loop_iter = py2tf_utils.dynamic_dataset(loop_iter)
     cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter)
     while cont and extra_cond:
       body
       i += 1
       cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter)
   """
         return templates.replace(template,
                                  loop_iter=node.iter,
                                  target=node.target,
                                  body=node.body,
                                  i=i_var,
                                  smart_loop_iter=smart_loop_iter_var,
                                  cont=cont_var,
                                  extra_cond=anno.getanno(
                                      node, 'extra_cond'))
     else:
         template = """
     i = 0
     smart_loop_iter = py2tf_utils.dynamic_dataset(loop_iter)
     cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter)
     while cont:
       body
       i += 1
       cont, target = py2tf_utils.dynamic_for_cond(i, smart_loop_iter)
   """
         repl = templates.replace(template,
                                  loop_iter=node.iter,
                                  target=node.target,
                                  body=node.body,
                                  i=i_var,
                                  smart_loop_iter=smart_loop_iter_var,
                                  cont=cont_var)
         return repl
コード例 #48
0
  def replace_as_expression(self):
    template = """
      foo(a)
    """

    node = templates.replace(template, foo='bar', a='baz')
    self.assertTrue(node is gast.Call)
    self.assertEqual(node.func.id, 'bar')
    self.assertEqual(node.func.args[0].id, 'baz')
コード例 #49
0
    def test_replace_tuple(self):
        template = """
      def test_fn(a, c):
        return b,
    """

        node = templates.replace(template, b=('a', 'c'))[0]
        result = compiler.ast_to_object(node)
        self.assertEquals((2, 3), result.test_fn(2, 3))
コード例 #50
0
    def canonicalize_listcomp(self, result_node, list_comp_node):

        make_list = templates.replace('list_ = create_list',
                                      list_=result_node,
                                      create_list=self.instantiate_list_node())
        loop_body = self.make_update_list_node(result_node, list_comp_node.elt)

        for gen in reversed(list_comp_node.generators):
            for gen_if in reversed(gen.ifs):
                loop_body = templates.replace('if test: loop_body',
                                              test=gen_if,
                                              loop_body=loop_body)
            loop_body = templates.replace('for target in iter_: loop_body',
                                          iter_=gen.iter,
                                          target=gen.target,
                                          loop_body=loop_body)

        return make_list + loop_body
コード例 #51
0
  def _create_break_trigger(self):

    def template(var_name):  # pylint:disable=unused-argument
      var_name = True

    block = templates.replace(
        template, var_name=gast.Name(self.break_uses[-1][1], None, None))
    block.append(gast.Continue())
    return block
コード例 #52
0
 def _inline_tf_op(self, op_name, args):
   template = """
     tf.op_name(args)
   """
   replacement = templates.replace(template, op_name=op_name, args=args)
   # It's a body with a single expression, we want its value.
   n = replacement[0].value
   anno.setanno(n, SAFE_BOOLEAN_OPERAND, True)
   return n
コード例 #53
0
 def _create_cond_expr(self, results, test, body_name, orelse_name):
     if results is not None:
         template = """
     results = py2tf_utils.run_cond(test, body_name, orelse_name)
   """
         return templates.replace(template,
                                  test=test,
                                  results=results,
                                  body_name=body_name,
                                  orelse_name=orelse_name)
     else:
         template = """
     py2tf_utils.run_cond(test, body_name, orelse_name)
   """
         return templates.replace(template,
                                  test=test,
                                  body_name=body_name,
                                  orelse_name=orelse_name)
コード例 #54
0
  def _gate_symbols(self, guard_statement, guarded_args):

    def template(args):  # pylint:disable=unused-argument
      (args,) = (tf.identity(a) for a in (args,))  # pylint:disable=undefined-variable

    guards = templates.replace(
        template, args=tuple(gast.Name(a, None, None) for a in guarded_args))
    guard_statement.body.extend(guards)
    return guard_statement
コード例 #55
0
    def _gate_symbols(self, guard_statement, guarded_args):
        def template(args):  # pylint:disable=unused-argument
            (args, ) = (tf.identity(a) for a in (args, ))  # pylint:disable=undefined-variable

        guards = templates.replace(
            template,
            args=tuple(gast.Name(a, None, None) for a in guarded_args))
        guard_statement.body.extend(guards)
        return guard_statement
コード例 #56
0
 def _create_continuation_check(self):
     template = """
   if not var_name:
     pass
 """
     cond, = templates.replace(template,
                               var_name=self.continuation_uses[-1][1])
     cond.body = []
     return cond
コード例 #57
0
  def test_replace_tuple(self):
    template = """
      def test_fn(a, c):
        return b,
    """

    node = templates.replace(template, b=('a', 'c'))[0]
    result = compiler.ast_to_object(node)
    self.assertEquals((2, 3), result.test_fn(2, 3))
コード例 #58
0
    def replace_as_expression(self):
        template = """
      foo(a)
    """

        node = templates.replace(template, foo='bar', a='baz')
        self.assertTrue(node is gast.Call)
        self.assertEqual(node.func.id, 'bar')
        self.assertEqual(node.func.args[0].id, 'baz')
コード例 #59
0
  def visit_While(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)
    if not state:
      # TODO(mdan): Implement this properly.
      # To complete this statement, we need to check whether any variable
      # created inside the body scope is used before being modified outside the
      # scope. This should be done during activity analysis, and in general
      # should cover the case where variables may not be initialized.
      raise ValueError('cannot convert while loop: no outputs')

    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    test = ast_util.rename_symbols(node.test, ssf_map)

    template = """
      def test_name(state_ssf):
        return test
      def body_name(state_ssf):
        body
        return state_ssf,
      state_ast_tuple = py2tf_utils.run_while(test_name, body_name, [state])
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        test_name=self.context.namer.new_symbol('loop_test',
                                                body_scope.referenced),
        test=test,
        body_name=self.context.namer.new_symbol('loop_body',
                                                body_scope.referenced),
        body=node_body)

    return node
コード例 #60
0
  def _create_continuation_check(self):

    def template(var_name):
      if not var_name:
        pass

    cond, = templates.replace(
        template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
    cond.body = []
    return cond