Ejemplo n.º 1
0
    def translate_batch(self, batch):

        torch.set_grad_enabled(False)
        # Batch size is in different location depending on data.

        beam_size = self.opt.beam_size
        batch_size = batch.size

        gold_scores = batch.get('source').data.new(batch_size).float().zero_()
        gold_words = 0
        allgold_scores = []

        if batch.has_target:
            # Use the first model to decode
            model_ = self.models[0]

            gold_words, gold_scores, allgold_scores = model_.decode(batch)

        #  (3) Start decoding

        # time x batch * beam

        # initialize the beam
        beam = [onmt.Beam(beam_size, self.bos_id, self.opt.cuda, self.opt.sampling) for k in range(batch_size)]

        batch_idx = list(range(batch_size))
        remaining_sents = batch_size

        decoder_states = dict()

        for i in range(self.n_models):
            decoder_states[i] = self.models[i].create_decoder_state(batch, beam_size)

        if self.opt.lm:
            lm_decoder_states = self.lm_model.create_decoder_state(batch, beam_size)

        for i in range(self.opt.max_sent_length):
            # Prepare decoder input.

            # input size: 1 x ( batch * beam )
            input = torch.stack([b.getCurrentState() for b in beam
                                 if not b.done]).t().contiguous().view(1, -1)

            decoder_input = input

            # require batch first for everything
            outs = dict()
            attns = dict()

            for k in range(self.n_models):
                # decoder_hidden, coverage = self.models[k].decoder.step(decoder_input.clone(), decoder_states[k])

                # run decoding on the model
                decoder_output = self.models[k].step(decoder_input.clone(), decoder_states[k])

                # extract the required tensors from the output (a dictionary)
                outs[k] = decoder_output['log_prob']
                attns[k] = decoder_output['coverage']

            # for ensembling models
            out = self._combine_outputs(outs)
            attn = self._combine_attention(attns)

            # for lm fusion
            if self.opt.lm:
                lm_decoder_output = self.lm_model.step(decoder_input.clone(), lm_decoder_states)

                # fusion
                lm_out =  lm_decoder_output['log_prob']
                # out = out + 0.3 * lm_out

                out = lm_out
            word_lk = out.view(beam_size, remaining_sents, -1) \
                .transpose(0, 1).contiguous()
            attn = attn.view(beam_size, remaining_sents, -1) \
                .transpose(0, 1).contiguous()

            active = []

            for b in range(batch_size):
                if beam[b].done:
                    continue

                idx = batch_idx[b]
                if not beam[b].advance(word_lk.data[idx], attn.data[idx]):
                    active += [b]

                for j in range(self.n_models):
                    decoder_states[j].update_beam(beam, b, remaining_sents, idx)

                if self.opt.lm:
                    lm_decoder_states.update_beam(beam, b, remaining_sents, idx)

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            active_idx = self.tt.LongTensor([batch_idx[k] for k in active])
            batch_idx = {beam: idx for idx, beam in enumerate(active)}

            for j in range(self.n_models):
                decoder_states[j].prune_complete_beam(active_idx, remaining_sents)

            if self.opt.lm:
                lm_decoder_states.prune_complete_beam(active_idx, remaining_sents)

            remaining_sents = len(active)

        #  (4) package everything up
        all_hyp, all_scores, all_attn = [], [], []
        n_best = self.opt.n_best
        all_lengths = []

        for b in range(batch_size):
            scores, ks = beam[b].sortBest()

            all_scores += [scores[:n_best]]
            hyps, attn, length = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
            all_hyp += [hyps]
            all_lengths += [length]
            # if(src_data.data.dim() == 3):
            if self.opt.encoder_type == 'audio':
                valid_attn = decoder_states[0].original_src.narrow(2, 0, 1).squeeze(2)[:, b].ne(onmt.Constants.PAD) \
                    .nonzero().squeeze(1)
            else:
                valid_attn = decoder_states[0].original_src[:, b].ne(onmt.Constants.PAD) \
                    .nonzero().squeeze(1)
            attn = [a.index_select(1, valid_attn) for a in attn]
            all_attn += [attn]

            if self.beam_accum:
                self.beam_accum["beam_parent_ids"].append(
                    [t.tolist()
                     for t in beam[b].prevKs])
                self.beam_accum["scores"].append([
                                                     ["%4f" % s for s in t.tolist()]
                                                     for t in beam[b].all_scores][1:])
                self.beam_accum["predicted_ids"].append(
                    [[self.tgt_dict.getLabel(id)
                      for id in t.tolist()]
                     for t in beam[b].nextYs][1:])

        torch.set_grad_enabled(True)

        return all_hyp, all_scores, all_attn, all_lengths, gold_scores, gold_words, allgold_scores
Ejemplo n.º 2
0
    def translateBatch(self, srcBatch, tgtBatch):

        batchSize = srcBatch[0].size(1)
        beamSize = self.opt.beam_size
        knntime = 0.0
        #  (1) run the encoder on the src
        encStates, context = self.model.encoder(srcBatch)
        srcBatch = srcBatch[0]  # drop the lengths needed for encoder

        rnnSize = context.size(2)
        encStates = (self.model._fix_enc_hidden(encStates[0]),
                     self.model._fix_enc_hidden(encStates[1]))

        #  This mask is applied to the attention model inside the decoder
        #  so that the attention ignores source padding
        padMask = srcBatch.data.eq(onmt.Constants.PAD).t()

        def applyContextMask(m):
            if isinstance(m, onmt.modules.GlobalAttention):
                m.applyMask(padMask)

        #  (3) run the decoder to generate sentences, using beam search

        # Expand tensors for each beam.
        # decStates = encStates
        if self.opt.use_lm:
            lm_hidden = self.LangModel.initialize_hidden(1, batchSize)

        context = Variable(context.data.repeat(1, beamSize, 1))
        decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)),
                     Variable(encStates[1].data.repeat(1, beamSize, 1)))

        beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]

        decOut = self.model.make_init_decoder_output(context)

        padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat(
            beamSize, 1, 1)

        batchIdx = list(range(batchSize))
        remainingSents = batchSize
        for i in range(self.opt.max_sent_length):

            self.model.decoder.apply(applyContextMask)

            # Prepare decoder input.
            input_ = torch.stack([
                b.getCurrentState() for b in beam if not b.done()
            ]).t().contiguous().view(1, -1)
            input_var = Variable(input_, volatile=True)
            decOut, decStates, attn = self.model.decoder(
                input_var, decStates, context, decOut)

            # decOut: 1 x (beam*batch) x numWords
            decOut = decOut.squeeze(0)
            out = self.model.generator.forward(decOut)

            if self.opt.use_lm:
                lm_output, lm_hidden, _, _ = self.LangModel(
                    input_var, lm_hidden)
                lm_output = torch.log(lm_output + 1e-12)

            beg = time.time()
            scores = self._get_scores(out, self.target_embeddings)
            diff = time.time() - beg
            knntime += diff

            if self.opt.use_lm:
                scores += 0.2 * lm_output.squeeze(0)

            wordLk = scores.view(beamSize, remainingSents,
                                 -1).transpose(0, 1).contiguous()
            attn = attn.view(beamSize, remainingSents,
                             -1).transpose(0, 1).contiguous()

            active = []
            for b in range(batchSize):
                if beam[b].done():
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
                    active += [b]

                for decState in decStates:  # iterate over h, c
                    # layers x beam*sent x dim
                    sentStates = decState.view(-1, beamSize, remainingSents,
                                               decState.size(2))[:, :, idx]
                    sentStates.data.copy_(
                        sentStates.data.index_select(
                            1, beam[b].getCurrentOrigin()))

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, rnnSize)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return Variable(view.index_select(1, activeIdx) \
                                    .view(*newSize), volatile=True)

            decStates = (updateActive(decStates[0]),
                         updateActive(decStates[1]))
            decOut = updateActive(decOut)
            context = updateActive(context)
            padMask = padMask.index_select(1, activeIdx)

            remainingSents = len(active)

        #  (4) package everything up

        allHyp, allScores, allAttn = [], [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            valid_attn = srcBatch.data[:, b].ne(
                onmt.Constants.PAD).nonzero().squeeze(1)

            hyps, attn = zip(
                *[beam[b].getHyp(times, k) for (times, k) in ks[:n_best]])
            attn = [a.index_select(1, valid_attn) for a in attn]
            allHyp += [hyps]
            allAttn += [attn]

        return allHyp, allScores, allAttn, knntime
Ejemplo n.º 3
0
    def translate_batch(self, batch, length_batch=None):

        torch.set_grad_enabled(False)
        # Batch size is in different location depending on data.

        beam_size = self.opt.beam_size
        batch_size = batch.size

        gold_scores = batch.get('source').data.new(batch_size).float().zero_()
        gold_words = 0
        allgold_scores = []

        prefix = None
        if batch.has_target:
            # Use the first model to decode
            model_ = self.models[0]
            gold_words, gold_scores, allgold_scores = model_.decode(batch)

            # batch.tensors['target_output'] =   # remove EOS
            prefix = batch.tensors['target_output'][:-1]
            print('PREFIX',
                  self.build_target_tokens(batch.tensors['target_output']))

        #  (3) Start decoding

        # time x batch * beam

        # initialize the beam
        beam = [
            onmt.Beam(beam_size,
                      self.opt.cuda,
                      prefix=prefix,
                      prefix_score=allgold_scores) for k in range(batch_size)
        ]

        batch_idx = list(range(batch_size))
        remaining_sents = batch_size

        decoder_states = dict()

        for i in range(self.n_models):
            decoder_states[i] = self.models[i].create_decoder_state(
                batch, beam_size, length_batch)

            if batch.has_target:
                prefix_states = []
                for state in beam[i].get_all_states():
                    prefix_states.append(
                        torch.stack([state]).t().contiguous().view(1, -1))

                for p in prefix_states:
                    decoder_output = self.models[i].step(
                        p.clone(), decoder_states[i])
                    # print('prefix', p)
        # can clear prefices from beam
        beam = [onmt.Beam(beam_size, self.opt.cuda) for k in range(batch_size)]

        if self.opt.lm:
            lm_decoder_states = self.lm_model.create_decoder_state(
                batch, beam_size)

        max_len = self.opt.max_sent_length
        if batch.has_target:
            max_len -= len(prefix)
            # print(max_len, len(prefix), len(prefix_states))

        for current_depth in range(max_len):  # EOS here?
            # Prepare decoder input.
            # print(current_depth, max_len)
            # input size: 1 x ( batch * beam )
            input = torch.stack([
                b.getCurrentState() for b in beam if not b.done
            ]).t().contiguous().view(1, -1)

            decoder_input = input

            # require batch first for everything
            outs = dict()
            attns = dict()

            for k in range(self.n_models):
                # decoder_hidden, coverage = self.models[k].decoder.step(decoder_input.clone(), decoder_states[k])

                # run decoding on the model
                if not (current_depth == 0 and batch.has_target):
                    # print('decoding ', self.tgt_dict.convertToLabels(decoder_input.data[0], 10))
                    decoder_output = self.models[k].step(
                        decoder_input.clone(), decoder_states[k],
                        current_depth +
                        (len(prefix) if prefix is not None else 0))
                    # print('new input', decoder_input)
                # else:
                #     print('skipped last of prefix')

                # extract the required tensors from the output (a dictionary)
                outs[k] = decoder_output['log_prob']
                # print('outs when decoding ', outs[k])
                attns[k] = decoder_output['coverage']

            # for ensembling models
            out = self._combine_outputs(outs)
            attn = self._combine_attention(attns)

            # for lm fusion
            if self.opt.lm:
                lm_decoder_output = self.lm_model.step(decoder_input.clone(),
                                                       lm_decoder_states)

                # fusion
                lm_out = lm_decoder_output['log_prob']
                # out = out + 0.3 * lm_out

                out = lm_out

            word_lk = out.view(beam_size, remaining_sents, -1) \
                .transpose(0, 1).contiguous()
            attn = attn.view(beam_size, remaining_sents, -1) \
                .transpose(0, 1).contiguous()

            active = []

            for seq_idx in range(batch_size):
                if beam[seq_idx].done:
                    continue

                idx = batch_idx[seq_idx]

                # Added two conditions for constrained decoding
                if self.force_target_length and length_batch and length_batch[
                        seq_idx] == current_depth:  # TODO: offset by prefix len
                    # finish hyp b since it has desired length
                    beam[seq_idx].advanceEOS(word_lk.data[idx], attn.data[idx])
                elif self.force_target_length and length_batch:
                    # ignore EOS since we are not at the end
                    word_lk[idx].select(1,
                                        onmt.Constants.EOS).zero_().add_(-1000)
                    if not beam[seq_idx].advance(word_lk.data[idx],
                                                 attn.data[idx]):
                        active += [seq_idx]
                elif not beam[seq_idx].advance(word_lk.data[idx],
                                               attn.data[idx],
                                               start_from_prefix=current_depth
                                               == 0):
                    active += [seq_idx]

                for j in range(self.n_models):
                    decoder_states[j].update_beam(beam, seq_idx,
                                                  remaining_sents, idx)

                if self.opt.lm:
                    lm_decoder_states.update_beam(beam, seq_idx,
                                                  remaining_sents, idx)

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            active_idx = self.tt.LongTensor([batch_idx[k] for k in active])
            batch_idx = {beam: idx for idx, beam in enumerate(active)}

            for j in range(self.n_models):
                decoder_states[j].prune_complete_beam(active_idx,
                                                      remaining_sents)

            if self.opt.lm:
                lm_decoder_states.prune_complete_beam(active_idx,
                                                      remaining_sents)

            remaining_sents = len(active)

            # if commit_depth == 0:
            #     for seq_idx in range(batch_size):
            #         beam[seq_idx].commit(buffer=decoding_buffer_depths)
            #
            # elif commit_depth > 0:
            #     raise NotImplementedError

        #  (4) package everything up
        all_hyp, all_scores, all_attn, all_lk = [], [], [], []
        n_best = self.opt.n_best
        all_lengths = []

        for seq_idx in range(batch_size):
            scores, ks = beam[seq_idx].sortBest()

            all_scores += [scores[:n_best]]
            hyps, attn, length = zip(*[
                beam[seq_idx].getHyp(k, return_att=False) for k in ks[:n_best]
            ])
            # append given prefix to beginning of output
            if prefix is not None:
                prefix_ = [p_[seq_idx] for p_ in prefix.tolist()]
                hyps = [prefix_ + hyp for hyp in hyps]
            all_hyp += [hyps]
            all_lengths += [length]
            # if(src_data.data.dim() == 3):
            if self.opt.encoder_type == 'audio':
                valid_attn = decoder_states[0].original_src.narrow(2, 0, 1).squeeze(2)[:, seq_idx].ne(onmt.Constants.PAD) \
                    .nonzero().squeeze(1)
            else:
                valid_attn = decoder_states[0].original_src[:, seq_idx].ne(onmt.Constants.PAD) \
                    .nonzero().squeeze(1)
            # attn = [a.index_select(1, valid_attn) for a in attn]
            # all_attn += [attn]

            if self.beam_accum:
                self.beam_accum["beam_parent_ids"].append(
                    [t.tolist() for t in beam[seq_idx].prevKs])
                self.beam_accum["scores"].append(
                    [["%4f" % s for s in t.tolist()]
                     for t in beam[seq_idx].all_scores][1:])
                self.beam_accum["predicted_ids"].append(
                    [[self.tgt_dict.getLabel(id) for id in t.tolist()]
                     for t in beam[seq_idx].nextYs][1:])

            all_scores_ = [beam[seq_idx].allScores[-1]]  # take last
            my_indices = range(beam[seq_idx].size)
            for j in range(len(beam[seq_idx].prevKs) - 1, -1, -1):
                my_indices = beam[seq_idx].prevKs[j][my_indices]
                all_scores_.append(beam[seq_idx].allScores[j][my_indices])
                # print(all_scores_[-1])
            all_lk.append(all_scores_[::-1])

        torch.set_grad_enabled(True)

        return all_hyp, all_scores, all_attn, all_lengths, gold_scores, gold_words, allgold_scores, all_lk
Ejemplo n.º 4
0
    def translateBatch(self, srcBatch, tgtBatch):
        # Batch size is in different location depending on data.

        beamSize = self.opt.beam_size

        #  (1) run the encoder on the src
        encStates, context, emb = self.model.encoder(srcBatch)

        # Drop the lengths needed for encoder.
        srcBatch = srcBatch[0]
        batchSize = self._getBatchSize(srcBatch)

        rnnSize = context.size(2)
        decoder = self.model.decoder
        attentionLayer = decoder.attn if hasattr(decoder, 'attn') else None

        if isinstance(self.model.encoder, Encoder):
            if isinstance(encStates, tuple):
                encStates = tuple(self.model.brnn_merge_concat(encStates[i])
                               for i in range(len(encStates)))
            else:
                encStates = self.model.brnn_merge_concat(encStates)
                if encStates.size(0) < decoder.layers:
                    encStates = encStates.repeat(decoder.layers, 1, 1)
        else:
            encStates = Variable(encStates.data.new(*encStates.size()).zero_(), requires_grad=False)
        #    encStates = encStates.unsqueeze(0).repeat(decoder.layers, 1, 1)


        useMasking = not isinstance(decoder, SGUDecoder) #self._type.endswith("text")

        #  This mask is applied to the attention model inside the decoder
        #  so that the attention ignores source padding
        padMask = None
        if useMasking:
            padMask = srcBatch.data.eq(onmt.Constants.PAD).t()

        def mask(padMask):
            if useMasking:
                attentionLayer.applyMask(padMask)

        #  (2) if a target is specified, compute the 'goldScore'
        #  (i.e. log likelihood) of the target under the model
        goldScores = context.data.new(batchSize).zero_()

        if tgtBatch is not None:

            decStates = encStates

            mask(padMask)
            initOutput = self.model.make_init_decoder_output(context)

            decOut, decStates, attn = self.model.decoder(
                tgtBatch[:-1], decStates, context, initOutput)
            for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data):
                gen_t = self.model.generator.forward(dec_t)
                tgt_t = tgt_t.unsqueeze(1)
                scores = gen_t.data.gather(1, tgt_t)
                scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
                goldScores += scores

        #  (3) run the decoder to generate sentences, using beam search

        # Expand tensors for each beam.
        context = Variable(context.data.repeat(1, beamSize, 1))
        if isinstance(emb, PackedSequence):
            emb = Variable(unpack(emb)[0].data.repeat(1, beamSize, 1))
        else:
            emb = Variable(emb.data.repeat(1, beamSize, 1))

        if isinstance(encStates, tuple):
            decStates = tuple(Variable(encStates[i].data.repeat(1, beamSize, 1))
                         for i in range(len(encStates)))
        else:
            decStates = Variable(encStates.data.repeat(1, beamSize, 1))

        beam = [onmt.Beam(beamSize, self.opt.cuda) for _ in range(batchSize)]

        decOut = self.model.make_init_decoder_output(context)

        if useMasking:
            padMask = srcBatch.data.eq(
                onmt.Constants.PAD).t() \
                                   .unsqueeze(0) \
                                   .repeat(beamSize, 1, 1)

        batchIdx = list(range(batchSize))
        remainingSents = batchSize

        activs = []
        for i in range(self.opt.max_sent_length):
            mask(padMask)
            # Prepare decoder input.
            input = torch.stack([b.getCurrentState() for b in beam
                                 if not b.done]).t().contiguous().view(1, -1)

            #if self.model.decoder.log:
            #    decOut, decStates, attn, activ = self.model.decoder(
            #        Variable(input, volatile=True), decStates, context, decOut, emb)
            #    activs.append(activ)
            #else:
            decOut, decStates, attn = self.model.decoder(
                Variable(input, volatile=True), decStates, context, decOut)

            # decOut: 1 x (beam*batch) x numWords
            decOut = decOut.squeeze(0)
            out = self.model.generator.forward(decOut)

            # batch x beam x numWords
            wordLk = out.view(beamSize, remainingSents, -1) \
                        .transpose(0, 1).contiguous()
            attn = attn.view(beamSize, remainingSents, -1) \
                       .transpose(0, 1).contiguous()

            active = []
            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
                    active += [b]
                #print(decStates)
                if not isinstance(decStates, tuple):
                    decStates = tuple(decStates.unsqueeze(0))
                #print(decStates)
                for decState in decStates:  # iterate over h, c
                    # layers x beam*sent x dim
                    sentStates = decState.view(-1, beamSize,
                                               remainingSents,
                                               decState.size(2))[:, :, idx]
                    sentStates.data.copy_(
                        sentStates.data.index_select(
                            1, beam[b].getCurrentOrigin()))

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t, lastSize=rnnSize):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, lastSize)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return Variable(view.index_select(1, activeIdx)
                                .view(*newSize), volatile=True)

            decStates = tuple(updateActive(decStates[i])
                         for i in range(len(decStates)))

            if len(decStates) == 1:
                # The GRU needs only one matrix as hidden state
                decStates = decStates[0]

            decOut = updateActive(decOut)
            context = updateActive(context)
            emb = updateActive(emb, emb.size(2))

            if useMasking:
                padMask = padMask.index_select(1, activeIdx)

            remainingSents = len(active)

        #  (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        n_best = self.opt.n_best

        if activs:
            new_activs = torch.zeros((2, activs[0].size(1), len(activs)))
            for i, activ in enumerate(activs):
                new_activs[:, :activ.size(1), i] = activ.data
            activs = new_activs
            sys.stderr.write("r=\n")
            for i in range(activs.size(1)):
                for j in range(activs.size(2)):
                    sys.stderr.write(str(activs[0][i][j]) + " ")
                sys.stderr.write("\n")
            sys.stderr.write("z=\n")
            for i in range(activs.size(1)):
                for j in range(activs.size(2)):
                    sys.stderr.write(str(activs[1][i][j]) + " ")
                sys.stderr.write("\n")

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
            allHyp += [hyps]
            if useMasking:
                valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \
                                                .nonzero().squeeze(1)
                attn = [a.index_select(1, valid_attn) for a in attn]
            allAttn += [attn]

            if self.beam_accum:
                self.beam_accum["beam_parent_ids"].append(
                    [t.tolist()
                     for t in beam[b].prevKs])
                self.beam_accum["scores"].append([
                    ["%4f" % s for s in t.tolist()]
                    for t in beam[b].allScores][1:])
                self.beam_accum["predicted_ids"].append(
                    [[self.tgt_dict.getLabel(id)
                      for id in t.tolist()]
                     for t in beam[b].nextYs][1:])

        return allHyp, allScores, allAttn, goldScores
