Example #1
0
def main(_):

    model_type=FLAGS.model_type
    # chkpt_path = FLAGS.model_dir+'saved/qgen-maluuba-crop-glove-smart'
    # chkpt_path = FLAGS.model_dir+'qgen-saved/MALUUBA-CROP-LATENT/1533247183'
    disc_path = FLAGS.model_dir+'saved/discriminator-trained-latent'
    chkpt_path = FLAGS.model_dir+'qgen/'+ model_type+'/'+FLAGS.eval_model_id

    # load dataset
    # train_data = loader.load_squad_triples(FLAGS.data_path, False)
    dev_data = loader.load_squad_triples(FLAGS.data_path, dev=FLAGS.eval_on_dev, test=FLAGS.eval_on_test)

    # train_contexts_unfilt, _,_,train_a_pos_unfilt = zip(*train_data)
    dev_contexts_unfilt, _,_,dev_a_pos_unfilt = zip(*dev_data)

    if FLAGS.filter_window_size >-1:
        # train_data = preprocessing.filter_squad(train_data, window_size=FLAGS.filter_window_size, max_tokens=FLAGS.filter_max_tokens)
        dev_data = preprocessing.filter_squad(dev_data, window_size=FLAGS.filter_window_size, max_tokens=FLAGS.filter_max_tokens)


    # print('Loaded SQuAD with ',len(train_data),' triples')
    print('Loaded SQuAD dev set with ',len(dev_data),' triples')
    # train_contexts, train_qs, train_as,train_a_pos = zip(*train_data)
    dev_contexts, dev_qs, dev_as, dev_a_pos = zip(*dev_data)


    # vocab = loader.get_vocab(train_contexts, tf.app.flags.FLAGS.vocab_size)
    with open(chkpt_path+'/vocab.json') as f:
        vocab = json.load(f)

    dev_data_source = SquadStreamer(vocab, FLAGS.eval_batch_size, 1, shuffle=False)

    glove_embeddings = loader.load_glove(FLAGS.data_path)


    # Create model
    if model_type[:7] == "SEQ2SEQ":
        model = Seq2SeqModel(vocab, training_mode=False)
    elif model_type[:7] == "MALUUBA":
        # TEMP - no need to spin up the LM or QA model at eval time
        FLAGS.qa_weight = 0
        FLAGS.lm_weight = 0
        model = MaluubaModel(vocab, training_mode=False)
    else:
        exit("Unrecognised model type: "+model_type)

    with model.graph.as_default():
        saver = tf.train.Saver()

    if FLAGS.eval_metrics:
        lm = LstmLmInstance()
        # qa = MpcmQaInstance()
        qa = QANetInstance()

        lm.load_from_chkpt(FLAGS.model_dir+'saved/lmtest')
        # qa.load_from_chkpt(FLAGS.model_dir+'saved/qatest')
        qa.load_from_chkpt(FLAGS.model_dir+'saved/qanet2')

        discriminator = DiscriminatorInstance(trainable=False, path=disc_path)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_limit)
    with tf.Session(graph=model.graph, config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        if not os.path.exists(chkpt_path):
            exit('Checkpoint path doesnt exist! '+chkpt_path)
        # summary_writer = tf.summary.FileWriter(FLAGS.log_dir+"eval/"+str(int(time.time())), sess.graph)

        saver.restore(sess, tf.train.latest_checkpoint(chkpt_path))
        # print('Loading not implemented yet')
        # else:
        #     sess.run(tf.global_variables_initializer())
        #     sess.run(model.glove_init_ops)

        num_steps = FLAGS.num_eval_samples//FLAGS.eval_batch_size

        # Initialise the dataset

        # np.random.shuffle(dev_data)
        dev_data_source.initialise(dev_data)

        f1s=[]
        bleus=[]
        qa_scores=[]
        qa_scores_gold=[]
        lm_scores=[]
        nlls=[]
        disc_scores=[]
        sowe_similarities=[]
        copy_probs=[]

        qgolds=[]
        qpreds=[]
        ctxts=[]
        answers=[]
        ans_positions=[]

        metric_individuals=[]
        res=[]
        for e in range(1):
            for i in tqdm(range(num_steps), desc='Epoch '+str(e)):
                dev_batch, curr_batch_size = dev_data_source.get_batch()
                pred_batch,pred_beam,pred_beam_lens,pred_ids,pred_lens,gold_batch, gold_lens,gold_ids,ctxt,ctxt_len,ans,ans_len,nll,copy_prob= sess.run([model.q_hat_beam_string, model.q_hat_full_beam_str, model.q_hat_full_beam_lens,model.q_hat_beam_ids,model.q_hat_beam_lens,model.question_raw, model.question_length, model.question_ids, model.context_raw, model.context_length, model.answer_locs, model.answer_length, model.nll, model.mean_copy_prob], feed_dict={model.input_batch: dev_batch ,model.is_training:False})

                unfilt_ctxt_batch = [dev_contexts_unfilt[ix] for ix in dev_batch[3]]
                a_text_batch = ops.byte_token_array_to_str(dev_batch[2][0], dev_batch[2][2], is_array=False)
                unfilt_apos_batch = [dev_a_pos_unfilt[ix] for ix in dev_batch[3]]

                # subtract 1 to remove the "end sent token"
                pred_q_batch = [q.replace(' </Sent>',"").replace(" <PAD>","") for q in ops.byte_token_array_to_str(pred_batch, pred_lens-1)]

                ctxts.extend(unfilt_ctxt_batch)
                answers.extend(a_text_batch)
                ans_positions.extend([dev_a_pos_unfilt[ix] for ix in dev_batch[3]])
                copy_probs.extend(copy_prob.tolist())

                # get QA score

                # gold_str=[]
                # pred_str=[]


                gold_ans = ops.byte_token_array_to_str(dev_batch[2][0], dev_batch[2][2], is_array=False)
                # pred_str = ops.byte_token_array_to_str([dev_batch[0][0][b][qa_pred[b][0]:qa_pred[b][1]] for b in range(curr_batch_size)], is_array=False)
                nlls.extend(nll.tolist())

                if FLAGS.eval_metrics:
                    qa_pred = qa.get_ans(unfilt_ctxt_batch, ops.byte_token_array_to_str(pred_batch, pred_lens))
                    gold_qa_pred = qa.get_ans(unfilt_ctxt_batch, ops.byte_token_array_to_str(dev_batch[1][0], dev_batch[1][3]))

                    qa_score_batch = [metrics.f1(metrics.normalize_answer(gold_ans[b]), metrics.normalize_answer(qa_pred[b])) for b in range(curr_batch_size)]
                    qa_score_gold_batch = [metrics.f1(metrics.normalize_answer(gold_ans[b]), metrics.normalize_answer(gold_qa_pred[b])) for b in range(curr_batch_size)]
                    lm_score_batch = lm.get_seq_perplexity(pred_q_batch).tolist()
                    disc_score_batch = discriminator.get_pred(unfilt_ctxt_batch, pred_q_batch, gold_ans, unfilt_apos_batch).tolist()

                for b, pred in enumerate(pred_batch):
                    pred_str = pred_q_batch[b].replace(' </Sent>',"").replace(" <PAD>","")
                    gold_str = tokens_to_string(gold_batch[b][:gold_lens[b]-1])
                    f1s.append(metrics.f1(gold_str, pred_str))
                    bleus.append(metrics.bleu(gold_str, pred_str))
                    qgolds.append(gold_str)
                    qpreds.append(pred_str)

                    # calc cosine similarity between sums of word embeddings
                    pred_sowe = np.sum(np.asarray([glove_embeddings[w] if w in glove_embeddings.keys() else np.zeros((FLAGS.embedding_size,)) for w in preprocessing.tokenise(pred_str ,asbytes=False)]) ,axis=0)
                    gold_sowe = np.sum(np.asarray([glove_embeddings[w] if w in glove_embeddings.keys() else np.zeros((FLAGS.embedding_size,)) for w in preprocessing.tokenise(gold_str ,asbytes=False)]) ,axis=0)
                    this_similarity = np.inner(pred_sowe, gold_sowe)/np.linalg.norm(pred_sowe, ord=2)/np.linalg.norm(gold_sowe, ord=2)

                    sowe_similarities.append(this_similarity)



                    this_metric_dict={
                        'f1':f1s[-1],
                        'bleu': bleus[-1],
                        'nll': nlls[-1],
                        'sowe': sowe_similarities[-1]
                        }
                    if FLAGS.eval_metrics:
                        this_metric_dict={
                        **this_metric_dict,
                        'qa': qa_score_batch[b],
                        'lm': lm_score_batch[b],
                        'disc': disc_score_batch[b]}
                        qa_scores.extend(qa_score_batch)
                        lm_scores.extend(lm_score_batch)
                        disc_scores.extend(disc_score_batch)
                    metric_individuals.append(this_metric_dict)

                    res.append({
                        'c':unfilt_ctxt_batch[b],
                        'q_pred': pred_str,
                        'q_gold': gold_str,
                        'a_pos': unfilt_apos_batch[b],
                        'a_text': a_text_batch[b],
                        'metrics': this_metric_dict
                    })

                # Quick output
                if i==0:
                    # print(copy_prob.tolist())
                    # print(copy_probs)
                    pred_str = tokens_to_string(pred_batch[0][:pred_lens[0]-1])
                    gold_str = tokens_to_string(gold_batch[0][:gold_lens[0]-1])
                    # print(pred_str)
                    print(qpreds[0])
                    print(gold_str)


                    title=chkpt_path
                    out_str = output_eval(title,pred_batch,  pred_ids, pred_lens, gold_batch, gold_lens, ctxt, ctxt_len, ans, ans_len)
                    with open(FLAGS.log_dir+'out_eval_'+model_type+'.htm', 'w', encoding='utf-8') as fp:
                        fp.write(out_str)

        # res = list(zip(qpreds,qgolds,ctxts,answers,ans_positions,metric_individuals))
        metric_dict={
            'f1':np.mean(f1s),
            'bleu': metrics.bleu_corpus(qgolds, qpreds),
            'nll':np.mean(nlls),
            'sowe': np.mean(sowe_similarities)
            }
        if FLAGS.eval_metrics:
            metric_dict={**metric_dict,
            'qa':np.mean(qa_scores),
            'lm':np.mean(lm_scores),
            'disc': np.mean(disc_scores)}
        # print(res)
        with open(FLAGS.log_dir+'out_eval_'+model_type+("_test" if FLAGS.eval_on_test else "")+("_train" if (not FLAGS.eval_on_dev and not FLAGS.eval_on_test) else "")+'.json', 'w', encoding='utf-8') as fp:
            json.dump({"metrics":metric_dict, "results": res}, fp)


        print("F1: ", np.mean(f1s))
        print("BLEU: ", metrics.bleu_corpus(qgolds, qpreds))
        print("NLL: ", np.mean(nlls))
        print("SOWE: ", np.mean(sowe_similarities))

        print("Copy prob: ", np.mean(copy_probs))
        if FLAGS.eval_metrics:
            print("QA: ", np.mean(qa_scores))
            print("LM: ", np.mean(lm_scores))
            print("Disc: ", np.mean(disc_scores))
Example #2
0
def main(_):
    if FLAGS.testing:
        print('TEST MODE - reducing model size')
        FLAGS.context_encoder_units =100
        FLAGS.answer_encoder_units=100
        FLAGS.decoder_units=100
        FLAGS.batch_size =8
        FLAGS.eval_batch_size=8
        # FLAGS.embedding_size=50

    run_id = str(int(time.time()))
    chkpt_path = FLAGS.model_dir+'qgen/'+FLAGS.model_type+'/'+run_id
    restore_path=FLAGS.model_dir+'qgen/'+ FLAGS.restore_path if FLAGS.restore_path is not None else None#'MALUUBA-CROP-LATENT'+'/'+'1534123959'
    # restore_path=FLAGS.model_dir+'saved/qgen-maluuba-crop-glove-smart'
    disc_path = FLAGS.model_dir+'saved/discriminator-trained-latent'

    print("Run ID is ", run_id)
    print("Model type is ", FLAGS.model_type)

    if not os.path.exists(chkpt_path):
        os.makedirs(chkpt_path)

    # load dataset
    train_data = loader.load_squad_triples(FLAGS.data_path, False)
    dev_data = loader.load_squad_triples(FLAGS.data_path, True)

    train_contexts_unfilt, _,ans_text_unfilt,ans_pos_unfilt = zip(*train_data)
    dev_contexts_unfilt, _,dev_ans_text_unfilt,dev_ans_pos_unfilt = zip(*dev_data)

    if FLAGS.testing:
        train_data=train_data[:1000]
        num_dev_samples=100
    else:
        num_dev_samples=FLAGS.num_dev_samples

    if FLAGS.filter_window_size >-1:
        train_data = preprocessing.filter_squad(train_data, window_size=FLAGS.filter_window_size, max_tokens=FLAGS.filter_max_tokens)
        dev_data = preprocessing.filter_squad(dev_data, window_size=FLAGS.filter_window_size, max_tokens=FLAGS.filter_max_tokens)



    print('Loaded SQuAD with ',len(train_data),' triples')
    train_contexts, train_qs, train_as,train_a_pos = zip(*train_data)

    if FLAGS.restore:
        if restore_path is None:
            exit('You need to specify a restore path!')
        with open(restore_path+'/vocab.json', encoding="utf-8") as f:
            vocab = json.load(f)
    elif FLAGS.glove_vocab:
        vocab = loader.get_glove_vocab(FLAGS.data_path, size=FLAGS.vocab_size, d=FLAGS.embedding_size)
        with open(chkpt_path+'/vocab.json', 'w', encoding="utf-8") as outfile:
            json.dump(vocab, outfile)
    else:
        vocab = loader.get_vocab(train_contexts+train_qs, FLAGS.vocab_size)
        with open(chkpt_path+'/vocab.json', 'w', encoding="utf-8") as outfile:
            json.dump(vocab, outfile)



    # Create model
    if FLAGS.model_type[:7] == "SEQ2SEQ":
        model = Seq2SeqModel(vocab, training_mode=True, use_embedding_loss=FLAGS.embedding_loss)
    elif FLAGS.model_type[:7] == "MALUUBA":
        # TEMP
        if not FLAGS.policy_gradient:
            FLAGS.qa_weight = 0
            FLAGS.lm_weight = 0
        model = MaluubaModel(vocab, training_mode=True, use_embedding_loss=FLAGS.embedding_loss)
        # if FLAGS.model_type[:10] == "MALUUBA_RL":
        #     qa_vocab=model.qa.vocab
        #     lm_vocab=model.lm.vocab
        if FLAGS.policy_gradient:
            discriminator = DiscriminatorInstance(trainable=FLAGS.disc_train, path=disc_path)
    else:
        exit("Unrecognised model type: "+FLAGS.model_type)

    # create data streamer
    train_data_source = SquadStreamer(vocab, FLAGS.batch_size, FLAGS.num_epochs, shuffle=True)
    dev_data_source = SquadStreamer(vocab, FLAGS.eval_batch_size, 1, shuffle=True)

    with model.graph.as_default():
        saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)


    # change visible devices if using RL models
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_limit, visible_device_list='0',allow_growth = True)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=False), graph=model.graph) as sess:

        summary_writer = tf.summary.FileWriter(FLAGS.log_dir+'qgen/'+FLAGS.model_type+'/'+run_id, sess.graph)

        train_data_source.initialise(train_data)

        num_steps_train = len(train_data)//FLAGS.batch_size
        num_steps_dev = num_dev_samples//FLAGS.eval_batch_size

        if FLAGS.restore:
            saver.restore(sess, tf.train.latest_checkpoint(restore_path))
            start_e=15#FLAGS.num_epochs
            print('Loaded model')
        else:
            start_e=0
            sess.run(tf.global_variables_initializer())
            # sess.run(model.glove_init_ops)

            f1summary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/f1",
                                             simple_value=0.0)])
            bleusummary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/bleu",
                                      simple_value=0.0)])

            summary_writer.add_summary(f1summary, global_step=0)
            summary_writer.add_summary(bleusummary, global_step=0)


        # Initialise the dataset
        # sess.run(model.iterator.initializer, feed_dict={model.context_ph: train_contexts,
        #                                   model.qs_ph: train_qs, model.as_ph: train_as, model.a_pos_ph: train_a_pos})



        best_oos_nll=1e6

        lm_score_moments = online_moments.OnlineMoment()
        qa_score_moments = online_moments.OnlineMoment()
        disc_score_moments = online_moments.OnlineMoment()

        # for e in range(start_e,start_e+FLAGS.num_epochs):
            # Train for one epoch
        for i in tqdm(range(num_steps_train*FLAGS.num_epochs), desc='Training'):
            # Get a batch
            train_batch, curr_batch_size = train_data_source.get_batch()

            # Are we doing policy gradient? Do a forward pass first, then build the PG batch and do an update step
            if FLAGS.model_type[:10] == "MALUUBA_RL" and FLAGS.policy_gradient:

                # do a fwd pass first, get the score, then do another pass and optimize
                qhat_str,qhat_ids, qhat_lens= sess.run([model.q_hat_beam_string, model.q_hat_beam_ids, model.q_hat_beam_lens],
                    feed_dict={model.input_batch: train_batch,
                    model.is_training: FLAGS.pg_dropout,
                    model.hide_answer_in_copy: True})

                # The output is as long as the max allowed len - remove the pointless extra padding
                qhat_ids = qhat_ids[:,:np.max(qhat_lens)]
                qhat_str = qhat_str[:,:np.max(qhat_lens)]

                pred_str = byte_token_array_to_str(qhat_str, qhat_lens-1)
                gold_q_str = byte_token_array_to_str(train_batch[1][0], train_batch[1][3])

                # Get reward values
                lm_score = (-1*model.lm.get_seq_perplexity(pred_str)).tolist() # lower perplexity is better




                # retrieve the uncropped context for QA evaluation
                unfilt_ctxt_batch = [train_contexts_unfilt[ix] for ix in train_batch[3]]
                ans_text_batch = [ans_text_unfilt[ix] for ix in train_batch[3]]
                ans_pos_batch = [ans_pos_unfilt[ix] for ix in train_batch[3]]

                qa_pred = model.qa.get_ans(unfilt_ctxt_batch, pred_str)
                qa_pred_gold = model.qa.get_ans(unfilt_ctxt_batch, gold_q_str)

                # gold_str=[]
                # pred_str=[]
                qa_f1s = []
                gold_ans_str = byte_token_array_to_str(train_batch[2][0], train_batch[2][2], is_array=False)


                qa_f1s.extend([metrics.f1(metrics.normalize_answer(gold_ans_str[b]), metrics.normalize_answer(qa_pred[b])) for b in range(curr_batch_size)])

                disc_scores = discriminator.get_pred(unfilt_ctxt_batch, pred_str, ans_text_batch, ans_pos_batch )

                if i > FLAGS.pg_burnin//2:
                    lm_score_moments.push(lm_score)
                    qa_score_moments.push(qa_f1s)
                    disc_score_moments.push(disc_scores)


                # print(disc_scores)
                # print((e-start_e)*num_steps_train+i, flags.pg_burnin)

                if i > FLAGS.pg_burnin:
                    # A variant of popart
                    qa_score_whitened = (qa_f1s-qa_score_moments.mean)/np.sqrt(qa_score_moments.variance+1e-6)
                    lm_score_whitened = (lm_score-lm_score_moments.mean)/np.sqrt(lm_score_moments.variance+1e-6)
                    disc_score_whitened = (disc_scores-disc_score_moments.mean)/np.sqrt(disc_score_moments.variance+1e-6)

                    lm_summary = tf.Summary(value=[tf.Summary.Value(tag="rl_rewards/lm",
                                                     simple_value=np.mean(lm_score))])
                    summary_writer.add_summary(lm_summary, global_step=(i))
                    qa_summary = tf.Summary(value=[tf.Summary.Value(tag="rl_rewards/qa",
                                                     simple_value=np.mean(qa_f1s))])
                    summary_writer.add_summary(qa_summary, global_step=(i))
                    disc_summary = tf.Summary(value=[tf.Summary.Value(tag="rl_rewards/disc",
                                                     simple_value=np.mean(disc_scores))])
                    summary_writer.add_summary(disc_summary, global_step=(i))

                    lm_white_summary = tf.Summary(value=[tf.Summary.Value(tag="rl_rewards/lm_white",
                                                     simple_value=np.mean(lm_score_whitened))])
                    summary_writer.add_summary(lm_white_summary, global_step=(i))
                    qa_white_summary = tf.Summary(value=[tf.Summary.Value(tag="rl_rewards/qa_white",
                                                     simple_value=np.mean(qa_score_whitened))])
                    summary_writer.add_summary(qa_white_summary, global_step=(i))
                    disc_white_summary = tf.Summary(value=[tf.Summary.Value(tag="rl_rewards/disc_white",
                                                     simple_value=np.mean(disc_score_whitened))])
                    summary_writer.add_summary(disc_white_summary, global_step=(i))

                    # Build a combined batch - half ground truth for MLE, half generated for PG
                    train_batch_ext = duplicate_batch_and_inject(train_batch, qhat_ids, qhat_str, qhat_lens)

                    # print(qhat_ids)
                    # print(qhat_lens)
                    # print(train_batch_ext[2][2])

                    rl_dict={model.lm_score: np.asarray((lm_score_whitened*FLAGS.lm_weight).tolist()+[FLAGS.pg_ml_weight for b in range(curr_batch_size)]),
                        model.qa_score: np.asarray((qa_score_whitened*FLAGS.qa_weight).tolist()+[0 for b in range(curr_batch_size)]),
                        model.disc_score: np.asarray((disc_score_whitened*FLAGS.disc_weight).tolist()+[0 for b in range(curr_batch_size)]),
                        model.rl_lm_enabled: True,
                        model.rl_qa_enabled: True,
                        model.rl_disc_enabled: FLAGS.disc_weight > 0,
                        model.step: i-FLAGS.pg_burnin,
                        model.hide_answer_in_copy: True}

                    # perform a policy gradient step, but combine with a XE step by using appropriate rewards
                    ops = [model.pg_optimizer, model.train_summary,model.q_hat_string]
                    if i%FLAGS.eval_freq==0:
                        ops.extend([ model.q_hat_ids, model.question_ids, model.copy_prob, model.question_raw, model.question_length])
                        res_offset = 5
                    else:
                        res_offset=0
                    ops.extend([model.lm_loss, model.qa_loss])
                    res= sess.run(ops, feed_dict={model.input_batch: train_batch_ext,
                        model.is_training:False,
                        **rl_dict})
                    summary_writer.add_summary(res[1], global_step=(i))

                    # Log only the first half of the PG related losses
                    lm_loss_summary = tf.Summary(value=[tf.Summary.Value(tag="train_loss/lm",
                                                     simple_value=np.mean(res[3+res_offset][:curr_batch_size]))])
                    summary_writer.add_summary(lm_loss_summary, global_step=(i))
                    qa_loss_summary = tf.Summary(value=[tf.Summary.Value(tag="train_loss/qa",
                                                     simple_value=np.mean(res[4+res_offset][:curr_batch_size]))])
                    summary_writer.add_summary(qa_loss_summary, global_step=(i))

                # TODO: more principled scheduling here than alternating steps
                if FLAGS.disc_train:
                    ixs = np.round(np.random.binomial(1,0.5,curr_batch_size))
                    qbatch = [pred_str[ix].replace(" </Sent>","").replace(" <PAD>","") if ixs[ix] < 0.5 else gold_q_str[ix].replace(" </Sent>","").replace(" <PAD>","") for ix in range(curr_batch_size)]

                    loss = discriminator.train_step(unfilt_ctxt_batch, qbatch, ans_text_batch, ans_pos_batch, ixs, step=(i) )

            else:
                # Normal single pass update step. If model has PG capability, fill in the placeholders with empty values
                if FLAGS.model_type[:7] == "MALUUBA" and not FLAGS.policy_gradient:
                    rl_dict={model.lm_score: [0 for b in range(curr_batch_size)],
                        model.qa_score: [0 for b in range(curr_batch_size)],
                        model.disc_score: [0 for b in range(curr_batch_size)],
                        model.rl_lm_enabled: False,
                        model.rl_qa_enabled: False,
                        model.rl_disc_enabled: False,
                        model.hide_answer_in_copy: False}
                else:
                    rl_dict={}

                # Perform a normal optimizer step
                ops = [model.optimizer, model.train_summary,model.q_hat_string]
                if i%FLAGS.eval_freq==0:
                    ops.extend([ model.q_hat_ids, model.question_ids, model.copy_prob, model.question_raw, model.question_length])
                res= sess.run(ops, feed_dict={model.input_batch: train_batch,
                    model.is_training:True,
                    **rl_dict})
                summary_writer.add_summary(res[1], global_step=(i))



            # Dump some output periodically
            if i>0 and i%FLAGS.eval_freq==0 and (i > FLAGS.pg_burnin or not FLAGS.policy_gradient):
                with open(FLAGS.log_dir+'out.htm', 'w', encoding='utf-8') as fp:
                    fp.write(output_pretty(res[2].tolist(), res[3], res[4], res[5], 0, i))
                gold_batch = res[6]
                gold_lens = res[7]
                f1s=[]
                bleus=[]
                for b, pred in enumerate(res[2]):
                    pred_str = tokens_to_string(pred[:gold_lens[b]-1])
                    gold_str = tokens_to_string(gold_batch[b][:gold_lens[b]-1])
                    f1s.append(metrics.f1(gold_str, pred_str))
                    bleus.append(metrics.bleu(gold_str, pred_str))


                f1summary = tf.Summary(value=[tf.Summary.Value(tag="train_perf/f1",
                                                 simple_value=sum(f1s)/len(f1s))])
                bleusummary = tf.Summary(value=[tf.Summary.Value(tag="train_perf/bleu",
                                          simple_value=sum(bleus)/len(bleus))])

                summary_writer.add_summary(f1summary, global_step=(i))
                summary_writer.add_summary(bleusummary, global_step=(i))

                # Evaluate against dev set
                f1s=[]
                bleus=[]
                nlls=[]

                np.random.shuffle(dev_data)
                dev_subset = dev_data[:num_dev_samples]
                dev_data_source.initialise(dev_subset)
                for j in tqdm(range(num_steps_dev), desc='Eval '+str(i)):
                    dev_batch, curr_batch_size = dev_data_source.get_batch()
                    pred_batch,pred_ids,pred_lens,gold_batch, gold_lens,ctxt,ctxt_len,ans,ans_len,nll= sess.run([model.q_hat_beam_string, model.q_hat_beam_ids,model.q_hat_beam_lens,model.question_raw, model.question_length, model.context_raw, model.context_length, model.answer_locs, model.answer_length, model.nll], feed_dict={model.input_batch: dev_batch ,model.is_training:False})

                    nlls.extend(nll.tolist())
                    # out_str="<h1>"+str(e)+' - '+str(datetime.datetime.now())+'</h1>'
                    for b, pred in enumerate(pred_batch):
                        pred_str = tokens_to_string(pred[:pred_lens[b]-1]).replace(' </Sent>',"").replace(" <PAD>","")
                        gold_str = tokens_to_string(gold_batch[b][:gold_lens[b]-1])
                        f1s.append(metrics.f1(gold_str, pred_str))
                        bleus.append(metrics.bleu(gold_str, pred_str))
                        # out_str+=pred_str.replace('>','&gt;').replace('<','&lt;')+"<br/>"+gold_str.replace('>','&gt;').replace('<','&lt;')+"<hr/>"
                    if j==0:
                        title=chkpt_path
                        out_str = output_eval(title,pred_batch,  pred_ids, pred_lens, gold_batch, gold_lens, ctxt, ctxt_len, ans, ans_len)
                        with open(FLAGS.log_dir+'out_eval_'+FLAGS.model_type+'.htm', 'w', encoding='utf-8') as fp:
                            fp.write(out_str)

                f1summary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/f1",
                                                 simple_value=sum(f1s)/len(f1s))])
                bleusummary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/bleu",
                                          simple_value=sum(bleus)/len(bleus))])
                nllsummary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/nll",
                                   simple_value=sum(nlls)/len(nlls))])

                summary_writer.add_summary(f1summary, global_step=i)
                summary_writer.add_summary(bleusummary, global_step=i)
                summary_writer.add_summary(nllsummary, global_step=i)

                mean_nll=sum(nlls)/len(nlls)
                if mean_nll < best_oos_nll:
                    print("New best NLL! ", mean_nll, " Saving...")
                    best_oos_nll = mean_nll
                    saver.save(sess, chkpt_path+'/model.checkpoint', global_step=i)
                else:
                    print("NLL not improved ", mean_nll)
                    if FLAGS.policy_gradient:
                        print("Saving anyway")
                        saver.save(sess, chkpt_path+'/model.checkpoint', global_step=i)
                    if FLAGS.disc_train:
                        print("Saving disc")
                        discriminator.save_to_chkpt(FLAGS.model_dir, i)