def _gen_beam(self, keyword_ids, normal_vector, **kwargs): dtype = normal_vector.dtype device = normal_vector.device batch_size, latent_dim = normal_vector.size() max_seq_len = self.opts.gen_max_seq_len beam_width = kwargs["beam_width"] length_norm = kwargs["length_norm"] n_best = kwargs["n_best"] keyword_embs = None if keyword_ids is not None: keyword_embs = self.embedding.forward_word_emb(keyword_ids) latent_out = self.latent_module.forward_gen_path(keyword_embs, normal_vector, head_dims=[], batch_size=batch_size, dtype=dtype, device=device)[1].squeeze(0) input = torch.full((1, batch_size), fill_value=self.SOS_token, dtype=torch.long, device=device) output_step = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device) back_pointers = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device) batch_beams = [Beam(beam_width, length_norm, self.EOS_token, n_best) for _ in range(batch_size)] # first step logits_step = self._gen_forward_step(input, keyword_ids, keyword_embs, latent_out, use_cache=False)[-1] step_batch_beams(batch_beams, logits_step, output_step, func="init_beams") if keyword_ids is not None: keyword_ids = keyword_ids.repeat_interleave(beam_width, dim=0) if "ktoken" in self.keyword_approaches: mask = output_step == self.KEYWORD_token output_step[mask] = keyword_ids[mask] # remain steps input = input.repeat_interleave(beam_width, dim=1) input = torch.cat([input, output_step.unsqueeze(0)], dim=0) latent_out = latent_out.repeat_interleave(beam_width, dim=0) if keyword_embs is not None: keyword_embs = keyword_embs.repeat_interleave(beam_width, dim=0) for _ in range(1, max_seq_len): logits = self._gen_forward_step(input, keyword_ids, keyword_embs, latent_out, use_cache=False) logits_step = logits[-1].view(batch_size, beam_width, -1) step_batch_beams(batch_beams, logits_step, output_step, back_pointers, func="update_beams") if all(b.done for b in batch_beams): break if "ktoken" in self.keyword_approaches: mask = output_step == self.KEYWORD_token output_step[mask] = keyword_ids[mask] input = input.index_select(dim=1, index=back_pointers) input = torch.cat([input, output_step.unsqueeze(0)], dim=0) output = list(chain(*(beam.get_best_results()[0] for beam in batch_beams))) output = bidirectional_padding(output, self.PAD_token, 0, device=device)[0] return output
def test_profile(self): self.assertEqual(Profile.UNIFORM, Beam().profile) self.assertEqual(Profile.GAUSSIAN, Beam(profile=Profile.GAUSSIAN).profile) self.assertEqual(Profile.UNIFORM, Beam(profile="unIfoRm").profile) self.assertEqual(Profile.GAUSSIAN, Beam(profile="gaUssIan").profile) self.assertEqual(Profile.UNIFORM, Beam(profile="gaussiann").profile) self.assertEqual(Profile.UNIFORM, Beam(profile=None).profile)
def beamsearch(memory, model, beam_size=4, candidates=1, max_seq_length=128, sos_token=1, eos_token=2): # memory: Tx1xE model.eval() device = memory.device beam = Beam(beam_size=beam_size, min_length=0, n_top=candidates, ranker=None, start_token_id=sos_token, end_token_id=eos_token) with torch.no_grad(): memory = memory.repeat(1, beam_size, 1) # TxNxE for _ in range(max_seq_length): tgt_inp = beam.get_current_state().transpose(0, 1).to(device) # TxN decoder_outputs = model.transformer.forward_decoder( tgt_inp, memory) log_prob = log_softmax(decoder_outputs[:, -1, :].squeeze(0), dim=-1) beam.advance(log_prob.cpu()) if beam.done(): break scores, ks = beam.sort_finished(minimum=1) hypothesises = [] for i, (times, k) in enumerate(ks[:candidates]): hypothesis = beam.get_hypothesis(times, k) hypothesises.append(hypothesis) return [1] + [int(i) for i in hypothesises[0][:-1]]
def beam_decode(self, dec_hidden, enc_outputs, enc_length): ''' Args: dec_hidden : [num_layers, batch_size, hidden_size] (optional) Return: prediction: [batch_size, beam, max_len] ''' batch_size, beam_size = self.config.batch_size, self.config.beam_size # [1, batch_size x beam_size] input = torch.ones(1, batch_size * beam_size, dtype=torch.long, device=self.device) * SOS_ID # [num_layers, batch_size * beam_size, hidden_size] dec_hidden = dec_hidden.repeat(1, beam_size, 1) # [1, batch_size * beam_size, hidden_size] enc_outputs = enc_outputs.repeat(1, beam_size, 1) # [batch_size * beam_size] enc_length = enc_length.repeat(beam_size) # [batch_size] [0, beam_size * 1, ..., beam_size * (batch_size - 1)] batch_position = torch.arange( 0, batch_size, dtype=torch.long, device=self.device) * beam_size score = torch.ones(batch_size * beam_size, device=self.device) * -float('inf') score.index_fill_(0, torch.arange( 0, batch_size, dtype=torch.long, device=self.device) * beam_size, 0.0) # Initialize Beam that stores decisions for backtracking beam = Beam( batch_size, beam_size, self.config.r_max_len, batch_position, EOS_ID ) for i in range(self.config.r_max_len): output, dec_hidden, _ = self.rnn_decoder( input.view(1, -1), dec_hidden, enc_outputs, enc_length, ) # output: [1, batch_size * beam_size, vocab_size] # -> [batch_size * beam_size, vocab_size] log_prob = output.squeeze(0) # print('log_prob: ', log_prob.shape) # score: [batch_size * beam_size, vocab_size] score = score.view(-1, 1) + log_prob # score [batch_size, beam_size] score, top_k_idx = score.view( batch_size, -1).topk(beam_size, dim=1) # input: [batch_size x beam_size] input = (top_k_idx % self.config.vocab_size).view(-1) # beam_idx: [batch_size, beam_size] # [batch_size, beam_size] beam_idx = top_k_idx / self.config.vocab_size # top_k_pointer: [batch_size * beam_size] top_k_pointer = (beam_idx + batch_position.unsqueeze(1)).view(-1) # [num_layers, batch_size * beam_size, hidden_size] dec_hidden = dec_hidden.index_select(1, top_k_pointer) # Update sequence scores at beam beam.update(score.clone(), top_k_pointer, input) # Erase scores for EOS so that they are not expanded # [batch_size, beam_size] eos_idx = input.data.eq(EOS_ID).view( batch_size, beam_size) if eos_idx.nonzero().dim() > 0: score.data.masked_fill_(eos_idx, -float('inf')) prediction, final_score, length = beam.backtrack() return prediction, final_score, length
def infer_wd(model, sp_results, enc_memory_bank, enc_memory_lengths, batch_size, opt, eos_symbols): """ Args: tensor_data_dict: "ph_bank_tensor": [batch_size x max_ph_size x max_ph_len] tensor of phrase word ids "ph_bank_word_mask_tensor": [batch_size x max_ph_size x max_ph_len] tensor of phrase word mask sp_results: a list of planner decoding results, each consists of 5 fields: 1) "sent_num": number of sentences in this sample 2) "stype_id": a list of tensors, each is a LongTensor indicating sentence type 3) "stype_onehot": a list of tensors, each is a onehot decoding indicating sentence type 4) "dec_outs": a list of tensors, each is of dimension [1, 512], indicating planner's hidden states 5) "content_selection_preds": a list of tensors, each is a binary vector encoding phrase selection enc_memory_bank: [(batch_size * beam_size) x max_src_len x 512] size of encoder hidden states, tiled enc_memory_lengths: [(batch_size * beam_size)] size of source sequence, tiled batch_size opt eos_symbols: a list of word id for symbols that end sentences, used to change sentence id Returns: wd_results: """ # do tiling on states # ph_bank_vec: [batch_size x max_ph_size x 300] model.wd_dec.map_state(lambda state, dim: utils.tile(state, opt.beam_size, dim=dim)) scorer = GNMTGlobalScorer(alpha=0., beta=0., length_penalty='none', cov_penalty="none") beam = [Beam(size=opt.beam_size, pad=utils.PAD_id, bos=utils.SOS_id, eos_lst=eos_symbols, eos=utils.EOS_id, n_best=1, cuda=True, global_scorer=scorer, min_length=opt.min_target_words, max_sent_num=item["sent_num"], stepwise_penalty=False, block_ngram_repeat=opt.block_ngram_repeat, exclusion_tokens=set()) for item in sp_results] sent_ids = torch.zeros(batch_size * opt.beam_size, dtype=torch.long).cuda() wd_results = [{"word_preds": [], "sent_ids": [], "sent_type": [], "scores": [], "attention": []} for _ in range(batch_size)] # sp_dec_outs: [batch_size x max_sent_num x 512] sp_dec_outs_tmp = [torch.cat(s["dec_outs"]) for s in sp_results] sp_dec_outs = torch.cat([k.unsqueeze(0) for k in sp_dec_outs_tmp]) sp_dec_outs = utils.tile(sp_dec_outs, opt.beam_size, dim=0) # stype_preds: [batch_size x max_sent_num x 4] stype_preds_tmp = [torch.cat(k["stype_onehot"]) for k in sp_results] stype_preds = torch.cat([k.unsqueeze(0) for k in stype_preds_tmp]) stype_preds = utils.tile(stype_preds, opt.beam_size, dim=0) def pick_sentence_states(sent_ids): # cur_stype_onehot: [(batch_size * beam_size) x 1 x 4] sent_id_expanded_stype = sent_ids.clone().cuda().unsqueeze(-1).unsqueeze(-1).expand(-1, 1, 4) cur_stype_onehot = torch.gather(stype_preds, 1, sent_id_expanded_stype) # cur_dec_outs: [(batch_size * beam_size) x 1 x 512] sent_id_expanded_dec_outs = sent_ids.clone().cuda().unsqueeze(-1).unsqueeze(-1).expand(-1, 1, 512) cur_dec_outs = torch.gather(sp_dec_outs, 1, sent_id_expanded_dec_outs) return cur_stype_onehot, cur_dec_outs cur_stype_onehot, cur_dec_outs = pick_sentence_states(sent_ids) steps_executed = 0 for word_t in range(opt.max_tgt_words): # print("word_t=%d" % word_t) if all(b.done() for b in beam): break steps_executed += 1 # word_input: [(batch_size * beam_size) x 1 x 1] # word_input_emb: [(batch_size * beam_size) x 1 x 300] word_input = torch.stack([b.get_current_state() for b in beam]) word_input = word_input.view(-1, 1) word_input_emb = model.word_emb(word_input) enc_attn, wd_logits = model.wd_dec.forward_onestep(word_inputs_emb=word_input_emb, sent_planner_output=cur_dec_outs, enc_memory_bank=enc_memory_bank, enc_memory_len=enc_memory_lengths, stype_one_hot=cur_stype_onehot) # wd_probs: [(batch_size * beam_size) x vocab_size] # beam_attn:[(batch_size * beam_size) x max_src_len] wd_probs = model.wd_dec.softmax(wd_logits).view(batch_size, opt.beam_size, -1) beam_attn = enc_attn.view(batch_size, opt.beam_size, -1) select_indices_array = [] sid_changed = [] for sample_id, b in enumerate(beam): cur_sid_changed = b.advance(wd_probs[sample_id,:], beam_attn.data[sample_id, :, :enc_memory_lengths[sample_id]]) select_indices_array.append(b.get_current_origin() + sample_id * opt.beam_size) sid_changed.extend(cur_sid_changed) select_indices = torch.cat(select_indices_array) model.wd_dec.map_state(lambda state, dim: state.index_select(dim, select_indices)) if sum(sid_changed) > 0 or word_t == 0: new_sent_ids_array = [] for sample_id, b in enumerate(beam): new_sent_ids_array.append(b.get_current_sent_id()) new_sent_ids = torch.cat(new_sent_ids_array) cur_stype_onehot, cur_dec_outs = pick_sentence_states(new_sent_ids) for sample_id, b in enumerate(beam): scores, ks = b.sort_finished(minimum=1) hyps, attn, sent_ids = [], [], [] for i, (times, k) in enumerate(ks[:1]): hyp, att, sent_id = b.get_hyp(times, k) hyps.append(hyp) attn.append(att) sent_ids.append(sent_id) wd_results[sample_id]["word_preds"] = [wid.cpu().tolist() for wid in hyps[0]] wd_results[sample_id]["scores"] = scores wd_results[sample_id]["attention"] = attn wd_results[sample_id]["sent_ids"] = [sid.cpu().tolist() for sid in sent_ids[0]] print("decoding finished for batch, steps executed=%d" % steps_executed) return wd_results
def eval_epoch(args, model, test_dataloader, tokenizer, device, n_gpu, nlgEvalObj=None, test_set=None): if hasattr(model, 'module'): model = model.module.to(device) if model._stage_one: return 0. all_result_lists = [] all_caption_lists = [] model.eval() for batch in test_dataloader: batch = tuple(t.to(device, non_blocking=True) for t in batch) input_ids, input_mask, segment_ids, video, video_mask, \ pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \ pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids = batch with torch.no_grad(): sequence_output, visual_output = model.get_sequence_visual_output( input_ids, segment_ids, input_mask, video, video_mask) # -- Repeat data for beam search n_bm = 5 # beam_size device = sequence_output.device n_inst, len_s, d_h = sequence_output.size() _, len_v, v_h = visual_output.size() decoder = model.decoder_caption # Note: shaped first, then decoder need the parameter shaped=True input_ids = input_ids.view(-1, input_ids.shape[-1]) input_mask = input_mask.view(-1, input_mask.shape[-1]) video_mask = video_mask.view(-1, video_mask.shape[-1]) sequence_output_rpt = sequence_output.repeat(1, n_bm, 1).view( n_inst * n_bm, len_s, d_h) visual_output_rpt = visual_output.repeat(1, n_bm, 1).view( n_inst * n_bm, len_v, v_h) input_ids_rpt = input_ids.repeat(1, n_bm).view(n_inst * n_bm, len_s) input_mask_rpt = input_mask.repeat(1, n_bm).view( n_inst * n_bm, len_s) video_mask_rpt = video_mask.repeat(1, n_bm).view( n_inst * n_bm, len_v) # -- Prepare beams inst_dec_beams = [ Beam(n_bm, device=device, tokenizer=tokenizer) for _ in range(n_inst) ] # -- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) # -- Decode for len_dec_seq in range(1, args.max_words + 1): active_inst_idx_list = beam_decode_step( decoder, inst_dec_beams, len_dec_seq, inst_idx_to_position_map, n_bm, device, (sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt)) if not active_inst_idx_list: break # all instances have finished their path to <EOS> (sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt), \ inst_idx_to_position_map = collate_active_info((sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt), inst_idx_to_position_map, active_inst_idx_list, n_bm, device) batch_hyp, batch_scores = collect_hypothesis_and_scores( inst_dec_beams, 1) result_list = [batch_hyp[i][0] for i in range(n_inst)] pairs_output_caption_ids = pairs_output_caption_ids.view( -1, pairs_output_caption_ids.shape[-1]) caption_list = pairs_output_caption_ids.cpu().detach().numpy() for re_idx, re_list in enumerate(result_list): decode_text_list = tokenizer.convert_ids_to_tokens(re_list) if "[SEP]" in decode_text_list: SEP_index = decode_text_list.index("[SEP]") decode_text_list = decode_text_list[:SEP_index] if "[PAD]" in decode_text_list: PAD_index = decode_text_list.index("[PAD]") decode_text_list = decode_text_list[:PAD_index] decode_text = ' '.join(decode_text_list) decode_text = decode_text.replace(" ##", "").strip("##").strip() all_result_lists.append(decode_text) for re_idx, re_list in enumerate(caption_list): decode_text_list = tokenizer.convert_ids_to_tokens(re_list) if "[SEP]" in decode_text_list: SEP_index = decode_text_list.index("[SEP]") decode_text_list = decode_text_list[:SEP_index] if "[PAD]" in decode_text_list: PAD_index = decode_text_list.index("[PAD]") decode_text_list = decode_text_list[:PAD_index] decode_text = ' '.join(decode_text_list) decode_text = decode_text.replace(" ##", "").strip("##").strip() all_caption_lists.append(decode_text) # Save full results if test_set is not None and hasattr(test_set, 'iter2video_pairs_dict'): hyp_path = os.path.join(args.output_dir, "hyp_complete_results.txt") with open(hyp_path, "w", encoding='utf-8') as writer: writer.write("{}\t{}\t{}\n".format("video_id", "start_time", "caption")) for idx, pre_txt in enumerate(all_result_lists): video_id, sub_id = test_set.iter2video_pairs_dict[idx] start_time = test_set.data_dict[video_id]['start'][sub_id] writer.write("{}\t{}\t{}\n".format(video_id, start_time, pre_txt)) logger.info("File of complete results is saved in {}".format(hyp_path)) # Save pure results hyp_path = os.path.join(args.output_dir, "hyp.txt") with open(hyp_path, "w", encoding='utf-8') as writer: for pre_txt in all_result_lists: writer.write(pre_txt + "\n") ref_path = os.path.join(args.output_dir, "ref.txt") with open(ref_path, "w", encoding='utf-8') as writer: for ground_txt in all_caption_lists: writer.write(ground_txt + "\n") if args.datatype == "msrvtt": all_caption_lists = [] sentences_dict = test_dataloader.dataset.sentences_dict video_sentences_dict = test_dataloader.dataset.video_sentences_dict for idx in range(len(sentences_dict)): video_id, _ = sentences_dict[idx] sentences = video_sentences_dict[video_id] all_caption_lists.append(sentences) all_caption_lists = [list(itms) for itms in zip(*all_caption_lists)] else: all_caption_lists = [all_caption_lists] # Evaluate metrics_nlg = nlgEvalObj.compute_metrics(ref_list=all_caption_lists, hyp_list=all_result_lists) logger.info( ">>> BLEU_1: {:.4f}, BLEU_2: {:.4f}, BLEU_3: {:.4f}, BLEU_4: {:.4f}". format(metrics_nlg["Bleu_1"], metrics_nlg["Bleu_2"], metrics_nlg["Bleu_3"], metrics_nlg["Bleu_4"])) logger.info(">>> METEOR: {:.4f}, ROUGE_L: {:.4f}, CIDEr: {:.4f}".format( metrics_nlg["METEOR"], metrics_nlg["ROUGE_L"], metrics_nlg["CIDEr"])) Bleu_4 = metrics_nlg["Bleu_4"] return Bleu_4
def test_get_mcerd_params(self): beam = Beam(ion=Element.from_string("4He"), energy=14) self.assertEqual(["Beam ion: 4He", "Beam energy: 14 MeV"], beam.get_mcerd_params())
def get_beam() -> Beam: """Returns a default Beam object. """ return Beam()
def generate(self, batch): """Run greedy decoding for sentence decoding (sp_dec), and beam search for token decoding (wd_dec).""" batch_size = len(batch["id"]) enc_outs, enc_final = self.model.enc(batch["enc_src"], batch["enc_src_len"]) memory_bank = tile(enc_outs, self.beam_size, dim=0) memory_lengths = tile(batch["enc_src_len"], self.beam_size) self.model.sp_dec.init_state(encoder_final=enc_final) self.model.wd_dec.init_state(encoder_final=enc_final) self.model.wd_dec.map_state( lambda state, dim: tile(state, self.beam_size, dim=dim)) if self.use_goldstandard_plan: sp_outputs = self.sentence_decoding_goldstandard( ph_bank=batch['ph_bank_tensor'], ph_bank_len=batch['ph_bank_len_tensor'], ph_sel_tensor=batch['ph_sel_tensor'], stype=batch['sent_types']) else: sp_outputs = self.sentence_decoding_greedy( ph_bank=batch["ph_bank_tensor"], ph_bank_len=batch["ph_bank_len_tensor"]) ph_sel_results = sp_outputs[0] stype_results = sp_outputs[1] sp_hidden_states = sp_outputs[2] _, max_pred_sent_num, sp_hidden_dim = sp_hidden_states.shape if not self.quiet: # print phrase selection and sentence type results self.print_sp_results(ph_sel_results, stype_results) softmax = nn.Softmax(dim=-1) beams = [ Beam(self.beam_size, vocab=self.vocab, min_length=10, block_ngram_repeat=self.block_ngram_repeat) for _ in range(batch_size) ] # first sentence has only <SOS>, therefore should only output one token max_token_in_sents = [1] + [ self.max_token_per_sentence for _ in range(max_pred_sent_num - 1) ] for sent_id in range(max_pred_sent_num): cur_sp_dec_outs = sp_hidden_states[:, sent_id, :].unsqueeze(1) cur_sp_dec_outs_tile = tile(cur_sp_dec_outs, self.beam_size, 0) for step in range(max_token_in_sents[sent_id]): if all((b.done() for b in beams)): break word_dec_input = torch.stack( [b.get_current_state() for b in beams]) word_dec_input = word_dec_input.view(-1, 1, 1).cuda() wd_outputs = self.model.wd_dec.forward_onestep( dec_inputs=word_dec_input, enc_memory_bank=memory_bank, enc_memory_len=memory_lengths, sp_hidden_state=cur_sp_dec_outs_tile) dec_logits = wd_outputs[0] dec_probs = softmax(dec_logits).view(batch_size, self.beam_size, -1) select_indices_array = [] for ix, beam in enumerate(beams): beam.advance(probs=dec_probs[ix, :]) select_indices_array.append(beam.get_current_origin() + ix * self.beam_size) select_indices = torch.cat(select_indices_array) self.model.wd_dec.map_state( lambda state, dim: state.index_select(dim, select_indices)) if not self.quiet: # print results in the top beam in all instances to_print = f"sentence {sent_id:<2d} step-{step:<2d} | " for ix, beam in enumerate(beams): cur_words = beam.get_current_state() top_beam = cur_words[0].item() top_beam_word = self.vocab.get_word(idx=top_beam) to_print += f" {top_beam_word:<10s} | " print(to_print) results = [] for b in beams: scores, ks = b.sort_finished(minimum=1) hyps = [] for i, (times, k) in enumerate(ks[:1]): hyp = b.get_hyp(times, k) hyps.append([tok_id.item() for tok_id in hyp]) results.append(hyps) stype_results_str = [] for b in range(batch_size): cur_stype = [stype_step[b].item() for stype_step in stype_results] end = cur_stype.index(2) if 2 in cur_stype else len(cur_stype) cur_stype = cur_stype[:end] stype_results_str.append( [self.stype_map[item] for item in cur_stype]) ph_sel_results_str = [] for b in range(batch_size): cur_ph_sel = [ph_step[b] for ph_step in ph_sel_results[1:]] cur_ph_sel_str = [] for sent in cur_ph_sel: sent = sent[sent.sum(-1) > 0] sent_str = [' '.join(self.vocab.decode(item)) for item in sent] cur_ph_sel_str.append(sent_str) if sent_str == ['EOS']: break ph_sel_results_str.append(cur_ph_sel_str) return results, stype_results_str, ph_sel_results_str
def _gen_beam(self, keyword_ids, normal_vector, **kwargs): device = normal_vector.device _, batch_size, latent_dim = normal_vector.size() max_seq_len = self.opts.gen_max_seq_len beam_width = kwargs["beam_width"] length_norm = kwargs["length_norm"] n_best = kwargs["n_best"] input = torch.full((1, batch_size), fill_value=self.SOS_token, dtype=torch.long, device=device) keyword_embs = self.embedding.forward_word_emb(keyword_ids) hidden = self._init_states(keyword_embs, "fwd") output_step = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device) back_pointers = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device) batch_beams = [ Beam(beam_width, length_norm, self.EOS_token, n_best) for _ in range(batch_size) ] # first step logits_step, hidden = self._gen_forward_step(input, hidden, normal_vector[0], use_cache=False) step_batch_beams(batch_beams, logits_step, output_step, func="init_beams") # remain steps input = input.repeat_interleave(beam_width, dim=1) normal_vector = normal_vector.repeat_interleave(beam_width, dim=1) input = torch.cat([input, output_step.unsqueeze(0)], dim=0) hidden = hidden.repeat_interleave(beam_width, dim=1) for step in range(1, max_seq_len): logits_step, hidden = self._gen_forward_step(input, hidden, normal_vector[step], use_cache=False) logits_step = logits_step.view(batch_size, beam_width, -1) step_batch_beams(batch_beams, logits_step, output_step, back_pointers, func="update_beams") if all(b.done for b in batch_beams): break input = input.index_select(dim=1, index=back_pointers) input = torch.cat([input, output_step.unsqueeze(0)], dim=0) hidden = hidden.index_select(dim=1, index=back_pointers) output = list( chain(*(beam.get_best_results()[0] for beam in batch_beams))) output = bidirectional_padding(output, self.PAD_token, 0, device=device)[0] return output