Ejemplo n.º 1
0
    def grad_fn(*args):
        """Computes the gradient of the wrapped function."""
        tape.push_new_tape()
        try:
            end_node = f(*args)
            if end_node is None:
                raise ValueError(
                    "Cannot differentiate a function that returns None; "
                    "did you forget to return a value from {}?".format(
                        f.__name__))
        finally:
            popped_tape = tape.pop_tape()
        # Sorting variables by id, which is monotonically increasing in construction
        # order. This ensures unique order across executions.
        variables = list(
            sorted(popped_tape.watched_variables(),
                   key=lambda v: v.handle._id))  # pylint: disable=protected-access
        sources = [x.handle for x in variables]

        if not sources:
            raise ValueError("No trainable variables were accessed while the "
                             "function was being computed.")
        grad = imperative_grad.imperative_grad(_default_vspace, popped_tape,
                                               nest.flatten(end_node), sources)
        return end_node, list(zip(grad, variables))
Ejemplo n.º 2
0
 def decorated(*args, **kwds):
   """Computes the value and gradient of the decorated function."""
   parameter_positions = _get_arg_spec(f, params, args)
   assert not kwds, "The gradient function can't take keyword arguments."
   tape.push_new_tape()
   try:
     sources = []
     args = [
         ops.convert_to_tensor(args[i])
         if i in parameter_positions else args[i]
         for i in range(len(args))
     ]
     args = _ensure_unique_tensor_objects(parameter_positions, args)
     for i in parameter_positions:
       sources.append(args[i])
       tape.watch(args[i])
     result = f(*args)
     if result is None:
       raise ValueError("Cannot differentiate a function that returns None; "
                        "did you forget to return a value from {}?".format(
                            f.__name__))
     flat_result = nest.flatten(result)
     flat_result = [gen_array_ops.identity(x) for x in flat_result]
     result = nest.pack_sequence_as(result, flat_result)
   finally:
     t = tape.pop_tape()
   def vjp(dy=None):
     if dy is not None:
       dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
     return imperative_grad.imperative_grad(
         _default_vspace, t, nest.flatten(result), sources,
         output_gradients=dy)
   return result, vjp
Ejemplo n.º 3
0
  def grad_fn(*args):
    """Computes the gradient of the wrapped function."""
    tape.push_new_tape()
    try:
      end_node = f(*args)
      if end_node is None:
        raise ValueError("Cannot differentiate a function that returns None; "
                         "did you forget to return a value from {}?".format(
                             f.__name__))
    finally:
      popped_tape = tape.pop_tape()
    # Sorting variables by id, which is monotonically increasing in construction
    # order. This ensures unique order across executions.
    variables = list(sorted(popped_tape.watched_variables(),
                            key=lambda v: v.handle._id))  # pylint: disable=protected-access
    sources = [x.handle for x in variables]

    if not sources:
      raise ValueError("No trainable variables were accessed while the "
                       "function was being computed.")
    grad = imperative_grad.imperative_grad(_default_vspace,
                                           popped_tape,
                                           nest.flatten(end_node),
                                           sources)
    return end_node, list(zip(grad, variables))
Ejemplo n.º 4
0
 def decorated(*args, **kwds):
   """Computes the value and gradient of the decorated function."""
   parameter_positions = _get_arg_spec(f, params, args)
   assert not kwds, "The gradient function can't take keyword arguments."
   tape.push_new_tape()
   try:
     sources = []
     args = [
         ops.convert_to_tensor(args[i])
         if i in parameter_positions else args[i]
         for i in range(len(args))
     ]
     args = _ensure_unique_tensor_objects(parameter_positions, args)
     for i in parameter_positions:
       sources.append(args[i])
       tape.watch(args[i])
     result = f(*args)
     if result is None:
       raise ValueError("Cannot differentiate a function that returns None; "
                        "did you forget to return a value from {}?".format(
                            f.__name__))
     flat_result = nest.flatten(result)
     flat_result = [gen_array_ops.identity(x) for x in flat_result]
     result = nest.pack_sequence_as(result, flat_result)
   finally:
     t = tape.pop_tape()
   def vjp(dy=None):
     if dy is not None:
       dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
     return imperative_grad.imperative_grad(
         _default_vspace, t, nest.flatten(result), sources,
         output_gradients=dy)
   return result, vjp
Ejemplo n.º 5
0
    def decorated(*args, **kwds):
        """Computes the value and gradient of the decorated function."""
        assert not kwds, "The gradient function can't take keyword arguments."
        tape.push_new_tape()
        sources = []
        args = [
            ops.convert_to_tensor(args[i])
            if i in parameter_positions else args[i] for i in range(len(args))
        ]
        args = _ensure_unique_tensor_objects(parameter_positions, args)
        for i in parameter_positions:
            sources.append(args[i])
            tape.watch(args[i])
        result = f(*args)
        t = tape.pop_tape()

        def vjp(dy=None):
            return imperative_grad.imperative_grad(
                _default_vspace,
                t,
                nest.flatten(result),
                sources,
                output_gradients=nest.flatten(dy) if dy is not None else None)

        return result, vjp
Ejemplo n.º 6
0
 def grad_fn(*args):
   """Computes the gradient of the wrapped function."""
   tape.push_new_tape()
   end_node = f(*args)
   variables = tape.top_tape_watched_variables()
   sources = [x.handle for x in variables]
   grad = imperative_grad(end_node, sources)
   return end_node, list(zip(grad, variables))
Ejemplo n.º 7
0
 def grad_fn(*args):
     """Computes the gradient of the wrapped function."""
     tape.push_new_tape()
     end_node = f(*args)
     variables = tape.top_tape_watched_variables()
     sources = [x.handle for x in variables]
     grad = imperative_grad(end_node, sources)
     return end_node, list(zip(grad, variables))
