def test_prefix_lm_last_input_batch_exists_but_no_output(self): prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6, 7]])) self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) self.assertLen(examples, 1)
def test_prefix_lm_last_output_batch_is_short(self): prefix_lm_fn = data.PrefixLM(input_length=2, output_length=3) examples = list(prefix_lm_fn([[1, 2, 3, 4, 5, 6, 7, 8]])) self.assertSequenceEqual(([1, 2], [3, 4, 5]), examples[0]) self.assertSequenceEqual(([6, 7], [8]), examples[1]) self.assertLen(examples, 2)