Exemplo n.º 1
0
    def _gen_beam(self, keyword_ids, normal_vector, **kwargs):
        dtype = normal_vector.dtype
        device = normal_vector.device
        batch_size, latent_dim = normal_vector.size()
        max_seq_len = self.opts.gen_max_seq_len
        beam_width = kwargs["beam_width"]
        length_norm = kwargs["length_norm"]
        n_best = kwargs["n_best"]

        keyword_embs = None
        if keyword_ids is not None:
            keyword_embs = self.embedding.forward_word_emb(keyword_ids)
        latent_out = self.latent_module.forward_gen_path(keyword_embs, normal_vector,
                                                         head_dims=[], batch_size=batch_size,
                                                         dtype=dtype, device=device)[1].squeeze(0)

        input = torch.full((1, batch_size), fill_value=self.SOS_token, dtype=torch.long, device=device)
        output_step = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device)
        back_pointers = torch.zeros(batch_size * beam_width, dtype=torch.long, device=device)
        batch_beams = [Beam(beam_width, length_norm, self.EOS_token, n_best) for _ in range(batch_size)]

        # first step
        logits_step = self._gen_forward_step(input, keyword_ids, keyword_embs, latent_out, use_cache=False)[-1]
        step_batch_beams(batch_beams, logits_step, output_step, func="init_beams")
        if keyword_ids is not None:
            keyword_ids = keyword_ids.repeat_interleave(beam_width, dim=0)
        if "ktoken" in self.keyword_approaches:
            mask = output_step == self.KEYWORD_token
            output_step[mask] = keyword_ids[mask]

        # remain steps
        input = input.repeat_interleave(beam_width, dim=1)
        input = torch.cat([input, output_step.unsqueeze(0)], dim=0)
        latent_out = latent_out.repeat_interleave(beam_width, dim=0)
        if keyword_embs is not None:
            keyword_embs = keyword_embs.repeat_interleave(beam_width, dim=0)
        for _ in range(1, max_seq_len):
            logits = self._gen_forward_step(input, keyword_ids, keyword_embs, latent_out, use_cache=False)
            logits_step = logits[-1].view(batch_size, beam_width, -1)
            step_batch_beams(batch_beams, logits_step, output_step, back_pointers, func="update_beams")
            if all(b.done for b in batch_beams):
                break
            if "ktoken" in self.keyword_approaches:
                mask = output_step == self.KEYWORD_token
                output_step[mask] = keyword_ids[mask]
            input = input.index_select(dim=1, index=back_pointers)
            input = torch.cat([input, output_step.unsqueeze(0)], dim=0)

        output = list(chain(*(beam.get_best_results()[0] for beam in batch_beams)))
        output = bidirectional_padding(output, self.PAD_token, 0, device=device)[0]

        return output
Exemplo n.º 2
0
    def test_profile(self):
        self.assertEqual(Profile.UNIFORM, Beam().profile)
        self.assertEqual(Profile.GAUSSIAN,
                         Beam(profile=Profile.GAUSSIAN).profile)
        self.assertEqual(Profile.UNIFORM, Beam(profile="unIfoRm").profile)
        self.assertEqual(Profile.GAUSSIAN, Beam(profile="gaUssIan").profile)

        self.assertEqual(Profile.UNIFORM, Beam(profile="gaussiann").profile)

        self.assertEqual(Profile.UNIFORM, Beam(profile=None).profile)
Exemplo n.º 3
0
def beamsearch(memory,
               model,
               beam_size=4,
               candidates=1,
               max_seq_length=128,
               sos_token=1,
               eos_token=2):
    # memory: Tx1xE
    model.eval()
    device = memory.device

    beam = Beam(beam_size=beam_size,
                min_length=0,
                n_top=candidates,
                ranker=None,
                start_token_id=sos_token,
                end_token_id=eos_token)

    with torch.no_grad():
        memory = memory.repeat(1, beam_size, 1)  # TxNxE

        for _ in range(max_seq_length):

            tgt_inp = beam.get_current_state().transpose(0,
                                                         1).to(device)  # TxN
            decoder_outputs = model.transformer.forward_decoder(
                tgt_inp, memory)

            log_prob = log_softmax(decoder_outputs[:, -1, :].squeeze(0),
                                   dim=-1)
            beam.advance(log_prob.cpu())

            if beam.done():
                break

        scores, ks = beam.sort_finished(minimum=1)

        hypothesises = []
        for i, (times, k) in enumerate(ks[:candidates]):
            hypothesis = beam.get_hypothesis(times, k)
            hypothesises.append(hypothesis)

    return [1] + [int(i) for i in hypothesises[0][:-1]]
