def translate_batch(self, batch): torch.set_grad_enabled(False) # Batch size is in different location depending on data. beam_size = self.opt.beam_size batch_size = batch.size gold_scores = batch.get('source').data.new(batch_size).float().zero_() gold_words = 0 allgold_scores = [] if batch.has_target: # Use the first model to decode model_ = self.models[0] gold_words, gold_scores, allgold_scores = model_.decode(batch) # (3) Start decoding # time x batch * beam # initialize the beam beam = [onmt.Beam(beam_size, self.bos_id, self.opt.cuda, self.opt.sampling) for k in range(batch_size)] batch_idx = list(range(batch_size)) remaining_sents = batch_size decoder_states = dict() for i in range(self.n_models): decoder_states[i] = self.models[i].create_decoder_state(batch, beam_size) if self.opt.lm: lm_decoder_states = self.lm_model.create_decoder_state(batch, beam_size) for i in range(self.opt.max_sent_length): # Prepare decoder input. # input size: 1 x ( batch * beam ) input = torch.stack([b.getCurrentState() for b in beam if not b.done]).t().contiguous().view(1, -1) decoder_input = input # require batch first for everything outs = dict() attns = dict() for k in range(self.n_models): # decoder_hidden, coverage = self.models[k].decoder.step(decoder_input.clone(), decoder_states[k]) # run decoding on the model decoder_output = self.models[k].step(decoder_input.clone(), decoder_states[k]) # extract the required tensors from the output (a dictionary) outs[k] = decoder_output['log_prob'] attns[k] = decoder_output['coverage'] # for ensembling models out = self._combine_outputs(outs) attn = self._combine_attention(attns) # for lm fusion if self.opt.lm: lm_decoder_output = self.lm_model.step(decoder_input.clone(), lm_decoder_states) # fusion lm_out = lm_decoder_output['log_prob'] # out = out + 0.3 * lm_out out = lm_out word_lk = out.view(beam_size, remaining_sents, -1) \ .transpose(0, 1).contiguous() attn = attn.view(beam_size, remaining_sents, -1) \ .transpose(0, 1).contiguous() active = [] for b in range(batch_size): if beam[b].done: continue idx = batch_idx[b] if not beam[b].advance(word_lk.data[idx], attn.data[idx]): active += [b] for j in range(self.n_models): decoder_states[j].update_beam(beam, b, remaining_sents, idx) if self.opt.lm: lm_decoder_states.update_beam(beam, b, remaining_sents, idx) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences active_idx = self.tt.LongTensor([batch_idx[k] for k in active]) batch_idx = {beam: idx for idx, beam in enumerate(active)} for j in range(self.n_models): decoder_states[j].prune_complete_beam(active_idx, remaining_sents) if self.opt.lm: lm_decoder_states.prune_complete_beam(active_idx, remaining_sents) remaining_sents = len(active) # (4) package everything up all_hyp, all_scores, all_attn = [], [], [] n_best = self.opt.n_best all_lengths = [] for b in range(batch_size): scores, ks = beam[b].sortBest() all_scores += [scores[:n_best]] hyps, attn, length = zip(*[beam[b].getHyp(k) for k in ks[:n_best]]) all_hyp += [hyps] all_lengths += [length] # if(src_data.data.dim() == 3): if self.opt.encoder_type == 'audio': valid_attn = decoder_states[0].original_src.narrow(2, 0, 1).squeeze(2)[:, b].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) else: valid_attn = decoder_states[0].original_src[:, b].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] all_attn += [attn] if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) self.beam_accum["scores"].append([ ["%4f" % s for s in t.tolist()] for t in beam[b].all_scores][1:]) self.beam_accum["predicted_ids"].append( [[self.tgt_dict.getLabel(id) for id in t.tolist()] for t in beam[b].nextYs][1:]) torch.set_grad_enabled(True) return all_hyp, all_scores, all_attn, all_lengths, gold_scores, gold_words, allgold_scores
def translateBatch(self, srcBatch, tgtBatch): batchSize = srcBatch[0].size(1) beamSize = self.opt.beam_size knntime = 0.0 # (1) run the encoder on the src encStates, context = self.model.encoder(srcBatch) srcBatch = srcBatch[0] # drop the lengths needed for encoder rnnSize = context.size(2) encStates = (self.model._fix_enc_hidden(encStates[0]), self.model._fix_enc_hidden(encStates[1])) # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding padMask = srcBatch.data.eq(onmt.Constants.PAD).t() def applyContextMask(m): if isinstance(m, onmt.modules.GlobalAttention): m.applyMask(padMask) # (3) run the decoder to generate sentences, using beam search # Expand tensors for each beam. # decStates = encStates if self.opt.use_lm: lm_hidden = self.LangModel.initialize_hidden(1, batchSize) context = Variable(context.data.repeat(1, beamSize, 1)) decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)), Variable(encStates[1].data.repeat(1, beamSize, 1))) beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)] decOut = self.model.make_init_decoder_output(context) padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat( beamSize, 1, 1) batchIdx = list(range(batchSize)) remainingSents = batchSize for i in range(self.opt.max_sent_length): self.model.decoder.apply(applyContextMask) # Prepare decoder input. input_ = torch.stack([ b.getCurrentState() for b in beam if not b.done() ]).t().contiguous().view(1, -1) input_var = Variable(input_, volatile=True) decOut, decStates, attn = self.model.decoder( input_var, decStates, context, decOut) # decOut: 1 x (beam*batch) x numWords decOut = decOut.squeeze(0) out = self.model.generator.forward(decOut) if self.opt.use_lm: lm_output, lm_hidden, _, _ = self.LangModel( input_var, lm_hidden) lm_output = torch.log(lm_output + 1e-12) beg = time.time() scores = self._get_scores(out, self.target_embeddings) diff = time.time() - beg knntime += diff if self.opt.use_lm: scores += 0.2 * lm_output.squeeze(0) wordLk = scores.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() attn = attn.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done(): continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn.data[idx]): active += [b] for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view(-1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select( 1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t): # select only the remaining active sentences view = t.data.view(-1, remainingSents, rnnSize) newSize = list(t.size()) newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents return Variable(view.index_select(1, activeIdx) \ .view(*newSize), volatile=True) decStates = (updateActive(decStates[0]), updateActive(decStates[1])) decOut = updateActive(decOut) context = updateActive(context) padMask = padMask.index_select(1, activeIdx) remainingSents = len(active) # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] valid_attn = srcBatch.data[:, b].ne( onmt.Constants.PAD).nonzero().squeeze(1) hyps, attn = zip( *[beam[b].getHyp(times, k) for (times, k) in ks[:n_best]]) attn = [a.index_select(1, valid_attn) for a in attn] allHyp += [hyps] allAttn += [attn] return allHyp, allScores, allAttn, knntime
def translate_batch(self, batch, length_batch=None): torch.set_grad_enabled(False) # Batch size is in different location depending on data. beam_size = self.opt.beam_size batch_size = batch.size gold_scores = batch.get('source').data.new(batch_size).float().zero_() gold_words = 0 allgold_scores = [] prefix = None if batch.has_target: # Use the first model to decode model_ = self.models[0] gold_words, gold_scores, allgold_scores = model_.decode(batch) # batch.tensors['target_output'] = # remove EOS prefix = batch.tensors['target_output'][:-1] print('PREFIX', self.build_target_tokens(batch.tensors['target_output'])) # (3) Start decoding # time x batch * beam # initialize the beam beam = [ onmt.Beam(beam_size, self.opt.cuda, prefix=prefix, prefix_score=allgold_scores) for k in range(batch_size) ] batch_idx = list(range(batch_size)) remaining_sents = batch_size decoder_states = dict() for i in range(self.n_models): decoder_states[i] = self.models[i].create_decoder_state( batch, beam_size, length_batch) if batch.has_target: prefix_states = [] for state in beam[i].get_all_states(): prefix_states.append( torch.stack([state]).t().contiguous().view(1, -1)) for p in prefix_states: decoder_output = self.models[i].step( p.clone(), decoder_states[i]) # print('prefix', p) # can clear prefices from beam beam = [onmt.Beam(beam_size, self.opt.cuda) for k in range(batch_size)] if self.opt.lm: lm_decoder_states = self.lm_model.create_decoder_state( batch, beam_size) max_len = self.opt.max_sent_length if batch.has_target: max_len -= len(prefix) # print(max_len, len(prefix), len(prefix_states)) for current_depth in range(max_len): # EOS here? # Prepare decoder input. # print(current_depth, max_len) # input size: 1 x ( batch * beam ) input = torch.stack([ b.getCurrentState() for b in beam if not b.done ]).t().contiguous().view(1, -1) decoder_input = input # require batch first for everything outs = dict() attns = dict() for k in range(self.n_models): # decoder_hidden, coverage = self.models[k].decoder.step(decoder_input.clone(), decoder_states[k]) # run decoding on the model if not (current_depth == 0 and batch.has_target): # print('decoding ', self.tgt_dict.convertToLabels(decoder_input.data[0], 10)) decoder_output = self.models[k].step( decoder_input.clone(), decoder_states[k], current_depth + (len(prefix) if prefix is not None else 0)) # print('new input', decoder_input) # else: # print('skipped last of prefix') # extract the required tensors from the output (a dictionary) outs[k] = decoder_output['log_prob'] # print('outs when decoding ', outs[k]) attns[k] = decoder_output['coverage'] # for ensembling models out = self._combine_outputs(outs) attn = self._combine_attention(attns) # for lm fusion if self.opt.lm: lm_decoder_output = self.lm_model.step(decoder_input.clone(), lm_decoder_states) # fusion lm_out = lm_decoder_output['log_prob'] # out = out + 0.3 * lm_out out = lm_out word_lk = out.view(beam_size, remaining_sents, -1) \ .transpose(0, 1).contiguous() attn = attn.view(beam_size, remaining_sents, -1) \ .transpose(0, 1).contiguous() active = [] for seq_idx in range(batch_size): if beam[seq_idx].done: continue idx = batch_idx[seq_idx] # Added two conditions for constrained decoding if self.force_target_length and length_batch and length_batch[ seq_idx] == current_depth: # TODO: offset by prefix len # finish hyp b since it has desired length beam[seq_idx].advanceEOS(word_lk.data[idx], attn.data[idx]) elif self.force_target_length and length_batch: # ignore EOS since we are not at the end word_lk[idx].select(1, onmt.Constants.EOS).zero_().add_(-1000) if not beam[seq_idx].advance(word_lk.data[idx], attn.data[idx]): active += [seq_idx] elif not beam[seq_idx].advance(word_lk.data[idx], attn.data[idx], start_from_prefix=current_depth == 0): active += [seq_idx] for j in range(self.n_models): decoder_states[j].update_beam(beam, seq_idx, remaining_sents, idx) if self.opt.lm: lm_decoder_states.update_beam(beam, seq_idx, remaining_sents, idx) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences active_idx = self.tt.LongTensor([batch_idx[k] for k in active]) batch_idx = {beam: idx for idx, beam in enumerate(active)} for j in range(self.n_models): decoder_states[j].prune_complete_beam(active_idx, remaining_sents) if self.opt.lm: lm_decoder_states.prune_complete_beam(active_idx, remaining_sents) remaining_sents = len(active) # if commit_depth == 0: # for seq_idx in range(batch_size): # beam[seq_idx].commit(buffer=decoding_buffer_depths) # # elif commit_depth > 0: # raise NotImplementedError # (4) package everything up all_hyp, all_scores, all_attn, all_lk = [], [], [], [] n_best = self.opt.n_best all_lengths = [] for seq_idx in range(batch_size): scores, ks = beam[seq_idx].sortBest() all_scores += [scores[:n_best]] hyps, attn, length = zip(*[ beam[seq_idx].getHyp(k, return_att=False) for k in ks[:n_best] ]) # append given prefix to beginning of output if prefix is not None: prefix_ = [p_[seq_idx] for p_ in prefix.tolist()] hyps = [prefix_ + hyp for hyp in hyps] all_hyp += [hyps] all_lengths += [length] # if(src_data.data.dim() == 3): if self.opt.encoder_type == 'audio': valid_attn = decoder_states[0].original_src.narrow(2, 0, 1).squeeze(2)[:, seq_idx].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) else: valid_attn = decoder_states[0].original_src[:, seq_idx].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) # attn = [a.index_select(1, valid_attn) for a in attn] # all_attn += [attn] if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[seq_idx].prevKs]) self.beam_accum["scores"].append( [["%4f" % s for s in t.tolist()] for t in beam[seq_idx].all_scores][1:]) self.beam_accum["predicted_ids"].append( [[self.tgt_dict.getLabel(id) for id in t.tolist()] for t in beam[seq_idx].nextYs][1:]) all_scores_ = [beam[seq_idx].allScores[-1]] # take last my_indices = range(beam[seq_idx].size) for j in range(len(beam[seq_idx].prevKs) - 1, -1, -1): my_indices = beam[seq_idx].prevKs[j][my_indices] all_scores_.append(beam[seq_idx].allScores[j][my_indices]) # print(all_scores_[-1]) all_lk.append(all_scores_[::-1]) torch.set_grad_enabled(True) return all_hyp, all_scores, all_attn, all_lengths, gold_scores, gold_words, allgold_scores, all_lk
def translateBatch(self, srcBatch, tgtBatch): # Batch size is in different location depending on data. beamSize = self.opt.beam_size # (1) run the encoder on the src encStates, context, emb = self.model.encoder(srcBatch) # Drop the lengths needed for encoder. srcBatch = srcBatch[0] batchSize = self._getBatchSize(srcBatch) rnnSize = context.size(2) decoder = self.model.decoder attentionLayer = decoder.attn if hasattr(decoder, 'attn') else None if isinstance(self.model.encoder, Encoder): if isinstance(encStates, tuple): encStates = tuple(self.model.brnn_merge_concat(encStates[i]) for i in range(len(encStates))) else: encStates = self.model.brnn_merge_concat(encStates) if encStates.size(0) < decoder.layers: encStates = encStates.repeat(decoder.layers, 1, 1) else: encStates = Variable(encStates.data.new(*encStates.size()).zero_(), requires_grad=False) # encStates = encStates.unsqueeze(0).repeat(decoder.layers, 1, 1) useMasking = not isinstance(decoder, SGUDecoder) #self._type.endswith("text") # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding padMask = None if useMasking: padMask = srcBatch.data.eq(onmt.Constants.PAD).t() def mask(padMask): if useMasking: attentionLayer.applyMask(padMask) # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model goldScores = context.data.new(batchSize).zero_() if tgtBatch is not None: decStates = encStates mask(padMask) initOutput = self.model.make_init_decoder_output(context) decOut, decStates, attn = self.model.decoder( tgtBatch[:-1], decStates, context, initOutput) for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data): gen_t = self.model.generator.forward(dec_t) tgt_t = tgt_t.unsqueeze(1) scores = gen_t.data.gather(1, tgt_t) scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) goldScores += scores # (3) run the decoder to generate sentences, using beam search # Expand tensors for each beam. context = Variable(context.data.repeat(1, beamSize, 1)) if isinstance(emb, PackedSequence): emb = Variable(unpack(emb)[0].data.repeat(1, beamSize, 1)) else: emb = Variable(emb.data.repeat(1, beamSize, 1)) if isinstance(encStates, tuple): decStates = tuple(Variable(encStates[i].data.repeat(1, beamSize, 1)) for i in range(len(encStates))) else: decStates = Variable(encStates.data.repeat(1, beamSize, 1)) beam = [onmt.Beam(beamSize, self.opt.cuda) for _ in range(batchSize)] decOut = self.model.make_init_decoder_output(context) if useMasking: padMask = srcBatch.data.eq( onmt.Constants.PAD).t() \ .unsqueeze(0) \ .repeat(beamSize, 1, 1) batchIdx = list(range(batchSize)) remainingSents = batchSize activs = [] for i in range(self.opt.max_sent_length): mask(padMask) # Prepare decoder input. input = torch.stack([b.getCurrentState() for b in beam if not b.done]).t().contiguous().view(1, -1) #if self.model.decoder.log: # decOut, decStates, attn, activ = self.model.decoder( # Variable(input, volatile=True), decStates, context, decOut, emb) # activs.append(activ) #else: decOut, decStates, attn = self.model.decoder( Variable(input, volatile=True), decStates, context, decOut) # decOut: 1 x (beam*batch) x numWords decOut = decOut.squeeze(0) out = self.model.generator.forward(decOut) # batch x beam x numWords wordLk = out.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() attn = attn.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn.data[idx]): active += [b] #print(decStates) if not isinstance(decStates, tuple): decStates = tuple(decStates.unsqueeze(0)) #print(decStates) for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view(-1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select( 1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t, lastSize=rnnSize): # select only the remaining active sentences view = t.data.view(-1, remainingSents, lastSize) newSize = list(t.size()) newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents return Variable(view.index_select(1, activeIdx) .view(*newSize), volatile=True) decStates = tuple(updateActive(decStates[i]) for i in range(len(decStates))) if len(decStates) == 1: # The GRU needs only one matrix as hidden state decStates = decStates[0] decOut = updateActive(decOut) context = updateActive(context) emb = updateActive(emb, emb.size(2)) if useMasking: padMask = padMask.index_select(1, activeIdx) remainingSents = len(active) # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best if activs: new_activs = torch.zeros((2, activs[0].size(1), len(activs))) for i, activ in enumerate(activs): new_activs[:, :activ.size(1), i] = activ.data activs = new_activs sys.stderr.write("r=\n") for i in range(activs.size(1)): for j in range(activs.size(2)): sys.stderr.write(str(activs[0][i][j]) + " ") sys.stderr.write("\n") sys.stderr.write("z=\n") for i in range(activs.size(1)): for j in range(activs.size(2)): sys.stderr.write(str(activs[1][i][j]) + " ") sys.stderr.write("\n") for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]]) allHyp += [hyps] if useMasking: valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) self.beam_accum["scores"].append([ ["%4f" % s for s in t.tolist()] for t in beam[b].allScores][1:]) self.beam_accum["predicted_ids"].append( [[self.tgt_dict.getLabel(id) for id in t.tolist()] for t in beam[b].nextYs][1:]) return allHyp, allScores, allAttn, goldScores
def translateBatch(self, srcBatch): # Batch size is in different location depending on data. beamSize = self.beam_size # (1) run the encoders on the src states, context = self.model.encoder(srcBatch) # reshape the states encStates = (self.model._fix_enc_hidden(states[0]), self.model._fix_enc_hidden(states[1])) # Drop the lengths needed for encoder. srcBatch = srcBatch[0] batchSize = self._getBatchSize(srcBatch) rnnSize = context.size(2) #~ decoder = self.model.decoder #~ attentionLayer = decoder.attn.current() useMasking = (batchSize > 1) # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding padMask = None if useMasking: padMask = srcBatch.data.eq(onmt.Constants.PAD).t() def mask(padMask): if useMasking: #~ attentionLayer.applyMask(padMask) self.model.decoder.attn.current().applyMask(padMask) # (2) run the decoder to generate sentences, using beam search # Expand tensors for each beam. context = Variable(context.data.repeat(1, beamSize, 1)) decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)), Variable(encStates[1].data.repeat(1, beamSize, 1))) # Initialize the beams # Each beam is an object containing the translation status for each sentence in the batch beam = [onmt.Beam(beamSize, self.cuda) for k in range(batchSize)] # Here we prepare the decoder output (zeroes) # For input feeding decOuts = self.model.make_init_decoder_output(context) if useMasking: padMask = srcBatch.data.eq( onmt.Constants.PAD).t() \ .unsqueeze(0) \ .repeat(beamSize, 1, 1) batchIdx = list(range(batchSize)) remainingSents = batchSize #~ if self.model.copy_pointer: src = Variable(srcBatch.data.repeat(1, beamSize)) # time x batch * beam for i in range(self.max_sent_length): mask(padMask) # Prepare decoder input. input = torch.stack([ b.getCurrentState() for b in beam if not b.done ]).t().contiguous().view(1, -1) # compute new decoder output (distribution) decOuts, decStates, attn = self.model.decoder( Variable(input, volatile=True), decStates, context, decOuts) # decOut: 1 x (beam*batch) x numWords decOuts = decOuts.squeeze(0) attn_ = attn attn = attn.squeeze(0) if self.model.copy_pointer: out = self.model.generator.forward(decOuts, attn_, src) else: out = self.model.generator.forward(decOuts) # batch x beam x numWords wordLk = out.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() attn = attn.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn.data[idx]): active += [b] for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view(-1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select( 1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t, size): # select only the remaining active sentences view = t.data.view(-1, remainingSents, size) newSize = list(t.size()) newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents return Variable(view.index_select(1, activeIdx).view(*newSize), volatile=True) decStates = (updateActive(decStates[0], rnnSize), updateActive(decStates[1], rnnSize)) decOuts = updateActive(decOuts, rnnSize) context = updateActive(context, rnnSize) # src size: time x batch * beam src_data = src.data.view(-1, remainingSents) newSize = list(src.size()) newSize[-1] = newSize[-1] * len(activeIdx) // remainingSents src = Variable(src_data.index_select(1, activeIdx).view(*newSize), volatile=True) #~ srcBatch = Variable(srcBatch.data.repeat(1, beamSize)) if useMasking: padMask = padMask.index_select(1, activeIdx) remainingSents = len(active) # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.n_best for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]]) allHyp += [hyps] if useMasking: valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] if useMasking: self.model.decoder.attn.current().applyMask(None) return allHyp, allScores, allAttn
def translateBatch(self, batch): beamSize = self.opt.beam_size batchSize = batch.batchSize # (1) run the encoder on the src encStates, context = self.model.encoder(batch.src) encStates = self.model.init_decoder_state(context, encStates) decoder = self.model.decoder attentionLayer = decoder.attn useMasking = (self._type == "text") # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding padMask = None if useMasking: padMask = batch.words().data.eq(onmt.Constants.PAD).t() def mask(padMask): if useMasking: attentionLayer.applyMask(padMask) # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model goldScores = context.data.new(batchSize).zero_() if batch.tgt is not None: decStates = encStates mask(padMask) decOut, decStates, attn = decoder(batch.tgt[:-1], context, decStates) for dec_t, tgt_t in zip(decOut, batch.tgt[1:].data): gen_t = self.model.generator.forward(dec_t) tgt_t = tgt_t.unsqueeze(1) scores = gen_t.data.gather(1, tgt_t) scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) goldScores += scores # (3) run the decoder to generate sentences, using beam search # Each hypothesis in the beam uses the same context # and initial decoder state context = Variable(context.data.repeat(1, beamSize, 1)) batch_src = Variable(batch.src.data.repeat(1, beamSize, 1)) decStates = encStates decStates.repeatBeam_(beamSize) beam = [onmt.Beam(beamSize, self.opt.cuda) for _ in range(batchSize)] if useMasking: padMask = batch.src.data[:, :, 0].eq( onmt.Constants.PAD).t() \ .unsqueeze(0) \ .repeat(beamSize, 1, 1) # (3b) The main loop for i in range(self.opt.max_sent_length): # (a) Run RNN decoder forward one step. mask(padMask) input = torch.stack([b.getCurrentState() for b in beam])\ .t().contiguous().view(1, -1) input = Variable(input, volatile=True) decOut, decStates, attn = self.model.decoder( input, batch_src, context, decStates) decOut = decOut.squeeze(0) # decOut: (beam*batch) x numWords attn["std"] = attn["std"].view(beamSize, batchSize, -1) \ .transpose(0, 1).contiguous() # (b) Compute a vector of batch*beam word scores. if not self.copy_attn: out = self.model.generator.forward(decOut) else: # Copy Attention Case words = batch.words().t() words = torch.stack([words[i] for i, b in enumerate(beam)])\ .contiguous() attn_copy = attn["copy"].view(beamSize, batchSize, -1) \ .transpose(0, 1).contiguous() out, c_attn_t \ = self.model.generator.forward( decOut, attn_copy.view(-1, batch_src.size(0))) for b in range(out.size(0)): for c in range(c_attn_t.size(1)): v = self.align[words[0, c].data[0]] if v != onmt.Constants.PAD: out[b, v] += c_attn_t[b, c] out = out.log() word_scores = out.view(beamSize, batchSize, -1) \ .transpose(0, 1).contiguous() # batch x beam x numWords # (c) Advance each beam. active = [] for b in range(batchSize): is_done = beam[b].advance(word_scores.data[b], attn["std"].data[b]) if not is_done: active += [b] decStates.beamUpdate_(b, beam[b].getCurrentOrigin(), beamSize) if not active: break # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn = [], [] for k in ks[:n_best]: hyp, att = beam[b].getHyp(k) hyps.append(hyp) attn.append(att) allHyp += [hyps] if useMasking: valid_attn = batch.src.data[:, b, 0].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] # For debugging visualization. if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) self.beam_accum["scores"].append( [["%4f" % s for s in t.tolist()] for t in beam[b].allScores][1:]) self.beam_accum["predicted_ids"].append( [[self.tgt_dict.getLabel(id) for id in t.tolist()] for t in beam[b].nextYs][1:]) return allHyp, allScores, allAttn, goldScores
def translate_batch_external(batch, beamSize, model, cuda, rb_init_token, rb_init_tgt, max_sent_length, n_best): srcBatch, tgtBatch, src_rb, tgt_rb = batch batchSize = srcBatch.size(0) # (1) run the encoder on the src # padding is dealt with by variable-length cudnn.RNN encStates, context = model.encoder(srcBatch, src_rb) # # have to execute the encoder manually to deal with padding # encStates = None # context = [] # for srcBatch_t in srcBatch.chunk(srcBatch.size(1), dim=1): # encStates, context_t = self.model.encoder(srcBatch_t, hidden=encStates) # batchPadIdx = srcBatch_t.data.squeeze(1).eq(onmt.Constants.PAD).nonzero() # if batchPadIdx.nelement() > 0: # batchPadIdx = batchPadIdx.squeeze(1) # encStates[0].data.index_fill_(1, batchPadIdx, 0) # encStates[1].data.index_fill_(1, batchPadIdx, 0) # context += [context_t] # context = torch.cat(context) rnnSize = context.size(2) encStates = (_fix_enc_hidden(encStates[0], model.encoder.num_directions), _fix_enc_hidden(encStates[1], model.encoder.num_directions)) # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding padMask = srcBatch.data.eq(onmt.Constants.PAD) rb_token_mask = torch.zeros(padMask.size(0), 1).byte() if cuda: rb_token_mask = rb_token_mask.cuda() if rb_init_token: padMask = torch.cat([rb_token_mask, padMask], 1) def applyContextMask(m): if isinstance(m, onmt.modules.GlobalAttention): m.applyMask(padMask) # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model goldScores = context.data.new(batchSize).zero_() re_padMask = 1 - padMask re_padMask = re_padMask.float() context_t = context.transpose(0, 1).data masked_context = context_t * re_padMask.unsqueeze(2).expand( re_padMask.size(0), re_padMask.size(1), context_t.size(2)) sent_len = torch.sum(re_padMask, 1).squeeze(1) representation = torch.div( torch.sum(masked_context, 1).squeeze(1), sent_len.unsqueeze(1).expand(sent_len.size(0), context.size(2))) if tgtBatch is not None: if rb_init_tgt: new_tgt_batch = tgtBatch[:, 1:] tgt_rb_token = tgt_rb.unsqueeze(1) + model.decoder.dict_size tgtBatch = torch.cat([tgt_rb_token, new_tgt_batch], 1) decStates = encStates decOut = model.make_init_decoder_output(context) model.decoder.apply(applyContextMask) initOutput = model.make_init_decoder_output(context) decOut, decStates, attn = model.decoder(tgtBatch[:, :-1], tgt_rb, decStates, context, initOutput) for dec_t, tgt_t in zip(decOut.transpose(0, 1), tgtBatch.transpose(0, 1)[1:].data): gen_t = model.generator.forward(dec_t) tgt_t = tgt_t.unsqueeze(1) scores = gen_t.data.gather(1, tgt_t) scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) goldScores += scores # (3) run the decoder to generate sentences, using beam search # Expand tensors for each beam. context = Variable(context.data.repeat(1, beamSize, 1)) decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)), Variable(encStates[1].data.repeat(1, beamSize, 1))) if rb_init_tgt: beam = [ onmt.Beam(beamSize, cuda, tgt_rb[k].data[0]) for k in range(batchSize) ] else: beam = [onmt.Beam(beamSize, cuda) for k in range(batchSize)] decOut = model.make_init_decoder_output(context) padMask = srcBatch.data.eq(onmt.Constants.PAD).unsqueeze(0).repeat( beamSize, 1, 1) rb_token_mask = torch.zeros(padMask.size(0), padMask.size(1), 1).byte() if cuda: rb_token_mask = rb_token_mask.cuda() if rb_init_token: padMask = torch.cat([rb_token_mask, padMask], 2) batchIdx = list(range(batchSize)) remainingSents = batchSize for i in range(max_sent_length): model.decoder.apply(applyContextMask) # Prepare decoder input. input = torch.stack([b.getCurrentState() for b in beam if not b.done]).t().contiguous().view(1, -1) new_tgt_rb = torch.stack([ tgt_rb[i].expand(beamSize) for i, b in enumerate(beam) if not b.done ]).contiguous().view(-1) '''some_done = False data = [] for i, b in enumerate(beam): if b.done: some_done = True else: data.append(tgt_rb.data[i]) if some_done: print data print new_tgt_rb''' decOut, decStates, attn = model.decoder( Variable(input).transpose(0, 1), new_tgt_rb, decStates, context, decOut) # decOut: 1 x (beam*batch) x numWords decOut = decOut.transpose(0, 1).squeeze(0) out = model.generator.forward(decOut) # batch x beam x numWords wordLk = out.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() attn = attn.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn.data[idx]): active += [b] for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view(-1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select(1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences tt = torch.cuda if cuda else torch activeIdx = tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t): # select only the remaining active sentences view = t.data.view(-1, remainingSents, rnnSize) newSize = list(t.size()) newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents return Variable(view.index_select(1, activeIdx) \ .view(*newSize)) decStates = (updateActive(decStates[0]), updateActive(decStates[1])) decOut = updateActive(decOut) context = updateActive(context) padMask = padMask.index_select(1, activeIdx) remainingSents = len(active) # (4) package everything up allHyp, allScores, allAttn = [], [], [] for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] valid_attn = srcBatch.transpose(0, 1).data[:, b].ne( onmt.Constants.PAD).nonzero().squeeze(1) hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]]) attn = [a.index_select(1, valid_attn) for a in attn] allHyp += [hyps] allAttn += [attn] padMask = None model.decoder.apply(applyContextMask) return allHyp, allScores, allAttn, goldScores, representation
def translateBatch(self, batch): beamSize = 15 batchSize = batch.batchSize # (1) run the encoder on the src encStates, context, fertility_vals = self.encoder(batch.src) encStates = self.init_decoder_state(context, encStates) def mask(padMask): self.decoder.attn.applyMask(padMask) # (2) if a target is specified, compute the 'goldScore' (i.e. log likelihood) of the target under the model goldScores = context.data.new(batchSize).zero_() # (3) run the decoder to generate sentences, using beam search # Each hypothesis in the beam uses the same context and initial decoder state context = Variable(context.data.repeat(1, beamSize, 1)) batch_src = Variable(batch.src.data.repeat(1, beamSize, 1)) decStates = encStates decStates.repeatBeam_(beamSize) beam = [onmt.Beam(beamSize, True) for _ in range(batchSize)] padMask = batch.src.data[:, :, 0].eq( onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1) # (3b) The main loop upper_bounds = None for i in range(100): # (a) Run RNN decoder forward one step. mask(padMask) input = torch.stack([b.getCurrentState() for b in beam]).t().contiguous().view(1, -1) input = Variable(input, volatile=True) decOut, decStates, attn, upper_bounds = self.decoder( input, batch_src, context, decStates, upper_bounds=decStates.attn_upper_bounds, test=True) #import pdb; pdb.set_trace() decOut = decOut.squeeze(0) # decOut: (beam*batch) x numWords attn["std"] = attn["std"].view(beamSize, batchSize, -1).transpose(0, 1).contiguous() # (b) Compute a vector of batch*beam word scores. out = self.generator.forward(decOut) word_scores = out.view(beamSize, batchSize, -1).transpose(0, 1).contiguous() # batch x beam x numWords # (c) Advance each beam. active = [] for b in range(batchSize): is_done = beam[b].advance(word_scores.data[b], attn["std"].data[b]) if not is_done: active += [b] decStates.beamUpdate_(b, beam[b].getCurrentOrigin(), beamSize) if not active: break # (4) package everything up allHyp, allScores, allAttn = [], [], [] self.n_best = 1 for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:self.n_best]] hyps, attn = [], [] for k in ks[:self.n_best]: hyp, att = beam[b].getHyp(k) hyps.append(hyp) attn.append(att) allHyp += [hyps] valid_attn = batch.src.data[:, b, 0].ne( onmt.Constants.PAD).nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] self.decoder.attn.applyMaskNone() #print allAttn[0][0].sum(0) return allHyp, allScores, allAttn, goldScores
def translateBatch(self, batch): beamSize = self.opt.beam_size batchSize = batch.batchSize # (1) run the encoder on the src encStates, context = self.model.encoder(batch.src) rnnSize = context.size(2) encStates = self.model.setup_decoder(encStates) decoder = self.model.decoder attentionLayer = decoder.attn useMasking = (self._type == "text") # This mask is applied to the attention model inside the decoder # so that the attention ignores source (padding padMask = None if useMasking: padMask = batch.words().data.eq(onmt.Constants.PAD).t() def mask(padMask): if useMasking: attentionLayer.applyMask(padMask) # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model goldScores = context.data.new(batchSize).zero_() if batch.tgt is not None: decStates = encStates decOut = self.model.make_init_decoder_output(context) mask(padMask) initOutput = self.model.make_init_decoder_output(context) decOut, decStates, attn = self.model.decoder( batch.tgt[:-1], decStates, context, initOutput) for dec_t, tgt_t in zip(decOut, batch.tgt[1:].data): gen_t = self.model.generator.forward(dec_t) tgt_t = tgt_t.unsqueeze(1) scores = gen_t.data.gather(1, tgt_t) scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) goldScores += scores # (3) run the decoder to generate sentences, using beam search # Each hypothesis in the beam uses the same context # and initial decoder state context = Variable(context.data.repeat(1, beamSize, 1)) decStates = tuple([Variable(e.data.repeat(1, beamSize, 1)) for e in encStates]) \ if encStates else None beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)] decOut = self.model.make_init_decoder_output(context) if useMasking: padMask = batch.src.data[:, :, 0].eq( onmt.Constants.PAD).t() \ .unsqueeze(0) \ .repeat(beamSize, 1, 1) batchIdx = list(range(batchSize)) remainingSents = batchSize for i in range(self.opt.max_sent_length): mask(padMask) # Prepare decoder input. input = torch.stack([ b.getCurrentState() for b in beam if not b.done ]).t().contiguous().view(1, -1) decOut, decStates, attn = self.model.decoder( Variable(input, volatile=True), batch.src, decStates, context, decOut) # decOut: 1 x (beam*batch) x numWords decOut = decOut.squeeze(0) attn["std"] = attn["std"].view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() if not self.copy_attn or self.copy_attn == "std": out = self.model.generator.forward(decOut) else: words = batch.words().t() words = torch.stack([ words[i] for i, b in enumerate(beam) if not b.done ]).contiguous() attn_copy = attn["copy"].view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() out, c_attn_t \ = self.model.generator.forward( decOut, words, attn_copy.view(-1, batch.src.size(0))) for b in range(out.size(0)): for c in range(c_attn_t.size(1)): v = self.align[words[0, c].data[0]] if v != onmt.Constants.PAD: out[b, v] += c_attn_t[b, c] out = out.log() # batch x beam x numWords wordLk = out.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn["std"].data[idx]): active += [b] for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view(-1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select( 1, beam[b].getCurrentOrigin())) if not active: break # In this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t, size=rnnSize, batchPos=-2): # Select only the remaining active sentences view = t.data.view(-1, remainingSents, t.size(-1)) newSize = list(t.size()) newSize[batchPos] = newSize[batchPos] * \ len(activeIdx) // remainingSents return Variable(view.index_select(1, activeIdx).view(*newSize), volatile=True) decStates = tuple([updateActive(d) for d in decStates]) decOut = updateActive(decOut) context = updateActive(context) if useMasking: padMask = padMask.index_select(1, activeIdx) remainingSents = len(active) # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]]) allHyp += [hyps] if useMasking: valid_attn = batch.src.data[:, b, 0].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) self.beam_accum["scores"].append( [["%4f" % s for s in t.tolist()] for t in beam[b].allScores][1:]) self.beam_accum["predicted_ids"].append( [[self.tgt_dict.getLabel(id) for id in t.tolist()] for t in beam[b].nextYs][1:]) return allHyp, allScores, allAttn, goldScores
def translateBatch(self, batch, data): beamSize = self.opt.beam_size batchSize = batch.batch_size _, src_lengths = batch.src src = make_features(batch, self.fields) # (1) run the encoder on the src encStates, context = self.model.encoder(src, lengths=src_lengths) encStates = self.model.init_decoder_state(context, encStates) useMasking = (self._type == "text") # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding padMask = None tgt_pad = self.fields["tgt"].vocab.stoi[onmt.IO.PAD_WORD] if useMasking: pad = self.fields["src"].vocab.stoi[onmt.IO.PAD_WORD] padMask = src[:, :, 0].data.eq(pad).t() def mask(padMask): if useMasking: self.model.decoder.attn.applyMask(padMask) # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model goldScores = context.data.new(batchSize).zero_() if "tgt" in batch.__dict__: decStates = encStates mask(padMask.unsqueeze(0)) decOut, decStates, attn = self.model.decoder( batch.tgt[:-1], batch.src, context, decStates) for dec_t, tgt_t in zip(decOut, batch.tgt[1:].data): gen_t = self.model.generator.forward(dec_t) tgt_t = tgt_t.unsqueeze(1) scores = gen_t.data.gather(1, tgt_t) scores.masked_fill_(tgt_t.eq(tgt_pad), 0) goldScores += scores # (3) run the decoder to generate sentences, using beam search # Each hypothesis in the beam uses the same context # and initial decoder state context = Variable(context.data.repeat(1, beamSize, 1), volatile=True) batch_src = Variable(src.data.repeat(1, beamSize, 1), volatile=True) batch_src_map = Variable(batch.src_map.data.repeat(1, beamSize, 1), volatile=True) decStates = encStates decStates.repeatBeam_(beamSize) beam = [ onmt.Beam(beamSize, cuda=self.opt.cuda, vocab=self.fields["tgt"].vocab) for __ in range(batchSize) ] if useMasking: padMask = src.data[:, :, 0].eq(pad).t() \ .unsqueeze(0) \ .repeat(beamSize, 1, 1) # (3b) The main loop for i in range(self.opt.max_sent_length): # (a) Run RNN decoder forward one step. mask(padMask) input = torch.stack([b.getCurrentState() for b in beam])\ .t().contiguous().view(1, -1) input.masked_fill_(input.gt(len(self.fields["tgt"].vocab) - 1), 0) input = Variable(input, volatile=True) decOut, decStates, attn = self.model.decoder( input, batch_src, context, decStates) # print(decStates.all[0][:, 0, 0]) decOut = decOut.squeeze(0) # decOut: (beam*batch) x numWords attn["std"] = attn["std"].view(beamSize, batchSize, -1) \ .transpose(0, 1).contiguous() # (b) Compute a vector of batch*beam word scores. if not self.copy_attn: out = self.model.generator.forward(decOut).data else: # print(attn["copy"].size()) attn_copy = attn["copy"].view(beamSize, batchSize, -1) \ .transpose(0, 1).contiguous() out = self.model.generator.forward( decOut, attn_copy.view(-1, batch_src.size(0)), batch_src_map) out = data.collapseCopyScores( out.data.view(batchSize, beamSize, -1).transpose(0, 1), batch, self.fields["tgt"].vocab) out = out.log().transpose(0, 1).contiguous()\ .view(beamSize * batchSize, -1) word_scores = out.view(beamSize, batchSize, -1) \ .transpose(0, 1).contiguous() # batch x beam x numWords # (c) Advance each beam. active = [] for b in range(batchSize): is_done = beam[b].advance(word_scores[b], attn["std"].data[b]) if not is_done: active += [b] decStates.beamUpdate_(b, beam[b].getCurrentOrigin(), beamSize) if not active: break # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn = [], [] for k in ks[:n_best]: hyp, att = beam[b].getHyp(k) hyps.append(hyp) attn.append(att) allHyp += [hyps] if useMasking: valid_attn = src.data[:, b, 0].ne(pad) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] # For debugging visualization. if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) self.beam_accum["scores"].append( [["%4f" % s for s in t.tolist()] for t in beam[b].allScores][1:]) self.beam_accum["predicted_ids"].append( [[self.tgt_dict.getLabel(id) for id in t.tolist()] for t in beam[b].nextYs][1:]) return allHyp, allScores, allAttn, goldScores
def beam_conf_once(self, srcBatch, tgtBatch, confidence_method, conf_n_best): beamSize = self.opt.beam_size confidence_method_split = confidence_method.split(':') # (1) run the encoder on the src encStates, context, rnnSize = self.conf_encode(srcBatch) # (3) run the decoder to generate sentences, using beam search # Expand tensors for each beam. batchSize = self._getBatchSize(srcBatch[0]) context = Variable(context.data.repeat(1, beamSize, 1)) decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)), Variable(encStates[1].data.repeat(1, beamSize, 1))) beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)] decOut = self.model.make_init_decoder_output(context) padMask = srcBatch[0].data.eq( onmt.Constants.PAD).t() \ .unsqueeze(0) \ .repeat(beamSize, 1, 1) batchIdx = list(range(batchSize)) remainingSents = batchSize for i in range(self.opt.max_sent_length): self.model.decoder.attn.applyMask(padMask) # Prepare decoder input. input = torch.stack([ b.getCurrentState() for b in beam if not b.done ]).t().contiguous().view(1, -1) decOut, decStates, attn = self.model.decoder( Variable(input, volatile=True), decStates, context, decOut) # decOut: 1 x (beam*batch) x numWords decOut = decOut.squeeze(0) out = self.model.generator.forward(decOut) # batch x beam x numWords wordLk = out.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() attn = attn.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn.data[idx]): active += [b] for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view(-1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select( 1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t): # select only the remaining active sentences view = t.data.view(-1, remainingSents, rnnSize) newSize = list(t.size()) newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents return Variable(view.index_select(1, activeIdx).view(*newSize), volatile=True) decStates = (updateActive(decStates[0]), updateActive(decStates[1])) decOut = updateActive(decOut) context = updateActive(context) padMask = padMask.index_select(1, activeIdx) remainingSents = len(active) # (4) package everything up allScores = [] for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:conf_n_best]] # -> (batch_size, conf_n_best) p_infer = torch.stack(allScores) return p_infer
def translateBatch(self, srcBatch, tgtBatch): batchSize = srcBatch.size(1) beamSize = self.opt.beam_size decoder = self.model.decoder attentionLayer = decoder.attn useMasking = self.opt.mem == 'lstm_lstm' and self.model.decoder.use_attn def lstm_encoder(src): emb_in = self.model.word_lut(src) init_h = self.model.make_init_hidden( emb_in[0], src.size(1), self.model.decoder.hidden_size, 2) hidden = (torch.stack(init_h[0]), torch.stack(init_h[1])) context, hidden = self.model.encoder(emb_in, hidden) return context, hidden def lstm_decoder(tgt, hidden, context, decOut): if useMasking: padMask = srcBatch.data.eq(onmt.Constants.PAD).t() attentionLayer.applyMask(padMask) out, dec_hidden, _attn = self.model.decoder( tgt, hidden, context, decOut) return out, dec_hidden, _attn def dnc_encoder(src): batch_size = src.size(1) hidden = self.model.encoder.make_init_hidden( src[0], *self.model.encoder.rnn_sz) M = self.model.encoder.make_init_M(batch_size) emb_in = self.model.word_lut(src) return self.model.encoder(emb_in, hidden, M) # (1) run the encoder on the src if self.opt.mem == 'lstm_lstm': context, encStates = lstm_encoder(srcBatch) elif self.opt.mem == 'dnc_dnc': context, encStates, M = dnc_encoder(srcBatch) rnnSize = encStates[0][0].size(1) # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model goldScores = encStates[0][0].data.new(batchSize).zero_() if tgtBatch is not None: decStates = encStates decM = M if self.opt.mem == 'lstm_lstm': init_output = self.model.make_init_decoder_output(context[0]) decOut, decStates, attn = lstm_decoder( tgtBatch[:-1], decStates, context, init_output) elif self.opt.mem == 'dnc_dnc': emb_out = self.model.word_lut(tgtBatch[:-1]) decOut, decStates, decM = self.model.decoder( emb_out, decStates, decM) for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data): gen_t = self.model.generator.forward(dec_t) tgt_t = tgt_t.unsqueeze(1) scores = gen_t.data.gather(1, tgt_t) scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) goldScores += scores print(' == got gold ==') # (3) run the decoder to generate sentences, using beam search # Expand tensors for each beam. if self.opt.mem == 'lstm_lstm': context = Variable(context.data.repeat(1, beamSize, 1)) elif self.opt.mem == 'dnc_dnc': decM = {} for k in M.keys(): print(k) dims = M[k].dim() if dims == 3: decM[k] = Variable(M[k].data.repeat(beamSize, 1, 1)) elif dims == 2: decM[k] = Variable(M[k].data.repeat(beamSize, 1)) print(' -- M:') [print(k, M[k].size()) for k in M.keys()] print(' -- decM:') [print(k, decM[k].size()) for k in decM.keys()] decStates = ((Variable(encStates[0][0].data.repeat(beamSize, 1)), Variable(encStates[0][1].data.repeat(beamSize, 1))), (Variable(encStates[1][0].data.repeat(beamSize, 1)), Variable(encStates[1][1].data.repeat(beamSize, 1)))) beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)] decOut = self.model.make_init_decoder_output( decStates[0][0]) # .squeeze(0)) if useMasking: padMask = srcBatch.data.eq( onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1) batchIdx = list(range(batchSize)) remainingSents = batchSize for i in range(self.opt.max_sent_length): if useMasking: attentionLayer.applyMask(padMask) # Prepare decoder input. input = torch.stack([b.getCurrentState() for b in beam if not b.done]).t().contiguous().view(1, -1) if self.opt.mem == 'lstm_lstm': decOut, decStates, attn = self.model.decoder( Variable(input, volatile=True), decStates, context, decOut) elif self.opt.mem == 'dnc_dnc': inp = self.model.word_lut(Variable(input, volatile=True)) decOut, decStates, decM = self.model.decoder( inp, decStates, decM) # decOut: 1 x (beam*batch) x numWords decOut = decOut.squeeze(0) out = self.model.generator.forward(decOut) # batch x beam x numWords wordLk = out.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() if self.opt.mem == 'lstm_lstm': attn = attn.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() else: attn = None active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn.data[idx]): active += [b] for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view(-1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select( 1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t): # select only the remaining active sentences view = t.data.view(-1, remainingSents, rnnSize) newSize = list(t.size()) newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents return Variable(view.index_select(1, activeIdx) .view(*newSize), volatile=True) decStates = (updateActive(decStates[0]), updateActive(decStates[1])) decOut = updateActive(decOut) context = updateActive(context) if useMasking: padMask = padMask.index_select(1, activeIdx) remainingSents = len(active) # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]]) allHyp += [hyps] if useMasking: valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) self.beam_accum["scores"].append([ ["%4f" % s for s in t.tolist()] for t in beam[b].allScores][1:]) self.beam_accum["predicted_ids"].append( [[self.tgt_dict.getLabel(id) for id in t.tolist()] for t in beam[b].nextYs][1:]) return allHyp, allScores, allAttn, goldScores
def beam_decode(self, encStates): batchSize = encStates.size(0) beamSize = self.opt.beam_size rnnSize = self.model.decoder.hidden_size beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)] decStates = self.model.latent_to_decoder(encStates) if self.model.prelu: decStates = self.model.prelu_dec(decStates) decStates = decStates.view(self.model.layers, decStates.size(0), -1) decStates = torch.split(decStates, decStates.size(-1) // 2, 2) decStates = (Variable(decStates[0].data.repeat(1, beamSize, 1)), Variable(decStates[1].data.repeat(1, beamSize, 1))) context = Variable(encStates.data.repeat(beamSize, 1)) decOut = self.model.make_init_decoder_output(context) batchIdx = list(range(batchSize)) remainingSents = batchSize for i in range(self.opt.max_sent_length): # Prepare decoder input. input = torch.stack([ b.getCurrentState() for b in beam if not b.done ]).t().contiguous().view(1, -1) decOut, decStates = self.model.decoder( Variable(input, volatile=True), decStates, decOut) # decOut: 1 x (beam*batch) x numWords decOut = decOut.squeeze(0) out = decOut # batch x beam x numWords wordLk = out.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx]): active += [b] for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view(-1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select( 1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t): # select only the remaining active sentences view = t.data.view(-1, remainingSents, rnnSize) newSize = list(t.size()) newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents return Variable(view.index_select(1, activeIdx) \ .view(*newSize), volatile=True) decStates = (updateActive(decStates[0]), updateActive(decStates[1])) #decOut = updateActive(decOut) remainingSents = len(active) # (4) package everything up allHyp, allScores = [], [] n_best = self.opt.n_best for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps = [beam[b].getHyp(k) for k in ks[:n_best]] allHyp += [hyps] return allHyp, allScores
def translateBatch(self, srcBatch, tgtBatch): batchSize = srcBatch[0].size(1) beamSize = self.opt.beam_size # (1) run the encoder on the src encStates, _ = self.model.encode(srcBatch) srcBatch = srcBatch[0] # drop the lengths needed for encoder rnnSize = self.model.decoder.hidden_size # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model goldScores = encStates.data.new(batchSize).zero_() if tgtBatch is not None: decStates = encStates decOut = self.model.decode(decStates, tgtBatch[:-1]) for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data): gen_t = dec_t tgt_t = tgt_t.unsqueeze(1) scores = gen_t.data.gather(1, tgt_t) scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) goldScores += scores # (3) run the decoder to generate sentences, using beam search # Expand tensors for each beam. beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)] decStates = self.model.latent_to_decoder(encStates) if self.model.prelu: decStates = self.model.prelu_dec(decStates) decStates = decStates.view(self.model.layers, decStates.size(0), -1) decStates = torch.chunk(decStates, 2, 2) decStates = (Variable(decStates[0].data.repeat(1, beamSize, 1)), Variable(decStates[1].data.repeat(1, beamSize, 1))) context = Variable(encStates.data.repeat(beamSize, 1)) decOut = self.model.make_init_decoder_output(context) padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat( beamSize, 1, 1) batchIdx = list(range(batchSize)) remainingSents = batchSize for i in range(self.opt.max_sent_length): # Prepare decoder input. input = torch.stack([ b.getCurrentState() for b in beam if not b.done ]).t().contiguous().view(1, -1) decOut, decStates = self.model.decoder( Variable(input, volatile=True), decStates, decOut) # decOut: 1 x (beam*batch) x numWords decOut = decOut.squeeze(0) out = decOut # batch x beam x numWords wordLk = out.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx]): active += [b] for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view(-1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select( 1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t): # select only the remaining active sentences view = t.data.view(-1, remainingSents, rnnSize) newSize = list(t.size()) newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents return Variable(view.index_select(1, activeIdx) \ .view(*newSize), volatile=True) decStates = (updateActive(decStates[0]), updateActive(decStates[1])) #decOut = updateActive(decOut) padMask = padMask.index_select(1, activeIdx) remainingSents = len(active) # (4) package everything up allHyp, allScores = [], [] n_best = self.opt.n_best for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps = [beam[b].getHyp(k) for k in ks[:n_best]] allHyp += [hyps] return allHyp, allScores, goldScores
def translateBatch(self, srcBatch, tgtBatch): torch.set_grad_enabled(False) # Batch size is in different location depending on data. beamSize = self.opt.beam_size batchSize = self._getBatchSize(srcBatch) vocab_size = self.tgt_dict.size() allHyp, allScores, allAttn, allLengths = [], [], [], [] # srcBatch should have size len x batch # tgtBatch should have size len x batch contexts = dict() src = srcBatch.transpose(0, 1) # (1) run the encoders on the src for i in range(self.n_models): contexts[i], src_mask = self.models[i].encoder(src) goldScores = contexts[0].data.new(batchSize).zero_() goldWords = 0 if tgtBatch is not None: # Use the first model to decode model_ = self.models[0] tgtBatchInput = tgtBatch[:-1] tgtBatchOutput = tgtBatch[1:] tgtBatchInput = tgtBatchInput.transpose(0, 1) output, coverage = model_.decoder(tgtBatchInput, contexts[0], src) # output should have size time x batch x dim # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model for dec_t, tgt_t in zip(output, tgtBatchOutput.data): gen_t = model_.generator(dec_t) tgt_t = tgt_t.unsqueeze(1) scores = gen_t.data.gather(1, tgt_t) scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) goldScores += scores.squeeze(1).type_as(goldScores) goldWords += tgt_t.ne(onmt.Constants.PAD).sum().item() # (3) Start decoding # time x batch * beam src = srcBatch # this is time first again (before transposing) # initialize the beam beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)] batchIdx = list(range(batchSize)) remainingSents = batchSize decoder_states = dict() decoder_hiddens = dict() for i in range(self.n_models): decoder_states[i] = self.models[i].create_decoder_state( src, contexts[i], src_mask, beamSize, type='old') for i in range(self.opt.max_sent_length): # Prepare decoder input. # input size: 1 x ( batch * beam ) input = torch.stack([ b.getCurrentState() for b in beam if not b.done ]).t().contiguous().view(1, -1) """ Inefficient decoding implementation We re-compute all states for every time step A better buffering algorithm will be implemented """ decoder_input = input # require batch first for everything outs = dict() attns = dict() for i in range(self.n_models): decoder_hidden, coverage = self.models[i].decoder.step( decoder_input.clone(), decoder_states[i]) # take the last decoder state decoder_hidden = decoder_hidden.squeeze(1) attns[i] = coverage[:, -1, :].squeeze(1) # batch * beam x src_len # batch * beam x vocab_size outs[i] = self.models[i].generator(decoder_hidden) out = self._combineOutputs(outs) attn = self._combineAttention(attns) wordLk = out.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() attn = attn.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn.data[idx]): active += [b] for i in range(self.n_models): decoder_states[i]._update_beam(beam, b, remainingSents, idx) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} for i in range(self.n_models): decoder_states[i]._prune_complete_beam(activeIdx, remainingSents) remainingSents = len(active) # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best allLengths = [] for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn, length = zip(*[beam[b].getHyp(k) for k in ks[:n_best]]) allHyp += [hyps] allLengths += [length] valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) self.beam_accum["scores"].append( [["%4f" % s for s in t.tolist()] for t in beam[b].allScores][1:]) self.beam_accum["predicted_ids"].append( [[self.tgt_dict.getLabel(id) for id in t.tolist()] for t in beam[b].nextYs][1:]) torch.set_grad_enabled(True) return allHyp, allScores, allAttn, allLengths, goldScores, goldWords
def translateBatch(self, batch): beamSize = self.opt.beam_size batchSize = batch.batchSize # (1) run the encoder on the src useMasking = (self._type == "text") encStatesL = [] decStatesL = [] contextL = [] src_lengths = batch.lengths.data.view(-1).tolist() globalScorer = onmt.GNMTGlobalScorer(self.opt.alpha, self.opt.beta) beam = [onmt.Beam(beamSize, self.opt.cuda, globalScorer, alpha=self.opt.alpha, beta=self.opt.beta, tgtDict=self.tgt_dict) for i in range(batchSize)] for model in self.models: encStates, context = model.encoder(batch.src, lengths=batch.lengths) encStates = model.init_decoder_state(context, encStates) decoder = model.decoder attentionLayer = decoder.attn # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding padMask = batch.words().data.eq(onmt.Constants.PAD).t() attentionLayer.applyMask(padMask) # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model ## for sanity check #if batch.tgt is not None: # decStates = encStates # mask(padMask.unsqueeze(0)) # decOut, decStates, attn = self.model.decoder(batch.tgt[:-1], # batch.src, # context, # decStates) # for dec_t, tgt_t in zip(decOut, batch.tgt[1:].data): # gen_t = self.model.generator.forward(dec_t) # tgt_t = tgt_t.unsqueeze(1) # scores = gen_t.data.gather(1, tgt_t) # scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) # goldScores += scores # for sanity check # (3) run the decoder to generate sentences, using beam search # Each hypothesis in the beam uses the same context # and initial decoder state context = Variable(context.data.repeat(1, beamSize, 1)) contextL.append(context.clone()) goldScores = context.data.new(batchSize).zero_() decStates = encStates decStates.repeatBeam_(beamSize) decStatesL.append(decStates) batch_src = Variable(batch.src.data.repeat(1, beamSize, 1)) padMask = batch.src.data[:, :, 0].eq(onmt.Constants.PAD).t() \ .unsqueeze(0) \ .repeat(beamSize, 1, 1) # (3b) The main loop beam_done = [] for i in range(self.opt.max_sent_length): # (a) Run RNN decoder forward one step. #mask(padMask) input = torch.stack([b.getCurrentState() for b in beam])\ .t().contiguous().view(1, -1) input = Variable(input, volatile=True) decOutTmp = [] attnTmp = [] word_scores = [] for idx in range(len(self.models)): model = self.models[idx] model.decoder.attn.applyMask(padMask) decOut, decStatesTmp, attn = model.decoder(input, batch_src, contextL[idx], decStatesL[idx]) decStatesL[idx] = decStatesTmp decOutTmp.append(decOut) attnTmp.append(attn) decOut = decOut.squeeze(0) # decOut: (beam*batch) x numWords attn["std"] = attn["std"].view(beamSize, batchSize, -1) \ .transpose(0, 1).contiguous() # (b) Compute a vector of batch*beam word scores. #if not self.copy_attn: if True: out = model.generator[0].forward(decOut) out = nn.Softmax()(out) else: # Copy Attention Case words = batch.words().t() words = torch.stack([words[i] for i, b in enumerate(beam)])\ .contiguous() attn_copy = attn["copy"].view(beamSize, batchSize, -1) \ .transpose(0, 1).contiguous() out, c_attn_t \ = self.model.generator.forward( decOut, attn_copy.view(-1, batch_src.size(0))) for b in range(out.size(0)): for c in range(c_attn_t.size(1)): v = self.align[words[0, c].data[0]] if v != onmt.Constants.PAD: out[b, v] += c_attn_t[b, c] out = out.log() #score = out.view(beamSize, batchSize, -1).transpose(0, 1).contiguous() # batch x beam x numWords word_scores.append(out.clone()) word_score = torch.stack(word_scores).sum(0).squeeze(0).div_(len(self.models)) mean_score = word_score.view(beamSize, batchSize, -1).transpose(0, 1).contiguous() scores = torch.log(mean_score) #scores = self.models[0].generator[1].forward(mean_score) # (c) Advance each beam. active = [] for b in range(batchSize): if b in beam_done: continue beam[b].advance(scores.data[b], attn["std"].data[b]) is_done = beam[b].done() if not is_done: active += [b] for dec in decStatesL: dec.beamUpdate_(b, beam[b].getCurrentOrigin(), beamSize) if is_done: beam_done.append(b) #if not active: #break if len(beam_done) == batchSize: break # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best for b in range(batchSize): scores, ks = beam[b].sortFinished() #scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn = [], [] for i, (times, k) in enumerate(ks[:n_best]): hyp, att = beam[b].getHyp(times, k) hyps.append(hyp) attn.append(att) allHyp += [hyps] if useMasking: valid_attn = batch.src.data[:, b, 0].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] # For debugging visualization. if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) self.beam_accum["scores"].append([ ["%4f" % s for s in t.tolist()] for t in beam[b].allScores][1:]) self.beam_accum["predicted_ids"].append( [[idx for idx in t.tolist()] for t in beam[b].nextYs][1:]) self.beam_accum["predicted_labels"].append( [[self.tgt_dict.getLabel(idx) for idx in t.tolist()] for t in beam[b].nextYs][1:]) beam[b].finished.sort(key=lambda a:-a[0]) self.beam_accum['finished'].append(beam[b].finished) return allHyp, allScores, allAttn, goldScores
def translateBatch(self, srcBatch, tgtBatch): # Batch size is in different location depending on data. beamSize = self.opt.beam_size # (1) run the encoder on the src encStates, context = self.model.encoder(srcBatch) # Drop the lengths needed for encoder. srcBatch = srcBatch[0] batchSize = self._getBatchSize(srcBatch) rnnSize = context.size(2) encStates = (self.model._fix_enc_hidden(encStates[0]), self.model._fix_enc_hidden(encStates[1])) decoder = self.model.decoder attentionLayer = decoder.attn useMasking = self._type == "text" # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding padMask = None if useMasking: padMask = srcBatch.data.eq(onmt.Constants.PAD).t() def mask(padMask): if useMasking: attentionLayer.applyMask(padMask) # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model goldScores = context.data.new(batchSize).zero_() if tgtBatch is not None: decStates = encStates decOut = self.model.make_init_decoder_output(context) mask(padMask) initOutput = self.model.make_init_decoder_output(context) decOut, decStates, attn = self.model.decoder( tgtBatch[:-1], decStates, context, initOutput) for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data): gen_t = self.model.generator.forward(dec_t) tgt_t = tgt_t.unsqueeze(1) scores = gen_t.data.gather(1, tgt_t) scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) goldScores += scores # (3) run the decoder to generate sentences, using beam search # Expand tensors for each beam. context = Variable(context.data.repeat(1, beamSize, 1)) decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)), Variable(encStates[1].data.repeat(1, beamSize, 1))) beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)] decOut = self.model.make_init_decoder_output(context) if useMasking: padMask = srcBatch.data.eq( onmt.Constants.PAD).t() \ .unsqueeze(0) \ .repeat(beamSize, 1, 1) batchIdx = list(range(batchSize)) remainingSents = batchSize for i in range(self.opt.max_sent_length): mask(padMask) # Prepare decoder input. input = torch.stack([ b.getCurrentState() for b in beam if not b.done ]).t().contiguous().view(1, -1) decOut, decStates, attn = self.model.decoder( Variable(input, volatile=True), decStates, context, decOut) # decOut: 1 x (beam*batch) x numWords decOut = decOut.squeeze(0) out = self.model.generator.forward(decOut) # batch x beam x numWords wordLk = out.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() attn = attn.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn.data[idx]): active += [b] for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view(-1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select( 1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t): # select only the remaining active sentences view = t.data.view(-1, remainingSents, rnnSize) newSize = list(t.size()) newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents return Variable(view.index_select(1, activeIdx).view(*newSize), volatile=True) decStates = (updateActive(decStates[0]), updateActive(decStates[1])) decOut = updateActive(decOut) context = updateActive(context) if useMasking: padMask = padMask.index_select(1, activeIdx) remainingSents = len(active) # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]]) allHyp += [hyps] if useMasking: valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) self.beam_accum["scores"].append( [["%4f" % s for s in t.tolist()] for t in beam[b].allScores][1:]) self.beam_accum["predicted_ids"].append( [[self.tgt_dict.getLabel(id) for id in t.tolist()] for t in beam[b].nextYs][1:]) return allHyp, allScores, allAttn, goldScores
def translateBatch(self, srcBatch, tgtBatch): torch.set_grad_enabled(False) # Batch size is in different location depending on data. beamSize = self.opt.beam_size batchSize = self._getBatchSize(srcBatch) if self.model_type == 'recurrent': # (1) run the encoder on the src encStates, context = self.model.encoder(srcBatch) rnnSize = context.size(2) decoder = self.model.decoder attentionLayer = decoder.attn useMasking = (self._type == "text" and batchSize > 1) # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding attn_mask = srcBatch.eq(onmt.Constants.PAD).t() # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model goldScores = context.data.new(batchSize).zero_() goldWords = 0 if tgtBatch is not None: decStates = encStates decOut = self.model.make_init_decoder_output(context) initOutput = self.model.make_init_decoder_output(context) decOut, decStates, attn = self.model.decoder( tgtBatch[:-1], decStates, context, initOutput, attn_mask=attn_mask) for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data): gen_t = self.model.generator.forward(dec_t) tgt_t = tgt_t.unsqueeze(1) scores = gen_t.data.gather(1, tgt_t) scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) goldScores += scores.squeeze(1) goldWords += tgt_t.ne(onmt.Constants.PAD).sum() # (3) run the decoder to generate sentences, using beam search # Expand tensors for each beam. context = Variable(context.data.repeat(1, beamSize, 1)) decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)), Variable(encStates[1].data.repeat(1, beamSize, 1))) beam = [ onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize) ] decOut = self.model.make_init_decoder_output(context) attn_mask = srcBatch.eq( onmt.Constants.PAD).t() \ .unsqueeze(0) \ .repeat(beamSize, 1, 1) batchIdx = list(range(batchSize)) remainingSents = batchSize for i in range(self.opt.max_sent_length): # Prepare decoder input. input = torch.stack([ b.getCurrentState() for b in beam if not b.done ]).t().contiguous().view(1, -1) decOut, decStates, attn = self.model.decoder( Variable(input), decStates, context, decOut, attn_mask=attn_mask) # decOut: 1 x (beam*batch) x numWords decOut = decOut.squeeze(0) out = self.model.generator.forward(decOut) # batch x beam x numWords wordLk = out.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() attn = attn.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn.data[idx]): active += [b] for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view(-1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select( 1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t): # select only the remaining active sentences view = t.data.view(-1, remainingSents, rnnSize) newSize = list(t.size()) newSize[-2] = newSize[-2] * len( activeIdx) // remainingSents return Variable( view.index_select(1, activeIdx).view(*newSize)) decStates = (updateActive(decStates[0]), updateActive(decStates[1])) decOut = updateActive(decOut) context = updateActive(context) attn_mask_data = attn_mask.data.index_select(1, activeIdx) attn_mask = Variable(attn_mask_data) remainingSents = len(active) # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best allLengths = [] for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn, length = zip( *[beam[b].getHyp(k) for k in ks[:n_best]]) allHyp += [hyps] allLengths += [length] valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) self.beam_accum["scores"].append( [["%4f" % s for s in t.tolist()] for t in beam[b].allScores][1:]) self.beam_accum["predicted_ids"].append( [[self.tgt_dict.getLabel(id) for id in t.tolist()] for t in beam[b].nextYs][1:]) torch.set_grad_enabled(True) return allHyp, allScores, allAttn, allLengths, goldScores, goldWords elif self.model_type in [ 'transformer', 'ptransformer', 'fctransformer' ]: vocab_size = self.tgt_dict.size() allHyp, allScores, allAttn, allLengths = [], [], [], [] # srcBatch should have size len x batch # tgtBatch should have size len x batch src = srcBatch.transpose(0, 1) context, src_mask = self.model.encoder(src) goldScores = context.data.new(batchSize).zero_() goldWords = 0 if tgtBatch is not None: tgtBatchInput = tgtBatch[:-1] tgtBatchOutput = tgtBatch[1:] tgtBatchInput = tgtBatchInput.transpose(0, 1) output, coverage = self.model.decoder(tgtBatchInput, context, src) output = output.transpose( 0, 1) # transpose to have time first, like RNN models # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model for dec_t, tgt_t in zip(output, tgtBatchOutput.data): gen_t = self.model.generator(dec_t) tgt_t = tgt_t.unsqueeze(1) scores = gen_t.data.gather(1, tgt_t) scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) goldScores += scores.squeeze(1) goldWords += tgt_t.ne(onmt.Constants.PAD).sum() # (3) Start decoding # time x batch * beam src = Variable(srcBatch.data.repeat(1, beamSize)) # context size : time x batch*beam x hidden context = self._replicate_context(context) # initialize the beam beam = [ onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize) ] batchIdx = list(range(batchSize)) remainingSents = batchSize #~ input_seq = None #~ #~ buffer = None #~ ## Create a new decoding state ## I use a method from the model because I don't want to directly access to the decoder state object ## Currently it doesn't share anything with the main model though decoder_state = self.model.create_decoder_state( src, context, beamSize) for i in range(self.opt.max_sent_length): # Prepare decoder input. # input size: 1 x ( batch * beam ) input = torch.stack([ b.getCurrentState() for b in beam if not b.done ]).t().contiguous().view(1, -1) """ Inefficient decoding implementation We re-compute all states for every time step A better buffering algorithm will be implemented """ #~ input_seq = decoder_state.input_seq #~ if input_seq is None: #~ input_seq = input #~ else: #~ # concatenate the last input to the previous input sequence #~ input_seq = torch.cat([input_seq, input], 0) #~ decoder_state.input_seq = input_seq # require batch first for everything decoder_input = Variable(input) #~ if context.dim() == 4: #~ context_ = context.transpose(1, 2) #~ else: #~ context_ = context.transpose(0, 1) #~ decoder_hidden, coverage, buffer = self.model.decoder.step(decoder_input.transpose(0,1) , context_, src.transpose(0, 1), buffer=buffer) decoder_hidden, coverage = self.model.decoder.step( decoder_input, decoder_state) # take the last decoder state decoder_hidden = decoder_hidden.squeeze(1) attn = coverage[:, -1, :].squeeze(1) # batch * beam x src_len # batch * beam x vocab_size out = self.model.generator(decoder_hidden) wordLk = out.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() attn = attn.view(beamSize, remainingSents, -1) \ .transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn.data[idx]): active += [b] decoder_state._update_beam(beam, b, remainingSents, idx) # update the decoding states #~ for tensor in [src, input_seq] : #~ #~ t_, br = tensor.size() #~ sent_states = tensor.view(t_, beamSize, remainingSents)[:, :, idx] #~ #~ if isinstance(tensor, Variable): #~ sent_states.data.copy_(sent_states.data.index_select( #~ 1, beam[b].getCurrentOrigin())) #~ else: #~ sent_states.copy_(sent_states.index_select( #~ 1, beam[b].getCurrentOrigin())) #~ #~ nl, br_, t_, d_ = buffer.size() #~ #~ sent_states = buffer.view(nl, beamSize, remainingSents, t_, d_)[:, :, idx, :, :] #~ #~ sent_states.data.copy_(sent_states.data.index_select( #~ 1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} #~ model_size = context.size(-1) decoder_state._prune_complete_beam(activeIdx, remainingSents) #~ def updateActive(t): #~ # select only the remaining active sentences #~ view = t.data.view(-1, remainingSents, model_size) #~ newSize = list(t.size()) #~ newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents #~ return Variable(view.index_select(1, activeIdx) #~ .view(*newSize)) #~ #~ def updateActive4D(t): #~ # select only the remaining active sentences #~ nl, br_, t_, d_ = t.size() #~ view = t.data.view(nl, -1, remainingSents, t_, model_size) #~ newSize = list(t.size()) #~ newSize[1] = newSize[1] * len(activeIdx) // remainingSents #~ return Variable(view.index_select(2, activeIdx) #~ .view(*newSize)) #~ #~ def updateActive4D_time_first(t): #~ # select only the remaining active sentences #~ nl, t_, br_, d_ = t.size() #~ view = t.data.view(nl, t_, -1, remainingSents, model_size) #~ newSize = list(t.size()) #~ newSize[2] = newSize[2] * len(activeIdx) // remainingSents #~ return Variable(view.index_select(3, activeIdx) #~ .view(*newSize)) #~ #~ def updateActive2D(t): #~ if isinstance(t, Variable): #~ # select only the remaining active sentences #~ view = t.data.view(-1, remainingSents) #~ newSize = list(t.size()) #~ newSize[-1] = newSize[-1] * len(activeIdx) // remainingSents #~ return Variable(view.index_select(1, activeIdx) #~ .view(*newSize)) #~ else: #~ view = t.view(-1, remainingSents) #~ newSize = list(t.size()) #~ newSize[-1] = newSize[-1] * len(activeIdx) // remainingSents #~ new_t = view.index_select(1, activeIdx).view(*newSize) #~ #~ return new_t #~ #~ if context.dim() == 3 : #~ context = updateActive(context) #~ elif context.dim() == 4: #~ context = updateActive4D_time_first(context) #~ #~ src = updateActive2D(src) #~ #~ input_seq = updateActive2D(input_seq) #~ #~ buffer = updateActive4D(buffer) #~ remainingSents = len(active) # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best allLengths = [] for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn, length = zip( *[beam[b].getHyp(k) for k in ks[:n_best]]) allHyp += [hyps] allLengths += [length] valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] if self.beam_accum: self.beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) self.beam_accum["scores"].append( [["%4f" % s for s in t.tolist()] for t in beam[b].allScores][1:]) self.beam_accum["predicted_ids"].append( [[self.tgt_dict.getLabel(id) for id in t.tolist()] for t in beam[b].nextYs][1:]) torch.set_grad_enabled(True) return allHyp, allScores, allAttn, allLengths, goldScores, goldWords else: print("Model type %s is not supported" % self.model_type) raise NotImplementedError
def translateBatch(self, batch, dataset): beam_size = self.opt.beam_size batch_size = batch.batch_size # (1) Run the encoder on the src. _, src_lengths = batch.src src = onmt.IO.make_features(batch, 'src') encStates, context = self.model.encoder(src, src_lengths) decStates = self.model.decoder.init_decoder_state( src, context, encStates) # (1b) Initialize for the decoder. def var(a): return Variable(a, volatile=True) def rvar(a): return var(a.repeat(1, beam_size, 1)) # Repeat everything beam_size times. context = rvar(context.data) src = rvar(src.data) srcMap = rvar(batch.src_map.data) decStates.repeat_beam_size_times(beam_size) scorer = onmt.GNMTGlobalScorer(self.alpha, self.beta) beam = [onmt.Beam(beam_size, n_best=self.opt.n_best, cuda=self.opt.cuda, vocab=self.fields["tgt"].vocab, global_scorer=scorer) for __ in range(batch_size)] # (2) run the decoder to generate sentences, using beam search. def bottle(m): return m.view(batch_size * beam_size, -1) def unbottle(m): return m.view(beam_size, batch_size, -1) for i in range(self.opt.max_sent_length): if all((b.done() for b in beam)): break # Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. inp = var(torch.stack([b.getCurrentState() for b in beam]) .t().contiguous().view(1, -1)) # Turn any copied words to UNKs # 0 is unk if self.copy_attn: inp = inp.masked_fill( inp.gt(len(self.fields["tgt"].vocab) - 1), 0) # Temporary kludge solution to handle changed dim expectation # in the decoder inp = inp.unsqueeze(2) # Run one step. decOut, decStates, attn = \ self.model.decoder(inp, context, decStates) decOut = decOut.squeeze(0) # decOut: beam x rnn_size # (b) Compute a vector of batch*beam word scores. if not self.copy_attn: out = self.model.generator.forward(decOut).data out = unbottle(out) # beam x tgt_vocab else: out = self.model.generator.forward(decOut, attn["copy"].squeeze(0), srcMap) # beam x (tgt_vocab + extra_vocab) out = dataset.collapse_copy_scores( unbottle(out.data), batch, self.fields["tgt"].vocab) # beam x tgt_vocab out = out.log() # (c) Advance each beam. for j, b in enumerate(beam): b.advance(out[:, j], unbottle(attn["std"]).data[:, j]) decStates.beam_update(j, b.getCurrentOrigin(), beam_size) if "tgt" in batch.__dict__: allGold = self._runTarget(batch, dataset) else: allGold = [0] * batch_size # (3) Package everything up. allHyps, allScores, allAttn = [], [], [] for b in beam: n_best = self.opt.n_best scores, ks = b.sortFinished(minimum=n_best) hyps, attn = [], [] for i, (times, k) in enumerate(ks[:n_best]): hyp, att = b.getHyp(times, k) hyps.append(hyp) attn.append(att) allHyps.append(hyps) allScores.append(scores) allAttn.append(attn) return allHyps, allScores, allAttn, allGold
def translateBatch(self, srcBatch, tgtBatch, alignBatch): batchSize = srcBatch[0].size(1) beamSize = self.opt.beam_size knntime = 0.0 # (1) run the encoder on the src encStates, context, fert = self.model.encoder(srcBatch, is_fert=True) init_fert = deepcopy(fert) init_fert = torch.max(init_fert, dim=-1)[1].float() # cov = torch.max(cov, dim=-1)[1].float() srcBatch = srcBatch[0] # drop the lengths needed for encoder rnnSize = context.size(2) encStates = (self.model._fix_enc_hidden(encStates[0]), self.model._fix_enc_hidden(encStates[1])) # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding padMask = srcBatch.data.eq(onmt.Constants.PAD).t() def applyContextMask(m): if isinstance(m, onmt.modules.GlobalAttention): m.applyMask(padMask) elif isinstance(m, onmt.modules.GlobalAttentionOriginal): m.applyMask(padMask) # (3) run the decoder to generate sentences, using beam search # Expand tensors for each beam. # decStates = encStates if self.opt.use_lm: lm_hidden = self.LangModel.initialize_hidden(1, batchSize) context = Variable(context.data.repeat(1, beamSize, 1)) # cov = Variable(torch.zeros((context.size(1),context.size(0))), requires_grad=True) # cov = cov.cuda() decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1)), Variable(encStates[1].data.repeat(1, beamSize, 1))) beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)] decOut = self.model.make_init_decoder_output(context) padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1) batchIdx = list(range(batchSize)) previous_batchIdx = list(range(batchSize)) remainingSents = batchSize for i in range(self.opt.max_sent_length): self.model.decoder.apply(applyContextMask) # Prepare decoder input. inputs = [] for i_b, b in enumerate(beam): if not b.done(): c = b.getCurrentState() if self.opt.replace_unk and i != 0 and (int(c)==onmt.Constants.UNK or int(c)== onmt.Constants.UNK + self.tgt_dict.size_uni()): tok = self.tgt_dict.getLabel(int(c)) if tok == onmt.Constants.UNK_WORD: b_i = previous_batchIdx[i_b] _src = self.src_dict.convertToLabels(srcBatch.data.t()[i_b], onmt.Constants.PAD_WORD) tok = self.replace_unk(i, tok, attn[b_i][0].data, _src) if tok in self.tgt_dict.labelToIdx.keys(): c = torch.LongTensor([self.tgt_dict.labelToIdx[tok]]).cuda() inputs.append(c) input_ = torch.stack(inputs).t().contiguous().view(1, -1) # input_ = torch.stack([b.getCurrentState() for b in beam # if not b.done()]).t().contiguous().view(1, -1) input_var = Variable(input_, volatile=True) # decOut, decStates, attn, alignBatch = self.model.decoder( # input_var, decStates, context, decOut, alignBatch) decOut, decStates, attn = self.model.decoder( input_var, decStates, context, decOut) # decOut: 1 x (beam*batch) x numWords decOut = decOut.squeeze(0) out = self.model.generator.forward(decOut) if self.opt.use_lm: lm_output, lm_hidden, _, _ = self.LangModel(input_var, lm_hidden) lm_output = torch.log(lm_output+1e-12) beg = time.time() original_scores = self._get_scores(out, self.target_embeddings) uni_scores, ngram_scores = self._get_scores(out) # print(fert) # print(attn.shape) fert_combined = torch.bmm(attn.unsqueeze(1)[:,:,:-1], fert[:,:-1]).squeeze(1) # print(fert_combined) if self.fert_dim > 2: uni_prob = torch.sum(fert_combined[:,:2], dim=1).unsqueeze(1) ngram_prob = torch.sum(fert_combined[:,2:], dim=1).unsqueeze(1) else: uni_prob = fert_combined[:,0].unsqueeze(1) ngram_prob = fert_combined[:,1].unsqueeze(1) scores = torch.cat((uni_scores*uni_prob, ngram_scores * ngram_prob),dim=1) # print(uni_prob.shape) # print(ngram_prob.shape) # print(uni_scores.shape) # print(ngram_scores.shape) # print(uni_prob.shape) # print(ngram_prob.shape) # print(scores.shape) # print(self.target_embeddings.weight.shape) # print(self.target_uni_embeddings.weight.shape) # print(self.target_ngram_embeddings.weight.shape) topk = scores.topk(beamSize, 1, True, True)[1].cpu().data.numpy() topk_uni = uni_scores.topk(beamSize, 1, True, True)[1].cpu().data.numpy() topk_ngram = ngram_scores.topk(beamSize, 1, True, True)[1].cpu().data.numpy() topk_original = original_scores.topk(beamSize, 1, True, True)[1].cpu().data.numpy() # print(topk) # print(topk_original) # print(topk_uni) # print(topk_ngram) # print(scores[0][:10]) # raise # print(uni_prob[:20]) # print(ngram_prob[:20]) # print(scores) # raise diff = time.time()-beg knntime += diff if self.opt.use_lm: scores += 0.2*lm_output.squeeze(0) wordLk = scores.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() attn = attn.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() out = out.view(remainingSents, beamSize, -1).contiguous() fert_beam = fert_combined.view(remainingSents, beamSize, -1).contiguous() init_fert_beam = init_fert.view(remainingSents, beamSize, -1).contiguous() active = [] for b in range(batchSize): if beam[b].done(): continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn.data[idx], out.data[idx], fert_beam.data[idx], init_fert_beam.data[idx]): active += [b] for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view( -1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select(1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) previous_batchIdx = batchIdx batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t): # select only the remaining active sentences view = t.data.view(-1, remainingSents, rnnSize) newSize = list(t.size()) newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents return Variable(view.index_select(1, activeIdx) \ .view(*newSize), volatile=True) decStates = (updateActive(decStates[0]), updateActive(decStates[1])) decOut = updateActive(decOut) context = updateActive(context) padMask = padMask.index_select(1, activeIdx) remainingSents = len(active) # (4) package everything up allHyp, allScores, allAttn, allOut, allOutScores, allFert, allInitFert = [], [], [], [], [], [], [] n_best = self.opt.n_best for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD).nonzero().squeeze(1) hyps, attn, out, score, fert, init_fert = zip(*[beam[b].getHyp(times, k) for (times, k) in ks[:n_best]]) attn = [a.index_select(1, valid_attn) for a in attn] allHyp += [hyps] allAttn += [attn] allOut += [out] allOutScores += [score] allFert += [fert] allInitFert += [init_fert] return allHyp, allScores, allAttn, allOut, allOutScores, allFert, allInitFert, knntime
def translateBatch(self, batch): srcBatch, tgtBatch = batch batchSize = srcBatch.size(1) beamSize = self.opt.beam_size # (1) run the encoder on the src encStates, context = None, None if self.model.encoder.num_directions == 2: # bidirectional encoder is negatively impacted by padding # run with batch size 1 for improved translations # This will be resolved when variable length LSTMs are used instead encStates, context = self.model.encoder(srcBatch, hidden=encStates) else: # have to execute the encoder manually to deal with padding context = [] for srcBatch_t in srcBatch.split(1): encStates, context_t = self.model.encoder(srcBatch_t, hidden=encStates) batchPadIdx = srcBatch_t.data.squeeze(0).eq(onmt.Constants.PAD).nonzero() if batchPadIdx.nelement() > 0: batchPadIdx = batchPadIdx.squeeze(1) encStates[0].data.index_fill_(1, batchPadIdx, 0) encStates[1].data.index_fill_(1, batchPadIdx, 0) context += [context_t] context = torch.cat(context) rnnSize = context.size(2) encStates = (self.model._fix_enc_hidden(encStates[0]), self.model._fix_enc_hidden(encStates[1])) # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding padMask = srcBatch.data.eq(onmt.Constants.PAD).t() def applyContextMask(m): if isinstance(m, onmt.modules.GlobalAttention): m.applyMask(padMask) # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model goldScores = context.data.new(batchSize).zero_() if tgtBatch is not None: decStates = encStates decOut = self.model.make_init_decoder_output(context) self.model.decoder.apply(applyContextMask) initOutput = self.model.make_init_decoder_output(context) decOut, decStates, attn = self.model.decoder( tgtBatch[:-1], decStates, context, initOutput) for dec_t, tgt_t in zip(decOut, tgtBatch[1:].data): gen_t = self.model.generator.forward(dec_t) tgt_t = tgt_t.unsqueeze(1) scores = gen_t.data.gather(1, tgt_t) scores.masked_fill_(tgt_t.eq(onmt.Constants.PAD), 0) goldScores += scores # (3) run the decoder to generate sentences, using beam search # Expand tensors for each beam. context = Variable(context.data.repeat(1, beamSize, 1), volatile=True) decStates = (Variable(encStates[0].data.repeat(1, beamSize, 1), volatile=True), Variable(encStates[1].data.repeat(1, beamSize, 1), volatile=True)) beam = [onmt.Beam(beamSize, self.opt.cuda) for k in range(batchSize)] decOut = self.model.make_init_decoder_output(context) padMask = srcBatch.data.eq(onmt.Constants.PAD).t().unsqueeze(0).repeat(beamSize, 1, 1) batchIdx = list(range(batchSize)) remainingSents = batchSize for i in range(self.opt.max_sent_length): self.model.decoder.apply(applyContextMask) # Prepare decoder input. input = torch.stack([b.getCurrentState() for b in beam if not b.done]).t().contiguous().view(1, -1) decOut, decStates, attn = self.model.decoder( Variable(input, volatile=True), decStates, context, decOut) # decOut: 1 x (beam*batch) x numWords decOut = decOut.squeeze(0) out = self.model.generator.forward(decOut) # batch x beam x numWords wordLk = out.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() attn = attn.view(beamSize, remainingSents, -1).transpose(0, 1).contiguous() active = [] for b in range(batchSize): if beam[b].done: continue idx = batchIdx[b] if not beam[b].advance(wordLk.data[idx], attn.data[idx]): active += [b] for decState in decStates: # iterate over h, c # layers x beam*sent x dim sentStates = decState.view( -1, beamSize, remainingSents, decState.size(2))[:, :, idx] sentStates.data.copy_( sentStates.data.index_select(1, beam[b].getCurrentOrigin())) if not active: break # in this section, the sentences that are still active are # compacted so that the decoder is not run on completed sentences activeIdx = self.tt.LongTensor([batchIdx[k] for k in active]) batchIdx = {beam: idx for idx, beam in enumerate(active)} def updateActive(t): # select only the remaining active sentences view = t.data.view(-1, remainingSents, rnnSize) newSize = list(t.size()) newSize[-2] = newSize[-2] * len(activeIdx) // remainingSents return Variable(view.index_select(1, activeIdx) \ .view(*newSize), volatile=True) decStates = (updateActive(decStates[0]), updateActive(decStates[1])) decOut = updateActive(decOut) context = updateActive(context) padMask = padMask.index_select(1, activeIdx) remainingSents = len(active) # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = self.opt.n_best for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] valid_attn = srcBatch.data[:, b].ne(onmt.Constants.PAD).nonzero().squeeze(1) hyps, attn = zip(*[beam[b].getHyp(k) for k in ks[:n_best]]) attn = [a.index_select(1, valid_attn) for a in attn] allHyp += [hyps] allAttn += [attn] return allHyp, allScores, allAttn, goldScores
def translateBatch(opt, model, batch, src_dict, tgt_dict, beam_accum): beamSize = opt.beam_size batchSize = batch.batchSize # (1) run the encoder on the src encStates, context, fertility_vals = model.encoder(batch.src) encStates = model.init_decoder_state(context, encStates) if fertility_vals is not None: fertility_vals = fertility_vals.repeat(beamSize * batchSize, 1) decoder = model.decoder attentionLayer = decoder.attn useMasking = True # This mask is applied to the attention model inside the decoder # so that the attention ignores source padding padMask = None if useMasking: padMask = batch.words().data.eq(onmt.Constants.PAD).t() def mask(padMask): if useMasking: attentionLayer.applyMask(padMask) # (2) if a target is specified, compute the 'goldScore' # (i.e. log likelihood) of the target under the model goldScores = context.data.new(batchSize).zero_() # (3) run the decoder to generate sentences, using beam search # Each hypothesis in the beam uses the same context # and initial decoder state context = Variable(context.data.repeat(1, beamSize, 1)) batch_src = Variable(batch.src.data.repeat(1, beamSize, 1)) decStates = encStates decStates.repeatBeam_(beamSize) beam = [onmt.Beam(beamSize, True) for _ in range(batchSize)] if useMasking: padMask = batch.src.data[:, :, 0].eq( onmt.Constants.PAD).t() \ .unsqueeze(0) \ .repeat(beamSize, 1, 1) # (3b) The main loop upper_bounds = None max_sent_length = 100 for i in range(max_sent_length): # (a) Run RNN decoder forward one step. mask(padMask) input = torch.stack([b.getCurrentState() for b in beam]) \ .t().contiguous().view(1, -1) input = Variable(input, volatile=True) decOut, decStates, attn, upper_bounds = model.decoder( input, batch_src, context, decStates, fertility_vals=fertility_vals, fert_dict=None, upper_bounds=decStates.attn_upper_bounds, test=True) # import pdb; pdb.set_trace() decOut = decOut.squeeze(0) # decOut: (beam*batch) x numWords attn["std"] = attn["std"].view(beamSize, batchSize, -1).transpose(0, 1).contiguous() # (b) Compute a vector of batch*beam word scores. out = model.generator.forward(decOut) word_scores = out.view(beamSize, batchSize, -1).transpose(0, 1).contiguous() # batch x beam x numWords # (c) Advance each beam. active = [] for b in range(batchSize): is_done = beam[b].advance(word_scores.data[b], attn["std"].data[b]) if not is_done: active += [b] decStates.beamUpdate_(b, beam[b].getCurrentOrigin(), beamSize) if not active: break # (4) package everything up allHyp, allScores, allAttn = [], [], [] n_best = 1 # If verbose is set, will output the n_best decoded sentences for b in range(batchSize): scores, ks = beam[b].sortBest() allScores += [scores[:n_best]] hyps, attn = [], [] for k in ks[:n_best]: hyp, att = beam[b].getHyp(k) hyps.append(hyp) attn.append(att) allHyp += [hyps] if useMasking: valid_attn = batch.src.data[:, b, 0].ne(onmt.Constants.PAD) \ .nonzero().squeeze(1) attn = [a.index_select(1, valid_attn) for a in attn] allAttn += [attn] # For debugging visualization. if beam_accum: beam_accum["beam_parent_ids"].append( [t.tolist() for t in beam[b].prevKs]) beam_accum["scores"].append([["%4f" % s for s in t.tolist()] for t in beam[b].allScores][1:]) beam_accum["predicted_ids"].append( [[tgt_dict.getLabel(id) for id in t.tolist()] for t in beam[b].nextYs][1:]) # import pdb; pdb.set_trace() if fertility_vals is not None: cum_attn = allAttn[0][0].sum(0).squeeze(0).cpu().numpy() fert = fertility_vals.data[0, :].cpu().numpy() for c, f in zip(cum_attn, fert): print('%f (%f)' % (c, f)) # print allAttn[0][0].sum(0) return allHyp, allScores, allAttn, goldScores