def test_roundtrip_sequence_of_scalars(self): x = tf.data.Dataset.range(5).map(lambda x: x * 2) serialized_bytes = value_serialization._serialize_dataset(x) y = value_serialization._deserialize_dataset_from_graph_def( serialized_bytes, element_type=computation_types.to_type(x.element_spec)) self.assertEqual(x.element_spec, y.element_spec) self.assertAllEqual(list(y), [x * 2 for x in range(5)])
def test_roundtrip_sequence_of_tuples(self): x = tf.data.Dataset.range(5).map( lambda x: (x * 2, tf.cast(x, tf.int32), tf.cast(x - 1, tf.float32))) serialized_bytes = value_serialization._serialize_dataset(x) y = value_serialization._deserialize_dataset_from_graph_def( serialized_bytes, element_type=computation_types.to_type(x.element_spec)) self.assertEqual(x.element_spec, y.element_spec) self.assertAllEqual(list(y), [(x * 2, x, x - 1.) for x in range(5)])
def test_roundtrip_sequence_of_singleton_tuples(self): x = tf.data.Dataset.range(5).map(lambda x: (x,)) serialized_bytes = value_serialization._serialize_dataset(x) y = value_serialization._deserialize_dataset_from_graph_def( serialized_bytes, element_type=computation_types.to_type(x.element_spec)) self.assertEqual(x.element_spec, y.element_spec) expected_values = [(x,) for x in range(5)] actual_values = list(y) self.assertAllEqual(expected_values, actual_values)
def test_roundtrip_sequence_of_namedtuples(self): test_tuple_type = collections.namedtuple('TestTuple', ['a', 'b', 'c']) def make_test_tuple(x): return test_tuple_type( a=x * 2, b=tf.cast(x, tf.int32), c=tf.cast(x - 1, tf.float32)) x = tf.data.Dataset.range(5).map(make_test_tuple) serialized_bytes = value_serialization._serialize_dataset(x) y = value_serialization._deserialize_dataset_from_graph_def( serialized_bytes, element_type=computation_types.to_type(x.element_spec)) self.assertEqual(x.element_spec, y.element_spec) self.assertAllEqual( list(y), [test_tuple_type(a=x * 2, b=x, c=x - 1.) for x in range(5)])
def test_roundtrip_sequence_of_nested_structures(self): test_tuple_type = collections.namedtuple('TestTuple', ['u', 'v']) def _make_nested_tf_structure(x): return collections.OrderedDict( b=tf.cast(x, tf.int32), a=( x, test_tuple_type(x * 2, x * 3), collections.OrderedDict(x=x**2, y=x**3), )) x = tf.data.Dataset.range(5).map(_make_nested_tf_structure) serialzied_bytes = value_serialization._serialize_dataset(x) y = value_serialization._deserialize_dataset_from_graph_def( serialzied_bytes, element_type=computation_types.to_type(x.element_spec)) # Note: TF loses the `OrderedDict` during serialization, so the expectation # here is for a `dict` in the result. expected_element_spec = collections.OrderedDict( b=tf.TensorSpec([], tf.int32), a=(tf.TensorSpec([], tf.int64), test_tuple_type( tf.TensorSpec([], tf.int64), tf.TensorSpec([], tf.int64)), collections.OrderedDict( x=tf.TensorSpec([], tf.int64), y=tf.TensorSpec([], tf.int64)))) self.assertEqual(y.element_spec, expected_element_spec) def _build_expected_structure(x): return collections.OrderedDict( b=x, a=( x, test_tuple_type(x * 2, x * 3), collections.OrderedDict(x=x**2, y=x**3), )) expected_values = (_build_expected_structure(x) for x in range(5)) for actual, expected in zip(y, expected_values): self.assertAllClose(actual, expected)
def test_serialize_sequence_bytes_too_large(self): with self.assertRaisesRegex(ValueError, r'Serialized size .* exceeds maximum allowed'): _ = value_serialization._serialize_dataset( tf.data.Dataset.range(5), max_serialized_size_bytes=0)
def test_serialize_sequence_not_a_dataset(self): with self.assertRaisesRegex(TypeError, r'Expected .*Dataset.* found int'): _ = value_serialization._serialize_dataset(5)