Exemplo n.º 4
0
    def beam_decode(self,
                    dec_hidden,
                    enc_outputs,
                    enc_length):
        '''
        Args:
            dec_hidden : [num_layers, batch_size, hidden_size] (optional)
        Return:
            prediction: [batch_size, beam, max_len]
        '''
        batch_size, beam_size = self.config.batch_size, self.config.beam_size

        # [1, batch_size x beam_size]
        input = torch.ones(1, batch_size * beam_size,
                               dtype=torch.long,
                               device=self.device) * SOS_ID

        # [num_layers, batch_size * beam_size, hidden_size]
        dec_hidden = dec_hidden.repeat(1, beam_size, 1)

        # [1, batch_size * beam_size, hidden_size]
        enc_outputs = enc_outputs.repeat(1, beam_size, 1)

        # [batch_size * beam_size]
        enc_length = enc_length.repeat(beam_size)

        # [batch_size] [0, beam_size * 1, ..., beam_size * (batch_size - 1)]
        batch_position = torch.arange(
            0, batch_size, dtype=torch.long, device=self.device) * beam_size

        score = torch.ones(batch_size * beam_size,
                           device=self.device) * -float('inf')

        score.index_fill_(0, torch.arange(
            0, batch_size, dtype=torch.long, device=self.device) * beam_size, 0.0)

        # Initialize Beam that stores decisions for backtracking
        beam = Beam(
            batch_size,
            beam_size,
            self.config.r_max_len,
            batch_position,
            EOS_ID
        )

        for i in range(self.config.r_max_len):
            output, dec_hidden, _ = self.rnn_decoder(
                input.view(1, -1),
                dec_hidden,
                enc_outputs,
                enc_length,
            )

            # output: [1, batch_size * beam_size, vocab_size]
            # -> [batch_size * beam_size, vocab_size]
            log_prob = output.squeeze(0)
            #  print('log_prob: ', log_prob.shape)

            # score: [batch_size * beam_size, vocab_size]
            score = score.view(-1, 1) + log_prob

            # score [batch_size, beam_size]
            score, top_k_idx = score.view(
                batch_size, -1).topk(beam_size, dim=1)

            # input: [batch_size x beam_size]
            input = (top_k_idx % self.config.vocab_size).view(-1)

            # beam_idx: [batch_size, beam_size]
            # [batch_size, beam_size]
            beam_idx = top_k_idx / self.config.vocab_size

            # top_k_pointer: [batch_size * beam_size]
            top_k_pointer = (beam_idx + batch_position.unsqueeze(1)).view(-1)

            # [num_layers, batch_size * beam_size, hidden_size]
            dec_hidden = dec_hidden.index_select(1, top_k_pointer)

            # Update sequence scores at beam
            beam.update(score.clone(), top_k_pointer, input)

            # Erase scores for EOS so that they are not expanded
            # [batch_size, beam_size]
            eos_idx = input.data.eq(EOS_ID).view(
                batch_size, beam_size)

            if eos_idx.nonzero().dim() > 0:
                score.data.masked_fill_(eos_idx, -float('inf'))

        prediction, final_score, length = beam.backtrack()

        return prediction, final_score, length
