def test_no_repeat_ngram_dist_processor(self): vocab_size = 3 batch_size = 2 input_ids = torch.tensor([[1, 1, 2, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long) scores = self._get_uniform_logits(batch_size, vocab_size) no_repeat_proc_2_gram = NoRepeatNGramLogitsProcessor(2) no_repeat_proc_3_gram = NoRepeatNGramLogitsProcessor(3) filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone()) filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone()) # 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch self.assertListEqual( torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]]) # 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch self.assertListEqual( torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]])
def test_processor_list(self): batch_size = 4 sequence_length = 10 vocab_size = 15 eos_token_id = 0 # dummy input_ids and scores input_ids = ids_tensor((batch_size, sequence_length), vocab_size) input_ids_comp = input_ids.clone() scores = self._get_uniform_logits(batch_size, vocab_size) scores_comp = scores.clone() # instantiate all dist processors min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) temp_dist_warp = TemperatureLogitsWarper(temperature=0.5) rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0) top_k_warp = TopKLogitsWarper(3) top_p_warp = TopPLogitsWarper(0.8) no_repeat_proc = NoRepeatNGramLogitsProcessor(2) no_bad_words_dist_proc = NoBadWordsLogitsProcessor( bad_words_ids=[[1]], eos_token_id=eos_token_id) # no processor list scores = min_dist_proc(input_ids, scores) scores = temp_dist_warp(input_ids, scores) scores = rep_penalty_proc(input_ids, scores) scores = top_k_warp(input_ids, scores) scores = top_p_warp(input_ids, scores) scores = no_repeat_proc(input_ids, scores) scores = no_bad_words_dist_proc(input_ids, scores) # with processor list processor = LogitsProcessorList([ min_dist_proc, temp_dist_warp, rep_penalty_proc, top_k_warp, top_p_warp, no_repeat_proc, no_bad_words_dist_proc, ]) scores_comp = processor(input_ids, scores_comp) # scores should be equal self.assertTrue(torch.allclose(scores, scores_comp, atol=1e-3)) # input_ids should never be changed self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
def _get_logits_processor_and_kwargs(input_length, eos_token_id): process_kwargs = { "min_length": input_length + 1, "bad_words_ids": [[1, 0]], "no_repeat_ngram_size": 2, "repetition_penalty": 1.2, } logits_processor = LogitsProcessorList(([ MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id ), ] if eos_token_id is not None else []) + [ NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id), NoRepeatNGramLogitsProcessor( process_kwargs["no_repeat_ngram_size"]), RepetitionPenaltyLogitsProcessor( process_kwargs["repetition_penalty"]), ]) return process_kwargs, logits_processor