Exemple #1
0
    def translate_batch(self, src_seq, src_pos):
        ''' Translation work in one batch '''
        def get_inst_idx_to_tensor_position_map(inst_idx_list):
            ''' Indicate the position of an instance in a tensor. '''
            return {
                inst_idx: tensor_position
                for tensor_position, inst_idx in enumerate(inst_idx_list)
            }

        def collect_active_part(beamed_tensor, curr_active_inst_idx,
                                n_prev_active_inst, n_bm):
            ''' Collect tensor parts associated to active instances. '''

            _, *d_hs = beamed_tensor.size()
            n_curr_active_inst = len(curr_active_inst_idx)
            new_shape = (n_curr_active_inst * n_bm, *d_hs)

            beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
            beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
            beamed_tensor = beamed_tensor.view(*new_shape)

            return beamed_tensor

        def collate_active_info(src_seq, src_enc, inst_idx_to_position_map,
                                active_inst_idx_list):
            # Sentences which are still active are collected,
            # so the decoder will not run on completed sentences.
            n_prev_active_inst = len(inst_idx_to_position_map)
            active_inst_idx = [
                inst_idx_to_position_map[k] for k in active_inst_idx_list
            ]
            active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)

            active_src_seq = collect_active_part(src_seq, active_inst_idx,
                                                 n_prev_active_inst, n_bm)
            active_src_enc = collect_active_part(src_enc, active_inst_idx,
                                                 n_prev_active_inst, n_bm)
            active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            return active_src_seq, active_src_enc, active_inst_idx_to_position_map

        def beam_decode_step(inst_dec_beams, len_dec_seq, src_seq, enc_output,
                             inst_idx_to_position_map, n_bm):
            ''' Decode and update beam status, and then return active beam idx '''
            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
                dec_partial_seq = [
                    b.get_current_state() for b in inst_dec_beams if not b.done
                ]
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
                return dec_partial_seq

            def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
                dec_partial_pos = torch.arange(1,
                                               len_dec_seq + 1,
                                               dtype=torch.long,
                                               device=self.device)
                dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(
                    n_active_inst * n_bm, 1)
                return dec_partial_pos

            def predict_word(dec_seq, dec_pos, src_seq, enc_output,
                             n_active_inst, n_bm):
                dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq,
                                                    enc_output)
                dec_output = dec_output[:,
                                        -1, :]  # Pick the last step: (bh * bm) * d_h
                word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output),
                                          dim=1)
                word_prob = word_prob.view(n_active_inst, n_bm, -1)

                return word_prob

            def collect_active_inst_idx_list(inst_beams, word_prob,
                                             inst_idx_to_position_map):
                active_inst_idx_list = []
                for inst_idx, inst_position in inst_idx_to_position_map.items(
                ):
                    is_inst_complete = inst_beams[inst_idx].advance(
                        word_prob[inst_position])
                    if not is_inst_complete:
                        active_inst_idx_list += [inst_idx]

                return active_inst_idx_list

            n_active_inst = len(inst_idx_to_position_map)

            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
            dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
            word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output,
                                     n_active_inst, n_bm)

            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            return active_inst_idx_list

        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                all_scores += [scores[:n_best]]

                hyps = [
                    inst_dec_beams[inst_idx].get_hypothesis(i)
                    for i in tail_idxs[:n_best]
                ]
                all_hyp += [hyps]
            return all_hyp, all_scores

        with torch.no_grad():
            #-- Encode
            src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device)
            src_enc, *_ = self.model.encoder(src_seq, src_pos)

            #-- Repeat data for beam search
            n_bm = self.opt.beam_size
            n_inst, len_s, d_h = src_enc.size()
            src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
            src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s,
                                                      d_h)

            #-- Prepare beams
            inst_dec_beams = [
                Beam(n_bm, device=self.device) for _ in range(n_inst)
            ]

            #-- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            #-- Decode
            for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1):

                active_inst_idx_list = beam_decode_step(
                    inst_dec_beams, len_dec_seq, src_seq, src_enc,
                    inst_idx_to_position_map, n_bm)

                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>

                src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
                    src_seq, src_enc, inst_idx_to_position_map,
                    active_inst_idx_list)

        batch_hyp, batch_scores = collect_hypothesis_and_scores(
            inst_dec_beams, self.opt.n_best)

        return batch_hyp, batch_scores
