def test_prefix_lm(self): num_tries = 100 original = 'This is a long test with lots of words to see if it works ok.' dataset = tf.data.Dataset.from_tensor_slices( {'text': [original] * num_tries}) dataset = prep.prefix_lm(dataset) for data in test_utils.dataset_as_text(dataset): inputs = data['inputs'].replace('prefix: ', '') targets = data['targets'] reconstructed = ''.join(inputs) if inputs: reconstructed += ' ' reconstructed += ''.join(targets) self.assertEqual(reconstructed, original)
def test_prefix_lm(self): vocab = test_utils.sentencepiece_vocab() inp = list(range(1, 101)) og_dataset = tf.data.Dataset.from_tensor_slices({'targets': [inp]}) og_dataset = og_dataset.repeat(100) output_features = {'targets': Feature(vocab)} output_dataset = prep.prefix_lm( og_dataset, { 'inputs': 100, 'targets': 100 }, output_features, ) input_lengths = set() for ex in output_dataset.as_numpy_iterator(): self.assertListEqual( ex['inputs'].tolist() + ex['targets'].tolist(), inp) input_lengths.add(len(ex['inputs'])) self.assertGreater(len(input_lengths), 1)