def test_jit_no_extension(self): bsz, vocab_size, beam_size, step = 2, 4, 1, 3 generated_tok = torch.tensor([[2, 2, 2, 2], [3, 3, 3, 3]], dtype=torch.long, device="cuda") lprobs = torch.zeros((beam_size * bsz, vocab_size), device="cuda") blocker = NGramRepeatBlock(2, use_extension=False) base_result = blocker(generated_tok, lprobs.clone(), bsz, beam_size, step) scripted_blocker = torch.jit.script(blocker) jit_result = scripted_blocker(generated_tok, lprobs.clone(), bsz, beam_size, step) self.assertTensorEqual(base_result, jit_result)
def _compare_cuda_ext_to_default_implem(self, bsz, beam_size, generated_tok, lprobs, step, block_param): """Assert that cuda extension and default implem return the same thing.""" blocker = NGramRepeatBlock(block_param) assert blocker.use_extension, "Extension not compiled" cuda_ext_result = blocker( generated_tok, lprobs.clone(), bsz, beam_size, step, ) blocker.use_extension = False baseline_result = blocker( generated_tok, lprobs.clone(), bsz, beam_size, step, ) self.assertTensorEqual(cuda_ext_result, baseline_result) blocker.use_extension = True return cuda_ext_result, baseline_result
def __init__( self, models, tgt_dict, beam_size=1, max_len_a=0, max_len_b=200, min_len=1, normalize_scores=True, len_penalty=1.0, unk_penalty=0.0, temperature=1.0, match_source_len=False, no_repeat_ngram_size=0, search_strategy=None, eos=None, symbols_to_strip_from_output=None, lm_model=None, lm_weight=1.0, **kwargs, ): """Generates translations of a given source sentence. Args: models (List[~fairseq.models.FairseqModel]): ensemble of models, currently support fairseq.models.TransformerModel for scripting beam_size (int, optional): beam width (default: 1) max_len_a/b (int, optional): generate sequences of maximum length ax + b, where x is the source length min_len (int, optional): the minimum length of the generated output (not including end-of-sentence) normalize_scores (bool, optional): normalize scores by the length of the output (default: True) len_penalty (float, optional): length penalty, where <1.0 favors shorter, >1.0 favors longer sentences (default: 1.0) unk_penalty (float, optional): unknown word penalty, where <0 produces more unks, >0 produces fewer (default: 0.0) temperature (float, optional): temperature, where values >1.0 produce more uniform samples and values <1.0 produce sharper samples (default: 1.0) match_source_len (bool, optional): outputs should match the source length (default: False) """ super().__init__() if isinstance(models, EnsembleModel): self.model = models else: self.model = EnsembleModel(models) self.tgt_dict = tgt_dict self.pad = tgt_dict.pad() self.unk = tgt_dict.unk() self.eos = tgt_dict.eos() if eos is None else eos self.symbols_to_strip_from_output = ( symbols_to_strip_from_output.union({self.eos}) if symbols_to_strip_from_output is not None else {self.eos}) self.vocab_size = len(tgt_dict) self.beam_size = beam_size # the max beam size is the dictionary size - 1, since we never select pad self.beam_size = min(beam_size, self.vocab_size - 1) self.max_len_a = max_len_a self.max_len_b = max_len_b self.min_len = min_len self.normalize_scores = normalize_scores self.len_penalty = len_penalty self.unk_penalty = unk_penalty self.temperature = temperature self.match_source_len = match_source_len self.no_repeat_ngram_size = no_repeat_ngram_size self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) self.eos_factor = kwargs.get("eos_factor", None) assert temperature > 0, "--temperature must be greater than 0" assert self.eos_factor is None or self.eos_factor >= 1.0, "--eos-factor must be >= 1.0 if set" self.search = (search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy) # We only need to set src_lengths in LengthConstrainedBeamSearch. # As a module attribute, setting it would break in multithread # settings when the model is shared. self.should_set_src_lengths = (hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths) self.model.eval() self.lm_model = lm_model self.lm_weight = lm_weight if self.lm_model is not None: self.lm_model.eval()