def translate_batch(self, batch, src_vocabs, attn_debug): """Translate a batch of sentences.""" with torch.no_grad(): if self.beam_size == 1: decode_strategy = GreedySearch( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, batch_size=batch.batch_size, min_length=self.min_length, max_length=self.max_length, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, sampling_temp=self.random_sampling_temp, keep_topk=self.sample_from_topk) else: # TODO: support these blacklisted features assert not self.dump_beam decode_strategy = BeamSearch( self.beam_size, batch_size=batch.batch_size, pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, n_best=self.n_best, global_scorer=self.global_scorer, min_length=self.min_length, max_length=self.max_length, return_attention=attn_debug or self.replace_unk, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, stepwise_penalty=self.stepwise_penalty, ratio=self.ratio) return self._translate_batch_with_strategy(batch, src_vocabs, decode_strategy)
def translate_batch(self, batch, src_vocabs, attn_debug): """Translate a batch of sentences.""" with torch.no_grad(): tic = time.perf_counter() if self.beam_size == 1: decode_strategy = GreedySearch( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, batch_size=batch.batch_size, min_length=self.min_length, max_length=self.max_length, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, sampling_temp=self.random_sampling_temp, keep_topk=self.sample_from_topk) else: # TODO: support these blacklisted features assert not self.dump_beam decode_strategy = BeamSearch( self.beam_size, batch_size=batch.batch_size, pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, n_best=self.n_best, global_scorer=self.global_scorer, min_length=self.min_length, max_length=self.max_length, return_attention=attn_debug or self.replace_unk, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, stepwise_penalty=self.stepwise_penalty, ratio=self.ratio) toc = time.perf_counter() beam_search_time = toc - tic tic = time.perf_counter() ret = self._translate_batch_with_strategy(batch, src_vocabs, decode_strategy) toc = time.perf_counter() translate_batch_with_strategy_time = toc - tic if show_profile_detail: print( f"BeamSearch Time {beam_search_time:0.4f} seconds, translate_batch_with_strategy Time {translate_batch_with_strategy_time: 0.4f} seconds" ) return ret
def sample_from_batch(self, batch): with torch.no_grad(): decode_strategy = GreedySearch( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, batch_size=batch.batch_size, min_length=self.min_length, max_length=self.max_length, block_ngram_repeat=0, exclusion_tokens=self._exclusion_idxs, return_attention=False, sampling_temp=self.random_sampling_temp, keep_topk=self.sample_from_topk) return self._sample_batch_with_strategy(batch, decode_strategy)
def translate_batch(self, batch, src_vocabs, attn_debug, src=None, enc_states=None, memory_bank=None, \ src_lengths=None, src_embed=None, tgt2=False, hidden_state=None): """Translate a batch of sentences.""" if self.beam_size == 1: decode_strategy = GreedySearch( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, batch_size=batch.batch_size, min_length=self.min_length, max_length=self.max_length, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, sampling_temp=self.random_sampling_temp, keep_topk=self.sample_from_topk) else: # TODO: support these blacklisted features assert not self.dump_beam decode_strategy = BeamSearch( self.beam_size, batch_size=batch.batch_size, pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, n_best=self.n_best, global_scorer=self.global_scorer, min_length=self.min_length, max_length=self.max_length, return_attention=attn_debug or self.replace_unk, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, stepwise_penalty=self.stepwise_penalty, ratio=self.ratio) return self._return_gold(batch, src_vocabs, decode_strategy, src, enc_states, memory_bank, src_lengths, src_embed, tgt2, hidden_state=hidden_state)
def test_doesnt_predict_eos_if_shorter_than_min_len(self): # batch 0 will always predict EOS. The other batches will predict # non-eos scores. for batch_sz in [1, 3]: n_words = 100 _non_eos_idxs = [47] valid_score_dist = torch.log_softmax(torch.tensor( [6., 5.]), dim=0) min_length = 5 eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = GreedySearch( 0, 1, 2, batch_sz, min_length, False, set(), False, 30, 1., 1) samp.initialize(torch.zeros(1), lengths) all_attns = [] for i in range(min_length + 4): word_probs = torch.full( (batch_sz, n_words), -float('inf')) # "best" prediction is eos - that should be blocked word_probs[0, eos_idx] = valid_score_dist[0] # include at least one prediction OTHER than EOS # that is greater than -1e20 word_probs[0, _non_eos_idxs[0]] = valid_score_dist[1] word_probs[1:, _non_eos_idxs[0] + i] = 0 attns = torch.randn(1, batch_sz, 53) all_attns.append(attns) samp.advance(word_probs, attns) if i < min_length: self.assertTrue( samp.topk_scores[0].allclose(valid_score_dist[1])) self.assertTrue( samp.topk_scores[1:].eq(0).all()) elif i == min_length: # now batch 0 has ended and no others have self.assertTrue(samp.is_finished[0, :].eq(1).all()) self.assertTrue(samp.is_finished[1:, 1:].eq(0).all()) else: # i > min_length break
def test_returns_correct_scores_deterministic(self): for batch_sz in [1, 13]: for temp in [1., 3.]: n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist_1 = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) valid_score_dist_2 = torch.log_softmax(torch.tensor( [6., 1.]), dim=0) eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = GreedySearch( 0, 1, 2, batch_sz, 0, False, set(), False, 30, temp, 1) samp.initialize(torch.zeros(1), lengths) # initial step i = 0 word_probs = torch.full( (batch_sz, n_words), -float('inf')) # batch 0 dies on step 0 word_probs[0, eos_idx] = valid_score_dist_1[0] # include at least one prediction OTHER than EOS # that is greater than -1e20 word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:] word_probs[1:, _non_eos_idxs[0] + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) self.assertTrue(samp.is_finished[0].eq(1).all()) samp.update_finished() self.assertEqual( samp.scores[0], [valid_score_dist_1[0] / temp]) if batch_sz == 1: self.assertTrue(samp.done) continue else: self.assertFalse(samp.done) # step 2 i = 1 word_probs = torch.full( (batch_sz - 1, n_words), -float('inf')) # (old) batch 8 dies on step 1 word_probs[7, eos_idx] = valid_score_dist_2[0] word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2 word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) self.assertTrue(samp.is_finished[7].eq(1).all()) samp.update_finished() self.assertEqual( samp.scores[8], [valid_score_dist_2[0] / temp]) # step 3 i = 2 word_probs = torch.full( (batch_sz - 2, n_words), -float('inf')) # everything dies word_probs[:, eos_idx] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) self.assertTrue(samp.is_finished.eq(1).all()) samp.update_finished() for b in range(batch_sz): if b != 0 and b != 8: self.assertEqual(samp.scores[b], [0]) self.assertTrue(samp.done)
def test_returns_correct_scores_non_deterministic(self): for batch_sz in [1, 13]: for temp in [1., 3.]: n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist_1 = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) valid_score_dist_2 = torch.log_softmax(torch.tensor( [6., 1.]), dim=0) eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) samp = GreedySearch( 0, 1, 2, batch_sz, 0, False, set(), False, 30, temp, 2) samp.initialize(torch.zeros(1), lengths) # initial step i = 0 for _ in range(100): word_probs = torch.full( (batch_sz, n_words), -float('inf')) # batch 0 dies on step 0 word_probs[0, eos_idx] = valid_score_dist_1[0] # include at least one prediction OTHER than EOS # that is greater than -1e20 word_probs[0, _non_eos_idxs] = valid_score_dist_1[1:] word_probs[1:, _non_eos_idxs[0] + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished[0].eq(1).all(): break else: self.fail("Batch 0 never ended (very unlikely but maybe " "due to stochasticisty. If so, please increase " "the range of the for-loop.") samp.update_finished() self.assertEqual( samp.scores[0], [valid_score_dist_1[0] / temp]) if batch_sz == 1: self.assertTrue(samp.done) continue else: self.assertFalse(samp.done) # step 2 i = 1 for _ in range(100): word_probs = torch.full( (batch_sz - 1, n_words), -float('inf')) # (old) batch 8 dies on step 1 word_probs[7, eos_idx] = valid_score_dist_2[0] word_probs[0:7, _non_eos_idxs[:2]] = valid_score_dist_2 word_probs[8:, _non_eos_idxs[:2]] = valid_score_dist_2 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished[7].eq(1).all(): break else: self.fail("Batch 8 never ended (very unlikely but maybe " "due to stochasticisty. If so, please increase " "the range of the for-loop.") samp.update_finished() self.assertEqual( samp.scores[8], [valid_score_dist_2[0] / temp]) # step 3 i = 2 for _ in range(250): word_probs = torch.full( (samp.alive_seq.shape[0], n_words), -float('inf')) # everything dies word_probs[:, eos_idx] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished.any(): samp.update_finished() if samp.is_finished.eq(1).all(): break else: self.fail("All batches never ended (very unlikely but " "maybe due to stochasticisty. If so, please " "increase the range of the for-loop.") for b in range(batch_sz): if b != 0 and b != 8: self.assertEqual(samp.scores[b], [0]) self.assertTrue(samp.done)
def test_returns_correct_scores_non_deterministic_beams(self): beam_size = 10 for batch_sz in [1, 13]: for temp in [1., 3.]: n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist_1 = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) valid_score_dist_2 = torch.log_softmax(torch.tensor([6., 1.]), dim=0) eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz, )) samp = GreedySearch(0, 1, 2, 3, batch_sz, GlobalScorerStub(), 0, False, set(), False, 30, temp, 50, 0, beam_size, False) samp.initialize(torch.zeros((1, 1)), lengths) # initial step # finish one beam i = 0 for _ in range(100): word_probs = torch.full((batch_sz * beam_size, n_words), -float('inf')) word_probs[beam_size - 2, eos_idx] = valid_score_dist_1[0] # include at least one prediction OTHER than EOS # that is greater than -1e20 word_probs[beam_size - 2, _non_eos_idxs] = valid_score_dist_1[1:] word_probs[beam_size - 2 + 1:, _non_eos_idxs[0] + i] = 0 word_probs[:beam_size - 2, _non_eos_idxs[0] + i] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished[beam_size - 2].eq(1).all(): self.assertFalse(samp.is_finished[:beam_size - 2].eq(1).any()) self.assertFalse(samp.is_finished[beam_size - 2 + 1].eq(1).any()) break else: self.fail("Batch 0 never ended (very unlikely but maybe " "due to stochasticisty. If so, please increase " "the range of the for-loop.") samp.update_finished() self.assertEqual([samp.topk_scores[beam_size - 2]], [valid_score_dist_1[0] / temp]) # step 2 # finish example in last batch i = 1 for _ in range(100): word_probs = torch.full( (batch_sz * beam_size - 1, n_words), -float('inf')) # (old) batch 8 dies on step 1 word_probs[(batch_sz - 1) * beam_size + 7, eos_idx] = valid_score_dist_2[0] word_probs[:(batch_sz - 1) * beam_size + 7, _non_eos_idxs[:2]] = valid_score_dist_2 word_probs[(batch_sz - 1) * beam_size + 8:, _non_eos_idxs[:2]] = valid_score_dist_2 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if (samp.is_finished[(batch_sz - 1) * beam_size + 7].eq(1).all()): break else: self.fail("Batch 8 never ended (very unlikely but maybe " "due to stochasticisty. If so, please increase " "the range of the for-loop.") samp.update_finished() self.assertEqual([ score for score, _, _ in samp.hypotheses[batch_sz - 1][-1:] ], [valid_score_dist_2[0] / temp]) # step 3 i = 2 for _ in range(250): word_probs = torch.full((samp.alive_seq.shape[0], n_words), -float('inf')) # everything dies word_probs[:, eos_idx] = 0 attns = torch.randn(1, batch_sz, 53) samp.advance(word_probs, attns) if samp.is_finished.any(): samp.update_finished() if samp.is_finished.eq(1).all(): break else: self.fail("All batches never ended (very unlikely but " "maybe due to stochasticisty. If so, please " "increase the range of the for-loop.") self.assertTrue(samp.done)