Ejemplo n.º 8
0
def _defun_internal(name, func, args, kwds):
    """Defines and returns graph-mode version of func."""
    container_prefix = ops.get_default_graph()._container_prefix  # pylint: disable=protected-access
    with context.graph_mode():
        captures = {}
        tmp_graph = CapturingGraph(captures)
        # Inherit the container prefix, since this is used for error checking when
        # isolating eager execution (the container prefix at creation must match the
        # container prefix when used, and variables accessed in the defun will be
        # used in the outside context).
        tmp_graph._container_prefix = container_prefix  # pylint: disable=protected-access
        # Copy the graph collections to ensure summaries and other things work. This
        # lets the function access (but not mutate) collections of the containing
        # graph, such as the global step and the summary writer collections.
        curr_graph = ops.get_default_graph()
        for collection in curr_graph.collections:
            tmp_graph.get_collection_ref(
                collection)[:] = curr_graph.get_collection(collection)
        with tmp_graph.as_default():
            func_inputs = _get_defun_inputs(args)

            with capture_tensors(captures):
                tape.push_new_tape()
                try:
                    func_outputs = func(*func_inputs, **kwds)
                finally:
                    variables = tape.pop_tape().watched_variables()
            ids = list(sorted(captures.keys()))
            if ids:
                extra_inputs, extra_placeholders = zip(
                    *[captures[x] for x in ids])
            else:
                extra_inputs = []
                extra_placeholders = []
            outputs_list = nest.flatten(func_outputs)
            output_shapes = tuple(x.shape for x in outputs_list
                                  if x is not None)

    flat_inputs = [
        x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
    ]
    all_inputs = flat_inputs + list(extra_placeholders)
    all_ignored_ops = frozenset(x.op for x in all_inputs)
    func_def_outputs = [x for x in outputs_list if x is not None]
    fname = _inference_name(name)
    operations = tuple(x for x in tmp_graph.get_operations()
                       if x not in all_ignored_ops)
    # Register any other functions defined in the graph
    # TODO(ashankar): Oh lord, forgive me for this lint travesty.
    for f in tmp_graph._functions.values():  # pylint: disable=protected-access
        # TODO(ashankar): What about the gradient registry?
        _register(f._c_func)  # pylint: disable=protected-access
    return GraphModeFunction(fname, all_inputs, extra_inputs, tmp_graph,
                             operations, func_def_outputs, func_outputs,
                             output_shapes, variables)
Ejemplo n.º 9
0
 def decorated(*args, **kwds):
     """Computes the value and gradient of the decorated function."""
     dy = kwds.pop("dy", None)
     assert not kwds, "The gradient function can't take keyword arguments."
     tape.push_new_tape()
     sources = []
     args = list(args)
     for i in parameter_positions:
         sources.append(args[i])
         tape.watch(args[i])
     result = f(*args)
     return result, imperative_grad(result, sources, output_gradients=dy)
Ejemplo n.º 10
0
    def grad_fn(*args):
        """Computes the gradient of the wrapped function."""
        tape.push_new_tape()
        end_node = f(*args)
        variables = tape.top_tape_watched_variables()
        sources = [x.handle for x in variables]

        if not sources:
            raise ValueError("no trainable variables were accessed while the "
                             "function was being computed.")
        grad = imperative_grad.imperative_grad(_default_vspace,
                                               tape.pop_tape(),
                                               nest.flatten(end_node), sources)
        return end_node, list(zip(grad, variables))
Ejemplo n.º 11
0
  def grad_fn(*args):
    """Computes the gradient of the wrapped function."""
    tape.push_new_tape()
    end_node = f(*args)
    variables = tape.top_tape_watched_variables()
    sources = [x.handle for x in variables]

    if not sources:
      raise ValueError("no trainable variables were accessed while the "
                       "function was being computed.")
    grad = imperative_grad.imperative_grad(_default_vspace,
                                           tape.pop_tape(),
                                           nest.flatten(end_node),
                                           sources)
    return end_node, list(zip(grad, variables))
Ejemplo n.º 12
0
def _defun_internal(name, func, args, kwds):
  """Defines and returns graph-mode version of func."""
  graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
  with context.graph_mode():
    captures = {}
    tmp_graph = CapturingGraph(captures)
    # Inherit the graph key, since this is used for matching variables in
    # optimizers.
    tmp_graph._graph_key = graph_key  # pylint: disable=protected-access
    # Copy the graph collections to ensure summaries and other things work. This
    # lets the function access (but not mutate) collections of the containing
    # graph, such as the global step and the summary writer collections.
    curr_graph = ops.get_default_graph()
    for collection in curr_graph.collections:
      tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
          collection)
    with tmp_graph.as_default():
      func_inputs = _get_defun_inputs(args)

      with capture_tensors(captures):
        this_tape = tape.push_new_tape()
        try:
          func_outputs = func(*func_inputs, **kwds)
        finally:
          tape.pop_tape(this_tape)
        variables = this_tape.watched_variables()

        # Returning a closed-over tensor as an output does not trigger a
        # call to convert_to_tensor, so we manually capture all such tensors.
        outputs_list = _flatten(func_outputs)
        func_def_outputs = [
            _convert_to_graph_tensor(x) for x in outputs_list if x is not None
        ]

      ids = list(sorted(captures.keys()))
      if ids:
        extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
      else:
        extra_inputs = []
        extra_placeholders = []
      output_shapes = tuple(
          x.shape if isinstance(x, ops.Tensor) else None
          for x in outputs_list)

  flat_inputs = [x for x in nest.flatten(func_inputs)
                 if isinstance(x, ops.Tensor)]
  all_inputs = flat_inputs + list(extra_placeholders)
  all_ignored_ops = frozenset(x.op for x in all_inputs)
  fname = _inference_name(name)
  operations = tuple(x for x in tmp_graph.get_operations()
                     if x not in all_ignored_ops)
  # Register any other functions defined in the graph
  # TODO(ashankar): Oh lord, forgive me for this lint travesty.
  if context.in_eager_mode():
    for f in tmp_graph._functions.values():  # pylint: disable=protected-access
      # TODO(ashankar): What about the gradient registry?
      _register(f._c_func)  # pylint: disable=protected-access
  return GraphModeFunction(
      fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
      func_outputs, output_shapes, variables)