Exemple #2
0
    def translate_batch(self, src_batch):
        ''' Translation work in one batch '''

        # Batch size is in different location depending on data.
        src_seq, src_pos = src_batch
        batch_size = src_seq.size(0)
        beam_size = self.opt.beam_size

        #- Enocde
        enc_output, *_ = self.model.encoder(src_seq, src_pos)

        #--- Repeat data for beam
        src_seq = Variable(
            src_seq.data.repeat(1, beam_size).view(
                src_seq.size(0) * beam_size, src_seq.size(1)))

        enc_output = Variable(
            enc_output.data.repeat(1, beam_size, 1).view(
                enc_output.size(0) * beam_size, enc_output.size(1), enc_output.size(2)))

        #--- Prepare beams
        beams = [Beam(beam_size, self.opt.cuda) for _ in range(batch_size)]
        beam_inst_idx_map = {
            beam_idx: inst_idx for inst_idx, beam_idx in enumerate(range(batch_size))}
        n_remaining_sents = batch_size

        #- Decode
        for i in range(self.model_opt.max_token_seq_len):

            len_dec_seq = i + 1

            # -- Preparing decoded data seq -- #
            # size: batch x beam x seq
            dec_partial_seq = torch.stack([
                b.get_current_state() for b in beams if not b.done])
            # size: (batch * beam) x seq
            dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
            # wrap into a Variable
            dec_partial_seq = Variable(dec_partial_seq, volatile=True)

            # -- Preparing decoded pos seq -- #
            # size: 1 x seq
            dec_partial_pos = torch.arange(1, len_dec_seq + 1).unsqueeze(0)
            # size: (batch * beam) x seq
            dec_partial_pos = dec_partial_pos.repeat(n_remaining_sents * beam_size, 1)
            # wrap into a Variable
            dec_partial_pos = Variable(dec_partial_pos.type(torch.LongTensor), volatile=True)

            if self.opt.cuda:
                dec_partial_seq = dec_partial_seq.cuda()
                dec_partial_pos = dec_partial_pos.cuda()

            # -- Decoding -- #
            dec_output, *_ = self.model.decoder(
                dec_partial_seq, dec_partial_pos, src_seq, enc_output)
            dec_output = dec_output[:, -1, :] # (batch * beam) * d_model
            dec_output = self.model.tgt_word_proj(dec_output)
            out = self.model.prob_projection(dec_output)

            # batch x beam x n_words
            word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous()

            active_beam_idx_list = []
            for beam_idx in range(batch_size):
                if beams[beam_idx].done:
                    continue

                inst_idx = beam_inst_idx_map[beam_idx]
                if not beams[beam_idx].advance(word_lk.data[inst_idx]):
                    active_beam_idx_list += [beam_idx]

            if not active_beam_idx_list:
                # all instances have finished their path to <EOS>
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            active_inst_idxs = self.tt.LongTensor(
                [beam_inst_idx_map[k] for k in active_beam_idx_list])

            # update the idx mapping
            beam_inst_idx_map = {
                beam_idx: inst_idx for inst_idx, beam_idx in enumerate(active_beam_idx_list)}

            def update_active_seq(seq_var, active_inst_idxs):
                ''' Remove the src sequence of finished instances in one batch. '''

                inst_idx_dim_size, *rest_dim_sizes = seq_var.size()
                inst_idx_dim_size = inst_idx_dim_size * len(active_inst_idxs) // n_remaining_sents
                new_size = (inst_idx_dim_size, *rest_dim_sizes)

                # select the active instances in batch
                original_seq_data = seq_var.data.view(n_remaining_sents, -1)
                active_seq_data = original_seq_data.index_select(0, active_inst_idxs)
                active_seq_data = active_seq_data.view(*new_size)

                return Variable(active_seq_data, volatile=True)

            def update_active_enc_info(enc_info_var, active_inst_idxs):
                ''' Remove the encoder outputs of finished instances in one batch. '''

                inst_idx_dim_size, *rest_dim_sizes = enc_info_var.size()
                inst_idx_dim_size = inst_idx_dim_size * len(active_inst_idxs) // n_remaining_sents
                new_size = (inst_idx_dim_size, *rest_dim_sizes)

                # select the active instances in batch
                original_enc_info_data = enc_info_var.data.view(
                    n_remaining_sents, -1, self.model_opt.d_model)
                active_enc_info_data = original_enc_info_data.index_select(0, active_inst_idxs)
                active_enc_info_data = active_enc_info_data.view(*new_size)

                return Variable(active_enc_info_data, volatile=True)

            src_seq = update_active_seq(src_seq, active_inst_idxs)
            enc_output = update_active_enc_info(enc_output, active_inst_idxs)

            #- update the remaining size
            n_remaining_sents = len(active_inst_idxs)

        #- Return useful information
        all_hyp, all_scores = [], []
        n_best = self.opt.n_best

        for beam_idx in range(batch_size):
            scores, tail_idxs = beams[beam_idx].sort_scores()
            all_scores += [scores[:n_best]]

            hyps = [beams[beam_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
            all_hyp += [hyps]

        return all_hyp, all_scores
Exemple #3
0
    def translate_batch(self, raw_src_seq, raw_src_pos, block_list=[]):
        ''' Translation work in one batch '''
        def get_inst_idx_to_tensor_position_map(inst_idx_list):
            ''' Indicate the position of an instance in a tensor. '''
            return {
                inst_idx: tensor_position
                for tensor_position, inst_idx in enumerate(inst_idx_list)
            }

        def collect_active_part(beamed_tensor, curr_active_inst_idx,
                                n_prev_active_inst, n_bm):
            ''' Collect tensor parts associated to active instances. '''

            _, *d_hs = beamed_tensor.size()
            n_curr_active_inst = len(curr_active_inst_idx)
            new_shape = (n_curr_active_inst * n_bm, *d_hs)

            beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
            beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
            beamed_tensor = beamed_tensor.view(*new_shape)

            return beamed_tensor

        def collate_active_info(src_seq, src_enc, inst_idx_to_position_map,
                                active_inst_idx_list):
            # Sentences which are still active are collected,
            # so the decoder will not run on completed sentences.
            n_prev_active_inst = len(inst_idx_to_position_map)
            active_inst_idx = [
                inst_idx_to_position_map[k] for k in active_inst_idx_list
            ]
            active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)

            active_src_seq = collect_active_part(src_seq, active_inst_idx,
                                                 n_prev_active_inst, n_bm)
            active_src_enc = collect_active_part(src_enc, active_inst_idx,
                                                 n_prev_active_inst, n_bm)
            # print(active_src_enc.shape)
            if hasattr(self.model.encoder, 'ntm'):
                self.model.encoder.ntm.previous_state = list(
                    self.model.encoder.ntm.previous_state)
                self.model.encoder.ntm.previous_state[1] = list(
                    self.model.encoder.ntm.previous_state[1])
                memory = self.model.encoder.ntm.memory
                self.model.encoder.ntm.memory.memory = collect_active_part(
                    memory.memory.view(n_prev_active_inst * n_bm,
                                       -1), active_inst_idx,
                    n_prev_active_inst, n_bm).view(-1, memory.N, memory.M)
                self.model.encoder.ntm.memory.batch_size = self.model.encoder.ntm.memory.memory.shape[
                    0]
                # print(self.model.encoder.ntm.memory.memory.shape, self.model.encoder.ntm.memory.batch_size)
                for i in range(len(self.model.encoder.ntm.previous_state)):
                    for j, tensor in enumerate(
                            self.model.encoder.ntm.previous_state[i]):
                        # print(i, j, tensor.shape)
                        squeezed = False
                        if len(tensor.shape) == 3:
                            dim0, dim1, dim2 = tensor.shape  # dim1 = n_prev_active_inst*n_bm
                            tensor = torch.transpose(tensor, 0,
                                                     1).contiguous().view(
                                                         dim1, -1)
                            # tensor = tensor.squeeze(0)
                            squeezed = True
                        new_tensor = collect_active_part(
                            tensor, active_inst_idx, n_prev_active_inst, n_bm)
                        if squeezed:
                            new_tensor = torch.transpose(
                                new_tensor.contiguous().view(-1, dim0, dim2),
                                0, 1).contiguous()

                        # print(new_tensor.shape)
                        self.model.encoder.ntm.previous_state[i][
                            j] = new_tensor

            # active_src_enc.register_hook(print_grad('active src enc'))
            active_src_enc[torch.isnan(active_src_enc)] = 0
            active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            return active_src_seq, active_src_enc, active_inst_idx_to_position_map

        def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
            dec_partial_seq = [
                b.get_current_state() for b in inst_dec_beams if not b.done
            ]
            dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
            dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
            return dec_partial_seq

        def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
            dec_partial_pos = torch.arange(1,
                                           len_dec_seq + 1,
                                           dtype=torch.long,
                                           device=self.device)
            dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(
                n_active_inst * n_bm, 1)
            return dec_partial_pos

        def predict_word(decoder, tgt_word_prj, dec_seq, dec_pos, src_seq,
                         enc_output, n_active_inst, n_bm):
            """decoder is added as an argument, compared to the original version"""
            # sometimes the output is only [0] the pad token

            dec_output, *_ = decoder(dec_seq, dec_pos, src_seq, enc_output)
            dec_output_last = dec_output[:,
                                         -1, :]  # Pick the last step: (bh * bm) * d_h
            # dec_output[torch.isnan(dec_output)] = 0

            # dec_output.register_hook(print_grad('dec_output in predict_word {} data'.format(dec_seq)))
            # print(dec_output)

            # gcl decoder
            # k = enc_output.clone()
            # k = F.max_pool1d(k.permute(0, 2, 1), k.shape[-2]).squeeze()
            # x = dec_output.clone()
            # x = F.max_pool1d(x.permute(0, 2, 1), x.shape[-2]).squeeze()
            #
            # gcl_output = self.model.gcl(k.unsqueeze(0).detach(), x.unsqueeze(0), bidirectional=False, save_attn=False)
            #
            # dec_output_last = torch.cat([dec_output_last, gcl_output.squeeze()], -1)

            word_prob = F.log_softmax(tgt_word_prj(dec_output_last), dim=1)
            # word_prob.register_hook(print_grad('word prob in predict_word'))
            word_prob = word_prob.view(n_active_inst, n_bm, -1)
            if block_list != []:
                for block_tok in block_list:
                    word_prob[:, :, block_tok] = -1000.

            return word_prob

        def collect_active_inst_idx_list(inst_beams, word_prob,
                                         inst_idx_to_position_map):
            """get indexes of instances that have not been fully translated yet"""
            active_inst_idx_list = []
            for inst_idx, inst_position in inst_idx_to_position_map.items():
                is_inst_complete = inst_beams[inst_idx].advance(
                    word_prob[inst_position])
                if not is_inst_complete:
                    active_inst_idx_list += [inst_idx]

            return active_inst_idx_list

        def beam_decode_step(decoder, tgt_word_prj, inst_dec_beams,
                             len_dec_seq, src_seq, enc_output,
                             inst_idx_to_position_map, n_bm):
            ''' Decode and update beam status, and then return active beam idx
                decoder is added as an argument, compared to the original version
            '''

            # enc_output.register_hook(print_grad('enc output'))
            n_active_inst = len(inst_idx_to_position_map)

            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
            dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
            word_prob = predict_word(decoder, tgt_word_prj, dec_seq, dec_pos,
                                     src_seq, enc_output, n_active_inst, n_bm)
            # word_prob.register_hook(print_grad('word prob in beam decode'))  # grad ok

            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            return active_inst_idx_list

        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                # scores.register_hook(print_grad('scores from collect hypothses'))
                all_scores = all_scores + [scores[:n_best]]

                hyps = [
                    inst_dec_beams[inst_idx].get_hypothesis(i)
                    for i in tail_idxs[:n_best]
                ]
                all_hyp += [hyps]
            return all_hyp, all_scores

        if self.opt.bi:
            decoders = [self.model.decoder_lr, self.model.decoder_rl]
            tgt_word_prjs = [
                self.model.tgt_word_prj_lr, self.model.tgt_word_prj_rl
            ]
        else:
            decoders = [self.model.decoder]
            tgt_word_prjs = [self.model.tgt_word_prj]
        batch_hyp_list = []  # list of results from each decoder
        batch_scores_list = []

        n_bm = self.opt.beam_size

        # -- Decode
        for decoder, tgt_word_prj in zip(
                decoders,
                tgt_word_prjs):  # two decoders for bidirectional model
            src_seq = copy.copy(raw_src_seq)
            src_pos = copy.copy(raw_src_pos)

            # -- Encode
            src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device)
            if hasattr(self.model.encoder, 'ntm'):
                n_inst, len_s = src_seq.size()
                src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
            src_enc, *_ = self.model.encoder(src_seq, src_pos)
            # print(self.model.encoder.ntm.memory.N, self.model.encoder.ntm.memory.memory.shape)
            # print(type(self.model.encoder.ntm.previous_state[0]))
            # print(type(self.model.encoder.ntm.previous_state[1]))
            # print(type(self.model.encoder.ntm.previous_state[2]))

            #-- Repeat data for beam search
            if len(src_enc.size()) == 3:
                n_inst, len_s, d_h = src_enc.size()
                src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
                src_enc = src_enc.unsqueeze(1).expand(-1, n_bm, -1,
                                                      -1).contiguous().view(
                                                          n_inst * n_bm, len_s,
                                                          d_h)

            # src_enc.register_hook(print_grad('{}, src_enc'.format(src_enc.size())))

            #-- Prepare beams
            inst_dec_beams = [
                Beam(n_bm, device=self.device) for _ in range(n_inst)
            ]

            # -- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1):

                # if len_dec_seq > 30:  # abnormally long seq
                #     print(len_dec_seq)

                active_inst_idx_list = beam_decode_step(
                    decoder, tgt_word_prj, inst_dec_beams, len_dec_seq,
                    src_seq, src_enc, inst_idx_to_position_map, n_bm)

                # src_enc[torch.isnan(src_enc)] = 0
                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>

                src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
                    src_seq, src_enc, inst_idx_to_position_map,
                    active_inst_idx_list)
                # src_enc.register_hook(print_grad('active src enc'))  # GRAD OK HERE

            # batch_hyp is a nested list of [batches [n_best seqs] ]
            batch_hyp, batch_scores = collect_hypothesis_and_scores(
                inst_dec_beams, self.opt.n_best)
            # print('\n')
            # print(batch_scores)
            batch_hyp_list.append(batch_hyp)
            batch_scores_list.append(batch_scores)

        return batch_hyp_list, batch_scores_list
