Ejemplo n.º 1
0
    def test_capped_elements_raise_type_error(self, input_data):
        ds = tf.data.Dataset.from_tensor_slices(input_data).batch(batch_size=1)

        with self.assertRaisesRegex(
                TypeError,
                '`dataset.element_spec.dtype` must be `tf.string`.'):
            data_processing.get_capped_elements(ds,
                                                max_user_contribution=10,
                                                batch_size=1)
Ejemplo n.º 2
0
 def test_capped_elements_raise_params_value_error(self,
                                                   max_user_contribution,
                                                   batch_size,
                                                   max_string_length,
                                                   raises_regex):
     ds = tf.data.Dataset.from_tensor_slices(['a', 'b', 'a', 'b',
                                              'c']).batch(batch_size=1)
     with self.assertRaisesRegex(ValueError, raises_regex):
         data_processing.get_capped_elements(
             ds,
             max_user_contribution=max_user_contribution,
             batch_size=batch_size,
             max_string_length=max_string_length)
Ejemplo n.º 3
0
    def test_capped_elements_raise_rank_value_error(self, dataset_rank):
        ds = tf.data.Dataset.from_tensor_slices(['a', 'b', 'a', 'b', 'c'])
        max_user_contribution = 3
        batch_size = 1
        while dataset_rank:
            ds = ds.batch(batch_size=batch_size)
            dataset_rank -= 1

        with self.assertRaisesRegex(
                ValueError,
                'The shape of elements in `dataset` must be of rank 1.*'):
            data_processing.get_capped_elements(
                ds,
                max_user_contribution=max_user_contribution,
                batch_size=batch_size)
Ejemplo n.º 4
0
    def test_capped_elements_raise_params_value_error(self):
        ds = tf.data.Dataset.from_tensor_slices(['a', 'b', 'a', 'b',
                                                 'c']).batch(batch_size=1)

        with self.subTest('batch_size_value_error'):
            with self.assertRaisesRegex(ValueError,
                                        '`batch_size` must be at least 1.'):
                _ = data_processing.get_capped_elements(
                    ds, max_user_contribution=3, batch_size=0)

        with self.subTest('max_user_contribution_value_error'):
            with self.assertRaisesRegex(
                    ValueError, '`max_user_contribution` must be at least 1.'):
                _ = data_processing.get_capped_elements(
                    ds, max_user_contribution=0, batch_size=1)
Ejemplo n.º 5
0
 def test_capped_elements_returns_expected_values(self, input_data,
                                                  batch_size,
                                                  max_user_contribution,
                                                  expected_result):
     ds = tf.data.Dataset.from_tensor_slices(input_data).batch(batch_size)
     capped_elements = data_processing.get_capped_elements(
         ds,
         max_user_contribution=max_user_contribution,
         batch_size=batch_size)
     self.assertAllEqual(capped_elements, expected_result)
Ejemplo n.º 6
0
 def test_capped_elements_with_max_len_returns_expected_values(
         self, input_data, batch_size, max_user_contribution,
         max_string_length, expected_result):
     ds = tf.data.Dataset.from_tensor_slices(input_data).batch(batch_size)
     all_elements = data_processing.get_capped_elements(
         ds,
         batch_size=batch_size,
         max_user_contribution=max_user_contribution,
         max_string_length=max_string_length)
     all_elements = [
         elem.decode('utf-8', 'ignore') for elem in all_elements.numpy()
     ]
     self.assertEqual(all_elements, expected_result)