예제 #1
0
    def test_triviaqa_truncate_text(self):

        vocab = test_utils.sentencepiece_vocab()

        def tokenize_and_prepare_dataset(inputs, targets):
            tokenized_inputs = vocab.encode(inputs)
            tokenized_targets = vocab.encode(targets)

            dataset = tf.data.Dataset.from_tensors({
                'inputs':
                tokenized_inputs,
                'targets':
                tokenized_targets,
            })

            return dataset, tokenized_targets

        inputs = 'This is a very very long string which must contain the answer.'
        targets = 'long string'

        og_dataset, tokenized_targets = tokenize_and_prepare_dataset(
            inputs, targets)

        for _ in range(0, 10):
            dataset = prep.trivia_qa_truncate_inputs(
                og_dataset,
                output_features=None,
                sequence_length={'inputs': 20})

            for data in test_utils.dataset_as_text(dataset):
                self.assertLen(data['inputs'], 20)
                self.assertContainsSubset(tokenized_targets, data['inputs'])

        # Dummy input which exists in the vocab to be able to compare strings after
        # decoding.
        inputs = 'w h d n r t v'
        targets = 'h d'

        og_dataset, _ = tokenize_and_prepare_dataset(inputs, targets)

        for _ in range(0, 5):
            dataset = prep.trivia_qa_truncate_inputs(
                og_dataset,
                output_features=None,
                sequence_length={'inputs': 5})

            for data in test_utils.dataset_as_text(dataset):
                self.assertLen(data['inputs'], 5)
                truncated_inputs = vocab.decode(data['inputs'].tolist())
                new_targets = vocab.decode(data['targets'].tolist())
                self.assertRegex(truncated_inputs, '.*' + targets + '.*')
                self.assertEqual(targets, new_targets)
예제 #2
0
    def test_fill_in_the_blank_sized(self):
        def _validate_data(data, valid_bins, og_length=15):
            # Remove the prefix from the start of the input string
            self.assertTrue(data['inputs'].startswith('fill: '))
            inp = data['inputs'].replace('fill: ', '')
            # Split input into chunks according to blank locations.
            inp_split = inp.split('_')
            # Make sure that there is exactly one blank (could be at beginning/end).
            self.assertLen(inp_split, 3)
            # Make sure reconstruction is accurate.
            reconstructed = ''.join([inp_split[0], data['targets']] +
                                    inp_split[2:])
            self.assertEqual(reconstructed, original)
            # Make sure blank size is correctly chosen.
            blank_bin = int(inp_split[1])
            self.assertIn(blank_bin, valid_bins)
            blank_size = len(data['targets'].split())
            self.assertGreaterEqual(blank_size, min(og_length, valid_bins[0]))
            self.assertLessEqual(blank_size, valid_bins[-1])
            return blank_size, blank_bin

        num_tries = 250
        original = 'This is a long test with lots of words to see if it works ok.'
        dataset = tf.data.Dataset.from_tensor_slices(
            {'text': [original] * num_tries})
        dataset = prep.fill_in_the_blank_sized(dataset, [1, 4])
        num_outputs = 0
        for data in test_utils.dataset_as_text(dataset):
            blank_size, blank_bin = _validate_data(data, [1, 4])
            if blank_size <= 2:
                self.assertEqual(blank_bin, 1)
            else:
                self.assertEqual(blank_bin, 4)
            num_outputs += 1
        self.assertEqual(num_tries, num_outputs)

        # Check case where bin size is larger than text.
        dataset = tf.data.Dataset.from_tensor_slices(
            {'text': [original] * num_tries})
        dataset = prep.fill_in_the_blank_sized(dataset, [1024])
        self.assertEmpty(list(test_utils.dataset_as_text(dataset)))
