def sample_beam(self, fc_feats, att_feats, opt={}):

        beam_size = opt.get('beam_size', 10)
        batch_size = att_feats.size(0)

        # fc_feats: batch_size * model_size
        # att_feats: batch_size * att_size * model_size
        fc_feats, att_feats = self.embed_feats(fc_feats, att_feats)

        # fc_feats: (batch_size * beam_size) * model_size
        new_fc_feats_size = (fc_feats.size(0) * beam_size, fc_feats.size(1))
        fc_feats = Variable(fc_feats.data.repeat(
            1, beam_size).view(new_fc_feats_size),
                            volatile=True)

        # att_feats: (batch_size * beam_size) * att_size * model_size
        new_output_enc_size = (att_feats.size(0) * beam_size,
                               att_feats.size(1), att_feats.size(2))
        output_enc = Variable(att_feats.data.repeat(
            1, beam_size, 1).view(new_output_enc_size),
                              volatile=True)

        # Prepare beams
        beams = [Beam(beam_size) for _ in range(batch_size)]
        beam_inst_idx_map = {
            beam_idx: inst_idx
            for inst_idx, beam_idx in enumerate(range(batch_size))
        }
        n_remaining_sents = batch_size

        # Decode
        for i in range(self.seq_length + 1):

            len_dec_seq = i + 1

            # (n_remaining_sents*beam_size) * len_dec_seq
            masks = torch.FloatTensor(n_remaining_sents * beam_size,
                                      len_dec_seq).fill_(1)
            masks = Variable(masks.cuda())

            # n_remaining_sents * beam_size * len_dec_seq
            dec_partial_seq = torch.stack(
                [b.get_current_state() for b in beams if not b.done])
            # (n_remaining_sents * beam_size) * len_dec_seq
            dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
            dec_partial_seq = Variable(dec_partial_seq, volatile=True).cuda()

            # size: 1 * len_dec_1_seq
            dec_partial_pos = torch.arange(1, len_dec_seq + 1).unsqueeze(0)
            # size: (n_remaining_sents * beam_size) * len_dec_seq
            dec_partial_pos = dec_partial_pos.repeat(
                n_remaining_sents * beam_size, 1)
            dec_partial_pos = Variable(dec_partial_pos.type(torch.LongTensor),
                                       volatile=True).cuda()

            # dec_partial_seq: (n_remaining_sents * beam_size) * len_dec_seq
            # dec_partial_pos: (n_remaining_sents * beam_size) * len_dec_1_seq
            # output_enc: (n_remaining_sents * beam_size) * len_q1 * model_size
            # masks: (n_remaining_sents * beam_size) * len_dec_1_seq
            # output_dec: (n_remaining_sents * beam_size) * len_dec_seq * model_size
            output_dec = self.decoder(dec_partial_seq, dec_partial_pos,
                                      fc_feats, output_enc, masks)

            # (n_remaining_sents * beam_size) * model_size
            output_dec = output_dec[:, -1, :]

            # (n_remaining_sents * beam_size) * (vocab_size+1)
            output = F.log_softmax(self.proj(output_dec), -1)

            # n_remaining_sents * beam_size * (vocab_size+1)
            word_lk = output.view(n_remaining_sents, beam_size,
                                  -1).contiguous()

            active_beam_idx_list = []
            for beam_idx in range(batch_size):
                if beams[beam_idx].done:
                    continue

                inst_idx = beam_inst_idx_map[beam_idx]
                if not beams[beam_idx].advance(word_lk.data[inst_idx]):
                    active_beam_idx_list += [beam_idx]

            if not active_beam_idx_list:
                # all instances have finished their path to <EOS>
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            active_inst_idxs = [
                beam_inst_idx_map[k] for k in active_beam_idx_list
            ]
            active_inst_idxs = torch.LongTensor(active_inst_idxs).cuda()

            # update the idx mapping
            beam_inst_idx_map = {
                beam_idx: inst_idx
                for inst_idx, beam_idx in enumerate(active_beam_idx_list)
            }

            # enc_info_var: (n_remaining_sents * beam_size) * len_q1 * model_size
            def update_active_enc_info(enc_info_var, active_inst_idxs):
                ''' Remove the encoder outputs of finished instances in one batch. '''

                inst_idx_dim_size, rest_dim_size1, rest_dim_size2 = enc_info_var.size(
                )
                inst_idx_dim_size = inst_idx_dim_size * len(
                    active_inst_idxs) // n_remaining_sents
                new_size = (inst_idx_dim_size, rest_dim_size1, rest_dim_size2)

                # select the active instances in batch
                # original_enc_info_data: n_remaining_sents * (beam_size * len_q1) * model_size
                original_enc_info_data = enc_info_var.data.view(
                    n_remaining_sents, -1, self.model_size)
                active_enc_info_data = original_enc_info_data.index_select(
                    0, active_inst_idxs)
                active_enc_info_data = active_enc_info_data.view(*new_size)

                # active_enc_info_data: (inst_idx_dim_size * beam_size) * len_q1 * model_size
                return Variable(active_enc_info_data, volatile=True)

            # enc_info_var: (n_remaining_sents * beam_size) * model_size
            def update_active_fc_feats(enc_info_var, active_inst_idxs):
                ''' Remove the encoder outputs of finished instances in one batch. '''

                inst_idx_dim_size, rest_dim_size1 = enc_info_var.size()
                inst_idx_dim_size = inst_idx_dim_size * len(
                    active_inst_idxs) // n_remaining_sents
                new_size = (inst_idx_dim_size, rest_dim_size1)

                # select the active instances in batch
                # original_enc_info_data: n_remaining_sents * beam_size * model_size
                original_enc_info_data = enc_info_var.data.view(
                    n_remaining_sents, -1, self.model_size)
                active_enc_info_data = original_enc_info_data.index_select(
                    0, active_inst_idxs)
                active_enc_info_data = active_enc_info_data.view(new_size)

                # active_enc_info_data: (inst_idx_dim_size * beam_size) * len_q1 * model_size
                return Variable(active_enc_info_data, volatile=True)

            # fc_feats: (inst_idx_dim_size * beam_size) * model_size
            fc_feats = update_active_fc_feats(fc_feats, active_inst_idxs)

            # output_enc: (inst_idx_dim_size * beam_size) * len_q1 * model_size
            output_enc = update_active_enc_info(output_enc, active_inst_idxs)

            # - update the remaining size
            n_remaining_sents = len(active_inst_idxs)

        # - Return useful information
        # batch_size * len_q * n_best
        all_hyp, all_scores = [], []
        n_best = self.n_best

        for beam_idx in range(batch_size):
            scores, tail_idxs = beams[beam_idx].sort_scores()
            all_scores += [scores[:n_best]]

            hyps = [
                beams[beam_idx].get_hypothesis(i) for i in tail_idxs[:n_best]
            ]
            all_hyp += [hyps]

        seq = torch.LongTensor(batch_size, self.seq_length + 1).zero_()
        for i in range(batch_size):
            for j in range(len(all_hyp[i][0])):
                seq[i, j] = all_hyp[i][0][j]

        # batch_size * seq_len
        seqLogprobs = all_scores

        return seq, seqLogprobs
