def translateBatch(self, batch, dataset): beam_size = self.opt.beam_size batch_size = batch.batch_size # (1) Run the encoder on the src. _, src_lengths = batch.src src = onmt.IO.make_features(batch, 'src') encStates, context = self.model.encoder(src, src_lengths) decStates = self.model.decoder.init_decoder_state( src, context, encStates) # (1b) Initialize for the decoder. def var(a): return Variable(a, volatile=True) def rvar(a): return var(a.repeat(1, beam_size, 1)) # Repeat everything beam_size times. context = rvar(context.data) src = rvar(src.data) srcMap = rvar(batch.src_map.data) decStates.repeat_beam_size_times(beam_size) scorer = onmt.GNMTGlobalScorer(self.alpha, self.beta) beam = [onmt.Beam(beam_size, n_best=self.opt.n_best, cuda=self.opt.cuda, vocab=self.fields["tgt"].vocab, global_scorer=scorer) for __ in range(batch_size)] # (2) run the decoder to generate sentences, using beam search. def bottle(m): return m.view(batch_size * beam_size, -1) def unbottle(m): return m.view(beam_size, batch_size, -1) for i in range(self.opt.max_sent_length): if all((b.done() for b in beam)): break # Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. inp = var(torch.stack([b.getCurrentState() for b in beam]) .t().contiguous().view(1, -1)) # Turn any copied words to UNKs # 0 is unk if self.copy_attn: inp = inp.masked_fill( inp.gt(len(self.fields["tgt"].vocab) - 1), 0) # Temporary kludge solution to handle changed dim expectation # in the decoder inp = inp.unsqueeze(2) # Run one step. decOut, decStates, attn = \ self.model.decoder(inp, context, decStates) decOut = decOut.squeeze(0) # decOut: beam x rnn_size # (b) Compute a vector of batch*beam word scores. if not self.copy_attn: out = self.model.generator.forward(decOut).data out = unbottle(out) # beam x tgt_vocab else: out = self.model.generator.forward(decOut, attn["copy"].squeeze(0), srcMap) # beam x (tgt_vocab + extra_vocab) out = dataset.collapse_copy_scores( unbottle(out.data), batch, self.fields["tgt"].vocab) # beam x tgt_vocab out = out.log() # (c) Advance each beam. for j, b in enumerate(beam): b.advance(out[:, j], unbottle(attn["std"]).data[:, j]) decStates.beam_update(j, b.getCurrentOrigin(), beam_size) if "tgt" in batch.__dict__: allGold = self._runTarget(batch, dataset) else: allGold = [0] * batch_size # (3) Package everything up. allHyps, allScores, allAttn = [], [], [] for b in beam: n_best = self.opt.n_best scores, ks = b.sortFinished(minimum=n_best) hyps, attn = [], [] for i, (times, k) in enumerate(ks[:n_best]): hyp, att = b.getHyp(times, k) hyps.append(hyp) attn.append(att) allHyps.append(hyps) allScores.append(scores) allAttn.append(attn) return allHyps, allScores, allAttn, allGold
def translateBatch(self, batch): beamSize = self.opt.beam_size batchSize = batch.batchSize # (1) run the encoder on the src useMasking = (self._type == "text") encStatesL = [] decStatesL = [] contextL = [] src_lengths = batch.lengths.data.view(-1).tolist() globalScorer = onmt.GNMTGlobalScorer(self.opt.alpha, self.opt.beta) beam = [onmt.Beam(beamSize, self.opt.cuda, globalScorer, alpha=self.opt.alpha, beta=self.opt.beta, tgtDict=self.tgt_dict) for i in range(batchSize)] for model in self.models: encStates, context = model.encoder(batch.src, lengths=batch.lengths) encStates = model.init_decoder_state(context, encStates) decoder = model.decoder attentionLayer = decoder.attn # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding padMask = batch.words().data.eq(onmt.Constants.PAD).t() attentionLayer.applyMask(padMask) # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model ## for sanity check #if batch.tgt is not None: # decStates = encStates # mask(padMask.unsqueeze(0)) # decOut, decStates, attn = self.model.decoder(batch.tgt[:-1], # batch.src, # context, # decStates) # for dec_t, tgt_t in zip(decOut, batch.tgt[1:].data): # gen_t = self.model.generator.forward(dec_t) # tgt_t = tgt_t.unsqueeze(1) # scores = gen_t.data.gather(1, tgt_t) # scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) # goldScores += scores # for sanity check # (3) run the decoder to generate sentences, using beam search # Each hypothesis in the beam uses the same context # and initial decoder state context = Variable(context.data.repeat(1, beamSize, 1)) contextL.append(context.clone()) goldScores = context.data.new(batchSize).zero_() decStates = encStates decStates.repeatBeam_(beamSize) decStatesL.append(decStates) batch_src = Variable(batch.src.data.repeat(1, beamSize, 1)) padMask = batch.src.data[:, :, 0].eq(onmt.Constants.PAD).t() \ .unsqueeze(0) \ .repeat(beamSize, 1, 1) # (3b) The main loop beam_done = [] for i in range(self.opt.max_sent_length): # (a) Run RNN decoder forward one step. #mask(padMask) input = torch.stack([b.getCurrentState() for b in beam])\ .t().contiguous().view(1, -1) input = Variable(input, volatile=True) decOutTmp = [] attnTmp = [] word_scores = [] for idx in range(len(self.models)): model = self.models[idx] model.decoder.attn.applyMask(padMask) decOut, decStatesTmp, attn = model.decoder(input, batch_src, contextL[idx], decStatesL[idx]) decStatesL[idx] = decStatesTmp decOutTmp.append(decOut) attnTmp.append(attn) decOut = decOut.squeeze(0) # decOut: (beam*batch) x numWords attn["std"] = attn["std"].view(beamSize, batchSize, -1) \ .transpose(0, 1).contiguous() # (b) Compute a vector of batch*beam word scores. #if not self.copy_attn: if True: out = model.generator[0].forward(decOut) out = nn.Softmax()(out) else: # Copy Attention Case words = batch.words().t() words = torch.stack([words[i] for i, b in enumerate(beam)])\ .contiguous() attn_copy = attn["copy"].view(beamSize, batchSize, -1) \ .transpose(0, 1).contiguous() out, c_attn_t \ = self.model.generator.forward( decOut, attn_copy.view(-1, batch_src.size(0))) for b in range(out.size(0)): for c in range(c_attn_t.size(1)): v = self.align[words[0, c].data[0]] if v != onmt.Constants.PAD: out[b, v] += c_attn_t[b, c] out = out.log() #score = out.view(beamSize, batchSize, -1).transpose(0, 1).contiguous() # batch x beam x numWords word_scores.append(out.clone()) word_score = torch.stack(word_scores).sum(0).squeeze(0).div_(len(self.models)) mean_score = word_score.view(beamSize, batchSize, -1).transpose(0, 1).contiguous() scores = torch.log(mean_score) #scores = self.models[0].generator[1].forward(mean_score) # (c) Advance each beam. active = [] for b in range(batchSize): if b in beam_done: continue beam[b].advance(scores.data[b], attn["std"].data[b]) is_done = beam[b].done() if not is_done: active += [b] for dec in decStatesL: dec.beamUpdate_(b, beam[b].getCurrentOrigin(), beamSize) if is_done: beam_done.append(b) #if not active: #break if len(beam_done) == batchSize: break # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best for b in range(batchSize): scores, ks = beam[b].sortFinished() #scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn = [], [] for i, (times, k) in enumerate(ks[:n_best]): hyp, att = beam[b].getHyp(times, k) hyps.append(hyp) attn.append(att) allHyp += [hyps] if useMasking: valid_attn = batch.src.data[:, b, 0].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] # For debugging visualization. if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) self.beam_accum["scores"].append([ ["%4f" % s for s in t.tolist()] for t in beam[b].allScores][1:]) self.beam_accum["predicted_ids"].append( [[idx for idx in t.tolist()] for t in beam[b].nextYs][1:]) self.beam_accum["predicted_labels"].append( [[self.tgt_dict.getLabel(idx) for idx in t.tolist()] for t in beam[b].nextYs][1:]) beam[b].finished.sort(key=lambda a:-a[0]) self.beam_accum['finished'].append(beam[b].finished) return allHyp, allScores, allAttn, goldScores