예제 #1
0
  def from_value(value):
    """Returns an `Optional` that wraps the given value.

    Args:
      value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.

    Returns:
      An `Optional` that wraps `value`.
    """
    # TODO(b/110122868): Consolidate this destructuring logic with the
    # similar code in `Dataset.from_tensors()`.
    with ops.name_scope("optional") as scope:
      with ops.name_scope("value"):
        value = nest.pack_sequence_as(value, [
            sparse_tensor_lib.SparseTensor.from_value(t)
            if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
                t, name="component_%d" % i)
            for i, t in enumerate(nest.flatten(value))
        ])

      encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value))
      output_classes = sparse.get_classes(value)
      output_shapes = nest.pack_sequence_as(
          value, [t.get_shape() for t in nest.flatten(value)])
      output_types = nest.pack_sequence_as(
          value, [t.dtype for t in nest.flatten(value)])

    return _OptionalImpl(
        gen_dataset_ops.optional_from_value(encoded_value, name=scope),
        output_shapes, output_types, output_classes)
예제 #2
0
  def from_value(value):
    """Returns an `Optional` that wraps the given value.

    Args:
      value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.

    Returns:
      An `Optional` that wraps `value`.
    """
    # TODO(b/110122868): Consolidate this destructuring logic with the
    # similar code in `Dataset.from_tensors()`.
    with ops.name_scope("optional") as scope:
      with ops.name_scope("value"):
        value = nest.pack_sequence_as(value, [
            sparse_tensor_lib.SparseTensor.from_value(t)
            if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
                t, name="component_%d" % i)
            for i, t in enumerate(nest.flatten(value))
        ])

      encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value))
      output_classes = sparse.get_classes(value)
      output_shapes = nest.pack_sequence_as(
          value, [t.get_shape() for t in nest.flatten(value)])
      output_types = nest.pack_sequence_as(
          value, [t.dtype for t in nest.flatten(value)])

    return _OptionalImpl(
        gen_dataset_ops.optional_from_value(encoded_value, name=scope),
        output_shapes, output_types, output_classes)
예제 #3
0
  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        any(tensor is t for t in self._forward_graph.inputs) or
        any(tensor is t for t in self._forward_graph.outputs)):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so capture intermediates directly.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      if tensor not in self.external_captures:
        self.xla_intermediates.append(tensor)
        self.op_needs_rewrite = True
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    tensor_id = ops.tensor_id(tensor)
    captured_tensor = self._indirect_captures.get(tensor_id)
    if captured_tensor is not None:
      return captured_tensor

    # 'tensor' is an uncaptured intermediate in the forward graph.
    # If it is not a resource, we wrap it in an optional in the forward graph
    # and capture the optional normally. We then unwrap the captured optional
    # value in the gradient graph to get the raw intermediate value.
    # If it is a resource, we trace the resource upto the input in the forward
    # graph and capture that.

    if tensor.dtype == dtypes.resource:
      # Index of the forward graph input corresponding to the resource tensor.
      index = util.resource_input_index(
          tensor.name, [t.name for t in self._forward_graph.inputs],
          {op.name: op.node_def for op in self._forward_graph.get_operations()},
          self._forward_graph._functions)
      # This gets mapped to the corresponding If op input in
      # `_resolve_grad_inputs`.
      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
          self._forward_graph.inputs[index], name)
    else:
      if tensor_id not in self._wrapped_intermediates:
        # If the gradient has already been computed for this If op, 'tensor' may
        # already be wrapped.
        for consumer in tensor.consumers():
          if (consumer.type == "OptionalFromValue" and
              consumer.outputs[0] in self._forward_graph.outputs):
            optional = consumer.outputs[0]
            break
        else:
          # 'tensor' hasn't been wrapped, do it now.
          with self._forward_graph.as_default():
            optional = gen_dataset_ops.optional_from_value([tensor])
          self.op_needs_rewrite = True
        self._wrapped_intermediates[tensor_id] = optional

      optional = self._wrapped_intermediates[tensor_id]
      captured_optional = super(_CondGradFuncGraph,
                                self)._capture_helper(optional, name)
      captured_tensor = gen_dataset_ops.optional_get_value(
          captured_optional, [tensor.dtype], [tensor.shape])[0]

    self._indirect_captures[tensor_id] = captured_tensor
    return captured_tensor
