def get_dev_loss(self, model): logging.info("Calculating dev loss...") tic = time.time() loss_per_batch, batch_lengths = [], [] i = 0 for batch in get_batch_generator(self.word2id, self.dev_context_path, self.dev_qn_path, self.dev_ans_path, config.batch_size, context_len=config.context_len, question_len=config.question_len, discard_long=True): loss, _, _ = self.eval_one_batch(batch, model) curr_batch_size = batch.batch_size loss_per_batch.append(loss * curr_batch_size) batch_lengths.append(curr_batch_size) i += 1 if i == 10: break total_num_examples = sum(batch_lengths) toc = time.time() print "Computed dev loss over %i examples in %.2f seconds" % ( total_num_examples, toc - tic) dev_loss = sum(loss_per_batch) / float(total_num_examples) return dev_loss
def train(context_path, qn_path, ans_path): """ Train the network """ model = Decoder(emb_matrix, 2) # Select the parameters which require grad / backpropagation params = list(filter(lambda p: p.requires_grad, model.parameters())) optimizer = optim.SGD(params, lr=config.learning_rate, weight_decay=config.l2_norm) checkpoint_name = "checkpoint-Embed{}-ep{}-iter{}".format( config.embedding_dim, 2, 1000) checkpoint_name = os.path.join(config.experiments_root_dir, checkpoint_name) # If the network has saved model, restore it if os.path.exists(checkpoint_name): state = torch.load(checkpoint_name) model.load_state_dict(state['model']) optimizer.load_state_dict(state['optimizer']) start_epoch = state['epoch'] i = state['iter'] current_loss = state['loss'] print("Model restored from ", checkpoint_name) print("Epoch : {}\tIter {}\t\tloss : {}".format( start_epoch, i, current_loss)) else: print("Training with fresh parameters") # For each epoch for epoch in range(config.num_epochs): # For each batch for i, batch in enumerate( get_batch_generator(word2index, context_path, qn_path, ans_path, config.batch_size, config.context_len, config.question_len, discard_long=True)): # Take step in training loss = step(model, optimizer, batch) # Displaying results if i % config.print_every == 0: f1 = evaluate(model, batch) print("Epoch : {}\tIter {}\t\tloss : {}\tf1 : {}".format( epoch, i, "%.2f" % loss, "%.2f" % f1)) # Maybe you want to do random evaluations as well for sanity check # Saving the model if i % config.save_every == 0: state = { 'iter': i, 'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'current_loss': loss } checkpoint_name = "checkpoint-Embed{}-ep{}-iter{}".format( config.embedding_dim, epoch, i) checkpoint_name = os.path.join(config.experiments_root_dir, checkpoint_name) torch.save(state, checkpoint_name)
def train(self, model_file_path): train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time()))) if not os.path.exists(train_dir): os.mkdir(train_dir) model_dir = os.path.join(train_dir, 'model') if not os.path.exists(model_dir): os.mkdir(model_dir) bestmodel_dir = os.path.join(train_dir, 'bestmodel') if not os.path.exists(bestmodel_dir): os.makedirs(bestmodel_dir) summary_writer = tf.summary.FileWriter(train_dir) with open(os.path.join(train_dir, "flags.json"), 'w') as fout: json.dump(vars(config), fout) model = self.get_model(model_file_path) params = list(filter(lambda p: p.requires_grad, model.parameters())) optimizer = Adam(params, lr=config.lr, amsgrad=True) num_params = sum(p.numel() for p in params) logging.info("Number of params: %d" % num_params) exp_loss, best_dev_f1, best_dev_em = None, None, None epoch = 0 global_step = 0 logging.info("Beginning training loop...") while config.num_epochs == 0 or epoch < config.num_epochs: epoch += 1 epoch_tic = time.time() for batch in get_batch_generator(self.word2id, self.train_context_path, self.train_qn_path, self.train_ans_path, config.batch_size, context_len=config.context_len, question_len=config.question_len, discard_long=True): global_step += 1 iter_tic = time.time() loss, param_norm, grad_norm = self.train_one_batch( batch, model, optimizer, params) write_summary(loss, "train/loss", summary_writer, global_step) iter_toc = time.time() iter_time = iter_toc - iter_tic if not exp_loss: exp_loss = loss else: exp_loss = 0.99 * exp_loss + 0.01 * loss if global_step % config.print_every == 0: logging.info( 'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f' % (epoch, global_step, loss, exp_loss, grad_norm, param_norm, iter_time)) if global_step % config.save_every == 0: logging.info("Saving to %s..." % model_dir) self.save_model(model, optimizer, loss, global_step, epoch, model_dir) if global_step % config.eval_every == 0: dev_loss = self.get_dev_loss(model) logging.info("Epoch %d, Iter %d, dev loss: %f" % (epoch, global_step, dev_loss)) write_summary(dev_loss, "dev/loss", summary_writer, global_step) train_f1, train_em = self.check_f1_em(model, "train", num_samples=1000) logging.info( "Epoch %d, Iter %d, Train F1 score: %f, Train EM score: %f" % (epoch, global_step, train_f1, train_em)) write_summary(train_f1, "train/F1", summary_writer, global_step) write_summary(train_em, "train/EM", summary_writer, global_step) dev_f1, dev_em = self.check_f1_em(model, "dev", num_samples=0) logging.info( "Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f" % (epoch, global_step, dev_f1, dev_em)) write_summary(dev_f1, "dev/F1", summary_writer, global_step) write_summary(dev_em, "dev/EM", summary_writer, global_step) if best_dev_f1 is None or dev_f1 > best_dev_f1: best_dev_f1 = dev_f1 if best_dev_em is None or dev_em > best_dev_em: best_dev_em = dev_em logging.info("Saving to %s..." % bestmodel_dir) self.save_model(model, optimizer, loss, global_step, epoch, bestmodel_dir) epoch_toc = time.time() logging.info("End of epoch %i. Time for epoch: %f" % (epoch, epoch_toc - epoch_tic)) sys.stdout.flush()
def check_f1_em(self, model, dataset, num_samples=100, print_to_screen=False): logging.info( "Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset)) if dataset == "train": context_path, qn_path, ans_path = self.train_context_path, self.train_qn_path, self.train_ans_path elif dataset == "dev": context_path, qn_path, ans_path = self.dev_context_path, self.dev_qn_path, self.dev_ans_path else: raise ('dataset is not defined') f1_total = 0. em_total = 0. example_num = 0 tic = time.time() for batch in get_batch_generator(self.word2id, context_path, qn_path, ans_path, config.batch_size, context_len=config.context_len, question_len=config.question_len, discard_long=False): pred_start_pos, pred_end_pos = self.test_one_batch(batch, model) pred_start_pos = pred_start_pos.tolist() pred_end_pos = pred_end_pos.tolist() for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) \ in enumerate(zip(pred_start_pos, pred_end_pos, batch.ans_tokens)): example_num += 1 pred_ans_tokens = batch.context_tokens[ex_idx][ pred_ans_start:pred_ans_end + 1] pred_answer = " ".join(pred_ans_tokens) true_answer = " ".join(true_ans_tokens) f1 = f1_score(pred_answer, true_answer) em = exact_match_score(pred_answer, true_answer) f1_total += f1 em_total += em if print_to_screen: print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em) if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break f1_total /= example_num em_total /= example_num toc = time.time() logging.info( "Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc - tic)) return f1_total, em_total
def train(context_path, qn_path, ans_path): """ Train the network """ model = N.CoattentionNetwork( device=config.device, hidden_size=config.hidden_size, emb_matrix=emb_matrix, num_encoder_layers=config.num_encoder_layers, num_fusion_bilstm_layers=config.num_fusion_bilstm_layers, num_decoder_layers=config.num_decoder_layers, batch_size=config.batch_size, max_dec_steps=config.max_dec_steps, fusion_dropout_rate=config.fusion_dropout_rate, encoder_bidirectional=config.encoder_bidirectional, decoder_bidirectional=config.decoder_bidirectional) # Select the parameters which require grad / backpropagation params = list(filter(lambda p: p.requires_grad, model.parameters())) optimizer = optim.SGD(params, lr=config.learning_rate, weight_decay=config.l2_norm) # Set up directories for this experiment if not os.path.exists(config.experiments_root_dir): os.makedirs(config.experiments_root_dir) serial_number = len(os.listdir(config.experiments_root_dir)) if config.restore: serial_number -= 1 # Check into the latest model experiment_dir = os.path.join(config.experiments_root_dir, 'experiment_{}'.format(serial_number)) if not os.path.exists(experiment_dir): os.makedirs(experiment_dir) model_dir = os.path.join(experiment_dir, 'model') if not os.path.exists(model_dir): os.makedirs(model_dir) # Save config as config.json with open(os.path.join(experiment_dir, "config.json"), 'w') as fout: json.dump(vars(config), fout) iteration = 0 if config.restore: saved_models = os.listdir(model_dir) if len(saved_models): print(saved_models) saved_models = [int(name.split('-')[-1]) for name in saved_models] latest_iter = max(saved_models) checkpoint_name = "checkpoint-embed{}-iter-{}".format( config.embedding_dim, latest_iter) checkpoint_name = os.path.join(model_dir, checkpoint_name) state = torch.load(checkpoint_name) model.load_state_dict(state['model']) optimizer.load_state_dict(state['optimizer']) iteration = state['iter'] print("Model restored from ", checkpoint_name) else: print("Training with fresh parameters") for epoch in range(config.num_epochs): for batch in get_batch_generator(word2index, context_path, qn_path, ans_path, config.batch_size, config.context_len, config.question_len, discard_long=True): # When the batch is partially filled, ignore it. if batch.batch_size < config.batch_size: del batch continue # Take step in training loss = step(model, optimizer, batch, params) # Displaying results if iteration % config.print_every == 0: print("Iter {}\t\tloss : {}\tf1 : {}".format( iteration, "%.5f" % loss, "%.4f" % -1)) if iteration % config.evaluate_every == 0: f1 = evaluate(model, batch) print("Iter {}\t\tloss : {}\tf1 : {}".format( iteration, "%.5f" % loss, "%.4f" % f1)) # Maybe you want to do random evaluations as well for sanity check # Saving the model if iteration % config.save_every == 0: state = { 'iter': iteration, 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': loss } checkpoint_name = "checkpoint-embed{}-iter-{}".format( config.embedding_dim, iteration) fname = os.path.join(model_dir, checkpoint_name) torch.save(state, fname) del loss iteration += 1