def composite_tensor_to_variants(value, type_spec=None, name=None): """Encodes `value` as a scalar variant tensor. Args: value: The `ExtensionType` value to encode. type_spec: Information about the value's type that should be included in the encoding. name: Optional name for the operation. Returns: A Tensor with shape=`()` and dtype=`tf.variant`. Raises: ValueError: If `type_spec` is not compatible with `value`. """ if not isinstance(value, composite_tensor.CompositeTensor): raise TypeError("Expected `value` to be a CompositeTensor.") if type_spec is None: type_spec = value._type_spec # pylint: disable=protected-access if not type_spec.is_compatible_with(value): raise ValueError("TypeSpec %r is not compatible with value %r" % (type_spec, value)) coder = nested_structure_coder.StructureCoder() metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata() metadata.type_spec_proto.CopyFrom( coder.encode_structure(type_spec).type_spec_value) return gen_composite_tensor_ops.CompositeTensorVariantFromComponents( components=nest.flatten(value, expand_composites=True), metadata=metadata.SerializeToString(), name=name)
def _composite_tensor_from_variant_grad(op, *grad): assert len(grad) == len(op.outputs) # `components` is `op.outputs`, but with any tensors for which we're # taking the gradient replaced by the corresponding value from `grad`. components = [ op.outputs[i] if grad[i] is None else grad[i] for i in range(len(grad)) ] return gen_composite_tensor_ops.CompositeTensorVariantFromComponents( components=components, metadata=op.get_attr("metadata"))