def test_roundtrip_sequence_of_scalars(self): x = tf.data.Dataset.range(5).map(lambda x: x * 2) serialized_bytes = tensorflow_serialization.serialize_dataset(x) y = tensorflow_serialization.deserialize_dataset(serialized_bytes) self.assertEqual(x.element_spec, y.element_spec) self.assertAllEqual([y_val for y_val in y], [x * 2 for x in range(5)])
def test_roundtrip_sequence_of_scalars(self): x = tf.data.Dataset.range(5).map(lambda x: x * 2) serialized_bytes = tensorflow_serialization.serialize_dataset(x) y = tensorflow_serialization.deserialize_dataset(serialized_bytes) self.assertEqual(tf.data.experimental.get_structure(x), tf.data.experimental.get_structure(y)) self.assertAllEqual([y_val for y_val in y], [x * 2 for x in range(5)])
def test_roundtrip_sequence_of_singleton_tuples(self): x = tf.data.Dataset.range(5).map(lambda x: (x, )) serialized_bytes = tensorflow_serialization.serialize_dataset(x) y = tensorflow_serialization.deserialize_dataset(serialized_bytes) self.assertEqual(x.element_spec, y.element_spec) expected_values = [(x, ) for x in range(5)] actual_values = self.evaluate([y_val for y_val in y]) self.assertAllEqual(expected_values, actual_values)
def test_roundtrip_sequence_of_tuples(self): x = tf.data.Dataset.range(5).map(lambda x: (x * 2, tf.cast( x, tf.int32), tf.cast(x - 1, tf.float32))) serialized_bytes = tensorflow_serialization.serialize_dataset(x) y = tensorflow_serialization.deserialize_dataset(serialized_bytes) self.assertEqual(x.element_spec, y.element_spec) self.assertAllEqual(self.evaluate([y_val for y_val in y]), [(x * 2, x, x - 1.) for x in range(5)])
def test_roundtrip_sequence_of_nested_structures(self): test_tuple_type = collections.namedtuple('TestTuple', ['u', 'v']) def _make_nested_tf_structure(x): return collections.OrderedDict([ ('b', tf.cast(x, tf.int32)), ('a', tuple([ x, test_tuple_type(x * 2, x * 3), collections.OrderedDict([('x', x**2), ('y', x**3)]) ])), ]) x = tf.data.Dataset.range(5).map(_make_nested_tf_structure) serialzied_bytes = tensorflow_serialization.serialize_dataset(x) y = tensorflow_serialization.deserialize_dataset(serialzied_bytes) # NOTE: TF loses the `OrderedDict` during serialization, so the expectation # here is for a `dict` in the result. self.assertEqual( tf.data.experimental.get_structure(y), { 'b': tf.TensorSpec([], tf.int32), 'a': tuple([ tf.TensorSpec([], tf.int64), test_tuple_type( tf.TensorSpec([], tf.int64), tf.TensorSpec([], tf.int64), ), { 'x': tf.TensorSpec([], tf.int64), 'y': tf.TensorSpec([], tf.int64), }, ]), }) def _build_expected_structure(x): return { 'b': x, 'a': tuple( [x, test_tuple_type(x * 2, x * 3), { 'x': x**2, 'y': x**3 }]) } actual_values = self.evaluate([y_val for y_val in y]) expected_values = [_build_expected_structure(x) for x in range(5)] for actual, expected in zip(actual_values, expected_values): self.assertAllClose(actual, expected)
def test_roundtrip_sequence_of_namedtuples(self): test_tuple_type = collections.namedtuple('TestTuple', ['a', 'b', 'c']) def make_test_tuple(x): return test_tuple_type( a=x * 2, b=tf.cast(x, tf.int32), c=tf.cast(x - 1, tf.float32)) x = tf.data.Dataset.range(5).map(make_test_tuple) serialized_bytes = tensorflow_serialization.serialize_dataset(x) y = tensorflow_serialization.deserialize_dataset(serialized_bytes) self.assertEqual(x.element_spec, y.element_spec) self.assertAllEqual( self.evaluate([y_val for y_val in y]), [test_tuple_type(a=x * 2, b=x, c=x - 1.) for x in range(5)])
def serialize_sequence_value(value): """Serializes a `tf.data.Dataset` value into `executor_pb2.Value`. Args: value: A `tf.data.Dataset`, or equivalent. Returns: A tuple `(value_proto, type_spec)` in which `value_proto` is an instance of `executor_pb2.Value` with the serialized content of `value`, and `type_spec` is the type of the serialized value. """ py_typecheck.check_type(value, type_utils.TF_DATASET_REPRESENTATION_TYPES) # TFF must store the type spec here because TF will lose the ordering of the # names for `tf.data.Dataset` that return elements of `collections.Mapping` # type. This allows TFF to preserve and restore the key ordering upon # deserialization. element_type = computation_types.to_type(value.element_spec) return executor_pb2.Value(sequence=executor_pb2.Value.Sequence( zipped_saved_model=tensorflow_serialization.serialize_dataset(value), element_type=type_serialization.serialize_type(element_type)))
def _serialize_sequence_value( value: Union[type_conversions.TF_DATASET_REPRESENTATION_TYPES], type_spec: computation_types.SequenceType) -> _SerializeReturnType: """Serializes a `tf.data.Dataset` value into `executor_pb2.Value`. Args: value: A `tf.data.Dataset`, or equivalent. type_spec: A `computation_types.Type` specifying the TFF sequence type of `value.` Returns: A tuple `(value_proto, type_spec)` in which `value_proto` is an instance of `executor_pb2.Value` with the serialized content of `value`, and `type_spec` is the type of the serialized value. """ if not isinstance(value, type_conversions.TF_DATASET_REPRESENTATION_TYPES): raise TypeError( 'Cannot serialize Python type {!s} as TFF type {!s}.'.format( py_typecheck.type_string(type(value)), type_spec if type_spec is not None else 'unknown')) value_type = computation_types.SequenceType( computation_types.to_type(value.element_spec)) if not type_spec.is_assignable_from(value_type): raise TypeError( 'Cannot serialize dataset with elements of type {!s} as TFF type {!s}.' .format(value_type, type_spec if type_spec is not None else 'unknown')) # TFF must store the type spec here because TF will lose the ordering of the # names for `tf.data.Dataset` that return elements of `collections.Mapping` # type. This allows TFF to preserve and restore the key ordering upon # deserialization. element_type = computation_types.to_type(value.element_spec) return executor_pb2.Value( sequence=executor_pb2.Value.Sequence( zipped_saved_model=tensorflow_serialization.serialize_dataset(value), element_type=type_serialization.serialize_type( element_type))), type_spec
def test_serialize_sequence_bytes_too_large(self): with self.assertRaisesRegex( ValueError, r'Serialized size .* exceeds maximum allowed'): _ = tensorflow_serialization.serialize_dataset( tf.data.Dataset.range(5), max_serialized_size_bytes=0)
def test_serialize_sequence_not_a_dataset(self): with self.assertRaisesRegex(TypeError, r'Expected .*Dataset.* found int'): _ = tensorflow_serialization.serialize_dataset(5)