Ejemplo n.º 13
0
  def grad_fn(*args, **kwds):
    """Computes the gradient of the wrapped function."""
    this_tape = tape.push_new_tape()
    try:
      end_node = f(*args, **kwds)
      if end_node is None:
        raise ValueError("Cannot differentiate a function that returns None; "
                         "did you forget to return a value from {}?".format(
                             f.__name__))
    finally:
      tape.pop_tape(this_tape)
    # Note: variables are returned in construction order. This ensures unique
    # order across executions.
    variables = this_tape.watched_variables()
    if not variables:
      raise ValueError("No trainable variables were accessed while the "
                       "function was being computed.")

    sources = [v.handle for v in variables]
    for s in sources:
      if getattr(s, "is_packed", False):
        raise ValueError(
            "GradientTape.gradient is not supported on packed EagerTensors yet."
        )
    grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
                                           sources)
    return end_node, list(zip(grad, variables))
Ejemplo n.º 14
0
 def decorated(*args, **kwds):
   """Computes the value and gradient of the decorated function."""
   dy = kwds.pop("dy", None)
   if dy is not None:
     dy = ops.convert_to_tensor(dy)
   assert not kwds, "The gradient function can't take keyword arguments."
   tape.push_new_tape()
   sources = []
   args = [ops.convert_to_tensor(x) for x in args]
   for i in parameter_positions:
     sources.append(args[i])
     tape.watch(args[i])
   result = f(*args)
   return result, imperative_grad(
       result,
       sources,
       output_gradients=dy)
Ejemplo n.º 15
0
 def grad_fn(*args, **kwds):
     """Computes the gradient of the wrapped function."""
     tape.push_new_tape()
     end_node = f(*args)
     start_node = tape.pop_tape()
     ag_core.active_progenitors.remove(start_node)
     if not ag_core.isnode(end_node):
         raise ValueError(
             "Target not part of a computation being traced. %s" % end_node)
     if start_node not in end_node.progenitors:
         raise ValueError("Target not derived from source. %s %s" %
                          (end_node.progenitors, repr(start_node)))
     output_gradients = kwds.get("output_gradients", None)
     if output_gradients is None:
         output_gradients = _ones(end_node.shape, end_node.dtype)
     grad = ag_core.backward_pass(output_gradients, end_node, start_node)
     return end_node.value, _aggregate_grads(grad.gradients)
Ejemplo n.º 16
0
 def grad_fn(*args, **kwds):
   """Computes the gradient of the wrapped function."""
   tape.push_new_tape()
   end_node = f(*args)
   start_node = tape.pop_tape()
   ag_core.active_progenitors.remove(start_node)
   if not ag_core.isnode(end_node):
     raise ValueError(
         "Target not part of a computation being traced. %s" % end_node)
   if start_node not in end_node.progenitors:
     raise ValueError("Target not derived from source. %s %s" %
                      (end_node.progenitors, repr(start_node)))
   output_gradients = kwds.get("output_gradients", None)
   if output_gradients is None:
     output_gradients = _ones(end_node.shape, end_node.dtype)
   grad = ag_core.backward_pass(output_gradients, end_node, start_node)
   return end_node.value, _aggregate_grads(grad.gradients)
Ejemplo n.º 17
0
 def decorated(*args, **kwds):
     """Computes the value and gradient of the decorated function."""
     dy = kwds.pop("dy", None)
     if dy is not None:
         dy = ops.convert_to_tensor(dy)
     assert not kwds, "The gradient function can't take keyword arguments."
     tape.push_new_tape()
     sources = []
     args = [
         ops.convert_to_tensor(args[i])
         if i in parameter_positions else args[i] for i in range(len(args))
     ]
     for i in parameter_positions:
         sources.append(args[i])
         tape.watch(args[i])
     result = f(*args)
     return result, imperative_grad(result, sources, output_gradients=dy)
Ejemplo n.º 18
0
  def testTapeGC(self):
    # TODO(apassos) figure out how to test this without using tape internal
    # APIs.
    tape.push_new_tape()

    def f():
      x = constant_op.constant(1.0)
      tape.watch(x)
      x = gradient_is_constant(x)
      x = gradient_is_constant(x)
      x = gradient_is_constant(x)

    f()
    t = tape.pop_tape()
    tensor_tape, op_tape = t.export()
    self.assertEqual(len(tensor_tape), 1)  # The watched tensor will remain on
                                           # the tape
    self.assertEqual(len(op_tape), 0)  # No operations should remain on the tape
Ejemplo n.º 19
0
 def _push_tape(self):
     if self._recording:
         raise ValueError("Tape is already recording.")
     if self._tape is None:
         self._tape = tape.push_new_tape(
             persistent=self._persistent,
             watch_accessed_variables=self._watch_accessed_variables)
     else:
         tape.push_tape(self._tape)
     self._recording = True
Ejemplo n.º 20
0
 def _push_tape(self):
   if self._recording:
     raise ValueError("Tape is already recording.")
   if self._tape is None:
     self._tape = tape.push_new_tape(
         persistent=self._persistent,
         watch_accessed_variables=self._watch_accessed_variables)
   else:
     tape.push_tape(self._tape)
   self._recording = True
Ejemplo n.º 21
0
 def _push_tape(self, existing_tape=False):
   if self._recording:
     raise ValueError("Tape is already recording.")
   if existing_tape:
     if self._tape is None:
       raise ValueError("There is no existing tape.")
     tape.push_tape(self._tape)
   else:
     self._tape = tape.push_new_tape(persistent=self._persistent)
   self._recording = True
Ejemplo n.º 22
0
    def testTapeGC(self):
        # TODO(apassos) figure out how to test this without using tape internal
        # APIs.
        tape.push_new_tape()

        def f():
            x = constant_op.constant(1.0)
            tape.watch(x)
            x = gradient_is_constant(x)
            x = gradient_is_constant(x)
            x = gradient_is_constant(x)

        f()
        t = tape.pop_tape()
        tensor_tape, op_tape = t.export()
        self.assertEqual(len(tensor_tape),
                         1)  # The watched tensor will remain on
        # the tape
        self.assertEqual(len(op_tape),
                         0)  # No operations should remain on the tape