예제 #3
0
 def test_random_split_text(self):
     num_tries = 10
     original = '%s' % list(range(100))
     dataset = tf.data.Dataset.from_tensor_slices(
         {'text': [original] * num_tries})
     dataset = prep.random_split_text(dataset)
     out = []
     for data in test_utils.dataset_as_text(dataset):
         out.append(data['text'])
     reconstructed = ' '.join(out)
     ref = ' '.join([original] * num_tries)
     self.assertEqual(reconstructed, ref)
 def _verify_split(length, n_expected_outputs):
   ds = prep.split_tokens(
       og_dataset, unused_vocabulary=None, max_tokens_per_segment=length)
   outputs = list(test_utils.dataset_as_text(ds))
   self.assertLen(outputs, n_expected_outputs)
   reconstructed = []
   for ex in outputs[:-1]:
     t = ex['targets']
     self.assertLen(t, length)
     reconstructed.extend(t)
   final_t = outputs[-1]['targets']
   self.assertLessEqual(len(final_t), length)
   reconstructed.extend(final_t)
   self.assertEqual(reconstructed, original)
예제 #5
0
    def test_prefix_lm(self):
        num_tries = 100
        original = 'This is a long test with lots of words to see if it works ok.'
        dataset = tf.data.Dataset.from_tensor_slices(
            {'text': [original] * num_tries})
        dataset = prep.prefix_lm(dataset)
        for data in test_utils.dataset_as_text(dataset):
            inputs = data['inputs'].replace('prefix: ', '')
            targets = data['targets']

            reconstructed = ''.join(inputs)
            if inputs:
                reconstructed += ' '
            reconstructed += ''.join(targets)

            self.assertEqual(reconstructed, original)
 def test_fill_in_the_blank(self):
   num_tries = 1000
   original = 'This is a long test with lots of words to see if it works ok.'
   dataset = tf.data.Dataset.from_tensor_slices(
       {'text': [original] * num_tries})
   dataset = prep.fill_in_the_blank(dataset)
   for data in test_utils.dataset_as_text(dataset):
     # Remove the prefix from the start of the input string
     inp = data['inputs'].replace('fill: ', '')
     # Split output into chunks according to X locations.
     out_split = data['targets'].split('X')
     # Make sure that there is at least one blank
     self.assertGreater(len(out_split), 1)
     # Remove leading/trailing whitespace and any empty chunks
     out_split = [o.strip() for o in out_split if o]
     # Replace 'X' with entries from out_split by popping from the front
     reconstructed = ''.join(
         [i if i != 'X' else out_split.pop(0) for i in inp])
     self.assertEqual(reconstructed, original)
예제 #7
0
    def test_triviaqa_truncate(self):

        sequence_length = {
            'inputs': 10,
        }

        # Answer starts from the 0th position of the inputs.
        dataset = tf.data.Dataset.from_tensors({
            'inputs': tf.range(0, 30),
            'targets': tf.range(0, 5)
        })

        dataset = prep.trivia_qa_truncate_inputs(
            dataset, vocabulary=None, sequence_length=sequence_length)

        assert_dataset(dataset, {
            'inputs': tf.range(0, 10),
            'targets': tf.range(0, 5)
        })

        # Answer is in the last n elements of the targets.
        dataset = tf.data.Dataset.from_tensors({
            'inputs': tf.range(0, 30),
            'targets': tf.range(27, 30)
        })

        dataset = prep.trivia_qa_truncate_inputs(
            dataset, vocabulary=None, sequence_length=sequence_length)

        assert_dataset(dataset, {
            'inputs': tf.range(20, 30),
            'targets': tf.range(27, 30)
        })

        # Answer is not in inputs. Example is droped from the dataset.
        no_overlap_dataset = tf.data.Dataset.from_tensors({
            'inputs':
            tf.range(0, 30),
            'targets':
            tf.range(27, 32)
        })

        dataset = prep.trivia_qa_truncate_inputs(
            no_overlap_dataset,
            vocabulary=None,
            sequence_length=sequence_length)

        i = 0
        for data in test_utils.dataset_as_text(dataset):
            i = i + 1

        self.assertEqual(i, 0)

        # Answer is in the middle of the inputs.
        for _ in range(0, 10):
            og_dataset = tf.data.Dataset.from_tensors({
                'inputs':
                tf.range(0, 30),
                'targets':
                tf.range(10, 15),
            })

            dataset = prep.trivia_qa_truncate_inputs(
                og_dataset, vocabulary=None, sequence_length=sequence_length)
            for data in test_utils.dataset_as_text(dataset):
                self.assertContainsSubset(data['targets'], data['inputs'])
                self.assertLen(data['inputs'], 10)