예제 #4
0
  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        tensor in self._forward_graph.inputs or
        tensor in self._forward_graph.outputs):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so capture intermediates directly.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      if tensor not in self.captures:
        self.xla_intermediates.append(tensor)
        self.if_op_needs_rewrite = True
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    captured_tensor = self._indirect_captures.get(tensor)
    if captured_tensor is not None:
      return captured_tensor

    # 'tensor' is an uncaptured intermediate in the forward graph.
    # If it is not a resource, we wrap it in an optional in the forward graph
    # and capture the optional normally. We then unwrap the captured optional
    # value in the gradient graph to get the raw intermediate value.
    # If it is a resource, we trace the resource upto the input in the forward
    # graph and capture that.

    if tensor.dtype == dtypes.resource:
      # Index of the forward graph input corresponding to the resource tensor.
      index = util.resource_input_index(
          tensor.name, [t.name for t in self._forward_graph.inputs],
          {op.name: op.node_def for op in self._forward_graph.get_operations()},
          self._forward_graph._functions)
      # This gets mapped to the corresponding If op input in
      # `_resolve_grad_inputs`.
      captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
          self._forward_graph.inputs[index], name)
    else:
      if tensor not in self._wrapped_intermediates:
        # If the gradient has already been computed for this If op, 'tensor' may
        # already be wrapped.
        for consumer in tensor.consumers():
          if (consumer.type == "OptionalFromValue" and
              consumer.outputs[0] in self._forward_graph.outputs):
            optional = consumer.outputs[0]
            break
        else:
          # 'tensor' hasn't been wrapped, do it now.
          with self._forward_graph.as_default():
            optional = gen_dataset_ops.optional_from_value([tensor])
          self.if_op_needs_rewrite = True
        self._wrapped_intermediates[tensor] = optional

      optional = self._wrapped_intermediates[tensor]
      captured_optional = super(_CondGradFuncGraph,
                                self)._capture_helper(optional, name)
      captured_tensor = gen_dataset_ops.optional_get_value(
          captured_optional, [tensor.dtype], [tensor.shape])[0]

    self._indirect_captures[tensor] = captured_tensor
    return captured_tensor
예제 #5
0
    def _capture_helper(self, tensor, name):
        if (tensor.graph is not self._forward_graph
                or tensor in self._forward_graph.inputs
                or tensor in self._forward_graph.outputs):
            return super(_CondGradFuncGraph,
                         self)._capture_helper(tensor, name)

        if control_flow_util.InXlaContext(ops.get_default_graph()):
            # XLA does not yet support optionals, so capture intermediates directly.
            # TODO(skyewm,jpienaar): can XLA support optionals?
            if tensor not in self.captures:
                self.xla_intermediates.append(tensor)
                self.if_op_needs_rewrite = True
            return super(_CondGradFuncGraph,
                         self)._capture_helper(tensor, name)

        captured_tensor = self._indirect_captures.get(tensor)
        if captured_tensor is not None:
            return captured_tensor

        # 'tensor' is an uncaptured intermediate in the forward graph. We wrap it in
        # an optional in the forward graph and capture the optional normally. We
        # then unwrap the captured optional value in the gradient graph to get the
        # raw intermediate value.

        if tensor not in self._wrapped_intermediates:
            # If the gradient has already been computed for this If op, 'tensor' may
            # already be wrapped.
            for consumer in tensor.consumers():
                if (consumer.type == "OptionalFromValue" and
                        consumer.outputs[0] in self._forward_graph.outputs):
                    optional = consumer.outputs[0]
                    break
            else:
                # 'tensor' hasn't been wrapped, do it now.
                with self._forward_graph.as_default():
                    optional = gen_dataset_ops.optional_from_value([tensor])
                self.if_op_needs_rewrite = True

            self._wrapped_intermediates[tensor] = optional

        optional = self._wrapped_intermediates[tensor]
        captured_optional = super(_CondGradFuncGraph,
                                  self)._capture_helper(optional, name)
        captured_tensor = gen_dataset_ops.optional_get_value(
            captured_optional, [tensor.dtype], [tensor.shape])[0]
        self._indirect_captures[tensor] = captured_tensor
        return captured_tensor
예제 #6
0
    def from_value(value):
        """Returns an `Optional` that wraps the given value.

    Args:
      value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.

    Returns:
      An `Optional` that wraps `value`.
    """
        with ops.name_scope("optional") as scope:
            with ops.name_scope("value"):
                value_structure = type_spec.type_spec_from_value(value)
                encoded_value = value_structure._to_tensor_list(value)  # pylint: disable=protected-access

        return _OptionalImpl(
            gen_dataset_ops.optional_from_value(encoded_value, name=scope),
            value_structure)
