Beispiel #1
0
 def test_jit_no_extension(self):
     bsz, vocab_size, beam_size, step = 2, 4, 1, 3
     generated_tok = torch.tensor([[2, 2, 2, 2], [3, 3, 3, 3]],
                                  dtype=torch.long,
                                  device="cuda")
     lprobs = torch.zeros((beam_size * bsz, vocab_size), device="cuda")
     blocker = NGramRepeatBlock(2, use_extension=False)
     base_result = blocker(generated_tok, lprobs.clone(), bsz, beam_size,
                           step)
     scripted_blocker = torch.jit.script(blocker)
     jit_result = scripted_blocker(generated_tok, lprobs.clone(), bsz,
                                   beam_size, step)
     self.assertTensorEqual(base_result, jit_result)
Beispiel #2
0
 def _compare_cuda_ext_to_default_implem(self, bsz, beam_size,
                                         generated_tok, lprobs, step,
                                         block_param):
     """Assert that cuda extension and default implem return the same thing."""
     blocker = NGramRepeatBlock(block_param)
     assert blocker.use_extension, "Extension not compiled"
     cuda_ext_result = blocker(
         generated_tok,
         lprobs.clone(),
         bsz,
         beam_size,
         step,
     )
     blocker.use_extension = False
     baseline_result = blocker(
         generated_tok,
         lprobs.clone(),
         bsz,
         beam_size,
         step,
     )
     self.assertTensorEqual(cuda_ext_result, baseline_result)
     blocker.use_extension = True
     return cuda_ext_result, baseline_result
Beispiel #3
0
    def __init__(
        self,
        models,
        tgt_dict,
        beam_size=1,
        max_len_a=0,
        max_len_b=200,
        min_len=1,
        normalize_scores=True,
        len_penalty=1.0,
        unk_penalty=0.0,
        temperature=1.0,
        match_source_len=False,
        no_repeat_ngram_size=0,
        search_strategy=None,
        eos=None,
        symbols_to_strip_from_output=None,
        lm_model=None,
        lm_weight=1.0,
        **kwargs,
    ):
        """Generates translations of a given source sentence.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models,
                currently support fairseq.models.TransformerModel for scripting
            beam_size (int, optional): beam width (default: 1)
            max_len_a/b (int, optional): generate sequences of maximum length
                ax + b, where x is the source length
            min_len (int, optional): the minimum length of the generated output
                (not including end-of-sentence)
            normalize_scores (bool, optional): normalize scores by the length
                of the output (default: True)
            len_penalty (float, optional): length penalty, where <1.0 favors
                shorter, >1.0 favors longer sentences (default: 1.0)
            unk_penalty (float, optional): unknown word penalty, where <0
                produces more unks, >0 produces fewer (default: 0.0)
            temperature (float, optional): temperature, where values
                >1.0 produce more uniform samples and values <1.0 produce
                sharper samples (default: 1.0)
            match_source_len (bool, optional): outputs should match the source
                length (default: False)
        """
        super().__init__()
        if isinstance(models, EnsembleModel):
            self.model = models
        else:
            self.model = EnsembleModel(models)
        self.tgt_dict = tgt_dict
        self.pad = tgt_dict.pad()
        self.unk = tgt_dict.unk()
        self.eos = tgt_dict.eos() if eos is None else eos
        self.symbols_to_strip_from_output = (
            symbols_to_strip_from_output.union({self.eos})
            if symbols_to_strip_from_output is not None else {self.eos})
        self.vocab_size = len(tgt_dict)
        self.beam_size = beam_size
        # the max beam size is the dictionary size - 1, since we never select pad
        self.beam_size = min(beam_size, self.vocab_size - 1)
        self.max_len_a = max_len_a
        self.max_len_b = max_len_b
        self.min_len = min_len

        self.normalize_scores = normalize_scores
        self.len_penalty = len_penalty
        self.unk_penalty = unk_penalty
        self.temperature = temperature
        self.match_source_len = match_source_len

        self.no_repeat_ngram_size = no_repeat_ngram_size
        self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)

        self.eos_factor = kwargs.get("eos_factor", None)
        assert temperature > 0, "--temperature must be greater than 0"
        assert self.eos_factor is None or self.eos_factor >= 1.0, "--eos-factor must be >= 1.0 if set"

        self.search = (search.BeamSearch(tgt_dict)
                       if search_strategy is None else search_strategy)
        # We only need to set src_lengths in LengthConstrainedBeamSearch.
        # As a module attribute, setting it would break in multithread
        # settings when the model is shared.
        self.should_set_src_lengths = (hasattr(self.search,
                                               "needs_src_lengths")
                                       and self.search.needs_src_lengths)

        self.model.eval()

        self.lm_model = lm_model
        self.lm_weight = lm_weight
        if self.lm_model is not None:
            self.lm_model.eval()