Ejemplo n.º 5
0
    def translateBatch(self, srcBatch):
        # Batch size is in different location depending on data.
        beamSize = self.beam_size

        #  (1) run the encoders on the src

        states, context = self.model.encoder(srcBatch)

        # reshape the states
        encStates = (self.model._fix_enc_hidden(states[0]),
                     self.model._fix_enc_hidden(states[1]))

        # Drop the lengths needed for encoder.
        srcBatch = srcBatch[0]
        batchSize = self._getBatchSize(srcBatch)

        rnnSize = context.size(2)

        #~ decoder = self.model.decoder
        #~ attentionLayer = decoder.attn.current()
        useMasking = (batchSize > 1)

        #  This mask is applied to the attention model inside the decoder
        #  so that the attention ignores source padding
        padMask = None
        if useMasking:
            padMask = srcBatch.data.eq(onmt.Constants.PAD).t()

        def mask(padMask):
            if useMasking:
                #~ attentionLayer.applyMask(padMask)
                self.model.decoder.attn.current().applyMask(padMask)

        #  (2) run the decoder to generate sentences, using beam search

        # Expand tensors for each beam.

        context = Variable(context.data.repeat(1, beamSize, 1))

        decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)),
                     Variable(encStates[1].data.repeat(1, beamSize, 1)))

        # Initialize the beams
        # Each beam is an object containing the translation status for each sentence in the batch
        beam = [onmt.Beam(beamSize, self.cuda) for k in range(batchSize)]

        # Here we prepare the decoder output (zeroes)
        # For input feeding
        decOuts = self.model.make_init_decoder_output(context)

        if useMasking:
            padMask = srcBatch.data.eq(
                onmt.Constants.PAD).t() \
                                   .unsqueeze(0) \
                                   .repeat(beamSize, 1, 1)

        batchIdx = list(range(batchSize))
        remainingSents = batchSize

        #~ if self.model.copy_pointer:
        src = Variable(srcBatch.data.repeat(1,
                                            beamSize))  # time x batch * beam

        for i in range(self.max_sent_length):
            mask(padMask)
            # Prepare decoder input.
            input = torch.stack([
                b.getCurrentState() for b in beam if not b.done
            ]).t().contiguous().view(1, -1)

            # compute new decoder output (distribution)
            decOuts, decStates, attn = self.model.decoder(
                Variable(input, volatile=True), decStates, context, decOuts)

            # decOut: 1 x (beam*batch) x numWords
            decOuts = decOuts.squeeze(0)
            attn_ = attn
            attn = attn.squeeze(0)

            if self.model.copy_pointer:
                out = self.model.generator.forward(decOuts, attn_, src)
            else:
                out = self.model.generator.forward(decOuts)

            # batch x beam x numWords
            wordLk = out.view(beamSize, remainingSents, -1) \
                        .transpose(0, 1).contiguous()
            attn = attn.view(beamSize, remainingSents, -1) \
                       .transpose(0, 1).contiguous()

            active = []
            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
                    active += [b]

                for decState in decStates:  # iterate over h, c
                    # layers x beam*sent x dim
                    sentStates = decState.view(-1, beamSize, remainingSents,
                                               decState.size(2))[:, :, idx]
                    sentStates.data.copy_(
                        sentStates.data.index_select(
                            1, beam[b].getCurrentOrigin()))

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t, size):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, size)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return Variable(view.index_select(1, activeIdx).view(*newSize),
                                volatile=True)

            decStates = (updateActive(decStates[0], rnnSize),
                         updateActive(decStates[1], rnnSize))
            decOuts = updateActive(decOuts, rnnSize)
            context = updateActive(context, rnnSize)

            # src size: time x batch * beam
            src_data = src.data.view(-1, remainingSents)
            newSize = list(src.size())
            newSize[-1] = newSize[-1] * len(activeIdx) // remainingSents
            src = Variable(src_data.index_select(1, activeIdx).view(*newSize),
                           volatile=True)
            #~ srcBatch = Variable(srcBatch.data.repeat(1, beamSize))

            if useMasking:
                padMask = padMask.index_select(1, activeIdx)

            remainingSents = len(active)

        #  (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        n_best = self.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
            allHyp += [hyps]
            if useMasking:
                valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \
                                                .nonzero().squeeze(1)
                attn = [a.index_select(1, valid_attn) for a in attn]
            allAttn += [attn]

        if useMasking:
            self.model.decoder.attn.current().applyMask(None)

        return allHyp, allScores, allAttn
