def test_advance_with_all_repeats_gets_blocked(self): # all beams repeat (beam >= 1 repeat dummy scores) beam_sz = 5 n_words = 100 repeat_idx = 47 ngram_repeat = 3 beam = Beam(beam_sz, 0, 1, 2, n_best=2, exclusion_tokens=set(), global_scorer=GlobalScorerStub(), block_ngram_repeat=ngram_repeat) for i in range(ngram_repeat + 4): # predict repeat_idx over and over again word_probs = torch.full((beam_sz, n_words), -float('inf')) word_probs[0, repeat_idx] = 0 attns = torch.randn(beam_sz) beam.advance(word_probs, attns) if i <= ngram_repeat: self.assertTrue( beam.scores.equal( torch.tensor( [0] + [-float('inf')] * (beam_sz - 1)))) else: self.assertTrue( beam.scores.equal(torch.tensor( [self.BLOCKED_SCORE] * beam_sz)))
def test_advance_with_some_repeats_gets_blocked(self): # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores) beam_sz = 5 n_words = 100 repeat_idx = 47 ngram_repeat = 3 beam = Beam(beam_sz, 0, 1, 2, n_best=2, exclusion_tokens=set(), global_scorer=GlobalScorerStub(), block_ngram_repeat=ngram_repeat) for i in range(ngram_repeat + 4): # non-interesting beams are going to get dummy values word_probs = torch.full((beam_sz, n_words), -float('inf')) if i == 0: # on initial round, only predicted scores for beam 0 # matter. Make two predictions. Top one will be repeated # in beam zero, second one will live on in beam 1. word_probs[0, repeat_idx] = -0.1 word_probs[0, repeat_idx + i + 1] = -2.3 else: # predict the same thing in beam 0 word_probs[0, repeat_idx] = 0 # continue pushing around what beam 1 predicts word_probs[1, repeat_idx + i + 1] = 0 attns = torch.randn(beam_sz) beam.advance(word_probs, attns) if i <= ngram_repeat: self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE)) self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE)) else: # now beam 0 dies (along with the others), beam 1 -> beam 0 self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE)) self.assertTrue( beam.scores[1:].equal(torch.tensor( [self.BLOCKED_SCORE] * (beam_sz - 1))))
def test_beam_is_done_when_n_best_beams_eos_using_min_length(self): # this is also a test that when block_ngram_repeat=0, # repeating is acceptable beam_sz = 5 n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) min_length = 5 eos_idx = 2 beam = Beam(beam_sz, 0, 1, eos_idx, n_best=2, exclusion_tokens=set(), min_length=min_length, global_scorer=GlobalScorerStub(), block_ngram_repeat=0) for i in range(min_length + 4): # non-interesting beams are going to get dummy values word_probs = torch.full((beam_sz, n_words), -float('inf')) if i == 0: # "best" prediction is eos - that should be blocked word_probs[0, eos_idx] = valid_score_dist[0] # include at least beam_sz predictions OTHER than EOS # that are greater than -1e20 for j, score in zip(_non_eos_idxs, valid_score_dist[1:]): word_probs[0, j] = score elif i <= min_length: # predict eos in beam 1 word_probs[1, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions in other beams for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz-1, k) word_probs[beam_idx, j] = score else: word_probs[0, eos_idx] = valid_score_dist[0] word_probs[1, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions in other beams for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz-1, k) word_probs[beam_idx, j] = score attns = torch.randn(beam_sz) beam.advance(word_probs, attns) if i < min_length: self.assertFalse(beam.done) elif i == min_length: # beam 1 dies on min_length self.assertEqual(beam.finished[0][1], beam.min_length + 1) self.assertEqual(beam.finished[0][2], 1) self.assertFalse(beam.done) else: # i > min_length # beam 0 dies on the step after beam 1 dies self.assertEqual(beam.finished[1][1], beam.min_length + 2) self.assertEqual(beam.finished[1][2], 0) self.assertTrue(beam.done)
def test_doesnt_predict_eos_if_shorter_than_min_len(self): # beam 0 will always predict EOS. The other beams will predict # non-eos scores. # this is also a test that when block_ngram_repeat=0, # repeating is acceptable beam_sz = 5 n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) min_length = 5 eos_idx = 2 beam = Beam(beam_sz, 0, 1, eos_idx, n_best=2, exclusion_tokens=set(), min_length=min_length, global_scorer=GlobalScorerStub(), block_ngram_repeat=0) for i in range(min_length + 4): # non-interesting beams are going to get dummy values word_probs = torch.full((beam_sz, n_words), -float('inf')) if i == 0: # "best" prediction is eos - that should be blocked word_probs[0, eos_idx] = valid_score_dist[0] # include at least beam_sz predictions OTHER than EOS # that are greater than -1e20 for j, score in zip(_non_eos_idxs, valid_score_dist[1:]): word_probs[0, j] = score else: # predict eos in beam 0 word_probs[0, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz-1, k) word_probs[beam_idx, j] = score attns = torch.randn(beam_sz) beam.advance(word_probs, attns) if i < min_length: expected_score_dist = (i+1) * valid_score_dist[1:] self.assertTrue(beam.scores.allclose(expected_score_dist)) elif i == min_length: # now the top beam has ended and no others have # first beam finished had length beam.min_length self.assertEqual(beam.finished[0][1], beam.min_length + 1) # first beam finished was 0 self.assertEqual(beam.finished[0][2], 0) else: # i > min_length # not of interest, but want to make sure it keeps running # since only beam 0 terminates and n_best = 2 pass
def test_repeating_excluded_index_does_not_die(self): # beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx) beam_sz = 5 n_words = 100 repeat_idx = 47 # will be repeated and should be blocked repeat_idx_ignored = 7 # will be repeated and should not be blocked ngram_repeat = 3 beam = Beam(beam_sz, 0, 1, 2, n_best=2, exclusion_tokens=set([repeat_idx_ignored]), global_scorer=GlobalScorerStub(), block_ngram_repeat=ngram_repeat) for i in range(ngram_repeat + 4): # non-interesting beams are going to get dummy values word_probs = torch.full((beam_sz, n_words), -float('inf')) if i == 0: word_probs[0, repeat_idx] = -0.1 word_probs[0, repeat_idx + i + 1] = -2.3 word_probs[0, repeat_idx_ignored] = -5.0 else: # predict the same thing in beam 0 word_probs[0, repeat_idx] = 0 # continue pushing around what beam 1 predicts word_probs[1, repeat_idx + i + 1] = 0 # predict the allowed-repeat again in beam 2 word_probs[2, repeat_idx_ignored] = 0 attns = torch.randn(beam_sz) beam.advance(word_probs, attns) if i <= ngram_repeat: self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE)) self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE)) self.assertFalse(beam.scores[2].eq(self.BLOCKED_SCORE)) else: # now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1 # and the rest die self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE)) # since all preds after i=0 are 0, we can check # that the beam is the correct idx by checking that # the curr score is the initial score self.assertTrue(beam.scores[0].eq(-2.3)) self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE)) self.assertTrue(beam.scores[1].eq(-5.0)) self.assertTrue(beam.scores[2:].equal( torch.tensor([self.BLOCKED_SCORE] * (beam_sz - 2))))
def test_repeating_excluded_index_does_not_die(self): # beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx) beam_sz = 5 n_words = 100 repeat_idx = 47 # will be repeated and should be blocked repeat_idx_ignored = 7 # will be repeated and should not be blocked ngram_repeat = 3 beam = Beam(beam_sz, 0, 1, 2, n_best=2, exclusion_tokens=set([repeat_idx_ignored]), global_scorer=GlobalScorerStub(), block_ngram_repeat=ngram_repeat) for i in range(ngram_repeat + 4): # non-interesting beams are going to get dummy values word_probs = torch.full((beam_sz, n_words), -float('inf')) if i == 0: word_probs[0, repeat_idx] = -0.1 word_probs[0, repeat_idx + i + 1] = -2.3 word_probs[0, repeat_idx_ignored] = -5.0 else: # predict the same thing in beam 0 word_probs[0, repeat_idx] = 0 # continue pushing around what beam 1 predicts word_probs[1, repeat_idx + i + 1] = 0 # predict the allowed-repeat again in beam 2 word_probs[2, repeat_idx_ignored] = 0 attns = torch.randn(beam_sz) beam.advance(word_probs, attns) if i <= ngram_repeat: self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE)) self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE)) self.assertFalse(beam.scores[2].eq(self.BLOCKED_SCORE)) else: # now beam 0 dies, beam 1 -> beam 0, beam 2 -> beam 1 # and the rest die self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE)) # since all preds after i=0 are 0, we can check # that the beam is the correct idx by checking that # the curr score is the initial score self.assertTrue(beam.scores[0].eq(-2.3)) self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE)) self.assertTrue(beam.scores[1].eq(-5.0)) self.assertTrue( beam.scores[2:].equal(torch.tensor( [self.BLOCKED_SCORE] * (beam_sz - 2))))
def test_beam_advance_against_known_reference(self): scorer = GNMTGlobalScorer(0.7, 0., "avg", "none") beam = Beam(self.BEAM_SZ, 0, 1, self.EOS_IDX, n_best=self.N_BEST, exclusion_tokens=set(), min_length=0, global_scorer=scorer, block_ngram_repeat=0) expected_beam_scores = self.init_step(beam) expected_beam_scores = self.first_step(beam, expected_beam_scores, 3) expected_beam_scores = self.second_step(beam, expected_beam_scores, 4) self.third_step(beam, expected_beam_scores, 5)
def test_beam_advance_against_known_reference(self): beam = Beam(self.BEAM_SZ, 0, 1, self.EOS_IDX, n_best=self.N_BEST, exclusion_tokens=set(), min_length=0, global_scorer=GlobalScorerStub(), block_ngram_repeat=0) expected_beam_scores = self.init_step(beam) expected_beam_scores = self.first_step(beam, expected_beam_scores, 1) expected_beam_scores = self.second_step(beam, expected_beam_scores, 1) self.third_step(beam, expected_beam_scores, 1)
def test_advance_with_some_repeats_gets_blocked(self): # beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores) beam_sz = 5 n_words = 100 repeat_idx = 47 ngram_repeat = 3 beam = Beam(beam_sz, 0, 1, 2, n_best=2, exclusion_tokens=set(), global_scorer=GlobalScorerStub(), block_ngram_repeat=ngram_repeat) for i in range(ngram_repeat + 4): # non-interesting beams are going to get dummy values word_probs = torch.full((beam_sz, n_words), -float('inf')) if i == 0: # on initial round, only predicted scores for beam 0 # matter. Make two predictions. Top one will be repeated # in beam zero, second one will live on in beam 1. word_probs[0, repeat_idx] = -0.1 word_probs[0, repeat_idx + i + 1] = -2.3 else: # predict the same thing in beam 0 word_probs[0, repeat_idx] = 0 # continue pushing around what beam 1 predicts word_probs[1, repeat_idx + i + 1] = 0 attns = torch.randn(beam_sz) beam.advance(word_probs, attns) if i <= ngram_repeat: self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE)) self.assertFalse(beam.scores[1].eq(self.BLOCKED_SCORE)) else: # now beam 0 dies (along with the others), beam 1 -> beam 0 self.assertFalse(beam.scores[0].eq(self.BLOCKED_SCORE)) self.assertTrue(beam.scores[1:].equal( torch.tensor([self.BLOCKED_SCORE] * (beam_sz - 1))))
def test_beam_is_done_when_n_best_beams_eos_using_min_length(self): # this is also a test that when block_ngram_repeat=0, # repeating is acceptable beam_sz = 5 n_words = 100 _non_eos_idxs = [47, 51, 13, 88, 99] valid_score_dist = torch.log_softmax(torch.tensor( [6., 5., 4., 3., 2., 1.]), dim=0) min_length = 5 eos_idx = 2 beam = Beam(beam_sz, 0, 1, eos_idx, n_best=2, exclusion_tokens=set(), min_length=min_length, global_scorer=GlobalScorerStub(), block_ngram_repeat=0) for i in range(min_length + 4): # non-interesting beams are going to get dummy values word_probs = torch.full((beam_sz, n_words), -float('inf')) if i == 0: # "best" prediction is eos - that should be blocked word_probs[0, eos_idx] = valid_score_dist[0] # include at least beam_sz predictions OTHER than EOS # that are greater than -1e20 for j, score in zip(_non_eos_idxs, valid_score_dist[1:]): word_probs[0, j] = score elif i <= min_length: # predict eos in beam 1 word_probs[1, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions in other beams for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz - 1, k) word_probs[beam_idx, j] = score else: word_probs[0, eos_idx] = valid_score_dist[0] word_probs[1, eos_idx] = valid_score_dist[0] # provide beam_sz other good predictions in other beams for k, (j, score) in enumerate( zip(_non_eos_idxs, valid_score_dist[1:])): beam_idx = min(beam_sz - 1, k) word_probs[beam_idx, j] = score attns = torch.randn(beam_sz) beam.advance(word_probs, attns) if i < min_length: self.assertFalse(beam.done) elif i == min_length: # beam 1 dies on min_length self.assertEqual(beam.finished[0][1], beam.min_length + 1) self.assertEqual(beam.finished[0][2], 1) self.assertFalse(beam.done) else: # i > min_length # beam 0 dies on the step after beam 1 dies self.assertEqual(beam.finished[1][1], beam.min_length + 2) self.assertEqual(beam.finished[1][2], 0) self.assertTrue(beam.done)
def translate_batch(self, batch): beam_size = self.beam_size tgt_field = self.fields['tgt'][0][1] vocab = tgt_field.vocab pad = vocab.stoi[tgt_field.pad_token] eos = vocab.stoi[tgt_field.eos_token] bos = vocab.stoi[tgt_field.init_token] b = Beam(beam_size, n_best=self.n_best, cuda=self.cuda, pad=pad, eos=eos, bos=bos) src, src_lengths = batch.src # why doesn't this contain inflection source lengths when ensembling? side_info = side_information(batch) encoder_out = self.model.encode(src, lengths=src_lengths, **side_info) enc_states = encoder_out["enc_state"] memory_bank = encoder_out["memory_bank"] infl_memory_bank = encoder_out.get("inflection_memory_bank", None) self.model.init_decoder_state(enc_states) results = dict() if "tgt" in batch.__dict__: results["gold_score"] = self._score_target( batch, memory_bank, src_lengths, inflection_memory_bank=infl_memory_bank, **side_info) self.model.init_decoder_state(enc_states) else: results["gold_score"] = 0 # (2) Repeat src objects `beam_size` times. self.model.map_decoder_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) else: memory_bank = tile(memory_bank, beam_size, dim=1) memory_lengths = tile(src_lengths, beam_size) if infl_memory_bank is not None: if isinstance(infl_memory_bank, tuple): infl_memory_bank = tuple( tile(x, beam_size, dim=1) for x in infl_memory_bank) else: infl_memory_bank = tile(infl_memory_bank, beam_size, dim=1) tiled_infl_len = tile(side_info["inflection_lengths"], beam_size) side_info["inflection_lengths"] = tiled_infl_len if "language" in side_info: side_info["language"] = tile(side_info["language"], beam_size) for i in range(self.max_length): if b.done(): break inp = b.current_state.unsqueeze(0) # the decoder expects an input of tgt_len x batch dec_out, dec_attn = self.model.decode( inp, memory_bank, memory_lengths=memory_lengths, inflection_memory_bank=infl_memory_bank, **side_info) attn = dec_attn["lemma"].squeeze(0) out = self.model.generator(dec_out.squeeze(0), transform=True, **side_info) # b.advance will take attn (beam size x src length) b.advance(out, dec_attn) select_indices = b.current_origin self.model.map_decoder_state( lambda state, dim: state.index_select(dim, select_indices)) scores, ks = b.sort_finished() hyps, attn, out_probs = [], [], [] for i, (times, k) in enumerate(ks[:self.n_best]): hyp, att, out_p = b.get_hyp(times, k) hyps.append(hyp) attn.append(att) out_probs.append(out_p) results["preds"] = hyps results["scores"] = scores results["attn"] = attn if self.beam_accum is not None: parent_ids = [t.tolist() for t in b.prev_ks] self.beam_accum["beam_parent_ids"].append(parent_ids) scores = [["%4f" % s for s in t.tolist()] for t in b.all_scores][1:] self.beam_accum["scores"].append(scores) pred_ids = [[vocab.itos[i] for i in t.tolist()] for t in b.next_ys][1:] self.beam_accum["predicted_ids"].append(pred_ids) if self.attn_path is not None: save_attn = {k: v.cpu() for k, v in attn[0].items()} src_seq = self.itos(src, "src") pred_seq = self.itos(hyps[0], "tgt") attn_dict = {"src": src_seq, "pred": pred_seq, "attn": save_attn} if "inflection" in save_attn: inflection_seq = self.itos(batch.inflection[0], "inflection") attn_dict["inflection"] = inflection_seq self.attns.append(attn_dict) if self.probs_path is not None: save_probs = out_probs[0].cpu() self.probs.append(save_probs) return results
def _translate_batch(self, batch, data): # (0) Prep each of the components of the search. # And helper method for reducing verbosity. beam_size = self.beam_size batch_size = batch.batch_size data_type = data.data_type vocab = self.fields["tgt"].vocab # Define a list of tokens to exclude from ngram-blocking # exclusion_list = ["<t>", "</t>", "."] exclusion_tokens = set( [vocab.stoi[t] for t in self.ignore_when_blocking]) beam_batch = [ Beam( beam_size, n_best=self.n_best, cuda=self.cuda, global_scorer=self.global_scorer, pad=vocab.stoi[inputters.PAD_WORD], eos=vocab.stoi[inputters.EOS_WORD], bos=vocab.stoi[inputters.BOS_WORD], min_length=self.min_length, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=exclusion_tokens, ) for __ in range(batch_size) ] # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder( batch, data_type) self.model.decoder.init_state(src, memory_bank, enc_states) results = {} results["predictions"] = [] results["scores"] = [] results["attention"] = [] results["batch"] = batch if "tgt" in batch.__dict__: results["gold_score"] = self._score_target( batch, memory_bank, src_lengths, data, batch.src_map if data_type == "text" and self.copy_attn else None, ) self.model.decoder.init_state(src, memory_bank, enc_states) else: results["gold_score"] = [0] * batch_size # (2) Repeat src objects `beam_size` times. # We use now batch_size x beam_size (same as fast mode) src_map = (tile(batch.src_map, beam_size, dim=1) if data.data_type == "text" and self.copy_attn else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) else: memory_bank = tile(memory_bank, beam_size, dim=1) memory_lengths = tile(src_lengths, beam_size) # (3) run the decoder to generate sentences, using beam search. for i in range(self.max_length): if all((b.done() for b in beam_batch)): break # (a) Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. inp = torch.stack([b.current_state for b in beam_batch]) inp = inp.view(1, -1, 1) # (b) Decode and forward out, beam_attn = self._decode_and_generate( inp, memory_bank, batch, data, memory_lengths=memory_lengths, src_map=src_map, step=i, ) out = out.view(batch_size, beam_size, -1) beam_attn = beam_attn.view(batch_size, beam_size, -1) # (c) Advance each beam. select_indices_array = [] # Loop over the batch_size number of beam for j, b in enumerate(beam_batch): b.advance(out[j, :], beam_attn[j, :, :memory_lengths[j]]) select_indices_array.append(b.current_origin + j * beam_size) select_indices = torch.cat(select_indices_array) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) # (4) Extract sentences from beam. for b in beam_batch: n_best = self.n_best scores, ks = b.sort_finished(minimum=n_best) hyps, attn = [], [] for i, (times, k) in enumerate(ks[:n_best]): hyp, att = b.get_hyp(times, k) hyps.append(hyp) attn.append(att) results["predictions"].append(hyps) results["scores"].append(scores) results["attention"].append(attn) if self.beam_accum is not None: parent_ids = [t.tolist() for t in b.prev_ks] self.beam_accum["beam_parent_ids"].append(parent_ids) scores = [["%4f" % s for s in t.tolist()] for t in b.all_scores][1:] self.beam_accum["scores"].append(scores) pred_ids = [[vocab.itos[i] for i in t.tolist()] for t in b.next_ys][1:] self.beam_accum["predicted_ids"].append(pred_ids) if self.beam_scores is not None: self.beam_scores.append(torch.stack(b.all_scores).cpu()) return results
def generate(self, src, lengths, dec_idx, max_length=20, beam_size=5, n_best=1): assert dec_idx == 0 or dec_idx == 1 batch_size = src.size(1) def var(a): return torch.tensor(a, requires_grad=False) def rvar(a): return var(a.repeat(1, beam_size, 1)) def bottle(m): return m.view(batch_size * beam_size, -1) def unbottle(m): return m.view(beam_size, batch_size, -1) def from_beam(beam): ret = {"predictions": [], "scores": [], "attention": []} for b in beam: scores, ks = b.sort_finished(minimum=n_best) hyps, attn = [], [] for i, (times, k) in enumerate(ks[:n_best]): hyp, att = b.get_hyp(times, k) hyps.append(hyp) attn.append(att) ret["predictions"].append(hyps) ret["scores"].append(scores) ret["attention"].append(attn) return ret scorer = GNMTGlobalScorer(0, 0, "none", "none") beam = [Beam(beam_size, n_best=n_best, cuda=self.cuda(), global_scorer=scorer, pad=PAD_IDX, eos=SOS_IDX, bos=EOS_IDX, min_length=0, stepwise_penalty=False, block_ngram_repeat=0) for __ in range(batch_size)] enc_final, memory_bank = self.encoder(src, lengths) token = torch.full((1, batch_size, 1), SOS_IDX, dtype=torch.long, device=next(self.parameters()).device) dec_state = enc_final dec_state = self.choose_decoder(dec_idx).init_decoder_state(src, memory_bank, dec_state) memory_bank = rvar(memory_bank.data) memory_lengths = lengths.repeat(beam_size) dec_state.repeat_beam_size_times(beam_size) # unroll all_indices = [] for i in range(max_length): if all((b.done() for b in beam)): break inp = var(torch.stack([b.get_current_state() for b in beam]).t().contiguous().view(1, -1)) inp = inp.unsqueeze(2) decoder_output, dec_state, attn = self.choose_decoder(dec_idx)(inp, memory_bank, dec_state, memory_lengths=memory_lengths, step=i) decoder_output = decoder_output.squeeze(0) out = self.generator(decoder_output).data out = unbottle(out) # beam x tgt_vocab beam_attn = unbottle(attn["std"]) for j, b in enumerate(beam): b.advance(out[:, j], beam_attn.data[:, j, :memory_lengths[j]]) dec_state.beam_update(j, b.get_current_origin(), beam_size) ret = from_beam(beam) # ret["src"] = src.transpose(1, 0) return ret