示例#1
0
    def translateBatch(self, batch, dataset):
        beam_size = self.opt.beam_size
        batch_size = batch.batch_size

        # (1) Run the encoder on the src.
        _, src_lengths = batch.src
        src = onmt.IO.make_features(batch, 'src')
        encStates, context = self.model.encoder(src, src_lengths)
        decStates = self.model.decoder.init_decoder_state(
                                        src, context, encStates)

        #  (1b) Initialize for the decoder.
        def var(a): return Variable(a, volatile=True)

        def rvar(a): return var(a.repeat(1, beam_size, 1))

        # Repeat everything beam_size times.
        context = rvar(context.data)
        src = rvar(src.data)
        srcMap = rvar(batch.src_map.data)
        decStates.repeat_beam_size_times(beam_size)
        scorer = onmt.GNMTGlobalScorer(self.alpha, self.beta)
        beam = [onmt.Beam(beam_size, n_best=self.opt.n_best,
                          cuda=self.opt.cuda,
                          vocab=self.fields["tgt"].vocab,
                          global_scorer=scorer)
                for __ in range(batch_size)]

        # (2) run the decoder to generate sentences, using beam search.

        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)

        for i in range(self.opt.max_sent_length):

            if all((b.done() for b in beam)):
                break

            # Construct batch x beam_size nxt words.
            # Get all the pending current beam words and arrange for forward.
            inp = var(torch.stack([b.getCurrentState() for b in beam])
                      .t().contiguous().view(1, -1))

            # Turn any copied words to UNKs
            # 0 is unk
            if self.copy_attn:
                inp = inp.masked_fill(
                    inp.gt(len(self.fields["tgt"].vocab) - 1), 0)

            # Temporary kludge solution to handle changed dim expectation
            # in the decoder
            inp = inp.unsqueeze(2)

            # Run one step.
            decOut, decStates, attn = \
                self.model.decoder(inp, context, decStates)
            decOut = decOut.squeeze(0)
            # decOut: beam x rnn_size

            # (b) Compute a vector of batch*beam word scores.
            if not self.copy_attn:
                out = self.model.generator.forward(decOut).data
                out = unbottle(out)
                # beam x tgt_vocab
            else:
                out = self.model.generator.forward(decOut,
                                                   attn["copy"].squeeze(0),
                                                   srcMap)
                # beam x (tgt_vocab + extra_vocab)
                out = dataset.collapse_copy_scores(
                    unbottle(out.data),
                    batch, self.fields["tgt"].vocab)
                # beam x tgt_vocab
                out = out.log()

            # (c) Advance each beam.
            for j, b in enumerate(beam):
                b.advance(out[:, j],  unbottle(attn["std"]).data[:, j])
                decStates.beam_update(j, b.getCurrentOrigin(), beam_size)

        if "tgt" in batch.__dict__:
            allGold = self._runTarget(batch, dataset)
        else:
            allGold = [0] * batch_size

        # (3) Package everything up.
        allHyps, allScores, allAttn = [], [], []
        for b in beam:
            n_best = self.opt.n_best
            scores, ks = b.sortFinished(minimum=n_best)
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = b.getHyp(times, k)
                hyps.append(hyp)
                attn.append(att)
            allHyps.append(hyps)
            allScores.append(scores)
            allAttn.append(attn)

        return allHyps, allScores, allAttn, allGold
