Example #1
0
    def _capture_helper(self, tensor, name):
        if tensor.graph is not self._forward_graph:
            return super(_WhileBodyGradFuncGraph,
                         self)._capture_helper(tensor, name)

        while tensor.op.type == "Identity":
            # We do not accumulate the output of identity nodes so we try to capture
            # the input of the Identity node instead.
            tensor = tensor.op.inputs[0]

        captured_tensor = self._indirect_captures.get(tensor)
        if captured_tensor is not None:
            return captured_tensor

        # Resource tensors are not accumulated and handled specially.
        if tensor.dtype == dtypes.resource:
            return self._resource_capture_helper(tensor)

        # Create or find an existing accumulator output for `tensor` in the forward
        # graph, and fetch from this accumulator in the gradient graph to get the
        # raw intermediate value.
        accumulator = _get_accumulator(tensor)
        if accumulator is None:
            # Create the initial empty tensor list.
            #
            # Note: We clear the control dependencies to avoid a cycle in case a
            # control tensor has an input path to an output of the  forward While.
            #
            # E.g.:
            # x = tf.while_loop(...)
            # y = f(x)
            # with tf.control_dependencies([y]):
            #   tf.gradients(y, x)
            #
            # Since the EmptyTensorList is fed back into the forward While, not
            # removing the control edge would cause a cycle.
            with self._forward_graph.outer_graph.as_default():
                with util.clear_control_inputs():
                    tensor_list = list_ops.empty_tensor_list(
                        element_dtype=tensor.dtype,
                        element_shape=tensor.shape,
                        max_num_elements=self._maximum_iterations,
                        name=_build_accumulator_name(tensor))
            self.empty_tensor_lists.append(tensor_list)

            # Push the intermediate tensor to the tensor list. This captures
            # `tensor_list`.
            with self._forward_graph.as_default():
                accumulator = list_ops.tensor_list_push_back(
                    tensor_list, tensor)
            # Add the modified tensor list to the list of outputs. This output will be
            # all the accumulated values.
            self._forward_graph.outputs.append(accumulator)

            # Capture in the cond graph as well so the forward cond and body inputs
            # match.
            with self._forward_cond_graph.as_default():
                self._forward_cond_graph.capture(tensor_list)

        # Capture the accumulator tensor list in the gradient graph directly from
        # the forward graph -- we'll later modify this to capture the final list
        # output by the forward While op instead.
        captured_accumulator = super(_WhileBodyGradFuncGraph,
                                     self)._capture_helper(accumulator, name)

        # Pop the intermediate value from the tensor list in the gradient graph.
        new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
            captured_accumulator, element_dtype=tensor.dtype)

        self._indirect_captures[tensor] = captured_tensor
        self.popped_tensor_lists[captured_accumulator] = new_tensor_list
        return captured_tensor
