Exemple #1
0
    def test_no_bad_words_dist_processor(self):
        vocab_size = 5
        batch_size = 2
        eos_token_id = 4

        input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]],
                                 device=torch_device,
                                 dtype=torch.long)
        bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
        scores = self._get_uniform_logits(batch_size, vocab_size)

        no_bad_words_dist_proc = NoBadWordsLogitsProcessor(
            bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)

        filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())

        # batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden
        # batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden
        # Note that 5th element cannot be forbidden as it is EOS token
        self.assertListEqual(
            torch.isinf(filtered_scores).tolist(),
            [[True, True, False, True, False],
             [True, True, True, False, False]])

        # check edge case
        no_bad_words_dist_proc = NoBadWordsLogitsProcessor(
            bad_words_ids=[[4]], eos_token_id=eos_token_id)
        filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
        self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3))
Exemple #2
0
    def test_processor_list(self):
        batch_size = 4
        sequence_length = 10
        vocab_size = 15
        eos_token_id = 0

        # dummy input_ids and scores
        input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
        input_ids_comp = input_ids.clone()

        scores = self._get_uniform_logits(batch_size, vocab_size)
        scores_comp = scores.clone()

        # instantiate all dist processors
        min_dist_proc = MinLengthLogitsProcessor(min_length=10,
                                                 eos_token_id=eos_token_id)
        temp_dist_warp = TemperatureLogitsWarper(temperature=0.5)
        rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
        top_k_warp = TopKLogitsWarper(3)
        top_p_warp = TopPLogitsWarper(0.8)
        no_repeat_proc = NoRepeatNGramLogitsProcessor(2)
        no_bad_words_dist_proc = NoBadWordsLogitsProcessor(
            bad_words_ids=[[1]], eos_token_id=eos_token_id)

        # no processor list
        scores = min_dist_proc(input_ids, scores)
        scores = temp_dist_warp(input_ids, scores)
        scores = rep_penalty_proc(input_ids, scores)
        scores = top_k_warp(input_ids, scores)
        scores = top_p_warp(input_ids, scores)
        scores = no_repeat_proc(input_ids, scores)
        scores = no_bad_words_dist_proc(input_ids, scores)

        # with processor list
        processor = LogitsProcessorList([
            min_dist_proc,
            temp_dist_warp,
            rep_penalty_proc,
            top_k_warp,
            top_p_warp,
            no_repeat_proc,
            no_bad_words_dist_proc,
        ])
        scores_comp = processor(input_ids, scores_comp)

        # scores should be equal
        self.assertTrue(torch.allclose(scores, scores_comp, atol=1e-3))

        # input_ids should never be changed
        self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
 def _get_logits_processor_and_kwargs(input_length, eos_token_id):
     process_kwargs = {
         "min_length": input_length + 1,
         "bad_words_ids": [[1, 0]],
         "no_repeat_ngram_size": 2,
         "repetition_penalty": 1.2,
     }
     logits_processor = LogitsProcessorList(([
         MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id
                                  ),
     ] if eos_token_id is not None else []) + [
         NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"],
                                   eos_token_id),
         NoRepeatNGramLogitsProcessor(
             process_kwargs["no_repeat_ngram_size"]),
         RepetitionPenaltyLogitsProcessor(
             process_kwargs["repetition_penalty"]),
     ])
     return process_kwargs, logits_processor