Ejemplo n.º 6
0
    def translateBatch(self, batch):
        beamSize = self.opt.beam_size
        batchSize = batch.batchSize

        #  (1) run the encoder on the src
        encStates, context = self.model.encoder(batch.src)
        encStates = self.model.init_decoder_state(context, encStates)

        decoder = self.model.decoder
        attentionLayer = decoder.attn
        useMasking = (self._type == "text")

        #  This mask is applied to the attention model inside the decoder
        #  so that the attention ignores source padding
        padMask = None
        if useMasking:
            padMask = batch.words().data.eq(onmt.Constants.PAD).t()

        def mask(padMask):
            if useMasking:
                attentionLayer.applyMask(padMask)

        #  (2) if a target is specified, compute the 'goldScore'
        #  (i.e. log likelihood) of the target under the model
        goldScores = context.data.new(batchSize).zero_()
        if batch.tgt is not None:
            decStates = encStates
            mask(padMask)
            decOut, decStates, attn = decoder(batch.tgt[:-1], context,
                                              decStates)
            for dec_t, tgt_t in zip(decOut, batch.tgt[1:].data):
                gen_t = self.model.generator.forward(dec_t)
                tgt_t = tgt_t.unsqueeze(1)
                scores = gen_t.data.gather(1, tgt_t)
                scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
                goldScores += scores

        #  (3) run the decoder to generate sentences, using beam search
        # Each hypothesis in the beam uses the same context
        # and initial decoder state
        context = Variable(context.data.repeat(1, beamSize, 1))
        batch_src = Variable(batch.src.data.repeat(1, beamSize, 1))
        decStates = encStates
        decStates.repeatBeam_(beamSize)
        beam = [onmt.Beam(beamSize, self.opt.cuda) for _ in range(batchSize)]
        if useMasking:
            padMask = batch.src.data[:, :, 0].eq(
                onmt.Constants.PAD).t() \
                                   .unsqueeze(0) \
                                   .repeat(beamSize, 1, 1)

        #  (3b) The main loop
        for i in range(self.opt.max_sent_length):
            # (a) Run RNN decoder forward one step.
            mask(padMask)
            input = torch.stack([b.getCurrentState() for b in beam])\
                         .t().contiguous().view(1, -1)
            input = Variable(input, volatile=True)
            decOut, decStates, attn = self.model.decoder(
                input, batch_src, context, decStates)
            decOut = decOut.squeeze(0)
            # decOut: (beam*batch) x numWords
            attn["std"] = attn["std"].view(beamSize, batchSize, -1) \
                                     .transpose(0, 1).contiguous()

            # (b) Compute a vector of batch*beam word scores.
            if not self.copy_attn:
                out = self.model.generator.forward(decOut)
            else:
                # Copy Attention Case
                words = batch.words().t()
                words = torch.stack([words[i] for i, b in enumerate(beam)])\
                             .contiguous()
                attn_copy = attn["copy"].view(beamSize, batchSize, -1) \
                                        .transpose(0, 1).contiguous()

                out, c_attn_t \
                    = self.model.generator.forward(
                        decOut, attn_copy.view(-1, batch_src.size(0)))

                for b in range(out.size(0)):
                    for c in range(c_attn_t.size(1)):
                        v = self.align[words[0, c].data[0]]
                        if v != onmt.Constants.PAD:
                            out[b, v] += c_attn_t[b, c]
                out = out.log()

            word_scores = out.view(beamSize, batchSize, -1) \
                .transpose(0, 1).contiguous()
            # batch x beam x numWords

            # (c) Advance each beam.
            active = []
            for b in range(batchSize):
                is_done = beam[b].advance(word_scores.data[b],
                                          attn["std"].data[b])
                if not is_done:
                    active += [b]
                decStates.beamUpdate_(b, beam[b].getCurrentOrigin(), beamSize)
            if not active:
                break

        #  (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            hyps, attn = [], []
            for k in ks[:n_best]:
                hyp, att = beam[b].getHyp(k)
                hyps.append(hyp)
                attn.append(att)
            allHyp += [hyps]
            if useMasking:
                valid_attn = batch.src.data[:, b, 0].ne(onmt.Constants.PAD) \
                                                .nonzero().squeeze(1)
                attn = [a.index_select(1, valid_attn) for a in attn]
            allAttn += [attn]

            # For debugging visualization.
            if self.beam_accum:
                self.beam_accum["beam_parent_ids"].append(
                    [t.tolist() for t in beam[b].prevKs])
                self.beam_accum["scores"].append(
                    [["%4f" % s for s in t.tolist()]
                     for t in beam[b].allScores][1:])
                self.beam_accum["predicted_ids"].append(
                    [[self.tgt_dict.getLabel(id) for id in t.tolist()]
                     for t in beam[b].nextYs][1:])

        return allHyp, allScores, allAttn, goldScores
Ejemplo n.º 7
0
def translate_batch_external(batch, beamSize, model, cuda, rb_init_token,
                             rb_init_tgt, max_sent_length, n_best):
    srcBatch, tgtBatch, src_rb, tgt_rb = batch
    batchSize = srcBatch.size(0)

    #  (1) run the encoder on the src

    # padding is dealt with by variable-length cudnn.RNN
    encStates, context = model.encoder(srcBatch, src_rb)
    # # have to execute the encoder manually to deal with padding
    # encStates = None
    # context = []
    # for srcBatch_t in srcBatch.chunk(srcBatch.size(1), dim=1):
    #     encStates, context_t = self.model.encoder(srcBatch_t, hidden=encStates)
    #     batchPadIdx = srcBatch_t.data.squeeze(1).eq(onmt.Constants.PAD).nonzero()
    #     if batchPadIdx.nelement() > 0:
    #         batchPadIdx = batchPadIdx.squeeze(1)
    #         encStates[0].data.index_fill_(1, batchPadIdx, 0)
    #         encStates[1].data.index_fill_(1, batchPadIdx, 0)
    #     context += [context_t]

    # context = torch.cat(context)

    rnnSize = context.size(2)

    encStates = (_fix_enc_hidden(encStates[0], model.encoder.num_directions),
                 _fix_enc_hidden(encStates[1], model.encoder.num_directions))

    #  This mask is applied to the attention model inside the decoder
    #  so that the attention ignores source padding
    padMask = srcBatch.data.eq(onmt.Constants.PAD)
    rb_token_mask = torch.zeros(padMask.size(0), 1).byte()
    if cuda:
        rb_token_mask = rb_token_mask.cuda()
    if rb_init_token:
        padMask = torch.cat([rb_token_mask, padMask], 1)

    def applyContextMask(m):
        if isinstance(m, onmt.modules.GlobalAttention):
            m.applyMask(padMask)

    #  (2) if a target is specified, compute the 'goldScore'
    #  (i.e. log likelihood) of the target under the model
    goldScores = context.data.new(batchSize).zero_()
    re_padMask = 1 - padMask
    re_padMask = re_padMask.float()
    context_t = context.transpose(0, 1).data
    masked_context = context_t * re_padMask.unsqueeze(2).expand(
        re_padMask.size(0), re_padMask.size(1), context_t.size(2))
    sent_len = torch.sum(re_padMask, 1).squeeze(1)
    representation = torch.div(
        torch.sum(masked_context, 1).squeeze(1),
        sent_len.unsqueeze(1).expand(sent_len.size(0), context.size(2)))
    if tgtBatch is not None:
        if rb_init_tgt:
            new_tgt_batch = tgtBatch[:, 1:]
            tgt_rb_token = tgt_rb.unsqueeze(1) + model.decoder.dict_size
            tgtBatch = torch.cat([tgt_rb_token, new_tgt_batch], 1)
        decStates = encStates
        decOut = model.make_init_decoder_output(context)
        model.decoder.apply(applyContextMask)
        initOutput = model.make_init_decoder_output(context)
        decOut, decStates, attn = model.decoder(tgtBatch[:, :-1], tgt_rb,
                                                decStates, context, initOutput)
        for dec_t, tgt_t in zip(decOut.transpose(0, 1),
                                tgtBatch.transpose(0, 1)[1:].data):
            gen_t = model.generator.forward(dec_t)
            tgt_t = tgt_t.unsqueeze(1)
            scores = gen_t.data.gather(1, tgt_t)
            scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
            goldScores += scores

    #  (3) run the decoder to generate sentences, using beam search

    # Expand tensors for each beam.
    context = Variable(context.data.repeat(1, beamSize, 1))
    decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)),
                 Variable(encStates[1].data.repeat(1, beamSize, 1)))
    if rb_init_tgt:
        beam = [
            onmt.Beam(beamSize, cuda, tgt_rb[k].data[0])
            for k in range(batchSize)
        ]
    else:
        beam = [onmt.Beam(beamSize, cuda) for k in range(batchSize)]

    decOut = model.make_init_decoder_output(context)

    padMask = srcBatch.data.eq(onmt.Constants.PAD).unsqueeze(0).repeat(
        beamSize, 1, 1)
    rb_token_mask = torch.zeros(padMask.size(0), padMask.size(1), 1).byte()
    if cuda:
        rb_token_mask = rb_token_mask.cuda()
    if rb_init_token:
        padMask = torch.cat([rb_token_mask, padMask], 2)
    batchIdx = list(range(batchSize))
    remainingSents = batchSize
    for i in range(max_sent_length):
        model.decoder.apply(applyContextMask)

        # Prepare decoder input.
        input = torch.stack([b.getCurrentState() for b in beam
                             if not b.done]).t().contiguous().view(1, -1)
        new_tgt_rb = torch.stack([
            tgt_rb[i].expand(beamSize) for i, b in enumerate(beam)
            if not b.done
        ]).contiguous().view(-1)
        '''some_done = False
        data = []
        for i, b in enumerate(beam):
            if b.done:
                some_done = True
            else:
                data.append(tgt_rb.data[i])
        if some_done:
            print data
            print new_tgt_rb'''
        decOut, decStates, attn = model.decoder(
            Variable(input).transpose(0, 1), new_tgt_rb, decStates, context,
            decOut)
        # decOut: 1 x (beam*batch) x numWords
        decOut = decOut.transpose(0, 1).squeeze(0)
        out = model.generator.forward(decOut)

        # batch x beam x numWords
        wordLk = out.view(beamSize, remainingSents,
                          -1).transpose(0, 1).contiguous()
        attn = attn.view(beamSize, remainingSents,
                         -1).transpose(0, 1).contiguous()

        active = []
        for b in range(batchSize):
            if beam[b].done:
                continue

            idx = batchIdx[b]
            if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
                active += [b]

            for decState in decStates:  # iterate over h, c
                # layers x beam*sent x dim
                sentStates = decState.view(-1, beamSize, remainingSents,
                                           decState.size(2))[:, :, idx]
                sentStates.data.copy_(
                    sentStates.data.index_select(1,
                                                 beam[b].getCurrentOrigin()))

        if not active:
            break

        # in this section, the sentences that are still active are
        # compacted so that the decoder is not run on completed sentences
        tt = torch.cuda if cuda else torch
        activeIdx = tt.LongTensor([batchIdx[k] for k in active])
        batchIdx = {beam: idx for idx, beam in enumerate(active)}

        def updateActive(t):
            # select only the remaining active sentences
            view = t.data.view(-1, remainingSents, rnnSize)
            newSize = list(t.size())
            newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
            return Variable(view.index_select(1, activeIdx) \
                            .view(*newSize))

        decStates = (updateActive(decStates[0]), updateActive(decStates[1]))
        decOut = updateActive(decOut)
        context = updateActive(context)
        padMask = padMask.index_select(1, activeIdx)

        remainingSents = len(active)
    #  (4) package everything up

    allHyp, allScores, allAttn = [], [], []

    for b in range(batchSize):
        scores, ks = beam[b].sortBest()

        allScores += [scores[:n_best]]
        valid_attn = srcBatch.transpose(0, 1).data[:, b].ne(
            onmt.Constants.PAD).nonzero().squeeze(1)
        hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
        attn = [a.index_select(1, valid_attn) for a in attn]
        allHyp += [hyps]
        allAttn += [attn]
    padMask = None
    model.decoder.apply(applyContextMask)
    return allHyp, allScores, allAttn, goldScores, representation
Ejemplo n.º 8
0
    def translateBatch(self, batch):
        beamSize = 15
        batchSize = batch.batchSize

        #  (1) run the encoder on the src
        encStates, context, fertility_vals = self.encoder(batch.src)
        encStates = self.init_decoder_state(context, encStates)

        def mask(padMask):
            self.decoder.attn.applyMask(padMask)

        #  (2) if a target is specified, compute the 'goldScore' (i.e. log likelihood) of the target under the model
        goldScores = context.data.new(batchSize).zero_()

        #  (3) run the decoder to generate sentences, using beam search
        # Each hypothesis in the beam uses the same context and initial decoder state
        context = Variable(context.data.repeat(1, beamSize, 1))
        batch_src = Variable(batch.src.data.repeat(1, beamSize, 1))
        decStates = encStates
        decStates.repeatBeam_(beamSize)
        beam = [onmt.Beam(beamSize, True) for _ in range(batchSize)]
        padMask = batch.src.data[:, :, 0].eq(
            onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1)

        #  (3b) The main loop

        upper_bounds = None
        for i in range(100):
            # (a) Run RNN decoder forward one step.
            mask(padMask)
            input = torch.stack([b.getCurrentState()
                                 for b in beam]).t().contiguous().view(1, -1)
            input = Variable(input, volatile=True)
            decOut, decStates, attn, upper_bounds = self.decoder(
                input,
                batch_src,
                context,
                decStates,
                upper_bounds=decStates.attn_upper_bounds,
                test=True)

            #import pdb; pdb.set_trace()
            decOut = decOut.squeeze(0)
            # decOut: (beam*batch) x numWords
            attn["std"] = attn["std"].view(beamSize, batchSize,
                                           -1).transpose(0, 1).contiguous()

            # (b) Compute a vector of batch*beam word scores.
            out = self.generator.forward(decOut)
            word_scores = out.view(beamSize, batchSize,
                                   -1).transpose(0, 1).contiguous()
            # batch x beam x numWords

            # (c) Advance each beam.
            active = []
            for b in range(batchSize):
                is_done = beam[b].advance(word_scores.data[b],
                                          attn["std"].data[b])
                if not is_done:
                    active += [b]
                decStates.beamUpdate_(b, beam[b].getCurrentOrigin(), beamSize)
            if not active:
                break

        #  (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        self.n_best = 1

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:self.n_best]]
            hyps, attn = [], []
            for k in ks[:self.n_best]:
                hyp, att = beam[b].getHyp(k)
                hyps.append(hyp)
                attn.append(att)
            allHyp += [hyps]
            valid_attn = batch.src.data[:, b, 0].ne(
                onmt.Constants.PAD).nonzero().squeeze(1)
            attn = [a.index_select(1, valid_attn) for a in attn]
            allAttn += [attn]

            self.decoder.attn.applyMaskNone()
        #print allAttn[0][0].sum(0)
        return allHyp, allScores, allAttn, goldScores