Exemple #4
0
    def translate_batch(self, src_batch):
        ''' Translation work in one batch '''

        # Batch size is in different location depending on data.
        src_seq, src_pos = src_batch
        batch_size = src_seq.size(0)
        beam_size = self.opt.beam_size

        #- Enocde
        enc_outputs, enc_slf_attns = self.model.encoder(src_seq, src_pos)

        #--- Repeat data for beam
        src_seq = Variable(src_seq.data.repeat(beam_size, 1))  # ERROR`
        enc_outputs = [
            Variable(enc_output.data.repeat(beam_size, 1, 1))
            for enc_output in enc_outputs
        ]

        #--- Prepare beams
        beam = [Beam(beam_size, self.opt.cuda) for k in range(batch_size)]
        batch_idx = list(range(batch_size))
        n_remaining_sents = batch_size

        #- Decode
        for i in range(self.model_opt.max_token_seq_len):

            len_dec_seq = i + 1

            # -- Preparing decode data seq -- #
            input_data = torch.stack([
                b.get_current_state() for b in beam if not b.done
            ])  # size: mb x bm x sq

            input_data = input_data.view(-1, len_dec_seq)  # size: (mb*bm) x sq
            input_data = Variable(input_data, volatile=True)

            # -- Preparing decode pos seq -- #
            # size: 1 x seq
            input_pos = torch.arange(1, len_dec_seq + 1).unsqueeze(0)
            # size: (batch * beam) x seq
            input_pos = input_pos.repeat(n_remaining_sents * beam_size, 1)
            input_pos = Variable(input_pos.type(torch.LongTensor),
                                 volatile=True)

            if self.opt.cuda:
                input_pos = input_pos.cuda()
                input_data = input_data.cuda()

            # -- Decoding -- #
            dec_outputs, dec_slf_attns, dec_enc_attns = self.model.decoder(
                input_data, input_pos, src_seq, enc_outputs)
            dec_output = dec_outputs[-1][:, -1, :]  # (batch * beam) * d_model
            dec_output = self.model.tgt_word_proj(dec_output)
            out = self.model.prob_projection(dec_output)

            # batch x beam x n_words
            word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous()

            active = []
            for b in range(batch_size):
                if beam[b].done:
                    continue

                idx = batch_idx[b]
                if not beam[b].advance(word_lk.data[idx]):
                    active += [b]

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            active_idx = self.tt.LongTensor([batch_idx[k] for k in active])
            batch_idx = {beam: idx for idx, beam in enumerate(active)}

            def update_active_enc_info(tensor_var, active_idx):
                ''' Remove the encoder outputs of finished instances in one batch. '''
                tensor_data = tensor_var.data.view(n_remaining_sents, -1,
                                                   self.model_opt.d_model)

                new_size = list(tensor_var.size())
                new_size[0] = new_size[0] * len(
                    active_idx) // n_remaining_sents

                # select the active index in batch
                return Variable(tensor_data.index_select(
                    0, active_idx).view(*new_size),
                                volatile=True)

            def update_active_seq(seq, active_idx):
                ''' Remove the src sequence of finished instances in one batch. '''
                view = seq.data.view(n_remaining_sents, -1)
                new_size = list(seq.size())
                new_size[0] = new_size[0] * len(
                    active_idx) // n_remaining_sents  # trim on batch dim

                # select the active index in batch
                return Variable(view.index_select(0,
                                                  active_idx).view(*new_size),
                                volatile=True)

            src_seq = update_active_seq(src_seq, active_idx)
            enc_outputs = [
                update_active_enc_info(enc_output, active_idx)
                for enc_output in enc_outputs
            ]
            n_remaining_sents = len(active)

        #- Return useful information
        all_hyp, all_scores = [], []
        n_best = self.opt.n_best

        for b in range(batch_size):
            scores, ks = beam[b].sort_scores()
            all_scores += [scores[:n_best]]
            hyps = [beam[b].get_hypothesis(k) for k in ks[:n_best]]
            all_hyp += [hyps]

        return all_hyp, all_scores
