Exemple #1
0
    def test_triviaqa_truncate_text(self):

        vocab = test_utils.sentencepiece_vocab()

        def tokenize_and_prepare_dataset(inputs, targets):
            tokenized_inputs = vocab.encode(inputs)
            tokenized_targets = vocab.encode(targets)

            dataset = tf.data.Dataset.from_tensors({
                'inputs':
                tokenized_inputs,
                'targets':
                tokenized_targets,
            })

            return dataset, tokenized_targets

        inputs = 'This is a very very long string which must contain the answer.'
        targets = 'long string'

        og_dataset, tokenized_targets = tokenize_and_prepare_dataset(
            inputs, targets)

        for _ in range(0, 10):
            dataset = prep.trivia_qa_truncate_inputs(
                og_dataset,
                output_features=None,
                sequence_length={'inputs': 20})

            for data in test_utils.dataset_as_text(dataset):
                self.assertLen(data['inputs'], 20)
                self.assertContainsSubset(tokenized_targets, data['inputs'])

        # Dummy input which exists in the vocab to be able to compare strings after
        # decoding.
        inputs = 'w h d n r t v'
        targets = 'h d'

        og_dataset, _ = tokenize_and_prepare_dataset(inputs, targets)

        for _ in range(0, 5):
            dataset = prep.trivia_qa_truncate_inputs(
                og_dataset,
                output_features=None,
                sequence_length={'inputs': 5})

            for data in test_utils.dataset_as_text(dataset):
                self.assertLen(data['inputs'], 5)
                truncated_inputs = vocab.decode(data['inputs'].tolist())
                new_targets = vocab.decode(data['targets'].tolist())
                self.assertRegex(truncated_inputs, '.*' + targets + '.*')
                self.assertEqual(targets, new_targets)
Exemple #2
0
    def test_triviaqa_truncate(self):

        sequence_length = {
            'inputs': 10,
        }

        # Answer starts from the 0th position of the inputs.
        dataset = tf.data.Dataset.from_tensors({
            'inputs': tf.range(0, 30),
            'targets': tf.range(0, 5)
        })

        dataset = prep.trivia_qa_truncate_inputs(
            dataset, vocabulary=None, sequence_length=sequence_length)

        assert_dataset(dataset, {
            'inputs': tf.range(0, 10),
            'targets': tf.range(0, 5)
        })

        # Answer is in the last n elements of the targets.
        dataset = tf.data.Dataset.from_tensors({
            'inputs': tf.range(0, 30),
            'targets': tf.range(27, 30)
        })

        dataset = prep.trivia_qa_truncate_inputs(
            dataset, vocabulary=None, sequence_length=sequence_length)

        assert_dataset(dataset, {
            'inputs': tf.range(20, 30),
            'targets': tf.range(27, 30)
        })

        # Answer is not in inputs. Example is droped from the dataset.
        no_overlap_dataset = tf.data.Dataset.from_tensors({
            'inputs':
            tf.range(0, 30),
            'targets':
            tf.range(27, 32)
        })

        dataset = prep.trivia_qa_truncate_inputs(
            no_overlap_dataset,
            vocabulary=None,
            sequence_length=sequence_length)

        i = 0
        for data in test_utils.dataset_as_text(dataset):
            i = i + 1

        self.assertEqual(i, 0)

        # Answer is in the middle of the inputs.
        for _ in range(0, 10):
            og_dataset = tf.data.Dataset.from_tensors({
                'inputs':
                tf.range(0, 30),
                'targets':
                tf.range(10, 15),
            })

            dataset = prep.trivia_qa_truncate_inputs(
                og_dataset, vocabulary=None, sequence_length=sequence_length)
            for data in test_utils.dataset_as_text(dataset):
                self.assertContainsSubset(data['targets'], data['inputs'])
                self.assertLen(data['inputs'], 10)