def test_triviaqa_truncate_text(self): vocab = test_utils.sentencepiece_vocab() def tokenize_and_prepare_dataset(inputs, targets): tokenized_inputs = vocab.encode(inputs) tokenized_targets = vocab.encode(targets) dataset = tf.data.Dataset.from_tensors({ 'inputs': tokenized_inputs, 'targets': tokenized_targets, }) return dataset, tokenized_targets inputs = 'This is a very very long string which must contain the answer.' targets = 'long string' og_dataset, tokenized_targets = tokenize_and_prepare_dataset( inputs, targets) for _ in range(0, 10): dataset = prep.trivia_qa_truncate_inputs( og_dataset, output_features=None, sequence_length={'inputs': 20}) for data in test_utils.dataset_as_text(dataset): self.assertLen(data['inputs'], 20) self.assertContainsSubset(tokenized_targets, data['inputs']) # Dummy input which exists in the vocab to be able to compare strings after # decoding. inputs = 'w h d n r t v' targets = 'h d' og_dataset, _ = tokenize_and_prepare_dataset(inputs, targets) for _ in range(0, 5): dataset = prep.trivia_qa_truncate_inputs( og_dataset, output_features=None, sequence_length={'inputs': 5}) for data in test_utils.dataset_as_text(dataset): self.assertLen(data['inputs'], 5) truncated_inputs = vocab.decode(data['inputs'].tolist()) new_targets = vocab.decode(data['targets'].tolist()) self.assertRegex(truncated_inputs, '.*' + targets + '.*') self.assertEqual(targets, new_targets)
def test_fill_in_the_blank_sized(self): def _validate_data(data, valid_bins, og_length=15): # Remove the prefix from the start of the input string self.assertTrue(data['inputs'].startswith('fill: ')) inp = data['inputs'].replace('fill: ', '') # Split input into chunks according to blank locations. inp_split = inp.split('_') # Make sure that there is exactly one blank (could be at beginning/end). self.assertLen(inp_split, 3) # Make sure reconstruction is accurate. reconstructed = ''.join([inp_split[0], data['targets']] + inp_split[2:]) self.assertEqual(reconstructed, original) # Make sure blank size is correctly chosen. blank_bin = int(inp_split[1]) self.assertIn(blank_bin, valid_bins) blank_size = len(data['targets'].split()) self.assertGreaterEqual(blank_size, min(og_length, valid_bins[0])) self.assertLessEqual(blank_size, valid_bins[-1]) return blank_size, blank_bin num_tries = 250 original = 'This is a long test with lots of words to see if it works ok.' dataset = tf.data.Dataset.from_tensor_slices( {'text': [original] * num_tries}) dataset = prep.fill_in_the_blank_sized(dataset, [1, 4]) num_outputs = 0 for data in test_utils.dataset_as_text(dataset): blank_size, blank_bin = _validate_data(data, [1, 4]) if blank_size <= 2: self.assertEqual(blank_bin, 1) else: self.assertEqual(blank_bin, 4) num_outputs += 1 self.assertEqual(num_tries, num_outputs) # Check case where bin size is larger than text. dataset = tf.data.Dataset.from_tensor_slices( {'text': [original] * num_tries}) dataset = prep.fill_in_the_blank_sized(dataset, [1024]) self.assertEmpty(list(test_utils.dataset_as_text(dataset)))
def test_random_split_text(self): num_tries = 10 original = '%s' % list(range(100)) dataset = tf.data.Dataset.from_tensor_slices( {'text': [original] * num_tries}) dataset = prep.random_split_text(dataset) out = [] for data in test_utils.dataset_as_text(dataset): out.append(data['text']) reconstructed = ' '.join(out) ref = ' '.join([original] * num_tries) self.assertEqual(reconstructed, ref)
def _verify_split(length, n_expected_outputs): ds = prep.split_tokens( og_dataset, unused_vocabulary=None, max_tokens_per_segment=length) outputs = list(test_utils.dataset_as_text(ds)) self.assertLen(outputs, n_expected_outputs) reconstructed = [] for ex in outputs[:-1]: t = ex['targets'] self.assertLen(t, length) reconstructed.extend(t) final_t = outputs[-1]['targets'] self.assertLessEqual(len(final_t), length) reconstructed.extend(final_t) self.assertEqual(reconstructed, original)
def test_prefix_lm(self): num_tries = 100 original = 'This is a long test with lots of words to see if it works ok.' dataset = tf.data.Dataset.from_tensor_slices( {'text': [original] * num_tries}) dataset = prep.prefix_lm(dataset) for data in test_utils.dataset_as_text(dataset): inputs = data['inputs'].replace('prefix: ', '') targets = data['targets'] reconstructed = ''.join(inputs) if inputs: reconstructed += ' ' reconstructed += ''.join(targets) self.assertEqual(reconstructed, original)
def test_fill_in_the_blank(self): num_tries = 1000 original = 'This is a long test with lots of words to see if it works ok.' dataset = tf.data.Dataset.from_tensor_slices( {'text': [original] * num_tries}) dataset = prep.fill_in_the_blank(dataset) for data in test_utils.dataset_as_text(dataset): # Remove the prefix from the start of the input string inp = data['inputs'].replace('fill: ', '') # Split output into chunks according to X locations. out_split = data['targets'].split('X') # Make sure that there is at least one blank self.assertGreater(len(out_split), 1) # Remove leading/trailing whitespace and any empty chunks out_split = [o.strip() for o in out_split if o] # Replace 'X' with entries from out_split by popping from the front reconstructed = ''.join( [i if i != 'X' else out_split.pop(0) for i in inp]) self.assertEqual(reconstructed, original)
def test_triviaqa_truncate(self): sequence_length = { 'inputs': 10, } # Answer starts from the 0th position of the inputs. dataset = tf.data.Dataset.from_tensors({ 'inputs': tf.range(0, 30), 'targets': tf.range(0, 5) }) dataset = prep.trivia_qa_truncate_inputs( dataset, vocabulary=None, sequence_length=sequence_length) assert_dataset(dataset, { 'inputs': tf.range(0, 10), 'targets': tf.range(0, 5) }) # Answer is in the last n elements of the targets. dataset = tf.data.Dataset.from_tensors({ 'inputs': tf.range(0, 30), 'targets': tf.range(27, 30) }) dataset = prep.trivia_qa_truncate_inputs( dataset, vocabulary=None, sequence_length=sequence_length) assert_dataset(dataset, { 'inputs': tf.range(20, 30), 'targets': tf.range(27, 30) }) # Answer is not in inputs. Example is droped from the dataset. no_overlap_dataset = tf.data.Dataset.from_tensors({ 'inputs': tf.range(0, 30), 'targets': tf.range(27, 32) }) dataset = prep.trivia_qa_truncate_inputs( no_overlap_dataset, vocabulary=None, sequence_length=sequence_length) i = 0 for data in test_utils.dataset_as_text(dataset): i = i + 1 self.assertEqual(i, 0) # Answer is in the middle of the inputs. for _ in range(0, 10): og_dataset = tf.data.Dataset.from_tensors({ 'inputs': tf.range(0, 30), 'targets': tf.range(10, 15), }) dataset = prep.trivia_qa_truncate_inputs( og_dataset, vocabulary=None, sequence_length=sequence_length) for data in test_utils.dataset_as_text(dataset): self.assertContainsSubset(data['targets'], data['inputs']) self.assertLen(data['inputs'], 10)