Exemple #5
0
    def translate_batch(self, src_batch):
        ''' Translation work in one batch '''

        # Batch size is in different location depending on data.
        if self.model_opt.use_ctx:
            (src_seq, src_pos), (ctx_seq, ctx_pos) = src_batch
        else:
            src_seq, src_pos = src_batch
        batch_size = src_seq.size(0)
        beam_size = self.opt.beam_size

        #- Encode
        enc_outputs, enc_slf_attns = self.model.encoder(src_seq, src_pos)
        enc_output = enc_outputs[-1]

        #--- Repeat data for beam
        src_seq = Variable(src_seq.data.repeat(beam_size, 1))
        enc_output = Variable(enc_output.data.repeat(beam_size, 1, 1))

        if self.model_opt.use_ctx:
            #- Encode
            ctx_outputs, ctx_slf_attns = self.model.encoder_ctx(
                ctx_seq, ctx_pos)
            ctx_output = ctx_outputs[-1]

            #--- Repeat data for beam
            ctx_seq = Variable(ctx_seq.data.repeat(beam_size, 1))
            ctx_output = Variable(ctx_output.data.repeat(beam_size, 1, 1))

        #--- Prepare beams
        beam = [Beam(beam_size, self.opt.cuda) for k in range(batch_size)]
        batch_idx = list(range(batch_size))
        n_remaining_sents = batch_size

        #- Decode
        for i in range(self.model_opt.max_token_seq_len):

            len_dec_seq = i + 1

            # -- Preparing decode data seq -- #
            input_data = torch.stack([
                b.get_current_state() for b in beam if not b.done
            ])  # size: mb x bm x sq

            input_data = input_data.permute(1, 0, 2).contiguous()
            input_data = input_data.view(-1, len_dec_seq)  # size: (mb*bm) x sq
            input_data = Variable(input_data, volatile=True)

            # -- Preparing decode pos seq -- #
            # size: 1 x seq
            input_pos = torch.arange(1, len_dec_seq + 1).unsqueeze(0)
            # size: (batch * beam) x seq
            input_pos = input_pos.repeat(n_remaining_sents * beam_size, 1)
            input_pos = Variable(input_pos.type(torch.LongTensor),
                                 volatile=True)

            if self.opt.cuda:
                input_pos = input_pos.cuda()
                input_data = input_data.cuda()

            # -- Decoding -- #
            if self.model_opt.use_ctx:
                dec_outputs, dec_slf_attns, dec_enc_attns, dec_ctx_attns = self.model.decoder(
                    input_data, input_pos, src_seq, enc_output, ctx_seq,
                    ctx_output)
            else:
                dec_outputs, dec_slf_attns, dec_enc_attns = self.model.decoder(
                    input_data, input_pos, src_seq, enc_output)
            dec_output = dec_outputs[-1][:, -1, :]  # (batch * beam) * d_model
            dec_output = self.model.tgt_word_proj(dec_output)
            out = self.model.prob_projection(dec_output)

            # batch x beam x n_words
            word_lk = out.view(beam_size, n_remaining_sents, -1).contiguous()

            active = []
            for b in range(batch_size):
                if beam[b].done:
                    continue

                idx = batch_idx[b]
                if not beam[b].advance(word_lk.data[:, idx]):
                    active += [b]

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            active_idx = self.tt.LongTensor([batch_idx[k] for k in active])
            batch_idx = {beam: idx for idx, beam in enumerate(active)}

            def update_active_enc_info(tensor_var, active_idx):
                ''' Remove the encoder outputs of finished instances in one batch. '''
                batch = tensor_var.data[:n_remaining_sents]
                selected = batch.index_select(0, active_idx)
                data = selected.repeat(beam_size, 1, 1)
                return Variable(data, volatile=True)

            def update_active_seq(seq, active_idx):
                ''' Remove the src sequence of finished instances in one batch. '''
                batch = seq.data[:n_remaining_sents]
                selected = batch.index_select(0, active_idx)
                data = selected.repeat(beam_size, 1)
                return Variable(data, volatile=True)

            src_seq = update_active_seq(src_seq, active_idx)
            enc_output = update_active_enc_info(enc_output, active_idx)

            if self.model_opt.use_ctx:
                ctx_seq = update_active_seq(ctx_seq, active_idx)
                ctx_output = update_active_enc_info(ctx_output, active_idx)

            n_remaining_sents = len(active)

        #- Return useful information
        all_hyp, all_scores = [], []
        n_best = self.opt.n_best

        for b in range(batch_size):
            scores = self.tt.FloatTensor(
                beam_size + len(beam[b].finish_early_scores)).zero_()
            scores[:beam_size] = beam[b].scores
            for i in range(beam_size,
                           beam_size + len(beam[b].finish_early_scores)):
                scores[i] = beam[b].finish_early_scores[i - beam_size][2]
            beam[b].scores = scores
            scores, ks = beam[b].sort_scores()
            all_scores += [scores[:n_best]]
            hyps = [
                beam[b].get_hypothesis(k)
                if k < beam_size else beam[b].get_early_hypothesis(
                    beam[b].finish_early_scores[k - beam_size][0],
                    beam[b].finish_early_scores[k - beam_size][1])
                for k in ks[:n_best]
            ]
            all_hyp += [hyps]
        return all_hyp, all_scores
Exemple #6
0
    def generate_question_batch(self, src1_seq, src1_pos, src1_emo, src1_bio,
                                src1_bio_pos):
        '''

        :param src_seq:
        :param src_pos:
        :return:
        '''
        ''' Generate question batach by batch'''
        def get_inst_idx_to_tensor_position_map(inst_idx_list):
            return {
                inst_idx: tensor_position
                for tensor_position, inst_idx in enumerate(inst_idx_list)
            }

        def collect_active_part(beamed_tensor, curr_active_inst_idx,
                                n_prev_active_inst, n_bm):

            _, *d_hs = beamed_tensor.size()
            n_curr_active_inst = len(curr_active_inst_idx)
            new_shape = (n_curr_active_inst * n_bm, *d_hs)

            beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
            beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
            beamed_tensor = beamed_tensor.view(*new_shape)

            return beamed_tensor

        def collate_active_info(src_seq, src_enc, inst_idx_to_position_map,
                                active_inst_idx_list):
            '''

            :param src_seq:
            :param src_enc:
            :param inst_idx_to_position_map:
            :param active_inst_idx_list:
            :return:
            '''

            n_prev_active_inst = len(inst_idx_to_position_map)
            active_inst_idx = [
                inst_idx_to_position_map[k] for k in active_inst_idx_list
            ]
            active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)

            active_src_seq = collect_active_part(src_seq, active_inst_idx,
                                                 n_prev_active_inst, n_bm)
            active_src_enc = collect_active_part(src_enc, active_inst_idx,
                                                 n_prev_active_inst, n_bm)
            active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            return active_src_seq, active_src_enc, active_inst_idx_to_position_map

        def beam_decode_step(inst_dec_beams, len_dec_seq, src_seq, enc_output,
                             inst_idx_to_position_map, n_bm):
            '''
            :param inst_dec_beams:
            :param len_dec_seq:
            :param src_seq:
            :param enc_output:
            :param inst_idx_to_position_map:
            :param n_bm:
            :return:
            '''
            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
                dec_partial_seq = [
                    b.get_current_state() for b in inst_dec_beams if not b.done
                ]
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
                return dec_partial_seq

            def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
                dec_partial_pos = torch.arange(1,
                                               len_dec_seq + 1,
                                               dtype=torch.long,
                                               device=self.device)
                dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(
                    n_active_inst * n_bm, 1)
                return dec_partial_pos

            def predict_word(dec_seq, dec_pos, src_seq, enc_output,
                             n_active_inst, n_bm):

                dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq,
                                                    enc_output)

                dec_output = dec_output[:,
                                        -1, :]  # Pick the last step: (bh * bm) * d_h
                word_prob = F.log_softmax(self.model.tgt_word_prj(
                    self.model.fc(dec_output)),
                                          dim=1)

                word_prob = word_prob.view(n_active_inst, n_bm, -1)

                return word_prob

            def collect_active_inst_idx_list(inst_beams, word_prob,
                                             inst_idx_to_position_map):
                active_inst_idx_list = []
                for inst_idx, inst_position in inst_idx_to_position_map.items(
                ):
                    is_inst_complete = inst_beams[inst_idx].advance(
                        word_prob[inst_position])
                    if not is_inst_complete:
                        active_inst_idx_list += [inst_idx]

                return active_inst_idx_list

            n_active_inst = len(inst_idx_to_position_map)

            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
            dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)

            word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output,
                                     n_active_inst, n_bm)

            active_inst_idx_list = collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            return active_inst_idx_list

        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                all_scores += [scores[:n_best]]

                hyps = [
                    inst_dec_beams[inst_idx].get_hypothesis(i)
                    for i in tail_idxs[:n_best]
                ]
                all_hyp += [hyps]
            return all_hyp, all_scores

        with torch.no_grad():

            src1_seq = src1_seq.to(self.device)
            src1_bio = src1_bio.to(self.device)

            src1_emo = src1_emo.to(self.device)
            # print('src1_seq:',src1_seq.shape)
            # print('src1_emo:',src1_emo.shape)

            # src1_enc, *_ = self.model.encoder1(src1_seq, src1_pos, src1_emo)
            # src2_enc, *_ = self.model.encoder2(src2_seq, src2_pos, src2_emo, src1_enc)
            # src3_enc, *_ = self.model.encoder3(src3_seq, src3_pos, src3_emo, src2_enc)

            src_enc1 = self.model.gcn(src1_seq, src1_bio,
                                      src1_emo)  # (batch, 20, 300)
            src_enc2, _ = self.model.encoder1(src1_seq, src1_bio)
            # print("src_enc2.sahpe:",src_enc2.shape)
            # print("_.shape:",_.shape)
            # src_enc2 = self.model.layer3(src_enc2)
            src_enc = torch.cat((src_enc1, src_enc2), 2)
            src_enc = self.model.layer1(src_enc)
            # src_enc = (0.5*src_enc2 + 0.5*src_enc1)

            n_bm = self.opt.beam_size  # 5

            n_inst, len_s, d_h = src_enc.size()  # (batch, 20, 300)
            src_seq = src1_seq.repeat(1, n_bm).view(n_inst * n_bm,
                                                    len_s)  # (batch * 5, 20)

            # (batch, 20*5, 300) --> (batch * 5, 20, 300)
            src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s,
                                                      d_h)

            inst_dec_beams = [
                Beam(n_bm, device=self.device) for _ in range(n_inst)
            ]

            active_inst_idx_list = list(
                range(n_inst))  # [0, 1, 2, ..., batch-1]
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            for len_dec_seq in range(1, 30):

                active_inst_idx_list = beam_decode_step(
                    inst_dec_beams, len_dec_seq, src_seq, src_enc,
                    inst_idx_to_position_map, n_bm)

                if not active_inst_idx_list:
                    break

                src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
                    src_seq, src_enc, inst_idx_to_position_map,
                    active_inst_idx_list)

        batch_hyp, batch_scores = collect_hypothesis_and_scores(
            inst_dec_beams, self.opt.n_best)

        return batch_hyp, batch_scores
