def test_beam_lm_update_memory_length_when_finished(self):
     beam = BeamSearchLM(
         self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST,
         GlobalScorerStub(),
         0, 30, False, 0, set(),
         False, 0.)
     device_init = torch.zeros(1, 1)
     src_lengths = torch.randint(0, 30, (self.BATCH_SZ,))
     fn_map_state, _, _, _ = beam.initialize(device_init, src_lengths)
     expected_beam_scores = self.init_step(beam, 1)
     self.finish_first_beam_step(beam)
     
     n_steps = beam.alive_seq.shape[-1] - 1
     self.assertTrue(beam.memory_lengths.equal(n_steps+fn_map_state(src_lengths[1:], dim=0)))
Пример #2
0
 def translate_batch(self, batch, src_vocabs, attn_debug):
     """Translate a batch of sentences."""
     with torch.no_grad():
         if self.sample_from_topk != 0 or self.sample_from_topp != 0:
             decode_strategy = GreedySearchLM(
                 pad=self._tgt_pad_idx,
                 bos=self._tgt_bos_idx,
                 eos=self._tgt_eos_idx,
                 unk=self._tgt_unk_idx,
                 batch_size=batch.batch_size,
                 global_scorer=self.global_scorer,
                 min_length=self.min_length,
                 max_length=self.max_length,
                 block_ngram_repeat=self.block_ngram_repeat,
                 exclusion_tokens=self._exclusion_idxs,
                 return_attention=attn_debug or self.replace_unk,
                 sampling_temp=self.random_sampling_temp,
                 keep_topk=self.sample_from_topk,
                 keep_topp=self.sample_from_topp,
                 beam_size=self.beam_size,
                 ban_unk_token=self.ban_unk_token,
             )
         else:
             # TODO: support these blacklisted features
             assert not self.dump_beam
             decode_strategy = BeamSearchLM(
                 self.beam_size,
                 batch_size=batch.batch_size,
                 pad=self._tgt_pad_idx,
                 bos=self._tgt_bos_idx,
                 eos=self._tgt_eos_idx,
                 unk=self._tgt_unk_idx,
                 n_best=self.n_best,
                 global_scorer=self.global_scorer,
                 min_length=self.min_length,
                 max_length=self.max_length,
                 return_attention=attn_debug or self.replace_unk,
                 block_ngram_repeat=self.block_ngram_repeat,
                 exclusion_tokens=self._exclusion_idxs,
                 stepwise_penalty=self.stepwise_penalty,
                 ratio=self.ratio,
                 ban_unk_token=self.ban_unk_token,
             )
         return self._translate_batch_with_strategy(batch, src_vocabs,
                                                    decode_strategy)