예제 #1
0
def pack(value):
    """Returns a copy of `value` with fields packed in a single Variant.

  Args:
    value: An `ExtensionType` object.

  Returns:
    An `ExtensionType` object.
  """
    if is_packed(value):
        return value

    spec = value._type_spec._tf_extension_type_with_packed(True)  # pylint: disable=protected-access
    try:
        variant = composite_tensor_ops.composite_tensor_to_variants(value)
    except nested_structure_coder.NotEncodableError as e:
        # Note: the only time `_TypeSpecCodec.can_encode` returns False is if the
        # named type is not registered.  The default error message would simply
        # tell the user that there is no encoder for the object, so we provide
        # a more useful message letting them know how to register the type.
        raise ValueError('ExtensionTypes must have a __name__ field in order '
                         'to be packed.') from e

    return _create_object_from_type_and_dict(
        type(value), {
            '_tf_extension_type_cached_type_spec': spec,
            '_tf_extension_type_packed_variant': variant,
        })
예제 #2
0
 def testRoundTripThroughTensorProto(self):
   value = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
   encoded = composite_tensor_ops.composite_tensor_to_variants(value)
   proto = parsing_ops.SerializeTensor(tensor=encoded)
   parsed = parsing_ops.ParseTensor(serialized=proto, out_type=dtypes.variant)
   decoded = composite_tensor_ops.composite_tensor_from_variant(
       parsed, value._type_spec)
   self.assertAllEqual(value, decoded)
예제 #3
0
  def testEncodeAndDecode(self, value_factory):
    value = value_factory()

    encoded = composite_tensor_ops.composite_tensor_to_variants(value)
    self.assertEqual(encoded.dtype, dtypes.variant)
    self.assertEqual(encoded.shape.rank, 0)

    decoded = composite_tensor_ops.composite_tensor_from_variant(
        encoded, value._type_spec)
    self.assertTrue(value._type_spec.is_compatible_with(decoded._type_spec))
    value_components = nest.flatten(value, expand_composites=True)
    decoded_components = nest.flatten(decoded, expand_composites=True)
    self.assertLen(value_components, len(decoded_components))
    for v, d in zip(value_components, decoded_components):
      self.assertAllEqual(v, d)
예제 #4
0
 def testEncodingErrors(self, value, spec, message):
   with self.assertRaisesRegex(ValueError, message):
     composite_tensor_ops.composite_tensor_to_variants(value(), spec)
예제 #5
0
 def testDecodingErrors(self, value, spec, message):
   encoded = composite_tensor_ops.composite_tensor_to_variants(value())
   with self.assertRaisesRegex(errors.InvalidArgumentError, message):
     self.evaluate(
         composite_tensor_ops.composite_tensor_from_variant(encoded, spec))
예제 #6
0
 def func(x):
   x2 = composite_tensor_ops.composite_tensor_to_variants(x * 2)
   x3 = composite_tensor_ops.composite_tensor_from_variant(x2, x._type_spec)
   return x3.with_values(x3.values * math_ops.range(6.0))