Exemple #7
0
def translate_batch(model, batch, opt, model_options):
    model.eval()
    # prepare data
    #key = [triples[0] for triples in batch]
    src = [triples[1] for triples in batch]
    tgt = [triples[2] for triples in batch]
    src_seq, src_pad_mask = instances_handler.pad_to_longest(src)
    tgt_seq, tgt_pad_mask = instances_handler.pad_to_longest(tgt)

    src_seq = Variable(torch.FloatTensor(
        src_seq))  #batch * max length in batch * padded feature dim
    src_pad_mask = Variable(torch.LongTensor(
        src_pad_mask))  #batch * maxlength in batch * bool mask dim
    tgt_seq = Variable(torch.LongTensor(
        tgt_seq))  #batch * max length in batch * padded index dim
    tgt_pad_mask = Variable(torch.LongTensor(
        tgt_pad_mask))  #batch * maxlength in batch * bool mask dim

    if opt.use_gpu:
        src_seq = src_seq.cuda()
        src_pad_mask = src_pad_mask.cuda()
        tgt_seq = tgt_seq.cuda()
        tgt_pad_mask = tgt_pad_mask.cuda()

    goal = tgt_seq[:, 1:]
    tgt_seq = tgt_seq[:, :-1]
    tgt_pad_mask = tgt_pad_mask[:, :-1]

    beam_size = opt.beam_size
    batch_size = src_seq.size(0)
    #---------------------------------------------------------------------------------------
    #- Enocde
    enc_output, *_ = model.encoder(src_seq, src_pad_mask)

    #--- Repeat data for beam
    src_seq = Variable(
        src_seq.data.repeat(1, beam_size, 1).view(
            src_seq.size(0) * beam_size, src_seq.size(1), src_seq.size(2)))

    src_pad_mask = Variable(
        src_pad_mask.data.repeat(1, beam_size).view(
            src_pad_mask.size(0) * beam_size, src_pad_mask.size(1)))

    enc_output = Variable(
        enc_output.data.repeat(1, beam_size, 1).view(
            enc_output.size(0) * beam_size, enc_output.size(1),
            enc_output.size(2)))

    #--- Prepare beams
    beams = [Beam(beam_size, opt.use_gpu) for _ in range(batch_size)]
    beam_inst_idx_map = {
        beam_idx: inst_idx
        for inst_idx, beam_idx in enumerate(range(batch_size))
    }
    n_remaining_sents = batch_size

    #- Decode
    for i in range(opt.max_token_seq_len):
        len_dec_seq = i + 1
        # -- Preparing decoded data seq -- #
        # size: batch x beam x seq
        dec_partial_seq = torch.stack(
            [b.get_current_state() for b in beams if not b.done])
        dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq).cpu().numpy()
        dec_partial_seq, dec_partial_seq_mask = instances_handler.pad_to_longest(
            dec_partial_seq)

        # size: (batch * beam) x seq
        dec_partial_seq = Variable(torch.LongTensor(dec_partial_seq))
        dec_partial_seq_mask = Variable(torch.LongTensor(dec_partial_seq_mask))

        if opt.use_gpu:
            dec_partial_seq = dec_partial_seq.cuda()
            dec_partial_seq_mask = dec_partial_seq_mask.cuda()

        # -- Decoding -- #
        dec_output, *_ = model.decoder(dec_partial_seq, dec_partial_seq_mask,
                                       src_pad_mask, enc_output)
        dec_output = dec_output[:, -1, :]  # (batch * beam) * d_model
        out = model.prob_projection(dec_output)

        # batch x beam x n_words
        word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous()

        active_beam_idx_list = []
        for beam_idx in range(batch_size):
            if beams[beam_idx].done:
                continue

            inst_idx = beam_inst_idx_map[beam_idx]
            if not beams[beam_idx].advance(word_lk.data[inst_idx]):
                active_beam_idx_list += [beam_idx]

        if not active_beam_idx_list:
            # all instances have finished their path to <EOS>
            break

        # in this section, the sentences that are still active are
        # compacted so that the decoder is not run on completed sentences
        active_inst_idxs = torch.LongTensor(
            [beam_inst_idx_map[k] for k in active_beam_idx_list])
        if opt.use_gpu:
            active_inst_idxs = active_inst_idxs.cuda()
        # update the idx mapping
        beam_inst_idx_map = {
            beam_idx: inst_idx
            for inst_idx, beam_idx in enumerate(active_beam_idx_list)
        }

        def update_active_seq(seq_var, active_inst_idxs):
            ''' Remove the src sequence of finished instances in one batch. '''

            inst_idx_dim_size, *rest_dim_sizes = seq_var.size()
            inst_idx_dim_size = inst_idx_dim_size * len(
                active_inst_idxs) // n_remaining_sents
            new_size = (inst_idx_dim_size, *rest_dim_sizes)

            # select the active instances in batch
            original_seq_data = seq_var.data.view(n_remaining_sents, -1)
            active_seq_data = original_seq_data.index_select(
                0, active_inst_idxs)
            active_seq_data = active_seq_data.view(*new_size)

            return Variable(active_seq_data, volatile=True)

        def update_active_enc_info(enc_info_var, active_inst_idxs):
            ''' Remove the encoder outputs of finished instances in one batch. '''

            inst_idx_dim_size, *rest_dim_sizes = enc_info_var.size()
            inst_idx_dim_size = inst_idx_dim_size * len(
                active_inst_idxs) // n_remaining_sents
            new_size = (inst_idx_dim_size, *rest_dim_sizes)

            # select the active instances in batch
            original_enc_info_data = enc_info_var.data.view(
                n_remaining_sents, -1, model_options.d_model)
            active_enc_info_data = original_enc_info_data.index_select(
                0, active_inst_idxs)
            active_enc_info_data = active_enc_info_data.view(*new_size)

            return Variable(active_enc_info_data, volatile=True)

        src_pad_mask = update_active_seq(src_pad_mask, active_inst_idxs)
        enc_output = update_active_enc_info(enc_output, active_inst_idxs)

        #- update the remaining size
        n_remaining_sents = len(active_inst_idxs)

    #- Return useful information
    all_hyp, all_scores = [], []

    for beam_idx in range(batch_size):
        scores, tail_idxs = beams[beam_idx].sort_scores()
        all_scores += [scores[:opt.n_best]]

        hyps = [
            beams[beam_idx].get_hypothesis(i) for i in tail_idxs[:opt.n_best]
        ]
        all_hyp += [hyps]

    return all_hyp, all_scores
