Beispiel #1
0
    def test_get_top_elements_raise_type_error(self, input_data):
        ds = tf.data.Dataset.from_tensor_slices(input_data).batch(batch_size=1)

        with self.assertRaisesRegex(
                TypeError,
                '`dataset.element_spec.dtype` must be `tf.string`.'):
            data_processing.get_top_elements(ds, max_user_contribution=10)
Beispiel #2
0
    def test_get_top_elements_raise_params_value_error(self,
                                                       max_user_contribution,
                                                       max_string_length,
                                                       raises_regex):
        ds = tf.data.Dataset.from_tensor_slices(['a', 'b', 'a', 'b',
                                                 'c']).batch(batch_size=1)

        with self.assertRaisesRegex(ValueError, raises_regex):
            data_processing.get_top_elements(
                ds,
                max_user_contribution=max_user_contribution,
                max_string_length=max_string_length)
Beispiel #3
0
    def test_get_top_elements_raise_rank_value_error(self, dataset_rank):
        ds = tf.data.Dataset.from_tensor_slices(['a', 'b', 'a', 'b', 'c'])
        batch_size = 1
        max_user_contribution = 1
        while dataset_rank:
            ds = ds.batch(batch_size=batch_size)
            dataset_rank -= 1

        with self.assertRaisesRegex(
                ValueError,
                'The shape of elements in `dataset` must be of rank 1.*'):
            data_processing.get_top_elements(ds, max_user_contribution)
Beispiel #4
0
    def test_get_top_elements_raise_params_value_error(self):
        ds = tf.data.Dataset.from_tensor_slices(['a', 'b', 'a', 'b',
                                                 'c']).batch(batch_size=1)

        with self.assertRaisesRegex(
                ValueError, '`max_user_contribution` must be at least 1.'):
            _ = data_processing.get_top_elements(ds, max_user_contribution=0)
Beispiel #5
0
 def test_get_top_elements_returns_expected_values(self, input_data,
                                                   batch_size,
                                                   max_user_contribution,
                                                   expected_result):
     ds = tf.data.Dataset.from_tensor_slices(input_data).batch(batch_size)
     top_elements = data_processing.get_top_elements(
         ds, max_user_contribution)
     self.assertSetEqual(set(top_elements.numpy()), set(expected_result))
Beispiel #6
0
 def test_get_top_elements_with_max_len_returns_expected_values(
         self, input_data, batch_size, max_user_contribution,
         max_string_length, expected_result):
     ds = tf.data.Dataset.from_tensor_slices(input_data).batch(batch_size)
     top_elements = data_processing.get_top_elements(
         ds, max_user_contribution, max_string_length=max_string_length)
     top_elements = [
         elem.decode('utf-8', 'ignore') for elem in top_elements.numpy()
     ]
     self.assertSetEqual(set(top_elements), set(expected_result))