示例#1
0
    def test_empty_dataset(self):
        ds = tf.data.Dataset.range(-1)
        assert ds.cardinality() == 0

        encoded = data_processing.to_stacked_tensor(ds)

        self.assertAllEqual(encoded, list())
示例#2
0
    def test_non_scalar_tensor(self):
        ds = tf.data.Dataset.from_tensors([[1, 2], [3, 4]])
        assert len(ds.element_spec.shape) > 1, ds.element_spec.shape

        encoded = data_processing.to_stacked_tensor(ds)

        self.assertAllEqual(encoded, [[[1, 2], [3, 4]]])
示例#3
0
def _parse_client_dict(dataset: tf.data.Dataset,
                       string_max_length: int) -> Tuple[tf.Tensor, tf.Tensor]:
    """Parses the dictionary in the input `dataset` to key and value lists.

  Args:
    dataset: A `tf.data.Dataset` that yields `OrderedDict`. In each
      `OrderedDict` there are two key, value pairs:
        `DATASET_KEY`: A `tf.string` representing a string in the dataset.
        `DATASET_VALUE`: A rank 1 `tf.Tensor` with `dtype` `tf.int64`
          representing the value associate with the string.
    string_max_length: The maximum length of the strings. If any string is
      longer than `string_max_length`, a `ValueError` will be raised.

  Returns:
    input_strings: A rank 1 `tf.Tensor` containing the list of strings in
      `dataset`.
    string_values: A rank 2 `tf.Tensor` containing the values of
      `input_strings`.
  Raises:
    ValueError: If any string in `dataset` is longer than string_max_length.
  """
    parsed_dict = data_processing.to_stacked_tensor(dataset)
    input_strings = parsed_dict[DATASET_KEY]
    string_values = parsed_dict[DATASET_VALUE]
    tf.debugging.Assert(tf.math.logical_not(
        tf.math.reduce_any(
            tf.greater(tf.strings.length(input_strings), string_max_length))),
                        data=[input_strings],
                        name='CHECK_STRING_LENGTH')
    return input_strings, string_values
示例#4
0
    def test_basic_encoding(self):
        ds = tf.data.Dataset.range(5)

        encoded = data_processing.to_stacked_tensor(ds)

        self.assertIsInstance(encoded, tf.Tensor)
        self.assertEqual(encoded.shape, [5])
        self.assertAllEqual(encoded, list(range(5)))
示例#5
0
    def test_roundtrip(self):
        ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(
                x=[1, 2, 3],
                y=[['a'], ['b'], ['c']],
            ))

        encoded = data_processing.to_stacked_tensor(ds)
        roundtripped = tf.data.Dataset.from_tensor_slices(encoded)

        self.assertAllEqual(list(ds), list(roundtripped))
示例#6
0
    def test_validates_input(self):
        not_a_dataset = [tf.constant(42)]

        with self.assertRaisesRegex(TypeError, 'ds'):
            data_processing.to_stacked_tensor(not_a_dataset)
示例#7
0
    def test_batched_with_remainder_unsupported(self):
        ds = tf.data.Dataset.range(6).batch(2, drop_remainder=False)

        with self.assertRaisesRegex(
                ValueError, 'Dataset elements must have fully-defined shapes'):
            data_processing.to_stacked_tensor(ds)
示例#8
0
    def test_batched_drop_remainder(self):
        ds = tf.data.Dataset.range(6).batch(2, drop_remainder=True)

        encoded = data_processing.to_stacked_tensor(ds)

        self.assertAllEqual(encoded, [[0, 1], [2, 3], [4, 5]])
示例#9
0
    def test_single_element(self):
        ds = tf.data.Dataset.from_tensors([42])

        encoded = data_processing.to_stacked_tensor(ds)

        self.assertAllEqual(encoded, [[42]])
示例#10
0
    def test_nested_structure(self):
        ds = tf.data.Dataset.from_tensors(collections.OrderedDict(x=42))

        encoded = data_processing.to_stacked_tensor(ds)

        self.assertAllEqual(encoded, collections.OrderedDict(x=[42]))