Exemple #8
0
    def translate_batch(self, src_seq, src_pos, src_sen_pos):
        ''' Translation work in one batch '''

        def get_inst_idx_to_tensor_position_map(inst_idx_list):
            ''' Indicate the position of an instance in a tensor. '''
            return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}

        def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
            ''' Collect tensor parts associated to active instances. '''

            _, *d_hs = beamed_tensor.size()
            n_curr_active_inst = len(curr_active_inst_idx)
            new_shape = (n_curr_active_inst * n_bm, *d_hs)

            beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
            beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
            beamed_tensor = beamed_tensor.view(*new_shape)

            return beamed_tensor

        def collate_active_info(
                src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list):
            # Sentences which are still active are collected,
            # so the decoder will not run on completed sentences.
            n_prev_active_inst = len(inst_idx_to_position_map)
            active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
            active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)

            active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm)
            active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm)
            active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)

            return active_src_seq, active_src_enc, active_inst_idx_to_position_map

        def beam_decode_step(
                inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm):
            ''' Decode and update beam status, and then return active beam idx '''

            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
                dec_partial_seq = [b.get_current_seq_state() for b in inst_dec_beams if not b.done]
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
                return dec_partial_seq

            def prepare_beam_dec_pos(inst_dec_beams, len_dec_seq):
                dec_partial_seq = [b.get_current_pos_state() for b in inst_dec_beams if not b.done]
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
                return dec_partial_seq

            def prepare_beam_dec_sen_pos(inst_dec_beams, len_dec_seq):
                #print('inst_dec_beams',inst_dec_beams)
                #print('inst_dec_beams', type(inst_dec_beams))
                #print('inst_dec_beams',inst_dec_beams.size)
                #print('inst_dec_beams[0]', inst_dec_beams[0].get_current_sen_pos_state())
                dec_partial_seq = [b.get_current_sen_pos_state() for b in inst_dec_beams if not b.done]
                #print('dec_partial_seq',dec_partial_seq)
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
                return dec_partial_seq

            def predict_word(dec_seq, dec_pos, dec_sen_pos, src_seq, enc_output, n_active_inst, n_bm):
                #print('dec_seq:',dec_seq.shape)
                #print('dec_pos:', dec_pos.shape)
                #print('dec_sen_pos:', dec_sen_pos.shape)
                #print('src_seq:', src_seq.shape)
                #print('enc_output', enc_output.shape)
                #print('n_active_inst', n_active_inst)
                #print('n_bm:', n_bm)

                
                dec_output, *_ = self.model.decoder(dec_seq, dec_pos, dec_sen_pos, src_seq, enc_output)
                dec_output = dec_output[:, -1, :]  # Pick the last step: (bh * bm) * d_h
                logits = self.model.tgt_word_prj(dec_output)
                ## word_prob actually
                logits = F.log_softmax(logits, dim=1)
                # UNK mask
                logits[:, Constants.UNK] = -1e19
                #rm_set = set(Constants.BOSs+[13])
                rm_set = set(Constants.BOSs+[19]) # 19 => "."
                #logits[:, 19] -= logits[:, 19].abs()* + 1e-10
                for i, (ins, pos, sen_pos) in enumerate(zip(dec_seq, dec_pos, dec_sen_pos)):
                    current_sen_pos = sen_pos[-1]
                    for token, s_pos in zip(ins.flip(0), sen_pos.flip(0)):
                        length_norm = len(ins)
                        if token.item() not in rm_set and s_pos == current_sen_pos:
                            logits[i, token] -= 5+1e-19
                        if s_pos != current_sen_pos:
                            logits[i, token] -= 20/length_norm - 1e-19
                            #break

                #word_prob = F.log_softmax(logits, dim=1)
                #word_prob = word_prob.view(n_active_inst, n_bm, -1)
                word_prob = logits.view(n_active_inst, n_bm, -1)

                return word_prob

            def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map):
                active_inst_idx_list = []
                for inst_idx, inst_position in inst_idx_to_position_map.items():
                    is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
                    if not is_inst_complete:
                        active_inst_idx_list += [inst_idx]

                return active_inst_idx_list

            n_active_inst = len(inst_idx_to_position_map)

            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
            dec_pos = prepare_beam_dec_pos(inst_dec_beams, len_dec_seq)
            dec_sen_pos = prepare_beam_dec_sen_pos(inst_dec_beams, len_dec_seq)
            #dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
            #dec_pos, dec_sen_pos = prepare_beam_dec_pos(dec_seq)
            word_prob = predict_word(dec_seq, dec_pos, dec_sen_pos, src_seq, enc_output, n_active_inst, n_bm)

            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            return active_inst_idx_list

        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                all_scores += [scores[:n_best]]

                hyps = [inst_dec_beams[inst_idx].get_seq_hypothesis(i) for i in tail_idxs[:n_best]]
                all_hyp += [hyps]
            return all_hyp, all_scores

        with torch.no_grad():
            #-- Encode
            src_seq, src_pos, src_sen_pos = src_seq.to(self.device), src_pos.to(self.device), src_sen_pos.to(self.device)
            src_enc, *_ = self.model.encoder(src_seq, src_pos, src_sen_pos)

            #-- Repeat data for beam search
            n_bm = self.opt.beam_size
            n_inst, len_s, d_h = src_enc.size()

            #print('n_inst',n_inst)
            #print('len_s',len_s)
            #print('d_h',d_h)
            #print('n_bm',n_bm)

            src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
            src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)

            #-- Prepare beams
            inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)]

            #print('inst_dec_beams.shape',len(inst_dec_beams))
            #print('inst_dec_beams.shape',inst_dec_beams[0].size)
            #print('inst_dec_beams[0].getcurrent', inst_dec_beams[0].get_current_sen_pos_state())
            #-- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)

            #-- Decode
            #for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1):
            for len_dec_seq in range(1, 200):
                active_inst_idx_list = beam_decode_step(
                    inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm)

                if not active_inst_idx_list or len_dec_seq > 50:
                    break  # all instances have finished their path to <EOS>

                src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
                    src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list)

        batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best)

        return batch_hyp, batch_scores
