コード例 #1
0
  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        any(tensor is t for t in self._forward_graph.inputs) or
        any(tensor is t for t in self._forward_graph.outputs)):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so capture intermediates directly.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      if tensor not in self.external_captures:
        self.xla_intermediates.append(tensor)
        self.op_needs_rewrite = True
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    tensor_id = ops.tensor_id(tensor)
    captured_tensor = self._indirect_captures.get(tensor_id)
    if captured_tensor is not None:
      return captured_tensor

    # 'tensor' is an uncaptured intermediate in the forward graph.
    # If it is not a resource, we wrap it in an optional in the forward graph
    # and capture the optional normally. We then unwrap the captured optional
    # value in the gradient graph to get the raw intermediate value.
    # If it is a resource, we trace the resource upto the input in the forward
    # graph and capture that.

    if tensor.dtype == dtypes.resource:
      # Index of the forward graph input corresponding to the resource tensor.
      index = util.resource_input_index(
          tensor.name, [t.name for t in self._forward_graph.inputs],
          {op.name: op.node_def for op in self._forward_graph.get_operations()},
          self._forward_graph._functions)
      # This gets mapped to the corresponding If op input in
      # `_resolve_grad_inputs`.
      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
          self._forward_graph.inputs[index], name)
    else:
      if tensor_id not in self._wrapped_intermediates:
        # If the gradient has already been computed for this If op, 'tensor' may
        # already be wrapped.
        for consumer in tensor.consumers():
          if (consumer.type == "OptionalFromValue" and
              consumer.outputs[0] in self._forward_graph.outputs):
            optional = consumer.outputs[0]
            break
        else:
          # 'tensor' hasn't been wrapped, do it now.
          with self._forward_graph.as_default():
            optional = gen_dataset_ops.optional_from_value([tensor])
          self.op_needs_rewrite = True
        self._wrapped_intermediates[tensor_id] = optional

      optional = self._wrapped_intermediates[tensor_id]
      captured_optional = super(_CondGradFuncGraph,
                                self)._capture_helper(optional, name)
      captured_tensor = gen_dataset_ops.optional_get_value(
          captured_optional, [tensor.dtype], [tensor.shape])[0]

    self._indirect_captures[tensor_id] = captured_tensor
    return captured_tensor
コード例 #2
0
ファイル: cond_v2.py プロジェクト: aritratony/tensorflow
  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        tensor in self._forward_graph.inputs or
        tensor in self._forward_graph.outputs):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so capture intermediates directly.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      if tensor not in self.captures:
        self.xla_intermediates.append(tensor)
        self.if_op_needs_rewrite = True
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    captured_tensor = self._indirect_captures.get(tensor)
    if captured_tensor is not None:
      return captured_tensor

    # 'tensor' is an uncaptured intermediate in the forward graph.
    # If it is not a resource, we wrap it in an optional in the forward graph
    # and capture the optional normally. We then unwrap the captured optional
    # value in the gradient graph to get the raw intermediate value.
    # If it is a resource, we trace the resource upto the input in the forward
    # graph and capture that.

    if tensor.dtype == dtypes.resource:
      # Index of the forward graph input corresponding to the resource tensor.
      index = util.resource_input_index(
          tensor.name, [t.name for t in self._forward_graph.inputs],
          {op.name: op.node_def for op in self._forward_graph.get_operations()},
          self._forward_graph._functions)
      # This gets mapped to the corresponding If op input in
      # `_resolve_grad_inputs`.
      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
          self._forward_graph.inputs[index], name)
    else:
      if tensor not in self._wrapped_intermediates:
        # If the gradient has already been computed for this If op, 'tensor' may
        # already be wrapped.
        for consumer in tensor.consumers():
          if (consumer.type == "OptionalFromValue" and
              consumer.outputs[0] in self._forward_graph.outputs):
            optional = consumer.outputs[0]
            break
        else:
          # 'tensor' hasn't been wrapped, do it now.
          with self._forward_graph.as_default():
            optional = gen_dataset_ops.optional_from_value([tensor])
          self.if_op_needs_rewrite = True
        self._wrapped_intermediates[tensor] = optional

      optional = self._wrapped_intermediates[tensor]
      captured_optional = super(_CondGradFuncGraph,
                                self)._capture_helper(optional, name)
      captured_tensor = gen_dataset_ops.optional_get_value(
          captured_optional, [tensor.dtype], [tensor.shape])[0]

    self._indirect_captures[tensor] = captured_tensor
    return captured_tensor
コード例 #3
0
 def get_value(self, name=None):
     # TODO(b/110122868): Consolidate the restructuring logic with similar logic
     # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
     with ops.name_scope(name, "OptionalGetValue",
                         [self._variant_tensor]) as scope:
         # pylint: disable=protected-access
         return self._value_structure._from_tensor_list(
             gen_dataset_ops.optional_get_value(
                 self._variant_tensor,
                 name=scope,
                 output_types=self._value_structure._flat_types,
                 output_shapes=self._value_structure._flat_shapes))
