def pack(value): """Returns a copy of `value` with fields packed in a single Variant. Args: value: An `ExtensionType` object. Returns: An `ExtensionType` object. """ if is_packed(value): return value spec = value._type_spec._tf_extension_type_with_packed(True) # pylint: disable=protected-access try: variant = composite_tensor_ops.composite_tensor_to_variants(value) except nested_structure_coder.NotEncodableError as e: # Note: the only time `_TypeSpecCodec.can_encode` returns False is if the # named type is not registered. The default error message would simply # tell the user that there is no encoder for the object, so we provide # a more useful message letting them know how to register the type. raise ValueError('ExtensionTypes must have a __name__ field in order ' 'to be packed.') from e return _create_object_from_type_and_dict( type(value), { '_tf_extension_type_cached_type_spec': spec, '_tf_extension_type_packed_variant': variant, })
def testRoundTripThroughTensorProto(self): value = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]]) encoded = composite_tensor_ops.composite_tensor_to_variants(value) proto = parsing_ops.SerializeTensor(tensor=encoded) parsed = parsing_ops.ParseTensor(serialized=proto, out_type=dtypes.variant) decoded = composite_tensor_ops.composite_tensor_from_variant( parsed, value._type_spec) self.assertAllEqual(value, decoded)
def testEncodeAndDecode(self, value_factory): value = value_factory() encoded = composite_tensor_ops.composite_tensor_to_variants(value) self.assertEqual(encoded.dtype, dtypes.variant) self.assertEqual(encoded.shape.rank, 0) decoded = composite_tensor_ops.composite_tensor_from_variant( encoded, value._type_spec) self.assertTrue(value._type_spec.is_compatible_with(decoded._type_spec)) value_components = nest.flatten(value, expand_composites=True) decoded_components = nest.flatten(decoded, expand_composites=True) self.assertLen(value_components, len(decoded_components)) for v, d in zip(value_components, decoded_components): self.assertAllEqual(v, d)
def testEncodingErrors(self, value, spec, message): with self.assertRaisesRegex(ValueError, message): composite_tensor_ops.composite_tensor_to_variants(value(), spec)
def testDecodingErrors(self, value, spec, message): encoded = composite_tensor_ops.composite_tensor_to_variants(value()) with self.assertRaisesRegex(errors.InvalidArgumentError, message): self.evaluate( composite_tensor_ops.composite_tensor_from_variant(encoded, spec))
def func(x): x2 = composite_tensor_ops.composite_tensor_to_variants(x * 2) x3 = composite_tensor_ops.composite_tensor_from_variant(x2, x._type_spec) return x3.with_values(x3.values * math_ops.range(6.0))