def test_triviaqa_truncate_text(self): vocab = test_utils.sentencepiece_vocab() def tokenize_and_prepare_dataset(inputs, targets): tokenized_inputs = vocab.encode(inputs) tokenized_targets = vocab.encode(targets) dataset = tf.data.Dataset.from_tensors({ 'inputs': tokenized_inputs, 'targets': tokenized_targets, }) return dataset, tokenized_targets inputs = 'This is a very very long string which must contain the answer.' targets = 'long string' og_dataset, tokenized_targets = tokenize_and_prepare_dataset( inputs, targets) for _ in range(0, 10): dataset = prep.trivia_qa_truncate_inputs( og_dataset, output_features=None, sequence_length={'inputs': 20}) for data in test_utils.dataset_as_text(dataset): self.assertLen(data['inputs'], 20) self.assertContainsSubset(tokenized_targets, data['inputs']) # Dummy input which exists in the vocab to be able to compare strings after # decoding. inputs = 'w h d n r t v' targets = 'h d' og_dataset, _ = tokenize_and_prepare_dataset(inputs, targets) for _ in range(0, 5): dataset = prep.trivia_qa_truncate_inputs( og_dataset, output_features=None, sequence_length={'inputs': 5}) for data in test_utils.dataset_as_text(dataset): self.assertLen(data['inputs'], 5) truncated_inputs = vocab.decode(data['inputs'].tolist()) new_targets = vocab.decode(data['targets'].tolist()) self.assertRegex(truncated_inputs, '.*' + targets + '.*') self.assertEqual(targets, new_targets)
def test_triviaqa_truncate(self): sequence_length = { 'inputs': 10, } # Answer starts from the 0th position of the inputs. dataset = tf.data.Dataset.from_tensors({ 'inputs': tf.range(0, 30), 'targets': tf.range(0, 5) }) dataset = prep.trivia_qa_truncate_inputs( dataset, vocabulary=None, sequence_length=sequence_length) assert_dataset(dataset, { 'inputs': tf.range(0, 10), 'targets': tf.range(0, 5) }) # Answer is in the last n elements of the targets. dataset = tf.data.Dataset.from_tensors({ 'inputs': tf.range(0, 30), 'targets': tf.range(27, 30) }) dataset = prep.trivia_qa_truncate_inputs( dataset, vocabulary=None, sequence_length=sequence_length) assert_dataset(dataset, { 'inputs': tf.range(20, 30), 'targets': tf.range(27, 30) }) # Answer is not in inputs. Example is droped from the dataset. no_overlap_dataset = tf.data.Dataset.from_tensors({ 'inputs': tf.range(0, 30), 'targets': tf.range(27, 32) }) dataset = prep.trivia_qa_truncate_inputs( no_overlap_dataset, vocabulary=None, sequence_length=sequence_length) i = 0 for data in test_utils.dataset_as_text(dataset): i = i + 1 self.assertEqual(i, 0) # Answer is in the middle of the inputs. for _ in range(0, 10): og_dataset = tf.data.Dataset.from_tensors({ 'inputs': tf.range(0, 30), 'targets': tf.range(10, 15), }) dataset = prep.trivia_qa_truncate_inputs( og_dataset, vocabulary=None, sequence_length=sequence_length) for data in test_utils.dataset_as_text(dataset): self.assertContainsSubset(data['targets'], data['inputs']) self.assertLen(data['inputs'], 10)