Ejemplo n.º 23
0
 def decorated(*args, **kwds):
   """Computes the value and gradient of the decorated function."""
   dy = kwds.pop("dy", None)
   if dy is not None:
     dy = ops.convert_to_tensor(dy)
   assert not kwds, "The gradient function can't take keyword arguments."
   tape.push_new_tape()
   sources = []
   args = [
       ops.convert_to_tensor(args[i]) if i in parameter_positions else args[i]
       for i in range(len(args))
   ]
   args = _ensure_unique_tensor_objects(parameter_positions, args)
   for i in parameter_positions:
     sources.append(args[i])
     tape.watch(args[i])
   result = f(*args)
   return result, imperative_grad.imperative_grad(
       _default_vspace, nest.flatten(result), sources,
       output_gradients=nest.flatten(dy) if dy is not None else None)
Ejemplo n.º 24
0
 def _push_tape(self):
   """Pushes a new tape onto the tape stack."""
   if self._recording:
     raise ValueError("Tape is still recording, This can happen if you try to "
                      "re-enter an already-active tape.")
   if self._tape is None:
     self._tape = tape.push_new_tape(
         persistent=self._persistent,
         watch_accessed_variables=self._watch_accessed_variables)
   else:
     tape.push_tape(self._tape)
   self._recording = True
Ejemplo n.º 25
0
    def grad_fn(*args):
        """Computes the gradient of the wrapped function."""
        tape.push_new_tape()
        try:
            end_node = f(*args)
            if end_node is None:
                raise ValueError(
                    "Cannot differentiate a function that returns None; "
                    "did you forget to return a value from {}?".format(
                        f.__name__))
        finally:
            popped_tape = tape.pop_tape()
            variables = popped_tape.watched_variables()
        sources = [x.handle for x in variables]

        if not sources:
            raise ValueError("No trainable variables were accessed while the "
                             "function was being computed.")
        grad = imperative_grad.imperative_grad(_default_vspace, popped_tape,
                                               nest.flatten(end_node), sources)
        return end_node, list(zip(grad, variables))
Ejemplo n.º 26
0
  def grad_fn(*args):
    """Computes the gradient of the wrapped function."""
    tape.push_new_tape()
    try:
      end_node = f(*args)
      if end_node is None:
        raise ValueError("Cannot differentiate a function that returns None; "
                         "did you forget to return a value from {}?".format(
                             f.__name__))
    finally:
      popped_tape = tape.pop_tape()
      variables = popped_tape.watched_variables()
    sources = [x.handle for x in variables]

    if not sources:
      raise ValueError("No trainable variables were accessed while the "
                       "function was being computed.")
    grad = imperative_grad.imperative_grad(_default_vspace,
                                           popped_tape,
                                           nest.flatten(end_node),
                                           sources)
    return end_node, list(zip(grad, variables))
Ejemplo n.º 27
0
    def decorated(*args, **kwds):
        """Computes the value and gradient of the decorated function."""
        parameter_positions = _get_arg_spec(f, params, args)
        assert not kwds, "The gradient function can't take keyword arguments."
        this_tape = tape.push_new_tape(persistent=persistent)
        try:
            sources = []
            args = [
                ops.convert_to_tensor(arg) if i in parameter_positions else arg
                for i, arg in enumerate(args)
            ]
            args = _ensure_unique_tensor_objects(parameter_positions, args)
            for i in parameter_positions:
                if getattr(args[i], "is_packed", False):
                    raise ValueError(
                        "GradientTape.gradient is not supported on packed EagerTensors"
                        "yet.")
                sources.append(args[i])
                tape.watch(this_tape, args[i])
            result = f(*args)
            if result is None:
                raise ValueError(
                    "Cannot differentiate a function that returns None; "
                    "did you forget to return a value from {}?".format(
                        f.__name__))
            flat_result = nest.flatten(result)
            flat_result = [gen_array_ops.identity(x) for x in flat_result]
            result = nest.pack_sequence_as(result, flat_result)
        finally:
            tape.pop_tape(this_tape)

        def vjp(dy=None):
            if dy is not None:
                dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
            return imperative_grad.imperative_grad(this_tape,
                                                   nest.flatten(result),
                                                   sources,
                                                   output_gradients=dy)

        return result, vjp
Ejemplo n.º 28
0
  def grad_fn(*args, **kwds):
    """Computes the gradient of the wrapped function."""
    this_tape = tape.push_new_tape()
    try:
      end_node = f(*args, **kwds)
      if end_node is None:
        raise ValueError("Cannot differentiate a function that returns None; "
                         "did you forget to return a value from {}?".format(
                             f.__name__))
    finally:
      tape.pop_tape(this_tape)
    # Note: variables are returned in construction order. This ensures unique
    # order across executions.
    variables = this_tape.watched_variables()
    if not variables:
      raise ValueError("No trainable variables were accessed while the "
                       "function was being computed.")

    sources = [v.handle for v in variables]
    grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
                                           sources)
    return end_node, list(zip(grad, variables))
