Exemple #1
0
def main(_):
    train_data = loader.load_squad_triples("./data/", False)
    dev_data = loader.load_squad_triples("./data/", test=True)

    from tqdm import tqdm

    print('Loaded SQuAD with ', len(train_data), ' triples')
    train_contexts, train_qs, train_as, train_a_pos = zip(*train_data)
    _, dev_qs, _, _ = zip(*dev_data)

    lm = LstmLmInstance()
    lm.load_from_chkpt(FLAGS.model_dir + 'saved/lmtest')

    vocab = lm.vocab

    # random words, basic q, common words, real q, real context
    seq_batch = [
        "what played a chance to defend their title from super bowl xlix ?",
        "who were the defending super bowl champions ?",
        "what was the name of the company that tesla the public ? </Sent>",
        "what was the boat called ?",
        "Which NFL team represented the AFC at Super Bowl 50?",
        "which NFL team represented the <OOV> at <OOV> <OOV> <OOV> ?"
    ]
    # seq_batch=dev_qs[:5]

    perps = lm.get_seq_perplexity(seq_batch)
    print(perps)
    print(seq_batch)

    perps = []
    num_steps = len(dev_qs) // 128
    for i in tqdm(range(num_steps)):
        perps.extend(lm.get_seq_perplexity(dev_qs[i * 128:(i + 1) * 128]))
    print(np.mean(perps))
def main(_):
    import helpers.metrics as metrics
    from tqdm import tqdm

    # train_data = loader.load_squad_triples("./data/", False)
    dev_data = loader.load_squad_triples("./data/", test=True, ans_list=True)

    # print('Loaded SQuAD with ',len(train_data),' triples')
    # train_contexts, train_qs, train_as,train_a_pos = zip(*train_data)

    qa = MpcmQaInstance()
    qa.load_from_chkpt(FLAGS.model_dir + 'saved/qamaybe')
    vocab = qa.vocab

    questions = [
        "What colour is the car?", "When was the car made?",
        "Where was the date?", "What was the dog called?",
        "Who was the oldest cat?"
    ]
    contexts = [
        "The car is green, and was built in 1985. This sentence should make it less likely to return the date, when asked about a cat. The oldest cat was called creme puff and lived for many years!"
        for i in range(len(questions))
    ]

    # print(contexts[0])

    f1s = []
    ems = []
    for x in tqdm(dev_data):
        ans_pred = qa.get_ans([x[0]], [x[1]])[0]

        this_f1s = []
        this_ems = []
        for a in range(len(x[2])):
            this_ems.append(1.0 * (metrics.normalize_answer(ans_pred)
                                   == metrics.normalize_answer(x[2][a])))
            this_f1s.append(
                metrics.f1(metrics.normalize_answer(ans_pred),
                           metrics.normalize_answer(x[2][a])))
        ems.append(max(this_ems))
        f1s.append(max(this_f1s))
    print("EM: ", np.mean(ems), " F1: ", np.mean(f1s))
Exemple #3
0
def main(_):

    train_data = loader.load_squad_triples("./data/", False)




    print('Loaded SQuAD with ',len(train_data),' triples')
    train_contexts, train_qs, train_as,train_a_pos = zip(*train_data)

    qa = MpcmQaInstance()
    qa.load_from_chkpt(FLAGS.model_dir+'saved/qatest')
    vocab = qa.vocab

    questions = ["What colour is the car?","When was the car made?","Where was the date?", "What was the dog called?","Who was the oldest cat?"]
    contexts=["The car is green, and was built in 1985. This sentence should make it less likely to return the date, when asked about a cat. The oldest cat was called creme puff and lived for many years!" for i in range(len(questions))]


    spans = qa.get_ans(contexts, questions)
    print(contexts[0])
    for i, q in enumerate(questions):
        toks = tokenise(contexts[i], asbytes=False)
        print(q, "->", toks[spans[i,0]:spans[i,1]])
