def test_advance_with_all_repeats_gets_blocked(self): # all beams repeat (beam >= 1 repeat dummy scores) beam_sz = 5 n_words = 100 repeat_idx = 47 ngram_repeat = 3 for batch_sz in [1, 3]: beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2, torch.device("cpu"), GlobalScorerStub(), 0, 30, False, ngram_repeat, set(), torch.randint(0, 30, (batch_sz, )), False) for i in range(ngram_repeat + 4): # predict repeat_idx over and over again word_probs = torch.full((batch_sz * beam_sz, n_words), -float('inf')) word_probs[0::beam_sz, repeat_idx] = 0 attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) if i <= ngram_repeat: expected_scores = torch.tensor( [0] + [-float('inf')] * (beam_sz - 1))\ .repeat(batch_sz, 1) self.assertTrue(beam.topk_log_probs.equal(expected_scores)) else: self.assertTrue( beam.topk_log_probs.equal( torch.tensor(self.BLOCKED_SCORE).repeat( batch_sz, beam_sz)))
def test_beam_is_done_when_n_best_beams_eos_using_min_length(self): # this is also a test that when block_ngram_repeat=0, # repeating is acceptable beam_sz = 5 batch_sz = 3 n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) min_length = 5 eos_idx = 2 beam = BeamSearch( beam_sz, batch_sz, 0, 1, 2, 2, GlobalScorerStub(), min_length, 30, False, 0, set(), False, 0.) device_init = torch.zeros(1, 1) beam.initialize(device_init, torch.randint(0, 30, (batch_sz,))) for i in range(min_length + 4): # non-interesting beams are going to get dummy values word_probs = torch.full( (batch_sz * beam_sz, n_words), -float('inf')) if i == 0: # "best" prediction is eos - that should be blocked word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] # include at least beam_sz predictions OTHER than EOS # that are greater than -1e20 for j, score in zip(_non_eos_idxs, valid_score_dist[1:]): word_probs[0::beam_sz, j] = score elif i <= min_length: # predict eos in beam 1 word_probs[1::beam_sz, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions in other beams for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz-1, k) word_probs[beam_idx::beam_sz, j] = score else: word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] word_probs[1::beam_sz, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions in other beams for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz-1, k) word_probs[beam_idx::beam_sz, j] = score attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) if i < min_length: self.assertFalse(beam.done) elif i == min_length: # beam 1 dies on min_length self.assertTrue(beam.is_finished[:, 1].all()) beam.update_finished() self.assertFalse(beam.done) else: # i > min_length # beam 0 dies on the step after beam 1 dies self.assertTrue(beam.is_finished[:, 0].all()) beam.update_finished() self.assertTrue(beam.done)
def test_advance_with_all_repeats_gets_blocked(self): # all beams repeat (beam >= 1 repeat dummy scores) beam_sz = 5 n_words = 100 repeat_idx = 47 ngram_repeat = 3 for batch_sz in [1, 3]: beam = BeamSearch( beam_sz, batch_sz, 0, 1, 2, 2, torch.device("cpu"), GlobalScorerStub(), 0, 30, False, ngram_repeat, set(), torch.randint(0, 30, (batch_sz,)), False, 0.) for i in range(ngram_repeat + 4): # predict repeat_idx over and over again word_probs = torch.full( (batch_sz * beam_sz, n_words), -float('inf')) word_probs[0::beam_sz, repeat_idx] = 0 attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) if i <= ngram_repeat: expected_scores = torch.tensor( [0] + [-float('inf')] * (beam_sz - 1))\ .repeat(batch_sz, 1) self.assertTrue(beam.topk_log_probs.equal(expected_scores)) else: self.assertTrue( beam.topk_log_probs.equal( torch.tensor(self.BLOCKED_SCORE) .repeat(batch_sz, beam_sz)))
def test_beam_advance_against_known_reference(self): beam = BeamSearch(self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST, GlobalScorerStub(), 0, 30, False, 0, set(), False, 0.) device_init = torch.zeros(1, 1) beam.initialize(device_init, torch.randint(0, 30, (self.BATCH_SZ, ))) expected_beam_scores = self.init_step(beam, 1) expected_beam_scores = self.first_step(beam, expected_beam_scores, 1) expected_beam_scores = self.second_step(beam, expected_beam_scores, 1) self.third_step(beam, expected_beam_scores, 1)
def test_beam_advance_against_known_reference(self): scorer = GNMTGlobalScorer(0.7, 0., "avg", "none") beam = BeamSearch(self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST, scorer, 0, 30, False, 0, set(), False, 0.) device_init = torch.zeros(1, 1) beam.initialize(device_init, torch.randint(0, 30, (self.BATCH_SZ, ))) expected_beam_scores = self.init_step(beam, 1.) expected_beam_scores = self.first_step(beam, expected_beam_scores, 3) expected_beam_scores = self.second_step(beam, expected_beam_scores, 4) self.third_step(beam, expected_beam_scores, 5)
def test_doesnt_predict_eos_if_shorter_than_min_len(self): # beam 0 will always predict EOS. The other beams will predict # non-eos scores. for batch_sz in [1, 3]: beam_sz = 5 n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) min_length = 5 eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2, torch.device("cpu"), GlobalScorerStub(), min_length, 30, False, 0, set(), lengths, False, 0.) all_attns = [] for i in range(min_length + 4): # non-interesting beams are going to get dummy values word_probs = torch.full( (batch_sz * beam_sz, n_words), -float('inf')) if i == 0: # "best" prediction is eos - that should be blocked word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] # include at least beam_sz predictions OTHER than EOS # that are greater than -1e20 for j, score in zip(_non_eos_idxs, valid_score_dist[1:]): word_probs[0::beam_sz, j] = score else: # predict eos in beam 0 word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz-1, k) word_probs[beam_idx::beam_sz, j] = score attns = torch.randn(1, batch_sz * beam_sz, 53) all_attns.append(attns) beam.advance(word_probs, attns) if i < min_length: expected_score_dist = \ (i+1) * valid_score_dist[1:].unsqueeze(0) self.assertTrue( beam.topk_log_probs.allclose( expected_score_dist)) elif i == min_length: # now the top beam has ended and no others have self.assertTrue(beam.is_finished[:, 0].eq(1).all()) self.assertTrue(beam.is_finished[:, 1:].eq(0).all()) else: # i > min_length # not of interest, but want to make sure it keeps running # since only beam 0 terminates and n_best = 2 pass
def test_doesnt_predict_eos_if_shorter_than_min_len(self): # beam 0 will always predict EOS. The other beams will predict # non-eos scores. for batch_sz in [1, 3]: beam_sz = 5 n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) min_length = 5 eos_idx = 2 lengths = torch.randint(0, 30, (batch_sz,)) beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2, torch.device("cpu"), GlobalScorerStub(), min_length, 30, False, 0, set(), lengths, False) all_attns = [] for i in range(min_length + 4): # non-interesting beams are going to get dummy values word_probs = torch.full( (batch_sz * beam_sz, n_words), -float('inf')) if i == 0: # "best" prediction is eos - that should be blocked word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] # include at least beam_sz predictions OTHER than EOS # that are greater than -1e20 for j, score in zip(_non_eos_idxs, valid_score_dist[1:]): word_probs[0::beam_sz, j] = score else: # predict eos in beam 0 word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz-1, k) word_probs[beam_idx::beam_sz, j] = score attns = torch.randn(1, batch_sz * beam_sz, 53) all_attns.append(attns) beam.advance(word_probs, attns) if i < min_length: expected_score_dist = \ (i+1) * valid_score_dist[1:].unsqueeze(0) self.assertTrue( beam.topk_log_probs.allclose( expected_score_dist)) elif i == min_length: # now the top beam has ended and no others have self.assertTrue(beam.is_finished[:, 0].eq(1).all()) self.assertTrue(beam.is_finished[:, 1:].eq(0).all()) else: # i > min_length # not of interest, but want to make sure it keeps running # since only beam 0 terminates and n_best = 2 pass
def test_repeating_excluded_index_does_not_die(self): # beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx) beam_sz = 5 n_words = 100 repeat_idx = 47 # will be repeated and should be blocked repeat_idx_ignored = 7 # will be repeated and should not be blocked ngram_repeat = 3 for batch_sz in [1, 3]: beam = BeamSearch( beam_sz, batch_sz, 0, 1, 2, 2, torch.device("cpu"), GlobalScorerStub(), 0, 30, False, ngram_repeat, {repeat_idx_ignored}, torch.randint(0, 30, (batch_sz,))) for i in range(ngram_repeat + 4): # non-interesting beams are going to get dummy values word_probs = torch.full( (batch_sz * beam_sz, n_words), -float('inf')) if i == 0: word_probs[0::beam_sz, repeat_idx] = -0.1 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3 word_probs[0::beam_sz, repeat_idx_ignored] = -5.0 else: # predict the same thing in beam 0 word_probs[0::beam_sz, repeat_idx] = 0 # continue pushing around what beam 1 predicts word_probs[1::beam_sz, repeat_idx + i + 1] = 0 # predict the allowed-repeat again in beam 2 word_probs[2::beam_sz, repeat_idx_ignored] = 0 attns = torch.randn(beam_sz) beam.advance(word_probs, attns) if i <= ngram_repeat: self.assertFalse(beam.topk_log_probs[:, 0].eq( self.BLOCKED_SCORE).any()) self.assertFalse(beam.topk_log_probs[:, 1].eq( self.BLOCKED_SCORE).any()) self.assertFalse(beam.topk_log_probs[:, 2].eq( self.BLOCKED_SCORE).any()) else: # now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1 # and the rest die self.assertFalse(beam.topk_log_probs[:, 0].eq( self.BLOCKED_SCORE).any()) # since all preds after i=0 are 0, we can check # that the beam is the correct idx by checking that # the curr score is the initial score self.assertTrue(beam.topk_log_probs[:, 0].eq(-2.3).all()) self.assertFalse(beam.topk_log_probs[:, 1].eq( self.BLOCKED_SCORE).all()) self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all()) self.assertTrue( beam.topk_log_probs[:, 2:].equal( torch.tensor(self.BLOCKED_SCORE) .repeat(batch_sz, beam_sz - 2)))
def test_repeating_excluded_index_does_not_die(self): # beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx) beam_sz = 5 n_words = 100 repeat_idx = 47 # will be repeated and should be blocked repeat_idx_ignored = 7 # will be repeated and should not be blocked ngram_repeat = 3 for batch_sz in [1, 3]: beam = BeamSearch( beam_sz, batch_sz, 0, 1, 2, 2, torch.device("cpu"), GlobalScorerStub(), 0, 30, False, ngram_repeat, {repeat_idx_ignored}, torch.randint(0, 30, (batch_sz,)), False, 0.) for i in range(ngram_repeat + 4): # non-interesting beams are going to get dummy values word_probs = torch.full( (batch_sz * beam_sz, n_words), -float('inf')) if i == 0: word_probs[0::beam_sz, repeat_idx] = -0.1 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3 word_probs[0::beam_sz, repeat_idx_ignored] = -5.0 else: # predict the same thing in beam 0 word_probs[0::beam_sz, repeat_idx] = 0 # continue pushing around what beam 1 predicts word_probs[1::beam_sz, repeat_idx + i + 1] = 0 # predict the allowed-repeat again in beam 2 word_probs[2::beam_sz, repeat_idx_ignored] = 0 attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) if i <= ngram_repeat: self.assertFalse(beam.topk_log_probs[:, 0].eq( self.BLOCKED_SCORE).any()) self.assertFalse(beam.topk_log_probs[:, 1].eq( self.BLOCKED_SCORE).any()) self.assertFalse(beam.topk_log_probs[:, 2].eq( self.BLOCKED_SCORE).any()) else: # now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1 # and the rest die self.assertFalse(beam.topk_log_probs[:, 0].eq( self.BLOCKED_SCORE).any()) # since all preds after i=0 are 0, we can check # that the beam is the correct idx by checking that # the curr score is the initial score self.assertTrue(beam.topk_log_probs[:, 0].eq(-2.3).all()) self.assertFalse(beam.topk_log_probs[:, 1].eq( self.BLOCKED_SCORE).all()) self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all()) self.assertTrue( beam.topk_log_probs[:, 2:].equal( torch.tensor(self.BLOCKED_SCORE) .repeat(batch_sz, beam_sz - 2)))
def translate_batch(self, batch, src_vocabs, attn_debug): """Translate a batch of sentences.""" with torch.no_grad(): if self.beam_size == 1: decode_strategy = GreedySearch( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, batch_size=batch.batch_size, min_length=self.min_length, max_length=self.max_length, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, sampling_temp=self.random_sampling_temp, keep_topk=self.sample_from_topk) else: # TODO: support these blacklisted features assert not self.dump_beam decode_strategy = BeamSearch( self.beam_size, batch_size=batch.batch_size, pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, n_best=self.n_best, global_scorer=self.global_scorer, min_length=self.min_length, max_length=self.max_length, return_attention=attn_debug or self.replace_unk, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, stepwise_penalty=self.stepwise_penalty, ratio=self.ratio) return self._translate_batch_with_strategy(batch, src_vocabs, decode_strategy)
def test_advance_with_some_repeats_gets_blocked(self): # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores) beam_sz = 5 n_words = 100 repeat_idx = 47 ngram_repeat = 3 for batch_sz in [1, 3]: beam = BeamSearch( beam_sz, batch_sz, 0, 1, 2, 2, torch.device("cpu"), GlobalScorerStub(), 0, 30, False, ngram_repeat, set(), torch.randint(0, 30, (batch_sz,))) for i in range(ngram_repeat + 4): # non-interesting beams are going to get dummy values word_probs = torch.full( (batch_sz * beam_sz, n_words), -float('inf')) if i == 0: # on initial round, only predicted scores for beam 0 # matter. Make two predictions. Top one will be repeated # in beam zero, second one will live on in beam 1. word_probs[0::beam_sz, repeat_idx] = -0.1 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3 else: # predict the same thing in beam 0 word_probs[0::beam_sz, repeat_idx] = 0 # continue pushing around what beam 1 predicts word_probs[1::beam_sz, repeat_idx + i + 1] = 0 attns = torch.randn(beam_sz) beam.advance(word_probs, attns) if i <= ngram_repeat: self.assertFalse( beam.topk_log_probs[0::beam_sz].eq( self.BLOCKED_SCORE).any()) self.assertFalse( beam.topk_log_probs[1::beam_sz].eq( self.BLOCKED_SCORE).any()) else: # now beam 0 dies (along with the others), beam 1 -> beam 0 self.assertFalse( beam.topk_log_probs[:, 0].eq( self.BLOCKED_SCORE).any()) self.assertTrue( beam.topk_log_probs[:, 1:].equal( torch.tensor(self.BLOCKED_SCORE) .repeat(batch_sz, beam_sz-1)))
def test_advance_with_some_repeats_gets_blocked(self): # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores) beam_sz = 5 n_words = 100 repeat_idx = 47 ngram_repeat = 3 for batch_sz in [1, 3]: beam = BeamSearch( beam_sz, batch_sz, 0, 1, 2, 2, torch.device("cpu"), GlobalScorerStub(), 0, 30, False, ngram_repeat, set(), torch.randint(0, 30, (batch_sz,)), False, 0.) for i in range(ngram_repeat + 4): # non-interesting beams are going to get dummy values word_probs = torch.full( (batch_sz * beam_sz, n_words), -float('inf')) if i == 0: # on initial round, only predicted scores for beam 0 # matter. Make two predictions. Top one will be repeated # in beam zero, second one will live on in beam 1. word_probs[0::beam_sz, repeat_idx] = -0.1 word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3 else: # predict the same thing in beam 0 word_probs[0::beam_sz, repeat_idx] = 0 # continue pushing around what beam 1 predicts word_probs[1::beam_sz, repeat_idx + i + 1] = 0 attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) if i <= ngram_repeat: self.assertFalse( beam.topk_log_probs[0::beam_sz].eq( self.BLOCKED_SCORE).any()) self.assertFalse( beam.topk_log_probs[1::beam_sz].eq( self.BLOCKED_SCORE).any()) else: # now beam 0 dies (along with the others), beam 1 -> beam 0 self.assertFalse( beam.topk_log_probs[:, 0].eq( self.BLOCKED_SCORE).any()) self.assertTrue( beam.topk_log_probs[:, 1:].equal( torch.tensor(self.BLOCKED_SCORE) .repeat(batch_sz, beam_sz-1)))
def test_beam_advance_against_known_reference(self): beam = BeamSearch(self.BEAM_SZ, self.BATCH_SZ, 0, 1, 2, self.N_BEST, torch.device("cpu"), GlobalScorerStub(), 0, 30, False, 0, set(), torch.randint(0, 30, (self.BATCH_SZ, ))) expected_beam_scores = self.init_step(beam) expected_beam_scores = self.first_step(beam, expected_beam_scores) expected_beam_scores = self.second_step(beam, expected_beam_scores) self.third_step(beam, expected_beam_scores)
def test_beam_is_done_when_n_best_beams_eos_using_min_length(self): # this is also a test that when block_ngram_repeat=0, # repeating is acceptable beam_sz = 5 batch_sz = 3 n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) min_length = 5 eos_idx = 2 beam = BeamSearch( beam_sz, batch_sz, 0, 1, 2, 2, torch.device("cpu"), GlobalScorerStub(), min_length, 30, False, 0, set(), torch.randint(0, 30, (batch_sz,)), False, 0.) for i in range(min_length + 4): # non-interesting beams are going to get dummy values word_probs = torch.full( (batch_sz * beam_sz, n_words), -float('inf')) if i == 0: # "best" prediction is eos - that should be blocked word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] # include at least beam_sz predictions OTHER than EOS # that are greater than -1e20 for j, score in zip(_non_eos_idxs, valid_score_dist[1:]): word_probs[0::beam_sz, j] = score elif i <= min_length: # predict eos in beam 1 word_probs[1::beam_sz, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions in other beams for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz-1, k) word_probs[beam_idx::beam_sz, j] = score else: word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] word_probs[1::beam_sz, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions in other beams for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz-1, k) word_probs[beam_idx::beam_sz, j] = score attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) if i < min_length: self.assertFalse(beam.done) elif i == min_length: # beam 1 dies on min_length self.assertTrue(beam.is_finished[:, 1].all()) beam.update_finished() self.assertFalse(beam.done) else: # i > min_length # beam 0 dies on the step after beam 1 dies self.assertTrue(beam.is_finished[:, 0].all()) beam.update_finished() self.assertTrue(beam.done)
def translate_batch(self, batch, src_vocabs, attn_debug): """Translate a batch of sentences.""" with torch.no_grad(): tic = time.perf_counter() if self.beam_size == 1: decode_strategy = GreedySearch( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, batch_size=batch.batch_size, min_length=self.min_length, max_length=self.max_length, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, sampling_temp=self.random_sampling_temp, keep_topk=self.sample_from_topk) else: # TODO: support these blacklisted features assert not self.dump_beam decode_strategy = BeamSearch( self.beam_size, batch_size=batch.batch_size, pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, n_best=self.n_best, global_scorer=self.global_scorer, min_length=self.min_length, max_length=self.max_length, return_attention=attn_debug or self.replace_unk, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, stepwise_penalty=self.stepwise_penalty, ratio=self.ratio) toc = time.perf_counter() beam_search_time = toc - tic tic = time.perf_counter() ret = self._translate_batch_with_strategy(batch, src_vocabs, decode_strategy) toc = time.perf_counter() translate_batch_with_strategy_time = toc - tic if show_profile_detail: print( f"BeamSearch Time {beam_search_time:0.4f} seconds, translate_batch_with_strategy Time {translate_batch_with_strategy_time: 0.4f} seconds" ) return ret
def test_advance_with_some_repeats_gets_blocked(self): # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores) beam_sz = 5 n_words = 100 repeat_idx = 47 ngram_repeat = 3 no_repeat_score = -2.3 repeat_score = -0.1 device_init = torch.zeros(1, 1) for batch_sz in [1, 3]: beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2, GlobalScorerStub(), 0, 30, False, ngram_repeat, set(), False, 0.) beam.initialize(device_init, torch.randint(0, 30, (batch_sz, ))) for i in range(ngram_repeat + 4): # non-interesting beams are going to get dummy values word_probs = torch.full((batch_sz * beam_sz, n_words), -float('inf')) if i == 0: # on initial round, only predicted scores for beam 0 # matter. Make two predictions. Top one will be repeated # in beam zero, second one will live on in beam 1. word_probs[0::beam_sz, repeat_idx] = repeat_score word_probs[0::beam_sz, repeat_idx + i + 1] = no_repeat_score else: # predict the same thing in beam 0 word_probs[0::beam_sz, repeat_idx] = 0 # continue pushing around what beam 1 predicts word_probs[1::beam_sz, repeat_idx + i + 1] = 0 attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) if i < ngram_repeat: self.assertFalse(beam.topk_log_probs[0::beam_sz].eq( self.BLOCKED_SCORE).any()) self.assertFalse(beam.topk_log_probs[1::beam_sz].eq( self.BLOCKED_SCORE).any()) elif i == ngram_repeat: # now beam 0 dies (along with the others), beam 1 -> beam 0 self.assertFalse(beam.topk_log_probs[:, 0].eq( self.BLOCKED_SCORE).any()) expected = torch.full([batch_sz, beam_sz], float("-inf")) expected[:, 0] = no_repeat_score expected[:, 1] = self.BLOCKED_SCORE self.assertTrue(beam.topk_log_probs[:, :].equal(expected)) else: # now beam 0 dies (along with the others), beam 1 -> beam 0 self.assertFalse(beam.topk_log_probs[:, 0].eq( self.BLOCKED_SCORE).any()) expected = torch.full([batch_sz, beam_sz], float("-inf")) expected[:, 0] = no_repeat_score expected[:, 1:3] = self.BLOCKED_SCORE expected[:, 3:] = float("-inf") self.assertTrue(beam.topk_log_probs.equal(expected))
def forward_dev_beam_search(self, encoder_output: torch.Tensor, pad_mask): batch_size = encoder_output.size(1) self.state["cache"] = None memory_lengths = pad_mask.ne(pad_token_index).sum(dim=0) self.map_state(lambda state, dim: tile(state, self.beam_size, dim=dim)) encoder_output = tile(encoder_output, self.beam_size, dim=1) pad_mask = tile(pad_mask, self.beam_size, dim=1) memory_lengths = tile(memory_lengths, self.beam_size, dim=0) # TODO: # - fix attn (?) # - use coverage_penalty="summary" ou "wu" and beta=0.2 (ou pas) # - use length_penalty="wu" and alpha=0.2 (ou pas) beam = BeamSearch(beam_size=self.beam_size, n_best=1, batch_size=batch_size, mb_device=default_device, global_scorer=GNMTGlobalScorer(alpha=0, beta=0, coverage_penalty="none", length_penalty="avg"), pad=pad_token_index, eos=eos_token_index, bos=bos_token_index, min_length=1, max_length=100, return_attention=False, stepwise_penalty=False, block_ngram_repeat=0, exclusion_tokens=set(), memory_lengths=memory_lengths, ratio=-1) for i in range(self.max_seq_out_len): inp = beam.current_predictions.view(1, -1) out, attn = self.forward_step(src=pad_mask, tgt=inp, memory_bank=encoder_output, step=i) # 1 x batch*beam x hidden out = self.linear(out) # 1 x batch*beam x vocab_out out = log_softmax(out, dim=2) # 1 x batch*beam x vocab_out out = out.squeeze(0) # batch*beam x vocab_out # attn = attn.squeeze(0) # batch*beam x vocab_out # out = out.view(batch_size, self.beam_size, -1) # batch x beam x vocab_out # attn = attn.view(batch_size, self.beam_size, -1) # TODO: fix attn (?) beam.advance(out, attn) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. encoder_output = encoder_output.index_select(1, select_indices) pad_mask = pad_mask.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) self.map_state(lambda state, dim: state.index_select(dim, select_indices)) outputs = beam.predictions outputs = [x[0] for x in outputs] outputs = pad_sequence(outputs, batch_first=True) return [outputs]
def translate_batch(self, batch, src_vocabs, attn_debug, src=None, enc_states=None, memory_bank=None, \ src_lengths=None, src_embed=None, tgt2=False, hidden_state=None): """Translate a batch of sentences.""" if self.beam_size == 1: decode_strategy = GreedySearch( pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, batch_size=batch.batch_size, min_length=self.min_length, max_length=self.max_length, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, return_attention=attn_debug or self.replace_unk, sampling_temp=self.random_sampling_temp, keep_topk=self.sample_from_topk) else: # TODO: support these blacklisted features assert not self.dump_beam decode_strategy = BeamSearch( self.beam_size, batch_size=batch.batch_size, pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, n_best=self.n_best, global_scorer=self.global_scorer, min_length=self.min_length, max_length=self.max_length, return_attention=attn_debug or self.replace_unk, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, stepwise_penalty=self.stepwise_penalty, ratio=self.ratio) return self._return_gold(batch, src_vocabs, decode_strategy, src, enc_states, memory_bank, src_lengths, src_embed, tgt2, hidden_state=hidden_state)
def test_advance_with_all_repeats_gets_blocked(self): # all beams repeat (beam >= 1 repeat dummy scores) beam_sz = 5 n_words = 100 repeat_idx = 47 ngram_repeat = 3 device_init = torch.zeros(1, 1) for batch_sz in [1, 3]: beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2, GlobalScorerStub(), 0, 30, False, ngram_repeat, set(), False, 0.) beam.initialize(device_init, torch.randint(0, 30, (batch_sz, ))) for i in range(ngram_repeat + 4): # predict repeat_idx over and over again word_probs = torch.full((batch_sz * beam_sz, n_words), -float('inf')) word_probs[0::beam_sz, repeat_idx] = 0 attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) if i < ngram_repeat: # before repeat, scores are either 0 or -inf expected_scores = torch.tensor( [0] + [-float('inf')] * (beam_sz - 1))\ .repeat(batch_sz, 1) self.assertTrue(beam.topk_log_probs.equal(expected_scores)) elif i % ngram_repeat == 0: # on repeat, `repeat_idx` score is BLOCKED_SCORE # (but it's still the best score, thus we have # [BLOCKED_SCORE, -inf, -inf, -inf, -inf] expected_scores = torch.tensor( [0] + [-float('inf')] * (beam_sz - 1))\ .repeat(batch_sz, 1) expected_scores[:, 0] = self.BLOCKED_SCORE self.assertTrue(beam.topk_log_probs.equal(expected_scores)) else: # repetitions keeps maximizing score # index 0 has been blocked, so repeating=>+0.0 score # other indexes are -inf so repeating=>BLOCKED_SCORE # which is higher expected_scores = torch.tensor( [0] + [-float('inf')] * (beam_sz - 1))\ .repeat(batch_sz, 1) expected_scores[:, :] = self.BLOCKED_SCORE expected_scores = torch.tensor(self.BLOCKED_SCORE).repeat( batch_sz, beam_sz)
def _translate_batch( self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size batch_size = batch.batch_size # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) # (0) pt 2, prep the beam object beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths) for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=beam._batch_offset) beam.advance(log_probs, attn) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
def _translate_batch( self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size batch_size = batch.batch_size # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) # (0) pt 2, prep the beam object beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths) states = None #see = 23 see = 0 if self.constraint: itos = self.fields["tgt"].base_field.vocab.itos stoi = self.fields["tgt"].base_field.vocab.stoi states = BB_sequence_state( itos, stoi, mb_device, batch_size, beam_size, eos=self._tgt_eos_idx) for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) #print("================= step",step) #print(decoder_input[0][see*10:(see+1)*10,0]) log_probs, attn = self._decode_and_generate( decoder_input, states, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=beam._batch_offset) beam.advance(log_probs, attn) lastest_action = beam.current_predictions.data.tolist() lastest_score = beam.current_scores.view(-1).data.tolist() select_indices = beam.current_origin #print(select_indices[see*10:see*10+10] % 10) #print(beam.current_scores.view(-1)[see*10:see*10+10] % 10) #print(lastest_action[see*10:(see+1)*10]) #print(lastest_score[see*10:(see+1)*10]) #print(lastest_action) #print(lastest_score) #for act in lastest_action[see*10:(see+1)*10]: # if act < len(self.fields["tgt"].base_field.vocab.itos): # print(act, self.fields["tgt"].base_field.vocab.itos[act], end=" | ") # else: # print(act, "copy", end=" | ") #print() #for act in lastest_action[::10]: # if act < len(self.fields["tgt"].base_field.vocab.itos): # print(self.fields["tgt"].base_field.vocab.itos[act], end=" ") # else: # print("copy", end=" ") #print() #print(lastest_score) if states is not None: states.update_beam(lastest_action, select_indices.data.tolist(), lastest_score) #print(select_indices[see*10:see*10+10] % 10) #for i in range(10): # states.states[see*10+i].print() #print() any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: # print("any_beam_is_finished") finished_batch = beam.update_finished() cnt = 0 for bidx in finished_batch: if bidx < see: cnt += 1 see -= cnt #exit() if beam.done: break select_indices = beam.current_origin #print("REDUCE",select_indices.size()) if any_beam_is_finished: # Reorder states. if states is not None: states.index_select(select_indices) if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
def _translate_batch( self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False,tags=[]): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size batch_size = batch.batch_size tags = self.ctags if tags[-1] =="EN" or tags[-1]=="DE": lang = tags[-1] tags = tags[:-1] else: lang =None assert(False) if lang is not None: enc1 = self.model.decoder enc2 = self.model.decoder2 if lang =="EN": self.model.decoder=enc2 # (1) Run the encoder on the src. allstuff = self._run_encoder(batch,tags=tags) #print (len(allstuff)) if len(allstuff) == 3: src, enc_states, memory_bank, = allstuff elif len(allstuff) == 4: src, enc_states, memory_bank, src_lengths =allstuff thing = (enc_states.data.cpu().numpy()) lengths = (src_lengths.data.cpu().numpy()) maxvecs = [] self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) # (0) pt 2, prep the beam object if hasattr(batch,"tgt"): print ("tgt") tgt = (batch.tgt.data.cpu().numpy()).T else: print ("no tgt") tgt = [[] for _ in range(100000)] # beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths,i2w=self._tgt_vocab.itos,batch=batch) if hasattr(batch,"tgt"): tgt = (batch.tgt.data.cpu().numpy()).T.squeeze() else: tgt = [[] for _ in range(100000)] # for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=beam._batch_offset) beam.advance(log_probs, attn) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) print ("NOW") print (src_map) if lang is not None and lang=="EN" and False: self.model.decoder2.map_state( lambda state, dim: state.index_select(dim, select_indices)) else: self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) #print (len(beam.scores)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention results["maxvecs"] = [] if lang is not None: self.decoder=enc1 return results
def test_beam_returns_attn_with_correct_length(self): beam_sz = 5 batch_sz = 3 n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) min_length = 5 eos_idx = 2 inp_lens = torch.randint(1, 30, (batch_sz, )) beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 2, GlobalScorerStub(), min_length, 30, True, 0, set(), False, 0.) device_init = torch.zeros(1, 1) _, _, inp_lens, _ = beam.initialize(device_init, inp_lens) # inp_lens is tiled in initialize, reassign to make attn match for i in range(min_length + 2): # non-interesting beams are going to get dummy values word_probs = torch.full((batch_sz * beam_sz, n_words), -float('inf')) if i == 0: # "best" prediction is eos - that should be blocked word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] # include at least beam_sz predictions OTHER than EOS # that are greater than -1e20 for j, score in zip(_non_eos_idxs, valid_score_dist[1:]): word_probs[0::beam_sz, j] = score elif i <= min_length: # predict eos in beam 1 word_probs[1::beam_sz, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions in other beams for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz - 1, k) word_probs[beam_idx::beam_sz, j] = score else: word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] word_probs[1::beam_sz, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions in other beams for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz - 1, k) word_probs[beam_idx::beam_sz, j] = score attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) if i < min_length: self.assertFalse(beam.done) # no top beams are finished yet for b in range(batch_sz): self.assertEqual(beam.attention[b], []) elif i == min_length: # beam 1 dies on min_length self.assertTrue(beam.is_finished[:, 1].all()) beam.update_finished() self.assertFalse(beam.done) # no top beams are finished yet for b in range(batch_sz): self.assertEqual(beam.attention[b], []) else: # i > min_length # beam 0 dies on the step after beam 1 dies self.assertTrue(beam.is_finished[:, 0].all()) beam.update_finished() self.assertTrue(beam.done) # top beam is finished now so there are attentions for b in range(batch_sz): # two beams are finished in each batch self.assertEqual(len(beam.attention[b]), 2) for k in range(2): # second dim is cut down to the non-padded src length self.assertEqual(beam.attention[b][k].shape[-1], inp_lens[b]) # first dim is equal to the time of death # (beam 0 died at current step - adjust for SOS) self.assertEqual(beam.attention[b][0].shape[0], i + 1) # (beam 1 died at last step - adjust for SOS) self.assertEqual(beam.attention[b][1].shape[0], i) # behavior gets weird when beam is already done so just stop break
def run( n_iterations=7, beam_sz=5, batch_sz=1, n_words=100, repeat_idx=47, ngram_repeat=-1, repeat_logprob=-0.2, no_repeat_logprob=-0.1, base_logprob=float("-inf"), verbose=True, ): """ At each timestep `i`: - token `repeat_idx` get a logprob of `repeat_logprob` - token `repeat_idx+i` get `no_repeat_logprob` - other tokens get `base_logprob` """ def log(*args, **kwargs): if verbose: print(*args, **kwargs) device_init = torch.zeros(1, 1) beam = BeamSearch(beam_sz, batch_sz, 0, 1, 2, 4, GlobalScorerStub(), 0, 30, False, ngram_repeat, set(), False, 0.) beam.initialize(device_init, torch.randint(0, 30, (batch_sz, ))) for i in range(n_iterations): # non-interesting beams are going to get dummy values word_logprobs = torch.full((batch_sz * beam_sz, n_words), base_logprob) if i == 0: # on initial round, only predicted scores for beam 0 # matter. Make two predictions. Top one will be repeated # in beam zero, second one will live on in beam 1. word_logprobs[0::beam_sz, repeat_idx] = repeat_logprob word_logprobs[0::beam_sz, repeat_idx + i + 1] = no_repeat_logprob else: # predict the same thing in beam 0 word_logprobs[0::beam_sz, repeat_idx] = repeat_logprob # continue pushing around what beam 1 predicts word_logprobs[0::beam_sz, repeat_idx + i + 1] = no_repeat_logprob attns = torch.randn(1, batch_sz * beam_sz, 0) beam.advance(word_logprobs, attns) if ngram_repeat > 0: # NOTE: IGNORE IT FOR NOW # if i < ngram_repeat: # assertFalse( # beam.topk_log_probs[0::beam_sz].eq( # BLOCKED_SCORE).any()) # assertFalse( # beam.topk_log_probs[1::beam_sz].eq( # BLOCKED_SCORE).any()) # elif i == ngram_repeat: # assertFalse( # beam.topk_log_probs[:, 0].eq( # BLOCKED_SCORE).any()) # expected = torch.full([batch_sz, beam_sz], base_logprob) # expected[:, 0] = (i+1) * no_repeat_logprob # expected[:, 1] = BLOCKED_SCORE # else: # expected = torch.full([batch_sz, beam_sz], base_logprob) # expected[:, 0] = i * no_repeat_logprob # expected[:, 1] = BLOCKED_SCORE pass # log("Iteration (%d): expected %s" % (i, str(expected))) log("Iteration (%d): logprobs %s" % (i, str(beam.topk_log_probs))) log("Iteration (%d): seq %s" % (i, str(beam.alive_seq))) log("Iteration (%d): indices %s" % (i, str(beam.topk_ids))) if ngram_repeat > 0: log("Iteration (%d): blocked %s" % (i, str(beam.forbidden_tokens))) return beam.topk_log_probs, beam.alive_seq
def _translate_batch( self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False, xlation_builder=None, ): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size batch_size = batch.batch_size # (1) Run the encoder on the src. src_list, enc_states_list, memory_bank_list, src_lengths_list = self._run_encoder(batch) self.model.decoder.init_state(src_list, memory_bank_list, enc_states_list) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank_list, src_lengths_list, src_vocabs, use_src_map, enc_states_list, batch_size, src_list)} # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map_list = list() for src_type in self.src_types: src_map_list.append((tile(getattr(batch, f"src_map.{src_type}"), beam_size, dim=1) if use_src_map else None)) # end for self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) memory_lengths_list = list() memory_lengths = list() for src_i in range(len(memory_bank_list)): if isinstance(memory_bank_list[src_i], tuple): memory_bank_list[src_i] = tuple(tile(x, beam_size, dim=1) for x in memory_bank_list[src_i]) mb_device = memory_bank_list[src_i][0].device else: memory_bank_list[src_i] = tile(memory_bank_list[src_i], beam_size, dim=1) mb_device = memory_bank_list[src_i].device # end if memory_lengths_list.append(tile(src_lengths_list[src_i], beam_size)) memory_lengths.append(src_lengths_list[src_i]) # end for memory_lengths = tile(torch.stack(memory_lengths, dim=0).sum(dim=0), beam_size) indexes = tile(torch.tensor(list(range(batch_size)), device=self._dev), beam_size) # (0) pt 2, prep the beam object beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths) for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) log_probs, attn = self._decode_and_generate( decoder_input, memory_bank_list, batch, src_vocabs, memory_lengths_list=memory_lengths_list, src_map_list=src_map_list, step=step, batch_offset=beam._batch_offset) if self.reranker is not None: log_probs = self.reranker.rerank_step_beam_batch( batch, beam, self.beam_size, indexes, log_probs, attn, self.fields["tgt"].base_field.vocab, xlation_builder, ) # end if non_finished = None beam.advance(log_probs, attn) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: non_finished = beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. for src_i in range(len(memory_bank_list)): if isinstance(memory_bank_list[src_i], tuple): memory_bank_list[src_i] = tuple(x.index_select(1, select_indices) for x in memory_bank_list[src_i]) else: memory_bank_list[src_i] = memory_bank_list[src_i].index_select(1, select_indices) # end if memory_lengths_list[src_i] = memory_lengths_list[src_i].index_select(0, select_indices) # end for if use_src_map and src_map_list[0] is not None: for src_i in range(len(src_map_list)): src_map_list[src_i] = src_map_list[src_i].index_select(1, select_indices) # end for # end if indexes = indexes.index_select(0, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
model.decoder.map_state(lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) beam = BeamSearch(beam_size, n_best=n_best, batch_size=batch_size, global_scorer=global_scorer, pad=tgt_pad_idx, eos=tgt_eos_idx, bos=tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=stepwise_penalty, block_ngram_repeat=block_ngram_repeat, exclusion_tokens=exclusion_idxs, memory_lengths=memory_lengths) for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) dec_out, dec_attn = model.decoder(decoder_input, memory_bank, memory_lengths=memory_lengths, step=step) log_probs = model.generator(dec_out.squeeze(0)) attn = dec_attn['std']
def __init__(self, onmt_summarizer, batch_size, beam_size=None, n_best=None, diverse=None, distractor=None, scramble_idxs=None, logger=None): super().__init__(logger) assert diverse in ['rank', None], 'Invalid diverse beam type!' s = onmt_summarizer T = onmt_summarizer.translator # default values are given by T beam_size = T.beam_size if beam_size is None else beam_size n_best = T.n_best if n_best is None else n_best if distractor is None: memory_lengths = s.memory_lengths else: tgt_idx = scramble2tgt(scramble_idxs, distractor.d_factor) memory_lengths = s.memory_lengths.view(-1, beam_size) \ .index_select(0, tgt_idx).view(-1) if diverse is None: self.beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, # actual batch size global_scorer=T.global_scorer, pad=T._tgt_pad_idx, eos=T._tgt_eos_idx, bos=T._tgt_bos_idx, min_length=T.min_length, ratio=T.ratio, max_length=T.max_length, mb_device=s.mb_device, return_attention=T.replace_unk, stepwise_penalty=T.stepwise_penalty, block_ngram_repeat=T.block_ngram_repeat, exclusion_tokens=T._exclusion_idxs, memory_lengths=s.memory_lengths) elif diverse == 'rank': self.beam = RankDiverseBeam( beam_size, n_best=n_best, batch_size=batch_size, # actual batch size global_scorer=T.global_scorer, pad=T._tgt_pad_idx, eos=T._tgt_eos_idx, bos=T._tgt_bos_idx, min_length=T.min_length, ratio=T.ratio, max_length=T.max_length, mb_device=s.mb_device, return_attention=T.replace_unk, stepwise_penalty=T.stepwise_penalty, block_ngram_repeat=T.block_ngram_repeat, exclusion_tokens=T._exclusion_idxs, memory_lengths=s.memory_lengths)
def _translate_batch( self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size #default 5 batch_size = batch.batch_size #default 30 # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) #src.size() = torch.Size([59, 30, 1]) [src_len,batch_size,1] #enc_states[0/1].size() = [2,30,500] #memory_bank.size() =[59,30,500] #src_lengths.size() = [30] self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) #把张量x在dim=1,重复beam_size次。beam_size=1是batch_size的维度。 mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) # (0) pt 2, prep the beam object beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, #Maximum prediction length. mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths) for step in range(max_length): #一共走这个多个step,每个step将beam_size * batch_size个分支加入 decoder_input = beam.current_predictions.view(1, -1, 1) #decoder_input.size() = torch.Size([1,150,1]) 150 = 30 * 5 = batch_size * beam_size # @property # def current_predictions(self): # return self.alive_seq[:, -1] log_probs, attn = self._decode_and_generate( decoder_input, memory_bank,#torch.Size([59, 150, 500]) batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=beam._batch_offset) # print("log_probs = ",log_probs) #[150, 50004] 这个50004应该是词表的大小,词表中的单词应该是5万,多出来的4个应该是<s> </s> <unk> <pad> # print("attn = ",attn) #torch.Size([1, 150, 59]) 这个59应该是src中的最长的句子的长度 # print("decoder_input = ",decoder_input.size()) beam.advance(log_probs, attn)#这个里面完成的工作应该是将150再变回30, any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
def _translate_batch( self, src, src_lengths, batch_size, min_length=0, ratio=0., n_best=1, return_attention=False): max_length = self.config.max_sequence_length + 1 # to account for EOS beam_size = 3 # Encoder forward. enc_states, memory_bank, src_lengths = self.encoder(src, src_lengths) self.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None } # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size self.decoder.map_state(lambda state, dim: tile(state, beam_size, dim=dim)) #if isinstance(memory_bank, tuple): # memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) # mb_device = memory_bank[0].device #else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) mb_device = memory_bank[0].device if isinstance(memory_bank, tuple) else memory_bank.device block_ngram_repeat = 0 _exclusion_idxs = {} # (0) pt 2, prep the beam object beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.scorer, pad=self.config.tgt_padding, eos=self.config.tgt_eos, bos=self.config.tgt_bos, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=None, block_ngram_repeat=block_ngram_repeat, exclusion_tokens=_exclusion_idxs, memory_lengths=memory_lengths) for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) log_probs, attn = self._decode_and_generate(decoder_input, memory_bank, memory_lengths, step, pretraining = True) beam.advance(log_probs, attn) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) self.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
def _translate_batch(self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size batch_size = batch.batch_size #### TODO: Augment batch with distractors # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) # src has shape [1311, 2, 1] # enc_states has shape [1311, 2, 512], # Memory_bank has shape [1311, 2, 512] self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score(batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src) } # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) print('memory_bank size after tile:', memory_bank.shape) #[1311, 20, 512] # (0) pt 2, prep the beam object beam = BeamSearch(beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths) all_log_probs = [] all_attn = [] for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) # decoder_input has shape[1,20,1] # decoder_input gives top 10 predictions for each batch element verbose = True if step == 10 else False log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=beam._batch_offset, verbose=verbose) # log_probs and attn are the probs for next word given that the # current word is that in decoder_input all_log_probs.append(log_probs) all_attn.append(attn) beam.advance(log_probs, attn, verbose=verbose) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple( x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) print('batch_size:', batch_size) print('max_length:', max_length) print('all_log_probs has len', len(all_log_probs)) print('all_log_probs[0].shape', all_log_probs[0].shape) print('comparing log_probs[0]', all_log_probs[2][:, 0]) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
def test_beam_returns_attn_with_correct_length(self): beam_sz = 5 batch_sz = 3 n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) min_length = 5 eos_idx = 2 inp_lens = torch.randint(1, 30, (batch_sz,)) beam = BeamSearch( beam_sz, batch_sz, 0, 1, 2, 2, torch.device("cpu"), GlobalScorerStub(), min_length, 30, True, 0, set(), inp_lens, False, 0.) for i in range(min_length + 2): # non-interesting beams are going to get dummy values word_probs = torch.full( (batch_sz * beam_sz, n_words), -float('inf')) if i == 0: # "best" prediction is eos - that should be blocked word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] # include at least beam_sz predictions OTHER than EOS # that are greater than -1e20 for j, score in zip(_non_eos_idxs, valid_score_dist[1:]): word_probs[0::beam_sz, j] = score elif i <= min_length: # predict eos in beam 1 word_probs[1::beam_sz, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions in other beams for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz-1, k) word_probs[beam_idx::beam_sz, j] = score else: word_probs[0::beam_sz, eos_idx] = valid_score_dist[0] word_probs[1::beam_sz, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions in other beams for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz-1, k) word_probs[beam_idx::beam_sz, j] = score attns = torch.randn(1, batch_sz * beam_sz, 53) beam.advance(word_probs, attns) if i < min_length: self.assertFalse(beam.done) # no top beams are finished yet for b in range(batch_sz): self.assertEqual(beam.attention[b], []) elif i == min_length: # beam 1 dies on min_length self.assertTrue(beam.is_finished[:, 1].all()) beam.update_finished() self.assertFalse(beam.done) # no top beams are finished yet for b in range(batch_sz): self.assertEqual(beam.attention[b], []) else: # i > min_length # beam 0 dies on the step after beam 1 dies self.assertTrue(beam.is_finished[:, 0].all()) beam.update_finished() self.assertTrue(beam.done) # top beam is finished now so there are attentions for b in range(batch_sz): # two beams are finished in each batch self.assertEqual(len(beam.attention[b]), 2) for k in range(2): # second dim is cut down to the non-padded src length self.assertEqual(beam.attention[b][k].shape[-1], inp_lens[b]) # first dim is equal to the time of death # (beam 0 died at current step - adjust for SOS) self.assertEqual(beam.attention[b][0].shape[0], i+1) # (beam 1 died at last step - adjust for SOS) self.assertEqual(beam.attention[b][1].shape[0], i) # behavior gets weird when beam is already done so just stop break
def batch_beam_search_trs(model, inputs, batch_size, device="cpu", max_len=128, beam_size=20, n_best=1, alpha=1.): """ beam search with batch input for Transformer model Arguments: beam {onmt.BeamSearch} -- opennmt BeamSearch class model {torch.nn.Module} -- subclass of torch.nn.Module, required to implement .encode() and .decode() method inputs {list} -- list of torch.Tensor for input of encode() Keyword Arguments: device {str} -- device to eval model (default: {"cpu"}) Returns: result -- 2D list (B, N-best), each element is an (seq, score) pair """ beam = BeamSearch(beam_size, batch_size, pad=C.PAD, bos=C.BOS, eos=C.EOS, n_best=n_best, mb_device=device, global_scorer=GNMTGlobalScorer(alpha, 0.1, "avg", "none"), min_length=0, max_length=max_len, ratio=0.0, memory_lengths=None, block_ngram_repeat=False, exclusion_tokens=None, stepwise_penalty=True, return_attention=False, ) model.eval() is_finished = [False] * beam.batch_size with torch.no_grad(): src_ids, _, src_pos, _, _, src_key_padding_mask, _, original_memory_key_padding_mask = list( map(lambda x: x.to(device), inputs)) if src_ids.shape[1]!=batch_size: diff = batch_size - src_ids.shape[1] src_ids = torch.cat([src_ids] + [src_ids[:,:1]] * diff, dim=1) src_pos = torch.cat([src_pos] + [src_pos[:,:1]] * diff, dim=1) src_key_padding_mask = torch.cat([src_key_padding_mask]+[src_key_padding_mask[:1]]* diff, dim=0) original_memory_key_padding_mask = torch.cat([original_memory_key_padding_mask] +[original_memory_key_padding_mask[:1]]*diff, dim=0) model.to(device) original_memory = model.encode(src_ids, src_pos, src_key_padding_mask=src_key_padding_mask) memory = original_memory memory_key_padding_mask = original_memory_key_padding_mask while not beam.done: len_decoder_inputs = beam.alive_seq.shape[1] dec_pos = torch.arange(1, len_decoder_inputs+1).repeat(beam.alive_seq.shape[0], 1).permute(1, 0).to(device) # unsqueeze the memory and memory_key_padding_mask in B dim to match the size (BM*BS) repeated_memory = memory.repeat(1, 1, beam.beam_size).reshape( memory.shape[0], -1, memory.shape[-1]) repeated_memory_key_padding_mask = memory_key_padding_mask.repeat( 1, beam.beam_size).reshape(-1, memory_key_padding_mask.shape[1]) decoder_outputs = model.decode(beam.alive_seq.permute(1, 0), dec_pos, _, repeated_memory, memory_key_padding_mask=repeated_memory_key_padding_mask)[-1] if hasattr(model, "proj"): logits = model.proj(decoder_outputs) elif hasattr(model, "gen"): logits = model.gen(decoder_outputs) else: raise ValueError("Unknown generator!") log_probs = torch.nn.functional.log_softmax(logits, dim=1) beam.advance(log_probs, None) if beam.is_finished.any(): beam.update_finished() # select data for the still-alive index for i, n_best in enumerate(beam.predictions): if is_finished[i] == False and len(n_best) == beam.n_best: is_finished[i] = True alive_example_idx = [i for i in range( len(is_finished)) if not is_finished[i]] if alive_example_idx: memory = original_memory[:, alive_example_idx, :] memory_key_padding_mask = original_memory_key_padding_mask[alive_example_idx] # packing data for easy accessing results = [] for batch_preds, batch_scores in zip(beam.predictions, beam.scores): n_best_result = [] for n_best_pred, n_best_score in zip(batch_preds, batch_scores): assert isinstance(n_best_pred, torch.Tensor) assert isinstance(n_best_score, torch.Tensor) n_best_result.append( (n_best_pred.tolist(), n_best_score.item()) ) results.append(n_best_result) return results