Ejemplo n.º 9
0
    def translateBatch(self, batch):
        beamSize = self.opt.beam_size
        batchSize = batch.batchSize

        #  (1) run the encoder on the src
        encStates, context = self.model.encoder(batch.src)

        rnnSize = context.size(2)
        encStates = self.model.setup_decoder(encStates)

        decoder = self.model.decoder
        attentionLayer = decoder.attn
        useMasking = (self._type == "text")

        #  This mask is applied to the attention model inside the decoder
        #  so that the attention ignores source (padding
        padMask = None
        if useMasking:
            padMask = batch.words().data.eq(onmt.Constants.PAD).t()

        def mask(padMask):
            if useMasking:
                attentionLayer.applyMask(padMask)

        #  (2) if a target is specified, compute the 'goldScore'
        #  (i.e. log likelihood) of the target under the model
        goldScores = context.data.new(batchSize).zero_()
        if batch.tgt is not None:
            decStates = encStates
            decOut = self.model.make_init_decoder_output(context)
            mask(padMask)
            initOutput = self.model.make_init_decoder_output(context)
            decOut, decStates, attn = self.model.decoder(
                batch.tgt[:-1], decStates, context, initOutput)
            for dec_t, tgt_t in zip(decOut, batch.tgt[1:].data):
                gen_t = self.model.generator.forward(dec_t)
                tgt_t = tgt_t.unsqueeze(1)
                scores = gen_t.data.gather(1, tgt_t)
                scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
                goldScores += scores

        #  (3) run the decoder to generate sentences, using beam search

        # Each hypothesis in the beam uses the same context
        # and initial decoder state
        context = Variable(context.data.repeat(1, beamSize, 1))
        decStates = tuple([Variable(e.data.repeat(1, beamSize, 1))
                           for e in encStates]) \
            if encStates else None

        beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]
        decOut = self.model.make_init_decoder_output(context)
        if useMasking:
            padMask = batch.src.data[:, :, 0].eq(
                onmt.Constants.PAD).t() \
                                   .unsqueeze(0) \
                                   .repeat(beamSize, 1, 1)

        batchIdx = list(range(batchSize))
        remainingSents = batchSize
        for i in range(self.opt.max_sent_length):
            mask(padMask)
            # Prepare decoder input.
            input = torch.stack([
                b.getCurrentState() for b in beam if not b.done
            ]).t().contiguous().view(1, -1)
            decOut, decStates, attn = self.model.decoder(
                Variable(input, volatile=True), batch.src, decStates, context,
                decOut)

            # decOut: 1 x (beam*batch) x numWords
            decOut = decOut.squeeze(0)

            attn["std"] = attn["std"].view(beamSize, remainingSents, -1) \
                                     .transpose(0, 1).contiguous()
            if not self.copy_attn or self.copy_attn == "std":
                out = self.model.generator.forward(decOut)
            else:
                words = batch.words().t()
                words = torch.stack([
                    words[i] for i, b in enumerate(beam) if not b.done
                ]).contiguous()

                attn_copy = attn["copy"].view(beamSize, remainingSents, -1) \
                                        .transpose(0, 1).contiguous()

                out, c_attn_t \
                    = self.model.generator.forward(
                        decOut, words,
                        attn_copy.view(-1, batch.src.size(0)))

                for b in range(out.size(0)):
                    for c in range(c_attn_t.size(1)):
                        v = self.align[words[0, c].data[0]]
                        if v != onmt.Constants.PAD:
                            out[b, v] += c_attn_t[b, c]
                out = out.log()

            # batch x beam x numWords
            wordLk = out.view(beamSize, remainingSents, -1) \
                        .transpose(0, 1).contiguous()
            active = []
            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx],
                                       attn["std"].data[idx]):
                    active += [b]

                for decState in decStates:  # iterate over h, c
                    # layers x beam*sent x dim
                    sentStates = decState.view(-1, beamSize, remainingSents,
                                               decState.size(2))[:, :, idx]
                    sentStates.data.copy_(
                        sentStates.data.index_select(
                            1, beam[b].getCurrentOrigin()))

            if not active:
                break

            # In this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t, size=rnnSize, batchPos=-2):
                # Select only the remaining active sentences
                view = t.data.view(-1, remainingSents, t.size(-1))
                newSize = list(t.size())

                newSize[batchPos] = newSize[batchPos] * \
                    len(activeIdx) // remainingSents
                return Variable(view.index_select(1, activeIdx).view(*newSize),
                                volatile=True)

            decStates = tuple([updateActive(d) for d in decStates])
            decOut = updateActive(decOut)
            context = updateActive(context)
            if useMasking:
                padMask = padMask.index_select(1, activeIdx)

            remainingSents = len(active)

        #  (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
            allHyp += [hyps]
            if useMasking:
                valid_attn = batch.src.data[:, b, 0].ne(onmt.Constants.PAD) \
                                                .nonzero().squeeze(1)
                attn = [a.index_select(1, valid_attn) for a in attn]
            allAttn += [attn]

            if self.beam_accum:
                self.beam_accum["beam_parent_ids"].append(
                    [t.tolist() for t in beam[b].prevKs])
                self.beam_accum["scores"].append(
                    [["%4f" % s for s in t.tolist()]
                     for t in beam[b].allScores][1:])
                self.beam_accum["predicted_ids"].append(
                    [[self.tgt_dict.getLabel(id) for id in t.tolist()]
                     for t in beam[b].nextYs][1:])

        return allHyp, allScores, allAttn, goldScores
Ejemplo n.º 10
0
    def translateBatch(self, batch, data):
        beamSize = self.opt.beam_size
        batchSize = batch.batch_size
        _, src_lengths = batch.src
        src = make_features(batch, self.fields)

        #  (1) run the encoder on the src
        encStates, context = self.model.encoder(src, lengths=src_lengths)
        encStates = self.model.init_decoder_state(context, encStates)

        useMasking = (self._type == "text")
        #  This mask is applied to the attention model inside the decoder
        #  so that the attention ignores source padding
        padMask = None
        tgt_pad = self.fields["tgt"].vocab.stoi[onmt.IO.PAD_WORD]
        if useMasking:
            pad = self.fields["src"].vocab.stoi[onmt.IO.PAD_WORD]
            padMask = src[:, :, 0].data.eq(pad).t()

        def mask(padMask):
            if useMasking:
                self.model.decoder.attn.applyMask(padMask)

        #  (2) if a target is specified, compute the 'goldScore'
        #  (i.e. log likelihood) of the target under the model
        goldScores = context.data.new(batchSize).zero_()
        if "tgt" in batch.__dict__:
            decStates = encStates
            mask(padMask.unsqueeze(0))
            decOut, decStates, attn = self.model.decoder(
                batch.tgt[:-1], batch.src, context, decStates)
            for dec_t, tgt_t in zip(decOut, batch.tgt[1:].data):
                gen_t = self.model.generator.forward(dec_t)
                tgt_t = tgt_t.unsqueeze(1)
                scores = gen_t.data.gather(1, tgt_t)
                scores.masked_fill_(tgt_t.eq(tgt_pad), 0)
                goldScores += scores

        #  (3) run the decoder to generate sentences, using beam search
        # Each hypothesis in the beam uses the same context
        # and initial decoder state
        context = Variable(context.data.repeat(1, beamSize, 1), volatile=True)
        batch_src = Variable(src.data.repeat(1, beamSize, 1), volatile=True)
        batch_src_map = Variable(batch.src_map.data.repeat(1, beamSize, 1),
                                 volatile=True)
        decStates = encStates
        decStates.repeatBeam_(beamSize)
        beam = [
            onmt.Beam(beamSize,
                      cuda=self.opt.cuda,
                      vocab=self.fields["tgt"].vocab)
            for __ in range(batchSize)
        ]
        if useMasking:
            padMask = src.data[:, :, 0].eq(pad).t() \
                                               .unsqueeze(0) \
                                               .repeat(beamSize, 1, 1)

        #  (3b) The main loop
        for i in range(self.opt.max_sent_length):
            # (a) Run RNN decoder forward one step.
            mask(padMask)
            input = torch.stack([b.getCurrentState() for b in beam])\
                         .t().contiguous().view(1, -1)
            input.masked_fill_(input.gt(len(self.fields["tgt"].vocab) - 1), 0)
            input = Variable(input, volatile=True)
            decOut, decStates, attn = self.model.decoder(
                input, batch_src, context, decStates)
            # print(decStates.all[0][:, 0, 0])
            decOut = decOut.squeeze(0)
            # decOut: (beam*batch) x numWords
            attn["std"] = attn["std"].view(beamSize, batchSize, -1) \
                                     .transpose(0, 1).contiguous()

            # (b) Compute a vector of batch*beam word scores.
            if not self.copy_attn:
                out = self.model.generator.forward(decOut).data
            else:
                # print(attn["copy"].size())
                attn_copy = attn["copy"].view(beamSize, batchSize, -1) \
                                        .transpose(0, 1).contiguous()
                out = self.model.generator.forward(
                    decOut, attn_copy.view(-1, batch_src.size(0)),
                    batch_src_map)
                out = data.collapseCopyScores(
                    out.data.view(batchSize, beamSize, -1).transpose(0, 1),
                    batch, self.fields["tgt"].vocab)
                out = out.log().transpose(0, 1).contiguous()\
                                               .view(beamSize * batchSize, -1)

            word_scores = out.view(beamSize, batchSize, -1) \
                .transpose(0, 1).contiguous()
            # batch x beam x numWords

            # (c) Advance each beam.
            active = []
            for b in range(batchSize):
                is_done = beam[b].advance(word_scores[b], attn["std"].data[b])
                if not is_done:
                    active += [b]
                decStates.beamUpdate_(b, beam[b].getCurrentOrigin(), beamSize)
            if not active:
                break

        #  (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            hyps, attn = [], []
            for k in ks[:n_best]:
                hyp, att = beam[b].getHyp(k)
                hyps.append(hyp)
                attn.append(att)
            allHyp += [hyps]
            if useMasking:
                valid_attn = src.data[:, b, 0].ne(pad) \
                                              .nonzero().squeeze(1)
                attn = [a.index_select(1, valid_attn) for a in attn]
            allAttn += [attn]

            # For debugging visualization.
            if self.beam_accum:
                self.beam_accum["beam_parent_ids"].append(
                    [t.tolist() for t in beam[b].prevKs])
                self.beam_accum["scores"].append(
                    [["%4f" % s for s in t.tolist()]
                     for t in beam[b].allScores][1:])
                self.beam_accum["predicted_ids"].append(
                    [[self.tgt_dict.getLabel(id) for id in t.tolist()]
                     for t in beam[b].nextYs][1:])

        return allHyp, allScores, allAttn, goldScores
Ejemplo n.º 11
0
    def beam_conf_once(self, srcBatch, tgtBatch, confidence_method,
                       conf_n_best):
        beamSize = self.opt.beam_size
        confidence_method_split = confidence_method.split(':')
        #  (1) run the encoder on the src
        encStates, context, rnnSize = self.conf_encode(srcBatch)

        #  (3) run the decoder to generate sentences, using beam search
        # Expand tensors for each beam.
        batchSize = self._getBatchSize(srcBatch[0])
        context = Variable(context.data.repeat(1, beamSize, 1))

        decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)),
                     Variable(encStates[1].data.repeat(1, beamSize, 1)))

        beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]

        decOut = self.model.make_init_decoder_output(context)

        padMask = srcBatch[0].data.eq(
            onmt.Constants.PAD).t() \
                               .unsqueeze(0) \
                               .repeat(beamSize, 1, 1)

        batchIdx = list(range(batchSize))
        remainingSents = batchSize
        for i in range(self.opt.max_sent_length):
            self.model.decoder.attn.applyMask(padMask)
            # Prepare decoder input.
            input = torch.stack([
                b.getCurrentState() for b in beam if not b.done
            ]).t().contiguous().view(1, -1)
            decOut, decStates, attn = self.model.decoder(
                Variable(input, volatile=True), decStates, context, decOut)
            # decOut: 1 x (beam*batch) x numWords
            decOut = decOut.squeeze(0)
            out = self.model.generator.forward(decOut)

            # batch x beam x numWords
            wordLk = out.view(beamSize, remainingSents, -1) \
                        .transpose(0, 1).contiguous()
            attn = attn.view(beamSize, remainingSents, -1) \
                       .transpose(0, 1).contiguous()

            active = []
            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
                    active += [b]

                for decState in decStates:  # iterate over h, c
                    # layers x beam*sent x dim
                    sentStates = decState.view(-1, beamSize, remainingSents,
                                               decState.size(2))[:, :, idx]
                    sentStates.data.copy_(
                        sentStates.data.index_select(
                            1, beam[b].getCurrentOrigin()))

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, rnnSize)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return Variable(view.index_select(1, activeIdx).view(*newSize),
                                volatile=True)

            decStates = (updateActive(decStates[0]),
                         updateActive(decStates[1]))
            decOut = updateActive(decOut)
            context = updateActive(context)
            padMask = padMask.index_select(1, activeIdx)

            remainingSents = len(active)

        #  (4) package everything up
        allScores = []
        for b in range(batchSize):
            scores, ks = beam[b].sortBest()
            allScores += [scores[:conf_n_best]]

        # -> (batch_size, conf_n_best)
        p_infer = torch.stack(allScores)

        return p_infer
