Пример #1
0
def composite_tensor_to_variants(value, type_spec=None, name=None):
    """Encodes `value` as a scalar variant tensor.

  Args:
    value: The `ExtensionType` value to encode.
    type_spec: Information about the value's type that should be included in the
      encoding.
    name: Optional name for the operation.

  Returns:
    A Tensor with shape=`()` and dtype=`tf.variant`.

  Raises:
    ValueError: If `type_spec` is not compatible with `value`.
  """
    if not isinstance(value, composite_tensor.CompositeTensor):
        raise TypeError("Expected `value` to be a CompositeTensor.")

    if type_spec is None:
        type_spec = value._type_spec  # pylint: disable=protected-access
    if not type_spec.is_compatible_with(value):
        raise ValueError("TypeSpec %r is not compatible with value %r" %
                         (type_spec, value))
    coder = nested_structure_coder.StructureCoder()
    metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
    metadata.type_spec_proto.CopyFrom(
        coder.encode_structure(type_spec).type_spec_value)

    return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
        components=nest.flatten(value, expand_composites=True),
        metadata=metadata.SerializeToString(),
        name=name)
Пример #2
0
def _composite_tensor_from_variant_grad(op, *grad):
    assert len(grad) == len(op.outputs)
    # `components` is `op.outputs`, but with any tensors for which we're
    # taking the gradient replaced by the corresponding value from `grad`.
    components = [
        op.outputs[i] if grad[i] is None else grad[i] for i in range(len(grad))
    ]
    return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
        components=components, metadata=op.get_attr("metadata"))