def test_roundtrip_sequence_of_scalars(self): x = tf.data.Dataset.range(5).map(lambda x: x * 2) serialized_bytes = tensorflow_serialization.serialize_dataset(x) y = tensorflow_serialization.deserialize_dataset(serialized_bytes) self.assertEqual(x.element_spec, y.element_spec) self.assertAllEqual([y_val for y_val in y], [x * 2 for x in range(5)])
def deserialize_sequence_value(sequence_value_proto): """Deserializes a `tf.data.Dataset`. Args: sequence_value_proto: `Sequence` protocol buffer message. Returns: A tuple of `(tf.data.Dataset, tff.Type)`. """ py_typecheck.check_type(sequence_value_proto, executor_pb2.Value.Sequence) which_value = sequence_value_proto.WhichOneof('value') if which_value == 'zipped_saved_model': ds = tensorflow_serialization.deserialize_dataset( sequence_value_proto.zipped_saved_model) else: raise NotImplementedError( 'Deserializing Sequences enocded as {!s} has not been implemented'. format(which_value)) element_type = type_serialization.deserialize_type( sequence_value_proto.element_type) # If a serialized dataset had elements of nested structes of tensors (e.g. # `dict`, `OrderedDict`), the deserialized dataset will return `dict`, # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion). # # Since the dataset will only be used inside TFF, we wrap the dictionary # coming from TF in an `OrderedDict` when necessary (a type that both TF and # TFF understand), using the field order stored in the TFF type stored during # serialization. ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( ds, element_type) return ds, computation_types.SequenceType(element=element_type)
def test_roundtrip_sequence_of_scalars(self): x = tf.data.Dataset.range(5).map(lambda x: x * 2) serialized_bytes = tensorflow_serialization.serialize_dataset(x) y = tensorflow_serialization.deserialize_dataset(serialized_bytes) self.assertEqual(tf.data.experimental.get_structure(x), tf.data.experimental.get_structure(y)) self.assertAllEqual([y_val for y_val in y], [x * 2 for x in range(5)])
def test_roundtrip_sequence_of_singleton_tuples(self): x = tf.data.Dataset.range(5).map(lambda x: (x, )) serialized_bytes = tensorflow_serialization.serialize_dataset(x) y = tensorflow_serialization.deserialize_dataset(serialized_bytes) self.assertEqual(x.element_spec, y.element_spec) expected_values = [(x, ) for x in range(5)] actual_values = self.evaluate([y_val for y_val in y]) self.assertAllEqual(expected_values, actual_values)
def test_roundtrip_sequence_of_tuples(self): x = tf.data.Dataset.range(5).map(lambda x: (x * 2, tf.cast( x, tf.int32), tf.cast(x - 1, tf.float32))) serialized_bytes = tensorflow_serialization.serialize_dataset(x) y = tensorflow_serialization.deserialize_dataset(serialized_bytes) self.assertEqual(x.element_spec, y.element_spec) self.assertAllEqual(self.evaluate([y_val for y_val in y]), [(x * 2, x, x - 1.) for x in range(5)])
def test_roundtrip_sequence_of_nested_structures(self): test_tuple_type = collections.namedtuple('TestTuple', ['u', 'v']) def _make_nested_tf_structure(x): return collections.OrderedDict([ ('b', tf.cast(x, tf.int32)), ('a', tuple([ x, test_tuple_type(x * 2, x * 3), collections.OrderedDict([('x', x**2), ('y', x**3)]) ])), ]) x = tf.data.Dataset.range(5).map(_make_nested_tf_structure) serialzied_bytes = tensorflow_serialization.serialize_dataset(x) y = tensorflow_serialization.deserialize_dataset(serialzied_bytes) # NOTE: TF loses the `OrderedDict` during serialization, so the expectation # here is for a `dict` in the result. self.assertEqual( tf.data.experimental.get_structure(y), { 'b': tf.TensorSpec([], tf.int32), 'a': tuple([ tf.TensorSpec([], tf.int64), test_tuple_type( tf.TensorSpec([], tf.int64), tf.TensorSpec([], tf.int64), ), { 'x': tf.TensorSpec([], tf.int64), 'y': tf.TensorSpec([], tf.int64), }, ]), }) def _build_expected_structure(x): return { 'b': x, 'a': tuple( [x, test_tuple_type(x * 2, x * 3), { 'x': x**2, 'y': x**3 }]) } actual_values = self.evaluate([y_val for y_val in y]) expected_values = [_build_expected_structure(x) for x in range(5)] for actual, expected in zip(actual_values, expected_values): self.assertAllClose(actual, expected)
def test_roundtrip_sequence_of_namedtuples(self): test_tuple_type = collections.namedtuple('TestTuple', ['a', 'b', 'c']) def make_test_tuple(x): return test_tuple_type( a=x * 2, b=tf.cast(x, tf.int32), c=tf.cast(x - 1, tf.float32)) x = tf.data.Dataset.range(5).map(make_test_tuple) serialized_bytes = tensorflow_serialization.serialize_dataset(x) y = tensorflow_serialization.deserialize_dataset(serialized_bytes) self.assertEqual(x.element_spec, y.element_spec) self.assertAllEqual( self.evaluate([y_val for y_val in y]), [test_tuple_type(a=x * 2, b=x, c=x - 1.) for x in range(5)])