Exemple #9
0
    def sample(self, src_seq, enc_outputs):
        """Samples captions for given image features (Greedy search)."""
        beam_size = 1
        batch_size = src_seq.size(0)

        self.softmax = nn.LogSoftmax()
        self.tt = torch.cuda if torch.cuda.is_available() else torch

        # Repeat Data
        src_seq = Variable(
            src_seq.data.repeat(1, beam_size).view(
                src_seq.size(0) * beam_size, src_seq.size(1)))

        for i in range(len(enc_outputs)):
            enc_output = enc_outputs[i]
            enc_outputs[i] = Variable(
                enc_output.data.repeat(1, beam_size, 1).view(
                    enc_output.size(0) * beam_size, enc_output.size(1),
                    enc_output.size(2)))

        # --- Prepare beams
        beams = [
            Beam(beam_size, torch.cuda.is_available())
            for _ in range(batch_size)
        ]
        beam_inst_idx_map = {
            beam_idx: inst_idx
            for inst_idx, beam_idx in enumerate(range(batch_size))
        }
        n_remaining_sents = batch_size

        # - Decode
        for i in range(20):
            len_dec_seq = i + 1

            # -- Preparing decoded data seq -- #
            # size: batch x beam x seq
            dec_partial_seq = torch.stack(
                [b.get_current_state() for b in beams if not b.done])
            # size: (batch * beam) x seq
            dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
            # wrap into a Variable
            dec_partial_seq = Variable(dec_partial_seq, volatile=True)

            if torch.cuda.is_available():
                dec_partial_seq = dec_partial_seq.cuda()

            # -- Decoding -- #
            dec_output = self(src_seq, dec_partial_seq, enc_outputs,
                              [len_dec_seq] * n_remaining_sents * beam_size,
                              False)
            dec_output = dec_output[:, -1, :]  # (batch * beam) * d_model
            out = self.softmax(dec_output)

            # batch x beam x n_words
            word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous()

            active_beam_idx_list = []

            for beam_idx in range(batch_size):
                if beams[beam_idx].done:
                    continue

                inst_idx = beam_inst_idx_map[beam_idx]
                if not beams[beam_idx].advance(word_lk.data[inst_idx]):
                    active_beam_idx_list += [beam_idx]

            if not active_beam_idx_list:
                # all instances have finished their path to <EOS>
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            active_inst_idxs = self.tt.LongTensor(
                [beam_inst_idx_map[k] for k in active_beam_idx_list])

            # update the idx mapping
            beam_inst_idx_map = {
                beam_idx: inst_idx
                for inst_idx, beam_idx in enumerate(active_beam_idx_list)
            }

            def update_active_seq(seq_var, active_inst_idxs):
                ''' Remove the src sequence of finished instances in one batch. '''

                inst_idx_dim_size, b = seq_var.size()
                inst_idx_dim_size = inst_idx_dim_size * len(
                    active_inst_idxs) // n_remaining_sents
                new_size = inst_idx_dim_size, b

                # select the active instances in batch
                original_seq_data = seq_var.data.view(n_remaining_sents, -1)
                active_seq_data = original_seq_data.index_select(
                    0, active_inst_idxs)
                active_seq_data = active_seq_data.view(*new_size)

                return Variable(active_seq_data, volatile=True)

            def update_active_enc_info(enc_info_var, active_inst_idxs):
                ''' Remove the encoder outputs of finished instances in one batch. '''

                inst_idx_dim_size, b, c = enc_info_var.size()
                inst_idx_dim_size = inst_idx_dim_size * len(
                    active_inst_idxs) // n_remaining_sents
                new_size = inst_idx_dim_size, b, c

                # select the active instances in batch
                original_enc_info_data = enc_info_var.data.view(
                    n_remaining_sents, -1, self.d_model)
                active_enc_info_data = original_enc_info_data.index_select(
                    0, active_inst_idxs)
                active_enc_info_data = active_enc_info_data.view(*new_size)

                return Variable(active_enc_info_data, volatile=True)

            src_seq = update_active_seq(src_seq, active_inst_idxs)

            for j in range(len(enc_outputs)):
                enc_outputs[j] = update_active_enc_info(
                    enc_outputs[j], active_inst_idxs)

            # - update the remaining size
            n_remaining_sents = len(active_inst_idxs)

        # - Return useful information
        all_hyp, all_scores = [], []
        n_best = 1

        for beam_idx in range(batch_size):
            scores, tail_idxs = beams[beam_idx].sort_scores()
            all_scores += [scores[:n_best]]

            hyps = [
                beams[beam_idx].get_hypothesis(i) for i in tail_idxs[:n_best]
            ]
            all_hyp += [hyps]

        return all_hyp
