def get_val_loss(vocab, args, model): total_loss = 0 total_util_loss = 0 total_answer_loss = 0 util_examples = 0 answer_examples = 0 model.eval() #criterion = torch.nn.CrossEntropyLoss(reduction='sum') criterion = InfoNCELoss(size_average=False).cuda() for ids, posts, questions, answers, labels in batch_iter(val_ids, \ post_content, qa_dict, vocab, args.batch_size, shuffle=False): util_examples += len(ids) question_vectors = vocab.id2vector(questions) post_vectors = vocab.id2vector(posts) answer_vectors = vocab.id2vector(answers) padded_posts, post_pad_idx = pad_sequence(args.device, posts) padded_questions, question_pad_idx = pad_sequence(args.device, questions) padded_answers, answer_pad_idx = pad_sequence(args.device, answers) pqa_probs = model(ids, (padded_posts, post_pad_idx),\ (padded_questions, question_pad_idx),\ (padded_answers, answer_pad_idx)) labels = torch.tensor(labels).to(device=args.device) util_loss = criterion(pqa_probs, labels) #bp() total_util_loss += util_loss.item() total_loss = (total_util_loss / util_examples) model.train() return total_loss
def train(): device = args.device log_every = args.log_every valid_iter = args.valid_iter train_iter = 0 cum_loss = 0 avg_loss = 0 avg_util_loss = 0 avg_answer_loss = 0 valid_num = 0 patience = 0 num_trial = 0 hist_valid_scores = [] begin_time = time.time() vocab = get_vocab(args.vocab_file) model = EVPI(args, vocab) if args.use_embed == 1: model.load_vector(args, vocab) print("Placing model on ", args.device) if args.device == 'cuda': model.cuda() lr = args.lr optim = torch.optim.Adam(list(model.parameters()), lr=lr) # The loss functions #criterion = torch.nn.CrossEntropyLoss().to(device=device) #criterion = FocalLoss() criterion = InfoNCELoss().cuda() print("Beginning Training") model.train() cosine_function = torch.nn.functional.cosine_similarity model_counter = 0 train_iter = 0 for ep in range(args.max_epochs): val_iter = 0 count = 0 hello = set() for ids, posts, questions, answers, labels in batch_iter(train_ids, \ post_content, qa_dict, vocab, args.batch_size, shuffle=False): train_iter += 1 optim.zero_grad() question_vectors = vocab.id2vector(questions) post_vectors = vocab.id2vector(posts) answer_vectors = vocab.id2vector(answers) padded_posts, post_pad_idx = pad_sequence(args.device, posts) padded_questions, question_pad_idx = pad_sequence(args.device, questions) padded_answers, answer_pad_idx = pad_sequence(args.device, answers) pqa_probs = model(ids, (padded_posts, post_pad_idx),\ (padded_questions, question_pad_idx),\ (padded_answers, answer_pad_idx)) labels = torch.tensor(labels).to(device=args.device) total_loss = criterion(pqa_probs, labels) #bp() avg_loss += total_loss.item() cum_loss += total_loss.item() total_loss.backward() torch.nn.utils.clip_grad_norm_(list(model.parameters()), args.clip_grad) optim.step() if train_iter % log_every == 0: print('epoch %d, iter %d, avg.loss %.6f, time elapsed %.2f'\ % (ep + 1, train_iter, avg_loss / log_every, time.time() - begin_time), file=sys.stderr) begin_time = time.time() avg_loss = 0 if train_iter % valid_iter == 0: print('epoch %d, iter %d, cum.loss %.2f, time elapsed %.2f'\ % (ep + 1, train_iter, cum_loss / valid_iter, time.time() - begin_time), file=sys.stderr) cum_loss = 0 valid_num += 1 print("Begin Validation ", file=sys.stderr) model.eval() val_loss = get_val_loss(vocab, args, model) model.train() print('validation: iter %d, loss %f' % (train_iter, val_loss), file=sys.stderr) is_better = (len(hist_valid_scores) == 0) or (val_loss < min(hist_valid_scores)) hist_valid_scores.append(val_loss) if is_better: patience = 0 print("Save the current model and optimiser state") torch.save(model, args.model_save_path) #torch.save(model, args.model_save_path + '.' + str(val_loss) + '-' + str(model_counter)) #model_counter += 1 torch.save(optim.state_dict(), args.model_save_path + '.optim') elif patience < args.patience: patience += 1 print('hit patience %d' % patience, file=sys.stderr) if patience == args.patience: num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == args.max_num_trials: print('early stop!', file=sys.stderr) return lr = lr * args.lr_decay print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr) model = load(args.model_save_path) model.train() print('restore parameters of the optimizers', file=sys.stderr) optim = torch.optim.Adam(list(model.parameters()), lr=lr) optim.load_state_dict(torch.load(args.model_save_path + '.optim')) for state in optim.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(args.device) for group in optim.param_groups: group['lr'] = lr patience = 0 print("Training Finished", file=sys.stderr)
def get_val_loss(vocab, args, model, ep): total_loss = 0 total_util_loss = 0 total_answer_loss = 0 util_examples = 0 answer_examples = 0 model.eval() #criterion = torch.nn.CrossEntropyLoss(reduction='sum') criterion = InfoNCELoss().to(device='cuda') valid_epoch_iter = 0 for ids, posts, questions, answers, labels in batch_iter(val_ids, \ post_content, qa_dict, vocab, args.batch_size, shuffle=False): print("Validation Iteration {}".format(valid_epoch_iter)) util_examples += len(ids) #question_vectors = vocab.id2vector(questions) #post_vectors = vocab.id2vector(posts) #answer_vectors = vocab.id2vector(answers) #padded_posts, post_pad_idx = pad_sequence(args.device, posts) #padded_questions, question_pad_idx = pad_sequence(args.device, questions) #padded_answers, answer_pad_idx = pad_sequence(args.device, answers) #if ep == 1: # with open('valid_p_embeddings.pickle', 'wb') as b: # pickle.dump(valid_p_embeddings, b) # with open('valid_q_embeddings.pickle', 'wb') as b: # pickle.dump(valid_q_embeddings, b) # with open('valid_a_embeddings.pickle', 'wb') as b: # pickle.dump(valid_a_embeddings, b) #if ep == 0: # posts_embeddings = np.asarray(sentence_bert_model.encode(posts)) # questions_embeddings = np.asarray(sentence_bert_model.encode(questions)) # answers_embeddings = np.asarray(sentence_bert_model.encode(answers)) # valid_p_embeddings.append(posts_embeddings) # valid_q_embeddings.append(questions_embeddings) # valid_a_embeddings.append(answers_embeddings) # print("Embeddings Cached for Validation Iteration {}".format(valid_epoch_iter)) #else: posts_embeddings = valid_p_embeddings[valid_epoch_iter] questions_embeddings = valid_q_embeddings[valid_epoch_iter] answers_embeddings = valid_a_embeddings[valid_epoch_iter] valid_epoch_iter += 1 #posts_embeddings = np.asarray(sentence_bert_model.encode(posts)) #questions_embeddings = np.asarray(sentence_bert_model.encode(questions)) #answers_embeddings = np.asarray(sentence_bert_model.encode(answers)) posts_embeddings = torch.from_numpy(posts_embeddings).float().to(args.device) questions_embeddings = torch.from_numpy(questions_embeddings).float().to(args.device) answers_embeddings = torch.from_numpy(answers_embeddings).float().to(args.device) pqa_probs = model(posts_embeddings, questions_embeddings, answers_embeddings) #pqa_probs = model(ids, posts, questions, answers) labels = torch.tensor(labels).to(device=args.device) util_loss = criterion(pqa_probs, labels) total_util_loss += util_loss.item() total_loss = (total_util_loss / util_examples) model.train() return total_loss
def train(): device = args.device log_every = args.log_every valid_iter = args.valid_iter train_iter = 0 cum_loss = 0 avg_loss = 0 avg_util_loss = 0 avg_answer_loss = 0 valid_num = 0 patience = 0 num_trial = 0 hist_valid_scores = [] begin_time = time.time() vocab = get_vocab(args.vocab_file) model = EVPI(args, vocab) if args.use_embed == 1: model.load_vector(args, vocab) print("Placing model on ", args.device) if args.device == 'cuda': model = model.cuda() lr = args.lr optim = torch.optim.Adam(list(model.parameters()), lr=lr) # The loss functions # criterion = torch.nn.CrossEntropyLoss().to(device=device) criterion = InfoNCELoss().to(device=device) print("Beginning Training") model.train() cosine_function = torch.nn.functional.cosine_similarity model_counter = 0 train_iter = 0 #p_embeddings = pickle.load( open( "p_large_embeddings.pickle", "rb" ) ) #print("P embeddings loaded", p_embeddings[0].shape) #q_embeddings = pickle.load( open( "q_large_embeddings.pickle", "rb" ) ) #print("Q embeddings loaded", q_embeddings[0].shape) #a_embeddings = pickle.load( open( "a_large_embeddings.pickle", "rb" ) ) #print("A embeddings loaded", a_embeddings[0].shape) # favorite_color = pickle.load( open( "a_embeddings.pickle", "rb" ) ) q_embeddings = [] #pickle.load( open( "q_large_embeddings.pickle", "rb" ) ) #print("Q embeddings loaded", q_embeddings[0].shape) a_embeddings = pickle.load( open( "a_large_embeddings.pickle", "rb" ) ) print("A embeddings loaded", a_embeddings[0].shape) q_embeddings = pickle.load( open( "q_large_embeddings.pickle", "rb" ) ) print("Q embeddings loaded", q_embeddings[0].shape) p_embeddings = pickle.load( open( "p_large_embeddings.pickle", "rb" ) ) print("P embeddings loaded", p_embeddings[0].shape) #q_embeddings = pickle.load( open( "q_large_embeddings.pickle", "rb" ) ) #print("Q embeddings loaded", q_embeddings[0].shape) #a_embeddings = pickle.load( open( "a_large_embeddings.pickle", "rb" ) ) #print("A embeddings loaded", a_embeddings[0].shape) for ep in range(args.max_epochs): val_iter = 0 epoch_iter = 0 count = 0 hello = set() #p_embeddings = pickle.load( open( "p_large_embeddings.pickle", "rb" ) ) #print("P embeddings loaded", p_embeddings[0].shape) #q_embeddings = pickle.load( open( "q_large_embeddings.pickle", "rb" ) ) #print("Q embeddings loaded", q_embeddings[0].shape) #a_embeddings = pickle.load( open( "a_large_embeddings.pickle", "rb" ) ) #print("A embeddings loaded", a_embeddings[0].shape) for ids, posts, questions, answers, labels in batch_iter(train_ids, \ post_content, qa_dict, vocab, args.batch_size, shuffle=False): train_iter += 1 #print(train_iter) optim.zero_grad() #question_vectors = vocab.id2vector(questions) #post_vectors = vocab.id2vector(posts) #answer_vectors = vocab.id2vector(answers) #padded_posts, post_pad_idx = pad_sequence(args.device, posts) #padded_questions, question_pad_idx = pad_sequence(args.device, questions) #padded_answers, answer_pad_idx = pad_sequence(args.device, answers) #posts = torch.tensor(posts).to(device=args.device) #questions = torch.tensor(questions).to(device=args.device) #answers = torch.tensor(answers).to(device=args.device) #if ep == 1: # with open('p_large_embeddings.pickle', 'wb') as b: # pickle.dump(p_embeddings, b) # with open('q_large_embeddings.pickle', 'wb') as b: # pickle.dump(q_embeddings, b) # with open('a_large_embeddings.pickle', 'wb') as b: # pickle.dump(a_embeddings, b) #if ep == 0: # posts_embeddings = np.asarray(sentence_bert_model.encode(posts)) # questions_embeddings = np.asarray(sentence_bert_model.encode(questions)) # answers_embeddings = np.asarray(sentence_bert_model.encode(answers)) # p_embeddings.append(posts_embeddings) # q_embeddings.append(questions_embeddings) # a_embeddings.append(answers_embeddings) # print("Embeddings Cached for Iteration {}".format(epoch_iter)) #else: # questions_embeddings = q_embeddings[epoch_iter] posts_embeddings = p_embeddings[epoch_iter] questions_embeddings = q_embeddings[epoch_iter] answers_embeddings = a_embeddings[epoch_iter] epoch_iter += 1 posts_embeddings = torch.from_numpy(posts_embeddings).float().to(args.device) questions_embeddings = torch.from_numpy(questions_embeddings).float().to(args.device) answers_embeddings = torch.from_numpy(answers_embeddings).float().to(args.device) pqa_probs = model(posts_embeddings, questions_embeddings, answers_embeddings) labels = torch.tensor(labels).to(device=args.device) #bp() total_loss = criterion(pqa_probs, labels) #bp() avg_loss += total_loss.item() cum_loss += total_loss.item() total_loss.backward() torch.nn.utils.clip_grad_norm_(list(model.parameters()), args.clip_grad) optim.step() if train_iter % log_every == 0: print('epoch %d, iter %d, avg.loss %.6f, time elapsed %.2f'\ % (ep + 1, train_iter, avg_loss / log_every, time.time() - begin_time), file=sys.stderr) begin_time = time.time() avg_loss = 0 if train_iter % valid_iter == 0: print('epoch %d, iter %d, cum.loss %.2f, time elapsed %.2f'\ % (ep + 1, train_iter, cum_loss / valid_iter, time.time() - begin_time), file=sys.stderr) cum_loss = 0 valid_num += 1 print("Begin Validation ", file=sys.stderr) model.eval() val_loss = get_val_loss(vocab, args, model, ep) model.train() print('validation: iter %d, loss %f' % (train_iter, val_loss), file=sys.stderr) is_better = (len(hist_valid_scores) == 0) or (val_loss < min(hist_valid_scores)) hist_valid_scores.append(val_loss) if is_better: patience = 0 print("Save the current model and optimiser state") torch.save(model, args.model_save_path) #torch.save(model, args.model_save_path + '.' + str(val_loss) + '-' + str(model_counter)) #model_counter += 1 torch.save(optim.state_dict(), args.model_save_path + '.optim') elif patience < args.patience: patience += 1 print('hit patience %d' % patience, file=sys.stderr) if patience == args.patience: num_trial += 1 print('hit #%d trial' % num_trial, file=sys.stderr) if num_trial == args.max_num_trials: print('early stop!', file=sys.stderr) return lr = lr * args.lr_decay print('load previously best model and decay learning rate to %f' % lr, file=sys.stderr) model = load(args.model_save_path) model.train() print('restore parameters of the optimizers', file=sys.stderr) optim = torch.optim.Adam(list(model.parameters()), lr=lr) optim.load_state_dict(torch.load(args.model_save_path + '.optim')) for state in optim.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(args.device) for group in optim.param_groups: group['lr'] = lr patience = 0 print("Training Finished", file=sys.stderr)