def main(_): import helpers.metrics as metrics from tqdm import tqdm # train_data = loader.load_squad_triples("./data/", False) dev_data = loader.load_squad_triples("./data/", test=True, ans_list=True) # print('Loaded SQuAD with ',len(train_data),' triples') # train_contexts, train_qs, train_as,train_a_pos = zip(*train_data) qa = MpcmQaInstance() qa.load_from_chkpt(FLAGS.model_dir + 'saved/qamaybe') vocab = qa.vocab questions = [ "What colour is the car?", "When was the car made?", "Where was the date?", "What was the dog called?", "Who was the oldest cat?" ] contexts = [ "The car is green, and was built in 1985. This sentence should make it less likely to return the date, when asked about a cat. The oldest cat was called creme puff and lived for many years!" for i in range(len(questions)) ] # print(contexts[0]) f1s = [] ems = [] for x in tqdm(dev_data): ans_pred = qa.get_ans([x[0]], [x[1]])[0] this_f1s = [] this_ems = [] for a in range(len(x[2])): this_ems.append(1.0 * (metrics.normalize_answer(ans_pred) == metrics.normalize_answer(x[2][a]))) this_f1s.append( metrics.f1(metrics.normalize_answer(ans_pred), metrics.normalize_answer(x[2][a]))) ems.append(max(this_ems)) f1s.append(max(this_f1s)) print("EM: ", np.mean(ems), " F1: ", np.mean(f1s))
def main(_): model_type=FLAGS.model_type # chkpt_path = FLAGS.model_dir+'saved/qgen-maluuba-crop-glove-smart' # chkpt_path = FLAGS.model_dir+'qgen-saved/MALUUBA-CROP-LATENT/1533247183' disc_path = FLAGS.model_dir+'saved/discriminator-trained-latent' chkpt_path = FLAGS.model_dir+'qgen/'+ model_type+'/'+FLAGS.eval_model_id # load dataset # train_data = loader.load_squad_triples(FLAGS.data_path, False) dev_data = loader.load_squad_triples(FLAGS.data_path, dev=FLAGS.eval_on_dev, test=FLAGS.eval_on_test) if len(dev_data) < FLAGS.num_eval_samples: exit('***ERROR*** Eval dataset is smaller than the num_eval_samples flag!') if len(dev_data) > FLAGS.num_eval_samples: print('***WARNING*** Eval dataset is larger than the num_eval_samples flag!') # train_contexts_unfilt, _,_,train_a_pos_unfilt = zip(*train_data) dev_contexts_unfilt, _,_,dev_a_pos_unfilt = zip(*dev_data) if FLAGS.filter_window_size_before >-1: # train_data = preprocessing.filter_squad(train_data, window_size=FLAGS.filter_window_size, max_tokens=FLAGS.filter_max_tokens) dev_data = preprocessing.filter_squad(dev_data, window_size_before=FLAGS.filter_window_size_before, window_size_after=FLAGS.filter_window_size_after, max_tokens=FLAGS.filter_max_tokens) # print('Loaded SQuAD with ',len(train_data),' triples') print('Loaded SQuAD dev set with ',len(dev_data),' triples') # train_contexts, train_qs, train_as,train_a_pos = zip(*train_data) dev_contexts, dev_qs, dev_as, dev_a_pos = zip(*dev_data) # vocab = loader.get_vocab(train_contexts, tf.app.flags.FLAGS.vocab_size) with open(chkpt_path+'/vocab.json') as f: vocab = json.load(f) with SquadStreamer(vocab, FLAGS.eval_batch_size, 1, shuffle=False) as dev_data_source: glove_embeddings = loader.load_glove(FLAGS.data_path) # Create model if model_type[:7] == "SEQ2SEQ": model = Seq2SeqModel(vocab, training_mode=False) elif model_type[:2] == "RL": # TEMP - no need to spin up the LM or QA model at eval time FLAGS.qa_weight = 0 FLAGS.lm_weight = 0 model = RLModel(vocab, training_mode=False) else: exit("Unrecognised model type: "+model_type) with model.graph.as_default(): saver = tf.train.Saver() if FLAGS.eval_metrics: lm = LstmLmInstance() # qa = MpcmQaInstance() qa = QANetInstance() lm.load_from_chkpt(FLAGS.model_dir+'saved/lmtest') # qa.load_from_chkpt(FLAGS.model_dir+'saved/qatest') qa.load_from_chkpt(FLAGS.model_dir+'saved/qanet2') discriminator = DiscriminatorInstance(trainable=False, path=disc_path) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_limit) with tf.Session(graph=model.graph, config=tf.ConfigProto(gpu_options=gpu_options)) as sess: if not os.path.exists(chkpt_path): exit('Checkpoint path doesnt exist! '+chkpt_path) # summary_writer = tf.summary.FileWriter(FLAGS.log_directory+"eval/"+str(int(time.time())), sess.graph) saver.restore(sess, tf.train.latest_checkpoint(chkpt_path)) # print('Loading not implemented yet') # else: # sess.run(tf.global_variables_initializer()) # sess.run(model.glove_init_ops) num_steps = FLAGS.num_eval_samples//FLAGS.eval_batch_size # Initialise the dataset # np.random.shuffle(dev_data) dev_data_source.initialise(dev_data) f1s=[] bleus=[] qa_scores=[] qa_scores_gold=[] lm_scores=[] nlls=[] disc_scores=[] sowe_similarities=[] copy_probs=[] qgolds=[] qpreds=[] qpred_ids=[] qgold_ids=[] ctxts=[] answers=[] ans_positions=[] metric_individuals=[] res=[] for e in range(1): for i in tqdm(range(num_steps), desc='Epoch '+str(e)): dev_batch, curr_batch_size = dev_data_source.get_batch() pred_batch,pred_beam,pred_beam_lens,pred_ids,pred_lens,gold_batch, gold_lens,gold_ids,ctxt,ctxt_len,ans,ans_len,nll,copy_prob= sess.run([model.q_hat_beam_string, model.q_hat_full_beam_str, model.q_hat_full_beam_lens,model.q_hat_beam_ids,model.q_hat_beam_lens,model.question_raw, model.question_length, model.question_ids, model.context_raw, model.context_length, model.answer_locs, model.answer_length, model.nll, model.mean_copy_prob], feed_dict={model.input_batch: dev_batch ,model.is_training:False}) unfilt_ctxt_batch = [dev_contexts_unfilt[ix] for ix in dev_batch[3]] a_text_batch = ops.byte_token_array_to_str(dev_batch[2][0], dev_batch[2][2], is_array=False) unfilt_apos_batch = [dev_a_pos_unfilt[ix] for ix in dev_batch[3]] # subtract 1 to remove the "end sent token" pred_q_batch = [q.replace(' </Sent>',"").replace(" <PAD>","") for q in ops.byte_token_array_to_str(pred_batch, pred_lens-1)] ctxts.extend(unfilt_ctxt_batch) answers.extend(a_text_batch) ans_positions.extend([dev_a_pos_unfilt[ix] for ix in dev_batch[3]]) copy_probs.extend(copy_prob.tolist()) # get QA score # gold_str=[] # pred_str=[] gold_ans = ops.byte_token_array_to_str(dev_batch[2][0], dev_batch[2][2], is_array=False) # pred_str = ops.byte_token_array_to_str([dev_batch[0][0][b][qa_pred[b][0]:qa_pred[b][1]] for b in range(curr_batch_size)], is_array=False) nlls.extend(nll.tolist()) if FLAGS.eval_metrics: qa_pred = qa.get_ans(unfilt_ctxt_batch, ops.byte_token_array_to_str(pred_batch, pred_lens)) gold_qa_pred = qa.get_ans(unfilt_ctxt_batch, ops.byte_token_array_to_str(dev_batch[1][0], dev_batch[1][3])) qa_score_batch = [metrics.f1(metrics.normalize_answer(gold_ans[b]), metrics.normalize_answer(qa_pred[b])) for b in range(curr_batch_size)] qa_score_gold_batch = [metrics.f1(metrics.normalize_answer(gold_ans[b]), metrics.normalize_answer(gold_qa_pred[b])) for b in range(curr_batch_size)] lm_score_batch = lm.get_seq_perplexity(pred_q_batch).tolist() disc_score_batch = discriminator.get_pred(unfilt_ctxt_batch, pred_q_batch, gold_ans, unfilt_apos_batch).tolist() for b, pred in enumerate(pred_batch): pred_str = pred_q_batch[b].replace(' </Sent>',"").replace(" <PAD>","") gold_str = tokens_to_string(gold_batch[b][:gold_lens[b]-1]) f1s.append(metrics.f1(gold_str, pred_str)) bleus.append(metrics.bleu(gold_str, pred_str)) qgolds.append(gold_str) qpreds.append(pred_str) # calc cosine similarity between sums of word embeddings pred_sowe = np.sum(np.asarray([glove_embeddings[w] if w in glove_embeddings.keys() else np.zeros((FLAGS.embedding_size,)) for w in preprocessing.tokenise(pred_str ,asbytes=False)]) ,axis=0) gold_sowe = np.sum(np.asarray([glove_embeddings[w] if w in glove_embeddings.keys() else np.zeros((FLAGS.embedding_size,)) for w in preprocessing.tokenise(gold_str ,asbytes=False)]) ,axis=0) this_similarity = np.inner(pred_sowe, gold_sowe)/np.linalg.norm(pred_sowe, ord=2)/np.linalg.norm(gold_sowe, ord=2) sowe_similarities.append(this_similarity) this_metric_dict={ 'f1':f1s[-1], 'bleu': bleus[-1], 'nll': nlls[-1], 'sowe': sowe_similarities[-1] } if FLAGS.eval_metrics: this_metric_dict={ **this_metric_dict, 'qa': qa_score_batch[b], 'lm': lm_score_batch[b], 'disc': disc_score_batch[b]} qa_scores.extend(qa_score_batch) lm_scores.extend(lm_score_batch) disc_scores.extend(disc_score_batch) metric_individuals.append(this_metric_dict) res.append({ 'c':unfilt_ctxt_batch[b], 'q_pred': pred_str, 'q_gold': gold_str, 'a_pos': unfilt_apos_batch[b], 'a_text': a_text_batch[b], 'metrics': this_metric_dict, 'q_pred_ids': pred_ids.tolist()[b], 'q_gold_ids': dev_batch[1][1][b].tolist() }) # Quick output if i==0: # print(copy_prob.tolist()) # print(copy_probs) pred_str = tokens_to_string(pred_batch[0][:pred_lens[0]-1]) gold_str = tokens_to_string(gold_batch[0][:gold_lens[0]-1]) # print(pred_str) print(qpreds[0]) print(gold_str) title=chkpt_path out_str = output_eval(title,pred_batch, pred_ids, pred_lens, gold_batch, gold_lens, ctxt, ctxt_len, ans, ans_len) with open(FLAGS.log_directory+'out_eval_'+model_type+'.htm', 'w', encoding='utf-8') as fp: fp.write(out_str) # res = list(zip(qpreds,qgolds,ctxts,answers,ans_positions,metric_individuals)) metric_dict={ 'f1':np.mean(f1s), 'bleu': metrics.bleu_corpus(qgolds, qpreds), 'nll':np.mean(nlls), 'sowe': np.mean(sowe_similarities) } if FLAGS.eval_metrics: metric_dict={**metric_dict, 'qa':np.mean(qa_scores), 'lm':np.mean(lm_scores), 'disc': np.mean(disc_scores)} # print(res) with open(FLAGS.log_directory+'out_eval_'+model_type+("_test" if FLAGS.eval_on_test else "")+("_train" if (not FLAGS.eval_on_dev and not FLAGS.eval_on_test) else "")+'.json', 'w', encoding='utf-8') as fp: json.dump({"metrics":metric_dict, "results": res}, fp) print("F1: ", np.mean(f1s)) print("BLEU: ", metrics.bleu_corpus(qgolds, qpreds)) print("NLL: ", np.mean(nlls)) print("SOWE: ", np.mean(sowe_similarities)) print("Copy prob: ", np.mean(copy_probs)) if FLAGS.eval_metrics: print("QA: ", np.mean(qa_scores)) print("LM: ", np.mean(lm_scores)) print("Disc: ", np.mean(disc_scores))
def main(_): from tqdm import tqdm FLAGS = tf.app.flags.FLAGS # questions = ["What colour is the car?","When was the car made?","Where was the date?", "What was the dog called?","Who was the oldest cat?"] # contexts=["The car is green, and was built in 1985. This sentence should make it less likely to return the date, when asked about a cat. The oldest cat was called creme puff and lived for many years!" for i in range(len(questions))] trainable = False squad_train_full = loader.load_squad_triples(path="./data/") squad_dev_full = loader.load_squad_triples(path="./data/", dev=True, ans_list=True) para_limit = FLAGS.test_para_limit ques_limit = FLAGS.test_ques_limit char_limit = FLAGS.char_limit def filter_func(example, is_test=False): return len(example["context_tokens"]) > para_limit or \ len(example["ques_tokens"]) > ques_limit or \ (example["y2s"][0] - example["y1s"][0]) > ans_limit qa = QANetInstance() qa.load_from_chkpt("./models/saved/qanet2/", trainable=trainable) squad_train = [] for x in squad_train_full: c_toks = word_tokenize(x[0]) q_toks = word_tokenize(x[1]) if len(c_toks) < para_limit and len(q_toks) < ques_limit: squad_train.append(x) squad_dev = [] for x in squad_dev_full: c_toks = word_tokenize(x[0]) q_toks = word_tokenize(x[1]) if len(c_toks) < para_limit and len(q_toks) < ques_limit: squad_dev.append(x) num_train_steps = len(squad_train) // FLAGS.batch_size num_eval_steps = len(squad_dev) // FLAGS.batch_size best_f1 = 0 if trainable: run_id = str(int(time.time())) chkpt_path = FLAGS.model_dir + 'qanet/' + run_id if not os.path.exists(chkpt_path): os.makedirs(chkpt_path) summary_writer = tf.summary.FileWriter( FLAGS.log_directory + 'qanet/' + run_id, qa.model.graph) for i in tqdm(range(FLAGS.qa_num_epochs * num_train_steps)): if i % num_train_steps == 0: print('Shuffling training set') np.random.shuffle(squad_train) this_batch = squad_train[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] batch_contexts, batch_questions, batch_ans_text, batch_ans_charpos = zip( *this_batch) batch_answers = [] for j, ctxt in enumerate(batch_contexts): ans_span = char_pos_to_word( ctxt.encode(), [t.encode() for t in word_tokenize(ctxt)], batch_ans_charpos[j]) ans_span = (np.eye(FLAGS.test_para_limit)[ans_span], np.eye(FLAGS.test_para_limit) [ans_span + len(word_tokenize(batch_ans_text[j])) - 1]) batch_answers.append(ans_span) this_loss = qa.train_step(batch_contexts, batch_questions, batch_answers) if i % 50 == 0: losssummary = tf.Summary(value=[ tf.Summary.Value(tag="train_loss/loss", simple_value=np.mean(this_loss)) ]) summary_writer.add_summary(losssummary, global_step=i) if i > 0 and i % 1000 == 0: qa_f1s = [] qa_em = [] for j in tqdm(range(num_eval_steps)): this_batch = squad_dev[j * FLAGS.batch_size:(j + 1) * FLAGS.batch_size] spans = qa.get_ans([x[0] for x in this_batch], [x[1] for x in this_batch]) for b in range(len(this_batch)): qa_f1s.append( metrics.f1( metrics.normalize_answer(this_batch[b][2]), metrics.normalize_answer(spans[b]))) qa_em.append( 1.0 * (metrics.normalize_answer(this_batch[b][2]) == metrics.normalize_answer(spans[b]))) f1summary = tf.Summary(value=[ tf.Summary.Value(tag="dev_perf/f1", simple_value=np.mean(qa_f1s)) ]) summary_writer.add_summary(f1summary, global_step=i) if np.mean(qa_f1s) > best_f1: print("New best F1! ", np.mean(qa_f1s), " Saving...") best_f1 = np.mean(qa_f1s) qa.saver.save(qa.sess, chkpt_path + '/model.checkpoint') qa_f1s = [] qa_em = [] for i in tqdm(range(num_eval_steps)): this_batch = squad_dev[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] spans = qa.get_ans([x[0] for x in this_batch], [x[1] for x in this_batch]) for b in range(len(this_batch)): this_f1s = [] this_em = [] for a in range(len(this_batch[b][2])): this_f1s.append( metrics.f1(metrics.normalize_answer(this_batch[b][2][a]), metrics.normalize_answer(spans[b]))) this_em.append(1.0 * (metrics.normalize_answer(this_batch[b][2][a]) == metrics.normalize_answer(spans[b]))) qa_em.append(max(this_em)) qa_f1s.append(max(this_f1s)) if i == 0: print(qa_f1s, qa_em) print(this_batch[0]) print(spans[0]) print('EM: ', np.mean(qa_em)) print('F1: ', np.mean(qa_f1s))
def main(_): if FLAGS.testing: print('TEST MODE - reducing model size') FLAGS.context_encoder_units = 100 FLAGS.answer_encoder_units = 100 FLAGS.decoder_units = 100 FLAGS.batch_size = 8 FLAGS.eval_batch_size = 8 # FLAGS.embedding_size=50 run_id = str(int(time.time())) chkpt_path = FLAGS.model_dir + 'qgen/' + FLAGS.model_type + '/' + run_id restore_path = FLAGS.model_dir + 'qgen/' + FLAGS.restore_path if FLAGS.restore_path is not None else None #'MALUUBA-CROP-LATENT'+'/'+'1534123959' # restore_path=FLAGS.model_dir+'saved/qgen-maluuba-crop-glove-smart' disc_path = FLAGS.model_dir + 'saved/discriminator-trained-latent' print("Run ID is ", run_id) print("Model type is ", FLAGS.model_type) if not os.path.exists(chkpt_path): os.makedirs(chkpt_path) # load dataset train_data = loader.load_squad_triples(FLAGS.data_path, False) dev_data = loader.load_squad_triples(FLAGS.data_path, True) train_contexts_unfilt, _, ans_text_unfilt, ans_pos_unfilt = zip( *train_data) dev_contexts_unfilt, _, dev_ans_text_unfilt, dev_ans_pos_unfilt = zip( *dev_data) if FLAGS.testing: train_data = train_data[:1000] num_dev_samples = 100 else: num_dev_samples = FLAGS.num_dev_samples if FLAGS.filter_window_size_before > -1: train_data = preprocessing.filter_squad( train_data, window_size_before=FLAGS.filter_window_size_before, window_size_after=FLAGS.filter_window_size_after, max_tokens=FLAGS.filter_max_tokens) dev_data = preprocessing.filter_squad( dev_data, window_size_before=FLAGS.filter_window_size_before, window_size_after=FLAGS.filter_window_size_after, max_tokens=FLAGS.filter_max_tokens) print('Loaded SQuAD with ', len(train_data), ' triples') train_contexts, train_qs, train_as, train_a_pos = zip(*train_data) if FLAGS.restore: if restore_path is None: exit('You need to specify a restore path!') with open(restore_path + '/vocab.json', encoding="utf-8") as f: vocab = json.load(f) elif FLAGS.glove_vocab: vocab = loader.get_glove_vocab(FLAGS.data_path, size=FLAGS.vocab_size, d=FLAGS.embedding_size) with open(chkpt_path + '/vocab.json', 'w', encoding="utf-8") as outfile: json.dump(vocab, outfile) else: vocab = loader.get_vocab(train_contexts + train_qs, FLAGS.vocab_size) with open(chkpt_path + '/vocab.json', 'w', encoding="utf-8") as outfile: json.dump(vocab, outfile) # Create model if FLAGS.model_type[:7] == "SEQ2SEQ": model = Seq2SeqModel(vocab, training_mode=True, use_embedding_loss=FLAGS.embedding_loss) elif FLAGS.model_type[:7] == "MALUUBA": # TEMP if not FLAGS.policy_gradient: FLAGS.qa_weight = 0 FLAGS.lm_weight = 0 model = MaluubaModel(vocab, training_mode=True, use_embedding_loss=FLAGS.embedding_loss) # if FLAGS.model_type[:10] == "MALUUBA_RL": # qa_vocab=model.qa.vocab # lm_vocab=model.lm.vocab if FLAGS.policy_gradient: discriminator = DiscriminatorInstance(trainable=FLAGS.disc_train, path=disc_path) else: exit("Unrecognised model type: " + FLAGS.model_type) # create data streamer with SquadStreamer(vocab, FLAGS.batch_size, FLAGS.num_epochs, shuffle=True) as train_data_source, SquadStreamer( vocab, FLAGS.eval_batch_size, 1, shuffle=True) as dev_data_source: with model.graph.as_default(): saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True) # change visible devices if using RL models gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_limit, visible_device_list='0', allow_growth=True) with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=False), graph=model.graph) as sess: summary_writer = tf.summary.FileWriter( FLAGS.log_dir + 'qgen/' + FLAGS.model_type + '/' + run_id, sess.graph) train_data_source.initialise(train_data) num_steps_train = len(train_data) // FLAGS.batch_size num_steps_dev = num_dev_samples // FLAGS.eval_batch_size if FLAGS.restore: saver.restore(sess, tf.train.latest_checkpoint(restore_path)) start_e = 15 #FLAGS.num_epochs print('Loaded model') else: start_e = 0 sess.run(tf.global_variables_initializer()) # sess.run(model.glove_init_ops) f1summary = tf.Summary(value=[ tf.Summary.Value(tag="dev_perf/f1", simple_value=0.0) ]) bleusummary = tf.Summary(value=[ tf.Summary.Value(tag="dev_perf/bleu", simple_value=0.0) ]) summary_writer.add_summary(f1summary, global_step=0) summary_writer.add_summary(bleusummary, global_step=0) # Initialise the dataset # sess.run(model.iterator.initializer, feed_dict={model.context_ph: train_contexts, # model.qs_ph: train_qs, model.as_ph: train_as, model.a_pos_ph: train_a_pos}) best_oos_nll = 1e6 lm_score_moments = online_moments.OnlineMoment() qa_score_moments = online_moments.OnlineMoment() disc_score_moments = online_moments.OnlineMoment() # for e in range(start_e,start_e+FLAGS.num_epochs): # Train for one epoch for i in tqdm(range(num_steps_train * FLAGS.num_epochs), desc='Training'): # Get a batch train_batch, curr_batch_size = train_data_source.get_batch() # Are we doing policy gradient? Do a forward pass first, then build the PG batch and do an update step if FLAGS.model_type[: 10] == "MALUUBA_RL" and FLAGS.policy_gradient: # do a fwd pass first, get the score, then do another pass and optimize qhat_str, qhat_ids, qhat_lens = sess.run( [ model.q_hat_beam_string, model.q_hat_beam_ids, model.q_hat_beam_lens ], feed_dict={ model.input_batch: train_batch, model.is_training: FLAGS.pg_dropout, model.hide_answer_in_copy: True }) # The output is as long as the max allowed len - remove the pointless extra padding qhat_ids = qhat_ids[:, :np.max(qhat_lens)] qhat_str = qhat_str[:, :np.max(qhat_lens)] pred_str = byte_token_array_to_str(qhat_str, qhat_lens - 1) gold_q_str = byte_token_array_to_str( train_batch[1][0], train_batch[1][3]) # Get reward values lm_score = (-1 * model.lm.get_seq_perplexity(pred_str) ).tolist() # lower perplexity is better # retrieve the uncropped context for QA evaluation unfilt_ctxt_batch = [ train_contexts_unfilt[ix] for ix in train_batch[3] ] ans_text_batch = [ ans_text_unfilt[ix] for ix in train_batch[3] ] ans_pos_batch = [ ans_pos_unfilt[ix] for ix in train_batch[3] ] qa_pred = model.qa.get_ans(unfilt_ctxt_batch, pred_str) qa_pred_gold = model.qa.get_ans(unfilt_ctxt_batch, gold_q_str) # gold_str=[] # pred_str=[] qa_f1s = [] gold_ans_str = byte_token_array_to_str(train_batch[2][0], train_batch[2][2], is_array=False) qa_f1s.extend([ metrics.f1(metrics.normalize_answer(gold_ans_str[b]), metrics.normalize_answer(qa_pred[b])) for b in range(curr_batch_size) ]) disc_scores = discriminator.get_pred( unfilt_ctxt_batch, pred_str, ans_text_batch, ans_pos_batch) if i > FLAGS.pg_burnin // 2: lm_score_moments.push(lm_score) qa_score_moments.push(qa_f1s) disc_score_moments.push(disc_scores) # print(disc_scores) # print((e-start_e)*num_steps_train+i, flags.pg_burnin) if i > FLAGS.pg_burnin: # A variant of popart qa_score_whitened = ( qa_f1s - qa_score_moments.mean ) / np.sqrt(qa_score_moments.variance + 1e-6) lm_score_whitened = ( lm_score - lm_score_moments.mean ) / np.sqrt(lm_score_moments.variance + 1e-6) disc_score_whitened = ( disc_scores - disc_score_moments.mean ) / np.sqrt(disc_score_moments.variance + 1e-6) lm_summary = tf.Summary(value=[ tf.Summary.Value(tag="rl_rewards/lm", simple_value=np.mean(lm_score)) ]) summary_writer.add_summary(lm_summary, global_step=(i)) qa_summary = tf.Summary(value=[ tf.Summary.Value(tag="rl_rewards/qa", simple_value=np.mean(qa_f1s)) ]) summary_writer.add_summary(qa_summary, global_step=(i)) disc_summary = tf.Summary(value=[ tf.Summary.Value(tag="rl_rewards/disc", simple_value=np.mean(disc_scores)) ]) summary_writer.add_summary(disc_summary, global_step=(i)) lm_white_summary = tf.Summary(value=[ tf.Summary.Value(tag="rl_rewards/lm_white", simple_value=np.mean( lm_score_whitened)) ]) summary_writer.add_summary(lm_white_summary, global_step=(i)) qa_white_summary = tf.Summary(value=[ tf.Summary.Value(tag="rl_rewards/qa_white", simple_value=np.mean( qa_score_whitened)) ]) summary_writer.add_summary(qa_white_summary, global_step=(i)) disc_white_summary = tf.Summary(value=[ tf.Summary.Value(tag="rl_rewards/disc_white", simple_value=np.mean( disc_score_whitened)) ]) summary_writer.add_summary(disc_white_summary, global_step=(i)) # Build a combined batch - half ground truth for MLE, half generated for PG train_batch_ext = duplicate_batch_and_inject( train_batch, qhat_ids, qhat_str, qhat_lens) # print(qhat_ids) # print(qhat_lens) # print(train_batch_ext[2][2]) rl_dict = { model.lm_score: np.asarray((lm_score_whitened * FLAGS.lm_weight).tolist() + [ FLAGS.pg_ml_weight for b in range(curr_batch_size) ]), model.qa_score: np.asarray((qa_score_whitened * FLAGS.qa_weight).tolist() + [0 for b in range(curr_batch_size)]), model.disc_score: np.asarray((disc_score_whitened * FLAGS.disc_weight).tolist() + [0 for b in range(curr_batch_size)]), model.rl_lm_enabled: True, model.rl_qa_enabled: True, model.rl_disc_enabled: FLAGS.disc_weight > 0, model.step: i - FLAGS.pg_burnin, model.hide_answer_in_copy: True } # perform a policy gradient step, but combine with a XE step by using appropriate rewards ops = [ model.pg_optimizer, model.train_summary, model.q_hat_string ] if i % FLAGS.eval_freq == 0: ops.extend([ model.q_hat_ids, model.question_ids, model.copy_prob, model.question_raw, model.question_length ]) res_offset = 5 else: res_offset = 0 ops.extend([model.lm_loss, model.qa_loss]) res = sess.run(ops, feed_dict={ model.input_batch: train_batch_ext, model.is_training: False, **rl_dict }) summary_writer.add_summary(res[1], global_step=(i)) # Log only the first half of the PG related losses lm_loss_summary = tf.Summary(value=[ tf.Summary.Value( tag="train_loss/lm", simple_value=np.mean(res[3 + res_offset] [:curr_batch_size])) ]) summary_writer.add_summary(lm_loss_summary, global_step=(i)) qa_loss_summary = tf.Summary(value=[ tf.Summary.Value( tag="train_loss/qa", simple_value=np.mean(res[4 + res_offset] [:curr_batch_size])) ]) summary_writer.add_summary(qa_loss_summary, global_step=(i)) # TODO: more principled scheduling here than alternating steps if FLAGS.disc_train: ixs = np.round( np.random.binomial(1, 0.5, curr_batch_size)) qbatch = [ pred_str[ix].replace(" </Sent>", "").replace( " <PAD>", "") if ixs[ix] < 0.5 else gold_q_str[ix].replace( " </Sent>", "").replace(" <PAD>", "") for ix in range(curr_batch_size) ] loss = discriminator.train_step(unfilt_ctxt_batch, qbatch, ans_text_batch, ans_pos_batch, ixs, step=(i)) else: # Normal single pass update step. If model has PG capability, fill in the placeholders with empty values if FLAGS.model_type[: 7] == "MALUUBA" and not FLAGS.policy_gradient: rl_dict = { model.lm_score: [0 for b in range(curr_batch_size)], model.qa_score: [0 for b in range(curr_batch_size)], model.disc_score: [0 for b in range(curr_batch_size)], model.rl_lm_enabled: False, model.rl_qa_enabled: False, model.rl_disc_enabled: False, model.hide_answer_in_copy: False } else: rl_dict = {} # Perform a normal optimizer step ops = [ model.optimizer, model.train_summary, model.q_hat_string ] if i % FLAGS.eval_freq == 0: ops.extend([ model.q_hat_ids, model.question_ids, model.copy_prob, model.question_raw, model.question_length ]) res = sess.run(ops, feed_dict={ model.input_batch: train_batch, model.is_training: True, **rl_dict }) summary_writer.add_summary(res[1], global_step=(i)) # Dump some output periodically if i > 0 and i % FLAGS.eval_freq == 0 and ( i > FLAGS.pg_burnin or not FLAGS.policy_gradient): with open(FLAGS.log_dir + 'out.htm', 'w', encoding='utf-8') as fp: fp.write( output_pretty(res[2].tolist(), res[3], res[4], res[5], 0, i)) gold_batch = res[6] gold_lens = res[7] f1s = [] bleus = [] for b, pred in enumerate(res[2]): pred_str = tokens_to_string(pred[:gold_lens[b] - 1]) gold_str = tokens_to_string( gold_batch[b][:gold_lens[b] - 1]) f1s.append(metrics.f1(gold_str, pred_str)) bleus.append(metrics.bleu(gold_str, pred_str)) f1summary = tf.Summary(value=[ tf.Summary.Value(tag="train_perf/f1", simple_value=sum(f1s) / len(f1s)) ]) bleusummary = tf.Summary(value=[ tf.Summary.Value(tag="train_perf/bleu", simple_value=sum(bleus) / len(bleus)) ]) summary_writer.add_summary(f1summary, global_step=(i)) summary_writer.add_summary(bleusummary, global_step=(i)) # Evaluate against dev set f1s = [] bleus = [] nlls = [] np.random.shuffle(dev_data) dev_subset = dev_data[:num_dev_samples] dev_data_source.initialise(dev_subset) for j in tqdm(range(num_steps_dev), desc='Eval ' + str(i)): dev_batch, curr_batch_size = dev_data_source.get_batch( ) pred_batch, pred_ids, pred_lens, gold_batch, gold_lens, ctxt, ctxt_len, ans, ans_len, nll = sess.run( [ model.q_hat_beam_string, model.q_hat_beam_ids, model.q_hat_beam_lens, model.question_raw, model.question_length, model.context_raw, model.context_length, model.answer_locs, model.answer_length, model.nll ], feed_dict={ model.input_batch: dev_batch, model.is_training: False }) nlls.extend(nll.tolist()) # out_str="<h1>"+str(e)+' - '+str(datetime.datetime.now())+'</h1>' for b, pred in enumerate(pred_batch): pred_str = tokens_to_string( pred[:pred_lens[b] - 1]).replace( ' </Sent>', "").replace(" <PAD>", "") gold_str = tokens_to_string( gold_batch[b][:gold_lens[b] - 1]) f1s.append(metrics.f1(gold_str, pred_str)) bleus.append(metrics.bleu(gold_str, pred_str)) # out_str+=pred_str.replace('>','>').replace('<','<')+"<br/>"+gold_str.replace('>','>').replace('<','<')+"<hr/>" if j == 0: title = chkpt_path out_str = output_eval(title, pred_batch, pred_ids, pred_lens, gold_batch, gold_lens, ctxt, ctxt_len, ans, ans_len) with open(FLAGS.log_dir + 'out_eval_' + FLAGS.model_type + '.htm', 'w', encoding='utf-8') as fp: fp.write(out_str) f1summary = tf.Summary(value=[ tf.Summary.Value(tag="dev_perf/f1", simple_value=sum(f1s) / len(f1s)) ]) bleusummary = tf.Summary(value=[ tf.Summary.Value(tag="dev_perf/bleu", simple_value=sum(bleus) / len(bleus)) ]) nllsummary = tf.Summary(value=[ tf.Summary.Value(tag="dev_perf/nll", simple_value=sum(nlls) / len(nlls)) ]) summary_writer.add_summary(f1summary, global_step=i) summary_writer.add_summary(bleusummary, global_step=i) summary_writer.add_summary(nllsummary, global_step=i) mean_nll = sum(nlls) / len(nlls) if mean_nll < best_oos_nll: print("New best NLL! ", mean_nll, " Saving...") best_oos_nll = mean_nll saver.save(sess, chkpt_path + '/model.checkpoint', global_step=i) else: print("NLL not improved ", mean_nll) if FLAGS.policy_gradient: print("Saving anyway") saver.save(sess, chkpt_path + '/model.checkpoint', global_step=i) if FLAGS.disc_train: print("Saving disc") discriminator.save_to_chkpt(FLAGS.model_dir, i)
def main(_): if FLAGS.testing: print('TEST MODE - reducing model size') FLAGS.qa_encoder_units =32 FLAGS.qa_match_units=32 FLAGS.qa_batch_size =16 FLAGS.embedding_size=50 run_id = str(int(time.time())) chkpt_path = FLAGS.model_dir+'qa/'+run_id restore_path=FLAGS.model_dir+'qa/1529056867' if not os.path.exists(chkpt_path): os.makedirs(chkpt_path) train_data = loader.load_squad_triples(FLAGS.data_path, False) dev_data = loader.load_squad_triples(FLAGS.data_path, dev=True, ans_list=True) train_data = filter_squad(train_data, window_size=FLAGS.filter_window_size, max_tokens=FLAGS.filter_max_tokens) # dev_data = filter_squad(dev_data, window_size=FLAGS.filter_window_size, max_tokens=FLAGS.filter_max_tokens) if FLAGS.testing: train_data=train_data[:1000] num_dev_samples=100 else: num_dev_samples=3000 print('Loaded SQuAD with ',len(train_data),' triples') train_contexts, train_qs, train_as,train_a_pos = zip(*train_data) dev_contexts, dev_qs, dev_as,dev_a_pos = zip(*dev_data) if FLAGS.restore: with open(restore_path+'/vocab.json') as f: vocab = json.load(f) else: vocab = loader.get_vocab(train_contexts+train_qs, tf.app.flags.FLAGS.qa_vocab_size) with open(chkpt_path+'/vocab.json', 'w') as outfile: json.dump(vocab, outfile) model = MpcmQa(vocab) with model.graph.as_default(): saver = tf.train.Saver() gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_limit, allow_growth = True) with tf.Session(graph=model.graph, config=tf.ConfigProto(gpu_options=gpu_options)) as sess: summary_writer = tf.summary.FileWriter(FLAGS.log_directory+'qa/'+run_id, sess.graph) if FLAGS.restore: saver.restore(sess, restore_path+ '/model.checkpoint') start_e=40#FLAGS.qa_num_epochs print('Loaded model') else: print("Building graph, loading glove") start_e=0 sess.run(tf.global_variables_initializer()) num_steps_train = len(train_data)//FLAGS.qa_batch_size num_steps_dev = num_dev_samples//FLAGS.qa_batch_size f1summary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/f1", simple_value=0.0)]) emsummary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/em", simple_value=0.0)]) summary_writer.add_summary(f1summary, global_step=start_e*num_steps_train) summary_writer.add_summary(emsummary, global_step=start_e*num_steps_train) best_oos_nll=1e6 for e in range(start_e,start_e+FLAGS.qa_num_epochs): np.random.shuffle(train_data) train_contexts, train_qs, train_as,train_a_pos = zip(*train_data) for i in tqdm(range(num_steps_train), desc='Epoch '+str(e)): # TODO: this keeps coming up - refactor it batch_contexts = train_contexts[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size] batch_questions = train_qs[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size] batch_ans_text = train_as[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size] batch_answer_charpos = train_a_pos[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size] batch_answers=[] for j, ctxt in enumerate(batch_contexts): ans_span=char_pos_to_word(ctxt.encode(), [t.encode() for t in tokenise(ctxt, asbytes=False)], batch_answer_charpos[j]) ans_span=(ans_span, ans_span+len(tokenise(batch_ans_text[j],asbytes=False))-1) batch_answers.append(ans_span) # print(batch_answers[:3]) # exit() # run_metadata = tf.RunMetadata() # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) _,summ, pred = sess.run([model.optimizer, model.train_summary, model.pred_span], feed_dict={model.context_in: get_padded_batch(batch_contexts,vocab), model.question_in: get_padded_batch(batch_questions,vocab), model.answer_spans_in: batch_answers, model.is_training: True}) # ,run_metadata=run_metadata, options=run_options) summary_writer.add_summary(summ, global_step=(e*num_steps_train+i)) # summary_writer.add_run_metadata(run_metadata, tag="step "+str(i), global_step=(e*num_steps_train+i)) if i%FLAGS.eval_freq==0: gold_str=[] pred_str=[] f1s = [] exactmatches= [] for b in range(FLAGS.qa_batch_size): gold_str.append(" ".join(tokenise(batch_contexts[b],asbytes=False)[batch_answers[b][0]:batch_answers[b][1]+1])) pred_str.append( " ".join(tokenise(batch_contexts[b],asbytes=False)[pred[b][0]:pred[b][1]+1]) ) f1s.extend([f1(gold_str[b], pred_str[b]) for b in range(FLAGS.qa_batch_size)]) exactmatches.extend([ np.product(pred[b] == batch_answers[b])*1.0 for b in range(FLAGS.qa_batch_size) ]) f1summary = tf.Summary(value=[tf.Summary.Value(tag="train_perf/f1", simple_value=sum(f1s)/len(f1s))]) emsummary = tf.Summary(value=[tf.Summary.Value(tag="train_perf/em", simple_value=sum(exactmatches)/len(exactmatches))]) summary_writer.add_summary(f1summary, global_step=(e*num_steps_train+i)) summary_writer.add_summary(emsummary, global_step=(e*num_steps_train+i)) # saver.save(sess, chkpt_path+'/model.checkpoint') f1s=[] exactmatches=[] nlls=[] np.random.shuffle(dev_data) dev_subset = dev_data[:num_dev_samples] for i in tqdm(range(num_steps_dev), desc='Eval '+str(e)): dev_contexts,dev_qs,dev_as,dev_a_pos = zip(*dev_subset) batch_contexts = dev_contexts[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size] batch_questions = dev_qs[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size] batch_ans_text = dev_as[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size] batch_answer_charpos = dev_a_pos[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size] batch_answers=[] for j, ctxt in enumerate(batch_contexts): ans_span=char_pos_to_word(ctxt.encode(), [t.encode() for t in tokenise(ctxt, asbytes=False)], batch_answer_charpos[j][0]) ans_span=(ans_span, ans_span+len(tokenise(batch_ans_text[j][0],asbytes=False))-1) batch_answers.append(ans_span) pred,nll = sess.run([model.pred_span, model.nll], feed_dict={model.context_in: get_padded_batch(batch_contexts,vocab), model.question_in: get_padded_batch(batch_questions,vocab), model.answer_spans_in: batch_answers, model.is_training: False}) gold_str=[] pred_str=[] for b in range(FLAGS.qa_batch_size): pred_str = " ".join(tokenise(batch_contexts[b],asbytes=False)[pred[b][0]:pred[b][1]+1]) this_f1=[] this_em=[] for a in range(len(batch_ans_text[b])): this_f1.append(f1(normalize_answer(batch_ans_text[b][a]), normalize_answer(pred_str))) this_em.append(1.0*(normalize_answer(batch_ans_text[b][a]) == normalize_answer(pred_str))) f1s.append(max(this_f1)) exactmatches.append(max(this_em)) nlls.extend(nll.tolist()) f1summary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/f1", simple_value=sum(f1s)/len(f1s))]) emsummary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/em", simple_value=sum(exactmatches)/len(exactmatches))]) nllsummary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/nll", simple_value=np.mean(nlls))]) summary_writer.add_summary(f1summary, global_step=((e+1)*num_steps_train)) summary_writer.add_summary(emsummary, global_step=((e+1)*num_steps_train)) summary_writer.add_summary(nllsummary, global_step=((e+1)*num_steps_train)) mean_nll=np.mean(nlls) if mean_nll < best_oos_nll: print("New best NLL! ", mean_nll, " Saving... F1: ", np.mean(f1s)) best_oos_nll = mean_nll saver.save(sess, chkpt_path+'/model.checkpoint') else: print("NLL not improved ", mean_nll)