Exemplo n.º 1
0
 def testRoundTripThroughTensorProto(self):
   value = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
   encoded = composite_tensor_ops.composite_tensor_to_variants(value)
   proto = parsing_ops.SerializeTensor(tensor=encoded)
   parsed = parsing_ops.ParseTensor(serialized=proto, out_type=dtypes.variant)
   decoded = composite_tensor_ops.composite_tensor_from_variant(
       parsed, value._type_spec)
   self.assertAllEqual(value, decoded)
Exemplo n.º 2
0
  def testEncodeAndDecode(self, value_factory):
    value = value_factory()

    encoded = composite_tensor_ops.composite_tensor_to_variants(value)
    self.assertEqual(encoded.dtype, dtypes.variant)
    self.assertEqual(encoded.shape.rank, 0)

    decoded = composite_tensor_ops.composite_tensor_from_variant(
        encoded, value._type_spec)
    self.assertTrue(value._type_spec.is_compatible_with(decoded._type_spec))
    value_components = nest.flatten(value, expand_composites=True)
    decoded_components = nest.flatten(decoded, expand_composites=True)
    self.assertLen(value_components, len(decoded_components))
    for v, d in zip(value_components, decoded_components):
      self.assertAllEqual(v, d)
Exemplo n.º 3
0
def unpack(value):
    """Returns a copy of `value` with individual fields stored in __dict__.

  Args:
    value: An `ExtensionType` object.

  Returns:
    An `ExtensionType` object.
  """
    if not is_packed(value):
        return value

    # pylint: disable=protected-access
    variant = value._tf_extension_type_packed_variant
    spec = value._tf_extension_type_cached_type_spec
    spec = spec._tf_extension_type_with_packed(False)
    return composite_tensor_ops.composite_tensor_from_variant(variant, spec)
Exemplo n.º 4
0
 def testDecodingErrors(self, value, spec, message):
   encoded = composite_tensor_ops.composite_tensor_to_variants(value())
   with self.assertRaisesRegex(errors.InvalidArgumentError, message):
     self.evaluate(
         composite_tensor_ops.composite_tensor_from_variant(encoded, spec))
Exemplo n.º 5
0
 def func(x):
   x2 = composite_tensor_ops.composite_tensor_to_variants(x * 2)
   x3 = composite_tensor_ops.composite_tensor_from_variant(x2, x._type_spec)
   return x3.with_values(x3.values * math_ops.range(6.0))