def test_empty_dataset(self): ds = tf.data.Dataset.range(-1) assert ds.cardinality() == 0 encoded = data_processing.to_stacked_tensor(ds) self.assertAllEqual(encoded, list())
def test_non_scalar_tensor(self): ds = tf.data.Dataset.from_tensors([[1, 2], [3, 4]]) assert len(ds.element_spec.shape) > 1, ds.element_spec.shape encoded = data_processing.to_stacked_tensor(ds) self.assertAllEqual(encoded, [[[1, 2], [3, 4]]])
def _parse_client_dict(dataset: tf.data.Dataset, string_max_length: int) -> Tuple[tf.Tensor, tf.Tensor]: """Parses the dictionary in the input `dataset` to key and value lists. Args: dataset: A `tf.data.Dataset` that yields `OrderedDict`. In each `OrderedDict` there are two key, value pairs: `DATASET_KEY`: A `tf.string` representing a string in the dataset. `DATASET_VALUE`: A rank 1 `tf.Tensor` with `dtype` `tf.int64` representing the value associate with the string. string_max_length: The maximum length of the strings. If any string is longer than `string_max_length`, a `ValueError` will be raised. Returns: input_strings: A rank 1 `tf.Tensor` containing the list of strings in `dataset`. string_values: A rank 2 `tf.Tensor` containing the values of `input_strings`. Raises: ValueError: If any string in `dataset` is longer than string_max_length. """ parsed_dict = data_processing.to_stacked_tensor(dataset) input_strings = parsed_dict[DATASET_KEY] string_values = parsed_dict[DATASET_VALUE] tf.debugging.Assert(tf.math.logical_not( tf.math.reduce_any( tf.greater(tf.strings.length(input_strings), string_max_length))), data=[input_strings], name='CHECK_STRING_LENGTH') return input_strings, string_values
def test_basic_encoding(self): ds = tf.data.Dataset.range(5) encoded = data_processing.to_stacked_tensor(ds) self.assertIsInstance(encoded, tf.Tensor) self.assertEqual(encoded.shape, [5]) self.assertAllEqual(encoded, list(range(5)))
def test_roundtrip(self): ds = tf.data.Dataset.from_tensor_slices( collections.OrderedDict( x=[1, 2, 3], y=[['a'], ['b'], ['c']], )) encoded = data_processing.to_stacked_tensor(ds) roundtripped = tf.data.Dataset.from_tensor_slices(encoded) self.assertAllEqual(list(ds), list(roundtripped))
def test_validates_input(self): not_a_dataset = [tf.constant(42)] with self.assertRaisesRegex(TypeError, 'ds'): data_processing.to_stacked_tensor(not_a_dataset)
def test_batched_with_remainder_unsupported(self): ds = tf.data.Dataset.range(6).batch(2, drop_remainder=False) with self.assertRaisesRegex( ValueError, 'Dataset elements must have fully-defined shapes'): data_processing.to_stacked_tensor(ds)
def test_batched_drop_remainder(self): ds = tf.data.Dataset.range(6).batch(2, drop_remainder=True) encoded = data_processing.to_stacked_tensor(ds) self.assertAllEqual(encoded, [[0, 1], [2, 3], [4, 5]])
def test_single_element(self): ds = tf.data.Dataset.from_tensors([42]) encoded = data_processing.to_stacked_tensor(ds) self.assertAllEqual(encoded, [[42]])
def test_nested_structure(self): ds = tf.data.Dataset.from_tensors(collections.OrderedDict(x=42)) encoded = data_processing.to_stacked_tensor(ds) self.assertAllEqual(encoded, collections.OrderedDict(x=[42]))