예제 #7
0
  def _capture_helper(self, tensor, name):
    if (tensor.graph is not self._forward_graph or
        tensor in self._forward_graph.inputs or
        tensor in self._forward_graph.outputs):
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    if control_flow_util.InXlaContext(ops.get_default_graph()):
      # XLA does not yet support optionals, so capture intermediates directly.
      # TODO(skyewm,jpienaar): can XLA support optionals?
      if tensor not in self.captures:
        self.xla_intermediates.append(tensor)
        self.if_op_needs_rewrite = True
      return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

    captured_tensor = self._indirect_captures.get(tensor)
    if captured_tensor is not None:
      return captured_tensor

    # 'tensor' is an uncaptured intermediate in the forward graph. We wrap it in
    # an optional in the forward graph and capture the optional normally. We
    # then unwrap the captured optional value in the gradient graph to get the
    # raw intermediate value.

    if tensor not in self._wrapped_intermediates:
      # If the gradient has already been computed for this If op, 'tensor' may
      # already be wrapped.
      for consumer in tensor.consumers():
        if (consumer.type == "OptionalFromValue"
            and consumer.outputs[0] in self._forward_graph.outputs):
          optional = consumer.outputs[0]
          break
      else:
        # 'tensor' hasn't been wrapped, do it now.
        with self._forward_graph.as_default():
          optional = gen_dataset_ops.optional_from_value([tensor])
        self.if_op_needs_rewrite = True

      self._wrapped_intermediates[tensor] = optional

    optional = self._wrapped_intermediates[tensor]
    captured_optional = super(_CondGradFuncGraph, self)._capture_helper(
        optional, name)
    captured_tensor = gen_dataset_ops.optional_get_value(
        captured_optional, [tensor.dtype], [tensor.shape])[0]
    self._indirect_captures[tensor] = captured_tensor
    return captured_tensor
예제 #8
0
  def from_value(value):
    """Returns an `Optional` that wraps the given value.

    Args:
      value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.

    Returns:
      An `Optional` that wraps `value`.
    """
    with ops.name_scope("optional") as scope:
      with ops.name_scope("value"):
        value_structure = structure.Structure.from_value(value)
        encoded_value = value_structure._to_tensor_list(value)  # pylint: disable=protected-access

    return _OptionalImpl(
        gen_dataset_ops.optional_from_value(encoded_value, name=scope),
        value_structure)
예제 #9
0
  def from_value(value):
    """Returns an `Optional` that wraps the given value.

    Args:
      value: A value to wrap. The value must be convertible to `Tensor` or
        `CompositeTensor`.

    Returns:
      An `Optional` that wraps `value`.
    """
    with ops.name_scope("optional") as scope:
      with ops.name_scope("value"):
        value_structure = structure.type_spec_from_value(value)
        encoded_value = structure.to_tensor_list(value_structure, value)

    return _OptionalImpl(
        gen_dataset_ops.optional_from_value(encoded_value, name=scope),
        value_structure)
예제 #10
0
    def from_value(value):
        """Returns a `tf.experimental.Optional` that wraps the given value.

    >>> optional = tf.experimental.Optional.from_value(42)
    >>> print(optional.has_value())
    tf.Tensor(True, shape=(), dtype=bool)
    >>> print(optional.get_value())
    tf.Tensor(42, shape=(), dtype=int32)

    Args:
      value: A value to wrap. The value must be convertible to `tf.Tensor` or
        `tf.CompositeTensor`.

    Returns:
      A `tf.experimental.Optional` that wraps `value`.
    """
        with ops.name_scope("optional") as scope:
            with ops.name_scope("value"):
                element_spec = structure.type_spec_from_value(value)
                encoded_value = structure.to_tensor_list(element_spec, value)

        return _OptionalImpl(
            gen_dataset_ops.optional_from_value(encoded_value, name=scope),
            element_spec)
예제 #11
0
def _wrap_intermediates(func_graph, intermediates):
    with func_graph.as_default():
        return [
            gen_dataset_ops.optional_from_value([t]) for t in intermediates
        ]
예제 #12
0
def _OptionalGetValueGrad(unused_op, *grads):
    return gen_dataset_ops.optional_from_value(grads)
예제 #13
0
def _OptionalGetValueGrad(unused_op, *grads):
  return gen_dataset_ops.optional_from_value(grads)
예제 #14
0
def _wrap_intermediates(func_graph, intermediates):
  with func_graph.as_default():
    return [gen_dataset_ops.optional_from_value([t]) for t in intermediates]