Ejemplo n.º 29
0
def _defun_internal(name, func, args, kwds):
  """Defines and returns graph-mode version of func."""
  container_prefix = ops.get_default_graph()._container_prefix  # pylint: disable=protected-access
  with context.graph_mode():
    captures = {}
    tmp_graph = CapturingGraph(captures)
    # Inherit the container prefix, since this is used for error checking when
    # isolating eager execution (the container prefix at creation must match the
    # container prefix when used, and variables accessed in the defun will be
    # used in the outside context).
    tmp_graph._container_prefix = container_prefix  # pylint: disable=protected-access
    # Copy the graph collections to ensure summaries and other things work. This
    # lets the function access (but not mutate) collections of the containing
    # graph, such as the global step and the summary writer collections.
    curr_graph = ops.get_default_graph()
    for collection in curr_graph.collections:
      tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
          collection)
    with tmp_graph.as_default():
      func_inputs = _get_defun_inputs(args)

      with capture_tensors(captures):
        tape.push_new_tape()
        try:
          func_outputs = func(*func_inputs, **kwds)
        finally:
          variables = tape.pop_tape().watched_variables()
      ids = list(sorted(captures.keys()))
      if ids:
        extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
      else:
        extra_inputs = []
        extra_placeholders = []
      outputs_list = nest.flatten(func_outputs)
      output_shapes = [x.shape for x in outputs_list if x is not None]

  flat_inputs = [
      x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
  ]
  all_inputs = flat_inputs + list(extra_placeholders)

  func_def_outputs = [x for x in outputs_list if x is not None]
  inference_function_def = make_function_def(
      tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
  # Register any other functions defined in the graph
  # TODO(ashankar): Oh lord, forgive me for this lint travesty.
  for f in tmp_graph._functions.values():  # pylint: disable=protected-access
    # TODO(ashankar): What about the gradient registry?
    _register_with_name(f.name, f.definition)
  _register_with_name(_inference_name(name), inference_function_def)

  return GraphModeFunction(
      all_inputs,
      extra_inputs,
      inference_function_def,
      tmp_graph,
      tmp_graph.get_operations(),
      func_outputs,
      _map_sequence_obj_to_idx(func_def_outputs),
      output_shapes,
      variables=variables)
Ejemplo n.º 30
0
def _defun_internal(name, func, args, kwds):
    """Defines and returns graph-mode version of func."""
    graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
    with context.graph_mode():
        captures = {}
        tmp_graph = CapturingGraph(captures)
        # Inherit the graph key, since this is used for matching variables in
        # optimizers.
        tmp_graph._graph_key = graph_key  # pylint: disable=protected-access
        # Copy the graph collections to ensure summaries and other things work. This
        # lets the function access (but not mutate) collections of the containing
        # graph, such as the global step and the summary writer collections.
        curr_graph = ops.get_default_graph()
        for collection in curr_graph.collections:
            tmp_graph.get_collection_ref(
                collection)[:] = curr_graph.get_collection(collection)
        with tmp_graph.as_default(), AutomaticControlDependencies() as a:
            func_inputs = _get_defun_inputs(args)

            def convert(x):
                if x is None:
                    return None
                x = ops.convert_to_tensor_or_indexed_slices(x)
                x = a.mark_as_return(x)
                return x

            with capture_tensors(captures):
                this_tape = tape.push_new_tape()
                try:
                    func_outputs = func(*func_inputs, **kwds)
                    func_outputs = nest.map_structure(convert, func_outputs)
                finally:
                    tape.pop_tape(this_tape)
                variables = this_tape.watched_variables()

                # Returning a closed-over tensor as an output does not trigger a
                # call to convert_to_tensor, so we manually capture all such tensors.
                outputs_list = _flatten(func_outputs)
                func_def_outputs = [
                    _convert_to_graph_tensor(x) for x in outputs_list
                    if x is not None
                ]

            ids = list(sorted(captures.keys()))
            if ids:
                extra_inputs, extra_placeholders = zip(
                    *[captures[x] for x in ids])
            else:
                extra_inputs = []
                extra_placeholders = []
            output_shapes = tuple(
                x.shape if isinstance(x, ops.Tensor) else None
                for x in outputs_list)

    flat_inputs = [
        x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
    ]
    all_inputs = flat_inputs + list(extra_placeholders)
    all_ignored_ops = frozenset(x.op for x in all_inputs)
    fname = _inference_name(name)
    operations = tuple(x for x in tmp_graph.get_operations()
                       if x not in all_ignored_ops)
    # Register any other functions defined in the graph
    # TODO(ashankar): Oh lord, forgive me for this lint travesty.
    if context.executing_eagerly():
        for f in tmp_graph._functions.values():  # pylint: disable=protected-access
            # TODO(ashankar): What about the gradient registry?
            _register(f._c_func.func)  # pylint: disable=protected-access
    return GraphModeFunction(fname, all_inputs, extra_inputs, tmp_graph,
                             operations, func_def_outputs, func_outputs,
                             output_shapes, variables)
Ejemplo n.º 31
0
def func_graph_from_py_func(name,
                            python_func,
                            args,
                            kwargs,
                            signature=None,
                            func_graph=None,
                            autograph=False,
                            autograph_options=None,
                            add_control_dependencies=True,
                            arg_names=None,
                            op_return_value=None,
                            collections=None,
                            capture_by_value=None):
    """Returns a `FuncGraph` generated from `python_func`.

  Args:
    name: an identifier for the function.
    python_func: the Python function to trace.
    args: the positional args with which the Python function should be called;
      ignored if a signature is provided.
    kwargs: the keyword args with which the Python function should be called;
      ignored if a signature is provided.
    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
      and dtypes of the arguments. When a signature is provided, `args` and
      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
      to `signature`. If `None`, the shapes and dtypes are inferred from the
      inputs.
    func_graph: Optional. An instance of FuncGraph. If provided, we will use
      this graph else a new one is built and returned.
    autograph: whether to use autograph to compile `python_func`.
      See https://www.tensorflow.org/guide/autograph for more information.
    autograph_options: additional knobs to control when `autograph=True`.
      See https://www.tensorflow.org/guide/autograph for more information.
    add_control_dependencies: If True, automatically adds control dependencies
      to ensure program order matches execution order and stateful ops always
      execute.
    arg_names: Optional list of argument names, used to give input placeholders
      recognizable names.
    op_return_value: Optional. A Tensor. If set and `python_func` returns
      Operations, those return values will be replaced with this value. If not
      set, returning an Operation triggers an error.
    collections: a dictionary of collections this FuncGraph should start
      with. If not specified (None), the FuncGraph will read (but not write to)
      the outer graph's collections that are not whitelisted, and both
      read and write to the outer graph's collections that are whitelisted.
      The current whitelisted collections are the global variables, the
      local variables, and the trainable variables.
      Defaults to None.
    capture_by_value: An optional boolean. If True, the func graph will capture
      Variables by value instead of reference. By default inherit from outer
      graphs, and failing that will default to False.

  Returns:
    A FuncGraph.

  Raises:
    TypeError: If any of `python_func`'s return values is neither `None` nor a
      `Tensor`.
  """
    if op_return_value is not None:
        assert isinstance(op_return_value, ops.Tensor), op_return_value
    if func_graph is None:
        func_graph = FuncGraph(name,
                               collections=collections,
                               capture_by_value=capture_by_value)
    assert isinstance(func_graph, FuncGraph)
    if add_control_dependencies:
        control_manager = AutomaticControlDependencies()
    else:
        control_manager = ops.NullContextmanager()
    with func_graph.as_default(), control_manager as a:
        current_scope = variable_scope.get_variable_scope()
        default_use_recource = current_scope.use_resource
        current_scope.set_use_resource(True)

        if signature is not None:
            args = signature
            kwargs = {}

        # Creates and names placeholders for all arguments.
        func_args = _get_defun_inputs_from_args(args, arg_names)
        func_kwargs = _get_defun_inputs_from_kwargs(kwargs)

        # Convert all Tensors into TensorSpecs before saving the structured inputs.
        # If storing pure concrete functions that are not called through polymorphic
        # functions, we don't have access to FunctionSpec, so we need to call the
        # TensorSpecs by their `arg_names` for later binding.
        func_graph.structured_input_signature = (
            convert_structure_to_signature(func_args, arg_names),
            convert_structure_to_signature(func_kwargs))

        flat_func_args = nest.flatten(func_args)
        flat_func_kwargs = nest.flatten(func_kwargs)
        # Temporarily set inputs to allow graph building code to inspect
        # them. Reassigned below.
        func_graph.inputs = [
            arg for arg in flat_func_args + flat_func_kwargs
            if isinstance(arg, ops.Tensor)
        ]

        # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
        # Variables to help check whether mutation happens in calling the function
        # Copy the recursive list, tuple and map structure, but not base objects
        func_args_before = nest.pack_sequence_as(func_args, flat_func_args)
        func_kwargs_before = nest.pack_sequence_as(func_kwargs,
                                                   flat_func_kwargs)

        def convert(x):
            """Converts a function output to a Tensor."""
            if x is None:
                return None
            if op_return_value is not None and isinstance(x, ops.Operation):
                # TODO(b/79881896): we currently can't capture external control deps, so
                # this won't work if x needs to be captured (i.e. if python_func returns
                # captured Operations).
                with ops.control_dependencies([x]):
                    x = array_ops.identity(op_return_value)
            elif not isinstance(x, tensor_array_ops.TensorArray):
                try:
                    x = ops.convert_to_tensor_or_composite(x)
                except (ValueError, TypeError):
                    raise TypeError(
                        "To be compatible with tf.contrib.eager.defun, Python functions "
                        "must return zero or more Tensors; in compilation of %s, found "
                        "return value of type %s, which is not a Tensor." %
                        (str(python_func), type(x)))
            if add_control_dependencies:
                x = a.mark_as_return(x)
            return x

        this_tape = tape.push_new_tape()
        try:
            if autograph:
                from tensorflow.python import autograph  # pylint: disable=g-import-not-at-top
                _, original_func = tf_decorator.unwrap(python_func)

                def wrapper(*args, **kwargs):
                    # Note: functions annotated with @tf.function should always be
                    # converted even though they would meet autograph's whitelisting
                    # criteria.
                    # If this assumption is ever broken, converted_call will need to
                    # handle the possibility of original_func still being a shim, e.g.
                    # bound to WeakrefSelf.
                    return autograph.converted_call(
                        original_func, None,
                        autograph.ConversionOptions(
                            verbose=autograph.Verbosity.BRIEF,
                            recursive=True,
                            strip_decorators=(def_function.function, ),
                            optional_features=autograph_options,
                            force_conversion=True,
                        ), args, kwargs)

                # Wrapping around a decorator allows checks like tf_inspect.getargspec
                # to be accurate.
                converted_func = tf_decorator.make_decorator(
                    original_func, wrapper)
                tf_decorator.rewrap(python_func, original_func, converted_func)

            func_outputs = python_func(*func_args, **func_kwargs)

            # invariant: `func_outputs` contains only Tensors, IndexedSlices,
            # SparseTensors, TensorArrays and `None`s.
            func_outputs = nest.map_structure(convert, func_outputs)

            check_mutation(func_args_before, func_args)
            check_mutation(func_kwargs_before, func_kwargs)
        finally:
            tape.pop_tape(this_tape)
            current_scope.set_use_resource(default_use_recource)

        # Variables in `func_args`, `func_kwargs` should be explicit inputs
        # to the function, not captured inputs.
        tape_variables = this_tape.watched_variables()
        arg_variables = set()
        inputs = []
        for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
            if isinstance(arg, resource_variable_ops.ResourceVariable):
                # Even if an argument variable was not used in the function, we've
                # already manually captured the resource Tensor when creating argument
                # placeholders.
                resource_placeholder = func_graph.captures.pop(
                    arg.handle, None)
                if resource_placeholder is None:
                    continue
                arg_variables.add(arg)
                inputs.append(resource_placeholder)
            elif isinstance(arg, ops.Tensor):
                inputs.append(arg)
        variables = [v for v in tape_variables if v not in arg_variables]
        func_graph.inputs = inputs + list(func_graph.captures.values())

        func_graph.structured_outputs = func_outputs
        # Returning a closed-over tensor does not trigger convert_to_tensor.
        func_graph.outputs.extend(
            func_graph.capture(x)
            for x in flatten(func_graph.structured_outputs) if x is not None)

        func_graph.variables = variables

    if add_control_dependencies:
        func_graph.control_outputs.extend(control_manager.ops_which_must_run)


# Register any other functions defined in the graph.
    with ops.init_scope():
        if context.executing_eagerly():
            for f in func_graph._functions.values():  # pylint: disable=protected-access
                # TODO(ashankar): What about the gradient registry?
                context.add_function(f._c_func.func)  # pylint: disable=protected-access

    return func_graph
Ejemplo n.º 32
0
 def _push_tape(self):
   if self._recording:
     raise ValueError("Tape is already recording.")
   self._tape = tape.push_new_tape(persistent=self._persistent)
   self._recording = True
Ejemplo n.º 33
0
 def __enter__(self):
   self._tape = tape.push_new_tape(persistent=self._persistent)
   return self
Ejemplo n.º 34
0
 def __enter__(self):
   tape.push_new_tape()
   return self
Ejemplo n.º 35
0
 def __enter__(self):
     tape.push_new_tape(persistent=self._persistent)
     return self
Ejemplo n.º 36
0
def func_graph_from_py_func(name,
                            python_func,
                            args,
                            kwargs,
                            signature=None,
                            func_graph=None,
                            experimental_autograph=False,
                            add_control_dependencies=True,
                            arg_names=None,
                            op_return_value=None):
    """Returns a `FuncGraph` generated from `python_func`.

  Args:
    name: an identifier for the function.
    python_func: the Python function to trace.
    args: the positional args with which the Python function should be called;
      ignored if a signature is provided.
    kwargs: the keyword args with which the Python function should be called;
      ignored if a signature is provided.
    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
      and dtypes of the arguments. When a signature is provided, `args` and
      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
      to `signature`. If `None`, the shapes and dtypes are inferred from the
      inputs.
    func_graph: Optional. An instance of FuncGraph. If provided, we will use
      this graph else a new one is built and returned.
    experimental_autograph: whether to use autograph to compile `python_func`.
      See https://www.tensorflow.org/guide/autograph for more information.
    add_control_dependencies: If True, automatically adds control dependencies
      to ensure program order matches execution order and stateful ops always
      execute.
    arg_names: Optional list of argument names, used to give input placeholders
      recognizable names.
    op_return_value: Optional. A Tensor. If set and `python_func` returns
      Operations, those return values will be replaced with this value. If not
      set, returning an Operation triggers an error.

  Returns:
    A FuncGraph.

  Raises:
    TypeError: If any of `python_func`'s return values is neither `None` nor a
      `Tensor`.
  """
    if op_return_value is not None:
        assert isinstance(op_return_value, ops.Tensor), op_return_value
    if func_graph is None:
        func_graph = FuncGraph(name)
    assert isinstance(func_graph, FuncGraph)
    if add_control_dependencies:
        control_manager = AutomaticControlDependencies
    else:
        control_manager = ops.NullContextmanager
    with func_graph.as_default(), control_manager() as a:
        current_scope = variable_scope.get_variable_scope()
        default_use_recource = current_scope.use_resource
        current_scope.set_use_resource(True)

        if signature is not None:
            args = signature
            kwargs = {}

        func_args = _get_defun_inputs_from_args(args, arg_names)
        func_kwargs = _get_defun_inputs_from_kwargs(kwargs)

        # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
        # Variables to help check whether mutation happens in calling the function
        # Copy the recursive list, tuple and map structure, but not base objects
        func_args_before = nest.pack_sequence_as(func_args,
                                                 nest.flatten(func_args))
        func_kwargs_before = nest.pack_sequence_as(func_kwargs,
                                                   nest.flatten(func_kwargs))

        def convert(x):
            """Converts a function output to a Tensor."""
            if x is None:
                return None
            if op_return_value is not None and isinstance(x, ops.Operation):
                # TODO(b/79881896): we currently can't capture external control deps, so
                # this won't work if x needs to be captured (i.e. if python_func returns
                # captured Operations).
                with ops.control_dependencies([x]):
                    x = array_ops.identity(op_return_value)
            else:
                try:
                    x = ops.convert_to_tensor_or_indexed_slices(x)
                except (ValueError, TypeError):
                    raise TypeError(
                        "To be compatible with tf.contrib.eager.defun, Python functions "
                        "must return zero or more Tensors; in compilation of %s, found "
                        "return value of type %s, which is not a Tensor." %
                        (str(python_func), type(x)))
            if add_control_dependencies:
                x = a.mark_as_return(x)
            return x

        this_tape = tape.push_new_tape()
        try:
            if experimental_autograph:
                func_outputs = autograph.converted_call(
                    python_func, None,
                    autograph.ConversionOptions(
                        verbose=True,
                        recursive=True,
                        strip_decorators=(function.defun, ),
                        optional_features=(),
                    ), *func_args, **func_kwargs)
            else:
                func_outputs = python_func(*func_args, **func_kwargs)
            # invariant: `func_outputs` contains only Tensors and `None`s.
            func_outputs = nest.map_structure(convert, func_outputs)

            check_mutation(func_args_before, func_args)
            check_mutation(func_kwargs_before, func_kwargs)
        finally:
            tape.pop_tape(this_tape)
            current_scope.set_use_resource(default_use_recource)

        # Variables in `func_args`, `func_kwargs` should be explicit inputs
        # to the function, not captured inputs.
        tape_variables = this_tape.watched_variables()
        arg_variables = set()
        inputs = []
        for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
            if isinstance(arg, resource_variable_ops.ResourceVariable):
                try:
                    resource_placeholder = func_graph.captures.pop(arg.handle)
                    arg_variables.add(arg)
                except KeyError:
                    # This case occurs if a Variable among the inputs is not actually
                    # used by the function; we still add an explicit input for it
                    # because the user should presumably pass the Variable as an input
                    # to the corresponding graph function.
                    resource_placeholder = _create_substitute_placeholder(
                        arg.handle)
                inputs.append(resource_placeholder)
            elif isinstance(arg, ops.Tensor):
                inputs.append(arg)
        variables = [v for v in tape_variables if v not in arg_variables]
        func_graph.inputs = inputs + list(func_graph.captures.values())

        func_graph.structured_outputs = func_outputs
        # Returning a closed-over tensor does not trigger convert_to_tensor.
        func_graph.outputs.extend(
            func_graph.capture(x)
            for x in flatten(func_graph.structured_outputs) if x is not None)

        func_graph.variables = variables

    # Register any other functions defined in the graph.
    with ops.init_scope():
        if context.executing_eagerly():
            for f in func_graph._functions.values():  # pylint: disable=protected-access
                # TODO(ashankar): What about the gradient registry?
                context.add_function(f._c_func.func)  # pylint: disable=protected-access

    return func_graph
Ejemplo n.º 37
0
def func_graph_from_py_func(name,
                            python_func,
                            args,
                            kwargs,
                            signature=None,
                            func_graph=None,
                            autograph=False,
                            add_control_dependencies=True,
                            arg_names=None,
                            op_return_value=None,
                            collections=None):
  """Returns a `FuncGraph` generated from `python_func`.

  Args:
    name: an identifier for the function.
    python_func: the Python function to trace.
    args: the positional args with which the Python function should be called;
      ignored if a signature is provided.
    kwargs: the keyword args with which the Python function should be called;
      ignored if a signature is provided.
    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
      and dtypes of the arguments. When a signature is provided, `args` and
      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
      to `signature`. If `None`, the shapes and dtypes are inferred from the
      inputs.
    func_graph: Optional. An instance of FuncGraph. If provided, we will use
      this graph else a new one is built and returned.
    autograph: whether to use autograph to compile `python_func`.
      See https://www.tensorflow.org/guide/autograph for more information.
    add_control_dependencies: If True, automatically adds control dependencies
      to ensure program order matches execution order and stateful ops always
      execute.
    arg_names: Optional list of argument names, used to give input placeholders
      recognizable names.
    op_return_value: Optional. A Tensor. If set and `python_func` returns
      Operations, those return values will be replaced with this value. If not
      set, returning an Operation triggers an error.
    collections: a dictionary of collections this FuncGraph should start
      with. If not specified (None), the FuncGraph will read (but not write to)
      the outer graph's collections that are not whitelisted, and both
      read and write to the outer graph's collections that are whitelisted.
      The current whitelisted collections are the global variables, the
      local variables, and the trainable variables.
      Defaults to None.

  Returns:
    A FuncGraph.

  Raises:
    TypeError: If any of `python_func`'s return values is neither `None` nor a
      `Tensor`.
  """
  if op_return_value is not None:
    assert isinstance(op_return_value, ops.Tensor), op_return_value
  if func_graph is None:
    func_graph = FuncGraph(name, collections=collections)
  assert isinstance(func_graph, FuncGraph)
  if add_control_dependencies:
    control_manager = AutomaticControlDependencies
  else:
    control_manager = ops.NullContextmanager
  with func_graph.as_default(), control_manager() as a:
    current_scope = variable_scope.get_variable_scope()
    default_use_recource = current_scope.use_resource
    current_scope.set_use_resource(True)

    if signature is not None:
      args = signature
      kwargs = {}

    # Creates and names placeholders for all arguments.
    func_args = _get_defun_inputs_from_args(args, arg_names)
    func_kwargs = _get_defun_inputs_from_kwargs(kwargs)

    # Convert all Tensors into TensorSpecs before saving the structured inputs.
    # If storing pure concrete functions that are not called through polymorphic
    # functions, we don't have access to FunctionSpec, so we need to call the
    # TensorSpecs by their `arg_names` for later binding.
    func_graph.structured_input_signature = (
        convert_structure_to_signature(func_args, arg_names),
        convert_structure_to_signature(func_kwargs))

    # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
    # Variables to help check whether mutation happens in calling the function
    # Copy the recursive list, tuple and map structure, but not base objects
    func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
    func_kwargs_before = nest.pack_sequence_as(
        func_kwargs, nest.flatten(func_kwargs))

    def convert(x):
      """Converts a function output to a Tensor."""
      if x is None:
        return None
      if op_return_value is not None and isinstance(x, ops.Operation):
        # TODO(b/79881896): we currently can't capture external control deps, so
        # this won't work if x needs to be captured (i.e. if python_func returns
        # captured Operations).
        with ops.control_dependencies([x]):
          x = array_ops.identity(op_return_value)
      elif not isinstance(x, tensor_array_ops.TensorArray):
        try:
          x = ops.convert_to_tensor_or_composite(x)
        except (ValueError, TypeError):
          raise TypeError(
              "To be compatible with tf.contrib.eager.defun, Python functions "
              "must return zero or more Tensors; in compilation of %s, found "
              "return value of type %s, which is not a Tensor." %
              (str(python_func), type(x)))
      if add_control_dependencies:
        x = a.mark_as_return(x)
      return x

    this_tape = tape.push_new_tape()
    try:
      if autograph:
        from tensorflow.python import autograph  # pylint: disable=g-import-not-at-top
        _, original_func = tf_decorator.unwrap(python_func)

        def wrapper(*args, **kwargs):
          # Note: functions annotated with @tf.function should always be
          # converted even though they would meet autograph's whitelisting
          # criteria.
          # If this assumption is ever broken, converted_call will need to
          # handle the possibility of original_func still being a shim, e.g.
          # bound to WeakrefSelf.
          return autograph.converted_call(
              original_func, None,
              autograph.ConversionOptions(
                  verbose=autograph.Verbosity.BRIEF,
                  recursive=True,
                  strip_decorators=(def_function.function,),
                  optional_features=(),
                  force_conversion=True,
              ), *args, **kwargs)

        # Wrapping around a decorator allows checks like tf_inspect.getargspec
        # to be accurate.
        converted_func = tf_decorator.make_decorator(original_func, wrapper)
        tf_decorator.rewrap(python_func, original_func, converted_func)

      func_outputs = python_func(*func_args, **func_kwargs)

      # invariant: `func_outputs` contains only Tensors, IndexedSlices,
      # SparseTensors, TensorArrays and `None`s.
      func_outputs = nest.map_structure(convert, func_outputs)

      check_mutation(func_args_before, func_args)
      check_mutation(func_kwargs_before, func_kwargs)
    finally:
      tape.pop_tape(this_tape)
      current_scope.set_use_resource(default_use_recource)

    # Variables in `func_args`, `func_kwargs` should be explicit inputs
    # to the function, not captured inputs.
    tape_variables = this_tape.watched_variables()
    arg_variables = set()
    inputs = []
    for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
      if isinstance(arg, resource_variable_ops.ResourceVariable):
        # Even if an argument variable was not used in the function, we've
        # already manually captured the resource Tensor when creating argument
        # placeholders.
        resource_placeholder = func_graph.captures.pop(arg.handle)
        arg_variables.add(arg)
        inputs.append(resource_placeholder)
      elif isinstance(arg, ops.Tensor):
        inputs.append(arg)
    variables = [v for v in tape_variables if v not in arg_variables]
    func_graph.inputs = inputs + list(func_graph.captures.values())

    func_graph.structured_outputs = func_outputs
    # Returning a closed-over tensor does not trigger convert_to_tensor.
    func_graph.outputs.extend(
        func_graph.capture(x)
        for x in flatten(func_graph.structured_outputs)
        if x is not None)

    func_graph.variables = variables

  # Register any other functions defined in the graph.
  with ops.init_scope():
    if context.executing_eagerly():
      for f in func_graph._functions.values():  # pylint: disable=protected-access
        # TODO(ashankar): What about the gradient registry?
        context.add_function(f._c_func.func)  # pylint: disable=protected-access

  return func_graph
Ejemplo n.º 38
0
 def __enter__(self):
   tape.push_new_tape()
   return self