Exemple #4
0
def main(_):

    model_type=FLAGS.model_type
    # chkpt_path = FLAGS.model_dir+'saved/qgen-maluuba-crop-glove-smart'
    # chkpt_path = FLAGS.model_dir+'qgen-saved/MALUUBA-CROP-LATENT/1533247183'
    disc_path = FLAGS.model_dir+'saved/discriminator-trained-latent'
    chkpt_path = FLAGS.model_dir+'qgen/'+ model_type+'/'+FLAGS.eval_model_id

    # load dataset
    # train_data = loader.load_squad_triples(FLAGS.data_path, False)
    dev_data = loader.load_squad_triples(FLAGS.data_path, dev=FLAGS.eval_on_dev, test=FLAGS.eval_on_test)

    if len(dev_data) < FLAGS.num_eval_samples:
        exit('***ERROR*** Eval dataset is smaller than the num_eval_samples flag!')
    if len(dev_data) > FLAGS.num_eval_samples:
        print('***WARNING*** Eval dataset is larger than the num_eval_samples flag!')

    # train_contexts_unfilt, _,_,train_a_pos_unfilt = zip(*train_data)
    dev_contexts_unfilt, _,_,dev_a_pos_unfilt = zip(*dev_data)

    if FLAGS.filter_window_size_before >-1:
        # train_data = preprocessing.filter_squad(train_data, window_size=FLAGS.filter_window_size, max_tokens=FLAGS.filter_max_tokens)
        dev_data = preprocessing.filter_squad(dev_data, window_size_before=FLAGS.filter_window_size_before, window_size_after=FLAGS.filter_window_size_after, max_tokens=FLAGS.filter_max_tokens)


    # print('Loaded SQuAD with ',len(train_data),' triples')
    print('Loaded SQuAD dev set with ',len(dev_data),' triples')
    # train_contexts, train_qs, train_as,train_a_pos = zip(*train_data)
    dev_contexts, dev_qs, dev_as, dev_a_pos = zip(*dev_data)


    # vocab = loader.get_vocab(train_contexts, tf.app.flags.FLAGS.vocab_size)
    with open(chkpt_path+'/vocab.json') as f:
        vocab = json.load(f)

    with SquadStreamer(vocab, FLAGS.eval_batch_size, 1, shuffle=False) as dev_data_source:

        glove_embeddings = loader.load_glove(FLAGS.data_path)


        # Create model
        if model_type[:7] == "SEQ2SEQ":
            model = Seq2SeqModel(vocab, training_mode=False)
        elif model_type[:2] == "RL":
            # TEMP - no need to spin up the LM or QA model at eval time
            FLAGS.qa_weight = 0
            FLAGS.lm_weight = 0
            model = RLModel(vocab, training_mode=False)
        else:
            exit("Unrecognised model type: "+model_type)

        with model.graph.as_default():
            saver = tf.train.Saver()

        if FLAGS.eval_metrics:
            lm = LstmLmInstance()
            # qa = MpcmQaInstance()
            qa = QANetInstance()

            lm.load_from_chkpt(FLAGS.model_dir+'saved/lmtest')
            # qa.load_from_chkpt(FLAGS.model_dir+'saved/qatest')
            qa.load_from_chkpt(FLAGS.model_dir+'saved/qanet2')

            discriminator = DiscriminatorInstance(trainable=False, path=disc_path)

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_limit)
        with tf.Session(graph=model.graph, config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
            if not os.path.exists(chkpt_path):
                exit('Checkpoint path doesnt exist! '+chkpt_path)
            # summary_writer = tf.summary.FileWriter(FLAGS.log_directory+"eval/"+str(int(time.time())), sess.graph)

            saver.restore(sess, tf.train.latest_checkpoint(chkpt_path))
            # print('Loading not implemented yet')
            # else:
            #     sess.run(tf.global_variables_initializer())
            #     sess.run(model.glove_init_ops)

            num_steps = FLAGS.num_eval_samples//FLAGS.eval_batch_size

            # Initialise the dataset

            # np.random.shuffle(dev_data)
            dev_data_source.initialise(dev_data)

            f1s=[]
            bleus=[]
            qa_scores=[]
            qa_scores_gold=[]
            lm_scores=[]
            nlls=[]
            disc_scores=[]
            sowe_similarities=[]
            copy_probs=[]

            qgolds=[]
            qpreds=[]
            qpred_ids=[]
            qgold_ids=[]
            ctxts=[]
            answers=[]
            ans_positions=[]

            metric_individuals=[]
            res=[]
            for e in range(1):
                for i in tqdm(range(num_steps), desc='Epoch '+str(e)):
                    dev_batch, curr_batch_size = dev_data_source.get_batch()
                    pred_batch,pred_beam,pred_beam_lens,pred_ids,pred_lens,gold_batch, gold_lens,gold_ids,ctxt,ctxt_len,ans,ans_len,nll,copy_prob= sess.run([model.q_hat_beam_string, model.q_hat_full_beam_str, model.q_hat_full_beam_lens,model.q_hat_beam_ids,model.q_hat_beam_lens,model.question_raw, model.question_length, model.question_ids, model.context_raw, model.context_length, model.answer_locs, model.answer_length, model.nll, model.mean_copy_prob], feed_dict={model.input_batch: dev_batch ,model.is_training:False})

                    unfilt_ctxt_batch = [dev_contexts_unfilt[ix] for ix in dev_batch[3]]
                    a_text_batch = ops.byte_token_array_to_str(dev_batch[2][0], dev_batch[2][2], is_array=False)
                    unfilt_apos_batch = [dev_a_pos_unfilt[ix] for ix in dev_batch[3]]

                    # subtract 1 to remove the "end sent token"
                    pred_q_batch = [q.replace(' </Sent>',"").replace(" <PAD>","") for q in ops.byte_token_array_to_str(pred_batch, pred_lens-1)]

                    ctxts.extend(unfilt_ctxt_batch)
                    answers.extend(a_text_batch)
                    ans_positions.extend([dev_a_pos_unfilt[ix] for ix in dev_batch[3]])
                    copy_probs.extend(copy_prob.tolist())



                    # get QA score

                    # gold_str=[]
                    # pred_str=[]


                    gold_ans = ops.byte_token_array_to_str(dev_batch[2][0], dev_batch[2][2], is_array=False)
                    # pred_str = ops.byte_token_array_to_str([dev_batch[0][0][b][qa_pred[b][0]:qa_pred[b][1]] for b in range(curr_batch_size)], is_array=False)
                    nlls.extend(nll.tolist())

                    if FLAGS.eval_metrics:
                        qa_pred = qa.get_ans(unfilt_ctxt_batch, ops.byte_token_array_to_str(pred_batch, pred_lens))
                        gold_qa_pred = qa.get_ans(unfilt_ctxt_batch, ops.byte_token_array_to_str(dev_batch[1][0], dev_batch[1][3]))

                        qa_score_batch = [metrics.f1(metrics.normalize_answer(gold_ans[b]), metrics.normalize_answer(qa_pred[b])) for b in range(curr_batch_size)]
                        qa_score_gold_batch = [metrics.f1(metrics.normalize_answer(gold_ans[b]), metrics.normalize_answer(gold_qa_pred[b])) for b in range(curr_batch_size)]
                        lm_score_batch = lm.get_seq_perplexity(pred_q_batch).tolist()
                        disc_score_batch = discriminator.get_pred(unfilt_ctxt_batch, pred_q_batch, gold_ans, unfilt_apos_batch).tolist()

                    for b, pred in enumerate(pred_batch):
                        pred_str = pred_q_batch[b].replace(' </Sent>',"").replace(" <PAD>","")
                        gold_str = tokens_to_string(gold_batch[b][:gold_lens[b]-1])
                        f1s.append(metrics.f1(gold_str, pred_str))
                        bleus.append(metrics.bleu(gold_str, pred_str))
                        qgolds.append(gold_str)
                        qpreds.append(pred_str)

                        # calc cosine similarity between sums of word embeddings
                        pred_sowe = np.sum(np.asarray([glove_embeddings[w] if w in glove_embeddings.keys() else np.zeros((FLAGS.embedding_size,)) for w in preprocessing.tokenise(pred_str ,asbytes=False)]) ,axis=0)
                        gold_sowe = np.sum(np.asarray([glove_embeddings[w] if w in glove_embeddings.keys() else np.zeros((FLAGS.embedding_size,)) for w in preprocessing.tokenise(gold_str ,asbytes=False)]) ,axis=0)
                        this_similarity = np.inner(pred_sowe, gold_sowe)/np.linalg.norm(pred_sowe, ord=2)/np.linalg.norm(gold_sowe, ord=2)

                        sowe_similarities.append(this_similarity)



                        this_metric_dict={
                            'f1':f1s[-1],
                            'bleu': bleus[-1],
                            'nll': nlls[-1],
                            'sowe': sowe_similarities[-1]
                            }
                        if FLAGS.eval_metrics:
                            this_metric_dict={
                            **this_metric_dict,
                            'qa': qa_score_batch[b],
                            'lm': lm_score_batch[b],
                            'disc': disc_score_batch[b]}
                            qa_scores.extend(qa_score_batch)
                            lm_scores.extend(lm_score_batch)
                            disc_scores.extend(disc_score_batch)
                        metric_individuals.append(this_metric_dict)

                        res.append({
                            'c':unfilt_ctxt_batch[b],
                            'q_pred': pred_str,
                            'q_gold': gold_str,
                            'a_pos': unfilt_apos_batch[b],
                            'a_text': a_text_batch[b],
                            'metrics': this_metric_dict,

                            'q_pred_ids': pred_ids.tolist()[b],
                            'q_gold_ids': dev_batch[1][1][b].tolist()

                        })

                    # Quick output
                    if i==0:
                        # print(copy_prob.tolist())
                        # print(copy_probs)
                        pred_str = tokens_to_string(pred_batch[0][:pred_lens[0]-1])
                        gold_str = tokens_to_string(gold_batch[0][:gold_lens[0]-1])
                        # print(pred_str)
                        print(qpreds[0])
                        print(gold_str)


                        title=chkpt_path
                        out_str = output_eval(title,pred_batch,  pred_ids, pred_lens, gold_batch, gold_lens, ctxt, ctxt_len, ans, ans_len)
                        with open(FLAGS.log_directory+'out_eval_'+model_type+'.htm', 'w', encoding='utf-8') as fp:
                            fp.write(out_str)

            # res = list(zip(qpreds,qgolds,ctxts,answers,ans_positions,metric_individuals))
            metric_dict={
                'f1':np.mean(f1s),
                'bleu': metrics.bleu_corpus(qgolds, qpreds),
                'nll':np.mean(nlls),
                'sowe': np.mean(sowe_similarities)
                }
            if FLAGS.eval_metrics:
                metric_dict={**metric_dict,
                'qa':np.mean(qa_scores),
                'lm':np.mean(lm_scores),
                'disc': np.mean(disc_scores)}
            # print(res)
            with open(FLAGS.log_directory+'out_eval_'+model_type+("_test" if FLAGS.eval_on_test else "")+("_train" if (not FLAGS.eval_on_dev and not FLAGS.eval_on_test) else "")+'.json', 'w', encoding='utf-8') as fp:
                json.dump({"metrics":metric_dict, "results": res}, fp)


            print("F1: ", np.mean(f1s))
            print("BLEU: ", metrics.bleu_corpus(qgolds, qpreds))
            print("NLL: ", np.mean(nlls))
            print("SOWE: ", np.mean(sowe_similarities))

            print("Copy prob: ", np.mean(copy_probs))
            if FLAGS.eval_metrics:
                print("QA: ", np.mean(qa_scores))
                print("LM: ", np.mean(lm_scores))
                print("Disc: ", np.mean(disc_scores))
Exemple #5
0
def main(_):
    from tqdm import tqdm
    FLAGS = tf.app.flags.FLAGS

    # questions = ["What colour is the car?","When was the car made?","Where was the date?", "What was the dog called?","Who was the oldest cat?"]
    # contexts=["The car is green, and was built in 1985. This sentence should make it less likely to return the date, when asked about a cat. The oldest cat was called creme puff and lived for many years!" for i in range(len(questions))]

    trainable = False

    squad_train_full = loader.load_squad_triples(path="./data/")
    squad_dev_full = loader.load_squad_triples(path="./data/",
                                               dev=True,
                                               ans_list=True)

    para_limit = FLAGS.test_para_limit
    ques_limit = FLAGS.test_ques_limit
    char_limit = FLAGS.char_limit

    def filter_func(example, is_test=False):
        return len(example["context_tokens"]) > para_limit or \
               len(example["ques_tokens"]) > ques_limit or \
               (example["y2s"][0] - example["y1s"][0]) > ans_limit

    qa = QANetInstance()
    qa.load_from_chkpt("./models/saved/qanet2/", trainable=trainable)

    squad_train = []
    for x in squad_train_full:
        c_toks = word_tokenize(x[0])
        q_toks = word_tokenize(x[1])
        if len(c_toks) < para_limit and len(q_toks) < ques_limit:
            squad_train.append(x)

    squad_dev = []
    for x in squad_dev_full:
        c_toks = word_tokenize(x[0])
        q_toks = word_tokenize(x[1])
        if len(c_toks) < para_limit and len(q_toks) < ques_limit:
            squad_dev.append(x)

    num_train_steps = len(squad_train) // FLAGS.batch_size
    num_eval_steps = len(squad_dev) // FLAGS.batch_size

    best_f1 = 0
    if trainable:
        run_id = str(int(time.time()))
        chkpt_path = FLAGS.model_dir + 'qanet/' + run_id
        if not os.path.exists(chkpt_path):
            os.makedirs(chkpt_path)

        summary_writer = tf.summary.FileWriter(
            FLAGS.log_directory + 'qanet/' + run_id, qa.model.graph)
        for i in tqdm(range(FLAGS.qa_num_epochs * num_train_steps)):
            if i % num_train_steps == 0:
                print('Shuffling training set')
                np.random.shuffle(squad_train)

            this_batch = squad_train[i * FLAGS.batch_size:(i + 1) *
                                     FLAGS.batch_size]
            batch_contexts, batch_questions, batch_ans_text, batch_ans_charpos = zip(
                *this_batch)

            batch_answers = []
            for j, ctxt in enumerate(batch_contexts):
                ans_span = char_pos_to_word(
                    ctxt.encode(), [t.encode() for t in word_tokenize(ctxt)],
                    batch_ans_charpos[j])
                ans_span = (np.eye(FLAGS.test_para_limit)[ans_span],
                            np.eye(FLAGS.test_para_limit)
                            [ans_span + len(word_tokenize(batch_ans_text[j])) -
                             1])
                batch_answers.append(ans_span)
            this_loss = qa.train_step(batch_contexts, batch_questions,
                                      batch_answers)

            if i % 50 == 0:
                losssummary = tf.Summary(value=[
                    tf.Summary.Value(tag="train_loss/loss",
                                     simple_value=np.mean(this_loss))
                ])

                summary_writer.add_summary(losssummary, global_step=i)

            if i > 0 and i % 1000 == 0:
                qa_f1s = []
                qa_em = []

                for j in tqdm(range(num_eval_steps)):
                    this_batch = squad_dev[j * FLAGS.batch_size:(j + 1) *
                                           FLAGS.batch_size]

                    spans = qa.get_ans([x[0] for x in this_batch],
                                       [x[1] for x in this_batch])

                    for b in range(len(this_batch)):
                        qa_f1s.append(
                            metrics.f1(
                                metrics.normalize_answer(this_batch[b][2]),
                                metrics.normalize_answer(spans[b])))
                        qa_em.append(
                            1.0 * (metrics.normalize_answer(this_batch[b][2])
                                   == metrics.normalize_answer(spans[b])))

                f1summary = tf.Summary(value=[
                    tf.Summary.Value(tag="dev_perf/f1",
                                     simple_value=np.mean(qa_f1s))
                ])

                summary_writer.add_summary(f1summary, global_step=i)
                if np.mean(qa_f1s) > best_f1:
                    print("New best F1! ", np.mean(qa_f1s), " Saving...")
                    best_f1 = np.mean(qa_f1s)
                    qa.saver.save(qa.sess, chkpt_path + '/model.checkpoint')

    qa_f1s = []
    qa_em = []

    for i in tqdm(range(num_eval_steps)):
        this_batch = squad_dev[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]

        spans = qa.get_ans([x[0] for x in this_batch],
                           [x[1] for x in this_batch])

        for b in range(len(this_batch)):
            this_f1s = []
            this_em = []
            for a in range(len(this_batch[b][2])):
                this_f1s.append(
                    metrics.f1(metrics.normalize_answer(this_batch[b][2][a]),
                               metrics.normalize_answer(spans[b])))
                this_em.append(1.0 *
                               (metrics.normalize_answer(this_batch[b][2][a])
                                == metrics.normalize_answer(spans[b])))
            qa_em.append(max(this_em))
            qa_f1s.append(max(this_f1s))

        if i == 0:
            print(qa_f1s, qa_em)
            print(this_batch[0])
            print(spans[0])

    print('EM: ', np.mean(qa_em))
    print('F1: ', np.mean(qa_f1s))
import helpers.metrics as metrics

from langmodel.lm import LstmLmInstance
from qa.mpcm import MpcmQaInstance

import flags
import tensorflow as tf
import numpy as np
from tqdm import tqdm
FLAGS = tf.app.flags.FLAGS

import baseline_model

url = 'http://localhost:9000/?properties={"annotators":"openie,ner","outputFormat":"json","openie.affinity_probability_cap":0.01}'

train_data = loader.load_squad_triples('./data/', True)[:1500]


def get_q_word(ner):
    if ner in ["MISC", "UNK", "IDEOLOGY", "RELIGION"]:
        return "what"
    elif ner in ["PERSON", "ORGANIZATION", "TITLE"]:
        return "who"
    elif ner in ["NUMBER", "MONEY"]:
        return "how many"
    elif ner in ["DATE", "TIME"]:
        return "when"
    elif ner in ["STATE_OR_PROVINCE", "COUNTRY", "CITY", "LOCATION"]:
        return "where"
    elif ner in ["DURATION"]:
        return "how long"
Exemple #7
0
def main(_):
    if FLAGS.testing:
        print('TEST MODE - reducing model size')
        FLAGS.context_encoder_units = 100
        FLAGS.answer_encoder_units = 100
        FLAGS.decoder_units = 100
        FLAGS.batch_size = 8
        FLAGS.eval_batch_size = 8
        # FLAGS.embedding_size=50

    run_id = str(int(time.time()))
    chkpt_path = FLAGS.model_dir + 'qgen/' + FLAGS.model_type + '/' + run_id
    restore_path = FLAGS.model_dir + 'qgen/' + FLAGS.restore_path if FLAGS.restore_path is not None else None  #'MALUUBA-CROP-LATENT'+'/'+'1534123959'
    # restore_path=FLAGS.model_dir+'saved/qgen-maluuba-crop-glove-smart'
    disc_path = FLAGS.model_dir + 'saved/discriminator-trained-latent'

    print("Run ID is ", run_id)
    print("Model type is ", FLAGS.model_type)

    if not os.path.exists(chkpt_path):
        os.makedirs(chkpt_path)

    # load dataset
    train_data = loader.load_squad_triples(FLAGS.data_path, False)
    dev_data = loader.load_squad_triples(FLAGS.data_path, True)

    train_contexts_unfilt, _, ans_text_unfilt, ans_pos_unfilt = zip(
        *train_data)
    dev_contexts_unfilt, _, dev_ans_text_unfilt, dev_ans_pos_unfilt = zip(
        *dev_data)

    if FLAGS.testing:
        train_data = train_data[:1000]
        num_dev_samples = 100
    else:
        num_dev_samples = FLAGS.num_dev_samples

    if FLAGS.filter_window_size_before > -1:
        train_data = preprocessing.filter_squad(
            train_data,
            window_size_before=FLAGS.filter_window_size_before,
            window_size_after=FLAGS.filter_window_size_after,
            max_tokens=FLAGS.filter_max_tokens)
        dev_data = preprocessing.filter_squad(
            dev_data,
            window_size_before=FLAGS.filter_window_size_before,
            window_size_after=FLAGS.filter_window_size_after,
            max_tokens=FLAGS.filter_max_tokens)

    print('Loaded SQuAD with ', len(train_data), ' triples')
    train_contexts, train_qs, train_as, train_a_pos = zip(*train_data)

    if FLAGS.restore:
        if restore_path is None:
            exit('You need to specify a restore path!')
        with open(restore_path + '/vocab.json', encoding="utf-8") as f:
            vocab = json.load(f)
    elif FLAGS.glove_vocab:
        vocab = loader.get_glove_vocab(FLAGS.data_path,
                                       size=FLAGS.vocab_size,
                                       d=FLAGS.embedding_size)
        with open(chkpt_path + '/vocab.json', 'w',
                  encoding="utf-8") as outfile:
            json.dump(vocab, outfile)
    else:
        vocab = loader.get_vocab(train_contexts + train_qs, FLAGS.vocab_size)
        with open(chkpt_path + '/vocab.json', 'w',
                  encoding="utf-8") as outfile:
            json.dump(vocab, outfile)

    # Create model
    if FLAGS.model_type[:7] == "SEQ2SEQ":
        model = Seq2SeqModel(vocab,
                             training_mode=True,
                             use_embedding_loss=FLAGS.embedding_loss)
    elif FLAGS.model_type[:7] == "MALUUBA":
        # TEMP
        if not FLAGS.policy_gradient:
            FLAGS.qa_weight = 0
            FLAGS.lm_weight = 0
        model = MaluubaModel(vocab,
                             training_mode=True,
                             use_embedding_loss=FLAGS.embedding_loss)
        # if FLAGS.model_type[:10] == "MALUUBA_RL":
        #     qa_vocab=model.qa.vocab
        #     lm_vocab=model.lm.vocab
        if FLAGS.policy_gradient:
            discriminator = DiscriminatorInstance(trainable=FLAGS.disc_train,
                                                  path=disc_path)
    else:
        exit("Unrecognised model type: " + FLAGS.model_type)

    # create data streamer
    with SquadStreamer(vocab, FLAGS.batch_size, FLAGS.num_epochs,
                       shuffle=True) as train_data_source, SquadStreamer(
                           vocab, FLAGS.eval_batch_size, 1,
                           shuffle=True) as dev_data_source:

        with model.graph.as_default():
            saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)

        # change visible devices if using RL models
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_limit,
                                    visible_device_list='0',
                                    allow_growth=True)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,
                                              allow_soft_placement=False),
                        graph=model.graph) as sess:

            summary_writer = tf.summary.FileWriter(
                FLAGS.log_dir + 'qgen/' + FLAGS.model_type + '/' + run_id,
                sess.graph)

            train_data_source.initialise(train_data)

            num_steps_train = len(train_data) // FLAGS.batch_size
            num_steps_dev = num_dev_samples // FLAGS.eval_batch_size

            if FLAGS.restore:
                saver.restore(sess, tf.train.latest_checkpoint(restore_path))
                start_e = 15  #FLAGS.num_epochs
                print('Loaded model')
            else:
                start_e = 0
                sess.run(tf.global_variables_initializer())
                # sess.run(model.glove_init_ops)

                f1summary = tf.Summary(value=[
                    tf.Summary.Value(tag="dev_perf/f1", simple_value=0.0)
                ])
                bleusummary = tf.Summary(value=[
                    tf.Summary.Value(tag="dev_perf/bleu", simple_value=0.0)
                ])

                summary_writer.add_summary(f1summary, global_step=0)
                summary_writer.add_summary(bleusummary, global_step=0)

            # Initialise the dataset
            # sess.run(model.iterator.initializer, feed_dict={model.context_ph: train_contexts,
            #                                   model.qs_ph: train_qs, model.as_ph: train_as, model.a_pos_ph: train_a_pos})

            best_oos_nll = 1e6

            lm_score_moments = online_moments.OnlineMoment()
            qa_score_moments = online_moments.OnlineMoment()
            disc_score_moments = online_moments.OnlineMoment()

            # for e in range(start_e,start_e+FLAGS.num_epochs):
            # Train for one epoch
            for i in tqdm(range(num_steps_train * FLAGS.num_epochs),
                          desc='Training'):
                # Get a batch
                train_batch, curr_batch_size = train_data_source.get_batch()

                # Are we doing policy gradient? Do a forward pass first, then build the PG batch and do an update step
                if FLAGS.model_type[:
                                    10] == "MALUUBA_RL" and FLAGS.policy_gradient:

                    # do a fwd pass first, get the score, then do another pass and optimize
                    qhat_str, qhat_ids, qhat_lens = sess.run(
                        [
                            model.q_hat_beam_string, model.q_hat_beam_ids,
                            model.q_hat_beam_lens
                        ],
                        feed_dict={
                            model.input_batch: train_batch,
                            model.is_training: FLAGS.pg_dropout,
                            model.hide_answer_in_copy: True
                        })

                    # The output is as long as the max allowed len - remove the pointless extra padding
                    qhat_ids = qhat_ids[:, :np.max(qhat_lens)]
                    qhat_str = qhat_str[:, :np.max(qhat_lens)]

                    pred_str = byte_token_array_to_str(qhat_str, qhat_lens - 1)
                    gold_q_str = byte_token_array_to_str(
                        train_batch[1][0], train_batch[1][3])

                    # Get reward values
                    lm_score = (-1 * model.lm.get_seq_perplexity(pred_str)
                                ).tolist()  # lower perplexity is better

                    # retrieve the uncropped context for QA evaluation
                    unfilt_ctxt_batch = [
                        train_contexts_unfilt[ix] for ix in train_batch[3]
                    ]
                    ans_text_batch = [
                        ans_text_unfilt[ix] for ix in train_batch[3]
                    ]
                    ans_pos_batch = [
                        ans_pos_unfilt[ix] for ix in train_batch[3]
                    ]

                    qa_pred = model.qa.get_ans(unfilt_ctxt_batch, pred_str)
                    qa_pred_gold = model.qa.get_ans(unfilt_ctxt_batch,
                                                    gold_q_str)

                    # gold_str=[]
                    # pred_str=[]
                    qa_f1s = []
                    gold_ans_str = byte_token_array_to_str(train_batch[2][0],
                                                           train_batch[2][2],
                                                           is_array=False)

                    qa_f1s.extend([
                        metrics.f1(metrics.normalize_answer(gold_ans_str[b]),
                                   metrics.normalize_answer(qa_pred[b]))
                        for b in range(curr_batch_size)
                    ])

                    disc_scores = discriminator.get_pred(
                        unfilt_ctxt_batch, pred_str, ans_text_batch,
                        ans_pos_batch)

                    if i > FLAGS.pg_burnin // 2:
                        lm_score_moments.push(lm_score)
                        qa_score_moments.push(qa_f1s)
                        disc_score_moments.push(disc_scores)

                    # print(disc_scores)
                    # print((e-start_e)*num_steps_train+i, flags.pg_burnin)

                    if i > FLAGS.pg_burnin:
                        # A variant of popart
                        qa_score_whitened = (
                            qa_f1s - qa_score_moments.mean
                        ) / np.sqrt(qa_score_moments.variance + 1e-6)
                        lm_score_whitened = (
                            lm_score - lm_score_moments.mean
                        ) / np.sqrt(lm_score_moments.variance + 1e-6)
                        disc_score_whitened = (
                            disc_scores - disc_score_moments.mean
                        ) / np.sqrt(disc_score_moments.variance + 1e-6)

                        lm_summary = tf.Summary(value=[
                            tf.Summary.Value(tag="rl_rewards/lm",
                                             simple_value=np.mean(lm_score))
                        ])
                        summary_writer.add_summary(lm_summary, global_step=(i))
                        qa_summary = tf.Summary(value=[
                            tf.Summary.Value(tag="rl_rewards/qa",
                                             simple_value=np.mean(qa_f1s))
                        ])
                        summary_writer.add_summary(qa_summary, global_step=(i))
                        disc_summary = tf.Summary(value=[
                            tf.Summary.Value(tag="rl_rewards/disc",
                                             simple_value=np.mean(disc_scores))
                        ])
                        summary_writer.add_summary(disc_summary,
                                                   global_step=(i))

                        lm_white_summary = tf.Summary(value=[
                            tf.Summary.Value(tag="rl_rewards/lm_white",
                                             simple_value=np.mean(
                                                 lm_score_whitened))
                        ])
                        summary_writer.add_summary(lm_white_summary,
                                                   global_step=(i))
                        qa_white_summary = tf.Summary(value=[
                            tf.Summary.Value(tag="rl_rewards/qa_white",
                                             simple_value=np.mean(
                                                 qa_score_whitened))
                        ])
                        summary_writer.add_summary(qa_white_summary,
                                                   global_step=(i))
                        disc_white_summary = tf.Summary(value=[
                            tf.Summary.Value(tag="rl_rewards/disc_white",
                                             simple_value=np.mean(
                                                 disc_score_whitened))
                        ])
                        summary_writer.add_summary(disc_white_summary,
                                                   global_step=(i))

                        # Build a combined batch - half ground truth for MLE, half generated for PG
                        train_batch_ext = duplicate_batch_and_inject(
                            train_batch, qhat_ids, qhat_str, qhat_lens)

                        # print(qhat_ids)
                        # print(qhat_lens)
                        # print(train_batch_ext[2][2])

                        rl_dict = {
                            model.lm_score:
                            np.asarray((lm_score_whitened *
                                        FLAGS.lm_weight).tolist() + [
                                            FLAGS.pg_ml_weight
                                            for b in range(curr_batch_size)
                                        ]),
                            model.qa_score:
                            np.asarray((qa_score_whitened *
                                        FLAGS.qa_weight).tolist() +
                                       [0 for b in range(curr_batch_size)]),
                            model.disc_score:
                            np.asarray((disc_score_whitened *
                                        FLAGS.disc_weight).tolist() +
                                       [0 for b in range(curr_batch_size)]),
                            model.rl_lm_enabled:
                            True,
                            model.rl_qa_enabled:
                            True,
                            model.rl_disc_enabled:
                            FLAGS.disc_weight > 0,
                            model.step:
                            i - FLAGS.pg_burnin,
                            model.hide_answer_in_copy:
                            True
                        }

                        # perform a policy gradient step, but combine with a XE step by using appropriate rewards
                        ops = [
                            model.pg_optimizer, model.train_summary,
                            model.q_hat_string
                        ]
                        if i % FLAGS.eval_freq == 0:
                            ops.extend([
                                model.q_hat_ids, model.question_ids,
                                model.copy_prob, model.question_raw,
                                model.question_length
                            ])
                            res_offset = 5
                        else:
                            res_offset = 0
                        ops.extend([model.lm_loss, model.qa_loss])
                        res = sess.run(ops,
                                       feed_dict={
                                           model.input_batch: train_batch_ext,
                                           model.is_training: False,
                                           **rl_dict
                                       })
                        summary_writer.add_summary(res[1], global_step=(i))

                        # Log only the first half of the PG related losses
                        lm_loss_summary = tf.Summary(value=[
                            tf.Summary.Value(
                                tag="train_loss/lm",
                                simple_value=np.mean(res[3 + res_offset]
                                                     [:curr_batch_size]))
                        ])
                        summary_writer.add_summary(lm_loss_summary,
                                                   global_step=(i))
                        qa_loss_summary = tf.Summary(value=[
                            tf.Summary.Value(
                                tag="train_loss/qa",
                                simple_value=np.mean(res[4 + res_offset]
                                                     [:curr_batch_size]))
                        ])
                        summary_writer.add_summary(qa_loss_summary,
                                                   global_step=(i))

                    # TODO: more principled scheduling here than alternating steps
                    if FLAGS.disc_train:
                        ixs = np.round(
                            np.random.binomial(1, 0.5, curr_batch_size))
                        qbatch = [
                            pred_str[ix].replace(" </Sent>", "").replace(
                                " <PAD>", "")
                            if ixs[ix] < 0.5 else gold_q_str[ix].replace(
                                " </Sent>", "").replace(" <PAD>", "")
                            for ix in range(curr_batch_size)
                        ]

                        loss = discriminator.train_step(unfilt_ctxt_batch,
                                                        qbatch,
                                                        ans_text_batch,
                                                        ans_pos_batch,
                                                        ixs,
                                                        step=(i))

                else:
                    # Normal single pass update step. If model has PG capability, fill in the placeholders with empty values
                    if FLAGS.model_type[:
                                        7] == "MALUUBA" and not FLAGS.policy_gradient:
                        rl_dict = {
                            model.lm_score:
                            [0 for b in range(curr_batch_size)],
                            model.qa_score:
                            [0 for b in range(curr_batch_size)],
                            model.disc_score:
                            [0 for b in range(curr_batch_size)],
                            model.rl_lm_enabled: False,
                            model.rl_qa_enabled: False,
                            model.rl_disc_enabled: False,
                            model.hide_answer_in_copy: False
                        }
                    else:
                        rl_dict = {}

                    # Perform a normal optimizer step
                    ops = [
                        model.optimizer, model.train_summary,
                        model.q_hat_string
                    ]
                    if i % FLAGS.eval_freq == 0:
                        ops.extend([
                            model.q_hat_ids, model.question_ids,
                            model.copy_prob, model.question_raw,
                            model.question_length
                        ])
                    res = sess.run(ops,
                                   feed_dict={
                                       model.input_batch: train_batch,
                                       model.is_training: True,
                                       **rl_dict
                                   })
                    summary_writer.add_summary(res[1], global_step=(i))

                # Dump some output periodically
                if i > 0 and i % FLAGS.eval_freq == 0 and (
                        i > FLAGS.pg_burnin or not FLAGS.policy_gradient):
                    with open(FLAGS.log_dir + 'out.htm', 'w',
                              encoding='utf-8') as fp:
                        fp.write(
                            output_pretty(res[2].tolist(), res[3], res[4],
                                          res[5], 0, i))
                    gold_batch = res[6]
                    gold_lens = res[7]
                    f1s = []
                    bleus = []
                    for b, pred in enumerate(res[2]):
                        pred_str = tokens_to_string(pred[:gold_lens[b] - 1])
                        gold_str = tokens_to_string(
                            gold_batch[b][:gold_lens[b] - 1])
                        f1s.append(metrics.f1(gold_str, pred_str))
                        bleus.append(metrics.bleu(gold_str, pred_str))

                    f1summary = tf.Summary(value=[
                        tf.Summary.Value(tag="train_perf/f1",
                                         simple_value=sum(f1s) / len(f1s))
                    ])
                    bleusummary = tf.Summary(value=[
                        tf.Summary.Value(tag="train_perf/bleu",
                                         simple_value=sum(bleus) / len(bleus))
                    ])

                    summary_writer.add_summary(f1summary, global_step=(i))
                    summary_writer.add_summary(bleusummary, global_step=(i))

                    # Evaluate against dev set
                    f1s = []
                    bleus = []
                    nlls = []

                    np.random.shuffle(dev_data)
                    dev_subset = dev_data[:num_dev_samples]
                    dev_data_source.initialise(dev_subset)
                    for j in tqdm(range(num_steps_dev), desc='Eval ' + str(i)):
                        dev_batch, curr_batch_size = dev_data_source.get_batch(
                        )
                        pred_batch, pred_ids, pred_lens, gold_batch, gold_lens, ctxt, ctxt_len, ans, ans_len, nll = sess.run(
                            [
                                model.q_hat_beam_string, model.q_hat_beam_ids,
                                model.q_hat_beam_lens, model.question_raw,
                                model.question_length, model.context_raw,
                                model.context_length, model.answer_locs,
                                model.answer_length, model.nll
                            ],
                            feed_dict={
                                model.input_batch: dev_batch,
                                model.is_training: False
                            })

                        nlls.extend(nll.tolist())
                        # out_str="<h1>"+str(e)+' - '+str(datetime.datetime.now())+'</h1>'
                        for b, pred in enumerate(pred_batch):
                            pred_str = tokens_to_string(
                                pred[:pred_lens[b] - 1]).replace(
                                    ' </Sent>', "").replace(" <PAD>", "")
                            gold_str = tokens_to_string(
                                gold_batch[b][:gold_lens[b] - 1])
                            f1s.append(metrics.f1(gold_str, pred_str))
                            bleus.append(metrics.bleu(gold_str, pred_str))
                            # out_str+=pred_str.replace('>','&gt;').replace('<','&lt;')+"<br/>"+gold_str.replace('>','&gt;').replace('<','&lt;')+"<hr/>"
                        if j == 0:
                            title = chkpt_path
                            out_str = output_eval(title, pred_batch, pred_ids,
                                                  pred_lens, gold_batch,
                                                  gold_lens, ctxt, ctxt_len,
                                                  ans, ans_len)
                            with open(FLAGS.log_dir + 'out_eval_' +
                                      FLAGS.model_type + '.htm',
                                      'w',
                                      encoding='utf-8') as fp:
                                fp.write(out_str)

                    f1summary = tf.Summary(value=[
                        tf.Summary.Value(tag="dev_perf/f1",
                                         simple_value=sum(f1s) / len(f1s))
                    ])
                    bleusummary = tf.Summary(value=[
                        tf.Summary.Value(tag="dev_perf/bleu",
                                         simple_value=sum(bleus) / len(bleus))
                    ])
                    nllsummary = tf.Summary(value=[
                        tf.Summary.Value(tag="dev_perf/nll",
                                         simple_value=sum(nlls) / len(nlls))
                    ])

                    summary_writer.add_summary(f1summary, global_step=i)
                    summary_writer.add_summary(bleusummary, global_step=i)
                    summary_writer.add_summary(nllsummary, global_step=i)

                    mean_nll = sum(nlls) / len(nlls)
                    if mean_nll < best_oos_nll:
                        print("New best NLL! ", mean_nll, " Saving...")
                        best_oos_nll = mean_nll
                        saver.save(sess,
                                   chkpt_path + '/model.checkpoint',
                                   global_step=i)
                    else:
                        print("NLL not improved ", mean_nll)
                        if FLAGS.policy_gradient:
                            print("Saving anyway")
                            saver.save(sess,
                                       chkpt_path + '/model.checkpoint',
                                       global_step=i)
                        if FLAGS.disc_train:
                            print("Saving disc")
                            discriminator.save_to_chkpt(FLAGS.model_dir, i)