Ejemplo n.º 12
0
    def translateBatch(self, srcBatch, tgtBatch):

        batchSize = srcBatch.size(1)
        beamSize = self.opt.beam_size

        decoder = self.model.decoder
        attentionLayer = decoder.attn
        useMasking = self.opt.mem == 'lstm_lstm' and self.model.decoder.use_attn

        def lstm_encoder(src):
            emb_in = self.model.word_lut(src)

            init_h = self.model.make_init_hidden(
                emb_in[0], src.size(1), self.model.decoder.hidden_size, 2)
            hidden = (torch.stack(init_h[0]), torch.stack(init_h[1]))

            context, hidden = self.model.encoder(emb_in, hidden)

            return context, hidden

        def lstm_decoder(tgt, hidden, context, decOut):

            if useMasking:
                padMask = srcBatch.data.eq(onmt.Constants.PAD).t()
                attentionLayer.applyMask(padMask)

            out, dec_hidden, _attn = self.model.decoder(
                tgt, hidden, context, decOut)

            return out, dec_hidden, _attn

        def dnc_encoder(src):
            batch_size = src.size(1)
            hidden = self.model.encoder.make_init_hidden(
                src[0], *self.model.encoder.rnn_sz)

            M = self.model.encoder.make_init_M(batch_size)

            emb_in = self.model.word_lut(src)
            return self.model.encoder(emb_in, hidden, M)

        #  (1) run the encoder on the src
        if self.opt.mem == 'lstm_lstm':
            context, encStates = lstm_encoder(srcBatch)
        elif self.opt.mem == 'dnc_dnc':
            context, encStates, M = dnc_encoder(srcBatch)

        rnnSize = encStates[0][0].size(1)

        #  (2) if a target is specified, compute the 'goldScore'
        #  (i.e. log likelihood) of the target under the model
        goldScores = encStates[0][0].data.new(batchSize).zero_()
        if tgtBatch is not None:
            decStates = encStates
            decM = M
            if self.opt.mem == 'lstm_lstm':
                init_output = self.model.make_init_decoder_output(context[0])
                decOut, decStates, attn = lstm_decoder(
                    tgtBatch[:-1], decStates, context, init_output)
            elif self.opt.mem == 'dnc_dnc':
                emb_out = self.model.word_lut(tgtBatch[:-1])
                decOut, decStates, decM = self.model.decoder(
                    emb_out, decStates, decM)

            for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data):
                gen_t = self.model.generator.forward(dec_t)
                tgt_t = tgt_t.unsqueeze(1)
                scores = gen_t.data.gather(1, tgt_t)
                scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
                goldScores += scores

        print(' == got gold ==')
        #  (3) run the decoder to generate sentences, using beam search

        # Expand tensors for each beam.
        if self.opt.mem == 'lstm_lstm':
            context = Variable(context.data.repeat(1, beamSize, 1))
        elif self.opt.mem == 'dnc_dnc':
            decM = {}
            for k in M.keys():
                print(k)
                dims = M[k].dim()
                if dims == 3:
                    decM[k] = Variable(M[k].data.repeat(beamSize, 1, 1))
                elif dims == 2:
                    decM[k] = Variable(M[k].data.repeat(beamSize, 1))
            print(' -- M:')
            [print(k, M[k].size()) for k in M.keys()]
            print(' -- decM:')
            [print(k, decM[k].size()) for k in decM.keys()]

        decStates = ((Variable(encStates[0][0].data.repeat(beamSize, 1)),
                      Variable(encStates[0][1].data.repeat(beamSize, 1))),
                     (Variable(encStates[1][0].data.repeat(beamSize, 1)),
                      Variable(encStates[1][1].data.repeat(beamSize, 1))))

        beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]

        decOut = self.model.make_init_decoder_output(
            decStates[0][0])  # .squeeze(0))

        if useMasking:
            padMask = srcBatch.data.eq(
                onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1)

        batchIdx = list(range(batchSize))
        remainingSents = batchSize
        for i in range(self.opt.max_sent_length):
            if useMasking:
                attentionLayer.applyMask(padMask)
                # Prepare decoder input.
            input = torch.stack([b.getCurrentState() for b in beam
                                 if not b.done]).t().contiguous().view(1, -1)
            if self.opt.mem == 'lstm_lstm':
                decOut, decStates, attn = self.model.decoder(
                    Variable(input, volatile=True), decStates, context, decOut)
            elif self.opt.mem == 'dnc_dnc':
                inp = self.model.word_lut(Variable(input, volatile=True))
                decOut, decStates, decM = self.model.decoder(
                    inp, decStates, decM)
            # decOut: 1 x (beam*batch) x numWords
            decOut = decOut.squeeze(0)
            out = self.model.generator.forward(decOut)

            # batch x beam x numWords
            wordLk = out.view(beamSize, remainingSents, -1) \
                .transpose(0, 1).contiguous()
            if self.opt.mem == 'lstm_lstm':
                attn = attn.view(beamSize, remainingSents, -1) \
                           .transpose(0, 1).contiguous()
            else:
                attn = None

            active = []
            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
                    active += [b]

                for decState in decStates:  # iterate over h, c
                    # layers x beam*sent x dim
                    sentStates = decState.view(-1, beamSize,
                                               remainingSents,
                                               decState.size(2))[:, :, idx]
                    sentStates.data.copy_(
                        sentStates.data.index_select(
                            1, beam[b].getCurrentOrigin()))

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, rnnSize)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return Variable(view.index_select(1, activeIdx)
                                .view(*newSize), volatile=True)

            decStates = (updateActive(decStates[0]),
                         updateActive(decStates[1]))
            decOut = updateActive(decOut)
            context = updateActive(context)
            if useMasking:
                padMask = padMask.index_select(1, activeIdx)

            remainingSents = len(active)

        #  (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
            allHyp += [hyps]
            if useMasking:
                valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \
                                                .nonzero().squeeze(1)
                attn = [a.index_select(1, valid_attn) for a in attn]
            allAttn += [attn]

            if self.beam_accum:
                self.beam_accum["beam_parent_ids"].append(
                    [t.tolist()
                     for t in beam[b].prevKs])
                self.beam_accum["scores"].append([
                    ["%4f" % s for s in t.tolist()]
                    for t in beam[b].allScores][1:])
                self.beam_accum["predicted_ids"].append(
                    [[self.tgt_dict.getLabel(id)
                      for id in t.tolist()]
                     for t in beam[b].nextYs][1:])

        return allHyp, allScores, allAttn, goldScores
    def beam_decode(self, encStates):
        batchSize = encStates.size(0)
        beamSize = self.opt.beam_size
        rnnSize = self.model.decoder.hidden_size
        beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]
        decStates = self.model.latent_to_decoder(encStates)
        if self.model.prelu:
            decStates = self.model.prelu_dec(decStates)
        decStates = decStates.view(self.model.layers, decStates.size(0), -1)
        decStates = torch.split(decStates, decStates.size(-1) // 2, 2)

        decStates = (Variable(decStates[0].data.repeat(1, beamSize, 1)),
                     Variable(decStates[1].data.repeat(1, beamSize, 1)))
        context = Variable(encStates.data.repeat(beamSize, 1))
        decOut = self.model.make_init_decoder_output(context)

        batchIdx = list(range(batchSize))
        remainingSents = batchSize
        for i in range(self.opt.max_sent_length):
            # Prepare decoder input.
            input = torch.stack([
                b.getCurrentState() for b in beam if not b.done
            ]).t().contiguous().view(1, -1)

            decOut, decStates = self.model.decoder(
                Variable(input, volatile=True), decStates, decOut)
            # decOut: 1 x (beam*batch) x numWords
            decOut = decOut.squeeze(0)
            out = decOut

            # batch x beam x numWords
            wordLk = out.view(beamSize, remainingSents,
                              -1).transpose(0, 1).contiguous()

            active = []
            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx]):
                    active += [b]

                for decState in decStates:  # iterate over h, c
                    # layers x beam*sent x dim
                    sentStates = decState.view(-1, beamSize, remainingSents,
                                               decState.size(2))[:, :, idx]
                    sentStates.data.copy_(
                        sentStates.data.index_select(
                            1, beam[b].getCurrentOrigin()))

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, rnnSize)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return Variable(view.index_select(1, activeIdx) \
                                    .view(*newSize), volatile=True)

            decStates = (updateActive(decStates[0]),
                         updateActive(decStates[1]))
            #decOut = updateActive(decOut)

            remainingSents = len(active)

        #  (4) package everything up

        allHyp, allScores = [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            hyps = [beam[b].getHyp(k) for k in ks[:n_best]]
            allHyp += [hyps]
        return allHyp, allScores
    def translateBatch(self, srcBatch, tgtBatch):
        batchSize = srcBatch[0].size(1)
        beamSize = self.opt.beam_size

        #  (1) run the encoder on the src
        encStates, _ = self.model.encode(srcBatch)
        srcBatch = srcBatch[0]  # drop the lengths needed for encoder

        rnnSize = self.model.decoder.hidden_size

        #  This mask is applied to the attention model inside the decoder
        #  so that the attention ignores source padding

        #  (2) if a target is specified, compute the 'goldScore'
        #  (i.e. log likelihood) of the target under the model
        goldScores = encStates.data.new(batchSize).zero_()
        if tgtBatch is not None:
            decStates = encStates

            decOut = self.model.decode(decStates, tgtBatch[:-1])
            for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data):
                gen_t = dec_t
                tgt_t = tgt_t.unsqueeze(1)
                scores = gen_t.data.gather(1, tgt_t)
                scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
                goldScores += scores

        #  (3) run the decoder to generate sentences, using beam search

        # Expand tensors for each beam.

        beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]
        decStates = self.model.latent_to_decoder(encStates)
        if self.model.prelu:
            decStates = self.model.prelu_dec(decStates)
        decStates = decStates.view(self.model.layers, decStates.size(0), -1)
        decStates = torch.chunk(decStates, 2, 2)
        decStates = (Variable(decStates[0].data.repeat(1, beamSize, 1)),
                     Variable(decStates[1].data.repeat(1, beamSize, 1)))
        context = Variable(encStates.data.repeat(beamSize, 1))
        decOut = self.model.make_init_decoder_output(context)

        padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat(
            beamSize, 1, 1)

        batchIdx = list(range(batchSize))
        remainingSents = batchSize
        for i in range(self.opt.max_sent_length):

            # Prepare decoder input.
            input = torch.stack([
                b.getCurrentState() for b in beam if not b.done
            ]).t().contiguous().view(1, -1)

            decOut, decStates = self.model.decoder(
                Variable(input, volatile=True), decStates, decOut)
            # decOut: 1 x (beam*batch) x numWords
            decOut = decOut.squeeze(0)
            out = decOut

            # batch x beam x numWords
            wordLk = out.view(beamSize, remainingSents,
                              -1).transpose(0, 1).contiguous()

            active = []
            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx]):
                    active += [b]

                for decState in decStates:  # iterate over h, c
                    # layers x beam*sent x dim
                    sentStates = decState.view(-1, beamSize, remainingSents,
                                               decState.size(2))[:, :, idx]
                    sentStates.data.copy_(
                        sentStates.data.index_select(
                            1, beam[b].getCurrentOrigin()))

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, rnnSize)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return Variable(view.index_select(1, activeIdx) \
                                    .view(*newSize), volatile=True)

            decStates = (updateActive(decStates[0]),
                         updateActive(decStates[1]))
            #decOut = updateActive(decOut)
            padMask = padMask.index_select(1, activeIdx)

            remainingSents = len(active)

        #  (4) package everything up

        allHyp, allScores = [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            hyps = [beam[b].getHyp(k) for k in ks[:n_best]]
            allHyp += [hyps]

        return allHyp, allScores, goldScores
Ejemplo n.º 15
0
    def translateBatch(self, srcBatch, tgtBatch):

        torch.set_grad_enabled(False)
        # Batch size is in different location depending on data.

        beamSize = self.opt.beam_size
        batchSize = self._getBatchSize(srcBatch)

        vocab_size = self.tgt_dict.size()
        allHyp, allScores, allAttn, allLengths = [], [], [], []

        # srcBatch should have size len x batch
        # tgtBatch should have size len x batch

        contexts = dict()

        src = srcBatch.transpose(0, 1)

        #  (1) run the encoders on the src
        for i in range(self.n_models):
            contexts[i], src_mask = self.models[i].encoder(src)

        goldScores = contexts[0].data.new(batchSize).zero_()
        goldWords = 0

        if tgtBatch is not None:
            # Use the first model to decode
            model_ = self.models[0]

            tgtBatchInput = tgtBatch[:-1]
            tgtBatchOutput = tgtBatch[1:]
            tgtBatchInput = tgtBatchInput.transpose(0, 1)

            output, coverage = model_.decoder(tgtBatchInput, contexts[0], src)
            # output should have size time x batch x dim

            #  (2) if a target is specified, compute the 'goldScore'
            #  (i.e. log likelihood) of the target under the model
            for dec_t, tgt_t in zip(output, tgtBatchOutput.data):
                gen_t = model_.generator(dec_t)
                tgt_t = tgt_t.unsqueeze(1)
                scores = gen_t.data.gather(1, tgt_t)
                scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
                goldScores += scores.squeeze(1).type_as(goldScores)
                goldWords += tgt_t.ne(onmt.Constants.PAD).sum().item()

        #  (3) Start decoding

        # time x batch * beam
        src = srcBatch  # this is time first again (before transposing)

        # initialize the beam
        beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]

        batchIdx = list(range(batchSize))
        remainingSents = batchSize

        decoder_states = dict()

        decoder_hiddens = dict()

        for i in range(self.n_models):
            decoder_states[i] = self.models[i].create_decoder_state(
                src, contexts[i], src_mask, beamSize, type='old')

        for i in range(self.opt.max_sent_length):
            # Prepare decoder input.

            # input size: 1 x ( batch * beam )
            input = torch.stack([
                b.getCurrentState() for b in beam if not b.done
            ]).t().contiguous().view(1, -1)
            """  
                Inefficient decoding implementation
                We re-compute all states for every time step
                A better buffering algorithm will be implemented
            """

            decoder_input = input

            # require batch first for everything
            outs = dict()
            attns = dict()

            for i in range(self.n_models):
                decoder_hidden, coverage = self.models[i].decoder.step(
                    decoder_input.clone(), decoder_states[i])

                # take the last decoder state
                decoder_hidden = decoder_hidden.squeeze(1)
                attns[i] = coverage[:,
                                    -1, :].squeeze(1)  # batch * beam x src_len

                # batch * beam x vocab_size
                outs[i] = self.models[i].generator(decoder_hidden)

            out = self._combineOutputs(outs)
            attn = self._combineAttention(attns)

            wordLk = out.view(beamSize, remainingSents, -1) \
                        .transpose(0, 1).contiguous()
            attn = attn.view(beamSize, remainingSents, -1) \
                       .transpose(0, 1).contiguous()

            active = []

            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]

                if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
                    active += [b]

                for i in range(self.n_models):
                    decoder_states[i]._update_beam(beam, b, remainingSents,
                                                   idx)

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            for i in range(self.n_models):
                decoder_states[i]._prune_complete_beam(activeIdx,
                                                       remainingSents)

            remainingSents = len(active)

        #  (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        n_best = self.opt.n_best
        allLengths = []

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            hyps, attn, length = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
            allHyp += [hyps]
            allLengths += [length]
            valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \
                                            .nonzero().squeeze(1)
            attn = [a.index_select(1, valid_attn) for a in attn]
            allAttn += [attn]

            if self.beam_accum:
                self.beam_accum["beam_parent_ids"].append(
                    [t.tolist() for t in beam[b].prevKs])
                self.beam_accum["scores"].append(
                    [["%4f" % s for s in t.tolist()]
                     for t in beam[b].allScores][1:])
                self.beam_accum["predicted_ids"].append(
                    [[self.tgt_dict.getLabel(id) for id in t.tolist()]
                     for t in beam[b].nextYs][1:])

        torch.set_grad_enabled(True)

        return allHyp, allScores, allAttn, allLengths, goldScores, goldWords
