Ejemplo n.º 1
0
  def _resource_capture_helper(self, tensor):
    """Returns the captured resource tensor.

    Resource-type tensors are not accumulated. If a resource tensor exists in
    the loop body it must either be a loop input or an output of a nested While
    op inside the loop body which had captured the external resource.

    Args:
      tensor: the external resource Tensor to be captured.

    Returns:
      Tensor in this graph.
    """
    assert tensor.dtype == dtypes.resource

    index = util.resource_input_index(
        tensor.name, [t.name for t in self._forward_graph.inputs],
        {op.name: op.node_def for op in self._forward_graph.get_operations()},
        self._forward_graph._functions)

    input_placeholder = self._forward_graph.inputs[index]
    tensor_in_outer_graph = self._forward_graph._while.inputs[index]

    assert input_placeholder.dtype == dtypes.resource
    assert tensor_in_outer_graph.dtype == dtypes.resource
    # This must be a loop invariant.
    assert input_placeholder is self._forward_graph.outputs[index], (
        "Resource tensors must be loop invariants %s." % tensor_in_outer_graph)

    self._indirect_captures[ops.tensor_id(tensor)] = self.capture(
        tensor_in_outer_graph, whitelisted=True)
    return self._indirect_captures[ops.tensor_id(tensor)]
Ejemplo n.º 2
0
  def _resource_capture_helper(self, tensor):
    """Returns the captured resource tensor.

    Resource-type tensors are not accumulated. If a resource tensor exists in
    the loop body it must either be a loop input or an output of a nested While
    op inside the loop body which had captured the external resource.

    Args:
      tensor: the external resource Tensor to be captured.

    Returns:
      Tensor in this graph.
    """
    assert tensor.dtype == dtypes.resource

    index = util.resource_input_index(
        tensor.name, [t.name for t in self._forward_graph.inputs],
        {op.name: op.node_def for op in self._forward_graph.get_operations()},
        self._forward_graph._functions)

    input_placeholder = self._forward_graph.inputs[index]
    tensor_in_outer_graph = self._forward_graph._while.inputs[index]

    assert input_placeholder.dtype == dtypes.resource
    assert tensor_in_outer_graph.dtype == dtypes.resource
    # This must be a loop invariant.
    assert input_placeholder == self._forward_graph.outputs[index], (
        "Resource tensors must be loop invariants %s." %
        tensor_in_outer_graph)

    self._indirect_captures[tensor] = self.capture(
        tensor_in_outer_graph, whitelisted=True)
    return self._indirect_captures[tensor]
Ejemplo n.º 3
0
  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        any(tensor is t for t in self._forward_graph.inputs) or
        any(tensor is t for t in self._forward_graph.outputs)):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so capture intermediates directly.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      if tensor not in self.external_captures:
        self.xla_intermediates.append(tensor)
        self.op_needs_rewrite = True
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    tensor_id = ops.tensor_id(tensor)
    captured_tensor = self._indirect_captures.get(tensor_id)
    if captured_tensor is not None:
      return captured_tensor

    # 'tensor' is an uncaptured intermediate in the forward graph.
    # If it is not a resource, we wrap it in an optional in the forward graph
    # and capture the optional normally. We then unwrap the captured optional
    # value in the gradient graph to get the raw intermediate value.
    # If it is a resource, we trace the resource upto the input in the forward
    # graph and capture that.

    if tensor.dtype == dtypes.resource:
      # Index of the forward graph input corresponding to the resource tensor.
      index = util.resource_input_index(
          tensor.name, [t.name for t in self._forward_graph.inputs],
          {op.name: op.node_def for op in self._forward_graph.get_operations()},
          self._forward_graph._functions)
      # This gets mapped to the corresponding If op input in
      # `_resolve_grad_inputs`.
      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
          self._forward_graph.inputs[index], name)
    else:
      if tensor_id not in self._wrapped_intermediates:
        # If the gradient has already been computed for this If op, 'tensor' may
        # already be wrapped.
        for consumer in tensor.consumers():
          if (consumer.type == "OptionalFromValue" and
              consumer.outputs[0] in self._forward_graph.outputs):
            optional = consumer.outputs[0]
            break
        else:
          # 'tensor' hasn't been wrapped, do it now.
          with self._forward_graph.as_default():
            optional = gen_dataset_ops.optional_from_value([tensor])
          self.op_needs_rewrite = True
        self._wrapped_intermediates[tensor_id] = optional

      optional = self._wrapped_intermediates[tensor_id]
      captured_optional = super(_CondGradFuncGraph,
                                self)._capture_helper(optional, name)
      captured_tensor = gen_dataset_ops.optional_get_value(
          captured_optional, [tensor.dtype], [tensor.shape])[0]

    self._indirect_captures[tensor_id] = captured_tensor
    return captured_tensor
Ejemplo n.º 4
0
  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        tensor in self._forward_graph.inputs or
        tensor in self._forward_graph.outputs):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so capture intermediates directly.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      if tensor not in self.captures:
        self.xla_intermediates.append(tensor)
        self.if_op_needs_rewrite = True
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    captured_tensor = self._indirect_captures.get(tensor)
    if captured_tensor is not None:
      return captured_tensor

    # 'tensor' is an uncaptured intermediate in the forward graph.
    # If it is not a resource, we wrap it in an optional in the forward graph
    # and capture the optional normally. We then unwrap the captured optional
    # value in the gradient graph to get the raw intermediate value.
    # If it is a resource, we trace the resource upto the input in the forward
    # graph and capture that.

    if tensor.dtype == dtypes.resource:
      # Index of the forward graph input corresponding to the resource tensor.
      index = util.resource_input_index(
          tensor.name, [t.name for t in self._forward_graph.inputs],
          {op.name: op.node_def for op in self._forward_graph.get_operations()},
          self._forward_graph._functions)
      # This gets mapped to the corresponding If op input in
      # `_resolve_grad_inputs`.
      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
          self._forward_graph.inputs[index], name)
    else:
      if tensor not in self._wrapped_intermediates:
        # If the gradient has already been computed for this If op, 'tensor' may
        # already be wrapped.
        for consumer in tensor.consumers():
          if (consumer.type == "OptionalFromValue" and
              consumer.outputs[0] in self._forward_graph.outputs):
            optional = consumer.outputs[0]
            break
        else:
          # 'tensor' hasn't been wrapped, do it now.
          with self._forward_graph.as_default():
            optional = gen_dataset_ops.optional_from_value([tensor])
          self.if_op_needs_rewrite = True
        self._wrapped_intermediates[tensor] = optional

      optional = self._wrapped_intermediates[tensor]
      captured_optional = super(_CondGradFuncGraph,
                                self)._capture_helper(optional, name)
      captured_tensor = gen_dataset_ops.optional_get_value(
          captured_optional, [tensor.dtype], [tensor.shape])[0]

    self._indirect_captures[tensor] = captured_tensor
    return captured_tensor