def main(_):
    train_data = loader.load_squad_triples(FLAGS.data_path, False)
    dev_data = loader.load_squad_triples(FLAGS.data_path, True)[:500]

    chkpt_path = FLAGS.model_dir + 'saved/qatest'
    # chkpt_path = FLAGS.model_dir+'qa/1528885583'

    print('Loaded SQuAD with ', len(train_data), ' triples')
    train_contexts, train_qs, train_as, train_a_pos = zip(*train_data)
    dev_contexts, dev_qs, dev_as, dev_a_pos = zip(*dev_data)

    # vocab = loader.get_vocab(train_contexts, tf.app.flags.FLAGS.qa_vocab_size)
    with open(chkpt_path + '/vocab.json') as f:
        vocab = json.load(f)

    model = MpcmQa(vocab, training_mode=False)
    with model.graph.as_default():
        saver = tf.train.Saver()

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_limit)
    with tf.Session(graph=model.graph,
                    config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        if not os.path.exists(chkpt_path):
            os.makedirs(chkpt_path)
        summary_writer = tf.summary.FileWriter(
            FLAGS.log_dir + 'qa/' + str(int(time.time())), sess.graph)

        saver.restore(sess, chkpt_path + '/model.checkpoint')

        num_steps = len(dev_data) // FLAGS.batch_size

        f1s = []
        exactmatches = []
        for e in range(1):
            np.random.shuffle(train_data)
            train_contexts, train_qs, train_as, train_a_pos = zip(*train_data)
            for i in tqdm(range(num_steps), desc='Epoch ' + str(e)):
                # TODO: this keeps coming up - refactor it
                batch_contexts = dev_contexts[i * FLAGS.batch_size:(i + 1) *
                                              FLAGS.batch_size]
                batch_questions = dev_qs[i * FLAGS.batch_size:(i + 1) *
                                         FLAGS.batch_size]
                batch_ans_text = dev_as[i * FLAGS.batch_size:(i + 1) *
                                        FLAGS.batch_size]
                batch_answer_charpos = dev_a_pos[i * FLAGS.batch_size:(i + 1) *
                                                 FLAGS.batch_size]

                batch_answers = []
                for j, ctxt in enumerate(batch_contexts):
                    ans_span = char_pos_to_word(
                        ctxt.encode(),
                        [t.encode() for t in tokenise(ctxt, asbytes=False)],
                        batch_answer_charpos[j])
                    ans_span = (
                        ans_span, ans_span +
                        len(tokenise(batch_ans_text[j], asbytes=False)))
                    batch_answers.append(ans_span)

                # print(batch_answers[:3])
                # exit()

                summ, pred = sess.run(
                    [model.eval_summary, model.pred_span],
                    feed_dict={
                        model.context_in:
                        get_padded_batch(batch_contexts, vocab),
                        model.question_in:
                        get_padded_batch(batch_questions, vocab),
                        model.answer_spans_in: batch_answers,
                        model.is_training: False
                    })

                summary_writer.add_summary(summ,
                                           global_step=(e * num_steps + i))

                gold_str = []
                pred_str = []
                for b in range(FLAGS.batch_size):
                    gold_str.append(" ".join(
                        tokenise(batch_contexts[b], asbytes=False)
                        [batch_answers[b][0]:batch_answers[b][1]]))
                    pred_str.append(" ".join(
                        tokenise(batch_contexts[b],
                                 asbytes=False)[pred[b][0]:pred[b][1]]))

                f1s.extend([
                    f1(gold_str[b], pred_str[b])
                    for b in range(FLAGS.batch_size)
                ])
                exactmatches.extend([
                    np.product(pred[b] == batch_answers[b]) * 1.0
                    for b in range(FLAGS.batch_size)
                ])

                if i % FLAGS.eval_freq == 0:
                    out_str = "<h1>" + "Eval - Dev set" + "</h1>"
                    for b in range(FLAGS.batch_size):
                        out_str += batch_contexts[b] + '<br/>'
                        out_str += batch_questions[b] + '<br/>'
                        out_str += str(batch_answers[b]) + str(
                            tokenise(batch_contexts[b], asbytes=False)
                            [batch_answers[b][0]:batch_answers[b][1]]
                        ) + '<br/>'
                        out_str += str(pred[b]) + str(
                            tokenise(batch_contexts[b], asbytes=False)
                            [pred[b][0]:pred[b][1]]) + '<br/>'
                        out_str += batch_ans_text[b] + '<br/>'
                        out_str += pred_str[b] + '<br/>'
                        out_str += "F1: " + str(f1(gold_str[b],
                                                   pred_str[b])) + '<br/>'
                        out_str += "EM: " + str(
                            np.product(pred[b] == batch_answers[b]) * 1.0)
                        out_str += "<hr/>"
                    with open(FLAGS.log_dir + 'out_qa_eval.htm', 'w') as fp:
                        fp.write(out_str)
        print("F1: ", np.mean(f1s))
        print("EM: ", np.mean(exactmatches))
Exemple #9
0
import sys
from time import time
sys.path.insert(0, "/Users/tom/Dropbox/msc-ml/project/src/")

from collections import Counter

from helpers import loader, preprocessing

# from qa.qanet.prepro import word_tokenize

import string
import matplotlib.pyplot as plt
import numpy as np

squad =  loader.load_squad_triples('./data/',False,v2=False)#[9654:9655]
squad_dev =  loader.load_squad_triples('./data/',True,v2=False)#[9654:9655]


# glove_vocab = set(loader.get_glove_vocab('./data/', size=1e12, d=200).keys())
# glove_short = list(loader.get_glove_vocab('./data/', size=2000, d=200).keys())[4:]

squad_vocab =set()
squad_count = Counter()

start = time()
max_context_len=0
max_pos = None
debugstr = ""

c_lens=[]
q_lens=[]
def main(_):
    model = FileLoaderModel('./models/BASELINE')
    squad = loader.load_squad_triples(FLAGS.data_path, True, as_dict=True)

    disc_path = FLAGS.model_dir + 'saved/discriminator-trained-latent'

    glove_embeddings = loader.load_glove(FLAGS.data_path)

    if FLAGS.eval_metrics:
        lm = LstmLmInstance()
        # qa = MpcmQaInstance()
        qa = QANetInstance()

        lm.load_from_chkpt(FLAGS.model_dir + 'saved/lmtest')
        # qa.load_from_chkpt(FLAGS.model_dir+'saved/qatest')
        qa.load_from_chkpt(FLAGS.model_dir + 'saved/qanet')

        discriminator = DiscriminatorInstance(trainable=False, path=disc_path)

    f1s = []
    bleus = []
    qa_scores = []
    qa_scores_gold = []
    lm_scores = []
    nlls = []
    disc_scores = []
    sowe_similarities = []

    qgolds = []
    qpreds = []
    ctxts = []
    answers = []
    ans_positions = []

    metric_individuals = []
    res = []

    missing = 0

    for id, el in tqdm(squad.items()):

        unfilt_ctxt_batch = [el[0]]
        a_text_batch = [el[2]]
        a_pos_batch = [el[3]]

        ctxts.extend(unfilt_ctxt_batch)
        answers.extend(a_text_batch)
        ans_positions.extend(a_pos_batch)

        pred_str = model.get_q(id)

        if pred_str is None:
            missing += 1
            continue
        gold_str = el[1]

        if FLAGS.eval_metrics:
            qa_pred = qa.get_ans(unfilt_ctxt_batch, [pred_str])
            gold_qa_pred = qa.get_ans(unfilt_ctxt_batch, [gold_str])

            qa_score = metrics.f1(el[2].lower(), qa_pred[0].lower())
            qa_score_gold = metrics.f1(el[2].lower(), gold_qa_pred[0].lower())
            lm_score = lm.get_seq_perplexity([pred_str]).tolist()
            disc_score = discriminator.get_pred(unfilt_ctxt_batch, [pred_str],
                                                a_text_batch,
                                                a_pos_batch).tolist()[0]

        f1s.append(metrics.f1(gold_str, pred_str))
        bleus.append(metrics.bleu(gold_str, pred_str))
        qgolds.append(gold_str)
        qpreds.append(pred_str)

        # calc cosine similarity between sums of word embeddings
        pred_sowe = np.sum(np.asarray([
            glove_embeddings[w] if w in glove_embeddings.keys() else np.zeros(
                (FLAGS.embedding_size, ))
            for w in preprocessing.tokenise(pred_str, asbytes=False)
        ]),
                           axis=0)
        gold_sowe = np.sum(np.asarray([
            glove_embeddings[w] if w in glove_embeddings.keys() else np.zeros(
                (FLAGS.embedding_size, ))
            for w in preprocessing.tokenise(gold_str, asbytes=False)
        ]),
                           axis=0)
        this_similarity = np.inner(pred_sowe, gold_sowe) / np.linalg.norm(
            pred_sowe, ord=2) / np.linalg.norm(gold_sowe, ord=2)

        sowe_similarities.append(this_similarity)

        this_metric_dict = {
            'f1': f1s[-1],
            'bleu': bleus[-1],
            'nll': 0,
            'sowe': sowe_similarities[-1]
        }
        if FLAGS.eval_metrics:
            this_metric_dict = {
                **this_metric_dict, 'qa': qa_score,
                'lm': lm_score,
                'disc': disc_score
            }
            qa_scores.append(qa_score)
            lm_scores.append(lm_score)
            disc_scores.append(disc_score)
        metric_individuals.append(this_metric_dict)

        res.append({
            'c': el[0],
            'q_pred': pred_str,
            'q_gold': gold_str,
            'a_pos': el[3],
            'a_text': el[2],
            'metrics': this_metric_dict
        })

    metric_dict = {
        'f1': np.mean(f1s),
        'bleu': np.mean(bleus),
        'nll': 0,
        'sowe': np.mean(sowe_similarities)
    }
    if FLAGS.eval_metrics:
        metric_dict = {
            **metric_dict, 'qa': np.mean(qa_scores),
            'lm': np.mean(lm_scores),
            'disc': np.mean(disc_scores)
        }
    # print(res)
    with open(FLAGS.log_dir + 'out_eval_BASELINE' +
              ("_train" if not FLAGS.eval_on_dev else "") + '.json',
              'w',
              encoding='utf-8') as fp:
        json.dump({"metrics": metric_dict, "results": res}, fp)

    print("F1: ", np.mean(f1s))
    print("BLEU: ", np.mean(bleus))
    print("NLL: ", 0)
    print("SOWE: ", np.mean(sowe_similarities))
    if FLAGS.eval_metrics:
        print("QA: ", np.mean(qa_scores))
        print("LM: ", np.mean(lm_scores))
        print("Disc: ", np.mean(disc_scores))

    print(missing, " ids were missing")
def main(_):

    from tqdm import tqdm
    import matplotlib.pyplot as plt
    import numpy as np
    from sklearn.metrics import confusion_matrix
    import itertools
    from sklearn.metrics import roc_curve, auc

    squad = loader.load_squad_triples(path="./data/",
                                      dev=True,
                                      v2=True,
                                      as_dict=True)
    # with open("./data/squad2_dev_official_output_fixed.json") as dataset_file:
    #     ans_preds = json.load(dataset_file)
    with open("./results/out_eval_MALUUBA-CROP-LATENT.json") as dataset_file:
        results = json.load(dataset_file)['results']

    disc = DiscriminatorInstance(
        path="./models/saved/discriminator-trained-latent")
    # disc = DiscriminatorInstance(path="./models/disc/1533307366-SQUAD-QANETINIT")

    # output={}
    # for id,candidates in tqdm(ans_preds.items()):
    #     ctxt, q, ans_gold, ans_gold_pos, label_gold = squad[id]
    #
    #     scores=[]
    #     for candidate in candidates:
    #         scores.append( disc.get_pred([ctxt], [q], [candidate['text']], [candidate['answer_start']]).tolist()[0] )
    #     cand_ix = np.argmax(scores)
    #
    #     pred_ans = candidates[cand_ix]['text']
    #     pred_score = scores[cand_ix]
    #     output[id] = pred_ans if pred_score > 0.5 else ""
    #
    # with open("./logs/squad2_dev_filtered.json","w") as fh:
    #     json.dump(output, fh)

    gold_labels = []
    pred_labels = []
    scores = []

    for res in tqdm(results[:1000]):
        # print(res['q_gold'], res['q_pred'])
        gold_score = disc.get_pred([res['c']], [res['q_gold']],
                                   [res['a_text']], [res['a_pos']])
        pred_score = disc.get_pred([res['c']], [res['q_pred']],
                                   [res['a_text']], [res['a_pos']])

        gold_labels.append(1)
        gold_labels.append(0)
        pred_labels.append(np.round(gold_score[0]))
        pred_labels.append(np.round(pred_score[0]))
        scores.append(gold_score[0])
        scores.append(pred_score[0])

    # oh_labels =np.eye(2)[gold_labels]
    ### disc conf mat
    # cm = confusion_matrix(gold_labels, pred_labels)
    # mat = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    # print(mat)
    # plt.imshow(mat, cmap=plt.cm.Blues)
    # plt.colorbar()
    # tick_marks = np.arange(2)
    # plt.xticks(tick_marks, [0,1], rotation=45)
    # plt.yticks(tick_marks, [0,1])
    # fmt = '.2f'
    # thresh = mat.max() / 2.
    # for i, j in itertools.product(range(mat.shape[0]), range(mat.shape[1])):
    #     plt.text(j, i, format(mat[i, j], fmt),
    #              horizontalalignment="center",
    #              color="white" if mat[i, j] > thresh else "black")
    #
    # # plt.tight_layout()
    # plt.ylabel('Actual Source')
    # plt.xlabel('Predicted source')
    # # plt.savefig("/users/Tom/Dropbox/Apps/Overleaf/Question Generation/figures/confusion_maluuba_crop_smart_set.pdf", format="pdf")
    # plt.show()
    # # exit()

    ### disc Roc curves
    fpr, tpr, _ = roc_curve(gold_labels, scores)
    roc_auc = auc(fpr, tpr)
    plt.figure()
    lw = 2
    plt.plot(fpr,
             tpr,
             color='darkorange',
             lw=lw,
             label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic example')
    plt.legend(loc="lower right")
    plt.show()
    exit()
Exemple #12
0
def main(_):
    run_id = str(int(time.time()))
    chkpt_path = FLAGS.model_dir + 'lm/' + run_id

    if not os.path.exists(chkpt_path):
        os.makedirs(chkpt_path)

    train_data = loader.load_squad_triples(FLAGS.data_path, False)
    dev_data = loader.load_squad_triples(FLAGS.data_path, True)

    np.random.shuffle(train_data)

    print('Loaded SQuAD with ', len(train_data), ' triples')
    train_contexts, train_qs, train_as, train_a_pos = zip(*train_data)
    _, dev_qs, _, _ = zip(*dev_data)
    vocab = loader.get_vocab(train_qs, tf.app.flags.FLAGS.lm_vocab_size)

    with open(chkpt_path + '/vocab.json', 'w') as outfile:
        json.dump(vocab, outfile)

    unique_sents = list(set(train_qs))
    print(len(unique_sents), " unique sentences")

    # Create model

    model = LstmLm(vocab, num_units=FLAGS.lm_units)
    with model.graph.as_default():
        saver = tf.train.Saver()

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_limit)
    with tf.Session(graph=model.graph,
                    config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        if not os.path.exists(chkpt_path):
            os.makedirs(chkpt_path)
        summary_writer = tf.summary.FileWriter(
            FLAGS.log_directory + 'lm/' + run_id, sess.graph)

        if FLAGS.restore:
            # saver.restore(sess, chkpt_path+ '/model.checkpoint')
            print('Loading not implemented yet')
        else:
            print("Building graph, loading glove")
            sess.run(tf.global_variables_initializer())

        num_steps = len(unique_sents) // FLAGS.batch_size

        best_perp = 1e6

        for e in range(FLAGS.lm_num_epochs):
            np.random.shuffle(unique_sents)
            for i in tqdm(range(num_steps), desc='Epoch ' + str(e)):
                seq_batch = unique_sents[i * FLAGS.batch_size:(i + 1) *
                                         FLAGS.batch_size]

                seq_batch_ids = [[vocab[loader.SOS]] + [
                    vocab[tok if tok in vocab.keys() else loader.OOV]
                    for tok in tokenise(sent, asbytes=False)
                ] + [vocab[loader.EOS]] for sent in seq_batch]
                max_seq_len = max([len(seq) for seq in seq_batch_ids])
                padded_batch = np.asarray([
                    seq +
                    [vocab[loader.PAD] for i in range(max_seq_len - len(seq))]
                    for seq in seq_batch_ids
                ])

                summ, _, pred, gold, seq = sess.run(
                    [
                        model.train_summary, model.optimise, model.preds,
                        model.tgt_output, model.input_seqs
                    ],
                    feed_dict={model.input_seqs: padded_batch})
                summary_writer.add_summary(summ,
                                           global_step=(e * num_steps + i))

                # print(pred, gold, seq)
                # exit()

                # if i%FLAGS.eval_freq==0:
                #     saver.save(sess, chkpt_path+'/model.checkpoint')
                # print(pred, gold, seq)

            perps = []
            num_steps_dev = len(dev_qs) // FLAGS.batch_size
            for i in tqdm(range(num_steps_dev), desc="Eval"):
                seq_batch = dev_qs[i * FLAGS.batch_size:(i + 1) *
                                   FLAGS.batch_size]
                seq_batch_ids = [[vocab[loader.SOS]] + [
                    vocab[tok if tok in vocab.keys() else loader.OOV]
                    for tok in tokenise(sent, asbytes=False)
                ] + [vocab[loader.EOS]] for sent in seq_batch]
                max_seq_len = max([len(seq) for seq in seq_batch_ids])
                padded_batch = np.asarray([
                    seq +
                    [vocab[loader.PAD] for i in range(max_seq_len - len(seq))]
                    for seq in seq_batch_ids
                ])

                perp = sess.run(model.perplexity,
                                feed_dict={model.input_seqs: padded_batch})
                perps.extend(perp)

            perpsummary = tf.Summary(value=[
                tf.Summary.Value(tag="dev_perf/perplexity",
                                 simple_value=sum(perps) / len(perps))
            ])

            summary_writer.add_summary(perpsummary,
                                       global_step=((e + 1) * num_steps))

            if np.mean(perps) < best_perp:
                print(np.mean(perps), " Saving!")
                saver.save(sess, chkpt_path + '/model.checkpoint')
                best_perp = np.mean(perps)
def main(_):

    FLAGS = tf.app.flags.FLAGS

    # results=results[:32]

    # dev_ctxts, dev_qs,dev_ans,dev_ans_pos, dev_correct = zip(*squad_dev)

    positive_data = []
    negative_data = []

    if FLAGS.disc_trainongenerated is True:
        with open(FLAGS.log_dir + 'out_eval_' + FLAGS.disc_modelslug +
                  '.json') as f:
            results = json.load(f)
        # for res in results:
        #     qpred,qgold,ctxt,ans_text,ans_pos =res
        for res in results['results']:
            positive_data.append(
                (res['c'], res['q_gold'], res['a_text'], res['a_pos']))
            negative_data.append(
                (res['c'], res['q_pred'], res['a_text'], res['a_pos']))

    if FLAGS.disc_trainonsquad is True:
        squad_v2 = loader.load_squad_triples(FLAGS.data_path,
                                             FLAGS.disc_dev_set,
                                             v2=True)
        for res in squad_v2:
            ctxt, q, ans_text, ans_pos, label = res
            if label is False:  # label is "is_unanswerable"
                positive_data.append(
                    (ctxt.lower(), q.lower(), ans_text.lower(), ans_pos))
            else:
                negative_data.append(
                    (ctxt.lower(), q.lower(), ans_text.lower(), ans_pos))

    num_instances = min(len(negative_data), len(positive_data))

    disc = DiscriminatorInstance(
        path=(FLAGS.model_dir +
              'saved/qanet2/' if FLAGS.disc_init_qanet is True else None),
        trainable=True,
        log_slug=FLAGS.disc_modelslug +
        ("_SQUAD" if FLAGS.disc_trainonsquad else "") +
        ("_QAINIT" if FLAGS.disc_init_qanet else ""),
        force_init=FLAGS.disc_init_qanet)

    # disc.load_from_chkpt() # this loads the embeddings etc

    train_samples = math.floor(0.8 * num_instances)
    dev_samples = math.floor(0.2 * num_instances)

    positive_data_train = positive_data[:train_samples]
    negative_data_train = negative_data[:train_samples]
    positive_data_dev = positive_data[train_samples:]
    negative_data_dev = negative_data[train_samples:]

    num_steps_train = train_samples // FLAGS.batch_size
    num_steps_dev = dev_samples // FLAGS.batch_size
    num_steps_squad = num_steps_dev

    best_oos_nll = 1e6

    for i in tqdm(range(num_steps_train * FLAGS.disc_num_epochs),
                  desc='Training'):
        if i % num_steps_train == 0:
            np.random.shuffle(positive_data_train)
            np.random.shuffle(negative_data_train)
        ixs = np.round(np.random.binomial(1, 0.5, FLAGS.batch_size))
        # batch = train_data[i*FLAGS.batch_size:(i+1)*FLAGS.batch_size]
        batch = [
            negative_data_train[(i % num_steps_train) * FLAGS.batch_size +
                                j] if ix < 0.5 else
            positive_data_train[(i % num_steps_train) * FLAGS.batch_size + j]
            for j, ix in enumerate(ixs.tolist())
        ]
        ctxt, qbatch, ans_text, ans_pos = zip(*batch)

        # print(ixs)
        # print(qbatch)
        # print(ans_text)
        # print(ans_pos)
        # print(ctxt)
        # exit()

        # +qpred[ix].replace("</Sent>","").replace("<PAD>","")
        qbatch = [
            q.replace(" </Sent>", "").replace(" <PAD>", "") for q in qbatch
        ]
        # qbatch = ["fake " if ixs[ix] < 0.5 else "real " for ix in range(FLAGS.batch_size)]
        # print(qbatch, ixs)
        loss = disc.train_step(ctxt, qbatch, ans_text, ans_pos, ixs, (i))

        if i % 1000 == 0 and i > 0:
            dev_acc = []
            dev_nll = []
            for dev_i in tqdm(range(num_steps_dev),
                              desc='Step ' + str(i) + " dev"):

                ixs = np.round(np.random.binomial(1, 0.5, FLAGS.batch_size))
                batch = [
                    negative_data_dev[dev_i * FLAGS.batch_size + j] if ix < 0.5
                    else positive_data_dev[dev_i * FLAGS.batch_size + j]
                    for j, ix in enumerate(ixs.tolist())
                ]
                ctxt, qbatch, ans_text, ans_pos = zip(*batch)

                qbatch = [
                    q.replace(" </Sent>", "").replace(" <PAD>", "")
                    for q in qbatch
                ]

                pred = disc.get_pred(ctxt, qbatch, ans_text, ans_pos)
                nll = disc.get_nll(ctxt, qbatch, ans_text, ans_pos, ixs)
                acc = 1.0 * np.equal(np.round(pred), ixs)
                dev_acc.extend(acc.tolist())
                dev_nll.extend(nll.tolist())

            accsummary = tf.Summary(value=[
                tf.Summary.Value(tag="dev_perf/acc",
                                 simple_value=np.mean(dev_acc))
            ])
            nllsummary = tf.Summary(value=[
                tf.Summary.Value(tag="dev_perf/nll",
                                 simple_value=np.mean(dev_nll))
            ])

            disc.summary_writer.add_summary(accsummary, global_step=i)
            disc.summary_writer.add_summary(nllsummary, global_step=i)

            print(np.mean(dev_acc))
            if np.mean(dev_nll) < best_oos_nll:
                best_oos_nll = np.mean(dev_nll)
                disc.save_to_chkpt(FLAGS.model_dir, i)
                print("New best NLL, saving")
def main(_):
    if FLAGS.testing:
        print('TEST MODE - reducing model size')
        FLAGS.qa_encoder_units =32
        FLAGS.qa_match_units=32
        FLAGS.qa_batch_size =16
        FLAGS.embedding_size=50

    run_id = str(int(time.time()))

    chkpt_path = FLAGS.model_dir+'qa/'+run_id
    restore_path=FLAGS.model_dir+'qa/1529056867'

    if not os.path.exists(chkpt_path):
        os.makedirs(chkpt_path)

    train_data = loader.load_squad_triples(FLAGS.data_path, False)
    dev_data = loader.load_squad_triples(FLAGS.data_path, dev=True, ans_list=True)

    train_data = filter_squad(train_data, window_size=FLAGS.filter_window_size, max_tokens=FLAGS.filter_max_tokens)
    # dev_data = filter_squad(dev_data, window_size=FLAGS.filter_window_size, max_tokens=FLAGS.filter_max_tokens)

    if FLAGS.testing:
        train_data=train_data[:1000]
        num_dev_samples=100
    else:
        num_dev_samples=3000

    print('Loaded SQuAD with ',len(train_data),' triples')
    train_contexts, train_qs, train_as,train_a_pos = zip(*train_data)
    dev_contexts, dev_qs, dev_as,dev_a_pos = zip(*dev_data)

    if FLAGS.restore:
        with open(restore_path+'/vocab.json') as f:
            vocab = json.load(f)
    else:
        vocab = loader.get_vocab(train_contexts+train_qs, tf.app.flags.FLAGS.qa_vocab_size)
        with open(chkpt_path+'/vocab.json', 'w') as outfile:
            json.dump(vocab, outfile)



    model = MpcmQa(vocab)
    with model.graph.as_default():
        saver = tf.train.Saver()



    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_limit, allow_growth = True)
    with tf.Session(graph=model.graph, config=tf.ConfigProto(gpu_options=gpu_options)) as sess:

        summary_writer = tf.summary.FileWriter(FLAGS.log_directory+'qa/'+run_id, sess.graph)

        if FLAGS.restore:
            saver.restore(sess, restore_path+ '/model.checkpoint')
            start_e=40#FLAGS.qa_num_epochs
            print('Loaded model')
        else:
            print("Building graph, loading glove")
            start_e=0
            sess.run(tf.global_variables_initializer())

        num_steps_train = len(train_data)//FLAGS.qa_batch_size
        num_steps_dev = num_dev_samples//FLAGS.qa_batch_size

        f1summary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/f1",
                                         simple_value=0.0)])
        emsummary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/em",
                                  simple_value=0.0)])

        summary_writer.add_summary(f1summary, global_step=start_e*num_steps_train)
        summary_writer.add_summary(emsummary, global_step=start_e*num_steps_train)

        best_oos_nll=1e6

        for e in range(start_e,start_e+FLAGS.qa_num_epochs):
            np.random.shuffle(train_data)
            train_contexts, train_qs, train_as,train_a_pos = zip(*train_data)

            for i in tqdm(range(num_steps_train), desc='Epoch '+str(e)):
                # TODO: this keeps coming up - refactor it
                batch_contexts = train_contexts[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size]
                batch_questions = train_qs[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size]
                batch_ans_text = train_as[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size]
                batch_answer_charpos = train_a_pos[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size]

                batch_answers=[]
                for j, ctxt in enumerate(batch_contexts):
                    ans_span=char_pos_to_word(ctxt.encode(), [t.encode() for t in tokenise(ctxt, asbytes=False)], batch_answer_charpos[j])
                    ans_span=(ans_span, ans_span+len(tokenise(batch_ans_text[j],asbytes=False))-1)
                    batch_answers.append(ans_span)

                # print(batch_answers[:3])
                # exit()
                # run_metadata = tf.RunMetadata()
                # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
                _,summ, pred = sess.run([model.optimizer, model.train_summary, model.pred_span],
                        feed_dict={model.context_in: get_padded_batch(batch_contexts,vocab),
                                model.question_in: get_padded_batch(batch_questions,vocab),
                                model.answer_spans_in: batch_answers,
                                model.is_training: True})
                                # ,run_metadata=run_metadata, options=run_options)

                summary_writer.add_summary(summ, global_step=(e*num_steps_train+i))
                # summary_writer.add_run_metadata(run_metadata, tag="step "+str(i), global_step=(e*num_steps_train+i))

                if i%FLAGS.eval_freq==0:
                    gold_str=[]
                    pred_str=[]
                    f1s = []
                    exactmatches= []
                    for b in range(FLAGS.qa_batch_size):
                        gold_str.append(" ".join(tokenise(batch_contexts[b],asbytes=False)[batch_answers[b][0]:batch_answers[b][1]+1]))
                        pred_str.append( " ".join(tokenise(batch_contexts[b],asbytes=False)[pred[b][0]:pred[b][1]+1]) )

                    f1s.extend([f1(gold_str[b], pred_str[b]) for b in range(FLAGS.qa_batch_size)])
                    exactmatches.extend([ np.product(pred[b] == batch_answers[b])*1.0 for b in range(FLAGS.qa_batch_size) ])

                    f1summary = tf.Summary(value=[tf.Summary.Value(tag="train_perf/f1",
                                                     simple_value=sum(f1s)/len(f1s))])
                    emsummary = tf.Summary(value=[tf.Summary.Value(tag="train_perf/em",
                                              simple_value=sum(exactmatches)/len(exactmatches))])

                    summary_writer.add_summary(f1summary, global_step=(e*num_steps_train+i))
                    summary_writer.add_summary(emsummary, global_step=(e*num_steps_train+i))


                    # saver.save(sess, chkpt_path+'/model.checkpoint')


            f1s=[]
            exactmatches=[]
            nlls=[]

            np.random.shuffle(dev_data)
            dev_subset = dev_data[:num_dev_samples]
            for i in tqdm(range(num_steps_dev), desc='Eval '+str(e)):
                dev_contexts,dev_qs,dev_as,dev_a_pos = zip(*dev_subset)
                batch_contexts = dev_contexts[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size]
                batch_questions = dev_qs[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size]
                batch_ans_text = dev_as[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size]
                batch_answer_charpos = dev_a_pos[i*FLAGS.qa_batch_size:(i+1)*FLAGS.qa_batch_size]

                batch_answers=[]
                for j, ctxt in enumerate(batch_contexts):
                    ans_span=char_pos_to_word(ctxt.encode(), [t.encode() for t in tokenise(ctxt, asbytes=False)], batch_answer_charpos[j][0])
                    ans_span=(ans_span, ans_span+len(tokenise(batch_ans_text[j][0],asbytes=False))-1)
                    batch_answers.append(ans_span)


                pred,nll = sess.run([model.pred_span, model.nll],
                        feed_dict={model.context_in: get_padded_batch(batch_contexts,vocab),
                                model.question_in: get_padded_batch(batch_questions,vocab),
                                model.answer_spans_in: batch_answers,
                                model.is_training: False})
                gold_str=[]
                pred_str=[]

                for b in range(FLAGS.qa_batch_size):
                    pred_str = " ".join(tokenise(batch_contexts[b],asbytes=False)[pred[b][0]:pred[b][1]+1])
                    this_f1=[]
                    this_em=[]
                    for a in range(len(batch_ans_text[b])):
                        this_f1.append(f1(normalize_answer(batch_ans_text[b][a]), normalize_answer(pred_str)))
                        this_em.append(1.0*(normalize_answer(batch_ans_text[b][a]) == normalize_answer(pred_str)))
                    f1s.append(max(this_f1))
                    exactmatches.append(max(this_em))
                nlls.extend(nll.tolist())
            f1summary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/f1",
                                             simple_value=sum(f1s)/len(f1s))])
            emsummary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/em",
                                      simple_value=sum(exactmatches)/len(exactmatches))])
            nllsummary = tf.Summary(value=[tf.Summary.Value(tag="dev_perf/nll",
                                      simple_value=np.mean(nlls))])

            summary_writer.add_summary(f1summary, global_step=((e+1)*num_steps_train))
            summary_writer.add_summary(emsummary, global_step=((e+1)*num_steps_train))
            summary_writer.add_summary(nllsummary, global_step=((e+1)*num_steps_train))

            mean_nll=np.mean(nlls)
            if mean_nll < best_oos_nll:
                print("New best NLL! ", mean_nll, " Saving... F1: ", np.mean(f1s))
                best_oos_nll = mean_nll
                saver.save(sess, chkpt_path+'/model.checkpoint')
            else:
                print("NLL not improved ", mean_nll)