def sample_G(self,
                 mbsize,
                 z,
                 c,
                 sample_mode='categorical',
                 temp=1.0,
                 gumbel_temp=1.0,
                 prepend_start_idx=True,
                 prevent_empty=False,
                 min_length=1,
                 beam_size=5,
                 n_best=3):
        """
        This function samples a minibatch of mbsize from the decoder, given a (z,c) input.
        sample_mode determines hard sampling (categorical / greedy / gumbel_max) vs soft (gumbel_soft, gumbel_ST, XX_softmax)
        prepend_start_idx will prepend dummy <start> token, matches dataloader format.
        prevent_empty will modify the probabilities before hard sampling from them.
        min_length will not modify sampling, but just have at least this length output even if it's just all padding.
        """
        sample_mode_soft = sample_mode in [
            'gumbel_soft', 'gumbel_ST', 'greedy_softmax',
            'categorical_softmax', 'none_softmax'
        ]
        assert not (
            sample_mode_soft and prevent_empty
        ), 'cant prevent_empty when soft sampling, we dont wanna modify softmax in place before feeding back into next timestep'
        assert beam_size >= n_best, 'Can\'t return more than max hypothesis'
        assert mbsize == z.size(0) == c.size(
            0), 'oops sizes dont match {} {} {}'.format(
                mbsize, z.size(0), c.size(0))
        assert (
            not self.use_flow) or z.flowed, 'BUG: flow>0 but z.flowed=False'

        # Collecting sampled sequences - Note: does not work for beam search
        seqIx = []
        seqSoftIx = []

        # to mask out after EOS
        finished = torch.zeros(mbsize, dtype=torch.bool).to(self.device)

        if sample_mode == 'beam':

            def unbottle(m):
                return m.view(beam_size, mbsize, -1)

            # Repeat inputs beam_size times
            z = z.repeat(beam_size, 1)
            c = c.repeat(beam_size, 1)
            # Initialize Beams
            beam = [
                Beam(beam_size,
                     n_best=n_best,
                     device=self.device,
                     pad=PAD_IDX,
                     bos=START_IDX,
                     eos=EOS_IDX,
                     min_length=min_length) for ___ in range(mbsize)
            ]
            # Start: first beam BOS, rest PAD.
            sampleIx = torch.stack([b.get_current_state() for b in beam]) \
                .t().contiguous().view(-1)
        else:
            # Start: all BOS.
            sampleIx = torch.LongTensor(mbsize).to(
                self.device).fill_(START_IDX)
        sampleSoftIx = None

        # RNN state
        h = self.decoder.init_hidden(z, c)  # [mbsize x z,c]
        h = h.unsqueeze(0)  # prepend 1 = num_layers * num_directions

        # seqLogProbs = [] # unused for now
        # collecting sampled logprobs would be basis for all policy gradient algos (seqGAN etc)

        # include start_idx in the output
        if prepend_start_idx:
            seqIx.append(sampleIx)
            if sample_mode_soft:
                seqSoftIx.append(onehot_embed(sampleIx, self.n_vocab).detach())

        for i in range(self.MAX_SEQ_LEN):
            ### 1) FORWARD PASS THIS TIMESTEP
            logits, h = self.decoder.forward_sample(sampleSoftIx, sampleIx, z,
                                                    c, h)
            # END TODO use forward_decoder()
            if prevent_empty and i == 0:
                # kinda hacky: force first char to be real character by  masking out the logits corresponding to pad/start/eos.
                large_neg = -2 * torch.abs(
                    logits.min()
                )  # dont wanna throw off downstream softmaxes by just putting -inf
                for maskix in [PAD_IDX, START_IDX, EOS_IDX]:
                    logits[:, maskix] = large_neg

            ### 2) GIVEN LOGITS, SAMPLE -> sampleIx, sampleLogProbs, sampleSoftIx
            if sample_mode == 'categorical':
                sampleIx = torch.distributions.Categorical(logits=logits /
                                                           temp).sample()
            elif sample_mode == 'greedy':
                sampleIx = torch.argmax(logits, 1)
            elif sample_mode == 'gumbel_max':
                tmp = """hard decision, same as Categorical sampling."""
            elif sample_mode == 'beam':
                logits = unbottle(logits)
                # Update the beams
                for j, b in enumerate(beam):
                    if not b.done():
                        logprobs = F.log_softmax(logits[:, j], dim=1)
                        b.advance(logprobs)
                    # Update corresponding hidden states
                    # NOTE if not advanced, the hidden will be reset and sampleIx will remain.
                    self._update_hidden(h, j, b.get_current_origin(),
                                        beam_size)
                # Get the current predictions
                sampleIx = torch.stack([b.get_current_state() for b in beam]) \
                    .t().contiguous().view(-1)
            # ABOVE: HARD SAMPLING, BELOW: SOFT SAMPLING
            elif sample_mode == 'gumbel_soft':
                tmp = """keep the softmax as seqSoftIx, not straight through."""
            elif sample_mode == 'gumbel_ST':
                tmp = """sampleSoftIx are straight-through onehot(argmax(gumbel_softmax)) which will pass through biased gradients"""
            # below: sampleIx none/greedy/categorical. softmax for softIx. Return seqIx, seqSoftIx.
            # The hard sample mode matters for when we'll run into EOS and mask out all subsequent softmaxes.
            elif sample_mode == 'none_softmax':
                sampleSoftIx = F.softmax(logits / temp, dim=1)
            elif sample_mode == 'greedy_softmax':
                sampleIx = torch.argmax(logits, 1)
                sampleSoftIx = F.softmax(logits / temp, dim=1)
            elif sample_mode == 'categorical_softmax':
                sampleIx = torch.distributions.Categorical(logits=logits /
                                                           temp).sample()
                sampleSoftIx = F.softmax(logits / temp, dim=1)
            else:
                raise Exception(
                    'Sample mode {} not implemented.'.format(sample_mode))

            ### 3) FINISHED SENTENCES: MASK OUT sampleIx, sampleLogProbs, sampleSoftIx
            # Not in beam-search: implemented inside of Beam.py
            if not sample_mode == "beam":
                sampleIx.masked_fill_(finished, PAD_IDX)  #(mask, value)
                finished[
                    sampleIx ==
                    EOS_IDX] = True  # new EOS reached, mask out in the future.
                seqIx.append(sampleIx)

                if sample_mode_soft:
                    sampleSoftIx = sampleSoftIx.masked_fill(
                        finished.unsqueeze(1).clone(), 0)
                    # set "one-hots" to 0, will embed to 0 vector. Note not exactly the same as sampleIx=0 which will map to embedweight[0,:]
                    seqSoftIx.append(sampleSoftIx)

            ### 4) UPDATE MASK FOR NEXT ITERATION; BREAK (if all done)
            if finished.sum() == mbsize and len(seqIx) >= min_length:
                break  # everyone is done
            if sample_mode == "beam":
                if all((b.done() for b in beam)):
                    break
        if sample_mode == "beam":
            seqIx = []
            for b in beam:
                scores, ks = b.sort_finished(minimum=n_best)
                hyps = []
                for i, (times, k) in enumerate(ks[:n_best]):
                    hyp = b.get_hyp(times, k)
                    hyps.append(hyp)
                seqIx.append(hyps)
            return seqIx

        # End of loop. Assemble seqIx, seqSoftIx into tensor.
        seqIx = torch.stack(seqIx, dim=1)  # bs x seqlen
        if sample_mode_soft:
            seqSoftIx = torch.stack(
                seqSoftIx, dim=1
            )  # bs x seqlen x vocab. Note seqlen dim is inserted in the middle.
            assert seqIx.size(1) == seqSoftIx.size(
                1), 'messup with prepending startIx?'
            return seqIx, seqSoftIx
        else:
            return seqIx  # only hard sampling.