示例#2
0
    def translateBatch(self, batch):
        beamSize = self.opt.beam_size
        batchSize = batch.batchSize

        #  (1) run the encoder on the src
        useMasking = (self._type == "text")
        encStatesL = []
        decStatesL = []
        contextL = []
        src_lengths = batch.lengths.data.view(-1).tolist()
        globalScorer = onmt.GNMTGlobalScorer(self.opt.alpha, self.opt.beta)
        beam = [onmt.Beam(beamSize, self.opt.cuda, globalScorer, alpha=self.opt.alpha, beta=self.opt.beta, tgtDict=self.tgt_dict) for i in range(batchSize)]
        for model in self.models:
            encStates, context = model.encoder(batch.src, lengths=batch.lengths)
            encStates = model.init_decoder_state(context, encStates)

            decoder = model.decoder
            attentionLayer = decoder.attn

            #  This mask is applied to the attention model inside the decoder
            #  so that the attention ignores source padding
            padMask = batch.words().data.eq(onmt.Constants.PAD).t()
            attentionLayer.applyMask(padMask)
            #  (2) if a target is specified, compute the 'goldScore'
            #  (i.e. log likelihood) of the target under the model
            
            ## for sanity check
            #if batch.tgt is not None:
            #    decStates = encStates
            #    mask(padMask.unsqueeze(0))
            #    decOut, decStates, attn = self.model.decoder(batch.tgt[:-1],
            #                                                 batch.src,
            #                                                 context,
            #                                                 decStates)
            #    for dec_t, tgt_t in zip(decOut, batch.tgt[1:].data):
            #        gen_t = self.model.generator.forward(dec_t)
            #        tgt_t = tgt_t.unsqueeze(1)
            #        scores = gen_t.data.gather(1, tgt_t)
            #        scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0)
            #        goldScores += scores
            # for sanity check

            #  (3) run the decoder to generate sentences, using beam search
            # Each hypothesis in the beam uses the same context
            # and initial decoder state
            context = Variable(context.data.repeat(1, beamSize, 1))
            contextL.append(context.clone())
            goldScores = context.data.new(batchSize).zero_()
            decStates = encStates
            decStates.repeatBeam_(beamSize)
            decStatesL.append(decStates)
        batch_src = Variable(batch.src.data.repeat(1, beamSize, 1))
        padMask = batch.src.data[:, :, 0].eq(onmt.Constants.PAD).t() \
                                   .unsqueeze(0) \
                                   .repeat(beamSize, 1, 1)

        #  (3b) The main loop
        beam_done = []
        for i in range(self.opt.max_sent_length):
            # (a) Run RNN decoder forward one step.
            #mask(padMask)

            input = torch.stack([b.getCurrentState() for b in beam])\
                         .t().contiguous().view(1, -1)
            input = Variable(input, volatile=True)
            decOutTmp = []
            attnTmp = []
            word_scores = []
            for idx in range(len(self.models)):
                model = self.models[idx]
                model.decoder.attn.applyMask(padMask)
                decOut, decStatesTmp, attn = model.decoder(input, batch_src, contextL[idx], decStatesL[idx])
                decStatesL[idx] = decStatesTmp
                decOutTmp.append(decOut)
                attnTmp.append(attn)
                decOut = decOut.squeeze(0)
                # decOut: (beam*batch) x numWords
                attn["std"] = attn["std"].view(beamSize, batchSize, -1) \
                                     .transpose(0, 1).contiguous()

                # (b) Compute a vector of batch*beam word scores.
                #if not self.copy_attn:
                if True:
                    out = model.generator[0].forward(decOut)
                    out = nn.Softmax()(out)
                else:
                    # Copy Attention Case
                    words = batch.words().t()
                    words = torch.stack([words[i] for i, b in enumerate(beam)])\
                                 .contiguous()
                    attn_copy = attn["copy"].view(beamSize, batchSize, -1) \
                                            .transpose(0, 1).contiguous()

                    out, c_attn_t \
                        = self.model.generator.forward(
                            decOut, attn_copy.view(-1, batch_src.size(0)))

                    for b in range(out.size(0)):
                        for c in range(c_attn_t.size(1)):
                            v = self.align[words[0, c].data[0]]
                            if v != onmt.Constants.PAD:
                                out[b, v] += c_attn_t[b, c]
                    out = out.log()

                #score = out.view(beamSize, batchSize, -1).transpose(0, 1).contiguous()
                # batch x beam x numWords
                word_scores.append(out.clone())
            word_score = torch.stack(word_scores).sum(0).squeeze(0).div_(len(self.models))
            mean_score = word_score.view(beamSize, batchSize, -1).transpose(0, 1).contiguous()

            scores = torch.log(mean_score) 
            #scores = self.models[0].generator[1].forward(mean_score)

            # (c) Advance each beam.
            active = []

            for b in range(batchSize):
                if b in beam_done:
                    continue
                beam[b].advance(scores.data[b],
                                          attn["std"].data[b])
                is_done = beam[b].done()
                if not is_done:
                    active += [b]
                for dec in decStatesL: 
                    dec.beamUpdate_(b, beam[b].getCurrentOrigin(), beamSize)
                if is_done:
                    beam_done.append(b)
            #if not active:
                #break
            if len(beam_done) == batchSize:
                break

        #  (4) package everything up
        allHyp, allScores, allAttn = [], [], []
        n_best = self.opt.n_best

        for b in range(batchSize):
            scores, ks = beam[b].sortFinished()
            #scores, ks = beam[b].sortBest()

            allScores += [scores[:n_best]]
            hyps, attn = [], []
            for i, (times, k) in enumerate(ks[:n_best]):
                hyp, att = beam[b].getHyp(times, k)
                hyps.append(hyp)
                attn.append(att)
            allHyp += [hyps]
            if useMasking:
                valid_attn = batch.src.data[:, b, 0].ne(onmt.Constants.PAD) \
                                                .nonzero().squeeze(1)
                attn = [a.index_select(1, valid_attn) for a in attn]
            allAttn += [attn]

            # For debugging visualization.
            if self.beam_accum:
                self.beam_accum["beam_parent_ids"].append(
                    [t.tolist()
                     for t in beam[b].prevKs])
                self.beam_accum["scores"].append([
                    ["%4f" % s for s in t.tolist()]
                    for t in beam[b].allScores][1:])
                self.beam_accum["predicted_ids"].append(
                    [[idx for idx in t.tolist()]
                     for t in beam[b].nextYs][1:])
                self.beam_accum["predicted_labels"].append(
                    [[self.tgt_dict.getLabel(idx)
                      for idx in t.tolist()]
                     for t in beam[b].nextYs][1:])
                beam[b].finished.sort(key=lambda a:-a[0])
                self.beam_accum['finished'].append(beam[b].finished)

        return allHyp, allScores, allAttn, goldScores