コード例 #4
0
 def get_value(self, name=None):
   # TODO(b/110122868): Consolidate the restructuring logic with similar logic
   # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
   with ops.name_scope(name, "OptionalGetValue",
                       [self._variant_tensor]) as scope:
     # pylint: disable=protected-access
     return self._value_structure._from_tensor_list(
         gen_dataset_ops.optional_get_value(
             self._variant_tensor,
             name=scope,
             output_types=self._value_structure._flat_types,
             output_shapes=self._value_structure._flat_shapes))
コード例 #5
0
 def get_value(self, name=None):
     # TODO(b/110122868): Consolidate the restructuring logic with similar logic
     # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
     with ops.name_scope(name, "OptionalGetValue",
                         [self._variant_tensor]) as scope:
         return structure.from_tensor_list(
             self._value_structure,
             gen_dataset_ops.optional_get_value(
                 self._variant_tensor,
                 name=scope,
                 output_types=structure.get_flat_tensor_types(
                     self._value_structure),
                 output_shapes=structure.get_flat_tensor_shapes(
                     self._value_structure)))
コード例 #6
0
 def get_value(self, name=None):
   # TODO(b/110122868): Consolidate the restructuring logic with similar logic
   # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
   with ops.name_scope(name, "OptionalGetValue",
                       [self._variant_tensor]) as scope:
     with ops.colocate_with(self._variant_tensor):
       result = gen_dataset_ops.optional_get_value(
           self._variant_tensor,
           name=scope,
           output_types=structure.get_flat_tensor_types(self._element_spec),
           output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
     # NOTE: We do not colocate the deserialization of composite tensors
     # because not all ops are guaranteed to have non-GPU kernels.
     return structure.from_tensor_list(self._element_spec, result)
コード例 #7
0
    def _capture_helper(self, tensor, name):
        if (tensor.graph is not self._forward_graph
                or tensor in self._forward_graph.inputs
                or tensor in self._forward_graph.outputs):
            return super(_CondGradFuncGraph,
                         self)._capture_helper(tensor, name)

        if control_flow_util.InXlaContext(ops.get_default_graph()):
            # XLA does not yet support optionals, so capture intermediates directly.
            # TODO(skyewm,jpienaar): can XLA support optionals?
            if tensor not in self.captures:
                self.xla_intermediates.append(tensor)
                self.if_op_needs_rewrite = True
            return super(_CondGradFuncGraph,
                         self)._capture_helper(tensor, name)

        captured_tensor = self._indirect_captures.get(tensor)
        if captured_tensor is not None:
            return captured_tensor

        # 'tensor' is an uncaptured intermediate in the forward graph. We wrap it in
        # an optional in the forward graph and capture the optional normally. We
        # then unwrap the captured optional value in the gradient graph to get the
        # raw intermediate value.

        if tensor not in self._wrapped_intermediates:
            # If the gradient has already been computed for this If op, 'tensor' may
            # already be wrapped.
            for consumer in tensor.consumers():
                if (consumer.type == "OptionalFromValue" and
                        consumer.outputs[0] in self._forward_graph.outputs):
                    optional = consumer.outputs[0]
                    break
            else:
                # 'tensor' hasn't been wrapped, do it now.
                with self._forward_graph.as_default():
                    optional = gen_dataset_ops.optional_from_value([tensor])
                self.if_op_needs_rewrite = True

            self._wrapped_intermediates[tensor] = optional

        optional = self._wrapped_intermediates[tensor]
        captured_optional = super(_CondGradFuncGraph,
                                  self)._capture_helper(optional, name)
        captured_tensor = gen_dataset_ops.optional_get_value(
            captured_optional, [tensor.dtype], [tensor.shape])[0]
        self._indirect_captures[tensor] = captured_tensor
        return captured_tensor
コード例 #8
0
ファイル: cond_v2.py プロジェクト: terrytangyuan/tensorflow
  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        tensor in self._forward_graph.inputs or
        tensor in self._forward_graph.outputs):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    if control_flow_util.InXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so capture intermediates directly.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      if tensor not in self.captures:
        self.xla_intermediates.append(tensor)
        self.if_op_needs_rewrite = True
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    captured_tensor = self._indirect_captures.get(tensor)
    if captured_tensor is not None:
      return captured_tensor

    # 'tensor' is an uncaptured intermediate in the forward graph. We wrap it in
    # an optional in the forward graph and capture the optional normally. We
    # then unwrap the captured optional value in the gradient graph to get the
    # raw intermediate value.

    if tensor not in self._wrapped_intermediates:
      # If the gradient has already been computed for this If op, 'tensor' may
      # already be wrapped.
      for consumer in tensor.consumers():
        if (consumer.type == "OptionalFromValue"
            and consumer.outputs[0] in self._forward_graph.outputs):
          optional = consumer.outputs[0]
          break
      else:
        # 'tensor' hasn't been wrapped, do it now.
        with self._forward_graph.as_default():
          optional = gen_dataset_ops.optional_from_value([tensor])
        self.if_op_needs_rewrite = True

      self._wrapped_intermediates[tensor] = optional

    optional = self._wrapped_intermediates[tensor]
    captured_optional = super(_CondGradFuncGraph, self)._capture_helper(
        optional, name)
    captured_tensor = gen_dataset_ops.optional_get_value(
        captured_optional, [tensor.dtype], [tensor.shape])[0]
    self._indirect_captures[tensor] = captured_tensor
    return captured_tensor