Exemplo n.º 5
0
def infer_wd(model, sp_results, enc_memory_bank, enc_memory_lengths, batch_size, opt, eos_symbols):
    """
    Args:
        tensor_data_dict:
            "ph_bank_tensor": [batch_size x max_ph_size x max_ph_len] tensor of phrase word ids
            "ph_bank_word_mask_tensor": [batch_size x max_ph_size x max_ph_len] tensor of phrase word mask
        sp_results: a list of planner decoding results, each consists of 5 fields:
            1) "sent_num": number of sentences in this sample
            2) "stype_id": a list of tensors, each is a LongTensor indicating sentence type
            3) "stype_onehot": a list of tensors, each is a onehot decoding indicating sentence type
            4) "dec_outs": a list of tensors, each is of dimension [1, 512], indicating planner's hidden states
            5) "content_selection_preds": a list of tensors, each is a binary vector encoding phrase selection
        enc_memory_bank: [(batch_size * beam_size) x max_src_len x 512] size of encoder hidden states, tiled
        enc_memory_lengths: [(batch_size * beam_size)] size of source sequence, tiled
        batch_size
        opt
        eos_symbols: a list of word id for symbols that end sentences, used to change sentence id
    Returns:
        wd_results:
    """

    # do tiling on states
    # ph_bank_vec: [batch_size x max_ph_size x 300]
    model.wd_dec.map_state(lambda state, dim: utils.tile(state, opt.beam_size, dim=dim))

    scorer = GNMTGlobalScorer(alpha=0., beta=0., length_penalty='none',
                              cov_penalty="none")

    beam = [Beam(size=opt.beam_size, pad=utils.PAD_id, bos=utils.SOS_id,
                 eos_lst=eos_symbols,
                 eos=utils.EOS_id,
                 n_best=1, cuda=True,
                 global_scorer=scorer,
                 min_length=opt.min_target_words,
                 max_sent_num=item["sent_num"],
                 stepwise_penalty=False,
                 block_ngram_repeat=opt.block_ngram_repeat,
                 exclusion_tokens=set()) for item in sp_results]

    sent_ids = torch.zeros(batch_size * opt.beam_size, dtype=torch.long).cuda()

    wd_results = [{"word_preds": [], "sent_ids": [], "sent_type": [], "scores": [], "attention": []}
                  for _ in range(batch_size)]

    # sp_dec_outs: [batch_size x max_sent_num x 512]
    sp_dec_outs_tmp = [torch.cat(s["dec_outs"]) for s in sp_results]
    sp_dec_outs = torch.cat([k.unsqueeze(0) for k in sp_dec_outs_tmp])
    sp_dec_outs = utils.tile(sp_dec_outs, opt.beam_size, dim=0)

    # stype_preds: [batch_size x max_sent_num x 4]
    stype_preds_tmp = [torch.cat(k["stype_onehot"]) for k in sp_results]
    stype_preds = torch.cat([k.unsqueeze(0) for k in stype_preds_tmp])
    stype_preds = utils.tile(stype_preds, opt.beam_size, dim=0)

    def pick_sentence_states(sent_ids):
        # cur_stype_onehot: [(batch_size * beam_size) x 1 x 4]
        sent_id_expanded_stype = sent_ids.clone().cuda().unsqueeze(-1).unsqueeze(-1).expand(-1, 1, 4)
        cur_stype_onehot = torch.gather(stype_preds, 1, sent_id_expanded_stype)

        # cur_dec_outs: [(batch_size * beam_size) x 1 x 512]
        sent_id_expanded_dec_outs = sent_ids.clone().cuda().unsqueeze(-1).unsqueeze(-1).expand(-1, 1, 512)
        cur_dec_outs = torch.gather(sp_dec_outs, 1, sent_id_expanded_dec_outs)
        return cur_stype_onehot, cur_dec_outs

    cur_stype_onehot, cur_dec_outs = pick_sentence_states(sent_ids)
    steps_executed = 0

    for word_t in range(opt.max_tgt_words):
        # print("word_t=%d" % word_t)
        if all(b.done() for b in beam): break
        steps_executed += 1

        # word_input: [(batch_size * beam_size) x 1 x 1]
        # word_input_emb: [(batch_size * beam_size) x 1 x 300]
        word_input = torch.stack([b.get_current_state() for b in beam])
        word_input = word_input.view(-1, 1)
        word_input_emb = model.word_emb(word_input)

        enc_attn, wd_logits = model.wd_dec.forward_onestep(word_inputs_emb=word_input_emb,
                                                           sent_planner_output=cur_dec_outs,
                                                           enc_memory_bank=enc_memory_bank,
                                                           enc_memory_len=enc_memory_lengths,
                                                           stype_one_hot=cur_stype_onehot)

        # wd_probs: [(batch_size * beam_size) x vocab_size]
        # beam_attn:[(batch_size * beam_size) x max_src_len]
        wd_probs = model.wd_dec.softmax(wd_logits).view(batch_size, opt.beam_size, -1)
        beam_attn = enc_attn.view(batch_size, opt.beam_size, -1)

        select_indices_array = []
        sid_changed = []
        for sample_id, b in enumerate(beam):
            cur_sid_changed = b.advance(wd_probs[sample_id,:],
                      beam_attn.data[sample_id, :, :enc_memory_lengths[sample_id]])
            select_indices_array.append(b.get_current_origin() + sample_id * opt.beam_size)
            sid_changed.extend(cur_sid_changed)

        select_indices = torch.cat(select_indices_array)

        model.wd_dec.map_state(lambda state, dim: state.index_select(dim, select_indices))
        if sum(sid_changed) > 0 or word_t == 0:
            new_sent_ids_array = []
            for sample_id, b in enumerate(beam):
                new_sent_ids_array.append(b.get_current_sent_id())
            new_sent_ids = torch.cat(new_sent_ids_array)
            cur_stype_onehot, cur_dec_outs = pick_sentence_states(new_sent_ids)

    for sample_id, b in enumerate(beam):
        scores, ks = b.sort_finished(minimum=1)
        hyps, attn, sent_ids = [], [], []
        for i, (times, k) in enumerate(ks[:1]):
            hyp, att, sent_id = b.get_hyp(times, k)
            hyps.append(hyp)
            attn.append(att)
            sent_ids.append(sent_id)
        wd_results[sample_id]["word_preds"] = [wid.cpu().tolist() for wid in hyps[0]]
        wd_results[sample_id]["scores"] = scores
        wd_results[sample_id]["attention"] = attn
        wd_results[sample_id]["sent_ids"] = [sid.cpu().tolist() for sid in sent_ids[0]]
    print("decoding finished for batch, steps executed=%d" % steps_executed)
    return wd_results
