def test_repetition_penalty_dist_process(self, use_xla): vocab_size = 10 cur_len = 2 input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32) self.assertEqual(cur_len, input_ids.shape[1]) scores = self._get_uniform_logits(batch_size=2, length=vocab_size) mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool) scores = tf.where(mask, -1 / vocab_size, scores) mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool) scores = tf.where(mask, 4 / vocab_size, scores) rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) if use_xla: rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True) scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len) # check that values were correctly changed (negative scores for used tokens should increase, others # should decrease) self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2) self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2) self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2) self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
def test_processor_list(self): batch_size = 4 cur_len = 10 vocab_size = 15 eos_token_id = 0 # dummy input_ids and scores input_ids = ids_tensor((batch_size, cur_len), vocab_size) input_ids_comp = tf.identity(input_ids) scores = self._get_uniform_logits(batch_size, vocab_size) scores_comp = tf.identity(scores) # instantiate all dist processors min_dist_proc = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) temp_dist_warp = TFTemperatureLogitsWarper(temperature=0.5) rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) top_k_warp = TFTopKLogitsWarper(3) top_p_warp = TFTopPLogitsWarper(0.8) no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2) no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id) # no processor list scores = min_dist_proc(input_ids, scores, cur_len) scores = temp_dist_warp(input_ids, scores) scores = rep_penalty_proc(input_ids, scores, cur_len) scores = top_k_warp(input_ids, scores) scores = top_p_warp(input_ids, scores) scores = no_repeat_proc(input_ids, scores, cur_len) scores = no_bad_words_dist_proc(input_ids, scores, cur_len) # with processor list processor = TFLogitsProcessorList( [ min_dist_proc, temp_dist_warp, rep_penalty_proc, top_k_warp, top_p_warp, no_repeat_proc, no_bad_words_dist_proc, ] ) scores_comp = processor(input_ids, scores_comp, cur_len=cur_len) # remove inf scores = tf.where(tf.math.is_inf(scores), -1e9, scores) scores_comp = tf.where(tf.math.is_inf(scores_comp), -1e9, scores_comp) # scores should be equal tf.debugging.assert_near(scores, scores_comp, atol=1e-3) # input_ids should never be changed self.assertListEqual(input_ids.numpy().tolist(), input_ids_comp.numpy().tolist())
def test_processor_list(self): batch_size = 4 sequence_length = 10 vocab_size = 15 eos_token_id = 0 # dummy input_ids and scores input_ids = ids_tensor((batch_size, sequence_length), vocab_size) input_ids_comp = tf.identity(input_ids) scores = self._get_uniform_logits(batch_size, vocab_size) scores_comp = tf.identity(scores) # instantiate all dist processors min_dist_proc = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2) no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id) # no processor list scores = min_dist_proc(input_ids, scores) scores = rep_penalty_proc(input_ids, scores) scores = no_repeat_proc(input_ids, scores) scores = no_bad_words_dist_proc(input_ids, scores) # with processor list processor = TFLogitsProcessorList( [ min_dist_proc, rep_penalty_proc, no_repeat_proc, no_bad_words_dist_proc, ] ) scores_comp = processor(input_ids, scores_comp) # remove inf scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9) scores_comp = set_tensor_by_indices_to_value(scores_comp, tf.math.is_inf(scores_comp), -1e9) # scores should be equal tf.debugging.assert_near(scores, scores_comp, atol=1e-3) # input_ids should never be changed self.assertListEqual(input_ids.numpy().tolist(), input_ids_comp.numpy().tolist())
def test_repetition_penalty_dist_process(self): input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32) vocab_size = 10 scores = self._get_uniform_logits(batch_size=2, length=vocab_size) mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool) scores = set_tensor_by_indices_to_value(scores, mask, -1 / vocab_size) mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool) scores = set_tensor_by_indices_to_value(scores, mask, 4 / vocab_size) rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0) scores = rep_penalty_proc(input_ids, tf.identity(scores)) # check that values were correctly changed self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2) self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2) self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)