Exemple #2
0
def translate(model, opt, src_batch, adj):
    ''' Translation work in one batch '''
    tt = torch.cuda if opt.cuda else torch

    # Batch size is in different location depending on data.
    src_seq, src_pos = src_batch
    batch_size = src_seq.size(0)
    beam_size = opt.beam_size

    #- Enocde
    enc_output, *_ = model.encoder(src_seq, adj, src_pos)

    #--- Repeat data for beam
    src_seq = src_seq.data.repeat(1, beam_size).view(
        src_seq.size(0) * beam_size, src_seq.size(1))

    enc_output = enc_output.data.repeat(1, beam_size, 1).view(
        enc_output.size(0) * beam_size, enc_output.size(1), enc_output.size(2))

    #--- Prepare beams
    beams = [Beam(beam_size, opt.cuda) for _ in range(batch_size)]
    beam_inst_idx_map = {
        beam_idx: inst_idx
        for inst_idx, beam_idx in enumerate(range(batch_size))
    }
    n_remaining_sents = batch_size

    if opt.decoder == 'rnn_m':
        decoder_hidden = enc_output.mean(1)

    #- Decode
    for i in range(opt.max_token_seq_len_d):

        len_dec_seq = i + 1

        # -- Preparing decoded data seq -- #
        # size: batch x beam x seq

        dec_partial_seq = torch.stack(
            [b.get_current_state() for b in beams if not b.done])

        # size: (batch * beam) x seq
        dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
        dec_partial_seq = dec_partial_seq

        # -- Preparing decoded pos seq -- #
        # size: 1 x seq
        # dec_partial_pos = torch.arange(1, len_dec_seq + 1).unsqueeze(0)
        # size: (batch * beam) x seq
        # dec_partial_pos = dec_partial_pos.repeat(n_remaining_sents * beam_size, 1)
        # dec_partial_pos = dec_partial_pos.type(torch.LongTensor)

        if opt.cuda:
            dec_partial_seq = dec_partial_seq.cuda()
            # dec_partial_pos = dec_partial_pos.cuda()

        # -- Decoding -- #
        # print(dec_partial_seq)
        if opt.decoder == 'rnn_m':
            dec_enc_attn_pad_mask = get_attn_padding_mask(dec_partial_seq,
                                                          src_seq,
                                                          unsqueeze=False)
            dec_output, decoder_hidden, _ = model.decoder.forward_step(
                dec_partial_seq[:, -1].unsqueeze(1),
                decoder_hidden.squeeze(),
                enc_output,
                dec_enc_attn_pad_mask=dec_enc_attn_pad_mask)
            dec_output = dec_output[-1, :, :]
        else:
            # dec_output, *_ = model.decoder(dec_partial_seq, dec_partial_pos, src_seq, enc_output)
            dec_output, *_ = model.decoder(dec_partial_seq, src_seq,
                                           enc_output)
            dec_output = dec_output[:, -1, :]  # (batch * beam) * d_model
            dec_output = model.tgt_word_proj(dec_output)

            # dec_output += model.U(enc_output.mean(1))#.unsqueeze(1)

        # Mask previously predicted labels
        for J in range(dec_output.size(0)):
            dec_output.data[J].index_fill_(0, dec_partial_seq.data[J],
                                           -float('inf'))

        out = F.log_softmax(dec_output, dim=1)

        # batch x beam x n_words
        word_lk = out.view(n_remaining_sents, beam_size, -1).contiguous()

        active_beam_idx_list = []
        for beam_idx in range(batch_size):
            if beams[beam_idx].done:
                continue

            inst_idx = beam_inst_idx_map[beam_idx]
            if not beams[beam_idx].advance(word_lk.data[inst_idx]):
                active_beam_idx_list += [beam_idx]

        if not active_beam_idx_list:
            # all instances have finished their path to <EOS>
            break

        # in this section, the sentences that are still active are
        # compacted so that the decoder is not run on completed sentences
        active_inst_idxs = tt.LongTensor(
            [beam_inst_idx_map[k] for k in active_beam_idx_list])

        # update the idx mapping
        beam_inst_idx_map = {
            beam_idx: inst_idx
            for inst_idx, beam_idx in enumerate(active_beam_idx_list)
        }

        def update_active_seq(seq_var, active_inst_idxs):
            ''' Remove the src sequence of finished instances in one batch. '''

            inst_idx_dim_size, *rest_dim_sizes = seq_var.size()
            inst_idx_dim_size = inst_idx_dim_size * len(
                active_inst_idxs) // n_remaining_sents
            new_size = (inst_idx_dim_size, *rest_dim_sizes)

            # select the active instances in batch
            original_seq_data = seq_var.data.view(n_remaining_sents, -1)
            active_seq_data = original_seq_data.index_select(
                0, active_inst_idxs)
            active_seq_data = active_seq_data.view(*new_size)

            return active_seq_data

        def update_active_enc_info(enc_info_var, active_inst_idxs):
            ''' Remove the encoder outputs of finished instances in one batch. '''

            inst_idx_dim_size, *rest_dim_sizes = enc_info_var.size()
            inst_idx_dim_size = inst_idx_dim_size * len(
                active_inst_idxs) // n_remaining_sents
            new_size = (inst_idx_dim_size, *rest_dim_sizes)

            # select the active instances in batch
            original_enc_info_data = enc_info_var.data.view(
                n_remaining_sents, -1, opt.d_model)
            active_enc_info_data = original_enc_info_data.index_select(
                0, active_inst_idxs)
            active_enc_info_data = active_enc_info_data.view(*new_size)

            return active_enc_info_data

        src_seq = update_active_seq(src_seq, active_inst_idxs)
        enc_output = update_active_enc_info(enc_output, active_inst_idxs)

        if opt.decoder == 'rnn_m':
            decoder_hidden = update_active_enc_info(
                decoder_hidden.transpose(0, 1), active_inst_idxs)
            decoder_hidden = decoder_hidden.transpose(0, 1)

        #- update the remaining size
        n_remaining_sents = len(active_inst_idxs)

    #- Return useful information
    all_hyp, all_hyp_scores, all_scores = [], [], []
    n_best = opt.n_best
    # for i in range(batch_size): print(len(beams[i].all_scores))
    for beam_idx in range(batch_size):
        scores, tail_idxs = beams[beam_idx].sort_scores()
        all_scores += [scores[:n_best]]

        hyps = [beams[beam_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
        all_hyp += [hyps]

        # stop()
        all_hyp_scores += [[
            torch.exp(i)[0] for i in beams[beam_idx].all_scores
        ]]
        # if beam_idx == 2:
        #     stop()

    return all_hyp, all_hyp_scores  #,all_scores
Exemple #3
0
    def decode_beam_search(self,
                           hidden_states,
                           memory,
                           src_mask,
                           vocab,
                           copy_tokens=None,
                           beam_size=5,
                           n_best=1,
                           alpha=0.6,
                           length_pen='avg'):
        """
            Beam search decoding
        """
        results = {"scores": [], "predictions": []}

        # Construct beams, we donot use stepwise coverage penalty nor ngrams block
        remaining_sents = memory.size(0)
        global_scorer = GNMTGlobalScorer(alpha, length_pen)
        beam = [
            Beam(beam_size,
                 vocab,
                 global_scorer=global_scorer,
                 device=memory.device) for _ in range(remaining_sents)
        ]

        # repeat beam_size times
        memory, src_mask, copy_tokens = tile([memory, src_mask, copy_tokens],
                                             beam_size,
                                             dim=0)
        hidden_states = tile(hidden_states, beam_size, dim=1)
        h_c = type(hidden_states) in [list, tuple]
        batch_idx = list(range(remaining_sents))

        for i in range(MAX_DECODE_LENGTH):
            # (a) construct beamsize * remaining_sents next words
            ys = torch.stack([
                b.get_current_state() for b in beam if not b.done()
            ]).contiguous().view(-1, 1)

            # (b) pass through the decoder network
            out, hidden_states = self.decode_one_step(ys, hidden_states,
                                                      memory, src_mask,
                                                      copy_tokens)
            out = out.contiguous().view(remaining_sents, beam_size, -1)

            # (c) advance each beam
            active, select_indices_array = [], []
            # Loop over the remaining_batch number of beam
            for b in range(remaining_sents):
                idx = batch_idx[
                    b]  # idx represent the original order in minibatch_size
                beam[idx].advance(out[b])
                if not beam[idx].done():
                    active.append((idx, b))
                select_indices_array.append(beam[idx].get_current_origin() +
                                            b * beam_size)

            # (d) update hidden_states history
            select_indices_array = torch.cat(select_indices_array, dim=0)
            if h_c:
                hidden_states = (hidden_states[0].index_select(
                    1, select_indices_array), hidden_states[1].index_select(
                        1, select_indices_array))
            else:
                hidden_states = hidden_states.index_select(
                    1, select_indices_array)

            if not active:
                break

            # (e) reserve un-finished batches
            active_idx = torch.tensor(
                [item[1] for item in active],
                dtype=torch.long,
                device=memory.device)  # original order in remaining batch
            batch_idx = {idx: item[0]
                         for idx, item in enumerate(active)
                         }  # order for next remaining batch

            def update_active(t):
                if t is None: return t
                t_reshape = t.contiguous().view(remaining_sents, beam_size, -1)
                new_size = list(t.size())
                new_size[0] = -1
                return t_reshape.index_select(0, active_idx).view(*new_size)

            if h_c:
                hidden_states = (update_active(hidden_states[0].transpose(
                    0, 1)).transpose(0, 1).contiguous(),
                                 update_active(hidden_states[1].transpose(
                                     0, 1)).transpose(0, 1).contiguous())
            else:
                hidden_states = update_active(hidden_states.transpose(
                    0, 1)).transpose(0, 1).contiguous()
            memory = update_active(memory)
            src_mask = update_active(src_mask)
            copy_tokens = update_active(copy_tokens)
            remaining_sents = len(active)

        for b in beam:
            scores, ks = b.sort_finished(minimum=n_best)
            hyps = []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp = b.get_hyp(times, k)
                hyps.append(
                    hyp.tolist())  # hyp contains </s> but does not contain <s>
            results["predictions"].append(
                hyps)  # batch list of variable_tgt_len
            results["scores"].append(torch.stack(
                scores)[:n_best])  # list of [n_best], torch.FloatTensor
        results["scores"] = torch.stack(results["scores"])
        return results
Exemple #4
0
    def translate_batch_ENSEMBLE(self, enc_output, enc_hidden, category):
        ''' Translation work in one batch '''
        def beam_decode_step(inst_dec_beams, enc_output, enc_hidden,
                             inst_idx_to_position_map, n_bm, category):
            ''' Decode and update beam status, and then return active beam idx '''
            def prepare_beam_dec_seq(inst_dec_beams):
                dec_partial_seq = [
                    b.get_lastest_state() for b in inst_dec_beams if not b.done
                ]
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1)
                #print(dec_partial_seq)
                return dec_partial_seq

            def predict_word(dec_seq, enc_output, enc_hidden, n_active_inst,
                             n_bm, category):
                word_prob = []
                for i in range(len(enc_output)):
                    res = self.model[i].decoder(it=dec_seq,
                                                encoder_outputs=enc_output[i],
                                                category=category,
                                                decoder_hidden=enc_hidden[i])
                    dec_output, enc_hidden[i] = res['dec_outputs'], res[
                        'dec_hidden']

                    tmp = F.log_softmax(self.model[i].tgt_word_prj(dec_output),
                                        dim=1)
                    tmp = tmp.view(n_active_inst, n_bm, -1)
                    word_prob.append(tmp)

                word_prob = torch.stack(word_prob, dim=0).mean(0)
                return word_prob, enc_hidden

            def collect_active_hidden_single(inst_beams,
                                             inst_idx_to_position_map,
                                             enc_hidden, n_bm):
                if isinstance(enc_hidden, tuple):
                    tmp1, tmp2 = enc_hidden
                    _, *d_hs = tmp1.size()
                    n_curr_active_inst = len(inst_idx_to_position_map)
                    new_shape = (n_curr_active_inst * n_bm, *d_hs)
                    tmp1 = tmp1.view(n_curr_active_inst, n_bm, -1)
                    tmp2 = tmp2.view(n_curr_active_inst, n_bm, -1)
                    #print('hidden:', tmp1)

                    for inst_idx, inst_position in inst_idx_to_position_map.items(
                    ):
                        _prev_ks = inst_beams[inst_idx].get_current_origin()
                        tmp1[inst_position] = tmp1[inst_position].index_select(
                            0, _prev_ks)
                        tmp2[inst_position] = tmp2[inst_position].index_select(
                            0, _prev_ks)
                        #print("PREV_KS:", _prev_ks)
                    #print('after h:', tmp1)
                    tmp1 = tmp1.view(*new_shape)
                    tmp2 = tmp2.view(*new_shape)
                    enc_hidden = (tmp1, tmp2)
                else:
                    _, *d_hs = enc_hidden.size()
                    n_curr_active_inst = len(inst_idx_to_position_map)
                    new_shape = (n_curr_active_inst * n_bm, *d_hs)
                    enc_hidden = enc_hidden.view(n_curr_active_inst, n_bm, -1)

                    for inst_idx, inst_position in inst_idx_to_position_map.items(
                    ):
                        _prev_ks = inst_beams[inst_idx].get_current_origin()
                        enc_hidden[inst_position] = enc_hidden[
                            inst_position].index_select(0, _prev_ks)

                    enc_hidden = enc_hidden.view(*new_shape)

                return enc_hidden

            def collect_active_hidden(inst_beams, inst_idx_to_position_map,
                                      enc_hidden, n_bm):
                if enc_hidden is None:
                    return None
                if isinstance(enc_hidden, list):
                    hidden = []
                    for item in enc_hidden:
                        hidden.append(
                            collect_active_hidden_single(
                                inst_beams, inst_idx_to_position_map, item,
                                n_bm))
                else:
                    hidden = collect_active_hidden_single(
                        inst_beams, inst_idx_to_position_map, enc_hidden, n_bm)
                return hidden

            n_active_inst = len(inst_idx_to_position_map)
            dec_seq = prepare_beam_dec_seq(inst_dec_beams)
            #print(dec_seq)
            #print('before:', enc_hidden[0])
            word_prob, enc_hidden = predict_word(dec_seq, enc_output,
                                                 enc_hidden, n_active_inst,
                                                 n_bm, category)
            #print('after:', enc_hidden[0])
            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = self.collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            #print(type(enc_hidden))
            #print(type(enc_hidden[0]))
            #print(type(enc_hidden[0][0]))
            #print(type(enc_hidden[0][0][0]))
            enc_hidden = [
                collect_active_hidden(inst_dec_beams, inst_idx_to_position_map,
                                      item, n_bm) for item in enc_hidden
            ]
            return active_inst_idx_list, enc_hidden

        with torch.no_grad():
            assert isinstance(enc_output, list)
            assert isinstance(enc_hidden, list)
            assert len(enc_output) == len(self.model)
            assert len(enc_output) == len(enc_hidden)

            for i in range(len(enc_output)):
                if not isinstance(enc_output[i], list):
                    enc_output[i] = [enc_output[i]]

            n_bm = self.opt["beam_size"]
            n_inst, len_s, d_h = enc_output[0][0].size()

            #-- Repeat data for beam search
            category = category.unsqueeze(1).repeat(1, n_bm, 1).view(
                n_inst * n_bm, self.opt['num_category'])
            enc_output = [[
                tmp.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)
                for tmp in item
            ] for item in enc_output]
            for i in range(len(enc_hidden)):
                if isinstance(enc_hidden[i], tuple):
                    enc_hidden[i] = (enc_hidden[i][0].unsqueeze(1).repeat(
                        1, n_bm,
                        1).view(n_inst * n_bm,
                                d_h), enc_hidden[i][1].unsqueeze(1).repeat(
                                    1, n_bm, 1).view(n_inst * n_bm, d_h))
                else:
                    enc_hidden[i] = enc_hidden[i].unsqueeze(1).repeat(
                        1, n_bm, 1).view(n_inst * n_bm, d_h)

            #-- initialize hidden state
            for i in range(len(enc_output)):
                enc_hidden[i] = self.model[i].decoder.init_hidden(
                    enc_hidden[i])

            #-- Prepare beams
            inst_dec_beams = [
                Beam(n_bm, self.opt["max_len"], device=self.device)
                for _ in range(n_inst)
            ]

            #-- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = self.get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            #-- Decode
            for t in range(1, self.opt["max_len"]):
                active_inst_idx_list, enc_hidden = beam_decode_step(
                    inst_dec_beams, enc_output, enc_hidden,
                    inst_idx_to_position_map, n_bm, category)

                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>

                enc_output, enc_hidden, category, inst_idx_to_position_map = self.collate_active_info(
                    enc_output,
                    inst_idx_to_position_map,
                    active_inst_idx_list,
                    category,
                    n_bm,
                    enc_hidden=enc_hidden)

        batch_hyp, batch_scores = self.collect_hypothesis_and_scores(
            inst_dec_beams, self.opt.get("topk", 1))

        return batch_hyp, batch_scores