Exemplo n.º 6
0
def eval_epoch(args,
               model,
               test_dataloader,
               tokenizer,
               device,
               n_gpu,
               nlgEvalObj=None,
               test_set=None):

    if hasattr(model, 'module'):
        model = model.module.to(device)

    if model._stage_one:
        return 0.

    all_result_lists = []
    all_caption_lists = []
    model.eval()
    for batch in test_dataloader:
        batch = tuple(t.to(device, non_blocking=True) for t in batch)

        input_ids, input_mask, segment_ids, video, video_mask, \
        pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \
        pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids = batch

        with torch.no_grad():
            sequence_output, visual_output = model.get_sequence_visual_output(
                input_ids, segment_ids, input_mask, video, video_mask)
            # -- Repeat data for beam search
            n_bm = 5  # beam_size
            device = sequence_output.device
            n_inst, len_s, d_h = sequence_output.size()
            _, len_v, v_h = visual_output.size()

            decoder = model.decoder_caption

            # Note: shaped first, then decoder need the parameter shaped=True
            input_ids = input_ids.view(-1, input_ids.shape[-1])
            input_mask = input_mask.view(-1, input_mask.shape[-1])
            video_mask = video_mask.view(-1, video_mask.shape[-1])

            sequence_output_rpt = sequence_output.repeat(1, n_bm, 1).view(
                n_inst * n_bm, len_s, d_h)
            visual_output_rpt = visual_output.repeat(1, n_bm, 1).view(
                n_inst * n_bm, len_v, v_h)
            input_ids_rpt = input_ids.repeat(1,
                                             n_bm).view(n_inst * n_bm, len_s)
            input_mask_rpt = input_mask.repeat(1, n_bm).view(
                n_inst * n_bm, len_s)
            video_mask_rpt = video_mask.repeat(1, n_bm).view(
                n_inst * n_bm, len_v)

            # -- Prepare beams
            inst_dec_beams = [
                Beam(n_bm, device=device, tokenizer=tokenizer)
                for _ in range(n_inst)
            ]
            # -- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)
            # -- Decode
            for len_dec_seq in range(1, args.max_words + 1):
                active_inst_idx_list = beam_decode_step(
                    decoder, inst_dec_beams, len_dec_seq,
                    inst_idx_to_position_map, n_bm, device,
                    (sequence_output_rpt, visual_output_rpt, input_ids_rpt,
                     input_mask_rpt, video_mask_rpt))

                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>

                (sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt), \
                inst_idx_to_position_map = collate_active_info((sequence_output_rpt, visual_output_rpt, input_ids_rpt, input_mask_rpt, video_mask_rpt),
                                                               inst_idx_to_position_map, active_inst_idx_list, n_bm, device)

            batch_hyp, batch_scores = collect_hypothesis_and_scores(
                inst_dec_beams, 1)
            result_list = [batch_hyp[i][0] for i in range(n_inst)]

            pairs_output_caption_ids = pairs_output_caption_ids.view(
                -1, pairs_output_caption_ids.shape[-1])
            caption_list = pairs_output_caption_ids.cpu().detach().numpy()

            for re_idx, re_list in enumerate(result_list):
                decode_text_list = tokenizer.convert_ids_to_tokens(re_list)
                if "[SEP]" in decode_text_list:
                    SEP_index = decode_text_list.index("[SEP]")
                    decode_text_list = decode_text_list[:SEP_index]
                if "[PAD]" in decode_text_list:
                    PAD_index = decode_text_list.index("[PAD]")
                    decode_text_list = decode_text_list[:PAD_index]
                decode_text = ' '.join(decode_text_list)
                decode_text = decode_text.replace(" ##",
                                                  "").strip("##").strip()
                all_result_lists.append(decode_text)

            for re_idx, re_list in enumerate(caption_list):
                decode_text_list = tokenizer.convert_ids_to_tokens(re_list)
                if "[SEP]" in decode_text_list:
                    SEP_index = decode_text_list.index("[SEP]")
                    decode_text_list = decode_text_list[:SEP_index]
                if "[PAD]" in decode_text_list:
                    PAD_index = decode_text_list.index("[PAD]")
                    decode_text_list = decode_text_list[:PAD_index]
                decode_text = ' '.join(decode_text_list)
                decode_text = decode_text.replace(" ##",
                                                  "").strip("##").strip()
                all_caption_lists.append(decode_text)

    # Save full results
    if test_set is not None and hasattr(test_set, 'iter2video_pairs_dict'):
        hyp_path = os.path.join(args.output_dir, "hyp_complete_results.txt")
        with open(hyp_path, "w", encoding='utf-8') as writer:
            writer.write("{}\t{}\t{}\n".format("video_id", "start_time",
                                               "caption"))
            for idx, pre_txt in enumerate(all_result_lists):
                video_id, sub_id = test_set.iter2video_pairs_dict[idx]
                start_time = test_set.data_dict[video_id]['start'][sub_id]
                writer.write("{}\t{}\t{}\n".format(video_id, start_time,
                                                   pre_txt))
        logger.info("File of complete results is saved in {}".format(hyp_path))

    # Save pure results
    hyp_path = os.path.join(args.output_dir, "hyp.txt")
    with open(hyp_path, "w", encoding='utf-8') as writer:
        for pre_txt in all_result_lists:
            writer.write(pre_txt + "\n")

    ref_path = os.path.join(args.output_dir, "ref.txt")
    with open(ref_path, "w", encoding='utf-8') as writer:
        for ground_txt in all_caption_lists:
            writer.write(ground_txt + "\n")

    if args.datatype == "msrvtt":
        all_caption_lists = []
        sentences_dict = test_dataloader.dataset.sentences_dict
        video_sentences_dict = test_dataloader.dataset.video_sentences_dict
        for idx in range(len(sentences_dict)):
            video_id, _ = sentences_dict[idx]
            sentences = video_sentences_dict[video_id]
            all_caption_lists.append(sentences)
        all_caption_lists = [list(itms) for itms in zip(*all_caption_lists)]
    else:
        all_caption_lists = [all_caption_lists]

    # Evaluate
    metrics_nlg = nlgEvalObj.compute_metrics(ref_list=all_caption_lists,
                                             hyp_list=all_result_lists)
    logger.info(
        ">>>  BLEU_1: {:.4f}, BLEU_2: {:.4f}, BLEU_3: {:.4f}, BLEU_4: {:.4f}".
        format(metrics_nlg["Bleu_1"], metrics_nlg["Bleu_2"],
               metrics_nlg["Bleu_3"], metrics_nlg["Bleu_4"]))
    logger.info(">>>  METEOR: {:.4f}, ROUGE_L: {:.4f}, CIDEr: {:.4f}".format(
        metrics_nlg["METEOR"], metrics_nlg["ROUGE_L"], metrics_nlg["CIDEr"]))

    Bleu_4 = metrics_nlg["Bleu_4"]
    return Bleu_4
