def evaluate_snli_final(esnli_net, criterion_expl, dataset, data, expl_no_unk, word_vec, word_index, batch_size, print_every, current_run_dir): assert dataset in ['snli_dev', 'snli_test'] print(dataset.upper()) esnli_net.eval() correct = 0. correct_labels_expl = 0. cum_test_ppl = 0 cum_test_n_words = 0 headers = [ "gold_label", "Premise", "Hypothesis", "pred_label", "pred_expl", "pred_lbl_decoder", "Expl_1", "Expl_2", "Expl_3" ] expl_csv = os.path.join( current_run_dir, time.strftime("%d:%m") + "_" + time.strftime("%H:%M:%S") + "_" + dataset + ".csv") remove_file(expl_csv) expl_f = open(expl_csv, "a") writer = csv.writer(expl_f) writer.writerow(headers) s1 = data['s1'] s2 = data['s2'] expl_1 = data['expl_1'] expl_2 = data['expl_2'] expl_3 = data['expl_3'] label = data['label'] label_expl = data['label_expl'] for i in range(0, len(s1), batch_size): # prepare batch s1_batch, s1_len = get_batch(s1[i:i + batch_size], word_vec) s2_batch, s2_len = get_batch(s2[i:i + batch_size], word_vec) s1_batch, s2_batch = Variable(s1_batch.cuda()), Variable( s2_batch.cuda()) tgt_label_batch = Variable(torch.LongTensor(label[i:i + batch_size])).cuda() tgt_label_expl_batch = label_expl[i:i + batch_size] # print example if i % print_every == 0: print("Final SNLI example from " + dataset) print("Sentence1: ", ' '.join(s1[i]), " LENGHT: ", s1_len[0]) print("Sentence2: ", ' '.join(s2[i]), " LENGHT: ", s2_len[0]) print("Gold label: ", get_key_from_val(label[i], NLI_DIC_LABELS)) out_lbl = [0, 1, 2, 3] for index in range(1, 4): expl = eval("expl_" + str(index)) input_expl_batch, _ = get_batch(expl[i:i + batch_size], word_vec) input_expl_batch = Variable(input_expl_batch[:-1].cuda()) if i % print_every == 0: print("Explanation " + str(index) + " : ", ' '.join(expl[i])) print("Predicted label by decoder " + str(index) + " : ", ' '.join(expl[i][0])) tgt_expl_batch, lens_tgt_expl = get_target_expl_batch( expl[i:i + batch_size], word_index) assert tgt_expl_batch.dim() == 2, "tgt_expl_batch.dim()=" + str( tgt_expl_batch.dim()) tgt_expl_batch = Variable(tgt_expl_batch).cuda() if i % print_every == 0: print( "Target expl " + str(index) + " : ", get_sentence_from_indices(word_index, tgt_expl_batch[:, 0]), " LENGHT: ", lens_tgt_expl[0]) # model forward, tgt_labels is still None bcs in test mode we get the predicted labels out_expl, out_lbl[index - 1] = esnli_net((s1_batch, s1_len), (s2_batch, s2_len), input_expl_batch, mode="teacher") # ppl loss_expl = criterion_expl( out_expl.view(out_expl.size(0) * out_expl.size(1), -1), tgt_expl_batch.view( tgt_expl_batch.size(0) * tgt_expl_batch.size(1))) cum_test_n_words += lens_tgt_expl.sum() cum_test_ppl += loss_expl.data[0] answer_idx = torch.max(out_expl, 2)[1] if i % print_every == 0: print("Decoded explanation " + str(index) + " : ", get_sentence_from_indices(word_index, answer_idx[:, 0])) print("\n") pred_expls, out_lbl[3] = esnli_net((s1_batch, s1_len), (s2_batch, s2_len), input_expl_batch, mode="forloop") if i % print_every == 0: print("Fully decoded explanation: ", pred_expls[0].strip().split()[1:-1]) print("Predicted label from decoder: ", pred_expls[0].strip().split()[0]) for b in range(len(pred_expls)): assert tgt_label_expl_batch[b] in [ 'entailment', 'neutral', 'contradiction' ] if len(pred_expls[b]) > 0: words = pred_expls[b].strip().split() assert words[0] in ['entailment', 'neutral', 'contradiction'], words[0] if words[0] == tgt_label_expl_batch[b]: correct_labels_expl += 1 assert (torch.equal(out_lbl[0], out_lbl[1])) assert (torch.equal(out_lbl[1], out_lbl[2])) assert (torch.equal(out_lbl[2], out_lbl[3])) # accuracy pred = out_lbl[0].data.max(1)[1] if i % print_every == 0: print("Predicted label from classifier: ", get_key_from_val(pred[0], NLI_DIC_LABELS), "\n\n\n") correct += pred.long().eq(tgt_label_batch.data.long()).cpu().sum() # write csv row of predictions for j in range(len(pred_expls)): row = [] row.append(get_key_from_val(label[i + j], NLI_DIC_LABELS)) row.append(' '.join(s1[i + j][1:-1])) row.append(' '.join(s2[i + j][1:-1])) row.append(get_key_from_val(pred[j], NLI_DIC_LABELS)) row.append(' '.join(pred_expls[j].strip().split()[1:-1])) assert pred_expls[j].strip().split()[0] in [ 'entailment', 'contradiction', 'neutral' ], pred_expls[j].strip().split()[0] row.append(pred_expls[j].strip().split()[0]) #row.append(' '.join(expl_1[i+j][2:-1])) #row.append(' '.join(expl_2[i+j][2:-1])) #row.append(' '.join(expl_3[i+j][2:-1])) row.append(expl_no_unk['expl_1'][i + j]) row.append(expl_no_unk['expl_2'][i + j]) row.append(expl_no_unk['expl_3'][i + j]) writer.writerow(row) eval_acc = round(100 * correct / len(s1), 2) eval_acc_label_expl = round(100 * correct_labels_expl / len(s1), 2) eval_ppl = math.exp(cum_test_ppl / cum_test_n_words) expl_f.close() bleu_score = 100 * bleu_prediction(expl_csv, expl_no_unk) print(dataset.upper() + ' SNLI accuracy: ', eval_acc, 'bleu score: ', bleu_score, 'ppl: ', eval_ppl, 'eval_acc_label_expl: ', eval_acc_label_expl) return eval_acc, round(bleu_score, 2), round(eval_ppl, 2), eval_acc_label_expl
def evaluate_dev(epoch): esnli_net.eval() global val_acc_best, val_ppl_best, stop_training, last_improvement_epoch correct = 0. cum_dev_ppl = 0 cum_dev_n_words = 0 print('\DEV : Epoch {0}'.format(epoch)) # eSNLI s1 = snli_dev['s1'] s2 = snli_dev['s2'] expl_1 = snli_dev['expl_1'] expl_2 = snli_dev['expl_2'] expl_3 = snli_dev['expl_3'] label = snli_dev['label'] for i in range(0, len(s1), params.eval_batch_size): # prepare batch s1_batch, s1_len = get_batch(s1[i:i + params.eval_batch_size], word_vec) s2_batch, s2_len = get_batch(s2[i:i + params.eval_batch_size], word_vec) s1_batch, s2_batch = Variable(s1_batch.cuda()), Variable( s2_batch.cuda()) tgt_label_batch = Variable( torch.LongTensor(label[i:i + params.eval_batch_size])).cuda() # print example if i % params.print_every == 0: print current_run_dir, '\n' print "SNLI DEV example" print "Sentence1: ", ' '.join(s1[i]), " LENGTH: ", s1_len[0] print "Sentence2: ", ' '.join(s2[i]), " LENGTH: ", s2_len[0] print "Gold label: ", get_key_from_val(label[i], NLI_DIC_LABELS) out_lbl = [0, 1, 2] for index in range(1, 4): expl = eval("expl_" + str(index)) input_expl_batch, _ = get_batch(expl[i:i + params.eval_batch_size], word_vec) input_expl_batch = Variable(input_expl_batch[:-1].cuda()) if i % params.print_every == 0: print "Explanation " + str(index) + " : ", ' '.join(expl[i]) tgt_expl_batch, lens_tgt_expl = get_target_expl_batch( expl[i:i + params.eval_batch_size], word_index) assert tgt_expl_batch.dim() == 2, "tgt_expl_batch.dim()=" + str( tgt_expl_batch.dim()) tgt_expl_batch = Variable(tgt_expl_batch).cuda() if i % params.print_every == 0: print "Target expl " + str( index) + " : ", get_sentence_from_indices( word_index, tgt_expl_batch[:, 0]), " LENGHT: ", lens_tgt_expl[0] # model forward, tgt_label is None for both v1 and v2 bcs it's test time for v2 out_expl, out_lbl[index - 1] = esnli_net( (s1_batch, s1_len), (s2_batch, s2_len), input_expl_batch, 'teacher') # ppl loss_expl = criterion_expl( out_expl.view(out_expl.size(0) * out_expl.size(1), -1), tgt_expl_batch.view( tgt_expl_batch.size(0) * tgt_expl_batch.size(1))) cum_dev_n_words += lens_tgt_expl.sum() cum_dev_ppl += loss_expl.data[0] answer_idx = torch.max(out_expl, 2)[1] if i % params.print_every == 0: print "Decoded explanation " + str( index) + " : ", get_sentence_from_indices( word_index, answer_idx[:, 0]) print "\n" assert torch.equal(out_lbl[0], out_lbl[1]), "out_lbl[0]: " + str( out_lbl[0]) + " while " + "out_lbl[1]: " + str(out_lbl[1]) assert torch.equal(out_lbl[1], out_lbl[2]), "out_lbl[1]: " + str( out_lbl[1]) + " while " + "out_lbl[2]: " + str(out_lbl[2]) # accuracy pred = out_lbl[0].data.max(1)[1] if i % params.print_every == 0: print "Predicted label: ", get_key_from_val( pred[0], NLI_DIC_LABELS), "\n\n\n" correct += pred.long().eq(tgt_label_batch.data.long()).cpu().sum() total_dev_points = len(s1) # accuracy eval_acc = round(100 * correct / total_dev_points, 2) print 'togrep : results : epoch {0} ; mean accuracy {1} '.format( epoch, eval_acc) dev_ppl.append(math.exp(cum_dev_ppl / cum_dev_n_words)) current_best_model_path = None current_best_model_state_dict_path = None if eval_acc > val_acc_best or dev_ppl[-1] < val_ppl_best: last_improvement_epoch = epoch # if alpha > 0 we only save the model if increase in ACC if params.alpha > 0.01 and eval_acc > val_acc_best: print('saving model at epoch {0}'.format(epoch)) # save with torch.save best_model_prefix = os.path.join(current_run_dir, 'best_devacc_') current_best_model_path = best_model_prefix + '_devACC{0:.3f}_devppl{1:.3f}__epoch_{2}_model.pt'.format( eval_acc, dev_ppl[-1], epoch) torch.save(esnli_net, current_best_model_path) for f in glob.glob(best_model_prefix + '*'): if f != current_best_model_path: os.remove(f) # also save model.state_dict() best_state_dict_prefix = os.path.join(current_run_dir, 'state_dict_best_devacc_') current_best_model_state_dict_path = best_state_dict_prefix + '_devACC{0:.3f}_devppl{1:.3f}__epoch_{2}_model.pt'.format( eval_acc, dev_ppl[-1], epoch) state = { 'model_state': esnli_net.state_dict(), 'config_model': config_nli_model, 'params': params } torch.save(state, current_best_model_state_dict_path) for f in glob.glob(best_state_dict_prefix + '*'): if f != current_best_model_state_dict_path: os.remove(f) val_acc_best = eval_acc if dev_ppl[-1] < val_ppl_best: val_ppl_best = dev_ppl[-1] # if alpha = 0 (EXPL_ONLY) we only save the model if decrease in PPL elif params.alpha < 0.01 and dev_ppl[-1] < val_ppl_best: print('saving model at epoch {0}'.format(epoch)) # save with torch.save best_model_prefix = os.path.join(current_run_dir, 'best_devppl_') current_best_model_path = best_model_prefix + '_devPPL{0:.3f}__epoch_{1}_model.pt'.format( dev_ppl[-1], epoch) torch.save(esnli_net, current_best_model_path) for f in glob.glob(best_model_prefix + '*'): if f != current_best_model_path: os.remove(f) # save model.state_dict() best_state_dict_prefix = os.path.join(current_run_dir, 'state_dict_best_devppl_') current_best_model_state_dict_path = best_state_dict_prefix + '_devPPL{0:.3f}__epoch_{1}_model.pt'.format( dev_ppl[-1], epoch) state = { 'model_state': esnli_net.state_dict(), 'config_model': config_nli_model, 'params': params } torch.save(state, current_best_model_state_dict_path) for f in glob.glob(best_state_dict_prefix + '*'): if f != current_best_model_state_dict_path: os.remove(f) val_ppl_best = dev_ppl[-1] else: # no improvement at all, regardless whether it's in PPL or ACC if 'sgd' in params.optimizer: optimizer.param_groups[0][ 'lr'] = optimizer.param_groups[0]['lr'] / params.lrshrink print('Shrinking lr by : {0}. New lr = {1}'.format( params.lrshrink, optimizer.param_groups[0]['lr'])) if optimizer.param_groups[0]['lr'] < params.minlr: stop_training = True print "Stopping training because LR < ", params.minlr # for any optimizer early stopping if (epoch - last_improvement_epoch > params.early_stopping_epochs): stop_training = True print "Stopping training because no more improvement done in the last ", params.early_stopping_epochs, " epochs" return eval_acc, current_best_model_state_dict_path
def eval_datasets_without_expl(esnli_net, which_set, data, word_vec, word_emb_dim, batch_size, print_every, current_run_dir): dict_labels = NLI_DIC_LABELS esnli_net.eval() correct = 0. correct_labels_expl = 0. s1 = data['s1'] s2 = data['s2'] label = data['label'] label_expl = data['label_expl'] headers = [ "gold_label", "Premise", "Hypothesis", "pred_label", "pred_expl", "pred_lbl_decoder" ] expl_csv = os.path.join( current_run_dir, time.strftime("%d:%m") + "_" + time.strftime("%H:%M:%S") + "_" + which_set + ".csv") remove_file(expl_csv) expl_f = open(expl_csv, "a") writer = csv.writer(expl_f) writer.writerow(headers) for i in range(0, len(s1), batch_size): # prepare batch s1_batch, s1_len = get_batch(s1[i:i + batch_size], word_vec) s2_batch, s2_len = get_batch(s2[i:i + batch_size], word_vec) current_bs = s1_batch.size(1) assert_sizes(s1_batch, 3, [s1_batch.size(0), current_bs, word_emb_dim]) assert_sizes(s2_batch, 3, [s2_batch.size(0), current_bs, word_emb_dim]) s1_batch, s2_batch = Variable(s1_batch.cuda()), Variable( s2_batch.cuda()) tgt_label_batch = Variable(torch.LongTensor(label[i:i + batch_size])).cuda() tgt_label_expl_batch = label_expl[i:i + batch_size] expl_t0 = Variable( torch.from_numpy(word_vec['<s>']).float().unsqueeze(0).expand( current_bs, word_emb_dim).unsqueeze(0)).cuda() assert_sizes(expl_t0, 3, [1, current_bs, word_emb_dim]) # model forward pred_expls, out_lbl = esnli_net((s1_batch, s1_len), (s2_batch, s2_len), expl_t0, mode="forloop") assert len(pred_expls) == current_bs, "pred_expls: " + str( len(pred_expls)) + " current_bs: " + str(current_bs) for b in range(len(pred_expls)): assert tgt_label_expl_batch[b] in [ 'entailment', 'neutral', 'contradiction' ] if len(pred_expls[b]) > 0: words = pred_expls[b].strip().split(" ") if words[0] == tgt_label_expl_batch[b]: correct_labels_expl += 1 # accuracy pred = out_lbl.data.max(1)[1] correct += pred.long().eq(tgt_label_batch.data.long()).cpu().sum() # write csv row of predictions # Look up for the headers order for j in range(len(pred_expls)): row = [] row.append(get_key_from_val(label[i + j], dict_labels)) row.append(' '.join(s1[i + j][1:-1])) row.append(' '.join(s2[i + j][1:-1])) row.append(get_key_from_val(pred[j], dict_labels)) row.append(pred_expls[j][1:-1]) row.append(pred_expls[j][0]) writer.writerow(row) # print example if i % print_every == 0: print(which_set.upper() + " example: ") print("Premise: ", ' '.join(s1[i]), " LENGHT: ", s1_len[0]) print("Hypothesis: ", ' '.join(s2[i]), " LENGHT: ", s2_len[0]) print("Gold label: ", get_key_from_val(label[i], dict_labels)) print("Predicted label: ", get_key_from_val(pred[0], dict_labels)) print("Predicted explanation: ", pred_expls[0], "\n\n\n") eval_acc = round(100 * correct / len(s1), 2) eval_acc_label_expl = round(100 * correct_labels_expl / len(s1), 2) print(which_set.upper() + " no train ", eval_acc, '\n\n\n') expl_f.close() return eval_acc, eval_acc_label_expl
def trainepoch(epoch): print('\nTRAINING : Epoch ' + str(epoch)) esnli_net.train() if (epoch > 1) and (params.annealing_alpha) and ( params.alpha + params.annealing_rate <= params.annealing_max): params.alpha += params.annealing_rate print "alpha: ", str(params.alpha) label_costs = [] expl_costs = [] all_losses = [] cum_n_words = 0 cum_ppl = 0 correct = 0. # shuffle the data permutation = np.random.permutation(len(train['s1'])) s1 = train['s1'][permutation] s2 = train['s2'][permutation] expl_1 = train['expl_1'][permutation] label = train['label'][permutation] label_expl = permute(train['label_expl'], permutation) optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * params.decay if epoch>1\ and 'sgd' in params.optimizer else optimizer.param_groups[0]['lr'] print('Learning rate : {0}'.format(optimizer.param_groups[0]['lr'])) for stidx in range(0, len(s1), params.batch_size): # prepare batch s1_batch, s1_len = get_batch(s1[stidx:stidx + params.batch_size], word_vec) s2_batch, s2_len = get_batch(s2[stidx:stidx + params.batch_size], word_vec) input_expl_batch, _ = get_batch( expl_1[stidx:stidx + params.batch_size], word_vec) # eliminate last input to explanation because we wouldn't need to input </s> and we need same number of input and output input_expl_batch = input_expl_batch[:-1] s1_batch, s2_batch, input_expl_batch = Variable( s1_batch.cuda()), Variable(s2_batch.cuda()), Variable( input_expl_batch.cuda()) tgt_label_batch = Variable( torch.LongTensor(label[stidx:stidx + params.batch_size])).cuda() tgt_label_expl_batch = label_expl[stidx:stidx + params.batch_size] tgt_expl_batch, lens_tgt_expl = get_target_expl_batch( expl_1[stidx:stidx + params.batch_size], word_index) assert tgt_expl_batch.dim() == 2, "tgt_expl_batch.dim()=" + str( tgt_expl_batch.dim()) tgt_expl_batch = Variable(tgt_expl_batch).cuda() # model forward train out_expl, out_lbl = esnli_net((s1_batch, s1_len), (s2_batch, s2_len), input_expl_batch, 'teacher') pred = out_lbl.data.max(1)[1] current_bs = len(pred) correct += pred.long().eq(tgt_label_batch.data.long()).cpu().sum() assert len(pred) == len( s1[stidx:stidx + params.batch_size]), "len(pred)=" + str( len(pred) ) + " while len(s1[stidx:stidx + params.batch_size])=" + str( len(s1[stidx:stidx + params.batch_size])) answer_idx = torch.max(out_expl, 2)[1] # print example if stidx % params.print_every == 0: print current_run_dir, '\n' print 'epoch: ', epoch print "Sentence1: ", ' '.join(s1[stidx]), " LENGTH: ", s1_len[0] print "Sentence2: ", ' '.join(s2[stidx]), " LENGTH: ", s2_len[0] print "Gold label: ", get_key_from_val(label[stidx], NLI_DIC_LABELS) print "Predicted label: ", get_key_from_val( pred[0], NLI_DIC_LABELS) print "Explanation: ", ' '.join(expl_1[stidx]) print "Target expl: ", get_sentence_from_indices( word_index, tgt_expl_batch[:, 0]), " LENGTH: ", lens_tgt_expl[0] print "Decoded explanation: ", get_sentence_from_indices( word_index, answer_idx[:, 0]), "\n\n\n" # loss labels loss_labels = criterion_labels(out_lbl, tgt_label_batch) label_costs.append(loss_labels.data[0]) # loss expl; out_expl is T x bs x vocab_sizes, tgt_expl_batch is T x bs loss_expl = criterion_expl( out_expl.view(out_expl.size(0) * out_expl.size(1), -1), tgt_expl_batch.view( tgt_expl_batch.size(0) * tgt_expl_batch.size(1))) expl_costs.append(loss_expl.data[0]) cum_n_words += lens_tgt_expl.sum() cum_ppl += loss_expl.data[0] # backward loss = params.lmbda * (params.alpha * loss_labels + (1 - params.alpha) * loss_expl) all_losses.append(loss.data[0]) optimizer.zero_grad() loss.backward() # infersent version of gradient clipping shrink_factor = 1 # total grads norm total_norm = 0 for p in esnli_net.parameters(): if p.requires_grad: p.grad.data.div_(current_bs) total_norm += p.grad.data.norm()**2 total_norm = np.sqrt(total_norm) total_norms.append(total_norm) # encoder grads norm enc_norm = 0 for p in esnli_net.encoder.parameters(): if p.requires_grad: enc_norm += p.grad.data.norm()**2 enc_norm = np.sqrt(enc_norm) enc_norms.append(enc_norm) if total_norm > params.max_norm: shrink_factor = params.max_norm / total_norm current_lr = optimizer.param_groups[0][ 'lr'] # current lr (no external "lr", for adam) optimizer.param_groups[0][ 'lr'] = current_lr * shrink_factor # just for update # optimizer step optimizer.step() optimizer.param_groups[0]['lr'] = current_lr # print and reset losses if len(all_losses) == params.avg_every: train_all_losses.append(np.mean(all_losses)) train_expl_costs.append(params.lmbda * (1 - params.alpha) * np.mean(expl_costs)) train_label_costs.append(params.lmbda * params.alpha * np.mean(label_costs)) train_ppl.append(math.exp(cum_ppl / cum_n_words)) print '{0} ; epoch: {1}, total loss : {2} ; lmbda * alpha * (lbl loss) : {3}; lmbda * (1-alpha) * (expl loss) : {4} ; train ppl : {5}; accuracy train esnli : {6}'.format( stidx, epoch, round(train_all_losses[-1], 2), round(train_label_costs[-1], 2), round(train_expl_costs[-1], 2), round(train_ppl[-1], 2), round(100. * correct / (stidx + s1_batch.size(1)), 2)) label_costs = [] expl_costs = [] all_losses = [] cum_n_words = 0 cum_ppl = 0 train_acc = round(100 * correct / len(s1), 2) print('results : epoch {0} ; mean accuracy train esnli : {1}'.format( epoch, train_acc)) return train_acc