Exemple #5
0
    def translate_batch_LSTM(self, encoder_outputs, category):
        ''' Translation work in one batch '''
        def beam_decode_step(inst_dec_beams, enc_output, enc_hidden,
                             inst_idx_to_position_map, n_bm, category, tag):
            ''' Decode and update beam status, and then return active beam idx '''
            def prepare_beam_dec_seq(inst_dec_beams):
                dec_partial_seq = [
                    b.get_lastest_state() for b in inst_dec_beams if not b.done
                ]
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1)
                #print(dec_partial_seq)
                return dec_partial_seq

            def predict_word(dec_seq, enc_output, enc_hidden, n_active_inst,
                             n_bm, category, tag):
                res = self.model.decoder(it=dec_seq,
                                         encoder_outputs=enc_output,
                                         category=category,
                                         decoder_hidden=enc_hidden,
                                         tag=tag)
                dec_output, enc_hidden, tag = res['dec_outputs'], res[
                    'dec_hidden'], res.get('pred_tag', None)
                word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output),
                                          dim=1)
                word_prob = word_prob.view(n_active_inst, n_bm, -1)
                return word_prob, enc_hidden, tag.argmax(
                    1) if tag is not None else None

            def collect_active_hidden_single(inst_beams,
                                             inst_idx_to_position_map,
                                             enc_hidden, n_bm):
                if isinstance(enc_hidden, tuple):
                    tmp1, tmp2 = enc_hidden
                    _, *d_hs = tmp1.size()
                    n_curr_active_inst = len(inst_idx_to_position_map)
                    new_shape = (n_curr_active_inst * n_bm, *d_hs)
                    tmp1 = tmp1.view(n_curr_active_inst, n_bm, -1)
                    tmp2 = tmp2.view(n_curr_active_inst, n_bm, -1)
                    #print('hidden:', tmp1)

                    for inst_idx, inst_position in inst_idx_to_position_map.items(
                    ):
                        _prev_ks = inst_beams[inst_idx].get_current_origin()
                        tmp1[inst_position] = tmp1[inst_position].index_select(
                            0, _prev_ks)
                        tmp2[inst_position] = tmp2[inst_position].index_select(
                            0, _prev_ks)
                        #print("PREV_KS:", _prev_ks)
                    #print('after h:', tmp1)
                    tmp1 = tmp1.view(*new_shape)
                    tmp2 = tmp2.view(*new_shape)
                    enc_hidden = (tmp1, tmp2)
                else:
                    _, *d_hs = enc_hidden.size()
                    n_curr_active_inst = len(inst_idx_to_position_map)
                    new_shape = (n_curr_active_inst * n_bm, *d_hs)
                    enc_hidden = enc_hidden.view(n_curr_active_inst, n_bm, -1)

                    for inst_idx, inst_position in inst_idx_to_position_map.items(
                    ):
                        _prev_ks = inst_beams[inst_idx].get_current_origin()
                        enc_hidden[inst_position] = enc_hidden[
                            inst_position].index_select(0, _prev_ks)

                    enc_hidden = enc_hidden.view(*new_shape)

                return enc_hidden

            def collect_active_hidden(inst_beams, inst_idx_to_position_map,
                                      enc_hidden, n_bm):
                if enc_hidden is None:
                    return None
                if isinstance(enc_hidden, list):
                    hidden = []
                    for item in enc_hidden:
                        hidden.append(
                            collect_active_hidden_single(
                                inst_beams, inst_idx_to_position_map, item,
                                n_bm))
                else:
                    hidden = collect_active_hidden_single(
                        inst_beams, inst_idx_to_position_map, enc_hidden, n_bm)
                return hidden
                '''
                _, *d_hs = beamed_tensor.size()
                n_curr_active_inst = len(curr_active_inst_idx)
                new_shape = (n_curr_active_inst * n_bm, *d_hs)

                print('n_prev_active:', n_prev_active_inst)
                print('n_curr_active:', curr_active_inst_idx)
                beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
                beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
                beamed_tensor = beamed_tensor.view(*new_shape)

                return beamed_tensor
                '''

            n_active_inst = len(inst_idx_to_position_map)
            dec_seq = prepare_beam_dec_seq(inst_dec_beams)
            #print(dec_seq)
            #print('before:', enc_hidden[0])
            word_prob, enc_hidden, tag = predict_word(dec_seq, enc_output,
                                                      enc_hidden,
                                                      n_active_inst, n_bm,
                                                      category, tag)
            #print('after:', enc_hidden[0])
            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = self.collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            enc_hidden = collect_active_hidden(inst_dec_beams,
                                               inst_idx_to_position_map,
                                               enc_hidden, n_bm)
            tag = collect_active_hidden(inst_dec_beams,
                                        inst_idx_to_position_map, tag, n_bm)
            return active_inst_idx_list, enc_hidden, tag

        with torch.no_grad():
            enc_output, enc_hidden = encoder_outputs[
                'enc_output'], encoder_outputs['enc_hidden']
            if not isinstance(enc_output, list):
                enc_output = [enc_output]

            n_bm = self.opt["beam_size"]
            n_inst, len_s, _ = enc_output[0].shape

            #-- Repeat data for beam search
            enc_output = [
                item.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, -1)
                for item in enc_output
            ]
            if isinstance(enc_hidden, tuple):
                n_inst, d_h = enc_hidden[0].size()
                enc_hidden = (enc_hidden[0].unsqueeze(1).repeat(
                    1, n_bm, 1).view(n_inst * n_bm,
                                     d_h), enc_hidden[1].unsqueeze(1).repeat(
                                         1, n_bm, 1).view(n_inst * n_bm, d_h))
            elif isinstance(enc_hidden, list):
                n_inst, d_h = enc_hidden[0].size()
                enc_hidden = [
                    item.unsqueeze(1).repeat(1, n_bm,
                                             1).view(n_inst * n_bm, d_h)
                    for item in enc_hidden
                ]
            else:
                n_inst, d_h = enc_hidden.size()
                enc_hidden = enc_hidden.unsqueeze(1).repeat(1, n_bm, 1).view(
                    n_inst * n_bm, d_h)
            enc_hidden = self.model.decoder.init_hidden(enc_hidden)
            if encoder_outputs.get('obj_emb', None) is not None:
                if self.opt['with_category']:
                    category = torch.cat(
                        [category, encoder_outputs['obj_emb']], dim=1)
                else:
                    category = encoder_outputs['obj_emb']

            category = category.unsqueeze(1).repeat(1, n_bm,
                                                    1).view(n_inst * n_bm, -1)

            #-- Prepare beams
            inst_dec_beams = [
                Beam(n_bm, self.opt["max_len"], device=self.device)
                for _ in range(n_inst)
            ]

            if self.opt['use_tag']:
                tag = category.new(n_inst, n_bm).fill_(Constants.BOS).view(
                    n_inst * n_bm).long()
            else:
                tag = None

            #-- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = self.get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            #-- Decode
            for t in range(1, self.opt["max_len"]):
                active_inst_idx_list, enc_hidden, tag = beam_decode_step(
                    inst_dec_beams, enc_output, enc_hidden,
                    inst_idx_to_position_map, n_bm, category, tag)

                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>

                enc_output, enc_hidden, category, inst_idx_to_position_map, tag = self.collate_active_info(
                    enc_output,
                    inst_idx_to_position_map,
                    active_inst_idx_list,
                    category,
                    n_bm,
                    enc_hidden=enc_hidden,
                    tag=tag)

        batch_hyp, batch_scores = self.collect_hypothesis_and_scores(
            inst_dec_beams, self.opt.get("topk", 1))

        return batch_hyp, batch_scores
