def predict(self, enc_hidden, context, context_lengths, batch, beam_size, max_code_length, generator, replace_unk, vis_params): # This decoder does not have input feeding. Parent state replces that decState = DecoderState( enc_hidden, #encoder hidden Variable(torch.zeros(1, 1, self.opt.rnn_size).cuda(), requires_grad=False) # parent state ) # Repeat everything beam_size times. def rvar(a, beam_size): return Variable(a.repeat(beam_size, 1, 1), volatile=True) context = rvar(context.data, beam_size) context_lengths = context_lengths.repeat(beam_size) decState.repeat_beam_size_times(beam_size) # TODO: get back to this # Use only one beam beam = TreeBeam(beam_size, True, self.vocabs, self.opt.rnn_size) for count in range(0, max_code_length): # We will break when we have the required number of terminals # to be consistent with seq2seq if beam.done(): # TODO: fix b.done break # Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. # Uses the start symbol in the beginning inp = beam.getCurrentState() # Should return a batch of the frontier # Run one step., decState gets automatically updated output, attn, copy_attn = self.forward(inp, context, context_lengths, decState) scores = generator(bottle(output), bottle(copy_attn), batch['src_map'], inp) #generator needs the non-terminals out = generator.collapseCopyScores(unbottle(scores.data.clone(), beam_size), batch) # needs seq2seq from batch out = out.log() # beam x tgt_vocab beam.advance(out[:, 0], attn.data[:, 0], output) decState.beam_update(beam.getCurrentOrigin(), beam_size) score, times, k = beam.getFinal() # times is the length of the prediction #hyp, att = beam.getHyp(times, k) goldNl = self.vocabs['seq2seq'].addStartOrEnd(batch['raw_seq2seq'][0]) # because batch = 1 goldCode = self.vocabs['code'].addStartOrEnd(batch['raw_code'][0]) # goldProd = self.vocabs['next_rules'].addStartOrEnd(batch['raw_next_rules'][0]) predictions = [] for score, times, k in beam.finished: hyp, att = beam.getHyp(times, k) predSent = self.buildTargetTokens( hyp, self.vocabs, goldNl, att, batch['seq2seq_vocab'][0], replace_unk ) predSent = ProdDecoder.rulesToCode(predSent) predictions.append(Prediction(goldNl, goldCode, predSent, att, score)) return predictions
def computeLoss(self, scores, batch): batch_size = batch['seq2seq'].size(0) target = Variable(batch['next_rules'].contiguous().cuda().view(-1), requires_grad=False) if self.opt.decoder_type == "prod": align = Variable( batch['next_rules_in_src_nums'].contiguous().cuda().view(-1), requires_grad=False) align_unk = batch['seq2seq_vocab'][0].stoi['<unk>'] elif self.opt.decoder_type in ["concode"]: align = Variable(batch['concode_next_rules_in_src_nums']. contiguous().cuda().view(-1), requires_grad=False) align_unk = batch['concode_vocab'][0].stoi['<unk>'] offset = len(self.vocabs['next_rules']) out = scores.gather( 1, align.view(-1, 1) + offset).view(-1).mul( align.ne(align_unk).float()) # all where copy is not unk tmp = scores.gather(1, target.view(-1, 1)).view(-1) unk_mask = target.data.ne(self.tgt_unks[0]) for unk in self.tgt_unks: unk_mask = unk_mask & target.data.ne(unk) unk_mask_var = Variable(unk_mask, requires_grad=False) inv_unk_mask_var = Variable(~unk_mask, requires_grad=False) out = out + 1e-20 + tmp.mul(unk_mask_var.float()) + \ tmp.mul(align.eq(align_unk).float()).mul(inv_unk_mask_var.float()) # copy and target are unks # Drop padding. loss = -out.log().mul(target.ne(self.tgt_pad).float()).sum() scores_data = scores.data.clone() target_data = target.data.clone() #computeLoss populates this scores_data = self.collapseCopyScores( unbottle(scores_data, batch_size), batch) scores_data = bottle(scores_data) # Correct target copy token instead of <unk> # tgt[i] = align[i] + len(tgt_vocab) # for i such that tgt[i] == 0 and align[i] != 0 # when target is <unk> but can be copied, make sure we get the copy index right correct_mask = inv_unk_mask_var.data * align.data.ne(align_unk) correct_copy = (align.data + offset) * correct_mask.long() target_data = (target_data * (~correct_mask).long()) + correct_copy pred = scores_data.max(1)[1] non_padding = target_data.ne(self.tgt_pad) num_correct = pred.eq(target_data).masked_select(non_padding).sum() return loss, non_padding.sum(), num_correct #, stats
def forward(self, batch): # initial parent states for Prod Decoder batch_size = batch['seq2seq'].size(0) if self.opt.decoder_type == "concode": batch['parent_states'] = {} for j in range(0, batch_size): batch['parent_states'][j] = {} if self.opt.decoder_type in ["prod", "concode"]: batch['parent_states'][j][0] = Variable( torch.zeros(1, 1, self.opt.decoder_rnn_size).cuda(), requires_grad=False) context, context_lengths, enc_hidden = self.encoder(batch) decInitState = DecoderState( enc_hidden, Variable(torch.zeros(batch_size, 1, self.opt.decoder_rnn_size).cuda(), requires_grad=False)) output, attn, copy_attn = self.decoder(batch, context, context_lengths, decInitState) if self.opt.decoder_type == "concode": del batch['parent_states'] # Other generators will not use the extra parameters # Let the generator put the src_map in cuda if it uses it # TODO: Make sec_map variable again in generator src_map = torch.zeros(0, 0) if self.opt.decoder_type == "concode": src_map = torch.cat((batch['concode_src_map_vars'], batch['concode_src_map_methods']), 1) scores = self.generator( bottle(output), bottle(copy_attn), src_map if self.opt.encoder_type in ["concode"] else batch['src_map'], batch) loss, total, correct = self.generator.computeLoss(scores, batch) return loss, Statistics(loss.data.item(), total.item(), correct.item(), self.encoder.n_src_words)
def computeLoss(self, scores, batch): """ Args: batch: the current batch. target: the validate target to compare output with. align: the align info. """ batch_size = batch['seq2seq'].size(0) self.target = Variable(shiftLeft(batch['code'].cuda(), self.tgt_padding_idx).view(-1), requires_grad=False) align = Variable( shiftLeft(batch['code_in_src_nums'].cuda(), self.vocabs['seq2seq'].stoi['<blank>']).view(-1), requires_grad=False) # All individual vocabs have the same unk index align_unk = batch['seq2seq_vocab'][0].stoi['<unk>'] loss = self.criterion(scores, self.target, align, align_unk) scores_data = scores.data.clone() target_data = self.target.data.clone() #computeLoss populates this if self.opt.copy_attn: scores_data = self.collapseCopyScores( unbottle(scores_data, batch_size), batch) scores_data = bottle(scores_data) # Correct target copy token instead of <unk> # tgt[i] = align[i] + len(tgt_vocab) # for i such that tgt[i] == 0 and align[i] != 0 # when target is <unk> but can be copied, make sure we get the copy index right correct_mask = target_data.eq( self.tgt_unk_idx) * align.data.ne(align_unk) correct_copy = (align.data + self.tgt_dict_size) * correct_mask.long() target_data = (target_data * (1 - correct_mask).long()) + correct_copy pred = scores_data.max(1)[1] non_padding = target_data.ne(self.tgt_padding_idx) num_correct = pred.eq(target_data).masked_select(non_padding).sum() return loss, non_padding.sum(), num_correct #, stats
def predict(self, enc_hidden, context, context_lengths, batch, beam_size, max_code_length, generator, replace_unk, vis_params): # This decoder does not have input feeding. Parent state replces that decState = DecoderState( enc_hidden, #encoder hidden Variable(torch.zeros(1, 1, self.opt.decoder_rnn_size).cuda(), requires_grad=False) # parent state ) # Repeat everything beam_size times. def rvar(a, beam_size): return Variable(a.repeat(beam_size, 1, 1), volatile=True) context = tuple( rvar(context[i].data, beam_size) for i in range(0, len(context))) context_lengths = tuple(context_lengths[i].repeat(beam_size, 1) for i in range(0, len(context_lengths))) decState.repeat_beam_size_times(beam_size) # Use only one beam beam = TreeBeam(beam_size, True, self.vocabs, self.opt.decoder_rnn_size) for count in range( 0, max_code_length ): # We will break when we have the required number of terminals # to be consistent with seq2seq if beam.done(): break # Construct batch x beam_size nxt words. # Get all the pending current beam words and arrange for forward. # Uses the start symbol in the beginning inp = beam.getCurrentState( ) # Should return a batch of the frontier # Run one step., decState gets automatically updated output, attn, copy_attn = self.forward(inp, context, context_lengths, decState) src_map = torch.zeros(0, 0) if self.opt.var_names: src_map = torch.cat((src_map, batch['concode_src_map_vars']), 1) if self.opt.method_names: src_map = torch.cat( (src_map, batch['concode_src_map_methods']), 1) scores = generator(bottle(output), bottle(copy_attn), src_map, inp) #generator needs the non-terminals out = generator.collapseCopyScores( unbottle(scores.data.clone(), beam_size), batch) # needs seq2seq from batch out = out.log() # beam x tgt_vocab beam.advance(out[:, 0], attn.data[:, 0], output) decState.beam_update(beam.getCurrentOrigin(), beam_size) pred_score_total = 0 pred_words_total = 0 score, times, k = beam.getFinal( ) # times is the length of the prediction hyp, att = beam.getHyp(times, k) goldNl = [] if self.opt.var_names: goldNl += batch['concode_var'][0] # because batch = 1 if self.opt.method_names: goldNl += batch['concode_method'][0] # because batch = 1 goldCode = self.vocabs['code'].addStartOrEnd(batch['raw_code'][0]) predSent, copied_tokens, replaced_tokens = self.buildTargetTokens( hyp, self.vocabs, goldNl, att, batch['concode_vocab'][0], replace_unk) predSent = ConcodeDecoder.rulesToCode(predSent) pred_score_total += score pred_words_total += len(predSent) return Prediction(goldNl, goldCode, predSent, att)