Ejemplo n.º 16
0
    def translateBatch(self, batch):
        beamSize = self.opt.beam_size
        batchSize = batch.batchSize

        #  (1) run the encoder on the src
        useMasking = (self._type == "text")
        encStatesL = []
        decStatesL = []
        contextL = []
        src_lengths = batch.lengths.data.view(-1).tolist()
        globalScorer = onmt.GNMTGlobalScorer(self.opt.alpha, self.opt.beta)
        beam = [onmt.Beam(beamSize, self.opt.cuda, globalScorer, alpha=self.opt.alpha, beta=self.opt.beta, tgtDict=self.tgt_dict) for i in range(batchSize)]
        for model in self.models:
            encStates, context = model.encoder(batch.src, lengths=batch.lengths)
            encStates = model.init_decoder_state(context, encStates)

            decoder = model.decoder
            attentionLayer = decoder.attn

            #  This mask is applied to the attention model inside the decoder
            #  so that the attention ignores source padding
            padMask = batch.words().data.eq(onmt.Constants.PAD).t()
            attentionLayer.applyMask(padMask)
            #  (2) if a target is specified, compute the 'goldScore'
            #  (i.e. log likelihood) of the target under the model
            
            ## for sanity check
            #if batch.tgt is not None:
            #    decStates = encStates
            #    mask(padMask.unsqueeze(0))
            #    decOut, decStates, attn = self.model.decoder(batch.tgt[:-1],
            #                                                 batch.src,
            #                                                 context,
            #                                                 decStates)
            #    for dec_t, tgt_t in zip(decOut, batch.tgt[1:].data):
            #        gen_t = self.model.generator.forward(dec_t)
            #        tgt_t = tgt_t.unsqueeze(1)
            #        scores = gen_t.data.gather(1, tgt_t)
            #        scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
            #        goldScores += scores
            # for sanity check

            #  (3) run the decoder to generate sentences, using beam search
            # Each hypothesis in the beam uses the same context
            # and initial decoder state
            context = Variable(context.data.repeat(1, beamSize, 1))
            contextL.append(context.clone())
            goldScores = context.data.new(batchSize).zero_()
            decStates = encStates
            decStates.repeatBeam_(beamSize)
            decStatesL.append(decStates)
        batch_src = Variable(batch.src.data.repeat(1, beamSize, 1))
        padMask = batch.src.data[:, :, 0].eq(onmt.Constants.PAD).t() \
                                   .unsqueeze(0) \
                                   .repeat(beamSize, 1, 1)

        #  (3b) The main loop
        beam_done = []
        for i in range(self.opt.max_sent_length):
            # (a) Run RNN decoder forward one step.
            #mask(padMask)

            input = torch.stack([b.getCurrentState() for b in beam])\
                         .t().contiguous().view(1, -1)
            input = Variable(input, volatile=True)
            decOutTmp = []
            attnTmp = []
            word_scores = []
            for idx in range(len(self.models)):
                model = self.models[idx]
                model.decoder.attn.applyMask(padMask)
                decOut, decStatesTmp, attn = model.decoder(input, batch_src, contextL[idx], decStatesL[idx])
                decStatesL[idx] = decStatesTmp
                decOutTmp.append(decOut)
                attnTmp.append(attn)
                decOut = decOut.squeeze(0)
                # decOut: (beam*batch) x numWords
                attn["std"] = attn["std"].view(beamSize, batchSize, -1) \
                                     .transpose(0, 1).contiguous()

                # (b) Compute a vector of batch*beam word scores.
                #if not self.copy_attn:
                if True:
                    out = model.generator[0].forward(decOut)
                    out = nn.Softmax()(out)
                else:
                    # Copy Attention Case
                    words = batch.words().t()
                    words = torch.stack([words[i] for i, b in enumerate(beam)])\
                                 .contiguous()
                    attn_copy = attn["copy"].view(beamSize, batchSize, -1) \
                                            .transpose(0, 1).contiguous()

                    out, c_attn_t \
                        = self.model.generator.forward(
                            decOut, attn_copy.view(-1, batch_src.size(0)))

                    for b in range(out.size(0)):
                        for c in range(c_attn_t.size(1)):
                            v = self.align[words[0, c].data[0]]
                            if v != onmt.Constants.PAD:
                                out[b, v] += c_attn_t[b, c]
                    out = out.log()

                #score = out.view(beamSize, batchSize, -1).transpose(0, 1).contiguous()
                # batch x beam x numWords
                word_scores.append(out.clone())
            word_score = torch.stack(word_scores).sum(0).squeeze(0).div_(len(self.models))
            mean_score = word_score.view(beamSize, batchSize, -1).transpose(0, 1).contiguous()

            scores = torch.log(mean_score) 
            #scores = self.models[0].generator[1].forward(mean_score)

            # (c) Advance each beam.
            active = []

            for b in range(batchSize):
                if b in beam_done:
                    continue
                beam[b].advance(scores.data[b],
                                          attn["std"].data[b])
                is_done = beam[b].done()
                if not is_done:
                    active += [b]
                for dec in decStatesL: 
                    dec.beamUpdate_(b, beam[b].getCurrentOrigin(), beamSize)
                if is_done:
                    beam_done.append(b)
            #if not active:
                #break
            if len(beam_done) == batchSize:
                break

        #  (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortFinished()
            #scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = beam[b].getHyp(times, k)
                hyps.append(hyp)
                attn.append(att)
            allHyp += [hyps]
            if useMasking:
                valid_attn = batch.src.data[:, b, 0].ne(onmt.Constants.PAD) \
                                                .nonzero().squeeze(1)
                attn = [a.index_select(1, valid_attn) for a in attn]
            allAttn += [attn]

            # For debugging visualization.
            if self.beam_accum:
                self.beam_accum["beam_parent_ids"].append(
                    [t.tolist()
                     for t in beam[b].prevKs])
                self.beam_accum["scores"].append([
                    ["%4f" % s for s in t.tolist()]
                    for t in beam[b].allScores][1:])
                self.beam_accum["predicted_ids"].append(
                    [[idx for idx in t.tolist()]
                     for t in beam[b].nextYs][1:])
                self.beam_accum["predicted_labels"].append(
                    [[self.tgt_dict.getLabel(idx)
                      for idx in t.tolist()]
                     for t in beam[b].nextYs][1:])
                beam[b].finished.sort(key=lambda a:-a[0])
                self.beam_accum['finished'].append(beam[b].finished)

        return allHyp, allScores, allAttn, goldScores
Ejemplo n.º 17
0
    def translateBatch(self, srcBatch, tgtBatch):
        # Batch size is in different location depending on data.

        beamSize = self.opt.beam_size

        #  (1) run the encoder on the src
        encStates, context = self.model.encoder(srcBatch)

        # Drop the lengths needed for encoder.
        srcBatch = srcBatch[0]
        batchSize = self._getBatchSize(srcBatch)

        rnnSize = context.size(2)
        encStates = (self.model._fix_enc_hidden(encStates[0]),
                     self.model._fix_enc_hidden(encStates[1]))

        decoder = self.model.decoder
        attentionLayer = decoder.attn
        useMasking = self._type == "text"

        #  This mask is applied to the attention model inside the decoder
        #  so that the attention ignores source padding
        padMask = None
        if useMasking:
            padMask = srcBatch.data.eq(onmt.Constants.PAD).t()

        def mask(padMask):
            if useMasking:
                attentionLayer.applyMask(padMask)

        #  (2) if a target is specified, compute the 'goldScore'
        #  (i.e. log likelihood) of the target under the model
        goldScores = context.data.new(batchSize).zero_()
        if tgtBatch is not None:
            decStates = encStates
            decOut = self.model.make_init_decoder_output(context)
            mask(padMask)
            initOutput = self.model.make_init_decoder_output(context)
            decOut, decStates, attn = self.model.decoder(
                tgtBatch[:-1], decStates, context, initOutput)
            for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data):
                gen_t = self.model.generator.forward(dec_t)
                tgt_t = tgt_t.unsqueeze(1)
                scores = gen_t.data.gather(1, tgt_t)
                scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
                goldScores += scores

        #  (3) run the decoder to generate sentences, using beam search

        # Expand tensors for each beam.
        context = Variable(context.data.repeat(1, beamSize, 1))

        decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)),
                     Variable(encStates[1].data.repeat(1, beamSize, 1)))

        beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]

        decOut = self.model.make_init_decoder_output(context)

        if useMasking:
            padMask = srcBatch.data.eq(
                onmt.Constants.PAD).t() \
                                   .unsqueeze(0) \
                                   .repeat(beamSize, 1, 1)

        batchIdx = list(range(batchSize))
        remainingSents = batchSize
        for i in range(self.opt.max_sent_length):
            mask(padMask)
            # Prepare decoder input.
            input = torch.stack([
                b.getCurrentState() for b in beam if not b.done
            ]).t().contiguous().view(1, -1)
            decOut, decStates, attn = self.model.decoder(
                Variable(input, volatile=True), decStates, context, decOut)
            # decOut: 1 x (beam*batch) x numWords
            decOut = decOut.squeeze(0)
            out = self.model.generator.forward(decOut)

            # batch x beam x numWords
            wordLk = out.view(beamSize, remainingSents, -1) \
                        .transpose(0, 1).contiguous()
            attn = attn.view(beamSize, remainingSents, -1) \
                       .transpose(0, 1).contiguous()

            active = []
            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
                    active += [b]

                for decState in decStates:  # iterate over h, c
                    # layers x beam*sent x dim
                    sentStates = decState.view(-1, beamSize, remainingSents,
                                               decState.size(2))[:, :, idx]
                    sentStates.data.copy_(
                        sentStates.data.index_select(
                            1, beam[b].getCurrentOrigin()))

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, rnnSize)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return Variable(view.index_select(1, activeIdx).view(*newSize),
                                volatile=True)

            decStates = (updateActive(decStates[0]),
                         updateActive(decStates[1]))
            decOut = updateActive(decOut)
            context = updateActive(context)
            if useMasking:
                padMask = padMask.index_select(1, activeIdx)

            remainingSents = len(active)

        #  (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
            allHyp += [hyps]
            if useMasking:
                valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \
                                                .nonzero().squeeze(1)
                attn = [a.index_select(1, valid_attn) for a in attn]
            allAttn += [attn]

            if self.beam_accum:
                self.beam_accum["beam_parent_ids"].append(
                    [t.tolist() for t in beam[b].prevKs])
                self.beam_accum["scores"].append(
                    [["%4f" % s for s in t.tolist()]
                     for t in beam[b].allScores][1:])
                self.beam_accum["predicted_ids"].append(
                    [[self.tgt_dict.getLabel(id) for id in t.tolist()]
                     for t in beam[b].nextYs][1:])

        return allHyp, allScores, allAttn, goldScores