Exemple #6
0
    def translate_batch_ARFormer(self, encoder_outputs, category):
        ''' Translation work in one batch '''
        def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
                             inst_idx_to_position_map, n_bm, category,
                             attribute):
            ''' Decode and update beam status, and then return active beam idx '''
            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
                dec_partial_seq = [
                    b.get_current_state() for b in inst_dec_beams if not b.done
                ]
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
                #print(dec_partial_seq)
                return dec_partial_seq

            def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
                             category, attribute):
                dec_output, *_ = self.model.decoder(dec_seq,
                                                    enc_output,
                                                    category,
                                                    tags=attribute)
                if isinstance(dec_output, list):
                    dec_output = dec_output[-1]
                dec_output = dec_output[:,
                                        -1, :]  # Pick the last step: (bh * bm) * d_h
                word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output),
                                          dim=1)
                word_prob = word_prob.view(n_active_inst, n_bm, -1)

                return word_prob

            n_active_inst = len(inst_idx_to_position_map)

            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
            word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
                                     category, attribute)

            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = self.collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            return active_inst_idx_list

        '''
        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                print(113, scores, tail_idxs)
                all_scores += [scores[:n_best]]

                hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
                all_hyp += [hyps]
            return all_hyp, all_scores
        '''

        with torch.no_grad():
            enc_output = encoder_outputs['enc_output']
            if isinstance(enc_output, list):
                assert len(enc_output) == 1
                enc_output = enc_output[0]
            #-- Repeat data for beam search
            n_bm = self.opt["beam_size"]
            n_inst, len_s, d_h = enc_output.size()
            enc_output = enc_output.repeat(1, n_bm,
                                           1).view(n_inst * n_bm, len_s, d_h)
            category = category.repeat(1, n_bm).view(n_inst * n_bm, 1)

            e = enc_output.clone()
            c = category.clone()

            attribute = encoder_outputs.get(Constants.mapping['attr'][0], None)
            if attribute is not None:
                attribute = attribute.unsqueeze(1).repeat(1, n_bm, 1).view(
                    n_inst * n_bm, -1)

            #-- Prepare beams
            inst_dec_beams = [
                Beam(n_bm,
                     self.opt["max_len"],
                     device=self.device,
                     specific_nums_of_sents=self.opt.get('topk', 1))
                for _ in range(n_inst)
            ]

            #-- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = self.get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            #-- Decode
            for len_dec_seq in range(1, self.opt["max_len"]):

                active_inst_idx_list = beam_decode_step(
                    inst_dec_beams, len_dec_seq, enc_output,
                    inst_idx_to_position_map, n_bm, category, attribute)

                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>

                enc_output, category, inst_idx_to_position_map, attribute = self.collate_active_info(
                    enc_output,
                    inst_idx_to_position_map,
                    active_inst_idx_list,
                    category,
                    n_bm,
                    tag=attribute)

        if self.opt.get('use_beam_decoder', False):
            batch_hyp, batch_scores = self.collect_hypothesis_and_scores_bd(
                inst_dec_beams, self.opt.get("topk", 1), e, c)
        else:
            batch_hyp, batch_scores = self.collect_hypothesis_and_scores(
                inst_dec_beams, self.opt.get("topk", 1))

        return batch_hyp, batch_scores
