def test_top_p_dist_warper(self): input_ids = None vocab_size = 10 batch_size = 2 # create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper) dist = torch.log( torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float) ) top_p_warp = TopPLogitsWarper(0.7) filtered_dist = torch.exp(top_p_warp(input_ids, dist)) # dist should be filtered to keep min num values so that sum is >= 0.7 # exp (-inf) => 0 EXPECTED_FILTERED_DIST = torch.tensor( [[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float ) self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) # check edge cases with negative and extreme logits ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat( batch_size, 1 ) - (vocab_size // 2) # make ramp_logits more extreme ramp_logits[1] = ramp_logits[1] * 100.0 # make sure at least 2 tokens are kept top_p_warp = TopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0) filtered_dist = top_p_warp(input_ids, ramp_logits) # first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
def _get_warper_and_kwargs(num_beams): warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7} logits_warper = LogitsProcessorList( [ TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)), TemperatureLogitsWarper(warp_kwargs["temperature"]), ] ) return warp_kwargs, logits_warper
def test_processor_list(self): batch_size = 4 sequence_length = 10 vocab_size = 15 eos_token_id = 0 # dummy input_ids and scores input_ids = ids_tensor((batch_size, sequence_length), vocab_size) input_ids_comp = input_ids.clone() scores = self._get_uniform_logits(batch_size, vocab_size) scores_comp = scores.clone() # instantiate all dist processors min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id) temp_dist_warp = TemperatureLogitsWarper(temperature=0.5) rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0) top_k_warp = TopKLogitsWarper(3) top_p_warp = TopPLogitsWarper(0.8) no_repeat_proc = NoRepeatNGramLogitsProcessor(2) no_bad_words_dist_proc = NoBadWordsLogitsProcessor( bad_words_ids=[[1]], eos_token_id=eos_token_id) # no processor list scores = min_dist_proc(input_ids, scores) scores = temp_dist_warp(input_ids, scores) scores = rep_penalty_proc(input_ids, scores) scores = top_k_warp(input_ids, scores) scores = top_p_warp(input_ids, scores) scores = no_repeat_proc(input_ids, scores) scores = no_bad_words_dist_proc(input_ids, scores) # with processor list processor = LogitsProcessorList([ min_dist_proc, temp_dist_warp, rep_penalty_proc, top_k_warp, top_p_warp, no_repeat_proc, no_bad_words_dist_proc, ]) scores_comp = processor(input_ids, scores_comp) # scores should be equal self.assertTrue(torch.allclose(scores, scores_comp, atol=1e-3)) # input_ids should never be changed self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
def _get_logits_warper_list(num_beams, temperature, top_k, top_p): warpers = LogitsProcessorList() if temperature is not None and temperature != 1.0: warpers.append(TemperatureLogitsWarper(temperature)) if top_k is not None and top_k != 0: warpers.append( TopKLogitsWarper( top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) if top_p is not None and top_p < 1.0: warpers.append( TopPLogitsWarper( top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) return warpers