Ejemplo n.º 18
0
    def translateBatch(self, srcBatch, tgtBatch):

        torch.set_grad_enabled(False)
        # Batch size is in different location depending on data.

        beamSize = self.opt.beam_size
        batchSize = self._getBatchSize(srcBatch)

        if self.model_type == 'recurrent':

            #  (1) run the encoder on the src
            encStates, context = self.model.encoder(srcBatch)

            rnnSize = context.size(2)

            decoder = self.model.decoder
            attentionLayer = decoder.attn
            useMasking = (self._type == "text" and batchSize > 1)

            #  This mask is applied to the attention model inside the decoder
            #  so that the attention ignores source padding
            attn_mask = srcBatch.eq(onmt.Constants.PAD).t()

            #  (2) if a target is specified, compute the 'goldScore'
            #  (i.e. log likelihood) of the target under the model
            goldScores = context.data.new(batchSize).zero_()
            goldWords = 0
            if tgtBatch is not None:
                decStates = encStates
                decOut = self.model.make_init_decoder_output(context)
                initOutput = self.model.make_init_decoder_output(context)
                decOut, decStates, attn = self.model.decoder(
                    tgtBatch[:-1],
                    decStates,
                    context,
                    initOutput,
                    attn_mask=attn_mask)
                for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data):
                    gen_t = self.model.generator.forward(dec_t)
                    tgt_t = tgt_t.unsqueeze(1)
                    scores = gen_t.data.gather(1, tgt_t)
                    scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
                    goldScores += scores.squeeze(1)
                    goldWords += tgt_t.ne(onmt.Constants.PAD).sum()

            #  (3) run the decoder to generate sentences, using beam search

            # Expand tensors for each beam.
            context = Variable(context.data.repeat(1, beamSize, 1))

            decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)),
                         Variable(encStates[1].data.repeat(1, beamSize, 1)))

            beam = [
                onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)
            ]

            decOut = self.model.make_init_decoder_output(context)


            attn_mask = srcBatch.eq(
                onmt.Constants.PAD).t() \
                                   .unsqueeze(0) \
                                   .repeat(beamSize, 1, 1)

            batchIdx = list(range(batchSize))
            remainingSents = batchSize
            for i in range(self.opt.max_sent_length):
                # Prepare decoder input.
                input = torch.stack([
                    b.getCurrentState() for b in beam if not b.done
                ]).t().contiguous().view(1, -1)
                decOut, decStates, attn = self.model.decoder(
                    Variable(input),
                    decStates,
                    context,
                    decOut,
                    attn_mask=attn_mask)
                # decOut: 1 x (beam*batch) x numWords
                decOut = decOut.squeeze(0)
                out = self.model.generator.forward(decOut)

                # batch x beam x numWords
                wordLk = out.view(beamSize, remainingSents, -1) \
                            .transpose(0, 1).contiguous()
                attn = attn.view(beamSize, remainingSents, -1) \
                           .transpose(0, 1).contiguous()

                active = []
                for b in range(batchSize):
                    if beam[b].done:
                        continue

                    idx = batchIdx[b]
                    if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
                        active += [b]

                    for decState in decStates:  # iterate over h, c
                        # layers x beam*sent x dim
                        sentStates = decState.view(-1, beamSize,
                                                   remainingSents,
                                                   decState.size(2))[:, :, idx]
                        sentStates.data.copy_(
                            sentStates.data.index_select(
                                1, beam[b].getCurrentOrigin()))

                if not active:
                    break

                # in this section, the sentences that are still active are
                # compacted so that the decoder is not run on completed sentences
                activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
                batchIdx = {beam: idx for idx, beam in enumerate(active)}

                def updateActive(t):
                    # select only the remaining active sentences
                    view = t.data.view(-1, remainingSents, rnnSize)
                    newSize = list(t.size())
                    newSize[-2] = newSize[-2] * len(
                        activeIdx) // remainingSents
                    return Variable(
                        view.index_select(1, activeIdx).view(*newSize))

                decStates = (updateActive(decStates[0]),
                             updateActive(decStates[1]))
                decOut = updateActive(decOut)
                context = updateActive(context)

                attn_mask_data = attn_mask.data.index_select(1, activeIdx)
                attn_mask = Variable(attn_mask_data)

                remainingSents = len(active)

            #  (4) package everything up
            allHyp, allScores, allAttn = [], [], []
            n_best = self.opt.n_best
            allLengths = []

            for b in range(batchSize):
                scores, ks = beam[b].sortBest()

                allScores += [scores[:n_best]]
                hyps, attn, length = zip(
                    *[beam[b].getHyp(k) for k in ks[:n_best]])
                allHyp += [hyps]
                allLengths += [length]
                valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \
                                                .nonzero().squeeze(1)
                attn = [a.index_select(1, valid_attn) for a in attn]
                allAttn += [attn]

                if self.beam_accum:
                    self.beam_accum["beam_parent_ids"].append(
                        [t.tolist() for t in beam[b].prevKs])
                    self.beam_accum["scores"].append(
                        [["%4f" % s for s in t.tolist()]
                         for t in beam[b].allScores][1:])
                    self.beam_accum["predicted_ids"].append(
                        [[self.tgt_dict.getLabel(id) for id in t.tolist()]
                         for t in beam[b].nextYs][1:])

            torch.set_grad_enabled(True)

            return allHyp, allScores, allAttn, allLengths, goldScores, goldWords
        elif self.model_type in [
                'transformer', 'ptransformer', 'fctransformer'
        ]:

            vocab_size = self.tgt_dict.size()
            allHyp, allScores, allAttn, allLengths = [], [], [], []

            # srcBatch should have size len x batch
            # tgtBatch should have size len x batch

            src = srcBatch.transpose(0, 1)
            context, src_mask = self.model.encoder(src)

            goldScores = context.data.new(batchSize).zero_()
            goldWords = 0

            if tgtBatch is not None:

                tgtBatchInput = tgtBatch[:-1]
                tgtBatchOutput = tgtBatch[1:]
                tgtBatchInput = tgtBatchInput.transpose(0, 1)

                output, coverage = self.model.decoder(tgtBatchInput, context,
                                                      src)
                output = output.transpose(
                    0, 1)  # transpose to have time first, like RNN models

                #  (2) if a target is specified, compute the 'goldScore'
                #  (i.e. log likelihood) of the target under the model
                for dec_t, tgt_t in zip(output, tgtBatchOutput.data):
                    gen_t = self.model.generator(dec_t)
                    tgt_t = tgt_t.unsqueeze(1)
                    scores = gen_t.data.gather(1, tgt_t)
                    scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
                    goldScores += scores.squeeze(1)
                    goldWords += tgt_t.ne(onmt.Constants.PAD).sum()

            #  (3) Start decoding

            # time x batch * beam
            src = Variable(srcBatch.data.repeat(1, beamSize))

            # context size : time x batch*beam x hidden
            context = self._replicate_context(context)

            # initialize the beam
            beam = [
                onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)
            ]

            batchIdx = list(range(batchSize))
            remainingSents = batchSize

            #~ input_seq = None
            #~
            #~ buffer = None
            #~

            ## Create a new decoding state
            ## I use a method from the model because I don't want to directly access to the decoder state object
            ## Currently it doesn't share anything with the main model though
            decoder_state = self.model.create_decoder_state(
                src, context, beamSize)

            for i in range(self.opt.max_sent_length):
                # Prepare decoder input.

                # input size: 1 x ( batch * beam )
                input = torch.stack([
                    b.getCurrentState() for b in beam if not b.done
                ]).t().contiguous().view(1, -1)
                """  
                    Inefficient decoding implementation
                    We re-compute all states for every time step
                    A better buffering algorithm will be implemented
                """
                #~ input_seq = decoder_state.input_seq
                #~ if input_seq is None:
                #~ input_seq = input
                #~ else:
                #~ # concatenate the last input to the previous input sequence
                #~ input_seq = torch.cat([input_seq, input], 0)
                #~ decoder_state.input_seq = input_seq

                # require batch first for everything
                decoder_input = Variable(input)
                #~ if context.dim() == 4:
                #~ context_ = context.transpose(1, 2)
                #~ else:
                #~ context_ = context.transpose(0, 1)
                #~ decoder_hidden, coverage, buffer = self.model.decoder.step(decoder_input.transpose(0,1) , context_, src.transpose(0, 1), buffer=buffer)
                decoder_hidden, coverage = self.model.decoder.step(
                    decoder_input, decoder_state)

                # take the last decoder state
                decoder_hidden = decoder_hidden.squeeze(1)
                attn = coverage[:, -1, :].squeeze(1)  # batch * beam x src_len

                # batch * beam x vocab_size
                out = self.model.generator(decoder_hidden)

                wordLk = out.view(beamSize, remainingSents, -1) \
                            .transpose(0, 1).contiguous()
                attn = attn.view(beamSize, remainingSents, -1) \
                           .transpose(0, 1).contiguous()
                active = []

                for b in range(batchSize):
                    if beam[b].done:
                        continue

                    idx = batchIdx[b]
                    if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
                        active += [b]

                    decoder_state._update_beam(beam, b, remainingSents, idx)

                    # update the decoding states
                    #~ for tensor in [src, input_seq]  :
                    #~
                    #~ t_, br = tensor.size()
                    #~ sent_states = tensor.view(t_, beamSize, remainingSents)[:, :, idx]
                    #~
                    #~ if isinstance(tensor, Variable):
                    #~ sent_states.data.copy_(sent_states.data.index_select(
                    #~ 1, beam[b].getCurrentOrigin()))
                    #~ else:
                    #~ sent_states.copy_(sent_states.index_select(
                    #~ 1, beam[b].getCurrentOrigin()))
                    #~
                    #~ nl, br_, t_, d_ = buffer.size()
                    #~
                    #~ sent_states = buffer.view(nl, beamSize, remainingSents, t_, d_)[:, :, idx, :, :]
                    #~
                    #~ sent_states.data.copy_(sent_states.data.index_select(
                    #~ 1, beam[b].getCurrentOrigin()))

                if not active:
                    break

                # in this section, the sentences that are still active are
                # compacted so that the decoder is not run on completed sentences
                activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
                batchIdx = {beam: idx for idx, beam in enumerate(active)}

                #~ model_size = context.size(-1)

                decoder_state._prune_complete_beam(activeIdx, remainingSents)

                #~ def updateActive(t):
                #~ # select only the remaining active sentences
                #~ view = t.data.view(-1, remainingSents, model_size)
                #~ newSize = list(t.size())
                #~ newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                #~ return Variable(view.index_select(1, activeIdx)
                #~ .view(*newSize))
                #~
                #~ def updateActive4D(t):
                #~ # select only the remaining active sentences
                #~ nl, br_, t_, d_ = t.size()
                #~ view = t.data.view(nl, -1, remainingSents, t_, model_size)
                #~ newSize = list(t.size())
                #~ newSize[1] = newSize[1] * len(activeIdx) // remainingSents
                #~ return Variable(view.index_select(2, activeIdx)
                #~ .view(*newSize))
                #~
                #~ def updateActive4D_time_first(t):
                #~ # select only the remaining active sentences
                #~ nl, t_, br_, d_ = t.size()
                #~ view = t.data.view(nl, t_, -1, remainingSents, model_size)
                #~ newSize = list(t.size())
                #~ newSize[2] = newSize[2] * len(activeIdx) // remainingSents
                #~ return Variable(view.index_select(3, activeIdx)
                #~ .view(*newSize))
                #~
                #~ def updateActive2D(t):
                #~ if isinstance(t, Variable):
                #~ # select only the remaining active sentences
                #~ view = t.data.view(-1, remainingSents)
                #~ newSize = list(t.size())
                #~ newSize[-1] = newSize[-1] * len(activeIdx) // remainingSents
                #~ return Variable(view.index_select(1, activeIdx)
                #~ .view(*newSize))
                #~ else:
                #~ view = t.view(-1, remainingSents)
                #~ newSize = list(t.size())
                #~ newSize[-1] = newSize[-1] * len(activeIdx) // remainingSents
                #~ new_t = view.index_select(1, activeIdx).view(*newSize)
                #~
                #~ return new_t
                #~
                #~ if context.dim() == 3 :
                #~ context = updateActive(context)
                #~ elif context.dim() == 4:
                #~ context = updateActive4D_time_first(context)
                #~
                #~ src = updateActive2D(src)
                #~
                #~ input_seq = updateActive2D(input_seq)
                #~
                #~ buffer = updateActive4D(buffer)
                #~
                remainingSents = len(active)

            #  (4) package everything up
            allHyp, allScores, allAttn = [], [], []
            n_best = self.opt.n_best
            allLengths = []

            for b in range(batchSize):
                scores, ks = beam[b].sortBest()

                allScores += [scores[:n_best]]
                hyps, attn, length = zip(
                    *[beam[b].getHyp(k) for k in ks[:n_best]])
                allHyp += [hyps]
                allLengths += [length]
                valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \
                                                .nonzero().squeeze(1)
                attn = [a.index_select(1, valid_attn) for a in attn]
                allAttn += [attn]

                if self.beam_accum:
                    self.beam_accum["beam_parent_ids"].append(
                        [t.tolist() for t in beam[b].prevKs])
                    self.beam_accum["scores"].append(
                        [["%4f" % s for s in t.tolist()]
                         for t in beam[b].allScores][1:])
                    self.beam_accum["predicted_ids"].append(
                        [[self.tgt_dict.getLabel(id) for id in t.tolist()]
                         for t in beam[b].nextYs][1:])

            torch.set_grad_enabled(True)

            return allHyp, allScores, allAttn, allLengths, goldScores, goldWords

        else:
            print("Model type %s is not supported" % self.model_type)
            raise NotImplementedError
Ejemplo n.º 19
0
    def translateBatch(self, batch, dataset):
        beam_size = self.opt.beam_size
        batch_size = batch.batch_size

        # (1) Run the encoder on the src.
        _, src_lengths = batch.src
        src = onmt.IO.make_features(batch, 'src')
        encStates, context = self.model.encoder(src, src_lengths)
        decStates = self.model.decoder.init_decoder_state(
                                        src, context, encStates)

        #  (1b) Initialize for the decoder.
        def var(a): return Variable(a, volatile=True)

        def rvar(a): return var(a.repeat(1, beam_size, 1))

        # Repeat everything beam_size times.
        context = rvar(context.data)
        src = rvar(src.data)
        srcMap = rvar(batch.src_map.data)
        decStates.repeat_beam_size_times(beam_size)
        scorer = onmt.GNMTGlobalScorer(self.alpha, self.beta)
        beam = [onmt.Beam(beam_size, n_best=self.opt.n_best,
                          cuda=self.opt.cuda,
                          vocab=self.fields["tgt"].vocab,
                          global_scorer=scorer)
                for __ in range(batch_size)]

        # (2) run the decoder to generate sentences, using beam search.

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        for i in range(self.opt.max_sent_length):

            if all((b.done() for b in beam)):
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = var(torch.stack([b.getCurrentState() for b in beam])
                      .t().contiguous().view(1, -1))

            # Turn any copied words to UNKs
            # 0 is unk
            if self.copy_attn:
                inp = inp.masked_fill(
                    inp.gt(len(self.fields["tgt"].vocab) - 1), 0)

            # Temporary kludge solution to handle changed dim expectation
            # in the decoder
            inp = inp.unsqueeze(2)

            # Run one step.
            decOut, decStates, attn = \
                self.model.decoder(inp, context, decStates)
            decOut = decOut.squeeze(0)
            # decOut: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.
            if not self.copy_attn:
                out = self.model.generator.forward(decOut).data
                out = unbottle(out)
                # beam x tgt_vocab
            else:
                out = self.model.generator.forward(decOut,
                                                   attn["copy"].squeeze(0),
                                                   srcMap)
                # beam x (tgt_vocab + extra_vocab)
                out = dataset.collapse_copy_scores(
                    unbottle(out.data),
                    batch, self.fields["tgt"].vocab)
                # beam x tgt_vocab
                out = out.log()

            # (c) Advance each beam.
            for j, b in enumerate(beam):
                b.advance(out[:, j],  unbottle(attn["std"]).data[:, j])
                decStates.beam_update(j, b.getCurrentOrigin(), beam_size)

        if "tgt" in batch.__dict__:
            allGold = self._runTarget(batch, dataset)
        else:
            allGold = [0] * batch_size

        # (3) Package everything up.
        allHyps, allScores, allAttn = [], [], []
        for b in beam:
            n_best = self.opt.n_best
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.getHyp(times, k)
                hyps.append(hyp)
                attn.append(att)
            allHyps.append(hyps)
            allScores.append(scores)
            allAttn.append(attn)

        return allHyps, allScores, allAttn, allGold