Example #2
0
  def _capture_helper(self, tensor, name):
    if tensor.graph is not self._forward_graph:
      return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)

    while tensor.op.type == "Identity":
      # We do not accumulate the output of identity nodes so we try to capture
      # the input of the Identity node instead.
      tensor = tensor.op.inputs[0]

    captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor))
    if captured_tensor is not None:
      return captured_tensor

    # Do not accumulate loop invariants.
    if (any(tensor is t for t in self._forward_graph.inputs) and
        any(tensor is t for t in self._forward_graph.outputs)):
      captured_tensor = super(_WhileBodyGradFuncGraph,
                              self)._capture_helper(tensor, name)
      # Add to `popped_tensor_lists` so that this gets added to the list of
      # outputs.
      # TODO(srbs): Rename popped_tensor_lists.
      self.popped_tensor_lists[ops.tensor_id(captured_tensor)] = captured_tensor
      self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
      return captured_tensor

    # Do not accumulate Const nodes. Instead copy them directly in the backward
    # graph.
    # TODO(srbs): This just checks for `Const` nodes. Consider checking for
    # graph compile time consts in general.
    # TODO(srbs): Consider making this a loop input.
    if constant_op.is_constant(tensor):
      real_value = constant_op.constant(
          tensor_util.constant_value(tensor), dtype=tensor.dtype)
      self._indirect_captures[ops.tensor_id(tensor)] = real_value
      return real_value

    # Resource tensors are not accumulated and handled specially.
    if tensor.dtype == dtypes.resource:
      return self._resource_capture_helper(tensor)

    # No need to accumulate loop invariants. Capture them directly.
    # The captured tensor gets resolved to the corresponding while output in
    # `_resolve_grad_captures`.
    if _is_loop_invariant(tensor, self._forward_graph_inputs,
                          self._forward_graph_outputs):
      captured_tensor = super(_WhileBodyGradFuncGraph,
                              self)._capture_helper(tensor, name)
      return captured_tensor

    # Create or find an existing accumulator output for `tensor` in the forward
    # graph, and fetch from this accumulator in the gradient graph to get the
    # raw intermediate value.
    accumulator = _get_accumulator(tensor)
    if accumulator is None:
      # Create the initial empty tensor list.
      #
      # Note: We clear the control dependencies to avoid a cycle in case a
      # control tensor has an input path to an output of the  forward While.
      #
      # E.g.:
      # x = tf.while_loop(...)
      # y = f(x)
      # with tf.control_dependencies([y]):
      #   tf.gradients(y, x)
      #
      # Since the EmptyTensorList is fed back into the forward While, not
      # removing the control edge would cause a cycle.
      with self._forward_graph.outer_graph.as_default():
        with util.clear_control_inputs():
          tensor_list = list_ops.empty_tensor_list(
              element_dtype=tensor.dtype,
              element_shape=tensor.shape,
              max_num_elements=self._maximum_iterations,
              name=_build_accumulator_name(tensor))
      self.empty_tensor_lists.append(tensor_list)

      # Push the intermediate tensor to the tensor list. This captures
      # `tensor_list`.
      with self._forward_graph.as_default():
        accumulator = list_ops.tensor_list_push_back(tensor_list, tensor)
      # Add the modified tensor list to the list of outputs. This output will be
      # all the accumulated values.
      self._forward_graph.outputs.append(accumulator)

      # Capture in the cond graph as well so the forward cond and body inputs
      # match.
      with self._forward_cond_graph.as_default():
        self._forward_cond_graph.capture(tensor_list)

    # Capture the accumulator tensor list in the gradient graph directly from
    # the forward graph -- we'll later modify this to capture the final list
    # output by the forward While op instead.
    captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(
        accumulator, name)

    # Pop the intermediate value from the tensor list in the gradient graph.
    new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
        captured_accumulator, element_dtype=tensor.dtype)

    self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor
    self.popped_tensor_lists[ops.tensor_id(
        captured_accumulator)] = new_tensor_list
    return captured_tensor
Example #3
0
  def _capture_helper(self, tensor, name):
    if tensor.graph is not self._forward_graph:
      return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)

    while tensor.op.type == "Identity":
      # We do not accumulate the output of identity nodes so we try to capture
      # the input of the Identity node instead.
      tensor = tensor.op.inputs[0]

    captured_tensor = self._indirect_captures.get(tensor)
    if captured_tensor is not None:
      return captured_tensor

    # Resource tensors are not accumulated and handled specially.
    if tensor.dtype == dtypes.resource:
      return self._resource_capture_helper(tensor)

    # Create or find an existing accumulator output for `tensor` in the forward
    # graph, and fetch from this accumulator in the gradient graph to get the
    # raw intermediate value.
    accumulator = _get_accumulator(tensor)
    if accumulator is None:
      # Create the initial empty tensor list.
      #
      # Note: We clear the control dependencies to avoid a cycle in case a
      # control tensor has an input path to an output of the  forward While.
      #
      # E.g.:
      # x = tf.while_loop(...)
      # y = f(x)
      # with tf.control_dependencies([y]):
      #   tf.gradients(y, x)
      #
      # Since the EmptyTensorList is fed back into the forward While, not
      # removing the control edge would cause a cycle.
      with self._forward_graph.outer_graph.as_default():
        with util.clear_control_inputs():
          tensor_list = list_ops.empty_tensor_list(
              element_dtype=tensor.dtype,
              element_shape=tensor.shape,
              max_num_elements=self._maximum_iterations,
              name=_build_accumulator_name(tensor))
      self.empty_tensor_lists.append(tensor_list)

      # Push the intermediate tensor to the tensor list. This captures
      # `tensor_list`.
      with self._forward_graph.as_default():
        accumulator = list_ops.tensor_list_push_back(tensor_list, tensor)
      # Add the modified tensor list to the list of outputs. This output will be
      # all the accumulated values.
      self._forward_graph.outputs.append(accumulator)

      # Capture in the cond graph as well so the forward cond and body inputs
      # match.
      with self._forward_cond_graph.as_default():
        self._forward_cond_graph.capture(tensor_list)

    # Capture the accumulator tensor list in the gradient graph directly from
    # the forward graph -- we'll later modify this to capture the final list
    # output by the forward While op instead.
    captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper(
        accumulator, name)

    # Pop the intermediate value from the tensor list in the gradient graph.
    new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
        captured_accumulator, element_dtype=tensor.dtype)

    self._indirect_captures[tensor] = captured_tensor
    self.popped_tensor_lists[captured_accumulator] = new_tensor_list
    return captured_tensor