Exemple #10
0
    def decode_batch(self, src_batch):
        ''' Translation work in one batch '''

        # Batch size is in different location depending on data.
        src_seq = src_batch
        batch_size = src_seq.size(0)
        beam_size = self.opt.beam_size

        # - Enocde
        enc_output, src_mask = self.model.encoder(src_seq)
        # print('enc_output.size', enc_output.size())     #(batch, length, d_model)
        # print('src_mask.size()', src_mask.size())  # (batch, 1, length)

        # (batch * beam_size, length, d_model)
        enc_output = Variable(
            enc_output.data.repeat(1, beam_size, 1).view(
                enc_output.size(0) * beam_size, enc_output.size(1),
                enc_output.size(2)))

        # (batch * beam_size, 1, d_model)
        src_mask = src_mask.repeat(1, beam_size, 1).view(
            src_mask.size(0) * beam_size, src_mask.size(1), src_mask.size(2))

        # --- Prepare beams
        beams = [Beam(beam_size, self.device) for _ in range(batch_size)]
        # print('beams:',beams)
        beam_inst_idx_map = {
            beam_idx: inst_idx
            for inst_idx, beam_idx in enumerate(range(batch_size))
        }
        # print('beam_inst_idx_map:',beam_inst_idx_map)
        n_remaining_sents = batch_size

        # - Decode
        for i in range(self.model_opt.label_max_len):

            # print('-'*20)
            len_dec_seq = i + 1
            # print(len_dec_seq)

            # -- Preparing decoded data seq -- #
            # size: (batch , beam , len_dec_seq)
            dec_partial_seq = torch.stack(
                [b.get_current_state() for b in beams if not b.done])
            # print('dec_partial_seq 1',dec_partial_seq.size())

            # size: (batch * beam , len_dec_seq)
            dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
            # print('dec_partial_seq 2:\n', dec_partial_seq)

            dec_partial_seq = dec_partial_seq.to(self.device)

            # -- Decoding -- #

            # (batch * beam, len_dec_seq, d_model)
            dec_output = self.model.decoder(dec_partial_seq, enc_output,
                                            src_mask)
            # print('dec_output:',dec_output.size())

            # (batch * beam, d_model)
            dec_output = dec_output[:, -1, :]

            # (batch * beam, vocab_size)
            dec_output = self.model.final_proj(dec_output)
            # print('decoder output shape:', dec_output.size())

            # (batch * beam, vocab_size) logSoftmax
            out = self.model.log_softmax(dec_output)

            # (batch , beam , vocab_size)
            word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous()

            active_beam_idx_list = []

            for beam_idx in range(batch_size):
                # current case in batch, is predicted EOS.
                if beams[beam_idx].done:
                    # print('continue','\n'*100)
                    continue

                inst_idx = beam_inst_idx_map[beam_idx]

                # print('word_lk.data[%d]'%(inst_idx),word_lk.data[inst_idx])

                # word_lk.data[inst_idx] (beam_size, vocab_size) current inst of batch
                if not beams[beam_idx].advance(word_lk.data[inst_idx]):
                    active_beam_idx_list += [beam_idx]

            if not active_beam_idx_list:
                # all instances have finished their path to <EOS>
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            active_inst_idxs = torch.LongTensor(
                [beam_inst_idx_map[k] for k in active_beam_idx_list])

            # update the idx mapping
            beam_inst_idx_map = {
                beam_idx: inst_idx
                for inst_idx, beam_idx in enumerate(active_beam_idx_list)
            }

            # print('beam_inst_idx_map2:\n',beam_inst_idx_map)

            def update_active_seq(seq_var, active_inst_idxs):
                ''' Remove the src sequence of finished instances in one batch. '''

                inst_idx_dim_size, *rest_dim_sizes = seq_var.size()
                inst_idx_dim_size = inst_idx_dim_size * \
                                    len(active_inst_idxs) // n_remaining_sents
                new_size = (inst_idx_dim_size, *rest_dim_sizes)

                # select the active instances in batch
                original_seq_data = seq_var.data.view(n_remaining_sents, -1)
                active_seq_data = original_seq_data.index_select(
                    0, active_inst_idxs)
                active_seq_data = active_seq_data.view(*new_size)
                with torch.no_grad():
                    return Variable(active_seq_data)

            def update_active_enc_info(enc_info_var, active_inst_idxs):
                ''' Remove the encoder outputs of finished instances in one batch. '''

                # (batch * beam, length, d_model)
                inst_idx_dim_size, *rest_dim_sizes = enc_info_var.size()

                inst_idx_dim_size = inst_idx_dim_size * len(
                    active_inst_idxs) // n_remaining_sents
                new_size = (inst_idx_dim_size, *rest_dim_sizes)
                # print('new_size:\n',new_size)
                # print(n_remaining_sents)

                # select the active instances in batch
                # (batch, beam, d_model)
                original_enc_info_data = enc_info_var.data.view(
                    n_remaining_sents, -1, enc_info_var.size(2))

                # select instance of batch (new_batch, beam, d_model)
                active_enc_info_data = original_enc_info_data.index_select(
                    0, active_inst_idxs)
                active_enc_info_data = active_enc_info_data.view(*new_size)
                with torch.no_grad():
                    return Variable(active_enc_info_data)

            enc_output = update_active_enc_info(
                enc_output, active_inst_idxs.to(self.device))
            src_mask = update_active_enc_info(src_mask,
                                              active_inst_idxs.to(self.device))

            # - update the remaining size
            n_remaining_sents = len(active_inst_idxs)

        # - Return useful information
        all_hyp, all_scores = [], []
        n_best = self.opt.n_best

        for beam_idx in range(batch_size):
            scores, tail_idxs = beams[beam_idx].sort_scores()
            all_scores += [scores[:n_best]]

            # hyps1 = [beams[beam_idx].get_hypothesis(
            #     i) for i in tail_idxs[:n_best]]
            # print(torch.LongTensor(hyps1))
            hyps = torch.LongTensor(beams[beam_idx].bestpath)[:n_best, 1:]
            # print(hyps)
            # assert torch.LongTensor(hyps1).equal(hyps)
            all_hyp += [hyps]

        return all_hyp, all_scores
    def translate_batch(self, src_seq, src_pos):
        """ Translation work in one batch """
        def get_inst_idx_to_tensor_position_map(inst_idx_list):
            """ Indicate the position of an instance in a tensor. """
            return {
                inst_idx: tensor_position
                for tensor_position, inst_idx in enumerate(inst_idx_list)
            }

        def collect_active_part(beamed_tensor, curr_active_inst_idx,
                                n_prev_active_inst, n_bm):
            """ Collect tensor parts associated to active instances. """

            _, *d_hs = beamed_tensor.size()
            n_curr_active_inst = len(curr_active_inst_idx)
            new_shape = (n_curr_active_inst * n_bm, *d_hs)

            beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
            beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
            beamed_tensor = beamed_tensor.view(*new_shape)

            return beamed_tensor

        def collate_active_info(src_seq, src_enc, inst_idx_to_position_map,
                                active_inst_idx_list):
            # Sentences which are still active are collected,
            # so the decoder will not run on completed sentences.
            n_prev_active_inst = len(inst_idx_to_position_map)
            active_inst_idx = [
                inst_idx_to_position_map[k] for k in active_inst_idx_list
            ]
            active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)

            active_src_seq = collect_active_part(src_seq, active_inst_idx,
                                                 n_prev_active_inst, n_bm)
            active_src_enc = collect_active_part(src_enc, active_inst_idx,
                                                 n_prev_active_inst, n_bm)
            active_inst_idx_to_position_map = \
                get_inst_idx_to_tensor_position_map(active_inst_idx_list)

            return active_src_seq, active_src_enc, active_inst_idx_to_position_map

        def beam_decode_step(inst_dec_beams, len_dec_seq, src_seq, enc_output,
                             inst_idx_to_position_map, n_bm):
            """ Decode and update beam status, and then return active beam idx
            """
            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
                dec_partial_seq = [
                    b.get_current_state() for b in inst_dec_beams if not b.done
                ]
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
                return dec_partial_seq

            def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
                dec_partial_pos = torch.arange(1,
                                               len_dec_seq + 1,
                                               dtype=torch.long,
                                               device=self.device)
                dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(
                    n_active_inst * n_bm, 1)
                return dec_partial_pos

            def predict_word(dec_seq, dec_pos, src_seq, enc_output,
                             n_active_inst, n_bm):
                ##############################################################
                # Make Mask
                # Find The Token Appeared in src sequence
                uniques = torch.unique(src_seq, dim=1).cpu().detach().numpy()
                # print(uniques.shape)

                p_gen_mask_shape = (src_seq.shape[0], self.n_voca)
                p_gen_mask = torch.tensor(np.zeros(p_gen_mask_shape),
                                          dtype=torch.float)

                batch_size = src_seq.shape[0]
                p_gen_mask[np.arange(batch_size)[:, None], uniques] = 1
                p_gen_mask = p_gen_mask.to(self.device)
                # print(p_gen_mask.shape)
                ############################################################

                dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq,
                                                    enc_output)
                print("dec_output.shape | before reshape", dec_output.shape)
                p_gen, *_ = self.model.p_generator(dec_seq, dec_pos, src_seq,
                                                   enc_output)
                p_gen = p_gen[:, -1, :]  # to get just last one.. why?
                # print("p_gen", p_gen.shape)

                p_gen = self.model.p_gen_linear(p_gen)
                p_gen = self.model.p_gen_sig(p_gen)

                dec_output = dec_output[:, -1, :]
                print("dec_output.shape | after reshape", dec_output.shape)

                seq_logit = self.model.tgt_word_prj(dec_output)  #
                print("seq_logit.shape | wo rd_prj", seq_logit.shape)

                seq_max_len = 1
                print("seq_max_len", seq_max_len)

                p_gen_mask = p_gen_mask[:, None, :]
                p_gen_mask = p_gen_mask[:, -1, :]

                print("p_gen_mask", p_gen_mask.shape)

                p_gen_mask = torch.repeat_interleave(p_gen_mask,
                                                     seq_max_len,
                                                     dim=1)
                masked_seq_logit = seq_logit * p_gen_mask
                print("masked_seq_logit.shape", masked_seq_logit.shape)
                ###########################################################
                softmax = torch.nn.Softmax(dim=1)
                prb_gen = softmax(seq_logit)
                prb_cp = softmax(masked_seq_logit)

                exclusive_copy_or_gen = True
                if exclusive_copy_or_gen:
                    p_gen = p_gen > 1 / 2
                    p_gen = p_gen.to(torch.float)

                prb = prb_gen * p_gen + prb_cp * (1 - p_gen)

                word_prob = prb.log()

                # Pick the last step: (bh * bm) * d_h

                # word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output),
                #                           dim=1)

                word_prob = word_prob.view(n_active_inst, n_bm, -1)

                return word_prob

            def collect_active_inst_idx_list(inst_beams, word_prob,
                                             inst_idx_to_position_map):
                active_inst_idx_list = []
                for inst_idx, inst_position in inst_idx_to_position_map.items(
                ):
                    is_inst_complete = inst_beams[inst_idx].advance(
                        word_prob[inst_position])
                    if not is_inst_complete:
                        active_inst_idx_list += [inst_idx]

                return active_inst_idx_list

            n_active_inst = len(inst_idx_to_position_map)  # int

            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
            dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
            word_prob = predict_word(
                dec_seq,
                dec_pos,
                src_seq,
                enc_output,
                n_active_inst,  # what is n_active_inst?
                n_bm)  # what is n_bm?

            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            return active_inst_idx_list

        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                all_scores += [scores[:n_best]]

                hyps = [
                    inst_dec_beams[inst_idx].get_hypothesis(i)
                    for i in tail_idxs[:n_best]
                ]
                all_hyp += [hyps]
            return all_hyp, all_scores

        with torch.no_grad():
            # -- Encode
            src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device)
            src_enc, *_ = self.model.encoder(src_seq, src_pos)

            # -- Repeat data for beam search
            n_bm = self.opt.beam_size
            n_inst, len_s, d_h = src_enc.size()
            src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
            src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s,
                                                      d_h)

            # -- Prepare beams
            inst_dec_beams = [
                Beam(n_bm, device=self.device) for _ in range(n_inst)
            ]

            # -- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            # -- Decode
            #            seq_logit_group = []
            for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1):

                active_inst_idx_list = beam_decode_step(
                    inst_dec_beams, len_dec_seq, src_seq, src_enc,
                    inst_idx_to_position_map, n_bm)
                # seq_logit_group.append(seq_logit)
                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>

                src_seq, src_enc, inst_idx_to_position_map = \
                    collate_active_info(src_seq,
                                        src_enc,
                                        inst_idx_to_position_map,
                                        active_inst_idx_list)

        batch_hyp, batch_scores = collect_hypothesis_and_scores(
            inst_dec_beams, self.opt.n_best)

        return batch_hyp, batch_scores  #, seq_logit_group