Esempio n. 2
0
    def translate_batch(self, src_batch, beam):
        self.model.eval()

        # Beam size, or beam width, is a parameter in the beam search algorithm which determines how many of
        # the best partial solutions to evaluate. In an LSTM model of melody generation, for example, beam size limits the number of
        # candidates to take as input for the decoder. A beam size of 1 is a best-first search - only the most probable candidate is chosen
        # as input for the decoder. A beam size of k will decode and evaluate the top k candidates. A large beam size means a more
        # extensive search - not only the single best candidate is evaluated.

        # Batch size is in different location depending on data.
        src_seq, src_pos = src_batch
        batch_size = src_seq.size(0)
        beam_size = beam
        # beam_size = self.trans_opt.beam_size
        print("beam_size")
        print(beam_size)

        # Encode
        enc_outputs, enc_slf_attns = self.model.encoder(
            src_seq, src_pos)  # enc_outputs, beam search, decoder

        # Repeat data for beam
        src_seq = Variable(src_seq.data.repeat(beam_size, 1))
        enc_outputs = [
            Variable(enc_output.data.repeat(beam_size, 1, 1))
            for enc_output in enc_outputs
        ]

        # Prepare beams
        beam = [Beam.Beam(beam_size, self.cuda) for k in range(batch_size)]
        batch_idx = list(range(batch_size))
        n_remaining_sents = batch_size

        # A larger beam generally means a more accurate prediction at the expense of memory and time
        # Beam search is a heuristic search algorithm that uses breadth-first search to build its search tree and reduces the search space
        # by eliminating candidates to reduce the memory and time requirements.

        # Decode
        for i in range(self.trans_opt.max_trans_length):

            len_dec_seq = i + 1

            # -- Preparing decode data seq -- #
            input_data = torch.stack([
                b.get_current_state() for b in beam if not b.done
            ])  # size: mb x bm x sq
            input_data = input_data.view(-1, len_dec_seq)  # size: (mb*bm) x sq
            input_data = Variable(input_data, volatile=True)

            # -- Preparing decode pos seq -- #
            # size: 1 x seq
            input_pos = torch.arange(1, len_dec_seq + 1).unsqueeze(0)
            # size: (batch * beam) x seq
            input_pos = input_pos.repeat(n_remaining_sents * beam_size, 1)
            input_pos = Variable(input_pos.type(torch.LongTensor),
                                 volatile=True)

            if self.cuda:
                input_pos = input_pos.cuda()
                input_data = input_data.cuda()

            # -- Decoding -- #
            dec_outputs, dec_slf_attns, dec_enc_attns = self.model.decoder(
                input_data, input_pos, src_seq, enc_outputs)
            dec_output = dec_outputs[-1][:, -1, :]  # (batch * beam) * d_model
            dec_output = self.model.tgt_word_proj(dec_output)
            out = self.model.prob_projection(dec_output)

            # batch x beam x n_words
            word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous()

            active = []
            for b in range(batch_size):
                if beam[b].done:
                    continue

                idx = batch_idx[b]
                if not beam[b].advance(word_lk.data[idx]):
                    active += [b]

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            active_idx = self.tt.LongTensor([batch_idx[k] for k in active])
            batch_idx = {beam: idx for idx, beam in enumerate(active)}

            def update_active_enc_info(tensor_var, active_idx):
                ''' Remove the encoder outputs of finished instances in one batch. '''
                tensor_data = tensor_var.data.view(n_remaining_sents, -1,
                                                   self.model_opt.d_model)

                new_size = list(tensor_var.size())
                new_size[0] = new_size[0] * len(
                    active_idx) // n_remaining_sents

                # select the active index in batch
                return Variable(tensor_data.index_select(
                    0, active_idx).view(*new_size),
                                volatile=True)

            def update_active_seq(seq, active_idx):
                ''' Remove the src sequence of finished instances in one batch. '''
                view = seq.data.view(n_remaining_sents, -1)
                new_size = list(seq.size())
                new_size[0] = new_size[0] * len(
                    active_idx) // n_remaining_sents  # trim on batch dim

                # select the active index in batch
                return Variable(view.index_select(0,
                                                  active_idx).view(*new_size),
                                volatile=True)

            src_seq = update_active_seq(src_seq, active_idx)
            enc_outputs = [
                update_active_enc_info(enc_output, active_idx)
                for enc_output in enc_outputs
            ]
            n_remaining_sents = len(active)

        # Return useful information
        all_hyp, all_scores = [], []
        n_best = self.trans_opt.n_best

        for b in range(batch_size):
            scores, ks = beam[b].sort_scores()
            all_scores += [scores[:n_best]]
            hyps = [beam[b].get_hypothesis(k) for k in ks[:n_best]]
            all_hyp += [hyps]

        decoded = [self.trans_opt.ctable.decode(hyps[0])
                   for hyps in all_hyp]  # all_hyp

        return decoded, all_hyp, all_scores, enc_outputs, dec_outputs, enc_slf_attns, dec_slf_attns, dec_enc_attns