def test_beam_lm_update_memory_length_when_finished(self): beam = BeamSearchLM( self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST, GlobalScorerStub(), 0, 30, False, 0, set(), False, 0.) device_init = torch.zeros(1, 1) src_lengths = torch.randint(0, 30, (self.BATCH_SZ,)) fn_map_state, _, _, _ = beam.initialize(device_init, src_lengths) expected_beam_scores = self.init_step(beam, 1) self.finish_first_beam_step(beam) n_steps = beam.alive_seq.shape[-1] - 1 self.assertTrue(beam.memory_lengths.equal(n_steps+fn_map_state(src_lengths[1:], dim=0)))
def translate_batch(self, batch, src_vocabs, attn_debug): """Translate a batch of sentences.""" with torch.no_grad(): if self.sample_from_topk != 0 or self.sample_from_topp != 0: decode_strategy = GreedySearchLM( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, unk=self._tgt_unk_idx, batch_size=batch.batch_size, global_scorer=self.global_scorer, min_length=self.min_length, max_length=self.max_length, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, sampling_temp=self.random_sampling_temp, keep_topk=self.sample_from_topk, keep_topp=self.sample_from_topp, beam_size=self.beam_size, ban_unk_token=self.ban_unk_token, ) else: # TODO: support these blacklisted features assert not self.dump_beam decode_strategy = BeamSearchLM( self.beam_size, batch_size=batch.batch_size, pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, unk=self._tgt_unk_idx, n_best=self.n_best, global_scorer=self.global_scorer, min_length=self.min_length, max_length=self.max_length, return_attention=attn_debug or self.replace_unk, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, stepwise_penalty=self.stepwise_penalty, ratio=self.ratio, ban_unk_token=self.ban_unk_token, ) return self._translate_batch_with_strategy(batch, src_vocabs, decode_strategy)