Exemple #7
0
    def decode_beam_search(self,
                           word_seqs,
                           lengths,
                           beam_size,
                           tag2idx,
                           extFeats=None,
                           with_snt_classifier=False,
                           masked_output=None):
        minibatch_size = len(
            lengths
        )  #word_seqs.size(0) if self.encoder.batch_first else word_seqs.size(1)
        max_length = max(
            lengths
        )  #word_seqs.size(1) if self.encoder.batch_first else word_seqs.size(0)
        # encoder
        embeds = self.get_token_embeddings(word_seqs, lengths)
        if type(extFeats) != type(None):
            concat_input = torch.cat((embeds, self.extFeats_linear(extFeats)),
                                     2)
        else:
            concat_input = embeds
        concat_input = self.dropout_layer(concat_input)
        packed_word_embeds = rnn_utils.pack_padded_sequence(concat_input,
                                                            lengths,
                                                            batch_first=True)
        packed_word_lstm_out, (enc_h_t, enc_c_t) = self.encoder(
            packed_word_embeds)  # bsize x seqlen x dim
        enc_word_lstm_out, unpacked_len = rnn_utils.pad_packed_sequence(
            packed_word_lstm_out, batch_first=True)

        # decoder
        if self.bidirectional:
            index_slices = [2 * i + 1 for i in range(self.num_layers)
                            ]  # generated from the reversed path
            index_slices = torch.tensor(index_slices,
                                        dtype=torch.long,
                                        device=self.device)
            h_t = torch.index_select(enc_h_t, 0, index_slices)
            c_t = torch.index_select(enc_c_t, 0, index_slices)
        else:
            h_t = enc_h_t
            c_t = enc_c_t

        h_t = h_t.repeat(1, beam_size, 1)
        c_t = c_t.repeat(1, beam_size, 1)
        word_lstm_out = enc_word_lstm_out.repeat(beam_size, 1, 1)

        beam = [
            Beam(beam_size, tag2idx, device=self.device)
            for k in range(minibatch_size)
        ]
        batch_idx = list(range(minibatch_size))
        remaining_sents = minibatch_size

        top_dec_h_t, top_dec_c_t = [0] * minibatch_size, [0] * minibatch_size
        for i in range(max_length):
            last_tags = torch.stack([
                b.get_current_state() for b in beam if not b.done
            ]).t().contiguous().view(-1,
                                     1)  # after t() -> beam_size * batch_size
            last_tags = last_tags.to(self.device)
            tag_embeds = self.dropout_layer(self.tag_embeddings(last_tags))
            decode_inputs = torch.cat(
                (self.dropout_layer(word_lstm_out[:, i:i + 1]), tag_embeds),
                2)  # (batch*beam) x 1 x insize
            tag_lstm_out, (dec_h_t, dec_c_t) = self.decoder(
                decode_inputs,
                (h_t,
                 c_t))  # (batch*beam) x 1 x insize => (batch*beam) x 1 x hsize

            tag_lstm_out_reshape = tag_lstm_out.contiguous().view(
                tag_lstm_out.size(0) * tag_lstm_out.size(1),
                tag_lstm_out.size(2))
            tag_space = self.hidden2tag(
                self.dropout_layer(tag_lstm_out_reshape))
            out = F.log_softmax(tag_space)  # (batch*beam) x outsize

            word_lk = out.view(beam_size, remaining_sents,
                               -1).transpose(0, 1).contiguous()

            active = []
            for b in range(minibatch_size):
                if beam[b].done:
                    continue
                if lengths[b] == i + 1:
                    beam[b].done = True
                    top_dec_h_t[b] = dec_h_t[:, b:b + beam_size, :]
                    top_dec_c_t[b] = dec_c_t[:, b:b + beam_size, :]
                idx = batch_idx[b]
                beam[b].advance(word_lk.data[idx])
                if not beam[b].done:
                    active.append(b)
                for dec_state in (dec_h_t, dec_c_t):
                    # (layer*direction) x beam*sent x Hdim
                    sent_states = dec_state.view(-1, beam_size,
                                                 remaining_sents,
                                                 dec_state.size(2))[:, :, idx]
                    sent_states.data.copy_(
                        sent_states.data.index_select(
                            1, beam[b].get_current_origin()))
            if not active:
                break

            active_idx = torch.tensor([batch_idx[k] for k in active],
                                      dtype=torch.long,
                                      device=self.device)
            batch_idx = {beam: idx for idx, beam in enumerate(active)}

            def update_active(t, hidden_dim):
                #t_reshape = t.data.view(-1, remaining_sents, hidden_dim)
                t_reshape = t.contiguous().view(-1, remaining_sents,
                                                hidden_dim)
                new_size = list(t.size())
                new_size[-2] = new_size[-2] * len(
                    active_idx) // remaining_sents  # beam*len(active_idx)
                return t_reshape.index_select(1, active_idx).view(*new_size)

            h_t = update_active(dec_h_t, self.hidden_dim)
            c_t = update_active(dec_c_t, self.hidden_dim)
            word_lstm_out = update_active(
                word_lstm_out.transpose(0, 1),
                self.num_directions * self.hidden_dim).transpose(0, 1)

            remaining_sents = len(active)

        allHyp, allScores = [], []
        n_best = 1
        for b in range(minibatch_size):
            scores, ks = beam[b].sort_best()
            allScores += [scores[:n_best]]
            hyps = zip(*[beam[b].get_hyp(k) for k in ks[:n_best]])
            allHyp += [hyps]
            top_dec_h_t[b] = top_dec_h_t[b].data.index_select(1, ks[:n_best])
            top_dec_c_t[b] = top_dec_c_t[b].data.index_select(1, ks[:n_best])
        top_dec_h_t = torch.cat(top_dec_h_t, 1)
        top_dec_c_t = torch.cat(top_dec_c_t, 1)
        allScores = torch.cat(allScores)

        if with_snt_classifier:
            return allScores, allHyp, ((enc_h_t, enc_c_t), enc_word_lstm_out,
                                       lengths)
        else:
            return allScores, allHyp
    def translate_batch_ARFormer(self, encoder_outputs, category):
        ''' Translation work in one batch '''
        def beam_decode_step(inst_dec_beams, len_dec_seq, inputs_for_decoder,
                             inst_idx_to_position_map, n_bm):
            ''' Decode and update beam status, and then return active beam idx '''
            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
                dec_partial_seq = [
                    b.get_current_state() for b in inst_dec_beams if not b.done
                ]
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
                return dec_partial_seq

            def predict_word(dec_seq, inputs_for_decoder, n_active_inst, n_bm):
                dec_output, *_ = self.model.decoder(dec_seq,
                                                    **inputs_for_decoder)
                if isinstance(dec_output, list):
                    dec_output = dec_output[-1]
                dec_output = dec_output[:,
                                        -1, :]  # Pick the last step: (bh * bm) * d_h

                word_prob = self.model.tgt_word_prj(dec_output)
                word_prob = F.log_softmax(word_prob, dim=1)
                #print(word_prob[0, :10])
                word_prob = word_prob.view(n_active_inst, n_bm, -1)

                return word_prob

            n_active_inst = len(inst_idx_to_position_map)

            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
            word_prob = predict_word(dec_seq, inputs_for_decoder,
                                     n_active_inst, n_bm)

            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = self.collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            return active_inst_idx_list

        with torch.no_grad():
            inputs_for_decoder = self.model.prepare_inputs_for_decoder(
                encoder_outputs, category)
            #-- Repeat data for beam search
            n_bm = self.opt["beam_size"]
            n_inst = inputs_for_decoder['enc_output'].size(0)

            for key in inputs_for_decoder:
                inputs_for_decoder[key] = auto_enlarge(inputs_for_decoder[key],
                                                       n_bm)

            #-- Prepare beams
            inst_dec_beams = [
                Beam(n_bm,
                     self.opt["max_len"],
                     device=self.device,
                     specific_nums_of_sents=self.opt.get('topk', 1))
                for _ in range(n_inst)
            ]

            #-- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = self.get_inst_idx_to_tensor_position_map(
                active_inst_idx_list)

            #-- Decode
            for len_dec_seq in range(1, self.opt["max_len"]):

                active_inst_idx_list = beam_decode_step(
                    inst_dec_beams, len_dec_seq, inputs_for_decoder,
                    inst_idx_to_position_map, n_bm)

                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>

                inputs_for_decoder, inst_idx_to_position_map = self.collate_active_info(
                    inputs_for_decoder, inst_idx_to_position_map,
                    active_inst_idx_list, n_bm)

        batch_hyp, batch_scores = self.collect_hypothesis_and_scores(
            inst_dec_beams, self.opt.get("topk", 1))

        return batch_hyp, batch_scores