def get_dev_loss(self, session, dev_context_path, dev_qn_path, dev_ans_path): """ Get loss for entire dev set. Inputs: session: TensorFlow session dev_qn_path, dev_context_path, dev_ans_path: paths to the dev.{context/question/answer} data files Outputs: dev_loss: float. Average loss across the dev set. """ logging.info("Calculating dev loss...") tic = time.time() loss_per_batch, batch_lengths = [], [] # Iterate over dev set batches # Note: here we set discard_long=True, meaning we discard any examples # which are longer than our context_len or question_len. # We need to do this because if, for example, the true answer is cut # off the context, then the loss function is undefined. for batch in get_batch_generator(self.word2id, dev_context_path, dev_qn_path, dev_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True): # Get loss for this batch loss = self.get_loss(session, batch) curr_batch_size = batch.batch_size loss_per_batch.append(loss * curr_batch_size) batch_lengths.append(curr_batch_size) # Calculate average loss total_num_examples = sum(batch_lengths) toc = time.time() print "Computed dev loss over %i examples in %.2f seconds" % ( total_num_examples, toc - tic) # Overall loss is total loss divided by total number of examples dev_loss = sum(loss_per_batch) / float(total_num_examples) return dev_loss
def train(self, session, train_context_path, train_qn_path, train_ans_path, dev_qn_path, dev_context_path, dev_ans_path): """ Main training loop. Inputs: session: TensorFlow session {train/dev}_{qn/context/ans}_path: paths to {train/dev}.{context/question/answer} data files """ # Print number of model parameters tic = time.time() params = tf.trainable_variables() num_params = sum( map(lambda t: np.prod(tf.shape(t.value()).eval()), params)) toc = time.time() logging.info("Number of params: %d (retrieval took %f secs)" % (num_params, toc - tic)) # We will keep track of exponentially-smoothed loss exp_loss = None # Checkpoint management. # We keep one latest checkpoint, and one best checkpoint (early stopping) checkpoint_path = os.path.join(self.FLAGS.train_dir, "qa.ckpt") bestmodel_dir = os.path.join(self.FLAGS.train_dir, "best_checkpoint") bestmodel_ckpt_path = os.path.join(bestmodel_dir, "qa_best.ckpt") best_dev_f1 = None best_dev_em = None # for TensorBoard summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir, session.graph) epoch = 0 logging.info("Beginning training loop...") while self.FLAGS.num_epochs == 0 or epoch < self.FLAGS.num_epochs: epoch += 1 epoch_tic = time.time() # Loop over batches for batch in get_batch_generator( self.word2id, train_context_path, train_qn_path, train_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True): # Run training iteration iter_tic = time.time() loss, global_step, param_norm, grad_norm = self.run_train_iter( session, batch, summary_writer) iter_toc = time.time() iter_time = iter_toc - iter_tic # Update exponentially-smoothed loss if not exp_loss: # first iter exp_loss = loss else: exp_loss = 0.99 * exp_loss + 0.01 * loss # Sometimes print info to screen if global_step % self.FLAGS.print_every == 0: logging.info( 'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f' % (epoch, global_step, loss, exp_loss, grad_norm, param_norm, iter_time)) # Sometimes save model if global_step % self.FLAGS.save_every == 0: logging.info("Saving to %s..." % checkpoint_path) self.saver.save(session, checkpoint_path, global_step=global_step) # Sometimes evaluate model on dev loss, train F1/EM and dev F1/EM if global_step % self.FLAGS.eval_every == 0: # Get loss for entire dev set and log to tensorboard dev_loss = self.get_dev_loss(session, dev_context_path, dev_qn_path, dev_ans_path) logging.info("Epoch %d, Iter %d, dev loss: %f" % (epoch, global_step, dev_loss)) write_summary(dev_loss, "dev/loss", summary_writer, global_step) # Get F1/EM on train set and log to tensorboard train_f1, train_em = self.check_f1_em(session, train_context_path, train_qn_path, train_ans_path, "train", num_samples=1000) logging.info( "Epoch %d, Iter %d, Train F1 score: %f, Train EM score: %f" % (epoch, global_step, train_f1, train_em)) write_summary(train_f1, "train/F1", summary_writer, global_step) write_summary(train_em, "train/EM", summary_writer, global_step) # Get F1/EM on dev set and log to tensorboard dev_f1, dev_em = self.check_f1_em(session, dev_context_path, dev_qn_path, dev_ans_path, "dev", num_samples=0) logging.info( "Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f" % (epoch, global_step, dev_f1, dev_em)) write_summary(dev_f1, "dev/F1", summary_writer, global_step) write_summary(dev_em, "dev/EM", summary_writer, global_step) # Early stopping based on dev EM. You could switch this to use F1 instead. if best_dev_em is None or dev_em > best_dev_em: best_dev_em = dev_em logging.info("Saving to %s..." % bestmodel_ckpt_path) self.bestmodel_saver.save(session, bestmodel_ckpt_path, global_step=global_step) epoch_toc = time.time() logging.info("End of epoch %i. Time for epoch: %f" % (epoch, epoch_toc - epoch_tic)) sys.stdout.flush()
def check_f1_em(self, session, context_path, qn_path, ans_path, dataset, num_samples=100, print_to_screen=False): """ Sample from the provided (train/dev) set. For each sample, calculate F1 and EM score. Return average F1 and EM score for all samples. Optionally pretty-print examples. Note: This function is not quite the same as the F1/EM numbers you get from "official_eval" mode. This function uses the pre-processed version of the e.g. dev set for speed, whereas "official_eval" mode uses the original JSON. Therefore: 1. official_eval takes your max F1/EM score w.r.t. the three reference answers, whereas this function compares to just the first answer (which is what's saved in the preprocessed data) 2. Our preprocessed version of the dev set is missing some examples due to tokenization issues (see squad_preprocess.py). "official_eval" includes all examples. Inputs: session: TensorFlow session qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files. dataset: string. Either "train" or "dev". Just for logging purposes. num_samples: int. How many samples to use. If num_samples=0 then do whole dataset. print_to_screen: if True, pretty-prints each example to screen Returns: F1 and EM: Scalars. The average across the sampled examples. """ logging.info( "Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset)) f1_total = 0. em_total = 0. example_num = 0 tic = time.time() # Note here we select discard_long=False because we want to sample from the entire dataset # That means we're truncating, rather than discarding, examples with too-long context or questions for batch in get_batch_generator(self.word2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=False): pred_start_pos, pred_end_pos = self.get_start_end_pos( session, batch) # Convert the start and end positions to lists length batch_size pred_start_pos = pred_start_pos.tolist() # list length batch_size pred_end_pos = pred_end_pos.tolist() # list length batch_size for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in enumerate( zip(pred_start_pos, pred_end_pos, batch.ans_tokens)): example_num += 1 # Get the predicted answer # Important: batch.context_tokens contains the original words (no UNKs) # You need to use the original no-UNK version when measuring F1/EM pred_ans_tokens = batch.context_tokens[ex_idx][ pred_ans_start:pred_ans_end + 1] pred_answer = " ".join(pred_ans_tokens) # Get true answer (no UNKs) true_answer = " ".join(true_ans_tokens) # Calc F1/EM f1 = f1_score(pred_answer, true_answer) em = exact_match_score(pred_answer, true_answer) f1_total += f1 em_total += em # Optionally pretty-print if print_to_screen: print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em) if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break f1_total /= example_num em_total /= example_num toc = time.time() logging.info( "Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc - tic)) return f1_total, em_total
param_norm = tf.global_norm(params) global_step = tf.Variable(0, name="global_step", trainable=False) opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) # you can try other optimizers updates = opt.apply_gradients(zip(clipped_gradients, params), global_step=global_step) # Define savers (for checkpointing) and summaries (for tensorboard) saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.keep) bestmodel_saver = tf.train.Saver(tf.global_variables(), max_to_keep=1) summaries = tf.summary.merge_all() # In[] for batch in get_batch_generator(word2id, train_context_path, train_qn_path, train_ans_path, FLAGS.batch_size, context_len=FLAGS.context_len, question_len=FLAGS.question_len, discard_long=True): #loss, global_step, param_norm, grad_norm = run_train_iter(session, batch, summary_writer) # Match up our input data with the placeholders input_feed = {} print('1') input_feed[context_ids] = batch.context_ids print(batch.context_ids.shape) input_feed[context_mask] = batch.context_mask input_feed[qn_ids] = batch.qn_ids input_feed[qn_mask] = batch.qn_mask input_feed[ans_span] = batch.ans_span input_feed[keep_prob] = 1.0 - FLAGS.dropout # apply dropout input_feed[question_length] = batch.qn_length input_feed[document_length] = batch.context_length # output_feed contains the things we want to fetch. #output_feed = [updates, summaries, loss, global_step, param_norm, gradient_norm]