def forward(self, expl, s1_embed, s2_embed, mode, classif_lbl): # expl: Variable(seqlen x bsize x worddim) # s1/2_embed: Variable(bsize x sent_dim) assert mode in ['forloop', 'teacher'], mode batch_size = expl.size(1) assert_sizes(s1_embed, 2, [batch_size, self.sent_dim]) assert_sizes(s2_embed, 2, [batch_size, self.sent_dim]) assert_sizes(expl, 3, [expl.size(0), batch_size, self.word_emb_dim]) context = torch.cat([s1_embed, s2_embed], 1).unsqueeze(0) if self.use_diff_prod_sent_embed: context = torch.cat([ s1_embed, s2_embed, torch.abs(s1_embed - s2_embed), s1_embed * s2_embed ], 1).unsqueeze(0) if self.only_diff_prod: context = torch.cat( [torch.abs(s1_embed - s2_embed), s1_embed * s2_embed], 1).unsqueeze(0) assert_sizes( context, 3, [1, batch_size, self.context_mutiply_coef * self.sent_dim]) # init decoder context_init = torch.cat([s1_embed, s2_embed], 1).unsqueeze(0) if self.use_init: if 2 * self.sent_dim != self.dec_rnn_dim: init_0 = self.proj_init( context_init.expand(self.n_layers_dec, batch_size, 2 * self.sent_dim)) else: init_0 = context_init else: init_0 = Variable( torch.zeros(self.n_layers_dec, batch_size, self.dec_rnn_dim)).cuda() init_state = init_0 if self.decoder_type == 'lstm': init_state = (init_0, init_0) self.decoder_rnn.flatten_parameters() if mode == "teacher": input_dec = torch.cat([ expl, context.expand(expl.size(0), batch_size, self.context_mutiply_coef * self.sent_dim) ], 2) input_dec = self.proj_inp_dec( nn.Dropout(self.dpout_dec)(input_dec)) out, _ = self.decoder_rnn(input_dec, init_state) dp_out = nn.Dropout(self.dpout_dec)(out) if not self.use_vocab_proj: return self.vocab_layer(dp_out) return self.vocab_layer(self.vocab_proj(dp_out)) else: assert classif_lbl is not None assert_sizes(classif_lbl, 1, [batch_size]) pred_expls = [] finished = [] for i in range(batch_size): pred_expls.append("") finished.append(False) dec_inp_t = torch.cat([expl[0, :, :].unsqueeze(0), context], 2) dec_inp_t = self.proj_inp_dec(dec_inp_t) ht = init_state t = 0 while t < self.max_T_decoder and not array_all_true(finished): t += 1 word_embed = torch.zeros(1, batch_size, self.word_emb_dim) assert_sizes(dec_inp_t, 3, [1, batch_size, self.inp_dec_dim]) dec_out_t, ht = self.decoder_rnn(dec_inp_t, ht) assert_sizes(dec_out_t, 3, [1, batch_size, self.dec_rnn_dim]) if self.use_vocab_proj: out_t_proj = self.vocab_proj(dec_out_t) out_t = self.vocab_layer(out_t_proj).data else: out_t = self.vocab_layer( dec_out_t ).data # TODO: Use torch.stack with variables instead assert_sizes(out_t, 3, [1, batch_size, self.n_vocab]) i_t = torch.max(out_t, 2)[1] assert_sizes(i_t, 2, [1, batch_size]) pred_words = get_keys_from_vals( i_t, self.word_index ) # array of bs of words at current timestep assert len(pred_words) == batch_size, "pred_words " + str( len(pred_words)) + " batch_size " + str(batch_size) for i in range(batch_size): if pred_words[i] == '</s>': finished[i] = True if not finished[i]: pred_expls[i] += " " + pred_words[i] if t > 1: #print "self.word_vec[pred_words[i]]", type(self.word_vec[pred_words[i]]) word_embed[0, i] = torch.from_numpy( self.word_vec[pred_words[i]]) #print "type(word_embed[0, i]) ", word_embed[0, i] #assert False else: # put label predicted by classifier classif_label = get_key_from_val( classif_lbl[i], NLI_DIC_LABELS) assert classif_label in [ 'entailment', 'contradiction', 'neutral' ], classif_label word_embed[0, i] = torch.from_numpy( self.word_vec[classif_label]) word_embed = Variable(word_embed.cuda()) assert_sizes(word_embed, 3, [1, batch_size, self.word_emb_dim]) dec_inp_t = self.proj_inp_dec( torch.cat([word_embed, context], 2)) return pred_expls
def forward(self, expl, enc_out_s1, enc_out_s2, s1_embed, s2_embed, mode, visualize): # expl: Variable(seqlen x bsize x worddim) # s1/2_embed: Variable(bsize x sent_dim) assert mode in ['forloop', 'teacher'], mode current_T_dec = expl.size(0) batch_size = expl.size(1) assert_sizes(s1_embed, 2, [batch_size, self.sent_dim]) assert_sizes(s2_embed, 2, [batch_size, self.sent_dim]) assert_sizes(expl, 3, [current_T_dec, batch_size, self.word_emb_dim]) assert_sizes(enc_out_s1, 3, [self.max_T_encoder, batch_size, 2 * self.enc_rnn_dim]) assert_sizes(enc_out_s2, 3, [self.max_T_encoder, batch_size, 2 * self.enc_rnn_dim]) context = torch.cat([ s1_embed, s2_embed, torch.abs(s1_embed - s2_embed), s1_embed * s2_embed ], 1).unsqueeze(0) assert_sizes(context, 3, [1, batch_size, 4 * self.sent_dim]) # init decoder if self.use_init: init_0 = self.context_proj(context).expand(self.n_layers_dec, batch_size, self.dec_rnn_dim) else: init_0 = Variable( torch.zeros(self.n_layers_dec, batch_size, self.dec_rnn_dim)).cuda() init_state = init_0 if self.decoder_type == 'lstm': init_state = (init_0, init_0) self.decoder_rnn.flatten_parameters() out_expl = None state_t = init_state context = self.context_proj(context) if mode == "teacher": for t_dec in range(current_T_dec): # attention over premise context1 = self.att_context_proj1(context).permute(1, 0, 2) assert_sizes(context1, 3, [batch_size, 1, self.att_hid_dim]) inp_att_1 = self.att_ht_proj1(enc_out_s1).transpose( 1, 0).transpose(2, 1) assert_sizes( inp_att_1, 3, [batch_size, self.att_hid_dim, self.max_T_encoder]) dot_prod_att_1 = torch.bmm(context1, inp_att_1) assert_sizes(dot_prod_att_1, 3, [batch_size, 1, self.max_T_encoder]) att_weights_1 = self.softmax_att(dot_prod_att_1) assert_sizes(att_weights_1, 3, [batch_size, 1, self.max_T_encoder]) att_applied_1 = torch.bmm( att_weights_1, self.att_ht_before_weighting_proj1(enc_out_s1).permute( 1, 0, 2)) assert_sizes(att_applied_1, 3, [batch_size, 1, self.att_hid_dim]) att_applied_perm_1 = att_applied_1.permute(1, 0, 2) assert_sizes(att_applied_perm_1, 3, [1, batch_size, self.att_hid_dim]) # attention over hypothesis context2 = self.att_context_proj2(context).permute(1, 0, 2) assert_sizes(context2, 3, [batch_size, 1, self.att_hid_dim]) inp_att_2 = self.att_ht_proj2(enc_out_s2).transpose( 1, 0).transpose(2, 1) assert_sizes( inp_att_2, 3, [batch_size, self.att_hid_dim, self.max_T_encoder]) dot_prod_att_2 = torch.bmm(context2, inp_att_2) assert_sizes(dot_prod_att_2, 3, [batch_size, 1, self.max_T_encoder]) att_weights_2 = self.softmax_att(dot_prod_att_2) assert_sizes(att_weights_2, 3, [batch_size, 1, self.max_T_encoder]) att_applied_2 = torch.bmm( att_weights_2, self.att_ht_before_weighting_proj2(enc_out_s2).permute( 1, 0, 2)) assert_sizes(att_applied_2, 3, [batch_size, 1, self.att_hid_dim]) att_applied_perm_2 = att_applied_2.permute(1, 0, 2) assert_sizes(att_applied_perm_2, 3, [1, batch_size, self.att_hid_dim]) input_dec = torch.cat([ expl[t_dec].unsqueeze(0), att_applied_perm_1, att_applied_perm_2 ], 2) input_dec = nn.Dropout(self.dpout_dec)( self.proj_inp_dec(input_dec)) out_dec, state_t = self.decoder_rnn(input_dec, state_t) assert_sizes(out_dec, 3, [1, batch_size, self.dec_rnn_dim]) if self.decoder_type == 'lstm': context = state_t[0] else: context = state_t if out_expl is None: out_expl = out_dec else: out_expl = torch.cat([out_expl, out_dec], 0) out_expl = self.vocab_layer(out_expl) assert_sizes(out_expl, 3, [current_T_dec, batch_size, self.n_vocab]) return out_expl else: pred_expls = [] finished = [] for i in range(batch_size): pred_expls.append("") finished.append(False) t_dec = 0 word_t = expl[0].unsqueeze(0) while t_dec < self.max_T_decoder and not array_all_true(finished): #print "\n\n\n t: ", t_dec assert_sizes(word_t, 3, [1, batch_size, self.word_emb_dim]) word_embed = torch.zeros(1, batch_size, self.word_emb_dim) # attention over premise context1 = self.att_context_proj1(context).permute(1, 0, 2) assert_sizes(context1, 3, [batch_size, 1, self.att_hid_dim]) inp_att_1 = self.att_ht_proj1(enc_out_s1).transpose( 1, 0).transpose(2, 1) assert_sizes( inp_att_1, 3, [batch_size, self.att_hid_dim, self.max_T_encoder]) dot_prod_att_1 = torch.bmm(context1, inp_att_1) assert_sizes(dot_prod_att_1, 3, [batch_size, 1, self.max_T_encoder]) att_weights_1 = self.softmax_att(dot_prod_att_1) assert_sizes(att_weights_1, 3, [batch_size, 1, self.max_T_encoder]) att_applied_1 = torch.bmm( att_weights_1, self.att_ht_before_weighting_proj1(enc_out_s1).permute( 1, 0, 2)) assert_sizes(att_applied_1, 3, [batch_size, 1, self.att_hid_dim]) att_applied_perm_1 = att_applied_1.permute(1, 0, 2) assert_sizes(att_applied_perm_1, 3, [1, batch_size, self.att_hid_dim]) # attention over hypothesis context2 = self.att_context_proj2(context).permute(1, 0, 2) assert_sizes(context2, 3, [batch_size, 1, self.att_hid_dim]) inp_att_2 = self.att_ht_proj2(enc_out_s2).transpose( 1, 0).transpose(2, 1) assert_sizes( inp_att_2, 3, [batch_size, self.att_hid_dim, self.max_T_encoder]) dot_prod_att_2 = torch.bmm(context2, inp_att_2) assert_sizes(dot_prod_att_2, 3, [batch_size, 1, self.max_T_encoder]) att_weights_2 = self.softmax_att(dot_prod_att_2) assert_sizes(att_weights_2, 3, [batch_size, 1, self.max_T_encoder]) att_applied_2 = torch.bmm( att_weights_2, self.att_ht_before_weighting_proj2(enc_out_s2).permute( 1, 0, 2)) assert_sizes(att_applied_2, 3, [batch_size, 1, self.att_hid_dim]) att_applied_perm_2 = att_applied_2.permute(1, 0, 2) assert_sizes(att_applied_perm_2, 3, [1, batch_size, self.att_hid_dim]) input_dec = torch.cat( [word_t, att_applied_perm_1, att_applied_perm_2], 2) input_dec = self.proj_inp_dec(input_dec) #print "att_weights_1[0] ", att_weights_1[0] #print "att_weights_2[0] ", att_weights_2[0] # get one visualization from the current batch if visualize: if t_dec == 0: weights_1 = att_weights_1[0] weights_2 = att_weights_2[0] else: weights_1 = torch.cat([weights_1, att_weights_1[0]], 0) weights_2 = torch.cat([weights_2, att_weights_2[0]], 0) for ii in range(batch_size): assert abs(att_weights_1[ii].data.sum() - 1) < 1e-5, str( att_weights_1[ii].data.sum()) assert abs(att_weights_2[ii].data.sum() - 1) < 1e-5, str( att_weights_2[ii].data.sum()) out_t, state_t = self.decoder_rnn(input_dec, state_t) assert_sizes(out_t, 3, [1, batch_size, self.dec_rnn_dim]) out_t = self.vocab_layer(out_t) if self.decoder_type == 'lstm': context = state_t[0] else: context = state_t i_t = torch.max(out_t, 2)[1].data assert_sizes(i_t, 2, [1, batch_size]) pred_words = get_keys_from_vals( i_t, self.word_index ) # array of bs of words at current timestep assert len(pred_words) == batch_size, "pred_words " + str( len(pred_words)) + " batch_size " + str(batch_size) for i in range(batch_size): if pred_words[i] == '</s>': finished[i] = True if not finished[i]: pred_expls[i] += " " + pred_words[i] word_embed[0, i] = torch.from_numpy( self.word_vec[pred_words[i]]) word_t = Variable(word_embed.cuda()) t_dec += 1 if visualize: assert weights_1.dim() == 2 assert weights_1.size(1) == self.max_T_encoder assert weights_2.dim() == 2 assert weights_2.size(1) == self.max_T_encoder pred_expls = [pred_expls, weights_1, weights_2] return pred_expls
def eval_datasets_without_expl(esnli_net, which_set, data, word_vec, word_emb_dim, batch_size, print_every, current_run_dir): dict_labels = NLI_DIC_LABELS esnli_net.eval() correct = 0. correct_labels_expl = 0. s1 = data['s1'] s2 = data['s2'] label = data['label'] label_expl = data['label_expl'] headers = [ "gold_label", "Premise", "Hypothesis", "pred_label", "pred_expl", "pred_lbl_decoder" ] expl_csv = os.path.join( current_run_dir, time.strftime("%d:%m") + "_" + time.strftime("%H:%M:%S") + "_" + which_set + ".csv") remove_file(expl_csv) expl_f = open(expl_csv, "a") writer = csv.writer(expl_f) writer.writerow(headers) for i in range(0, len(s1), batch_size): # prepare batch s1_batch, s1_len = get_batch(s1[i:i + batch_size], word_vec) s2_batch, s2_len = get_batch(s2[i:i + batch_size], word_vec) current_bs = s1_batch.size(1) assert_sizes(s1_batch, 3, [s1_batch.size(0), current_bs, word_emb_dim]) assert_sizes(s2_batch, 3, [s2_batch.size(0), current_bs, word_emb_dim]) s1_batch, s2_batch = Variable(s1_batch.cuda()), Variable( s2_batch.cuda()) tgt_label_batch = Variable(torch.LongTensor(label[i:i + batch_size])).cuda() tgt_label_expl_batch = label_expl[i:i + batch_size] expl_t0 = Variable( torch.from_numpy(word_vec['<s>']).float().unsqueeze(0).expand( current_bs, word_emb_dim).unsqueeze(0)).cuda() assert_sizes(expl_t0, 3, [1, current_bs, word_emb_dim]) # model forward pred_expls, out_lbl = esnli_net((s1_batch, s1_len), (s2_batch, s2_len), expl_t0, mode="forloop") assert len(pred_expls) == current_bs, "pred_expls: " + str( len(pred_expls)) + " current_bs: " + str(current_bs) for b in range(len(pred_expls)): assert tgt_label_expl_batch[b] in [ 'entailment', 'neutral', 'contradiction' ] if len(pred_expls[b]) > 0: words = pred_expls[b].strip().split(" ") if words[0] == tgt_label_expl_batch[b]: correct_labels_expl += 1 # accuracy pred = out_lbl.data.max(1)[1] correct += pred.long().eq(tgt_label_batch.data.long()).cpu().sum() # write csv row of predictions # Look up for the headers order for j in range(len(pred_expls)): row = [] row.append(get_key_from_val(label[i + j], dict_labels)) row.append(' '.join(s1[i + j][1:-1])) row.append(' '.join(s2[i + j][1:-1])) row.append(get_key_from_val(pred[j], dict_labels)) row.append(pred_expls[j][1:-1]) row.append(pred_expls[j][0]) writer.writerow(row) # print example if i % print_every == 0: print(which_set.upper() + " example: ") print("Premise: ", ' '.join(s1[i]), " LENGHT: ", s1_len[0]) print("Hypothesis: ", ' '.join(s2[i]), " LENGHT: ", s2_len[0]) print("Gold label: ", get_key_from_val(label[i], dict_labels)) print("Predicted label: ", get_key_from_val(pred[0], dict_labels)) print("Predicted explanation: ", pred_expls[0], "\n\n\n") eval_acc = round(100 * correct / len(s1), 2) eval_acc_label_expl = round(100 * correct_labels_expl / len(s1), 2) print(which_set.upper() + " no train ", eval_acc, '\n\n\n') expl_f.close() return eval_acc, eval_acc_label_expl