def testRoundTripThroughTensorProto(self): value = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]]) encoded = composite_tensor_ops.composite_tensor_to_variants(value) proto = parsing_ops.SerializeTensor(tensor=encoded) parsed = parsing_ops.ParseTensor(serialized=proto, out_type=dtypes.variant) decoded = composite_tensor_ops.composite_tensor_from_variant( parsed, value._type_spec) self.assertAllEqual(value, decoded)
def testEncodeAndDecode(self, value_factory): value = value_factory() encoded = composite_tensor_ops.composite_tensor_to_variants(value) self.assertEqual(encoded.dtype, dtypes.variant) self.assertEqual(encoded.shape.rank, 0) decoded = composite_tensor_ops.composite_tensor_from_variant( encoded, value._type_spec) self.assertTrue(value._type_spec.is_compatible_with(decoded._type_spec)) value_components = nest.flatten(value, expand_composites=True) decoded_components = nest.flatten(decoded, expand_composites=True) self.assertLen(value_components, len(decoded_components)) for v, d in zip(value_components, decoded_components): self.assertAllEqual(v, d)
def unpack(value): """Returns a copy of `value` with individual fields stored in __dict__. Args: value: An `ExtensionType` object. Returns: An `ExtensionType` object. """ if not is_packed(value): return value # pylint: disable=protected-access variant = value._tf_extension_type_packed_variant spec = value._tf_extension_type_cached_type_spec spec = spec._tf_extension_type_with_packed(False) return composite_tensor_ops.composite_tensor_from_variant(variant, spec)
def testDecodingErrors(self, value, spec, message): encoded = composite_tensor_ops.composite_tensor_to_variants(value()) with self.assertRaisesRegex(errors.InvalidArgumentError, message): self.evaluate( composite_tensor_ops.composite_tensor_from_variant(encoded, spec))
def func(x): x2 = composite_tensor_ops.composite_tensor_to_variants(x * 2) x3 = composite_tensor_ops.composite_tensor_from_variant(x2, x._type_spec) return x3.with_values(x3.values * math_ops.range(6.0))