def translate_batch(self, batch): def get_inst_idx_to_tensor_position_map(inst_idx_list): ''' Indicate the position of an instance in a tensor. ''' return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): ''' Collect tensor parts associated to active instances. ''' _, *d_hs = beamed_tensor.size() n_curr_active_inst = len(curr_active_inst_idx) new_shape = (n_curr_active_inst * n_bm, *d_hs) beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) beamed_tensor = beamed_tensor.view(*new_shape) return beamed_tensor def beam_decode_step( inst_dec_beams, len_dec_seq, inst_idx_to_position_map, n_bm): ''' Decode and update beam status, and then return active beam idx ''' # len_dec_seq: i (starting from 0) def prepare_beam_dec_seq(inst_dec_beams): dec_seq = [b.get_last_target_word() for b in inst_dec_beams if not b.done] # dec_seq: [(beam_size)] * batch_size dec_seq = torch.stack(dec_seq).to(self.device) # dec_seq: (batch_size, beam_size) dec_seq = dec_seq.view(1, -1) # dec_seq: (1, batch_size * beam_size) return dec_seq def predict_word(dec_seq, n_active_inst, n_bm, len_dec_seq): # dec_seq: (1, batch_size * beam_size) dec_output, *_ = self.model.decoder(dec_seq, step=len_dec_seq) # dec_output: (1, batch_size * beam_size, hid_size) word_prob = self.model.generator(dec_output.squeeze(0)) # word_prob: (batch_size * beam_size, vocab_size) word_prob = word_prob.view(n_active_inst, n_bm, -1) # word_prob: (batch_size, beam_size, vocab_size) return word_prob def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): active_inst_idx_list = [] select_indices_array = [] for inst_idx, inst_position in inst_idx_to_position_map.items(): is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) if not is_inst_complete: active_inst_idx_list += [inst_idx] select_indices_array.append(inst_beams[inst_idx].get_current_origin() + inst_position * n_bm) if len(select_indices_array) > 0: select_indices = torch.cat(select_indices_array) else: select_indices = None return active_inst_idx_list, select_indices n_active_inst = len(inst_idx_to_position_map) dec_seq = prepare_beam_dec_seq(inst_dec_beams) # dec_seq: (1, batch_size * beam_size) word_prob = predict_word(dec_seq, n_active_inst, n_bm, len_dec_seq) # Update the beam with predicted word prob information and collect incomplete instances active_inst_idx_list, select_indices = collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) if select_indices is not None: assert len(active_inst_idx_list) > 0 self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) return active_inst_idx_list def collate_active_info( src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): # Sentences which are still active are collected, # so the decoder will not run on completed sentences. n_prev_active_inst = len(inst_idx_to_position_map) active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) return active_src_seq, active_src_enc, active_inst_idx_to_position_map def collect_best_hypothesis_and_score(inst_dec_beams): hyps, scores = [], [] for inst_idx in range(len(inst_dec_beams)): hyp, score = inst_dec_beams[inst_idx].get_best_hypothesis() hyps.append(hyp) scores.append(score) return hyps, scores with torch.no_grad(): #-- Encode src_seq = make_features(batch, 'src') # src: (seq_len_src, batch_size) src_emb, src_enc, _ = self.model.encoder(src_seq) # src_emb: (seq_len_src, batch_size, emb_size) # src_end: (seq_len_src, batch_size, hid_size) self.model.decoder.init_state(src_seq, src_enc) src_len = src_seq.size(0) #-- Repeat data for beam search n_bm = self.beam_size n_inst = src_seq.size(1) self.model.decoder.map_state(lambda state, dim: tile(state, n_bm, dim=dim)) # src_enc: (seq_len_src, batch_size * beam_size, hid_size) #-- Prepare beams decode_length = src_len + self.decode_extra_length decode_min_length = 0 if self.decode_min_length >= 0: decode_min_length = src_len - self.decode_min_length if self.task_type == 'task': inst_dec_beams = [Beam(n_bm, decode_length=decode_length, minimal_length=decode_min_length, minimal_relative_prob=self.minimal_relative_prob, bos_id=self.tgt_bos_id, eos_id=self.tgt_eos_id, device=self.device) for _ in range(n_inst)] else: inst_dec_beams = [Beam(n_bm, decode_length=decode_length, minimal_length=decode_min_length, minimal_relative_prob=self.minimal_relative_prob, bos_id=self.tgt2_bos_id, eos_id=self.tgt2_eos_id, device=self.device) for _ in range(n_inst)] #-- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) #-- Decode for len_dec_seq in range(0, decode_length): active_inst_idx_list = beam_decode_step( inst_dec_beams, len_dec_seq, inst_idx_to_position_map, n_bm) if not active_inst_idx_list: break # all instances have finished their path to <EOS> inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) batch_hyps, batch_scores = collect_best_hypothesis_and_score(inst_dec_beams) return batch_hyps, batch_scores
def translate_batch(self, src, src_len, tgt, max_length, min_length=0, ratio=0., n_best=1, return_attention=False): with torch.no_grad(): # (0) Prep the components of the search. beam_size = self.beam_size batch_size = src.size(1) # (1) pt 1, Run the encoder on the src. enc_states, memory_bank, src_lengths = self.model.encoder( src, src_len) # (1) pt 2, Convert encoder state to decoder state enc_states = self.model.bridge(enc_states) # (1) pt 3, Make size of memory bank same as that of decoder memory_bank = self.model.W(memory_bank) self.model.decoder.init_state(enc_states) results = { "predictions": None, "scores": None, "attention": None, "gold_score": self._gold_score(tgt, enc_states, memory_bank, src_lengths) \ if tgt is not None else None } # (2) Repeat src objects `beam_size` times. # We use batch_size x beam_size self.model.decoder.map_state( lambda state, dim: tile(state, beam_size, dim=dim)) memory_bank = tile(memory_bank, beam_size, dim=1) memory_lengths = tile(src_lengths, beam_size) # (0) pt 2, prep the beam object beam = BeamSearch(beam_size=beam_size, batch_size=batch_size, pad=self._tgt_pad_idx, bos=self._tgt_bos_idx, eos=self._tgt_eos_idx, n_best=n_best, device=self._dev, min_length=min_length, max_length=max_length, return_attention=return_attention, block_ngram_repeat=self.block_ngram_repeat, exclusion_tokens=self._exclusion_idxs, memory_lengths=memory_lengths, ratio=ratio) for step in range(max_length): decoder_input = beam.current_predictions.view(1, -1, 1) log_probs, attn = self.model.decoder( decoder_input, memory_bank, memory_lengths=memory_lengths) log_probs = log_probs.squeeze(0) beam.advance(log_probs, attn) any_beam_is_finished = beam.is_finished.any() if any_beam_is_finished: beam.update_finished() if beam.done: break select_indices = beam.current_origin if any_beam_is_finished: # Reorder states. memory_bank = memory_bank.index_select(1, select_indices) memory_lengths = memory_lengths.index_select( 0, select_indices) self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = beam.scores results["predictions"] = beam.predictions results["attention"] = beam.attention return results