Ejemplo n.º 20
0
    def translateBatch(self, srcBatch, tgtBatch, alignBatch):

        batchSize = srcBatch[0].size(1)
        beamSize = self.opt.beam_size
        knntime = 0.0
        #  (1) run the encoder on the src
        encStates, context, fert = self.model.encoder(srcBatch, is_fert=True)
        init_fert = deepcopy(fert)
        init_fert = torch.max(init_fert, dim=-1)[1].float()
        # cov = torch.max(cov, dim=-1)[1].float()
        srcBatch = srcBatch[0] # drop the lengths needed for encoder

        rnnSize = context.size(2)
        encStates = (self.model._fix_enc_hidden(encStates[0]),
                      self.model._fix_enc_hidden(encStates[1]))

        #  This mask is applied to the attention model inside the decoder
        #  so that the attention ignores source padding
        padMask = srcBatch.data.eq(onmt.Constants.PAD).t()
        def applyContextMask(m):
            if isinstance(m, onmt.modules.GlobalAttention):
                m.applyMask(padMask)
            elif isinstance(m, onmt.modules.GlobalAttentionOriginal):
                m.applyMask(padMask)


        #  (3) run the decoder to generate sentences, using beam search

        # Expand tensors for each beam.
        # decStates = encStates
        if self.opt.use_lm:
            lm_hidden = self.LangModel.initialize_hidden(1, batchSize)

        context = Variable(context.data.repeat(1, beamSize, 1))
        # cov = Variable(torch.zeros((context.size(1),context.size(0))), requires_grad=True)
        # cov = cov.cuda()
        decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)),
                     Variable(encStates[1].data.repeat(1, beamSize, 1)))

        beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]

        decOut = self.model.make_init_decoder_output(context)

        padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1)

        batchIdx = list(range(batchSize))
        previous_batchIdx = list(range(batchSize))
        remainingSents = batchSize
        for i in range(self.opt.max_sent_length):

            self.model.decoder.apply(applyContextMask)

            # Prepare decoder input.
            inputs = []
            for i_b, b in enumerate(beam):
                if not b.done():
                    c = b.getCurrentState()
                    if self.opt.replace_unk and i != 0 and (int(c)==onmt.Constants.UNK or int(c)== onmt.Constants.UNK + self.tgt_dict.size_uni()):
                        tok = self.tgt_dict.getLabel(int(c))
                        if tok == onmt.Constants.UNK_WORD:
                            b_i = previous_batchIdx[i_b]
                            _src = self.src_dict.convertToLabels(srcBatch.data.t()[i_b], onmt.Constants.PAD_WORD)
                            tok = self.replace_unk(i, tok, attn[b_i][0].data, _src)
                            if tok in self.tgt_dict.labelToIdx.keys():
                                c = torch.LongTensor([self.tgt_dict.labelToIdx[tok]]).cuda()
                    inputs.append(c)
            input_ = torch.stack(inputs).t().contiguous().view(1, -1)
            # input_ = torch.stack([b.getCurrentState() for b in beam
            #                    if not b.done()]).t().contiguous().view(1, -1)
            input_var = Variable(input_, volatile=True)
            # decOut, decStates, attn, alignBatch = self.model.decoder(
            #     input_var, decStates, context, decOut, alignBatch)
            decOut, decStates, attn = self.model.decoder(
                input_var, decStates, context, decOut)

            # decOut: 1 x (beam*batch) x numWords
            decOut = decOut.squeeze(0)
            out = self.model.generator.forward(decOut)

            if self.opt.use_lm:
                lm_output, lm_hidden, _, _ = self.LangModel(input_var, lm_hidden)
                lm_output = torch.log(lm_output+1e-12)

            beg = time.time()
            original_scores = self._get_scores(out, self.target_embeddings)
            uni_scores, ngram_scores = self._get_scores(out)
            # print(fert)
            # print(attn.shape)
            fert_combined = torch.bmm(attn.unsqueeze(1)[:,:,:-1], fert[:,:-1]).squeeze(1)
            # print(fert_combined)

            if self.fert_dim > 2:
                uni_prob = torch.sum(fert_combined[:,:2], dim=1).unsqueeze(1)
                ngram_prob = torch.sum(fert_combined[:,2:], dim=1).unsqueeze(1)
            else:
                uni_prob = fert_combined[:,0].unsqueeze(1)
                ngram_prob = fert_combined[:,1].unsqueeze(1)
            scores = torch.cat((uni_scores*uni_prob, ngram_scores * ngram_prob),dim=1)
            # print(uni_prob.shape)
            # print(ngram_prob.shape)
            # print(uni_scores.shape)
            # print(ngram_scores.shape)
            # print(uni_prob.shape)
            # print(ngram_prob.shape)
            # print(scores.shape)
            # print(self.target_embeddings.weight.shape)
            # print(self.target_uni_embeddings.weight.shape)
            # print(self.target_ngram_embeddings.weight.shape)
            topk = scores.topk(beamSize, 1, True, True)[1].cpu().data.numpy()
            topk_uni = uni_scores.topk(beamSize, 1, True, True)[1].cpu().data.numpy()
            topk_ngram = ngram_scores.topk(beamSize, 1, True, True)[1].cpu().data.numpy()
            topk_original = original_scores.topk(beamSize, 1, True, True)[1].cpu().data.numpy()
            # print(topk)
            # print(topk_original)
            # print(topk_uni)
            # print(topk_ngram)
            # print(scores[0][:10])
            # raise
            # print(uni_prob[:20])
            # print(ngram_prob[:20])
            # print(scores)
            # raise

            diff = time.time()-beg
            knntime += diff

            if self.opt.use_lm:
                scores += 0.2*lm_output.squeeze(0)

            wordLk = scores.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous()
            attn = attn.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous()
            out = out.view(remainingSents, beamSize, -1).contiguous()
            fert_beam = fert_combined.view(remainingSents, beamSize, -1).contiguous()
            init_fert_beam = init_fert.view(remainingSents, beamSize, -1).contiguous()

            active = []
            for b in range(batchSize):
                if beam[b].done():
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx], attn.data[idx], out.data[idx], fert_beam.data[idx], init_fert_beam.data[idx]):
                    active += [b]

                for decState in decStates:  # iterate over h, c
                    # layers x beam*sent x dim
                    sentStates = decState.view(
                        -1, beamSize, remainingSents, decState.size(2))[:, :, idx]
                    sentStates.data.copy_(
                        sentStates.data.index_select(1, beam[b].getCurrentOrigin()))

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            previous_batchIdx = batchIdx
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, rnnSize)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return Variable(view.index_select(1, activeIdx) \
                                    .view(*newSize), volatile=True)

            decStates = (updateActive(decStates[0]), updateActive(decStates[1]))
            decOut = updateActive(decOut)
            context = updateActive(context)
            padMask = padMask.index_select(1, activeIdx)

            remainingSents = len(active)

        #  (4) package everything up

        allHyp, allScores, allAttn, allOut, allOutScores, allFert, allInitFert = [], [], [], [], [], [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD).nonzero().squeeze(1)

            hyps, attn, out, score, fert, init_fert = zip(*[beam[b].getHyp(times, k) for (times, k) in ks[:n_best]])
            attn = [a.index_select(1, valid_attn) for a in attn]

            allHyp += [hyps]
            allAttn += [attn]
            allOut += [out]
            allOutScores += [score]
            allFert += [fert]
            allInitFert += [init_fert]

        return allHyp, allScores, allAttn, allOut, allOutScores, allFert, allInitFert, knntime
Ejemplo n.º 21
0
    def translateBatch(self, batch):
        srcBatch, tgtBatch = batch
        batchSize = srcBatch.size(1)
        beamSize = self.opt.beam_size

        #  (1) run the encoder on the src

        encStates, context = None, None

        if self.model.encoder.num_directions == 2:
            # bidirectional encoder is negatively impacted by padding
            # run with batch size 1 for improved translations
            # This will be resolved when variable length LSTMs are used instead
            encStates, context = self.model.encoder(srcBatch, hidden=encStates)
        else:
            # have to execute the encoder manually to deal with padding
            context = []
            for srcBatch_t in srcBatch.split(1):
                encStates, context_t = self.model.encoder(srcBatch_t, hidden=encStates)
                batchPadIdx = srcBatch_t.data.squeeze(0).eq(onmt.Constants.PAD).nonzero()
                if batchPadIdx.nelement() > 0:
                    batchPadIdx = batchPadIdx.squeeze(1)
                    encStates[0].data.index_fill_(1, batchPadIdx, 0)
                    encStates[1].data.index_fill_(1, batchPadIdx, 0)
                context += [context_t]
            context = torch.cat(context)

        rnnSize = context.size(2)
        encStates = (self.model._fix_enc_hidden(encStates[0]),
                      self.model._fix_enc_hidden(encStates[1]))

        #  This mask is applied to the attention model inside the decoder
        #  so that the attention ignores source padding
        padMask = srcBatch.data.eq(onmt.Constants.PAD).t()
        def applyContextMask(m):
            if isinstance(m, onmt.modules.GlobalAttention):
                m.applyMask(padMask)

        #  (2) if a target is specified, compute the 'goldScore'
        #  (i.e. log likelihood) of the target under the model
        goldScores = context.data.new(batchSize).zero_()
        if tgtBatch is not None:
            decStates = encStates
            decOut = self.model.make_init_decoder_output(context)
            self.model.decoder.apply(applyContextMask)
            initOutput = self.model.make_init_decoder_output(context)

            decOut, decStates, attn = self.model.decoder(
                tgtBatch[:-1], decStates, context, initOutput)
            for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data):
                gen_t = self.model.generator.forward(dec_t)
                tgt_t = tgt_t.unsqueeze(1)
                scores = gen_t.data.gather(1, tgt_t)
                scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
                goldScores += scores

        #  (3) run the decoder to generate sentences, using beam search

        # Expand tensors for each beam.
        context = Variable(context.data.repeat(1, beamSize, 1), volatile=True)
        decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1), volatile=True),
                     Variable(encStates[1].data.repeat(1, beamSize, 1), volatile=True))

        beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)]

        decOut = self.model.make_init_decoder_output(context)

        padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1)

        batchIdx = list(range(batchSize))
        remainingSents = batchSize
        for i in range(self.opt.max_sent_length):

            self.model.decoder.apply(applyContextMask)

            # Prepare decoder input.
            input = torch.stack([b.getCurrentState() for b in beam
                               if not b.done]).t().contiguous().view(1, -1)

            decOut, decStates, attn = self.model.decoder(
                Variable(input, volatile=True), decStates, context, decOut)
            # decOut: 1 x (beam*batch) x numWords
            decOut = decOut.squeeze(0)
            out = self.model.generator.forward(decOut)

            # batch x beam x numWords
            wordLk = out.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous()
            attn = attn.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous()

            active = []
            for b in range(batchSize):
                if beam[b].done:
                    continue

                idx = batchIdx[b]
                if not beam[b].advance(wordLk.data[idx], attn.data[idx]):
                    active += [b]

                for decState in decStates:  # iterate over h, c
                    # layers x beam*sent x dim
                    sentStates = decState.view(
                        -1, beamSize, remainingSents, decState.size(2))[:, :, idx]
                    sentStates.data.copy_(
                        sentStates.data.index_select(1, beam[b].getCurrentOrigin()))

            if not active:
                break

            # in this section, the sentences that are still active are
            # compacted so that the decoder is not run on completed sentences
            activeIdx = self.tt.LongTensor([batchIdx[k] for k in active])
            batchIdx = {beam: idx for idx, beam in enumerate(active)}

            def updateActive(t):
                # select only the remaining active sentences
                view = t.data.view(-1, remainingSents, rnnSize)
                newSize = list(t.size())
                newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents
                return Variable(view.index_select(1, activeIdx) \
                                    .view(*newSize), volatile=True)

            decStates = (updateActive(decStates[0]), updateActive(decStates[1]))
            decOut = updateActive(decOut)
            context = updateActive(context)
            padMask = padMask.index_select(1, activeIdx)

            remainingSents = len(active)

        #  (4) package everything up

        allHyp, allScores, allAttn = [], [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD).nonzero().squeeze(1)
            hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]])
            attn = [a.index_select(1, valid_attn) for a in attn]
            allHyp += [hyps]
            allAttn += [attn]

        return allHyp, allScores, allAttn, goldScores
Ejemplo n.º 22
0
def translateBatch(opt, model, batch, src_dict, tgt_dict, beam_accum):
    beamSize = opt.beam_size
    batchSize = batch.batchSize

    #  (1) run the encoder on the src
    encStates, context, fertility_vals = model.encoder(batch.src)
    encStates = model.init_decoder_state(context, encStates)
    if fertility_vals is not None:
        fertility_vals = fertility_vals.repeat(beamSize * batchSize, 1)

    decoder = model.decoder
    attentionLayer = decoder.attn
    useMasking = True

    #  This mask is applied to the attention model inside the decoder
    #  so that the attention ignores source padding
    padMask = None
    if useMasking:
        padMask = batch.words().data.eq(onmt.Constants.PAD).t()

    def mask(padMask):
        if useMasking:
            attentionLayer.applyMask(padMask)

    # (2) if a target is specified, compute the 'goldScore'
    #  (i.e. log likelihood) of the target under the model
    goldScores = context.data.new(batchSize).zero_()

    # (3) run the decoder to generate sentences, using beam search
    # Each hypothesis in the beam uses the same context
    # and initial decoder state
    context = Variable(context.data.repeat(1, beamSize, 1))
    batch_src = Variable(batch.src.data.repeat(1, beamSize, 1))
    decStates = encStates
    decStates.repeatBeam_(beamSize)
    beam = [onmt.Beam(beamSize, True) for _ in range(batchSize)]
    if useMasking:
        padMask = batch.src.data[:, :, 0].eq(
            onmt.Constants.PAD).t() \
            .unsqueeze(0) \
            .repeat(beamSize, 1, 1)

    # (3b) The main loop
    upper_bounds = None
    max_sent_length = 100
    for i in range(max_sent_length):
        # (a) Run RNN decoder forward one step.
        mask(padMask)
        input = torch.stack([b.getCurrentState() for b in beam]) \
            .t().contiguous().view(1, -1)
        input = Variable(input, volatile=True)
        decOut, decStates, attn, upper_bounds = model.decoder(
            input,
            batch_src,
            context,
            decStates,
            fertility_vals=fertility_vals,
            fert_dict=None,
            upper_bounds=decStates.attn_upper_bounds,
            test=True)

        # import pdb; pdb.set_trace()
        decOut = decOut.squeeze(0)
        # decOut: (beam*batch) x numWords
        attn["std"] = attn["std"].view(beamSize, batchSize,
                                       -1).transpose(0, 1).contiguous()

        # (b) Compute a vector of batch*beam word scores.
        out = model.generator.forward(decOut)

        word_scores = out.view(beamSize, batchSize,
                               -1).transpose(0, 1).contiguous()
        # batch x beam x numWords

        # (c) Advance each beam.
        active = []
        for b in range(batchSize):
            is_done = beam[b].advance(word_scores.data[b], attn["std"].data[b])
            if not is_done:
                active += [b]
            decStates.beamUpdate_(b, beam[b].getCurrentOrigin(), beamSize)
        if not active:
            break

    # (4) package everything up
    allHyp, allScores, allAttn = [], [], []
    n_best = 1  # If verbose is set, will output the n_best decoded sentences

    for b in range(batchSize):
        scores, ks = beam[b].sortBest()

        allScores += [scores[:n_best]]
        hyps, attn = [], []
        for k in ks[:n_best]:
            hyp, att = beam[b].getHyp(k)
            hyps.append(hyp)
            attn.append(att)
        allHyp += [hyps]
        if useMasking:
            valid_attn = batch.src.data[:, b, 0].ne(onmt.Constants.PAD) \
                .nonzero().squeeze(1)
            attn = [a.index_select(1, valid_attn) for a in attn]
        allAttn += [attn]

        # For debugging visualization.
        if beam_accum:
            beam_accum["beam_parent_ids"].append(
                [t.tolist() for t in beam[b].prevKs])
            beam_accum["scores"].append([["%4f" % s for s in t.tolist()]
                                         for t in beam[b].allScores][1:])
            beam_accum["predicted_ids"].append(
                [[tgt_dict.getLabel(id) for id in t.tolist()]
                 for t in beam[b].nextYs][1:])
    # import pdb; pdb.set_trace()
    if fertility_vals is not None:
        cum_attn = allAttn[0][0].sum(0).squeeze(0).cpu().numpy()
        fert = fertility_vals.data[0, :].cpu().numpy()
        for c, f in zip(cum_attn, fert):
            print('%f (%f)' % (c, f))
    # print allAttn[0][0].sum(0)
    return allHyp, allScores, allAttn, goldScores