def get_c2q_attention(self, session, context_path, qn_path, ans_path, dataset, num_samples=0): """ Sample from the provided (train/dev) set. Inputs: session: TensorFlow session qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files. dataset: string. Either "train" or "dev". Just for logging purposes. Returns: begin_prob, end_prob: The average probabilities the sampled examples. """ total_c2q_attention = [] example_num = 0 for batch in get_batch_generator( self.word2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=False, random=False): c2q_dists = self.get_c2q_attention_dist(session, batch) c2q_list = c2q_dists.tolist() # list length batch_size for _, (c2q_dist) in enumerate(c2q_list): example_num += 1 total_c2q_attention.append(c2q_dist) # print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em) if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break return np.asarray(total_c2q_attention)
def evaluate(model, word2id, FLAGS, dev_context_path, dev_qn_path, dev_ans_path): logging.info("Calculating F1/EM for all examples in dev set...") f1_total = 0. em_total = 0. example_num = 0 tic = time.time() for batch in get_batch_generator(word2id, dev_context_path, dev_qn_path, dev_ans_path, FLAGS.batch_size, context_len=FLAGS.context_len, question_len=FLAGS.question_len, discard_long=False): # print(type(batch)) prob_start, prob_end = model.predict([ batch.context_ids, batch.context_mask, batch.qn_ids, batch.qn_mask ]) start_pos = np.argmax(prob_start, axis=1) end_pos = np.argmax(prob_end, axis=1) pred_start_pos = start_pos.tolist() pred_end_pos = end_pos.tolist() for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in enumerate( zip(pred_start_pos, pred_end_pos, batch.ans_tokens)): example_num += 1 # Get the predicted answer # Important: batch.context_tokens contains the original words (no UNKs) # You need to use the original no-UNK version when measuring F1/EM pred_ans_tokens = batch.context_tokens[ex_idx][ pred_ans_start:pred_ans_end + 1] pred_answer = " ".join(pred_ans_tokens) # Get true answer (no UNKs) true_answer = " ".join(true_ans_tokens) # Calc F1/EM f1 = f1_score(pred_answer, true_answer) em = exact_match_score(pred_answer, true_answer) f1_total += f1 em_total += em # print(f1, em, example_num) f1_total /= example_num em_total /= example_num toc = time.time() print("Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, "dev", toc - tic)) return f1_total, em_total
def get_spans(self, session, context_path, qn_path, ans_path, dataset, num_samples=0): """ Sample from the provided (train/dev) set. Inputs: session: TensorFlow session qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files. dataset: string. Either "train" or "dev". Just for logging purposes. Returns: begin_prob, end_prob: The average probabilities the sampled examples. """ total_start_dists = [] total_end_dists = [] f1_em_scores = [] example_num = 0 for batch in get_batch_generator( self.word2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=False, random=False): pred_start_dists, pred_end_dists = self.get_prob_dists(session, batch) pred_start_pos, pred_end_pos = self.get_start_end_pos(session, batch) # Convert the start and end positions to lists length batch_size pred_start_pos = pred_start_pos.tolist() # list length batch_size pred_end_pos = pred_end_pos.tolist() # list length batch_size pred_start_dists = pred_start_dists.tolist() # list length batch_size pred_end_dists = pred_end_dists.tolist() # list length batch_size for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in enumerate(zip(pred_start_pos, pred_end_pos, batch.ans_tokens)): example_num += 1 # Get the predicted answer # Important: batch.context_tokens contains the original words (no UNKs) # You need to use the original no-UNK version when measuring F1/EM pred_ans_tokens = batch.context_tokens[ex_idx][pred_ans_start : pred_ans_end + 1] pred_answer = " ".join(pred_ans_tokens) # Get true answer (no UNKs) true_answer = " ".join(true_ans_tokens) # Calc F1/EM f1 = f1_score(pred_answer, true_answer) em = exact_match_score(pred_answer, true_answer) f1_em_scores.append((f1,em)) # print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em) if num_samples != 0 and example_num >= num_samples: break # Convert the start and end positions to lists length batch_size total_end_dists += pred_end_dists total_start_dists += pred_start_dists if num_samples != 0 and example_num >= num_samples: break return np.asarray(total_start_dists), np.asarray(total_end_dists), np.asarray(f1_em_scores)
def get_dev_loss(self, session, dev_context_path, dev_qn_path, dev_ans_path): """ Get loss for entire dev set. """ logging.info("Calculating dev loss...") tic = time.time() loss_per_batch, batch_lengths = [], [] # Iterate over dev set batches for batch in get_batch_generator(self.word2id, dev_context_path, dev_qn_path, dev_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True): # Get loss for this batch loss = self.get_loss(session, batch) curr_batch_size = batch.batch_size loss_per_batch.append(loss * curr_batch_size) batch_lengths.append(curr_batch_size) # Calculate average loss total_num_examples = sum(batch_lengths) toc = time.time() print ("Computed dev loss over %i examples in %.2f seconds" % (total_num_examples, toc-tic)) # Overall loss is total loss divided by total number of examples dev_loss = sum(loss_per_batch) / float(total_num_examples) return dev_loss
def get_dev_loss(self, session, dev_context_path, dev_qn_path, dev_ans_path): """ Get loss for entire dev set. Inputs: session: TensorFlow session dev_qn_path, dev_context_path, dev_ans_path: paths to the dev.{context/question/answer} data files Outputs: dev_loss: float. Average loss across the dev set. """ logging.info("Calculating dev loss...") tic = time.time() loss_per_batch, batch_lengths = [], [] # Iterate over dev set batches # Note: here we set discard_long=True, meaning we discard any examples # which are longer than our context_len or question_len. # We need to do this because if, for example, the true answer is cut # off the context, then the loss function is undefined. for batch in get_batch_generator(self.word2id, dev_context_path, dev_qn_path, dev_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True): # Get loss for this batch loss = self.get_loss(session, batch) curr_batch_size = batch.batch_size loss_per_batch.append(loss * curr_batch_size) batch_lengths.append(curr_batch_size) # Calculate average loss total_num_examples = sum(batch_lengths) toc = time.time() print("Computed dev loss over %i examples in %.2f seconds" % (total_num_examples, toc-tic)) # Overall loss is total loss divided by total number of examples dev_loss = sum(loss_per_batch) / float(total_num_examples) return dev_loss
def check_f1_em(self, model, dataset, num_samples=100, print_to_screen=False): logging.info("Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset)) if dataset == "train": context_path, qn_path, ans_path = self.train_context_path, self.train_qn_path, self.train_ans_path elif dataset == "dev": context_path, qn_path, ans_path = self.dev_context_path, self.dev_qn_path, self.dev_ans_path else: raise ('dataset is not defined') f1_total = 0. em_total = 0. example_num = 0 tic = time.time() for batch in get_batch_generator(self.word2id, context_path, qn_path, ans_path, config.batch_size, context_len=config.context_len, question_len=config.question_len, discard_long=False): pred_start_pos, pred_end_pos = self.test_one_batch(batch, model) pred_start_pos = pred_start_pos.tolist() pred_end_pos = pred_end_pos.tolist() for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) \ in enumerate(zip(pred_start_pos, pred_end_pos, batch.ans_tokens)): example_num += 1 pred_ans_tokens = batch.context_tokens[ex_idx][pred_ans_start : pred_ans_end + 1] pred_answer = " ".join(pred_ans_tokens) true_answer = " ".join(true_ans_tokens) f1 = f1_score(pred_answer, true_answer) em = exact_match_score(pred_answer, true_answer) f1_total += f1 em_total += em if print_to_screen: print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em) if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break f1_total /= example_num em_total /= example_num toc = time.time() logging.info("Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc-tic)) return f1_total, em_total
def check_f1_em(self, context_path, qn_path, ans_path, dataset, num_samples=1000): f1_total = 0. em_total = 0. example_num = 0 for batch in data_batcher.get_batch_generator(self.word2id, self.id2idf, context_path, qn_path, ans_path, self.batch_size, context_len=300, question_len=30, discard_long=False): pred_start_pos, pred_end_pos = self.get_predictions(batch) # Convert the start and end positions to lists length batch_size pred_start_pos = pred_start_pos.tolist() # list length batch_size pred_end_pos = pred_end_pos.tolist() # list length batch_size for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in \ enumerate(zip(pred_start_pos, pred_end_pos, batch.ans_tokens)): example_num += 1 # Get the predicted answer # Important: batch.context_tokens contains the original words (no UNKs) # You need to use the original no-UNK version when measuring F1/EM pred_ans_tokens = batch.context_tokens[ex_idx][ pred_ans_start:pred_ans_end + 1] pred_answer = " ".join(pred_ans_tokens) # Get true answer (no UNKs) true_answer = " ".join(true_ans_tokens) # Calc F1/EM f1 = f1_score(pred_answer, true_answer) em = exact_match_score(pred_answer, true_answer) f1_total += f1 em_total += em if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break f1_total /= example_num em_total /= example_num return f1_total, em_total
def check_f1_em(self, session, context_path, qn_path, ans_path, dataset, num_samples=100, print_to_screen=False): """ Sample from the provided (train/dev) set. For each sample, calculate F1 and EM score. Return average F1 and EM score for all samples. Optionally pretty-print examples. """ logging.info("Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset)) f1_total = 0. em_total = 0. example_num = 0 tic = time.time() for batch in get_batch_generator(self.word2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=False): pred_start_pos, pred_end_pos = self.get_start_end_pos(session, batch) for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in enumerate(zip(pred_start_pos, pred_end_pos, batch.ans_tokens)): example_num += 1 # Get the predicted answer pred_ans_tokens = batch.context_tokens[ex_idx][pred_ans_start : pred_ans_end + 1] pred_answer = " ".join(pred_ans_tokens) # Get true answer (no UNKs) true_answer = " ".join(true_ans_tokens) # Calc F1/EM f1 = f1_score(pred_answer, true_answer) em = exact_match_score(pred_answer, true_answer) f1_total += f1 em_total += em # Optionally pretty-print if print_to_screen: print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em) if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break f1_total /= example_num em_total /= example_num toc = time.time() logging.info("Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc-tic)) return f1_total, em_total
def get_val_loss(self, session): ''' Get average loss on the entire val set This function is called periodically during training ''' total_loss, num_examples = 0., 0 tic = time.time() for batch in get_batch_generator(self.word2id, self.img_features_map, self.val_caption_id_2_caption, self.caption_id_2_img_id, \ self.FLAGS.batch_size, self.FLAGS.max_caption_len, 'train', None, self.FLAGS.data_source): total_loss += self.get_loss(session, batch) * batch.batch_size num_examples += batch.batch_size logging.info("Computing validation loss over {} examples took {} seconds".format(num_examples, time.time() - tic)) return total_loss / num_examples
def get_dev_loss(self, session, dev_context_path, dev_qn_path, dev_ans_path): """ Get loss for entire dev set. Inputs: session: TensorFlow session dev_qn_path, dev_context_path, dev_ans_path: paths to the dev.{context/question/answer} data files Outputs: dev_loss: float. Average loss across the dev set. """ logging.info("Calculating dev loss...") tic = time.time() loss_per_batch, batch_lengths = [], [] i = 0 for batch in get_batch_generator( self.word2id, self.context2id, self.ans2id, dev_context_path, dev_qn_path, dev_ans_path, self.FLAGS.batch_size, self.graph_vocab_class, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, answer_len=self.FLAGS.answer_len, discard_long=False, use_raw_graph=self.FLAGS.use_raw_graph, show_start_tokens=self.FLAGS.show_start_tokens): loss = self.get_loss(session, batch) curr_batch_size = batch.batch_size loss_per_batch.append(loss * curr_batch_size) batch_lengths.append(curr_batch_size) if i == 10: break i += 1 # Calculate average loss total_num_examples = sum(batch_lengths) toc = time.time() print "Computed dev loss over %i examples in %.2f seconds" % ( total_num_examples, toc - tic) # Overall loss is total loss divided by total number of examples dev_loss = sum(loss_per_batch) / float(total_num_examples) return dev_loss
def train(self, session, train_context_path, train_qn_path, train_ans_path, dev_qn_path, dev_context_path, dev_ans_path): summary_writer = tf.summary.FileWriter( "/Users/lam/Desktop/Lam-cs224n/Projects/qa/squad", session.graph) for batch in get_batch_generator(self.word2id, self.char2id, train_context_path, train_qn_path, train_ans_path, self.FLAGS.batch_size, self.FLAGS.context_len, self.FLAGS.question_len, self.FLAGS.max_word_len, discard_long=True): self.sample_batch = batch self.run_train_iter(session, batch, summary_writer) break
def check_metric(self, session, mode='val', num_samples=0): ''' Evaluate the model on the validation or test set. Inputs: mode: should be either 'val' or 'test' num_samples: number of images to evaluate on. Evaluate on all val images if 0. ''' assert (mode == 'val' or mode == 'test') captions = [] # [{"image_id": image_id, "caption": caption_str}] # Generate all the captions and save in list 'captions' tic = time.time() num_seen = 0 # Record the number of samples predicted so far this_caption_map = self.val_caption_id_2_caption if mode == 'val' else self.test_caption_id_2_caption for batch in get_batch_generator(self.word2id, self.img_features_map, this_caption_map, self.caption_id_2_img_id, \ self.FLAGS.batch_size, self.FLAGS.max_caption_len, 'eval', None, self.FLAGS.data_source): batch_captions = self.get_captions(session, batch) # {imgae_id: caption_string} for id, cap in batch_captions.items(): captions.append({"image_id": id, "caption": cap}) num_seen += batch.batch_size if num_samples != 0 and num_seen >= num_samples: break logging.info("Predicting on {} examples took {} seconds".format(num_seen, time.time() - tic)) # Dump the generated captions to json file file = open(self.FLAGS.train_res_dir, 'w') json.dump(captions, file) file.close() # Evaluate using the official evaluation API (The evaluation takes ~12s for 1000 examples) tic = time.time() cocoGold = COCO(self.FLAGS.goldAnn_val_dir) # Official annotations cocoRes = cocoGold.loadRes(self.FLAGS.train_res_dir) # Prediction cocoEval = COCOEvalCap(cocoGold, cocoRes) cocoEval.params['image_id'] = cocoRes.getImgIds() # Evaluate on a subset of the official captions_val2014 cocoEval.evaluate() logging.info("Evaluating {} predictions took {} seconds".format(num_seen, time.time() - tic)) scores = cocoEval.eval # {metric_name: metric_score} return scores # Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, CIDEr
def get_dev_loss(self, model): logging.info("Calculating dev loss...") tic = time.time() loss_per_batch, batch_lengths = [], [] i = 0 for batch in get_batch_generator(self.word2id, self.dev_context_path, self.dev_qn_path, self.dev_ans_path, config.batch_size, context_len=config.context_len, question_len=config.question_len, discard_long=True): loss, _, _ = self.eval_one_batch(batch, model) curr_batch_size = batch.batch_size loss_per_batch.append(loss * curr_batch_size) batch_lengths.append(curr_batch_size) i += 1 if i == 10: break total_num_examples = sum(batch_lengths) toc = time.time() print ("Computed dev loss over %i examples in %.2f seconds" % (total_num_examples, toc-tic)) dev_loss = sum(loss_per_batch) / float(total_num_examples) return dev_loss
def get_dev_loss(self, session, dev_context_path, dev_qn_path, dev_ans_path): """ Get loss for entire dev set. Inputs: session: TensorFlow session dev_qn_path, dev_context_path, dev_ans_path: paths to the dev.{context/question/answer} data files Outputs: dev_loss: float. Average loss across the dev set. """ logging.info("Calculating dev loss...") tic = time.time() loss_per_batch, batch_lengths = [], [] # Iterate over dev set batches # Note: here we set discard_long=True, meaning we discard any examples # which are longer than our context_len or question_len. # We need to do this because if, for example, the true answer is cut # off the context, then the loss function is undefined. for batch in get_batch_generator(self.word2id, dev_context_path, dev_qn_path, dev_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True): # Get loss for this batch loss = self.get_loss(session, batch) curr_batch_size = batch.batch_size loss_per_batch.append(loss * curr_batch_size) batch_lengths.append(curr_batch_size) # Calculate average loss total_num_examples = sum(batch_lengths) toc = time.time() print "Computed dev loss over %i examples in %.2f seconds" % (total_num_examples, toc-tic) # Overall loss is total loss divided by total number of examples dev_loss = sum(loss_per_batch) / float(total_num_examples) return dev_loss
def demo(self, session, context_path, qn_path, ans_path, dataset, num_samples=10, print_to_screen=False, write_out=False, file_out=None, shuffle=True): """ Sample from the provided (train/dev) set. For each sample, calculate F1 and EM score. Return average F1 and EM score for all samples. Optionally pretty-print examples. Inputs: session: TensorFlow session qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files. dataset: string. Either "train" or "dev". Just for logging purposes. num_samples: int. How many samples to use. If num_samples=0 then do whole dataset. print_to_screen: if True, pretty-prints each example to screen Returns: F1 and EM: Scalars. The average across the sampled examples. """ logging.info( "Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset)) example_num = 0 tic = time.time() ans_list = [] graph_route_info = [] for batch in get_batch_generator( self.word2id, self.context2id, self.ans2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, self.graph_vocab_class, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, answer_len=self.FLAGS.answer_len, discard_long=False, use_raw_graph=self.FLAGS.use_raw_graph, shuffle=shuffle, show_start_tokens=self.FLAGS.show_start_tokens, output_goal=True): train_ids, pred_ids, dev_final_states, pred_logits = self.get_prob_dists( session, batch) start_ids = batch.ans_ids[:, 0].reshape(-1) if self.FLAGS.pred_method != 'beam': pred_ids, confidence_score, ans_str = output_route( start_ids, pred_logits, batch.context_tokens, self.ans2id, self.id2ans, self.FLAGS.answer_len) pred_ids = pred_ids.tolist() # the output of using test network dev_attention_map = create_attention_images_summary( dev_final_states) print "dev_attention_map", dev_attention_map.shape dev_attention_map = dev_attention_map.eval().tolist() # the output of using training network, that the true input is fed as the input of the next RNN, for debug. for ex_idx, (pred_ans_list, true_ans_tokens, attention_map) in enumerate( zip(pred_ids, list(batch.ans_tokens), dev_attention_map)): example_num += 1 pred_ans_tokens = [] for id in pred_ans_list: if id == PAD_ID: break else: pred_ans_tokens.append(self.id2ans[id]) pred_answer = " ".join(pred_ans_tokens) # Get true answer (no UNKs) true_answer = " ".join(true_ans_tokens[:]) # Calculate metrics f1, em, edit_dist, rough_em = compute_all_metrics( pred_ans_tokens, true_ans_tokens) ans_list.append(pred_answer) if print_to_screen: print_example(self.word2id, self.context2id, self.ans2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], true_answer, pred_answer, f1, em, edit_dist, confidence_score[ex_idx]) # Draw attention map draw_attention(batch, ex_idx, attention_map, pred_ans_tokens) if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break toc = time.time() logging.info( "Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc - tic)) if write_out: logging.info("Writing the prediction to {}".format(file_out)) with open(file_out, 'w') as f: for line, extra_info in zip(ans_list, graph_route_info): f.write(line + " " + " ".join(extra_info) + '\n') print("Wrote predictions to %s" % file_out) return
def train(self, session, train_context_path, train_qn_path, train_ans_path, dev_qn_path, dev_context_path, dev_ans_path): """ Main training loop. """ # Print number of model parameters tic = time.time() params = tf.trainable_variables() num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()), params)) toc = time.time() logging.info("Number of params: %d (retrieval took %f secs)" % (num_params, toc - tic)) exp_loss = None # Checkpoint management. checkpoint_path = os.path.join(self.FLAGS.train_dir, "qa.ckpt") bestmodel_dir = os.path.join(self.FLAGS.train_dir, "best_checkpoint") bestmodel_ckpt_path = os.path.join(bestmodel_dir, "qa_best.ckpt") best_dev_f1 = None best_dev_em = None # for TensorBoard summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir, session.graph) epoch = 0 logging.info("Beginning training loop...") while self.FLAGS.num_epochs == 0 or epoch < self.FLAGS.num_epochs: epoch += 1 epoch_tic = time.time() for batch in get_batch_generator(self.word2id, train_context_path, train_qn_path, train_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True): iter_tic = time.time() loss, global_step, param_norm, grad_norm = self.run_train_iter(session, batch, summary_writer) iter_toc = time.time() iter_time = iter_toc - iter_tic if not exp_loss: # first iter exp_loss = loss else: exp_loss = 0.99 * exp_loss + 0.01 * loss if global_step % self.FLAGS.print_every == 0: logging.info( 'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f' % (epoch, global_step, loss, exp_loss, grad_norm, param_norm, iter_time)) if global_step % self.FLAGS.save_every == 0: logging.info("Saving to %s..." % checkpoint_path) self.saver.save(session, checkpoint_path, global_step=global_step) if global_step % self.FLAGS.eval_every == 0: dev_loss = self.get_dev_loss(session, dev_context_path, dev_qn_path, dev_ans_path) logging.info("Epoch %d, Iter %d, dev loss: %f" % (epoch, global_step, dev_loss)) write_summary(dev_loss, "dev/loss", summary_writer, global_step) train_f1, train_em = self.check_f1_em(session, train_context_path, train_qn_path, train_ans_path, "train", num_samples=1000) logging.info("Epoch %d, Iter %d, Train F1 score: %f, Train EM score: %f" % (epoch, global_step, train_f1, train_em)) write_summary(train_f1, "train/F1", summary_writer, global_step) write_summary(train_em, "train/EM", summary_writer, global_step) dev_f1, dev_em = self.check_f1_em(session, dev_context_path, dev_qn_path, dev_ans_path, "dev", num_samples=0) logging.info("Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f" % (epoch, global_step, dev_f1, dev_em)) write_summary(dev_f1, "dev/F1", summary_writer, global_step) write_summary(dev_em, "dev/EM", summary_writer, global_step) if best_dev_f1 is None or dev_f1 > best_dev_f1: best_dev_f1 = dev_f1 logging.info("Saving to %s..." % bestmodel_ckpt_path) self.bestmodel_saver.save(session, bestmodel_ckpt_path, global_step=global_step) epoch_toc = time.time() logging.info("End of epoch %i. Time for epoch: %f" % (epoch, epoch_toc-epoch_tic)) sys.stdout.flush()
def check_f1_em(self, session, context_path, qn_path, ans_path, dataset, num_samples=100, print_to_screen=False): """ Sample from the provided (train/dev) set. For each sample, calculate F1 and EM score. Return average F1 and EM score for all samples. Optionally pretty-print examples. Note: This function is not quite the same as the F1/EM numbers you get from "official_eval" mode. This function uses the pre-processed version of the e.g. dev set for speed, whereas "official_eval" mode uses the original JSON. Therefore: 1. official_eval takes your max F1/EM score w.r.t. the three reference answers, whereas this function compares to just the first answer (which is what's saved in the preprocessed data) 2. Our preprocessed version of the dev set is missing some examples due to tokenization issues (see squad_preprocess.py). "official_eval" includes all examples. Inputs: session: TensorFlow session qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files. dataset: string. Either "train" or "dev". Just for logging purposes. num_samples: int. How many samples to use. If num_samples=0 then do whole dataset. print_to_screen: if True, pretty-prints each example to screen Returns: F1 and EM: Scalars. The average across the sampled examples. """ logging.info("Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset)) f1_total = 0. em_total = 0. example_num = 0 tic = time.time() # Note here we select discard_long=False because we want to sample from the entire dataset # That means we're truncating, rather than discarding, examples with too-long context or questions for batch in get_batch_generator(self.word2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=False): pred_start_pos, pred_end_pos = self.get_start_end_pos(session, batch) # Convert the start and end positions to lists length batch_size pred_start_pos = pred_start_pos.tolist() # list length batch_size pred_end_pos = pred_end_pos.tolist() # list length batch_size for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in enumerate(zip(pred_start_pos, pred_end_pos, batch.ans_tokens)): example_num += 1 # Get the predicted answer # Important: batch.context_tokens contains the original words (no UNKs) # You need to use the original no-UNK version when measuring F1/EM pred_ans_tokens = batch.context_tokens[ex_idx][pred_ans_start : pred_ans_end + 1] pred_answer = " ".join(pred_ans_tokens) # Get true answer (no UNKs) true_answer = " ".join(true_ans_tokens) # Calc F1/EM f1 = f1_score(pred_answer, true_answer) em = exact_match_score(pred_answer, true_answer) f1_total += f1 em_total += em # Optionally pretty-print if print_to_screen: print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em) if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break f1_total /= example_num em_total /= example_num toc = time.time() logging.info("Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc-tic)) return f1_total, em_total
def train(self, session, train_context_path, train_qn_path, train_ans_path, dev_qn_path, dev_context_path, dev_ans_path): """ Main training loop. Inputs: session: TensorFlow session {train/dev}_{qn/context/ans}_path: paths to {train/dev}.{context/question/answer} data files """ # Print number of model parameters tic = time.time() params = tf.trainable_variables() num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()), params)) toc = time.time() logging.info("Number of params: %d (retrieval took %f secs)" % (num_params, toc - tic)) # We will keep track of exponentially-smoothed loss exp_loss = None # Checkpoint management. # We keep one latest checkpoint, and one best checkpoint (early stopping) checkpoint_path = os.path.join(self.FLAGS.train_dir, "qa.ckpt") bestmodel_dir = os.path.join(self.FLAGS.train_dir, "best_checkpoint") bestmodel_ckpt_path = os.path.join(bestmodel_dir, "qa_best.ckpt") best_dev_f1 = None best_dev_em = None # for TensorBoard summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir, session.graph) epoch = 0 logging.info("Beginning training loop...") while self.FLAGS.num_epochs == 0 or epoch < self.FLAGS.num_epochs: epoch += 1 epoch_tic = time.time() # Loop over batches for batch in get_batch_generator(self.word2id, train_context_path, train_qn_path, train_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True): # Run training iteration iter_tic = time.time() loss, global_step, param_norm, grad_norm = self.run_train_iter(session, batch, summary_writer) iter_toc = time.time() iter_time = iter_toc - iter_tic # Update exponentially-smoothed loss if not exp_loss: # first iter exp_loss = loss else: exp_loss = 0.99 * exp_loss + 0.01 * loss # Sometimes print info to screen if global_step % self.FLAGS.print_every == 0: logging.info( 'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f' % (epoch, global_step, loss, exp_loss, grad_norm, param_norm, iter_time)) # Sometimes save model if global_step % self.FLAGS.save_every == 0: logging.info("Saving to %s..." % checkpoint_path) self.saver.save(session, checkpoint_path, global_step=global_step) # Sometimes evaluate model on dev loss, train F1/EM and dev F1/EM if global_step % self.FLAGS.eval_every == 0: # Get loss for entire dev set and log to tensorboard dev_loss = self.get_dev_loss(session, dev_context_path, dev_qn_path, dev_ans_path) logging.info("Epoch %d, Iter %d, dev loss: %f" % (epoch, global_step, dev_loss)) write_summary(dev_loss, "dev/loss", summary_writer, global_step) # Get F1/EM on train set and log to tensorboard train_f1, train_em = self.check_f1_em(session, train_context_path, train_qn_path, train_ans_path, "train", num_samples=1000) logging.info("Epoch %d, Iter %d, Train F1 score: %f, Train EM score: %f" % (epoch, global_step, train_f1, train_em)) write_summary(train_f1, "train/F1", summary_writer, global_step) write_summary(train_em, "train/EM", summary_writer, global_step) # Get F1/EM on dev set and log to tensorboard dev_f1, dev_em = self.check_f1_em(session, dev_context_path, dev_qn_path, dev_ans_path, "dev", num_samples=0) logging.info("Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f" % (epoch, global_step, dev_f1, dev_em)) write_summary(dev_f1, "dev/F1", summary_writer, global_step) write_summary(dev_em, "dev/EM", summary_writer, global_step) # Early stopping based on dev EM. You could switch this to use F1 instead. if best_dev_em is None or dev_em > best_dev_em: best_dev_em = dev_em logging.info("Saving to %s..." % bestmodel_ckpt_path) self.bestmodel_saver.save(session, bestmodel_ckpt_path, global_step=global_step) epoch_toc = time.time() logging.info("End of epoch %i. Time for epoch: %f" % (epoch, epoch_toc-epoch_tic)) sys.stdout.flush()
def train(self, model_file_path): train_dir = os.path.join(config.log_root, 'train_%d' % (int(time.time()))) if not os.path.exists(train_dir): os.mkdir(train_dir) model_dir = os.path.join(train_dir, 'model') if not os.path.exists(model_dir): os.mkdir(model_dir) bestmodel_dir = os.path.join(train_dir, 'bestmodel') if not os.path.exists(bestmodel_dir): os.makedirs(bestmodel_dir) summary_writer = SummaryWriter(train_dir) with open(os.path.join(train_dir, "flags.json"), 'w') as fout: json.dump(vars(config), fout) model = self.get_model(model_file_path) params = list(filter(lambda p: p.requires_grad, model.parameters())) optimizer = Adam(params, lr=config.lr, weight_decay=config.reg_lambda, amsgrad=True) num_params = sum(p.numel() for p in params) logging.info("Number of params: %d" % num_params) exp_loss, best_dev_f1, best_dev_em = None, None, None epoch = 0 global_step = 0 logging.info("Beginning training loop...") while config.num_epochs == 0 or epoch < config.num_epochs: epoch += 1 epoch_tic = time.time() for batch in get_batch_generator(self.word2id, self.train_context_path, self.train_qn_path, self.train_ans_path, config.batch_size, context_len=config.context_len, question_len=config.question_len, discard_long=True): global_step += 1 iter_tic = time.time() loss, param_norm, grad_norm = self.train_one_batch(batch, model, optimizer, params) summary_writer.add_scalar("train/loss", loss, global_step) iter_toc = time.time() iter_time = iter_toc - iter_tic if not exp_loss: exp_loss = loss else: exp_loss = 0.99 * exp_loss + 0.01 * loss if global_step % config.print_every == 0: logging.info( 'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f' % (epoch, global_step, loss, exp_loss, grad_norm, param_norm, iter_time)) if global_step % config.save_every == 0: logging.info("Saving to %s..." % model_dir) self.save_model(model, optimizer, loss, global_step, epoch, model_dir) if global_step % config.eval_every == 0: dev_loss = self.get_dev_loss(model) logging.info("Epoch %d, Iter %d, dev loss: %f" % (epoch, global_step, dev_loss)) summary_writer.add_scalar("dev/loss", dev_loss, global_step) train_f1, train_em = self.check_f1_em(model, "train", num_samples=1000) logging.info("Epoch %d, Iter %d, Train F1 score: %f, Train EM score: %f" % ( epoch, global_step, train_f1, train_em)) summary_writer.add_scalar("train/F1", train_f1, global_step) summary_writer.add_scalar("train/EM", train_em, global_step) dev_f1, dev_em = self.check_f1_em(model, "dev", num_samples=0) logging.info( "Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f" % (epoch, global_step, dev_f1, dev_em)) summary_writer.add_scalar("dev/F1", dev_f1, global_step) summary_writer.add_scalar("dev/EM", dev_em, global_step) if best_dev_f1 is None or dev_f1 > best_dev_f1: best_dev_f1 = dev_f1 if best_dev_em is None or dev_em > best_dev_em: best_dev_em = dev_em logging.info("Saving to %s..." % bestmodel_dir) self.save_model(model, optimizer, loss, global_step, epoch, bestmodel_dir) epoch_toc = time.time() logging.info("End of epoch %i. Time for epoch: %f" % (epoch, epoch_toc - epoch_tic)) sys.stdout.flush()
def train(self, session): """ Main training loop. """ # Print number of model parameters tic = time.time() params = tf.trainable_variables() num_params = sum(map(lambda t: np.prod(tf.shape(t.value()).eval()), params)) toc = time.time() logging.info("Number of params: %d (retrieval took %f secs)" % (num_params, toc - tic)) print(params) # For debugging purpose # We will keep track of exponentially-smoothed loss exp_loss = None # Checkpoint management. We keep one latest checkpoint, and one best checkpoint (early stopping) checkpoint_path = os.path.join(self.FLAGS.train_dir, "latest.ckpt") bestmodel_dir = self.FLAGS.bestmodel_dir bestmodel_ckpt_path = os.path.join(bestmodel_dir, "best.ckpt") best_val_metric = None # For TensorBoard summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir, session.graph) epoch = 0 logging.info("Beginning training loop...") while self.FLAGS.num_epochs == 0 or epoch < self.FLAGS.num_epochs: epoch += 1 epoch_tic = time.time() # Loop over batches for batch in get_batch_generator(self.word2id, self.img_features_map, self.train_caption_id_2_caption, self.caption_id_2_img_id, self.FLAGS.batch_size, self.FLAGS.max_caption_len, 'train', None, self.FLAGS.data_source): # Run training iteration iter_tic = time.time() loss, global_step, param_norm, grad_norm = self.run_train_iter(session, batch, summary_writer) iter_toc = time.time() iter_time = iter_toc - iter_tic # Update exponentially-smoothed loss if not exp_loss: # first iter exp_loss = loss else: exp_loss = 0.99 * exp_loss + 0.01 * loss # Sometimes print info to screen if global_step % self.FLAGS.print_every == 0: logging.info( 'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f' % (epoch, global_step, loss, exp_loss, grad_norm, param_norm, iter_time)) write_summary(loss, "train/loss", summary_writer, global_step) # Sometimes save model if global_step % self.FLAGS.save_every == 0: logging.info("Saving to %s..." % checkpoint_path) self.saver.save(session, checkpoint_path, global_step=global_step) # Sometimes evaluate the model if global_step % self.FLAGS.eval_every == 0: # Get loss for entire val set and log to tensorboard val_loss = self.get_val_loss(session) logging.info("Epoch %d, Iter %d, Val loss: %f" % (epoch, global_step, val_loss)) write_summary(val_loss, "val/loss", summary_writer, global_step) # Evaluate on val set and log all the metrics to tensorboard val_scores = self.check_metric(session, mode='val', num_samples=0) val_metric = val_scores[self.FLAGS.primary_metric] for metric_name, metric_score in val_scores.items(): logging.info("Epoch {}, Iter {}, Val {}: {}".format(epoch, global_step, metric_name, metric_score)) write_summary(metric_score, "val/"+metric_name, summary_writer, global_step) # Early stopping based on val evaluation if best_val_metric is None or val_metric > best_val_metric: best_val_metric = val_metric logging.info("Saving to %s..." % bestmodel_ckpt_path) self.bestmodel_saver.save(session, bestmodel_ckpt_path, global_step=global_step) epoch_toc = time.time() logging.info("End of epoch %i. Time for epoch: %f" % (epoch, epoch_toc-epoch_tic)) sys.stdout.flush()
def visualize_results(self, session, context_path, qn_path, ans_path, dataset, num_samples=100, print_to_screen=False): logging.info( "Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset)) f1_total = 0. em_total = 0. example_num = 0 tic = time.time() # Note here we select discard_long=False because we want to sample from the entire dataset # That means we're truncating, rather than discarding, examples with too-long context or questions for batch in get_batch_generator(self.word2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=False): pred_start_pos, pred_end_pos, c2q_attn, q2c_attn, strt_logts, end_logts = self.get_results_vis( session, batch) # Convert the start and end positions to lists length batch_size pred_start_pos = pred_start_pos.tolist() # list length batch_size pred_end_pos = pred_end_pos.tolist() # list length batch_size q2c_attn = np.argmax(q2c_attn, axis=1) q2c_attn = q2c_attn.tolist() fig = plt.figure() gs = grd.GridSpec(3, 1, height_ratios=[1, 3, 1]) ax = plt.subplot(gs[1]) c2q_attn_plt = c2q_attn[ 0, :len(batch.context_tokens[0]), :len(batch.qn_tokens[0])] p = ax.imshow(np.transpose(c2q_attn_plt), interpolation='nearest', aspect='auto') plt.ylabel('c2q attn') plt.xlim(0, len(batch.context_tokens[0])) ax2 = plt.subplot(gs[0]) ax2.plot(strt_logts[0, :len(batch.context_tokens[0])]) plt.ylabel('start logits') plt.xlim(0, len(batch.context_tokens[0])) ax3 = plt.subplot(gs[2]) ax3.plot(end_logts[0, :len(batch.context_tokens[0])]) plt.ylabel('end logits') plt.xlim(0, len(batch.context_tokens[0])) plt.savefig('c2q_attn.png') plt.clf() for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens, q2c_attn_idx) in enumerate( zip(pred_start_pos, pred_end_pos, batch.ans_tokens, q2c_attn)): example_num += 1 # Get the predicted answer # Important: batch.context_tokens contains the original words (no UNKs) # You need to use the original no-UNK version when measuring F1/EM pred_ans_tokens = batch.context_tokens[ex_idx][ pred_ans_start:pred_ans_end + 1] pred_answer = " ".join(pred_ans_tokens) qn_attn_words = [ batch.context_tokens[ex_idx][i] for i in q2c_attn_idx[:len(batch.qn_tokens[ex_idx])] ] qn_attn = " ".join(qn_attn_words) # Get true answer (no UNKs) true_answer = " ".join(true_ans_tokens) # Calc F1/EM f1 = f1_score(pred_answer, true_answer) em = exact_match_score(pred_answer, true_answer) f1_total += f1 em_total += em # Optionally pretty-print if print_to_screen: print_example_attn(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], qn_attn, batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em) if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break f1_total /= example_num em_total /= example_num toc = time.time() logging.info( "Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc - tic)) return f1_total, em_total
def check_f1_em(self, session, context_path, qn_path, ans_path, dataset, num_samples=100, print_to_screen=False): """ Sample from the provided (train/dev) set. For each sample, calculate F1 and EM score. Return average F1 and EM score for all samples. Optionally pretty-print examples. Note: This function is not quite the same as the F1/EM numbers you get from "official_eval" mode. This function uses the pre-processed version of the e.g. dev set for speed, whereas "official_eval" mode uses the original JSON. Therefore: 1. official_eval takes your max F1/EM score w.r.t. the three reference answers, whereas this function compares to just the first answer (which is what's saved in the preprocessed data) 2. Our preprocessed version of the dev set is missing some examples due to tokenization issues (see squad_preprocess.py). "official_eval" includes all examples. Inputs: session: TensorFlow session qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files. dataset: string. Either "train" or "dev". Just for logging purposes. num_samples: int. How many samples to use. If num_samples=0 then do whole dataset. print_to_screen: if True, pretty-prints each example to screen Returns: F1 and EM: Scalars. The average across the sampled examples. """ logging.info( "Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset)) f1_total = 0. em_total = 0. example_num = 0 tic = time.time() # Note here we select discard_long=False because we want to sample from the entire dataset # That means we're truncating, rather than discarding, examples with too-long context or questions for batch in get_batch_generator(self.word2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=False): pred_start_pos, pred_end_pos = self.get_start_end_pos( session, batch) # Convert the start and end positions to lists length batch_size pred_start_pos = pred_start_pos.tolist() # list length batch_size pred_end_pos = pred_end_pos.tolist() # list length batch_size for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in enumerate( zip(pred_start_pos, pred_end_pos, batch.ans_tokens)): example_num += 1 # Get the predicted answer # Important: batch.context_tokens contains the original words (no UNKs) # You need to use the original no-UNK version when measuring F1/EM pred_ans_tokens = batch.context_tokens[ex_idx][ pred_ans_start:pred_ans_end + 1] pred_answer = " ".join(pred_ans_tokens) # Get true answer (no UNKs) true_answer = " ".join(true_ans_tokens) # Calc F1/EM f1 = f1_score(pred_answer, true_answer) em = exact_match_score(pred_answer, true_answer) f1_total += f1 em_total += em # Optionally pretty-print if print_to_screen: print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em) if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break f1_total /= example_num em_total /= example_num toc = time.time() logging.info( "Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc - tic)) return f1_total, em_total
def check_f1_em(self, session, context_path, qn_path, ans_path, dataset, num_samples=100, print_to_screen=False): logging.info( "Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset)) f1_total = 0. em_total = 0. example_num = 0 tic = time.time() # Note here we select discard_long=False because we want to sample from the entire dataset # That means we're truncating, rather than discarding, examples with too-long context or questions for batch in get_batch_generator(self.word2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=False): pred_start_pos, pred_end_pos = self.get_start_end_pos( session, batch) # Convert the start and end positions to lists length batch_size pred_start_pos = pred_start_pos.tolist() # list length batch_size pred_end_pos = pred_end_pos.tolist() # list length batch_size for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in enumerate( zip(pred_start_pos, pred_end_pos, batch.ans_tokens)): example_num += 1 # Get the predicted answer # Important: batch.context_tokens contains the original words (no UNKs) # You need to use the original no-UNK version when measuring F1/EM pred_ans_tokens = batch.context_tokens[ex_idx][ pred_ans_start:pred_ans_end + 1] pred_answer = " ".join(pred_ans_tokens) # Get true answer (no UNKs) true_answer = " ".join(true_ans_tokens) # Calc F1/EM f1 = f1_score(pred_answer, true_answer) em = exact_match_score(pred_answer, true_answer) f1_total += f1 em_total += em # Optionally pretty-print if print_to_screen: print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em) if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break f1_total /= example_num em_total /= example_num toc = time.time() logging.info( "Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc - tic)) return f1_total, em_total
def get_error_stats(self, session, context_path, qn_path, ans_path, dataset, num_samples=10, print_to_screen=False): """ Sample from the provided (train/dev) set. For each sample, calculate F1 and EM score. Return average F1 and EM score for all samples. Optionally pretty-print examples. Note: This function is not quite the same as the F1/EM numbers you get from "official_eval" mode. This function uses the pre-processed version of the e.g. dev set for speed, whereas "official_eval" mode uses the original JSON. Therefore: 1. official_eval takes your max F1/EM score w.r.t. the three reference answers, whereas this function compares to just the first answer (which is what's saved in the preprocessed data) 2. Our preprocessed version of the dev set is missing some examples due to tokenization issues (see squad_preprocess.py). "official_eval" includes all examples. Inputs: session: TensorFlow session qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files. dataset: string. Either "train" or "dev". Just for logging purposes. num_samples: int. How many samples to use. If num_samples=0 then do whole dataset. print_to_screen: if True, pretty-prints each example to screen Returns: F1 and EM: Scalars. The average across the sampled examples. """ logging.info( "Calculating Error stats for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset)) f1_total = 0. em_total = 0. example_num = 0 tic = time.time() # Note here we select discard_long=False because we want to sample from the entire dataset # That means we're truncating, rather than discarding, examples with too-long context or questions first_token_qn_dict_wrong = defaultdict(float) first_token_qn_dict_total = defaultdict(float) first_token_qn_dict_f1 = defaultdict(float) for batch in get_batch_generator(self.word2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=False): pred_start_pos, pred_end_pos = self.get_start_end_pos( session, batch) # Convert the start and end positions to lists length batch_size pred_start_pos = pred_start_pos.tolist() # list length batch_size pred_end_pos = pred_end_pos.tolist() # list length batch_size for ex_idx, (pred_ans_start, pred_ans_end, true_ans_tokens) in enumerate( zip(pred_start_pos, pred_end_pos, batch.ans_tokens)): example_num += 1 # Get the predicted answer # Important: batch.context_tokens contains the original words (no UNKs) # You need to use the original no-UNK version when measuring F1/EM pred_ans_tokens = batch.context_tokens[ex_idx][ pred_ans_start:pred_ans_end + 1] pred_answer = " ".join(pred_ans_tokens) # Get true answer (no UNKs) true_answer = " ".join(true_ans_tokens) # Calc F1/EM f1 = f1_score(pred_answer, true_answer) em = exact_match_score(pred_answer, true_answer) first_token_qn = batch.qn_tokens[ex_idx][0] first_token_qn_dict_total[first_token_qn] += 1 #print 'example_num: ', example_num #print 'total words seen in first_token_qn_dict: ', sum(first_token_qn_dict_total.itervalues()) if not em: #we have found an error: #get first token of error question: first_token_qn_dict_wrong[first_token_qn] += 1 f1_total += f1 first_token_qn_dict_f1[first_token_qn] += f1 em_total += em # Optionally pretty-print if print_to_screen: print_example(self.word2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], batch.ans_span[ex_idx, 0], batch.ans_span[ex_idx, 1], pred_ans_start, pred_ans_end, true_answer, pred_answer, f1, em) if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break f1_total /= example_num em_total /= example_num print 'total words: ', sum(first_token_qn_dict_total.itervalues()) toc = time.time() logging.info( "Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc - tic)) final_freq_dict = {} for token, count in sorted(first_token_qn_dict_total.iteritems(), key=lambda (k, v): (v, k)): #key is fist token of question, value is how many times that token occurs freq = first_token_qn_dict_wrong[ token] / first_token_qn_dict_total[token] f1 = first_token_qn_dict_f1[token] / first_token_qn_dict_total[ token] print "When first token is: [", token, "] f1:", f1, "We got : ", first_token_qn_dict_wrong[ token], " wrong exact match out of ", first_token_qn_dict_total[ token], " percentage of 1st tokens that are this token: ", first_token_qn_dict_total[ token] / sum(first_token_qn_dict_total.itervalues( )), " precentage of this token WRONG: ", freq print('em_total:', em_total) print('f1_total:', f1_total) return f1_total, em_total
def check_f1_em(self, session, context_path, qn_path, ans_path, dataset, num_samples=10, print_to_screen=False, write_out=False, file_out=None, shuffle=True): """ Sample from the provided (train/dev) set. For each sample, calculate F1 and EM score. Return average F1 and EM score for all samples. Optionally pretty-print examples. Inputs: session: TensorFlow session qn_path, context_path, ans_path: paths to {dev/train}.{question/context/answer} data files. dataset: string. Either "train" or "dev". Just for logging purposes. num_samples: int. How many samples to use. If num_samples=0 then do whole dataset. print_to_screen: if True, pretty-prints each example to screen Returns: F1 and EM: Scalars. The average across the sampled examples. """ logging.info( "Calculating F1/EM for %s examples in %s set..." % (str(num_samples) if num_samples != 0 else "all", dataset)) f1_total = 0. em_total = 0. ed_total = 0. rough_em_total = 0. example_num = 0 tic = time.time() ans_list = [] graph_route_info = [] # Note here we select discard_long=False because we want to sample from the entire dataset # That means we're truncating, rather than discarding, examples with too-long context or questions for batch in get_batch_generator( self.word2id, self.context2id, self.ans2id, context_path, qn_path, ans_path, self.FLAGS.batch_size, self.graph_vocab_class, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, answer_len=self.FLAGS.answer_len, discard_long=False, use_raw_graph=self.FLAGS.use_raw_graph, shuffle=shuffle, show_start_tokens=self.FLAGS.show_start_tokens, output_goal=True): train_ids, pred_ids, dev_final_states, pred_logits = self.get_prob_dists( session, batch) start_ids = batch.ans_ids[:, 0].reshape(-1) graph_length = np.sum(batch.context_mask, axis=1) if self.FLAGS.pred_method != 'beam': pred_ids, confidence_score, ans_str = verify_route( start_ids, pred_logits, batch.context_tokens, self.ans2id, self.id2ans, self.FLAGS.answer_len) f1_scores, em_scores, ed_scores, gm_scores = [], [], [], [] pred_ids = pred_ids.tolist() # the output of using test network for ex_idx, (pred_ans_list, true_ans_tokens) in enumerate( zip(pred_ids, list(batch.ans_tokens))): example_num += 1 pred_ans_tokens = [] for id in pred_ans_list: if id == PAD_ID: break else: pred_ans_tokens.append(self.id2ans[id]) pred_answer = " ".join(pred_ans_tokens) # Get true answer (no UNKs) true_answer = " ".join(true_ans_tokens[:]) # Calculate metrics f1, em, edit_dist, goal_match = compute_all_metrics( pred_ans_tokens, true_ans_tokens) f1_scores.append(f1) em_scores.append(em) ed_scores.append(edit_dist) gm_scores.append(goal_match) f1_total += f1 em_total += em ed_total += edit_dist rough_em_total += goal_match ans_list.append(pred_answer) graph_route_info.append( (str(int(graph_length[ex_idx])), str(len(true_ans_tokens[1:-1])), str(int(em)))) # Optionally pretty-print if print_to_screen: print_example(self.word2id, self.context2id, self.ans2id, batch.context_tokens[ex_idx], batch.qn_tokens[ex_idx], true_answer, pred_answer, f1, em, edit_dist, confidence_score[ex_idx]) if num_samples != 0 and example_num >= num_samples: break if num_samples != 0 and example_num >= num_samples: break f1_total /= example_num em_total /= example_num ed_total /= example_num rough_em_total /= example_num toc = time.time() logging.info( "Calculating F1/EM for %i examples in %s set took %.2f seconds" % (example_num, dataset, toc - tic)) if write_out: logging.info("Writing the prediction to {}".format(file_out)) with open(file_out, 'w') as f: for line, extra_info in zip(ans_list, graph_route_info): f.write(line + " " + " ".join(extra_info) + '\n') print("Wrote predictions to %s" % file_out) em_file = "em_" + str(file_out) logging.info("Writing EM scores to {}".format(em_file)) with open(em_file, 'w') as f: for em in em_scores: f.write(str(em) + '\n') print("Wrote EM Scores to %s" % em_file) ed_file = "ed_" + str(file_out) logging.info("Writing ED scores to {}".format(ed_file)) with open(ed_file, 'w') as f: for ed in ed_scores: f.write(str(ed) + '\n') print("Wrote ED Scores to %s" % ed_file) gm_file = "gm_" + str(file_out) logging.info("Writing GM scores to {}".format(gm_file)) with open(gm_file, 'w') as f: for gm in gm_scores: f.write(str(gm) + '\n') print("Wrote GM Scores to %s" % gm_file) return f1_total, em_total, ed_total, rough_em_total
def main(): print("Your TensorFlow version: %s" % tf.__version__) # Define train_dir if not FLAGS.experiment_name and not FLAGS.train_dir and FLAGS.mode != "official_eval": raise Exception( "You need to specify either --experiment_name or --train_dir") FLAGS.train_dir = FLAGS.train_dir or os.path.join(EXPERIMENTS_DIR, FLAGS.experiment_name) bestmodel_dir = os.path.join(FLAGS.train_dir, "best_checkpoint") # Define path for glove vecs FLAGS.glove_path = FLAGS.glove_path or os.path.join( DEFAULT_DATA_DIR + "/glove.6B/", "glove.6B.{}d.txt".format( FLAGS.embedding_size)) # Load embedding matrix and vocab mappings emb_matrix, word2id, id2word = get_glove(FLAGS.glove_path, FLAGS.embedding_size) # Get filepaths to train/dev datafiles for tokenized queries, contexts and answers train_context_path = os.path.join(FLAGS.data_dir, "train.context") train_qn_path = os.path.join(FLAGS.data_dir, "train.question") train_ans_path = os.path.join(FLAGS.data_dir, "train.span") dev_context_path = os.path.join(FLAGS.data_dir, "dev.context") dev_qn_path = os.path.join(FLAGS.data_dir, "dev.question") dev_ans_path = os.path.join(FLAGS.data_dir, "dev.span") global_step = 1 epoch = 0 print("Beginning training loop...") # Initialize model model = QAModel(FLAGS, id2word, word2id, emb_matrix) optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.learning_rate) while FLAGS.num_epochs == 0 or epoch < FLAGS.num_epochs: epoch += 1 epoch_tic = time.time() for batch in get_batch_generator( \ word2id, train_context_path, train_qn_path, \ train_ans_path, FLAGS.batch_size, context_len=FLAGS.context_len, \ question_len=FLAGS.question_len, discard_long=True): # print(batch.ans_span) with tf.GradientTape() as tape: prob_start, prob_end = model([ batch.context_ids, batch.context_mask, batch.qn_ids, batch.qn_mask ]) # prob_start, prob_end = model(batch.context_ids, batch.context_mask, batch.qn_ids, batch.qn_mask) loss_start = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=prob_start, labels=batch.ans_span[:, 0]) loss_start = tf.reduce_mean(loss_start) loss_end = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=prob_end, labels=batch.ans_span[:, 1]) loss_end = tf.reduce_mean(loss_end) loss = loss_start + loss_end # print("loss %f" % (loss.numpy())) grads = tape.gradient(loss, model.variables) optimizer.apply_gradients( grads_and_vars=zip(grads, model.variables)) if global_step % FLAGS.eval_every == 0: print("==== start evaluating ==== ") dev_f1, dev_em = evaluate(model, word2id, FLAGS, dev_context_path, dev_qn_path, dev_ans_path) print("Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f" % (epoch, global_step, dev_f1, dev_em)) print("==========================") global_step += 1 epoch_toc = time.time() print("End of epoch %i. Time for epoch: %f" % (epoch, epoch_toc - epoch_tic)) sys.stdout.flush()
def train(self, session, train_context_path, train_qn_path, train_ans_path, dev_qn_path, dev_context_path, dev_ans_path): """ Main training loop. Inputs: session: TensorFlow session {train/dev}_{qn/context/ans}_path: paths to {train/dev}.{context/question/answer} data files """ # Print number of model parameters tic = time.time() params = tf.trainable_variables() num_params = sum( map(lambda t: np.prod(tf.shape(t.value()).eval()), params)) toc = time.time() logging.info("Number of params: %d (retrieval took %f secs)" % (num_params, toc - tic)) # We will keep track of exponentially-smoothed loss exp_loss = None # Checkpoint management. # We keep one latest checkpoint, and one best checkpoint (early stopping) checkpoint_path = os.path.join(self.FLAGS.train_dir, "qa.ckpt") bestmodel_dir = os.path.join(self.FLAGS.train_dir, "best_checkpoint") bestmodel_ckpt_path = os.path.join(bestmodel_dir, "qa_best.ckpt") best_dev_f1 = None best_dev_em = None # for TensorBoard summary_writer = tf.summary.FileWriter(self.FLAGS.train_dir, session.graph) epoch = 0 logging.info("Beginning training loop...") while self.FLAGS.num_epochs == 0 or epoch < self.FLAGS.num_epochs: epoch += 1 epoch_tic = time.time() # Loop over batches for batch in get_batch_generator( self.word2id, train_context_path, train_qn_path, train_ans_path, self.FLAGS.batch_size, context_len=self.FLAGS.context_len, question_len=self.FLAGS.question_len, discard_long=True): # Run training iteration iter_tic = time.time() loss, global_step, param_norm, grad_norm = self.run_train_iter( session, batch, summary_writer) iter_toc = time.time() iter_time = iter_toc - iter_tic # Update exponentially-smoothed loss if not exp_loss: # first iter exp_loss = loss else: exp_loss = 0.99 * exp_loss + 0.01 * loss # Sometimes print info to screen if global_step % self.FLAGS.print_every == 0: logging.info( 'epoch %d, iter %d, loss %.5f, smoothed loss %.5f, grad norm %.5f, param norm %.5f, batch time %.3f' % (epoch, global_step, loss, exp_loss, grad_norm, param_norm, iter_time)) # Sometimes save model if global_step % self.FLAGS.save_every == 0: logging.info("Saving to %s..." % checkpoint_path) self.saver.save(session, checkpoint_path, global_step=global_step) # Sometimes evaluate model on dev loss, train F1/EM and dev F1/EM if global_step % self.FLAGS.eval_every == 0: # Get loss for entire dev set and log to tensorboard dev_loss = self.get_dev_loss(session, dev_context_path, dev_qn_path, dev_ans_path) logging.info("Epoch %d, Iter %d, dev loss: %f" % (epoch, global_step, dev_loss)) write_summary(dev_loss, "dev/loss", summary_writer, global_step) # Get F1/EM on train set and log to tensorboard train_f1, train_em = self.check_f1_em(session, train_context_path, train_qn_path, train_ans_path, "train", num_samples=1000) logging.info( "Epoch %d, Iter %d, Train F1 score: %f, Train EM score: %f" % (epoch, global_step, train_f1, train_em)) write_summary(train_f1, "train/F1", summary_writer, global_step) write_summary(train_em, "train/EM", summary_writer, global_step) # Get F1/EM on dev set and log to tensorboard dev_f1, dev_em = self.check_f1_em(session, dev_context_path, dev_qn_path, dev_ans_path, "dev", num_samples=0) logging.info( "Epoch %d, Iter %d, Dev F1 score: %f, Dev EM score: %f" % (epoch, global_step, dev_f1, dev_em)) write_summary(dev_f1, "dev/F1", summary_writer, global_step) write_summary(dev_em, "dev/EM", summary_writer, global_step) # Early stopping based on dev EM. You could switch this to use F1 instead. if best_dev_em is None or dev_em > best_dev_em: best_dev_em = dev_em logging.info("Saving to %s..." % bestmodel_ckpt_path) self.bestmodel_saver.save(session, bestmodel_ckpt_path, global_step=global_step) epoch_toc = time.time() logging.info("End of epoch %i. Time for epoch: %f" % (epoch, epoch_toc - epoch_tic)) sys.stdout.flush()
if __name__ == "__main__": logits_start, probdist_start, logits_end, probdist_end = build_graph() # run the program with tf.Session() as sess: # It is necessary to initialize variables once before running inference. sess.run(tf.global_variables_initializer()) #sess = tf_debug.LocalCLIDebugWrapperSession(sess) for batch in get_batch_generator(word2id, char2id, train_context_path, train_qn_path, train_ans_path, batch_size, context_len, question_len, max_word_len, discard_long=True): # Create batches of data. input_feed = {} input_feed[context_elmo] = batcher.batch_sentences( batch.context_tokens) input_feed[question_elmo] = batcher.batch_sentences( batch.qn_tokens) input_feed[context_ids] = batch.context_ids input_feed[context_mask] = batch.context_mask input_feed[qn_ids] = batch.qn_ids input_feed[qn_mask] = batch.qn_mask