def beam_search(model, batch, beam_size, start=1, end=2, pad=0, min_length=3, min_n_best=5, max_ts=40, block_ngram=0): """ Beam search given the model and Batch This function uses model with the following reqs: - model.encoder takes input returns tuple (enc_out, enc_hidden, attn_mask) - model.decoder takes decoder params and returns decoder outputs after attn - model.output takes decoder outputs and returns distr over dictionary Function arguments: model : nn.Module, here defined in modules.py batch : Batch structure with input and labels beam_size : Size of each beam during the search start : start of sequence token end : end of sequence token pad : padding token min_length : minimum length of the decoded sequence min_n_best : minimum number of completed hypothesis generated from each beam max_ts: the maximum length of the decoded sequence Return: beam_preds_scores : list of tuples (prediction, score) for each sample in Batch n_best_preds_scores : list of n_best list of tuples (prediction, score) for each sample from Batch beams : list of Beam instances defined in Beam class, can be used for any following postprocessing, e.g. dot logging. """ encoder_states = model.encoder(batch.text_vec) enc_out = encoder_states[0] enc_hidden = encoder_states[1] attn_mask = encoder_states[2] current_device = encoder_states[0][0].device batch_size = len(batch.text_lengths) beams = [ Beam(beam_size, min_length=min_length, padding_token=pad, bos_token=start, eos_token=end, min_n_best=min_n_best, cuda=current_device, block_ngram=block_ngram) for i in range(batch_size) ] decoder_input = torch.Tensor([start]).detach().expand( batch_size, 1).long().to(current_device) # repeat encoder_outputs, hiddens, attn_mask decoder_input = decoder_input.repeat(1, beam_size).view( beam_size * batch_size, -1) enc_out = enc_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(batch_size * beam_size, -1, enc_out.size(-1)) attn_mask = encoder_states[2].repeat(1, beam_size).view( attn_mask.size(0) * beam_size, -1) repeated_hiddens = [] if isinstance(enc_hidden, tuple): # LSTM for i in range(len(enc_hidden)): repeated_hiddens.append(enc_hidden[i].unsqueeze(2).repeat( 1, 1, beam_size, 1)) num_layers = enc_hidden[0].size(0) hidden_size = enc_hidden[0].size(-1) enc_hidden = tuple([ repeated_hiddens[i].view(num_layers, batch_size * beam_size, hidden_size) for i in range(len(repeated_hiddens)) ]) else: # GRU num_layers = enc_hidden.size(0) hidden_size = enc_hidden.size(-1) enc_hidden = enc_hidden.unsqueeze(2).repeat( 1, 1, beam_size, 1).view(num_layers, batch_size * beam_size, hidden_size) hidden = enc_hidden for ts in range(max_ts): if all((b.done() for b in beams)): break output, hidden = model.decoder(decoder_input, hidden, (enc_out, attn_mask)) score = model.output(output) # score contains softmax scores for batch_size * beam_size samples score = score.view(batch_size, beam_size, -1) score = F.log_softmax(score, dim=-1) for i, b in enumerate(beams): b.advance(score[i]) decoder_input = torch.cat([ b.get_output_from_current_step() for b in beams ]).unsqueeze(-1) permute_hidden_idx = torch.cat([ beam_size * i + b.get_backtrack_from_current_step() for i, b in enumerate(beams) ]) # permute decoder hiddens with respect to chosen hypothesis now if isinstance(hidden, tuple): # LSTM for i in range(len(hidden)): hidden[i].data.copy_(hidden[i].data.index_select( dim=1, index=permute_hidden_idx)) else: # GRU hidden.data.copy_( hidden.data.index_select(dim=1, index=permute_hidden_idx)) for b in beams: b.check_finished() beam_preds_scores = [list(b.get_top_hyp()) for b in beams] for pair in beam_preds_scores: pair[0] = Beam.get_pretty_hypothesis(pair[0]) n_best_beams = [ b.get_rescored_finished(n_best=min_n_best) for b in beams ] n_best_beam_preds_scores = [] for i, beamhyp in enumerate(n_best_beams): this_beam = [] for hyp in beamhyp: pred = beams[i].get_pretty_hypothesis( beams[i].get_hyp_from_finished(hyp)) score = hyp.score this_beam.append((pred, score)) n_best_beam_preds_scores.append(this_beam) return beam_preds_scores, n_best_beam_preds_scores, beams
def translate_batch(self, src_seq, src_pos): ''' Translation work in one batch ''' def get_inst_idx_to_tensor_position_map(inst_idx_list): ''' Indicate the position of an instance in a tensor. ''' return { inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list) } def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): ''' Collect tensor parts associated to active instances. ''' _, *d_hs = beamed_tensor.size() n_curr_active_inst = len(curr_active_inst_idx) new_shape = (n_curr_active_inst * n_bm, *d_hs) beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) beamed_tensor = beamed_tensor.view(*new_shape) return beamed_tensor def collate_active_info(src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list, device): # Sentences which are still active are collected, # so the decoder will not run on completed sentences. n_prev_active_inst = len(inst_idx_to_position_map) active_inst_idx = [ inst_idx_to_position_map[k] for k in active_inst_idx_list ] active_inst_idx = torch.LongTensor(active_inst_idx).to(device) active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) return active_src_seq, active_src_enc, active_inst_idx_to_position_map def beam_decode_step(inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm, device): ''' Decode and update beam status, and then return active beam idx ''' def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq, device): dec_partial_seq = [ b.get_current_state() for b in inst_dec_beams if not b.done ] dec_partial_seq = torch.stack(dec_partial_seq).to(device) dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq) return dec_partial_seq def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm, device): dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=device) dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat( n_active_inst * n_bm, 1) return dec_partial_pos def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm): dec_output, *_ = self.decoder(dec_seq, dec_pos, src_seq, enc_output) dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h word_prob = F.log_softmax(self.tgt_word_prj(dec_output), dim=1) word_prob = word_prob.view(n_active_inst, n_bm, -1) return word_prob def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): active_inst_idx_list = [] for inst_idx, inst_position in inst_idx_to_position_map.items( ): is_inst_complete = inst_beams[inst_idx].advance( word_prob[inst_position]) if not is_inst_complete: active_inst_idx_list += [inst_idx] return active_inst_idx_list n_active_inst = len(inst_idx_to_position_map) dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq, device) dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm, device) word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm) # Update the beam with predicted word prob information and collect incomplete instances active_inst_idx_list = collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) return active_inst_idx_list def collect_hypothesis_and_scores(inst_dec_beams, n_best): all_hyp, all_scores = [], [] for inst_idx in range(len(inst_dec_beams)): scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() all_scores += [scores[:n_best]] hyps = [ inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best] ] all_hyp += [hyps] return all_hyp, all_scores with torch.no_grad(): #-- Encode #src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device) src_enc, *_ = self.encoder(src_seq, src_pos) device = src_seq.device.type #-- Repeat data for beam search n_bm = self.opt['beam_size'] n_inst, len_s, d_h = src_enc.size() src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) #-- Prepare beams inst_dec_beams = [ Beam(n_bm, device=src_seq.device.type) for _ in range(n_inst) ] #-- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) #-- Decode for len_dec_seq in range(1, self.opt['max_token_seq_len'] + 1): active_inst_idx_list = beam_decode_step( inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm, device) if not active_inst_idx_list: break # all instances have finished their path to <EOS> src_seq, src_enc, inst_idx_to_position_map = collate_active_info( src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list, device) n_best = 1 batch_hyp, batch_scores = collect_hypothesis_and_scores( inst_dec_beams, n_best) return batch_hyp, batch_scores
def forward(self, xs, ys=None, cands=None, valid_cands=None, prev_enc=None, rank_during_training=False, beam_size=1, topk=1): """Get output predictions from the model. Arguments: xs -- input to the encoder ys -- expected output from the decoder cands -- set of candidates to rank, if applicable valid_cands -- indices to match candidates with their appropriate xs prev_enc -- if you know you'll pass in the same xs multiple times and the model is in eval mode, you can pass in the encoder output from the last forward pass to skip recalcuating the same encoder output rank_during_training -- (default False) if set, ranks any available cands during training as well """ input_xs = xs nbest_beam_preds, nbest_beam_scores = None, None bsz = len(xs) if ys is not None: # keep track of longest label we've ever seen # we'll never produce longer ones than that during prediction self.longest_label = max(self.longest_label, ys.size(1)) if prev_enc is not None: enc_out, hidden, attn_mask = prev_enc else: enc_out, hidden = self.encoder(xs) attn_mask = xs.ne(0).float() if self.attn_type != 'none' else None encoder_states = (enc_out, hidden, attn_mask) start = self.START.detach() starts = start.expand(bsz, 1) predictions = [] scores = [] cand_preds, cand_scores = None, None if self.rank and cands is not None: decode_params = (start, hidden, enc_out, attn_mask) if self.training: if rank_during_training: cand_preds, cand_scores = self.ranker.forward(cands, valid_cands, decode_params=decode_params) else: cand_preds, cand_scores = self.ranker.forward(cands, valid_cands, decode_params=decode_params) if ys is not None: y_in = ys.narrow(1, 0, ys.size(1) - 1) xs = torch.cat([starts, y_in], 1) if self.attn_type == 'none': preds, score, hidden = self.decoder(xs, hidden, enc_out, attn_mask) predictions.append(preds) scores.append(score) else: for i in range(ys.size(1)): xi = xs.select(1, i) preds, score, hidden = self.decoder(xi, hidden, enc_out, attn_mask) predictions.append(preds) scores.append(score) else: # here we do search: supported search types: greedy, beam search if beam_size == 1: done = [False for _ in range(bsz)] total_done = 0 xs = starts for _ in range(self.longest_label): # generate at most longest_label tokens preds, score, hidden = self.decoder(xs, hidden, enc_out, attn_mask, topk) scores.append(score) xs = preds predictions.append(preds) # check if we've produced the end token for b in range(bsz): if not done[b]: # only add more tokens for examples that aren't done if preds.data[b][0] == self.END_IDX: # if we produced END, we're done done[b] = True total_done += 1 if total_done == bsz: # no need to generate any more break elif beam_size > 1: enc_out, hidden = encoder_states[0], encoder_states[1] # take it from encoder enc_out = enc_out.unsqueeze(1).repeat(1, beam_size, 1, 1) # create batch size num of beams data_device = enc_out.device beams = [Beam(beam_size, 3, 0, 1, 2, min_n_best=beam_size / 2, cuda=data_device) for _ in range(bsz)] # init the input with start token xs = starts # repeat tensors to support batched beam xs = xs.repeat(1, beam_size) attn_mask = input_xs.ne(0).float() attn_mask = attn_mask.unsqueeze(1).repeat(1, beam_size, 1) repeated_hidden = [] if isinstance(hidden, tuple): for i in range(len(hidden)): repeated_hidden.append(hidden[i].unsqueeze(2).repeat(1, 1, beam_size, 1)) hidden = self.unbeamize_hidden(tuple(repeated_hidden), beam_size, bsz) else: # GRU repeated_hidden = hidden.unsqueeze(2).repeat(1, 1, beam_size, 1) hidden = self.unbeamize_hidden(repeated_hidden, beam_size, bsz) enc_out = self.unbeamize_enc_out(enc_out, beam_size, bsz) xs = xs.view(bsz * beam_size, -1) for step in range(self.longest_label): if all((b.done() for b in beams)): break out = self.decoder(xs, hidden, enc_out) scores = out[1] scores = scores.view(bsz, beam_size, -1) # -1 is a vocab size for i, b in enumerate(beams): b.advance(F.log_softmax(scores[i, :], dim=-1)) xs = torch.cat([b.get_output_from_current_step() for b in beams]).unsqueeze(-1) permute_hidden_idx = torch.cat( [beam_size * i + b.get_backtrack_from_current_step() for i, b in enumerate(beams)]) new_hidden = out[2] if isinstance(hidden, tuple): for i in range(len(hidden)): hidden[i].data.copy_(new_hidden[i].data.index_select(dim=1, index=permute_hidden_idx)) else: # GRU hidden.data.copy_(new_hidden.data.index_select(dim=1, index=permute_hidden_idx)) for b in beams: b.check_finished() beam_pred = [b.get_pretty_hypothesis(b.get_top_hyp()[0])[1:] for b in beams] # these beam scores are rescored with length penalty! beam_scores = torch.stack([b.get_top_hyp()[1] for b in beams]) pad_length = max([t.size(0) for t in beam_pred]) beam_pred = torch.stack([pad(t, length=pad_length, dim=0) for t in beam_pred], dim=0) # prepare n best list for each beam n_best_beam_tails = [b.get_rescored_finished(n_best=len(b.finished)) for b in beams] nbest_beam_scores = [] nbest_beam_preds = [] for i, beamtails in enumerate(n_best_beam_tails): perbeam_preds = [] perbeam_scores = [] for tail in beamtails: perbeam_preds.append(beams[i].get_pretty_hypothesis(beams[i].get_hyp_from_finished(tail))) perbeam_scores.append(tail.score) nbest_beam_scores.append(perbeam_scores) nbest_beam_preds.append(perbeam_preds) if self.beam_log_freq > 0.0: num_dump = round(bsz * self.beam_log_freq) for i in range(num_dump): dot_graph = beams[i].get_beam_dot(dictionary=self.dict) dot_graph.write_png(os.path.join(self.beam_dump_path, "{}.png".format(self.beam_dump_filecnt))) self.beam_dump_filecnt += 1 predictions = beam_pred scores = beam_scores if isinstance(predictions, list): predictions = torch.cat(predictions, 1) if isinstance(scores, list): scores = torch.cat(scores, 1) return predictions, scores, cand_preds, cand_scores, encoder_states, nbest_beam_preds, nbest_beam_scores