def initialize(self, src, src_lengths, src_map=None, device=None, target_prefix=None): """Initialize for decoding. Repeat src objects `beam_size` times. """ def fn_map_state(state, dim): return tile(state, self.beam_size, dim=dim) src = fn_map_state(src, dim=1) if src_map is not None: src_map = tile(src_map, self.beam_size, dim=1) if device is None: device = src.device self.memory_lengths = tile(src_lengths, self.beam_size) if target_prefix is not None: target_prefix = tile(target_prefix, self.beam_size, dim=1) super(BeamSearchLM, self).initialize_(None, self.memory_lengths, src_map=src_map, device=device, target_prefix=target_prefix) return fn_map_state, src, self.memory_lengths, src_map
def initialize(self, memory_bank, src_lengths, src_map=None, device=None): """Initialize for decoding.""" #fn_map_state = None def fn_map_state(state, dim): return tile(state, self.sample_size, dim=dim) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, self.sample_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, self.sample_size, dim=1) mb_device = memory_bank.device if src_map is not None: src_map = tile(src_map, self.sample_size, dim=1) if device is None: device = mb_device self.memory_lengths = tile(src_lengths, self.sample_size) super(PriorSampling, self).initialize( memory_bank, self.memory_lengths, src_map, device) self.select_indices = torch.arange( self.batch_size * self.sample_size, dtype=torch.long, device=device) self.original_batch_idx = tile(torch.arange( self.batch_size, dtype=torch.long, device=device), self.sample_size) return fn_map_state, memory_bank, self.memory_lengths, src_map
def initialize(self, memory_bank, src_lengths, src_map=None, device=None, target_prefix=None): """Initialize for decoding. Repeat src objects `beam_size` times. """ def fn_map_state(state, dim): return tile(state, self.beam_size, dim=dim) if isinstance(memory_bank, tuple): memory_bank = tuple( tile(x, self.beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, self.beam_size, dim=1) mb_device = memory_bank.device if src_map is not None: src_map = tile(src_map, self.beam_size, dim=1) if device is None: device = mb_device self.memory_lengths = tile(src_lengths, self.beam_size) if target_prefix is not None: target_prefix = tile(target_prefix, self.beam_size, dim=1) super(BeamSearch, self).initialize_(memory_bank, self.memory_lengths, src_map, device, target_prefix) return fn_map_state, memory_bank, self.memory_lengths, src_map
def forward_dev_beam_search(self, encoder_output: torch.Tensor, pad_mask): batch_size = encoder_output.size(1) self.state["cache"] = None memory_lengths = pad_mask.ne(pad_token_index).sum(dim=0) self.map_state(lambda state, dim: tile(state, self.beam_size, dim=dim)) encoder_output = tile(encoder_output, self.beam_size, dim=1) pad_mask = tile(pad_mask, self.beam_size, dim=1) memory_lengths = tile(memory_lengths, self.beam_size, dim=0) # TODO: # - fix attn (?) # - use coverage_penalty="summary" ou "wu" and beta=0.2 (ou pas) # - use length_penalty="wu" and alpha=0.2 (ou pas) beam = BeamSearch(beam_size=self.beam_size, n_best=1, batch_size=batch_size, mb_device=default_device, global_scorer=GNMTGlobalScorer(alpha=0, beta=0, coverage_penalty="none", length_penalty="avg"), pad=pad_token_index, eos=eos_token_index, bos=bos_token_index, min_length=1, max_length=100, return_attention=False, stepwise_penalty=False, block_ngram_repeat=0, exclusion_tokens=set(), memory_lengths=memory_lengths, ratio=-1) for i in range(self.max_seq_out_len): inp = beam.current_predictions.view(1, -1) out, attn = self.forward_step(src=pad_mask, tgt=inp, memory_bank=encoder_output, step=i) # 1 x batch*beam x hidden out = self.linear(out) # 1 x batch*beam x vocab_out out = log_softmax(out, dim=2) # 1 x batch*beam x vocab_out out = out.squeeze(0) # batch*beam x vocab_out # attn = attn.squeeze(0) # batch*beam x vocab_out # out = out.view(batch_size, self.beam_size, -1) # batch x beam x vocab_out # attn = attn.view(batch_size, self.beam_size, -1) # TODO: fix attn (?) beam.advance(out, attn) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. encoder_output = encoder_output.index_select(1, select_indices) pad_mask = pad_mask.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) self.map_state(lambda state, dim: state.index_select(dim, select_indices)) outputs = beam.predictions outputs = [x[0] for x in outputs] outputs = pad_sequence(outputs, batch_first=True) return [outputs]
def initialize(self, memory_bank, src_lengths, src_map=None, device=None, target_prefix=None): """Initialize for decoding. Repeat src objects `beam_size` times. """ def fn_map_state(state, dim): return tile(state, self.beam_size, dim=dim) if isinstance(memory_bank, tuple): memory_bank = tuple( tile(x, self.beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, self.beam_size, dim=1) mb_device = memory_bank.device if src_map is not None: src_map = tile(src_map, self.beam_size, dim=1) if device is None: device = mb_device self.memory_lengths = tile(src_lengths, self.beam_size) if target_prefix is not None: target_prefix = tile(target_prefix, self.beam_size, dim=1) super(BeamSearch, self).initialize(memory_bank, self.memory_lengths, src_map, device, target_prefix) self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=device) self._beam_offset = torch.arange(0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=device) self.topk_log_probs = torch.tensor( [0.0] + [float("-inf")] * (self.beam_size - 1), device=device).repeat(self.batch_size) # buffers for the topk scores and 'backpointer' self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device) self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=device) self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=device) return fn_map_state, memory_bank, self.memory_lengths, src_map
def _align_forward(self, batch, predictions): """ For a batch of input and its prediction, return a list of batch predict alignment src indice Tensor in size ``(batch, n_best,)``. """ # (0) add BOS and padding to tgt prediction if hasattr(batch, 'tgt'): batch_tgt_idxs = batch.tgt.transpose(1, 2).transpose(0, 2) else: batch_tgt_idxs = self._align_pad_prediction(predictions, bos=self._tgt_bos_idx, pad=self._tgt_pad_idx) tgt_mask = (batch_tgt_idxs.eq(self._tgt_pad_idx) | batch_tgt_idxs.eq(self._tgt_eos_idx) | batch_tgt_idxs.eq(self._tgt_bos_idx)) n_best = batch_tgt_idxs.size(1) # (1) Encoder forward. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) # (2) Repeat src objects `n_best` times. # We use batch_size x n_best, get ``(src_len, batch * n_best, nfeat)`` src = tile(src, n_best, dim=1) enc_states = tile(enc_states, n_best, dim=1) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, n_best, dim=1) for x in memory_bank) else: memory_bank = tile(memory_bank, n_best, dim=1) src_lengths = tile(src_lengths, n_best) # ``(batch * n_best,)`` # (3) Init decoder with n_best src, self.model.decoder.init_state(src, memory_bank, enc_states) self.turbo_decoder.init_state(src, memory_bank, enc_states) # reshape tgt to ``(len, batch * n_best, nfeat)`` tgt = batch_tgt_idxs.view(-1, batch_tgt_idxs.size(-1)).T.unsqueeze(-1) dec_in = tgt[:-1] # exclude last target from inputs _, attns = self.model.decoder(dec_in, memory_bank, memory_lengths=src_lengths, with_align=True) alignment_attn = attns["align"] # ``(B, tgt_len-1, src_len)`` # masked_select align_tgt_mask = tgt_mask.view(-1, tgt_mask.size(-1)) prediction_mask = align_tgt_mask[:, 1:] # exclude bos to match pred # get aligned src id for each prediction's valid tgt tokens alignement = extract_alignment(alignment_attn, prediction_mask, src_lengths, n_best) return alignement
def forward(self, batch_size, beam_size, max_seq_len, memory, memory_seq_lens): extended_memory = tile(memory, beam_size) extended_memory_seq_lens = tile(memory_seq_lens, beam_size) output_ids, parent_ids, out_seq_lens = self.decoding.forward( batch_size, beam_size, max_seq_len, extended_memory, extended_memory_seq_lens) parent_ids = parent_ids % beam_size beams, lengths = finalize(beam_size, output_ids, parent_ids, out_seq_lens, self.end_id, max_seq_len, args=self.args) return beams, lengths
def initialize_tile(self, memory_bank, src_lengths, src_map=None, target_prefix=None): def fn_map_state(state, dim): return tile(state, self.beam_size, dim=dim) if isinstance(memory_bank, tuple): memory_bank = tuple( tile(x, self.beam_size, dim=1) for x in memory_bank) elif memory_bank is not None: memory_bank = tile(memory_bank, self.beam_size, dim=1) if src_map is not None: src_map = tile(src_map, self.beam_size, dim=1) self.memory_lengths = tile(src_lengths, self.beam_size) if target_prefix is not None: target_prefix = tile(target_prefix, self.beam_size, dim=1) return fn_map_state, memory_bank, src_map, target_prefix
def forward(self, batch_size, beam_size, max_seq_len, memory, memory_seq_lens): extended_memory = tile(memory, beam_size) batchxbeam = extended_memory.size(0) extended_memory = extended_memory.transpose(0, 1).contiguous() extended_memory_seq_lens = tile(memory_seq_lens, beam_size) start_ids = extended_memory_seq_lens.new_full((batchxbeam, ), self.start_id, dtype=torch.int64) initial_log_probs = extended_memory.new_full((beam_size, ), -float("inf"), dtype=torch.float32) initial_log_probs[0] = 0. initial_log_probs = initial_log_probs.repeat(batch_size) sequence_lengths = extended_memory_seq_lens.new_full((batchxbeam, ), 0) finished = extended_memory_seq_lens.new_full((batchxbeam, ), 0, dtype=torch.bool) dtype_info = torch.finfo(extended_memory.dtype) eos_max_prob = extended_memory.new_full((batchxbeam, self.vocab_size), dtype_info.min) eos_max_prob[:, self.end_id] = dtype_info.max self.decoder.init_state(extended_memory, extended_memory, None) word_ids = start_ids cum_log_probs = initial_log_probs for step in range(max_seq_len): if not torch.bitwise_not(finished).any(): break word_ids = word_ids.view(1, -1, 1) dec_out, dec_attn = self.decoder( word_ids, extended_memory, memory_lengths=extended_memory_seq_lens, step=step) logits = self.generator(dec_out.squeeze(0)) logits = torch.where(finished.view(-1, 1), eos_max_prob, logits).to(torch.float32) log_probs = self.logsoftmax(logits.to(torch.float32)) total_probs = log_probs + torch.unsqueeze(cum_log_probs, 1) total_probs = total_probs.view(-1, beam_size * self.vocab_size) # beamsearch # _, sample_ids = torch.topk(total_probs, beam_size) # sample_ids = sample_ids.view(-1) #diversesiblingsearch sibling_score = torch.arange(1, beam_size + 1).to( total_probs.dtype).to(extended_memory.device ) * self.diversity_rate # [beam_size] scores, ids = torch.topk( total_probs.view(-1, beam_size, self.vocab_size), beam_size) # [batch size, beam width, beam width] scores = scores + sibling_score # [batch size, beam width, beam width] scores = scores.view(-1, beam_size * beam_size) ids = ids + torch.unsqueeze( torch.unsqueeze( torch.arange(0, beam_size).to(extended_memory.device) * self.vocab_size, 0), -1) ids = ids.view(-1, beam_size * beam_size) _, final_ids = torch.topk(scores, beam_size) # [batch size, beam size] final_ids = final_ids.view(-1, 1) batch_index = torch.arange(0, batch_size).to( extended_memory.device).view(-1, 1).repeat(1, beam_size).view(-1, 1) index = torch.cat([batch_index, final_ids], 1) sample_ids = gather_nd(ids, index) word_ids = sample_ids % self.vocab_size # [batch_size * beam_size] beam_ids = sample_ids // self.vocab_size # [batch_size * beam_size] beam_indices = (torch.arange(batchxbeam).to(extended_memory.device) // beam_size) * beam_size + beam_ids sequence_lengths = torch.where(finished, sequence_lengths, sequence_lengths + 1) batch_pos = torch.arange(batchxbeam).to( extended_memory.device) // beam_size next_cum_log_probs = gather_nd( total_probs, torch.stack([batch_pos, sample_ids], -1)) # [batch_size * beam_size] finished = finished.index_select(0, beam_indices) sequence_lengths = sequence_lengths.index_select(0, beam_indices) self.decoder.map_state( lambda state, dim: state.index_select(dim, beam_indices)) if step == 0: parent_ids = beam_ids.view(1, -1) output_ids = word_ids.view(1, -1) else: parent_ids = torch.cat((parent_ids, beam_ids.view(1, -1))) output_ids = torch.cat((output_ids, word_ids.view(1, -1))) cum_log_probs = torch.where(finished, cum_log_probs, next_cum_log_probs) finished = torch.bitwise_or(finished, torch.eq(word_ids, self.end_id)) beams, lengths = finalize(beam_size, output_ids, parent_ids, sequence_lengths, self.end_id, args=self.args) return beams, lengths
def _translate_batch(self, batch, data): # (0) Prep each of the components of the search. # And helper method for reducing verbosity. beam_size = self.beam_size batch_size = batch.batch_size tgt_field = self.fields['tgt'][0][1].base_field vocab = tgt_field.vocab # Define a set of tokens to exclude from ngram-blocking exclusion_tokens = {vocab.stoi[t] for t in self.ignore_when_blocking} pad = vocab.stoi[tgt_field.pad_token] eos = vocab.stoi[tgt_field.eos_token] bos = vocab.stoi[tgt_field.init_token] beam = [ onmt.translate.Beam(beam_size, n_best=self.n_best, cuda=self.cuda, global_scorer=self.global_scorer, pad=pad, eos=eos, bos=bos, min_length=self.min_length, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=exclusion_tokens) for __ in range(batch_size) ] # (1) Run the encoder on the src. src, enc_states_ques, memory_bank_ques, ques_lengths, ans, enc_states_ans, memory_bank_ans, ans_lengths = self._run_encoder( batch) #################33 Modified ####################### enc_states_final = tuple( torch.add(enc_q, enc_ans) for enc_q, enc_ans in zip(enc_states_ques, enc_states_ans)) memory_bank_final = torch.cat([memory_bank_ques, memory_bank_ans], 0) memory_lengths_final = torch.add(ques_lengths, ans_lengths) self.model.decoder.init_state(src, memory_bank_final, enc_states_final) results = {} results["predictions"] = [] results["scores"] = [] results["attention"] = [] results["batch"] = batch if "tgt" in batch.__dict__: results["gold_score"] = self._score_target( batch, memory_bank_final, memory_lengths_final, data, batch.src_map if self.copy_attn else None) self.model.decoder.init_state(src, memory_bank_final, enc_states_final) else: results["gold_score"] = [0] * batch_size # (2) Repeat src objects `beam_size` times. # We use now batch_size x beam_size (same as fast mode) src_map = (tile(batch.src_map, beam_size, dim=1) if self.copy_attn else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank_final, tuple): memory_bank_final = tuple( tile(x, beam_size, dim=1) for x in memory_bank_final) else: memory_bank_final = tile(memory_bank_final, beam_size, dim=1) memory_lengths_final = tile(memory_lengths_final, beam_size) # (3) run the decoder to generate sentences, using beam search. for i in range(self.max_length): if all((b.done() for b in beam)): break # (a) Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. inp = torch.stack([b.get_current_state() for b in beam]) inp = inp.view(1, -1, 1) # (b) Decode and forward out, beam_attn = self._decode_and_generate( inp, memory_bank_final, batch, data, memory_lengths=memory_lengths_final, src_map=src_map, step=i) out = out.view(batch_size, beam_size, -1) beam_attn = beam_attn.view(batch_size, beam_size, -1) # (c) Advance each beam. select_indices_array = [] # Loop over the batch_size number of beam for j, b in enumerate(beam): b.advance(out[j, :], beam_attn.data[j, :, :memory_lengths_final[j]]) select_indices_array.append(b.get_current_origin() + j * beam_size) select_indices = torch.cat(select_indices_array) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) # (4) Extract sentences from beam. for b in beam: scores, ks = b.sort_finished(minimum=self.n_best) hyps, attn = [], [] for times, k in ks[:self.n_best]: hyp, att = b.get_hyp(times, k) hyps.append(hyp) attn.append(att) results["predictions"].append(hyps) results["scores"].append(scores) results["attention"].append(attn) return results
def _translate_batch( self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size batch_size = batch.batch_size # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) # (0) pt 2, prep the beam object beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths) for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=beam._batch_offset) beam.advance(log_probs, attn) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
def _fast_translate_batch(self, batch, data, max_length, min_length=0, n_best=1, return_attention=False): # TODO: faster code path for beam_size == 1. # TODO: support these blacklisted features. assert data.data_type == 'text' assert not self.copy_attn assert not self.dump_beam assert not self.use_filter_pred assert self.block_ngram_repeat == 0 assert self.global_scorer.beta == 0 beam_size = self.beam_size batch_size = batch.batch_size vocab = self.fields["tgt"].vocab start_token = vocab.stoi[inputters.BOS_WORD] end_token = vocab.stoi[inputters.EOS_WORD] # Encoder forward. src = inputters.make_features(batch, 'src', data.data_type) _, src_lengths = batch.src enc_states, memory_bank, src_lengths \ = self.model.encoder(src, src_lengths) dec_states = self.model.decoder.init_decoder_state( src, memory_bank, enc_states, with_cache=True) # Tile states and memory beam_size times. dec_states.map_batch_fn( lambda state, dim: tile(state, beam_size, dim=dim)) if type(memory_bank) == tuple: device = memory_bank[0].device memory_bank = tuple(tile(m, beam_size, dim=1) for m in memory_bank) else: memory_bank = tile(memory_bank, beam_size, dim=1) device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8) batch_offset = torch.arange(batch_size, dtype=torch.long) beam_offset = torch.arange( 0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device) alive_seq = torch.full( [batch_size * beam_size, 1], start_token, dtype=torch.long, device=device) alive_attn = None # Give full probability to the first beam on the first step. topk_log_probs = ( torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size)) # Structure that holds finished hypotheses. hypotheses = [[] for _ in range(batch_size)] # noqa: F812 results = {} results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 results["attention"] = [[] for _ in range(batch_size)] # noqa: F812 results["gold_score"] = [0] * batch_size results["batch"] = batch if self.mask is not None: mask = self.mask.get_log_probs_masking_tensor(src.squeeze(2), beam_size).to(memory_bank.device) for step in range(max_length): decoder_input = alive_seq[:, -1].view(1, -1, 1) # Decoder forward. dec_out, dec_states, attn = self.model.decoder( decoder_input, memory_bank, dec_states, memory_lengths=memory_lengths, step=step) # Generator forward. log_probs = self.model.generator.forward(dec_out.squeeze(0)) vocab_size = log_probs.size(-1) if step < min_length: log_probs[:, end_token] = -1e20 if self.mask is not None: log_probs = log_probs * mask # Multiply probs by the beam probability. log_probs += topk_log_probs.view(-1).unsqueeze(1) alpha = self.global_scorer.alpha length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha # Flatten probs into a list of possibilities. curr_scores = log_probs / length_penalty curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) # Recover log probs. topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # Map beam_index to batch_index in the flat representation. batch_index = ( topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) select_indices = batch_index.view(-1) # Append last prediction. alive_seq = torch.cat( [alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1) if return_attention: current_attn = attn["std"].index_select(1, select_indices) if alive_attn is None: alive_attn = current_attn else: alive_attn = alive_attn.index_select(1, select_indices) alive_attn = torch.cat([alive_attn, current_attn], 0) is_finished = topk_ids.eq(end_token) if step + 1 == max_length: is_finished.fill_(1) # Save finished hypotheses. if is_finished.any(): # Penalize beams that finished. topk_log_probs.masked_fill_(is_finished, -1e10) is_finished = is_finished.to('cpu') top_beam_finished |= is_finished[:, 0].eq(1) predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) attention = ( alive_attn.view( alive_attn.size(0), -1, beam_size, alive_attn.size(-1)) if alive_attn is not None else None) non_finished_batch = [] for i in range(is_finished.size(0)): b = batch_offset[i] finished_hyp = is_finished[i].nonzero().view(-1) # Store finished hypotheses for this batch. for j in finished_hyp: # if (predictions[i, j, 1:] == end_token).sum() <= 1: hypotheses[b].append(( topk_scores[i, j], predictions[i, j, 1:], # Ignore start_token. attention[:, i, j, :memory_lengths[i]] if attention is not None else None)) # End condition is the top beam finished and we can return # n_best hypotheses. if top_beam_finished[i] and len(hypotheses[b]) >= n_best: best_hyp = sorted( hypotheses[b], key=lambda x: x[0], reverse=True) for n, (score, pred, attn) in enumerate(best_hyp): if n >= n_best: break results["scores"][b].append(score) results["predictions"][b].append(pred) results["attention"][b].append( attn if attn is not None else []) else: non_finished_batch.append(i) non_finished = torch.tensor(non_finished_batch) # If all sentences are translated, no need to go further. if len(non_finished) == 0: break # Remove finished batches for the next step. top_beam_finished = top_beam_finished.index_select( 0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) non_finished = non_finished.to(topk_ids.device) topk_log_probs = topk_log_probs.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) select_indices = batch_index.view(-1) alive_seq = predictions.index_select(0, non_finished) \ .view(-1, alive_seq.size(-1)) if alive_attn is not None: alive_attn = attention.index_select(1, non_finished) \ .view(alive_attn.size(0), -1, alive_attn.size(-1)) # Reorder states. if type(memory_bank) == tuple: memory_bank = tuple(m.index_select(1, select_indices) for m in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) dec_states.map_batch_fn( lambda state, dim: state.index_select(dim, select_indices)) if self.mask is not None: mask = mask.index_select(0, select_indices) return results
def fn_map_state(state, dim): return tile(state, self.beam_size, dim=dim)
def _translate_batch( self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False, xlation_builder=None, ): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size batch_size = batch.batch_size # (1) Run the encoder on the src. src_list, enc_states_list, memory_bank_list, src_lengths_list = self._run_encoder(batch) self.model.decoder.init_state(src_list, memory_bank_list, enc_states_list) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank_list, src_lengths_list, src_vocabs, use_src_map, enc_states_list, batch_size, src_list)} # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map_list = list() for src_type in self.src_types: src_map_list.append((tile(getattr(batch, f"src_map.{src_type}"), beam_size, dim=1) if use_src_map else None)) # end for self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) memory_lengths_list = list() memory_lengths = list() for src_i in range(len(memory_bank_list)): if isinstance(memory_bank_list[src_i], tuple): memory_bank_list[src_i] = tuple(tile(x, beam_size, dim=1) for x in memory_bank_list[src_i]) mb_device = memory_bank_list[src_i][0].device else: memory_bank_list[src_i] = tile(memory_bank_list[src_i], beam_size, dim=1) mb_device = memory_bank_list[src_i].device # end if memory_lengths_list.append(tile(src_lengths_list[src_i], beam_size)) memory_lengths.append(src_lengths_list[src_i]) # end for memory_lengths = tile(torch.stack(memory_lengths, dim=0).sum(dim=0), beam_size) indexes = tile(torch.tensor(list(range(batch_size)), device=self._dev), beam_size) # (0) pt 2, prep the beam object beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths) for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) log_probs, attn = self._decode_and_generate( decoder_input, memory_bank_list, batch, src_vocabs, memory_lengths_list=memory_lengths_list, src_map_list=src_map_list, step=step, batch_offset=beam._batch_offset) if self.reranker is not None: log_probs = self.reranker.rerank_step_beam_batch( batch, beam, self.beam_size, indexes, log_probs, attn, self.fields["tgt"].base_field.vocab, xlation_builder, ) # end if non_finished = None beam.advance(log_probs, attn) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: non_finished = beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. for src_i in range(len(memory_bank_list)): if isinstance(memory_bank_list[src_i], tuple): memory_bank_list[src_i] = tuple(x.index_select(1, select_indices) for x in memory_bank_list[src_i]) else: memory_bank_list[src_i] = memory_bank_list[src_i].index_select(1, select_indices) # end if memory_lengths_list[src_i] = memory_lengths_list[src_i].index_select(0, select_indices) # end for if use_src_map and src_map_list[0] is not None: for src_i in range(len(src_map_list)): src_map_list[src_i] = src_map_list[src_i].index_select(1, select_indices) # end for # end if indexes = indexes.index_select(0, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
def _fast_translate_batch(self, batch, data, min_length=0, node_type=None, atc=None): # TODO: faster code path for beam_size == 1. # TODO: support these blacklisted features. assert data.data_type == 'text' assert not self.copy_attn assert not self.dump_beam assert not self.use_filter_pred assert self.block_ngram_repeat == 0 assert self.global_scorer.beta == 0 beam_size = self.beam_size batch_size = batch.batch_size vocab = self.fields["tgt"].vocab node_type_vocab = self.fields['tgt_feat_0'].vocab start_token = vocab.stoi[inputters.BOS_WORD] end_token = vocab.stoi[inputters.EOS_WORD] unk_token = vocab.stoi[inputters.UNK] token_masks, allowed_token_indices, not_allowed_token_indices\ = generate_token_mask(atc, node_type_vocab, vocab) # debug(token_masks.keys()) assert batch_size == 1, "Only 1 example decoding at a time supported" assert (node_type is not None) and isinstance(node_type, str), \ "Node type string must be provided to translate tokens" node_types = [ var(node_type_vocab.stoi[n_type.strip()]) for n_type in node_type.split() ] node_types.append(var(node_type_vocab.stoi['-1'])) if self.cuda: node_types = [n_type.cuda() for n_type in node_types] # debug(node_types) max_length = len(node_types) # Encoder forward. src = inputters.make_features(batch, 'src', data.data_type) _, src_lengths = batch.src enc_states, memory_bank = self.model.encoder(src, src_lengths) dec_states = self.model.decoder.init_decoder_state(src, memory_bank, enc_states, with_cache=True) # Tile states and memory beam_size times. dec_states.map_batch_fn( lambda state, dim: tile(state, beam_size, dim=dim)) memory_bank = tile(memory_bank, beam_size, dim=1) memory_lengths = tile(src_lengths, beam_size) batch_offset = torch.arange(batch_size, dtype=torch.long, device=memory_bank.device) beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=memory_bank.device) alive_seq = torch.full([batch_size * beam_size, 1], start_token, dtype=torch.long, device=memory_bank.device) alive_attn = None # Give full probability to the first beam on the first step. topk_log_probs = (torch.tensor( [0.0] + [float("-inf")] * (beam_size - 1), device=memory_bank.device).repeat(batch_size)) results = dict() results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 results["attention"] = [[] for _ in range(batch_size)] # noqa: F812 results["gold_score"] = [0] * batch_size results["batch"] = batch save_attention = [[] for _ in range(batch_size)] # max_length += 1 for step in range(max_length): decoder_input = alive_seq[:, -1].view(1, -1, 1) node_type = node_types[step] node_type_str = str(node_type_vocab.itos[node_type.item()]) not_allowed_indices = not_allowed_token_indices[node_type_str] extra_input = torch.stack( [node_types[step] for _ in range(decoder_input.shape[1])]) extra_input = extra_input.view(1, -1, 1) final_input = torch.cat((decoder_input, extra_input), dim=-1) if self.cuda: final_input = final_input.cuda() # Decoder forward. dec_out, dec_states, attn = self.model.decoder( final_input, memory_bank, dec_states, memory_lengths=memory_lengths, step=step) # Generator forward. log_probs = self.model.generator.forward(dec_out.squeeze(0)) vocab_size = log_probs.size(-1) # debug(vocab_size, len(vocab)) if step < min_length: log_probs[:, end_token] = -1.1e20 # debug(len(not_allowed_indices)) log_probs[:, not_allowed_indices] = -1e20 log_probs += topk_log_probs.view(-1).unsqueeze(1) # debug('Source Shape :\t', memory_bank.size()) # debug('Probab Shape :\t', log_probs.size()) # attn_probs = attn['std'].squeeze() # (beam_size, source_length) # debug('Inside Attn:\t', attn['std'].size()) alpha = self.global_scorer.alpha length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha # Flatten probs into a list of possibilities. curr_scores = log_probs / length_penalty curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) # Recover log probs. topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) beam_indices = topk_beam_index.squeeze().cpu().numpy().tolist() if len(attn_probs.shape) == 1: attn_to_save = attn_probs[:] else: attn_to_save = attn_probs[beam_indices, :] save_attention[0].append(attn_to_save) # Map beam_index to batch_index in the flat representation. batch_index = (topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) # Select and reorder alive batches. select_indices = batch_index.view(-1) alive_seq = alive_seq.index_select(0, select_indices) memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) dec_states.map_batch_fn( lambda state, dim: state.index_select(dim, select_indices)) alive_seq = torch.cat([alive_seq, topk_ids.view(-1, 1)], -1) # # End condition is the top beam reached end_token. # end_condition = topk_ids[: , 0].eq(end_token) # if step + 1 == max_length: # end_condition.fill_(1) # finished = end_condition.nonzero().view(-1) # # # Save result of finished sentences. # if len(finished) > 0: predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) scores = topk_scores.view(-1, beam_size) attention = None if alive_attn is not None: attention = alive_attn.view(alive_attn.size(0), -1, beam_size, alive_attn.size(-1)) # debug(predictions.size()) # debug(end_token) # debug(start_token) for i in range(len(predictions)): b = batch_offset[i] for n in range(self.n_best): # debug(unk_token in predictions[i, n, 1:].cpu().numpy().tolist()) results["predictions"][b].append(predictions[i, n, 1:]) results["scores"][b].append(scores[i, n]) if attention is None: results["attention"][b].append([]) else: results["attention"][b].append( attention[:, i, n, :memory_lengths[i]]) results["save_attention"] = save_attention # results["save_attention"] = # non_finished = end_condition.eq(0).nonzero().view(-1) # # If all sentences are translated, no need to go further. # if len(non_finished) == 0: # break # # Remove finished batches for the next step. # topk_log_probs = topk_log_probs.index_select( # 0 , non_finished.to(topk_log_probs.device)) # topk_ids = topk_ids.index_select(0 , non_finished) # batch_index = batch_index.index_select(0 , non_finished) # batch_offset = batch_offset.index_select(0 , non_finished) results["gold_score"] = [0] * batch_size if "tgt" in batch.__dict__: results["gold_score"] = self._run_target(batch, data) return results
def initialize(self, memory_bank, src_lengths, src_map=None, device=None): """Initialize for decoding. Repeat src objects `beam_size` times. """ def fn_map_state(state, dim): return tile(state, self.beam_size, dim=dim) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, self.beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, self.beam_size, dim=1) mb_device = memory_bank.device if src_map is not None: src_map = tile(src_map, self.beam_size, dim=1) if device is None: device = mb_device self.memory_lengths = tile(src_lengths, self.beam_size) super(BeamSearch, self).initialize( memory_bank, self.memory_lengths, src_map, device) ######### base scripts ########### # if device is None: # device = torch.device('cpu') # self.alive_seq = torch.full( # [self.batch_size * self.parallel_paths, 1], self.bos, # dtype=torch.long, device=device) # self.is_finished = torch.zeros( # [self.batch_size, self.parallel_paths], # dtype=torch.uint8, device=device) # return None, memory_bank, src_lengths, src_map ######### base scripts ########### self.best_scores = torch.full( [self.batch_size], -1e10, dtype=torch.float, device=device) self._beam_offset = torch.arange( 0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=device) self.topk_log_probs = torch.tensor( [0.0] + [float("-inf")] * (self.beam_size - 1), device=device ).repeat(self.batch_size) # buffers for the topk scores and 'backpointer' self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device) self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=device) self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=device) return fn_map_state, memory_bank, self.memory_lengths, src_map def __len__(self): return self.alive_seq.shape[1] def ensure_min_length(self, log_probs): if len(self) <= self.min_length: log_probs[:, self.eos] = -1e20 def ensure_max_length(self): # add one to account for BOS. Don't account for EOS because hitting # this implies it hasn't been found. if len(self) == self.max_length + 1: self.is_finished.fill_(1) def block_ngram_repeats(self, log_probs): cur_len = len(self) if self.block_ngram_repeat > 0 and cur_len > 1: for path_idx in range(self.alive_seq.shape[0]): # skip BOS hyp = self.alive_seq[path_idx, 1:] ngrams = set() fail = False gram = [] for i in range(cur_len - 1): # Last n tokens, n = block_ngram_repeat gram = (gram + [hyp[i].item()])[-self.block_ngram_repeat:] # skip the blocking if any token in gram is excluded if set(gram) & self.exclusion_tokens: continue if tuple(gram) in ngrams: fail = True ngrams.add(tuple(gram)) if fail: log_probs[path_idx] = -10e20 @property def current_predictions(self): return self.alive_seq[:, -1] @property def current_backptr(self): # for testing return self.select_indices.view(self.batch_size, self.beam_size)\ .fmod(self.beam_size) @property def batch_offset(self): return self._batch_offset def advance(self, log_probs, attn): vocab_size = log_probs.size(-1) # using integer division to get an integer _B without casting _B = log_probs.shape[0] // self.beam_size if self._stepwise_cov_pen and self._prev_penalty is not None: self.topk_log_probs += self._prev_penalty self.topk_log_probs -= self.global_scorer.cov_penalty( self._coverage + attn, self.global_scorer.beta).view( _B, self.beam_size) # force the output to be longer than self.min_length step = len(self) self.ensure_min_length(log_probs) # Multiply probs by the beam probability. log_probs += self.topk_log_probs.view(_B * self.beam_size, 1) self.block_ngram_repeats(log_probs) # if the sequence ends now, then the penalty is the current # length + 1, to include the EOS token length_penalty = self.global_scorer.length_penalty( step + 1, alpha=self.global_scorer.alpha) # Flatten probs into a list of possibilities. curr_scores = log_probs / length_penalty curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size) torch.topk(curr_scores, self.beam_size, dim=-1, out=(self.topk_scores, self.topk_ids)) # Recover log probs. # Length penalty is just a scalar. It doesn't matter if it's applied # before or after the topk. torch.mul(self.topk_scores, length_penalty, out=self.topk_log_probs) # Resolve beam origin and map to batch index flat representation. torch.div(self.topk_ids, vocab_size, out=self._batch_index) self._batch_index += self._beam_offset[:_B].unsqueeze(1) self.select_indices = self._batch_index.view(_B * self.beam_size) self.topk_ids.fmod_(vocab_size) # resolve true word ids # Append last prediction. self.alive_seq = torch.cat( [self.alive_seq.index_select(0, self.select_indices), self.topk_ids.view(_B * self.beam_size, 1)], -1) if self.return_attention or self._cov_pen: current_attn = attn.index_select(1, self.select_indices) if step == 1: self.alive_attn = current_attn # update global state (step == 1) if self._cov_pen: # coverage penalty self._prev_penalty = torch.zeros_like(self.topk_log_probs) self._coverage = current_attn else: self.alive_attn = self.alive_attn.index_select( 1, self.select_indices) self.alive_attn = torch.cat([self.alive_attn, current_attn], 0) # update global state (step > 1) if self._cov_pen: self._coverage = self._coverage.index_select( 1, self.select_indices) self._coverage += current_attn self._prev_penalty = self.global_scorer.cov_penalty( self._coverage, beta=self.global_scorer.beta).view( _B, self.beam_size) if self._vanilla_cov_pen: # shape: (batch_size x beam_size, 1) cov_penalty = self.global_scorer.cov_penalty( self._coverage, beta=self.global_scorer.beta) self.topk_scores -= cov_penalty.view(_B, self.beam_size).float() self.is_finished = self.topk_ids.eq(self.eos) self.ensure_max_length() def update_finished(self): # Penalize beams that finished. _B_old = self.topk_log_probs.shape[0] step = self.alive_seq.shape[-1] # 1 greater than the step in advance self.topk_log_probs.masked_fill_(self.is_finished, -1e10) # on real data (newstest2017) with the pretrained transformer, # it's faster to not move this back to the original device self.is_finished = self.is_finished.to('cpu') self.top_beam_finished |= self.is_finished[:, 0].eq(1) predictions = self.alive_seq.view(_B_old, self.beam_size, step) attention = ( self.alive_attn.view( step - 1, _B_old, self.beam_size, self.alive_attn.size(-1)) if self.alive_attn is not None else None) non_finished_batch = [] for i in range(self.is_finished.size(0)): # Batch level b = self._batch_offset[i] finished_hyp = self.is_finished[i].nonzero().view(-1) # Store finished hypotheses for this batch. for j in finished_hyp: # Beam level: finished beam j in batch i if self.ratio > 0: s = self.topk_scores[i, j] / (step + 1) if self.best_scores[b] < s: self.best_scores[b] = s self.hypotheses[b].append(( self.topk_scores[i, j], predictions[i, j, 1:], # Ignore start_token. attention[:, i, j, :self.memory_lengths[i]] if attention is not None else None)) # End condition is the top beam finished and we can return # n_best hypotheses. if self.ratio > 0: pred_len = self.memory_lengths[i] * self.ratio finish_flag = ((self.topk_scores[i, 0] / pred_len) <= self.best_scores[b]) or \ self.is_finished[i].all() else: finish_flag = self.top_beam_finished[i] != 0 if finish_flag and len(self.hypotheses[b]) >= self.n_best: best_hyp = sorted( self.hypotheses[b], key=lambda x: x[0], reverse=True) for n, (score, pred, attn) in enumerate(best_hyp): if n >= self.n_best: break self.scores[b].append(score) self.predictions[b].append(pred) # ``(batch, n_best,)`` self.attention[b].append( attn if attn is not None else []) else: non_finished_batch.append(i) non_finished = torch.tensor(non_finished_batch) # If all sentences are translated, no need to go further. if len(non_finished) == 0: self.done = True return _B_new = non_finished.shape[0] # Remove finished batches for the next step. self.top_beam_finished = self.top_beam_finished.index_select( 0, non_finished) self._batch_offset = self._batch_offset.index_select(0, non_finished) non_finished = non_finished.to(self.topk_ids.device) self.topk_log_probs = self.topk_log_probs.index_select(0, non_finished) self._batch_index = self._batch_index.index_select(0, non_finished) self.select_indices = self._batch_index.view(_B_new * self.beam_size) self.alive_seq = predictions.index_select(0, non_finished) \ .view(-1, self.alive_seq.size(-1)) self.topk_scores = self.topk_scores.index_select(0, non_finished) self.topk_ids = self.topk_ids.index_select(0, non_finished) if self.alive_attn is not None: inp_seq_len = self.alive_attn.size(-1) self.alive_attn = attention.index_select(1, non_finished) \ .view(step - 1, _B_new * self.beam_size, inp_seq_len) if self._cov_pen: self._coverage = self._coverage \ .view(1, _B_old, self.beam_size, inp_seq_len) \ .index_select(1, non_finished) \ .view(1, _B_new * self.beam_size, inp_seq_len) if self._stepwise_cov_pen: self._prev_penalty = self._prev_penalty.index_select( 0, non_finished)
def _translate_batch(self, batch, data, builder): # (0) Prep each of the components of the search. # And helper method for reducing verbosity. beam_size = self.beam_size batch_size = batch.batch_size data_type = data.data_type tgt_field = self.fields['tgt'][0][1] vocab = tgt_field.vocab # Define a set of tokens to exclude from ngram-blocking exclusion_tokens = {vocab.stoi[t] for t in self.ignore_when_blocking} pad = vocab.stoi[tgt_field.pad_token] eos = vocab.stoi[tgt_field.eos_token] bos = vocab.stoi[tgt_field.init_token] beam = [ onmt.translate.Beam(beam_size, n_best=self.n_best, cuda=self.cuda, global_scorer=self.global_scorer, pad=pad, eos=eos, bos=bos, vocab=vocab, min_length=self.min_length, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=exclusion_tokens, num_clusters=self.num_clusters, embeddings=self.cluster_embeddings, prev_hyps=self.prev_hyps, hamming_dist=self.hamming_dist, k_per_cand=self.k_per_cand, hamming_penalty=self.hamming_penalty) for __ in range(batch_size) ] # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder( batch, data_type) self.model.decoder.init_state(src, memory_bank, enc_states) results = {} results["predictions"] = [] results["scores"] = [] results["attention"] = [] results["batch"] = batch if "tgt" in batch.__dict__: results["gold_score"] = self._score_target( batch, memory_bank, src_lengths, data, batch.src_map if data_type == 'text' and self.copy_attn else None) self.model.decoder.init_state(src, memory_bank, enc_states) else: results["gold_score"] = [0] * batch_size # (2) Repeat src objects `beam_size` times. # We use now batch_size x beam_size (same as fast mode) src_map = (tile(batch.src_map, beam_size, dim=1) if data.data_type == 'text' and self.copy_attn else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) else: memory_bank = tile(memory_bank, beam_size, dim=1) memory_lengths = tile(src_lengths, beam_size) # Saves new hypotheses new_hyps = [] # (3) run the decoder to generate sentences, using beam search. for i in range(self.max_length): if all((b.done() for b in beam)): break # (a) Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. inp = torch.stack([b.get_current_state() for b in beam]) inp = inp.view(1, -1, 1) # (b) Decode and forward out, beam_attn = self._decode_and_generate( inp, memory_bank, batch, data, memory_lengths=memory_lengths, src_map=src_map, step=i) out = out.view(batch_size, beam_size, -1) beam_attn = beam_attn.view(batch_size, beam_size, -1) # (c) Advance each beam. select_indices_array = [] # Loop over the batch_size number of beam for j, b in enumerate(beam): ## Gets previous beam current_beam = [] if i > 0: ret2, fins = self._from_current_beam(beam) ret2["gold_score"] = [0] * batch_size if "tgt" in batch.__dict__: ret2["gold_score"] = self._run_target(batch, data) ret2["batch"] = batch current_beam = self.debug_translation(ret2, builder, fins)[0] new_hyps += current_beam b.advance(out[j, :], beam_attn.data[j, :, :memory_lengths[j]], current_beam, i) select_indices_array.append(b.get_current_origin() + j * beam_size) select_indices = torch.cat(select_indices_array) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) # Adds all partial hypotheses to prev_hyps for the next iteration of beam search self.prev_hyps += new_hyps # (4) Extract sentences from beam. for b in beam: scores, ks = b.sort_finished(minimum=self.n_best) hyps, attn = [], [] for i, (times, k) in enumerate(ks[:self.n_best]): hyp, att = b.get_hyp(times, k) hyps.append(hyp) attn.append(att) results["predictions"].append(hyps) results["scores"].append(scores) results["attention"].append(attn) return results
def _translate_batch(self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size batch_size = batch.batch_size #### TODO: Augment batch with distractors # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) # src has shape [1311, 2, 1] # enc_states has shape [1311, 2, 512], # Memory_bank has shape [1311, 2, 512] self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score(batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src) } # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) print('memory_bank size after tile:', memory_bank.shape) #[1311, 20, 512] # (0) pt 2, prep the beam object beam = BeamSearch(beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths) all_log_probs = [] all_attn = [] for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) # decoder_input has shape[1,20,1] # decoder_input gives top 10 predictions for each batch element verbose = True if step == 10 else False log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=beam._batch_offset, verbose=verbose) # log_probs and attn are the probs for next word given that the # current word is that in decoder_input all_log_probs.append(log_probs) all_attn.append(attn) beam.advance(log_probs, attn, verbose=verbose) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple( x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) print('batch_size:', batch_size) print('max_length:', max_length) print('all_log_probs has len', len(all_log_probs)) print('all_log_probs[0].shape', all_log_probs[0].shape) print('comparing log_probs[0]', all_log_probs[2][:, 0]) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
def _translate_batch( self, src, src_lengths, batch_size, min_length=0, ratio=0., n_best=1, return_attention=False): max_length = self.config.max_sequence_length + 1 # to account for EOS beam_size = 3 # Encoder forward. enc_states, memory_bank, src_lengths = self.encoder(src, src_lengths) self.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None } # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size self.decoder.map_state(lambda state, dim: tile(state, beam_size, dim=dim)) #if isinstance(memory_bank, tuple): # memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) # mb_device = memory_bank[0].device #else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) mb_device = memory_bank[0].device if isinstance(memory_bank, tuple) else memory_bank.device block_ngram_repeat = 0 _exclusion_idxs = {} # (0) pt 2, prep the beam object beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.scorer, pad=self.config.tgt_padding, eos=self.config.tgt_eos, bos=self.config.tgt_bos, min_length=min_length, ratio=ratio, max_length=max_length, mb_device=mb_device, return_attention=return_attention, stepwise_penalty=None, block_ngram_repeat=block_ngram_repeat, exclusion_tokens=_exclusion_idxs, memory_lengths=memory_lengths) for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) log_probs, attn = self._decode_and_generate(decoder_input, memory_bank, memory_lengths, step, pretraining = True) beam.advance(log_probs, attn) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) self.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
def _translate_batch(self, batch, data): # (0) Prep each of the components of the search. # And helper method for reducing verbosity. beam_size = self.beam_size batch_size = batch.batch_size data_type = data.data_type vocab = self.fields["tgt"].vocab # Define a set of tokens to exclude from ngram-blocking exclusion_tokens = {vocab.stoi[t] for t in self.ignore_when_blocking} pad = vocab.stoi[self.fields['tgt'].pad_token] eos = vocab.stoi[self.fields['tgt'].eos_token] bos = vocab.stoi[self.fields['tgt'].init_token] beam = [ onmt.translate.Beam(beam_size, n_best=self.n_best, cuda=self.cuda, global_scorer=self.global_scorer, pad=pad, eos=eos, bos=bos, min_length=self.min_length, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=exclusion_tokens) for __ in range(batch_size) ] # (1) Run the encoder on the src. src, knl, enc_states, his_memory_bank, src_memory_bank, knl_memory_bank, src_lengths = self._run_encoder( batch, data_type) self.model.decoder.init_state(src[100:, :, :], src[100:, :, :], src_memory_bank, enc_states) first_dec_words = torch.zeros( (self.max_length, batch_size, 1)).fill_(self.model.encoder.embeddings.word_padding_idx).long() results = {} results["predictions"] = [] results["scores"] = [] results["attention"] = [] results["batch"] = batch if "tgt" in batch.__dict__: results["gold_score"] = self._score_target( batch, his_memory_bank, src_memory_bank, knl_memory_bank, knl, src_lengths, data, batch.src_map if data_type == 'text' and self.copy_attn else None, ) self.model.decoder.init_state(src[100:, :, :], src[100:, :, :], src_memory_bank, enc_states) else: results["gold_score"] = [0] * batch_size # (2) Repeat src objects `beam_size` times. # We use now batch_size x beam_size (same as fast mode) src_map = (tile(batch.src_map, beam_size, dim=1) if data.data_type == 'text' and self.copy_attn else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(src_memory_bank, tuple): src_memory_bank = tuple( tile(x, beam_size, dim=1) for x in src_memory_bank) mb_device = src_memory_bank[0].device else: src_memory_bank = tile(src_memory_bank, beam_size, dim=1) mb_device = src_memory_bank.device if isinstance(knl_memory_bank, tuple): knl_memory_bank = tuple( tile(x, beam_size, dim=1) for x in knl_memory_bank) mb_device = knl_memory_bank[0].device else: knl_memory_bank = tile(knl_memory_bank, beam_size, dim=1) mb_device = knl_memory_bank.device if isinstance(his_memory_bank, tuple): his_memory_bank = tuple( tile(x, beam_size, dim=1) for x in his_memory_bank) mb_device = his_memory_bank[0].device else: his_memory_bank = tile(his_memory_bank, beam_size, dim=1) mb_device = his_memory_bank.device memory_lengths = tile(src_lengths, beam_size) # (3) run the first decoder to generate sentences, using beam search. for i in range(self.max_length): if all((b.done() for b in beam)): break # (a) Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. inp = torch.stack([b.get_current_state() for b in beam]) inp = inp.view(1, -1, 1) # (b) Decode and forward dec_out, dec_attn = self.model.decoder( inp, src_memory_bank, his_memory_bank, memory_lengths=memory_lengths, step=i) beam_attn = dec_attn["std"] out = self.model.generator(dec_out.squeeze(0)) out = out.view(batch_size, beam_size, -1) beam_attn = beam_attn.view(batch_size, beam_size, -1) # (c) Advance each beam. select_indices_array = [] # Loop over the batch_size number of beam for j, b in enumerate(beam): b.advance(out[j, :], beam_attn.data[j, :, :memory_lengths[j]]) select_indices_array.append(b.get_current_origin() + j * beam_size) select_indices = torch.cat(select_indices_array) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) # Extract the first decoder sentences from beam. for j, b in enumerate(beam): scores, ks = b.sort_finished(minimum=self.n_best) hyps = [] for i, (times, k) in enumerate(ks[:self.n_best]): hyp, _ = b.get_hyp(times, k) for h in range(len(hyp)): first_dec_words[h, j, 0] = hyp[h] first_dec_words = first_dec_words.cuda() emb, decode1_memory_bank, decode1_mask = self.model.encoder.histransformer( first_dec_words[:50, :, :], None) self.model.decoder2.init_state(first_dec_words[:50, :, :], knl[600:, :, :], None, None) if isinstance(decode1_memory_bank, tuple): decode1_memory_bank = tuple( tile(x, beam_size, dim=1) for x in decode1_memory_bank) mb_device = decode1_memory_bank[0].device else: decode1_memory_bank = tile(decode1_memory_bank, beam_size, dim=1) mb_device = decode1_memory_bank.device self.model.decoder2.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) # (4) run the second decoder to generate sentences, using beam search. for i in range(self.max_length): if all((b.done() for b in beam)): break # (a) Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. inp = torch.stack([b.get_current_state() for b in beam]) inp = inp.view(1, -1, 1) # (b) Decode and forward dec_out, dec_attn = self.model.decoder2( inp, decode1_memory_bank, knl_memory_bank, memory_lengths=memory_lengths, step=i) beam_attn = dec_attn["std"] out = self.model.generator(dec_out.squeeze(0)) out = out.view(batch_size, beam_size, -1) beam_attn = beam_attn.view(batch_size, beam_size, -1) # (c) Advance each beam. select_indices_array = [] # Loop over the batch_size number of beam for j, b in enumerate(beam): b.advance(out[j, :], beam_attn.data[j, :, :memory_lengths[j]]) select_indices_array.append(b.get_current_origin() + j * beam_size) select_indices = torch.cat(select_indices_array) self.model.decoder2.map_state( lambda state, dim: state.index_select(dim, select_indices)) # Extract the second decoder sentences from beam. for b in beam: scores, ks = b.sort_finished(minimum=self.n_best) hyps, attn = [], [] for i, (times, k) in enumerate(ks[:self.n_best]): hyp, att = b.get_hyp(times, k) hyps.append(hyp) attn.append(att) results["predictions"].append(hyps) results["scores"].append(scores) results["attention"].append(attn) return results
def _translate_batch( self, batch, src_vocabs, max_length, min_length=0, ratio=0., n_best=1, return_attention=False): # TODO: support these blacklisted features. assert not self.dump_beam # (0) Prep the components of the search. use_src_map = self.copy_attn beam_size = self.beam_size #default 5 batch_size = batch.batch_size #default 30 # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) #src.size() = torch.Size([59, 30, 1]) [src_len,batch_size,1] #enc_states[0/1].size() = [2,30,500] #memory_bank.size() =[59,30,500] #src_lengths.size() = [30] self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) #把张量x在dim=1,重复beam_size次。beam_size=1是batch_size的维度。 mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) # (0) pt 2, prep the beam object beam = BeamSearch( beam_size, n_best=n_best, batch_size=batch_size, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=min_length, ratio=ratio, max_length=max_length, #Maximum prediction length. mb_device=mb_device, return_attention=return_attention, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths) for step in range(max_length): #一共走这个多个step,每个step将beam_size * batch_size个分支加入 decoder_input = beam.current_predictions.view(1, -1, 1) #decoder_input.size() = torch.Size([1,150,1]) 150 = 30 * 5 = batch_size * beam_size # @property # def current_predictions(self): # return self.alive_seq[:, -1] log_probs, attn = self._decode_and_generate( decoder_input, memory_bank,#torch.Size([59, 150, 500]) batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=beam._batch_offset) # print("log_probs = ",log_probs) #[150, 50004] 这个50004应该是词表的大小,词表中的单词应该是5万,多出来的4个应该是<s> </s> <unk> <pad> # print("attn = ",attn) #torch.Size([1, 150, 59]) 这个59应该是src中的最长的句子的长度 # print("decoder_input = ",decoder_input.size()) beam.advance(log_probs, attn)#这个里面完成的工作应该是将150再变回30, any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results
def _fast_translate_batch(self, batch, data): # TODO: faster code path for beam_size == 1. # TODO: support these blacklisted features. assert data.data_type == 'text' assert self.n_best == 1 assert self.min_length == 0 assert not self.copy_attn assert not self.replace_unk assert not self.dump_beam assert not self.use_filter_pred assert self.block_ngram_repeat == 0 assert self.global_scorer.beta == 0 beam_size = self.beam_size batch_size = batch.batch_size vocab = self.fields["tgt"].vocab start_token = vocab.stoi[inputters.BOS_WORD] end_token = vocab.stoi[inputters.EOS_WORD] # Encoder forward. src = inputters.make_features(batch, 'src', data.data_type) _, src_lengths = batch.src enc_states, memory_bank = self.model.encoder(src, src_lengths) dec_states = self.model.decoder.init_decoder_state( src, memory_bank, enc_states) # Tile states and memory beam_size times. dec_states.map_batch_fn( lambda state, dim: tile(state, beam_size, dim=dim)) memory_bank = tile(memory_bank, beam_size, dim=1) memory_lengths = tile(src_lengths, beam_size) batch_offset = torch.arange(batch_size, dtype=torch.long, device=memory_bank.device) beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=memory_bank.device) alive_seq = torch.full([batch_size * beam_size, 1], start_token, dtype=torch.long, device=memory_bank.device) # Give full probability to the first beam on the first step. topk_log_probs = (torch.tensor( [0.0] + [float("-inf")] * (beam_size - 1), device=memory_bank.device).repeat(batch_size)) results = {} results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 results["attention"] = [[[]] for _ in range(batch_size)] # noqa: F812 results["gold_score"] = [0] * batch_size results["batch"] = batch for step in range(self.max_length): decoder_input = alive_seq[:, -1].view(1, -1, 1) # Decoder forward. dec_out, dec_states, attn = self.model.decoder( decoder_input, memory_bank, dec_states, memory_lengths=memory_lengths, step=step) # Generator forward. log_probs = self.model.generator.forward(dec_out.squeeze(0)) vocab_size = log_probs.size(-1) # Multiply probs by the beam probability. log_probs += topk_log_probs.view(-1).unsqueeze(1) alpha = self.global_scorer.alpha length_penalty = ((5.0 + (step + 1)) / 6.0)**alpha # Flatten probs into a list of possibilities. curr_scores = log_probs / length_penalty curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) # Recover log probs. topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # Map beam_index to batch_index in the flat representation. batch_index = (topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) # End condition is the top beam reached end_token. finished = topk_ids[:, 0].eq(end_token) finished_count = finished.sum() # Save result of finished sentences. if finished_count > 0 or step + 1 == self.max_length: predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) scores = topk_scores.view(-1, beam_size) for i, is_finished in enumerate(finished.tolist()): if step + 1 != self.max_length and is_finished == 0: continue # TODO: if we get there because of max_length, the last # predicted token is currently discarded. b = batch_offset[i] results["predictions"][b].append(predictions[i, 0, 1:]) results["scores"][b].append(scores[i, 0]) non_finished = finished.eq(0).nonzero().view(-1) # If all sentences are translated, no need to go further. if non_finished.nelement() == 0: break # Remove finished batches for the next step. if non_finished.nelement() < finished.nelement(): topk_log_probs = topk_log_probs.index_select( 0, non_finished.to(topk_log_probs.device)) topk_ids = topk_ids.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) # Select and reorder alive batches. select_indices = batch_index.view(-1) alive_seq = alive_seq.index_select(0, select_indices) memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) dec_states.map_batch_fn( lambda state, dim: state.index_select(dim, select_indices)) # Append last prediction. alive_seq = torch.cat([alive_seq, topk_ids.view(-1, 1)], -1) return results
def _translate_batch(self, batch, data): # (0) Prep each of the components of the search. # And helper method for reducing verbosity. beam_size = self.beam_size batch_size = batch.batch_size data_type = data.data_type vocab = self.fields["tgt"].vocab # Define a list of tokens to exclude from ngram-blocking # exclusion_list = ["<t>", "</t>", "."] exclusion_tokens = set([vocab.stoi[t] for t in self.ignore_when_blocking]) beam = [onmt.translate.Beam(beam_size, n_best=self.n_best, cuda=self.cuda, global_scorer=self.global_scorer, pad=vocab.stoi[inputters.PAD_WORD], eos=vocab.stoi[inputters.EOS_WORD], bos=vocab.stoi[inputters.BOS_WORD], min_length=self.min_length, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=exclusion_tokens) for __ in range(batch_size)] # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths, _ = self._run_encoder( batch, data_type) self.model.decoder.init_state(src, memory_bank, enc_states) results = {} results["predictions"] = [] results["scores"] = [] results["attention"] = [] results["batch"] = batch if "tgt" in batch.__dict__: results["gold_score"] = self._score_target( batch, memory_bank, src_lengths, data, batch.src_map if data_type == 'text' and self.copy_attn else None) self.model.decoder.init_state(src, memory_bank, enc_states) else: results["gold_score"] = [0] * batch_size # (2) Repeat src objects `beam_size` times. # We use now batch_size x beam_size (same as fast mode) src_map = (tile(batch.src_map, beam_size, dim=1) if data.data_type == 'text' and self.copy_attn else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) else: memory_bank = tile(memory_bank, beam_size, dim=1) memory_lengths = tile(src_lengths, beam_size) # (3) run the decoder to generate sentences, using beam search. for i in range(self.max_length): if all((b.done() for b in beam)): break # (a) Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. inp = torch.stack([b.get_current_state() for b in beam]) inp = inp.view(1, -1, 1) # (b) Decode and forward out, beam_attn = \ self._decode_and_generate(inp, memory_bank, batch, data, memory_lengths=memory_lengths, src_map=src_map, step=i) out = out.view(batch_size, beam_size, -1) beam_attn = beam_attn.view(batch_size, beam_size, -1) # (c) Advance each beam. select_indices_array = [] # Loop over the batch_size number of beam for j, b in enumerate(beam): b.advance(out[j, :], beam_attn.data[j, :, :memory_lengths[j]]) select_indices_array.append( b.get_current_origin() + j * beam_size) select_indices = torch.cat(select_indices_array) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) # (4) Extract sentences from beam. for b in beam: n_best = self.n_best scores, ks = b.sort_finished(minimum=n_best) for i, (times, k) in enumerate(ks[:n_best]): hyp, att = b.get_hyp(times, k) hyps.append(hyp) attn.append(att) results["predictions"].append(hyps) results["scores"].append(scores) results["attention"].append(attn) return results
def translate_batch(self, batch): beam_size = self.beam_size tgt_field = self.fields['tgt'][0][1] vocab = tgt_field.vocab pad = vocab.stoi[tgt_field.pad_token] eos = vocab.stoi[tgt_field.eos_token] bos = vocab.stoi[tgt_field.init_token] b = Beam(beam_size, n_best=self.n_best, cuda=self.cuda, pad=pad, eos=eos, bos=bos) src, src_lengths = batch.src # why doesn't this contain inflection source lengths when ensembling? side_info = side_information(batch) encoder_out = self.model.encode(src, lengths=src_lengths, **side_info) enc_states = encoder_out["enc_state"] memory_bank = encoder_out["memory_bank"] infl_memory_bank = encoder_out.get("inflection_memory_bank", None) self.model.init_decoder_state(enc_states) results = dict() if "tgt" in batch.__dict__: results["gold_score"] = self._score_target( batch, memory_bank, src_lengths, inflection_memory_bank=infl_memory_bank, **side_info) self.model.init_decoder_state(enc_states) else: results["gold_score"] = 0 # (2) Repeat src objects `beam_size` times. self.model.map_decoder_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) else: memory_bank = tile(memory_bank, beam_size, dim=1) memory_lengths = tile(src_lengths, beam_size) if infl_memory_bank is not None: if isinstance(infl_memory_bank, tuple): infl_memory_bank = tuple( tile(x, beam_size, dim=1) for x in infl_memory_bank) else: infl_memory_bank = tile(infl_memory_bank, beam_size, dim=1) tiled_infl_len = tile(side_info["inflection_lengths"], beam_size) side_info["inflection_lengths"] = tiled_infl_len if "language" in side_info: side_info["language"] = tile(side_info["language"], beam_size) for i in range(self.max_length): if b.done(): break inp = b.current_state.unsqueeze(0) # the decoder expects an input of tgt_len x batch dec_out, dec_attn = self.model.decode( inp, memory_bank, memory_lengths=memory_lengths, inflection_memory_bank=infl_memory_bank, **side_info) attn = dec_attn["lemma"].squeeze(0) out = self.model.generator(dec_out.squeeze(0), transform=True, **side_info) # b.advance will take attn (beam size x src length) b.advance(out, dec_attn) select_indices = b.current_origin self.model.map_decoder_state( lambda state, dim: state.index_select(dim, select_indices)) scores, ks = b.sort_finished() hyps, attn, out_probs = [], [], [] for i, (times, k) in enumerate(ks[:self.n_best]): hyp, att, out_p = b.get_hyp(times, k) hyps.append(hyp) attn.append(att) out_probs.append(out_p) results["preds"] = hyps results["scores"] = scores results["attn"] = attn if self.beam_accum is not None: parent_ids = [t.tolist() for t in b.prev_ks] self.beam_accum["beam_parent_ids"].append(parent_ids) scores = [["%4f" % s for s in t.tolist()] for t in b.all_scores][1:] self.beam_accum["scores"].append(scores) pred_ids = [[vocab.itos[i] for i in t.tolist()] for t in b.next_ys][1:] self.beam_accum["predicted_ids"].append(pred_ids) if self.attn_path is not None: save_attn = {k: v.cpu() for k, v in attn[0].items()} src_seq = self.itos(src, "src") pred_seq = self.itos(hyps[0], "tgt") attn_dict = {"src": src_seq, "pred": pred_seq, "attn": save_attn} if "inflection" in save_attn: inflection_seq = self.itos(batch.inflection[0], "inflection") attn_dict["inflection"] = inflection_seq self.attns.append(attn_dict) if self.probs_path is not None: save_probs = out_probs[0].cpu() self.probs.append(save_probs) return results
def _fast_translate_batch(self, batch, data, max_length, min_length=0, n_best=1, return_attention=False): # TODO: faster code path for beam_size == 1. # TODO: support these blacklisted features. assert not self.dump_beam assert not self.use_filter_pred assert self.block_ngram_repeat == 0 assert self.global_scorer.beta == 0 beam_size = self.beam_size batch_size = batch.batch_size vocab = self.fields["tgt"].vocab start_token = vocab.stoi[inputters.BOS_WORD] end_token = vocab.stoi[inputters.EOS_WORD] # Encoder forward. src, enc_states, memory_bank, src_lengths = self._run_encoder( batch, data.data_type) self.model.decoder.init_state(src, memory_bank, enc_states, with_cache=True) if self.refer: ref_list, ref_states_list, ref_bank_list, ref_lengths_list, ref_prs_list = [], [], [], [], [] for i in range(self.refer): ref, ref_states, ref_bank, ref_lengths, ref_prs = self._run_refer(batch, data.data_type, k=i) ref_list.append(ref) ref_states_list.append(ref_states) ref_bank_list.append(ref_bank) ref_lengths_list.append(ref_lengths) ref_prs_list.append(ref_prs) self.extra_decoders[i].init_state(ref, ref_bank, ref_states, with_cache=True) else: ref_list, ref_states_list, ref_bank_list, ref_lengths_list, ref_prs_list = tuple([None]*5) results = dict() results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 results["attention"] = [[] for _ in range(batch_size)] # noqa: F812 results["batch"] = batch if "tgt" in batch.__dict__: results["gold_score"] = self._score_target( batch, memory_bank, src_lengths, data, batch.src_map if data.data_type == 'text' and self.copy_attn else None) self.model.decoder.init_state( src, memory_bank, enc_states, with_cache=True) else: results["gold_score"] = [0] * batch_size # Tile states and memory beam_size times. self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) src_map = (tile(batch.src_map, beam_size, dim=1) if data.data_type == 'text' and self.copy_attn else None) if self.refer: for i in range(self.refer): self.extra_decoders[i].map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(ref_bank_list[i], tuple): ref_bank_list[i] = tuple(tile(x, beam_size, dim=1) for x in ref_bank_list[i]) else: ref_bank_list[i] = tile(ref_bank_list[i], beam_size, dim=1) ref_lengths_list[i] = tile(ref_lengths_list[i], beam_size) ref_prs_list[i] = tile(ref_prs_list[i], beam_size).view(-1, 1) # ref_prs = torch.rand_like(ref_prs) else: ref_bank_list, ref_lengths_list = None, None top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8) batch_offset = torch.arange(batch_size, dtype=torch.long) beam_offset = torch.arange( 0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=mb_device) alive_seq = torch.full( [batch_size * beam_size, 1], start_token, dtype=torch.long, device=mb_device) alive_attn = None # Give full probability to the first beam on the first step. topk_log_probs = ( torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=mb_device).repeat(batch_size)) # Structure that holds finished hypotheses. hypotheses = [[] for _ in range(batch_size)] # noqa: F812 for step in range(max_length): decoder_input = alive_seq[:, -1].view(1, -1, 1) log_probs, attn = \ self._decode_and_generate(decoder_input, memory_bank, batch, data, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=batch_offset, ref_bank=ref_bank_list, ref_lengths=ref_lengths_list, ref_prs=ref_prs_list) vocab_size = log_probs.size(-1) if self.guide: log_probs = self.guide_by_tp(alive_seq, log_probs) if step < min_length: log_probs[:, end_token] = -1e20 # Multiply probs by the beam probability. log_probs += topk_log_probs.view(-1).unsqueeze(1) alpha = self.global_scorer.alpha length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha # Flatten probs into a list of possibilities. curr_scores = log_probs / length_penalty curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) # Recover log probs. topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # topk_ids = torch.cat([batch.tgt[step + 1][batch_offset].view(-1, 1), topk_ids], -1)[:, :self.beam_size] # Map beam_index to batch_index in the flat representation. batch_index = ( topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) select_indices = batch_index.view(-1) # Append last prediction. alive_seq = torch.cat( [alive_seq.index_select(0, select_indices), topk_ids.contiguous().view(-1, 1)], -1) if return_attention: current_attn = attn.index_select(1, select_indices) if alive_attn is None: alive_attn = current_attn else: alive_attn = alive_attn.index_select(1, select_indices) alive_attn = torch.cat([alive_attn, current_attn], 0) is_finished = topk_ids.eq(end_token) if step + 1 == max_length: is_finished.fill_(1) # Save finished hypotheses. if is_finished.any(): # Penalize beams that finished. topk_log_probs.masked_fill_(is_finished, -1e10) is_finished = is_finished.to('cpu') top_beam_finished |= is_finished[:, 0].eq(1) predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) attention = ( alive_attn.view( alive_attn.size(0), -1, beam_size, alive_attn.size(-1)) if alive_attn is not None else None) non_finished_batch = [] for i in range(is_finished.size(0)): b = batch_offset[i] finished_hyp = is_finished[i].nonzero().view(-1) # Store finished hypotheses for this batch. for j in finished_hyp: hypotheses[b].append(( topk_scores[i, j], predictions[i, j, 1:], # Ignore start_token. attention[:, i, j, :memory_lengths[i]] if attention is not None else None)) # End condition is the top beam finished and we can return # n_best hypotheses. if top_beam_finished[i] and len(hypotheses[b]) >= n_best: best_hyp = sorted( hypotheses[b], key=lambda x: x[0], reverse=True) for n, (score, pred, attn) in enumerate(best_hyp): if n >= n_best: break results["scores"][b].append(score) results["predictions"][b].append(pred) results["attention"][b].append( attn if attn is not None else []) else: non_finished_batch.append(i) non_finished = torch.tensor(non_finished_batch) # If all sentences are translated, no need to go further. if len(non_finished) == 0: break # Remove finished batches for the next step. top_beam_finished = top_beam_finished.index_select( 0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) non_finished = non_finished.to(topk_ids.device) topk_log_probs = topk_log_probs.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) select_indices = batch_index.view(-1) alive_seq = predictions.index_select(0, non_finished) \ .view(-1, alive_seq.size(-1)) if alive_attn is not None: alive_attn = attention.index_select(1, non_finished) \ .view(alive_attn.size(0), -1, alive_attn.size(-1)) # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple(x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) if self.refer: for i in range(self.refer): if isinstance(ref_bank_list[i], tuple): ref_bank_list[i] = tuple(x.index_select(1, select_indices)for x in ref_bank_list[i]) else: ref_bank_list[i] = ref_bank_list[i].index_select(1, select_indices) ref_lengths_list[i] = ref_lengths_list[i].index_select(0, select_indices) ref_prs_list[i] = ref_prs_list[i].index_select(0, select_indices) self.extra_decoders[i].map_state( lambda state, dim: state.index_select(dim, select_indices)) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) if src_map is not None: src_map = src_map.index_select(1, select_indices) if self.guide: print(self.batch_num) self.batch_num += 1 return results
def forward(self, input, attn_mask=None): """ Inputs of forward function input: [target length, batch size, embed dim] attn_mask [(batch size), sequence_length, sequence_length] Outputs of forward function attn_output: [target length, batch size, embed dim] attn_output_weights: [batch size, target length, sequence length] """ seq_len, bsz, embed_dim = input.size() assert embed_dim == self.embed_dim # self-attention q, k, v = F.linear(input, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) q *= self.scaling # Cut q, k, v in num_heads part q = q.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) # Gated Linear Unit if self._use_glu: q = self.query_glu(q) k = self.key_glu(k) # batch matrix multply query against key # attn_output_weights is [bsz * num_heads, seq_len, seq_len] attn_output_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_output_weights.size()) == [ bsz * self.num_heads, seq_len, seq_len ] if attn_mask is not None: if attn_mask.dim() == 2: # We use the same mask for each item in the batch attn_mask = attn_mask.unsqueeze(0) elif attn_mask.dim() == 3: # Each item in the batch has its own mask. # We need to inflate the mask to go with all heads attn_mask = tile(attn_mask, count=self.num_heads, dim=0) else: # Don't known what we would be doing here... raise RuntimeError(f'Wrong mask dim: {attn_mask.dim()}') # The mask should be either 0 of -inf to go with softmax attn_output_weights += attn_mask attn_output_weights = F.softmax( attn_output_weights.float(), dim=-1, dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype) attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [ bsz * self.num_heads, seq_len, self.head_dim ] attn_output = attn_output.transpose(0, 1).contiguous().view( seq_len, bsz, embed_dim) attn_output = self.out_proj(attn_output) # average attention weights over heads attn_output_weights = attn_output_weights.view(bsz, self.num_heads, seq_len, seq_len) attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads return attn_output, attn_output_weights
def _translate_batch_with_strategy(self, batch, src_vocabs, decode_strategy, only_gold_score): """Translate a batch of sentences step by step using cache. Args: batch: a batch of sentences, yield by data iterator. src_vocabs (list): list of torchtext.data.Vocab if can_copy. decode_strategy (DecodeStrategy): A decode strategy to use for generate translation step by step. Returns: results (dict): The translation results. """ # (0) Prep the components of the search. use_src_map = self.copy_attn parallel_paths = decode_strategy.parallel_paths # beam_size batch_size = batch.batch_size initial_encoding_token = None initial_decoding_token = self.expert_index(self.expert_id) # initial_token = self.expert_index(self.expert_id) # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = \ self._run_encoder(batch, initial_encoding_token) self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": None, "scores": None, "attention": None, "batch": batch, "gold_score": self._gold_score(batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src, initial_decoding_token) } if only_gold_score: results["scores"] = [[x] for x in results["gold_score"]] results["predictions"] = [[None] for _ in decode_strategy.predictions] results["attention"] = [[None] for _ in range(batch_size)] results["alignment"] = [[] for _ in range(batch_size)] else: # (2) prep decode_strategy. Possibly repeat src objects. src_map = batch.src_map if use_src_map else None fn_map_state, memory_bank, memory_lengths, src_map = \ decode_strategy.initialize(memory_bank, src_lengths, src_map, initial_token=initial_decoding_token) if fn_map_state is not None: self.model.decoder.map_state(fn_map_state) # (3) Begin decoding step by step: for step in range(decode_strategy.max_length): decoder_input = decode_strategy.current_predictions.view( 1, -1, 1) log_probs, attn = self._decode_and_generate( decoder_input, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=step, batch_offset=decode_strategy.batch_offset) if step == 0 and self.model.prior is not None: lprob_z = self.model.prior(memory_bank, memory_lengths) log_probs += tile(lprob_z[:, self.expert_id].unsqueeze(-1), self._tgt_vocab_len, dim=1) decode_strategy.advance(log_probs, attn) any_finished = decode_strategy.is_finished.any() if any_finished: decode_strategy.update_finished() if decode_strategy.done: break select_indices = decode_strategy.select_indices if any_finished: # Reorder states. if isinstance(memory_bank, tuple): memory_bank = tuple( x.index_select(1, select_indices) for x in memory_bank) else: memory_bank = memory_bank.index_select( 1, select_indices) memory_lengths = memory_lengths.index_select( 0, select_indices) if src_map is not None: src_map = src_map.index_select(1, select_indices) if parallel_paths > 1 or any_finished: self.model.decoder.map_state( lambda state, dim: state.index_select( dim, select_indices)) results["scores"] = decode_strategy.scores results["predictions"] = decode_strategy.predictions results["attention"] = decode_strategy.attention if self.report_align: results["alignment"] = self._align_forward( batch, decode_strategy.predictions) else: results["alignment"] = [[] for _ in range(batch_size)] return results
def _fast_translate_batch(self, batch, data, max_length, min_length=0, n_best=1, return_attention=False): # TODO: faster code path for beam_size == 1. # TODO: support these blacklisted features. assert data.data_type == 'text' assert not self.copy_attn assert not self.dump_beam assert not self.use_filter_pred assert self.block_ngram_repeat == 0 assert self.global_scorer.beta == 0 beam_size = self.beam_size batch_size = batch.batch_size vocab = self.fields["tgt"].vocab start_token = vocab.stoi[inputters.BOS_WORD] end_token = vocab.stoi[inputters.EOS_WORD] # Encoder forward. src = inputters.make_features(batch, 'src', data.data_type) _, src_lengths = batch.src enc_states, memory_bank = self.model.encoder(src, src_lengths) dec_states = self.model.decoder.init_decoder_state( src, memory_bank, enc_states, with_cache=True) # Tile states and memory beam_size times. dec_states.map_batch_fn( lambda state, dim: tile(state, beam_size, dim=dim)) memory_bank = tile(memory_bank, beam_size, dim=1) memory_lengths = tile(src_lengths, beam_size) batch_offset = torch.arange( batch_size, dtype=torch.long, device=memory_bank.device) beam_offset = torch.arange( 0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=memory_bank.device) alive_seq = torch.full( [batch_size * beam_size, 1], start_token, dtype=torch.long, device=memory_bank.device) alive_attn = None # Give full probability to the first beam on the first step. topk_log_probs = ( torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=memory_bank.device).repeat(batch_size)) results = {} results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 results["attention"] = [[] for _ in range(batch_size)] # noqa: F812 results["gold_score"] = [0] * batch_size results["batch"] = batch max_length += 1 for step in range(max_length): decoder_input = alive_seq[:, -1].view(1, -1, 1) # Decoder forward. dec_out, dec_states, attn = self.model.decoder( decoder_input, memory_bank, dec_states, memory_lengths=memory_lengths, step=step) # Generator forward. log_probs = self.model.generator.forward(dec_out.squeeze(0)) vocab_size = log_probs.size(-1) if step < min_length: log_probs[:, end_token] = -1e20 # Multiply probs by the beam probability. log_probs += topk_log_probs.view(-1).unsqueeze(1) alpha = self.global_scorer.alpha length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha # Flatten probs into a list of possibilities. curr_scores = log_probs / length_penalty curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) # Recover log probs. topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) # Map beam_index to batch_index in the flat representation. batch_index = ( topk_beam_index + beam_offset[:topk_beam_index.size(0)].unsqueeze(1)) # End condition is the top beam reached end_token. end_condition = topk_ids[:, 0].eq(end_token) if step + 1 == max_length: end_condition.fill_(1) finished = end_condition.nonzero().view(-1) # Save result of finished sentences. if len(finished) > 0: predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) scores = topk_scores.view(-1, beam_size) attention = None if alive_attn is not None: attention = alive_attn.view( alive_attn.size(0), -1, beam_size, alive_attn.size(-1)) for i in finished: b = batch_offset[i] for n in range(n_best): results["predictions"][b].append(predictions[i, n, 1:]) results["scores"][b].append(scores[i, n]) if attention is None: results["attention"][b].append([]) else: results["attention"][b].append( attention[:, i, n, :memory_lengths[i]]) non_finished = end_condition.eq(0).nonzero().view(-1) # If all sentences are translated, no need to go further. if len(non_finished) == 0: break # Remove finished batches for the next step. topk_log_probs = topk_log_probs.index_select( 0, non_finished.to(topk_log_probs.device)) topk_ids = topk_ids.index_select(0, non_finished) batch_index = batch_index.index_select(0, non_finished) batch_offset = batch_offset.index_select(0, non_finished) # Select and reorder alive batches. select_indices = batch_index.view(-1) alive_seq = alive_seq.index_select(0, select_indices) memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select(0, select_indices) dec_states.map_batch_fn( lambda state, dim: state.index_select(dim, select_indices)) # Append last prediction. alive_seq = torch.cat([alive_seq, topk_ids.view(-1, 1)], -1) if return_attention: current_attn = attn["std"].index_select(1, select_indices) if alive_attn is None: alive_attn = current_attn else: alive_attn = alive_attn.index_select(1, select_indices) alive_attn = torch.cat([alive_attn, current_attn], 0) return results
min_length = opt.min_length ratio = 0. max_length = opt.max_length mb_device = 0 stepwise_penalty = None block_ngram_repeat = 0 global_scorer = onmt.translate.GNMTGlobalScorer.from_opt(opt) return_attention = False for batch in data_iter: print() src, src_lengths = batch.src memory_lengths = src_lengths enc_states, memory_bank, src_lengths = model.encoder(src, src_lengths) model.decoder.init_state(src, memory_bank, enc_states) model.decoder.map_state(lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) mb_device = memory_bank[0].device else: memory_bank = tile(memory_bank, beam_size, dim=1) mb_device = memory_bank.device memory_lengths = tile(src_lengths, beam_size) beam = BeamSearch(beam_size, n_best=n_best, batch_size=batch_size, global_scorer=global_scorer, pad=tgt_pad_idx, eos=tgt_eos_idx, bos=tgt_bos_idx, min_length=min_length,
def _translate_batch_deprecated(self, batch, src_vocabs): # (0) Prep each of the components of the search. # And helper method for reducing verbosity. use_src_map = self.copy_attn beam_size = self.beam_size batch_size = batch.batch_size beam = [onmt.translate.Beam( beam_size, n_best=self.n_best, cuda=self.cuda, global_scorer=self.global_scorer, pad=self._tgt_pad_idx, eos=self._tgt_eos_idx, bos=self._tgt_bos_idx, min_length=self.min_length, stepwise_penalty=self.stepwise_penalty, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs) for __ in range(batch_size)] # (1) Run the encoder on the src. src, enc_states, memory_bank, src_lengths = self._run_encoder(batch) self.model.decoder.init_state(src, memory_bank, enc_states) results = { "predictions": [], "scores": [], "attention": [], "batch": batch, "gold_score": self._gold_score( batch, memory_bank, src_lengths, src_vocabs, use_src_map, enc_states, batch_size, src)} # (2) Repeat src objects `beam_size` times. # We use now batch_size x beam_size (same as fast mode) src_map = (tile(batch.src_map, beam_size, dim=1) if use_src_map else None) self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) if isinstance(memory_bank, tuple): memory_bank = tuple(tile(x, beam_size, dim=1) for x in memory_bank) else: memory_bank = tile(memory_bank, beam_size, dim=1) memory_lengths = tile(src_lengths, beam_size) # (3) run the decoder to generate sentences, using beam search. for i in range(self.max_length): if all((b.done for b in beam)): break # (a) Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. inp = torch.stack([b.current_predictions for b in beam]) inp = inp.view(1, -1, 1) # (b) Decode and forward out, beam_attn = self._decode_and_generate( inp, memory_bank, batch, src_vocabs, memory_lengths=memory_lengths, src_map=src_map, step=i ) out = out.view(batch_size, beam_size, -1) beam_attn = beam_attn.view(batch_size, beam_size, -1) # (c) Advance each beam. select_indices_array = [] # Loop over the batch_size number of beam for j, b in enumerate(beam): if not b.done: b.advance(out[j, :], beam_attn.data[j, :, :memory_lengths[j]]) select_indices_array.append( b.current_origin + j * beam_size) select_indices = torch.cat(select_indices_array) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) # (4) Extract sentences from beam. for b in beam: scores, ks = b.sort_finished(minimum=self.n_best) hyps, attn = [], [] for times, k in ks[:self.n_best]: hyp, att = b.get_hyp(times, k) hyps.append(hyp) attn.append(att) results["predictions"].append(hyps) results["scores"].append(scores) results["attention"].append(attn) return results