Exemplo n.º 7
0
 def test_get_mcerd_params(self):
     beam = Beam(ion=Element.from_string("4He"), energy=14)
     self.assertEqual(["Beam ion: 4He", "Beam energy: 14 MeV"],
                      beam.get_mcerd_params())
Exemplo n.º 8
0
def get_beam() -> Beam:
    """Returns a default Beam object.
    """
    return Beam()
Exemplo n.º 9
0
    def generate(self, batch):
        """Run greedy decoding for sentence decoding (sp_dec), and beam search
        for token decoding (wd_dec)."""

        batch_size = len(batch["id"])
        enc_outs, enc_final = self.model.enc(batch["enc_src"],
                                             batch["enc_src_len"])

        memory_bank = tile(enc_outs, self.beam_size, dim=0)
        memory_lengths = tile(batch["enc_src_len"], self.beam_size)

        self.model.sp_dec.init_state(encoder_final=enc_final)
        self.model.wd_dec.init_state(encoder_final=enc_final)
        self.model.wd_dec.map_state(
            lambda state, dim: tile(state, self.beam_size, dim=dim))

        if self.use_goldstandard_plan:
            sp_outputs = self.sentence_decoding_goldstandard(
                ph_bank=batch['ph_bank_tensor'],
                ph_bank_len=batch['ph_bank_len_tensor'],
                ph_sel_tensor=batch['ph_sel_tensor'],
                stype=batch['sent_types'])
        else:
            sp_outputs = self.sentence_decoding_greedy(
                ph_bank=batch["ph_bank_tensor"],
                ph_bank_len=batch["ph_bank_len_tensor"])
        ph_sel_results = sp_outputs[0]
        stype_results = sp_outputs[1]
        sp_hidden_states = sp_outputs[2]
        _, max_pred_sent_num, sp_hidden_dim = sp_hidden_states.shape

        if not self.quiet:
            # print phrase selection and sentence type results
            self.print_sp_results(ph_sel_results, stype_results)

        softmax = nn.Softmax(dim=-1)
        beams = [
            Beam(self.beam_size,
                 vocab=self.vocab,
                 min_length=10,
                 block_ngram_repeat=self.block_ngram_repeat)
            for _ in range(batch_size)
        ]

        # first sentence has only <SOS>, therefore should only output one token
        max_token_in_sents = [1] + [
            self.max_token_per_sentence for _ in range(max_pred_sent_num - 1)
        ]

        for sent_id in range(max_pred_sent_num):
            cur_sp_dec_outs = sp_hidden_states[:, sent_id, :].unsqueeze(1)
            cur_sp_dec_outs_tile = tile(cur_sp_dec_outs, self.beam_size, 0)

            for step in range(max_token_in_sents[sent_id]):

                if all((b.done() for b in beams)):
                    break

                word_dec_input = torch.stack(
                    [b.get_current_state() for b in beams])
                word_dec_input = word_dec_input.view(-1, 1, 1).cuda()

                wd_outputs = self.model.wd_dec.forward_onestep(
                    dec_inputs=word_dec_input,
                    enc_memory_bank=memory_bank,
                    enc_memory_len=memory_lengths,
                    sp_hidden_state=cur_sp_dec_outs_tile)
                dec_logits = wd_outputs[0]
                dec_probs = softmax(dec_logits).view(batch_size,
                                                     self.beam_size, -1)

                select_indices_array = []
                for ix, beam in enumerate(beams):
                    beam.advance(probs=dec_probs[ix, :])
                    select_indices_array.append(beam.get_current_origin() +
                                                ix * self.beam_size)
                select_indices = torch.cat(select_indices_array)
                self.model.wd_dec.map_state(
                    lambda state, dim: state.index_select(dim, select_indices))

                if not self.quiet:
                    # print results in the top beam in all instances
                    to_print = f"sentence {sent_id:<2d} step-{step:<2d} | "

                    for ix, beam in enumerate(beams):
                        cur_words = beam.get_current_state()
                        top_beam = cur_words[0].item()
                        top_beam_word = self.vocab.get_word(idx=top_beam)
                        to_print += f" {top_beam_word:<10s} | "
                    print(to_print)

        results = []
        for b in beams:
            scores, ks = b.sort_finished(minimum=1)
            hyps = []
            for i, (times, k) in enumerate(ks[:1]):
                hyp = b.get_hyp(times, k)
                hyps.append([tok_id.item() for tok_id in hyp])
            results.append(hyps)

        stype_results_str = []
        for b in range(batch_size):
            cur_stype = [stype_step[b].item() for stype_step in stype_results]
            end = cur_stype.index(2) if 2 in cur_stype else len(cur_stype)
            cur_stype = cur_stype[:end]
            stype_results_str.append(
                [self.stype_map[item] for item in cur_stype])

        ph_sel_results_str = []
        for b in range(batch_size):
            cur_ph_sel = [ph_step[b] for ph_step in ph_sel_results[1:]]
            cur_ph_sel_str = []
            for sent in cur_ph_sel:
                sent = sent[sent.sum(-1) > 0]
                sent_str = [' '.join(self.vocab.decode(item)) for item in sent]
                cur_ph_sel_str.append(sent_str)
                if sent_str == ['EOS']:
                    break
            ph_sel_results_str.append(cur_ph_sel_str)

        return results, stype_results_str, ph_sel_results_str