コード例 #9
0
ファイル: optional_ops.py プロジェクト: zpdcqu/tensorflow
 def get_value(self, name=None):
   # TODO(b/110122868): Consolidate the restructuring logic with similar logic
   # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
   with ops.name_scope(name, "OptionalGetValue",
                       [self._variant_tensor]) as scope:
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(
             self._output_types,
             gen_dataset_ops.optional_get_value(
                 self._variant_tensor,
                 name=scope,
                 output_types=nest.flatten(
                     sparse.as_dense_types(self._output_types,
                                           self._output_classes)),
                 output_shapes=nest.flatten(
                     sparse.as_dense_shapes(self._output_shapes,
                                            self._output_classes)))),
         self._output_types, self._output_shapes, self._output_classes)
コード例 #10
0
ファイル: optional_ops.py プロジェクト: AnishShah/tensorflow
 def get_value(self, name=None):
   # TODO(b/110122868): Consolidate the restructuring logic with similar logic
   # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
   with ops.name_scope(name, "OptionalGetValue",
                       [self._variant_tensor]) as scope:
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(
             self._output_types,
             gen_dataset_ops.optional_get_value(
                 self._variant_tensor,
                 name=scope,
                 output_types=nest.flatten(
                     sparse.as_dense_types(self._output_types,
                                           self._output_classes)),
                 output_shapes=nest.flatten(
                     sparse.as_dense_shapes(self._output_shapes,
                                            self._output_classes)))),
         self._output_types, self._output_shapes, self._output_classes)
コード例 #11
0
ファイル: cond_v2.py プロジェクト: aeverall/tensorflow
  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        tensor in self._forward_graph.inputs or
        tensor in self._forward_graph.outputs):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    # 'tensor' is an intermediate in the forward graph. We find the corresonding
    # optional tensor, which is output from the If op, and capture it as
    # normal. We then unwrap the captured optional value to get the raw
    # intermediate value.
    for consumer in tensor.consumers():
      if (consumer.type == "OptionalFromValue"
          and consumer.outputs[0] in self._forward_graph.outputs):
        optional = consumer.outputs[0]
        captured_optional = super(_CondGradFuncGraph, self)._capture_helper(
            optional, name)
        return gen_dataset_ops.optional_get_value(
            captured_optional, [tensor.dtype], [tensor.shape])[0]
    raise ValueError(
        "Couldn't find OptionalFromValue consumer for tensor '%s'.\n"
        "This is an internal bug, please report at "
        "https://github.com/tensorflow/tensorflow/issues." % tensor.name)
コード例 #12
0
    def _capture_helper(self, tensor, name):
        if (tensor.graph is not self._forward_graph
                or tensor in self._forward_graph.inputs
                or tensor in self._forward_graph.outputs):
            return super(_CondGradFuncGraph,
                         self)._capture_helper(tensor, name)

        # 'tensor' is an intermediate in the forward graph. We find the corresonding
        # optional tensor, which is output from the If op, and capture it as
        # normal. We then unwrap the captured optional value to get the raw
        # intermediate value.
        for consumer in tensor.consumers():
            if (consumer.type == "OptionalFromValue"
                    and consumer.outputs[0] in self._forward_graph.outputs):
                optional = consumer.outputs[0]
                captured_optional = super(_CondGradFuncGraph,
                                          self)._capture_helper(
                                              optional, name)
                return gen_dataset_ops.optional_get_value(
                    captured_optional, [tensor.dtype], [tensor.shape])[0]
        raise ValueError(
            "Couldn't find OptionalFromValue consumer for tensor '%s'.\n"
            "This is an internal bug, please report at "
            "https://github.com/tensorflow/tensorflow/issues." % tensor.name)
コード例 #13
0
def _OptionalFromValueGrad(op, grad):
    return gen_dataset_ops.optional_get_value(grad,
                                              [t.dtype for t in op.inputs],
                                              [t.shape for t in op.inputs])
コード例 #14
0
ファイル: optional_grad.py プロジェクト: Wajih-O/tensorflow
def _OptionalFromValueGrad(op, grad):
  return gen_dataset_ops.optional_get_value(
      grad, [t.dtype for t in op.inputs], [t.shape for t in op.inputs])