Exemplo n.º 10
0
    def _gen_beam(self, keyword_ids, normal_vector, **kwargs):
        device = normal_vector.device
        _, batch_size, latent_dim = normal_vector.size()
        max_seq_len = self.opts.gen_max_seq_len
        beam_width = kwargs["beam_width"]
        length_norm = kwargs["length_norm"]
        n_best = kwargs["n_best"]

        input = torch.full((1, batch_size),
                           fill_value=self.SOS_token,
                           dtype=torch.long,
                           device=device)
        keyword_embs = self.embedding.forward_word_emb(keyword_ids)
        hidden = self._init_states(keyword_embs, "fwd")
        output_step = torch.zeros(batch_size * beam_width,
                                  dtype=torch.long,
                                  device=device)
        back_pointers = torch.zeros(batch_size * beam_width,
                                    dtype=torch.long,
                                    device=device)
        batch_beams = [
            Beam(beam_width, length_norm, self.EOS_token, n_best)
            for _ in range(batch_size)
        ]

        # first step
        logits_step, hidden = self._gen_forward_step(input,
                                                     hidden,
                                                     normal_vector[0],
                                                     use_cache=False)
        step_batch_beams(batch_beams,
                         logits_step,
                         output_step,
                         func="init_beams")

        # remain steps
        input = input.repeat_interleave(beam_width, dim=1)
        normal_vector = normal_vector.repeat_interleave(beam_width, dim=1)
        input = torch.cat([input, output_step.unsqueeze(0)], dim=0)
        hidden = hidden.repeat_interleave(beam_width, dim=1)
        for step in range(1, max_seq_len):
            logits_step, hidden = self._gen_forward_step(input,
                                                         hidden,
                                                         normal_vector[step],
                                                         use_cache=False)
            logits_step = logits_step.view(batch_size, beam_width, -1)
            step_batch_beams(batch_beams,
                             logits_step,
                             output_step,
                             back_pointers,
                             func="update_beams")
            if all(b.done for b in batch_beams):
                break
            input = input.index_select(dim=1, index=back_pointers)
            input = torch.cat([input, output_step.unsqueeze(0)], dim=0)
            hidden = hidden.index_select(dim=1, index=back_pointers)

        output = list(
            chain(*(beam.get_best_results()[0] for beam in batch_beams)))
        output = bidirectional_padding(output,
                                       self.PAD_token,
                                       0,
                                       device=device)[0]

        return output