Beispiel #1
0
def main(_):

    model_type=FLAGS.model_type
    # chkpt_path = FLAGS.model_dir+'saved/qgen-maluuba-crop-glove-smart'
    # chkpt_path = FLAGS.model_dir+'qgen-saved/MALUUBA-CROP-LATENT/1533247183'
    disc_path = FLAGS.model_dir+'saved/discriminator-trained-latent'
    chkpt_path = FLAGS.model_dir+'qgen/'+ model_type+'/'+FLAGS.eval_model_id

    # load dataset
    # train_data = loader.load_squad_triples(FLAGS.data_path, False)
    dev_data = loader.load_squad_triples(FLAGS.data_path, dev=FLAGS.eval_on_dev, test=FLAGS.eval_on_test)

    if len(dev_data) < FLAGS.num_eval_samples:
        exit('***ERROR*** Eval dataset is smaller than the num_eval_samples flag!')
    if len(dev_data) > FLAGS.num_eval_samples:
        print('***WARNING*** Eval dataset is larger than the num_eval_samples flag!')

    # train_contexts_unfilt, _,_,train_a_pos_unfilt = zip(*train_data)
    dev_contexts_unfilt, _,_,dev_a_pos_unfilt = zip(*dev_data)

    if FLAGS.filter_window_size_before >-1:
        # train_data = preprocessing.filter_squad(train_data, window_size=FLAGS.filter_window_size, max_tokens=FLAGS.filter_max_tokens)
        dev_data = preprocessing.filter_squad(dev_data, window_size_before=FLAGS.filter_window_size_before, window_size_after=FLAGS.filter_window_size_after, max_tokens=FLAGS.filter_max_tokens)


    # print('Loaded SQuAD with ',len(train_data),' triples')
    print('Loaded SQuAD dev set with ',len(dev_data),' triples')
    # train_contexts, train_qs, train_as,train_a_pos = zip(*train_data)
    dev_contexts, dev_qs, dev_as, dev_a_pos = zip(*dev_data)


    # vocab = loader.get_vocab(train_contexts, tf.app.flags.FLAGS.vocab_size)
    with open(chkpt_path+'/vocab.json') as f:
        vocab = json.load(f)

    with SquadStreamer(vocab, FLAGS.eval_batch_size, 1, shuffle=False) as dev_data_source:

        glove_embeddings = loader.load_glove(FLAGS.data_path)


        # Create model
        if model_type[:7] == "SEQ2SEQ":
            model = Seq2SeqModel(vocab, training_mode=False)
        elif model_type[:2] == "RL":
            # TEMP - no need to spin up the LM or QA model at eval time
            FLAGS.qa_weight = 0
            FLAGS.lm_weight = 0
            model = RLModel(vocab, training_mode=False)
        else:
            exit("Unrecognised model type: "+model_type)

        with model.graph.as_default():
            saver = tf.train.Saver()

        if FLAGS.eval_metrics:
            lm = LstmLmInstance()
            # qa = MpcmQaInstance()
            qa = QANetInstance()

            lm.load_from_chkpt(FLAGS.model_dir+'saved/lmtest')
            # qa.load_from_chkpt(FLAGS.model_dir+'saved/qatest')
            qa.load_from_chkpt(FLAGS.model_dir+'saved/qanet2')

            discriminator = DiscriminatorInstance(trainable=False, path=disc_path)

        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_limit)
        with tf.Session(graph=model.graph, config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
            if not os.path.exists(chkpt_path):
                exit('Checkpoint path doesnt exist! '+chkpt_path)
            # summary_writer = tf.summary.FileWriter(FLAGS.log_directory+"eval/"+str(int(time.time())), sess.graph)

            saver.restore(sess, tf.train.latest_checkpoint(chkpt_path))
            # print('Loading not implemented yet')
            # else:
            #     sess.run(tf.global_variables_initializer())
            #     sess.run(model.glove_init_ops)

            num_steps = FLAGS.num_eval_samples//FLAGS.eval_batch_size

            # Initialise the dataset

            # np.random.shuffle(dev_data)
            dev_data_source.initialise(dev_data)

            f1s=[]
            bleus=[]
            qa_scores=[]
            qa_scores_gold=[]
            lm_scores=[]
            nlls=[]
            disc_scores=[]
            sowe_similarities=[]
            copy_probs=[]

            qgolds=[]
            qpreds=[]
            qpred_ids=[]
            qgold_ids=[]
            ctxts=[]
            answers=[]
            ans_positions=[]

            metric_individuals=[]
            res=[]
            for e in range(1):
                for i in tqdm(range(num_steps), desc='Epoch '+str(e)):
                    dev_batch, curr_batch_size = dev_data_source.get_batch()
                    pred_batch,pred_beam,pred_beam_lens,pred_ids,pred_lens,gold_batch, gold_lens,gold_ids,ctxt,ctxt_len,ans,ans_len,nll,copy_prob= sess.run([model.q_hat_beam_string, model.q_hat_full_beam_str, model.q_hat_full_beam_lens,model.q_hat_beam_ids,model.q_hat_beam_lens,model.question_raw, model.question_length, model.question_ids, model.context_raw, model.context_length, model.answer_locs, model.answer_length, model.nll, model.mean_copy_prob], feed_dict={model.input_batch: dev_batch ,model.is_training:False})

                    unfilt_ctxt_batch = [dev_contexts_unfilt[ix] for ix in dev_batch[3]]
                    a_text_batch = ops.byte_token_array_to_str(dev_batch[2][0], dev_batch[2][2], is_array=False)
                    unfilt_apos_batch = [dev_a_pos_unfilt[ix] for ix in dev_batch[3]]

                    # subtract 1 to remove the "end sent token"
                    pred_q_batch = [q.replace(' </Sent>',"").replace(" <PAD>","") for q in ops.byte_token_array_to_str(pred_batch, pred_lens-1)]

                    ctxts.extend(unfilt_ctxt_batch)
                    answers.extend(a_text_batch)
                    ans_positions.extend([dev_a_pos_unfilt[ix] for ix in dev_batch[3]])
                    copy_probs.extend(copy_prob.tolist())



                    # get QA score

                    # gold_str=[]
                    # pred_str=[]


                    gold_ans = ops.byte_token_array_to_str(dev_batch[2][0], dev_batch[2][2], is_array=False)
                    # pred_str = ops.byte_token_array_to_str([dev_batch[0][0][b][qa_pred[b][0]:qa_pred[b][1]] for b in range(curr_batch_size)], is_array=False)
                    nlls.extend(nll.tolist())

                    if FLAGS.eval_metrics:
                        qa_pred = qa.get_ans(unfilt_ctxt_batch, ops.byte_token_array_to_str(pred_batch, pred_lens))
                        gold_qa_pred = qa.get_ans(unfilt_ctxt_batch, ops.byte_token_array_to_str(dev_batch[1][0], dev_batch[1][3]))

                        qa_score_batch = [metrics.f1(metrics.normalize_answer(gold_ans[b]), metrics.normalize_answer(qa_pred[b])) for b in range(curr_batch_size)]
                        qa_score_gold_batch = [metrics.f1(metrics.normalize_answer(gold_ans[b]), metrics.normalize_answer(gold_qa_pred[b])) for b in range(curr_batch_size)]
                        lm_score_batch = lm.get_seq_perplexity(pred_q_batch).tolist()
                        disc_score_batch = discriminator.get_pred(unfilt_ctxt_batch, pred_q_batch, gold_ans, unfilt_apos_batch).tolist()

                    for b, pred in enumerate(pred_batch):
                        pred_str = pred_q_batch[b].replace(' </Sent>',"").replace(" <PAD>","")
                        gold_str = tokens_to_string(gold_batch[b][:gold_lens[b]-1])
                        f1s.append(metrics.f1(gold_str, pred_str))
                        bleus.append(metrics.bleu(gold_str, pred_str))
                        qgolds.append(gold_str)
                        qpreds.append(pred_str)

                        # calc cosine similarity between sums of word embeddings
                        pred_sowe = np.sum(np.asarray([glove_embeddings[w] if w in glove_embeddings.keys() else np.zeros((FLAGS.embedding_size,)) for w in preprocessing.tokenise(pred_str ,asbytes=False)]) ,axis=0)
                        gold_sowe = np.sum(np.asarray([glove_embeddings[w] if w in glove_embeddings.keys() else np.zeros((FLAGS.embedding_size,)) for w in preprocessing.tokenise(gold_str ,asbytes=False)]) ,axis=0)
                        this_similarity = np.inner(pred_sowe, gold_sowe)/np.linalg.norm(pred_sowe, ord=2)/np.linalg.norm(gold_sowe, ord=2)

                        sowe_similarities.append(this_similarity)



                        this_metric_dict={
                            'f1':f1s[-1],
                            'bleu': bleus[-1],
                            'nll': nlls[-1],
                            'sowe': sowe_similarities[-1]
                            }
                        if FLAGS.eval_metrics:
                            this_metric_dict={
                            **this_metric_dict,
                            'qa': qa_score_batch[b],
                            'lm': lm_score_batch[b],
                            'disc': disc_score_batch[b]}
                            qa_scores.extend(qa_score_batch)
                            lm_scores.extend(lm_score_batch)
                            disc_scores.extend(disc_score_batch)
                        metric_individuals.append(this_metric_dict)

                        res.append({
                            'c':unfilt_ctxt_batch[b],
                            'q_pred': pred_str,
                            'q_gold': gold_str,
                            'a_pos': unfilt_apos_batch[b],
                            'a_text': a_text_batch[b],
                            'metrics': this_metric_dict,

                            'q_pred_ids': pred_ids.tolist()[b],
                            'q_gold_ids': dev_batch[1][1][b].tolist()

                        })

                    # Quick output
                    if i==0:
                        # print(copy_prob.tolist())
                        # print(copy_probs)
                        pred_str = tokens_to_string(pred_batch[0][:pred_lens[0]-1])
                        gold_str = tokens_to_string(gold_batch[0][:gold_lens[0]-1])
                        # print(pred_str)
                        print(qpreds[0])
                        print(gold_str)


                        title=chkpt_path
                        out_str = output_eval(title,pred_batch,  pred_ids, pred_lens, gold_batch, gold_lens, ctxt, ctxt_len, ans, ans_len)
                        with open(FLAGS.log_directory+'out_eval_'+model_type+'.htm', 'w', encoding='utf-8') as fp:
                            fp.write(out_str)

            # res = list(zip(qpreds,qgolds,ctxts,answers,ans_positions,metric_individuals))
            metric_dict={
                'f1':np.mean(f1s),
                'bleu': metrics.bleu_corpus(qgolds, qpreds),
                'nll':np.mean(nlls),
                'sowe': np.mean(sowe_similarities)
                }
            if FLAGS.eval_metrics:
                metric_dict={**metric_dict,
                'qa':np.mean(qa_scores),
                'lm':np.mean(lm_scores),
                'disc': np.mean(disc_scores)}
            # print(res)
            with open(FLAGS.log_directory+'out_eval_'+model_type+("_test" if FLAGS.eval_on_test else "")+("_train" if (not FLAGS.eval_on_dev and not FLAGS.eval_on_test) else "")+'.json', 'w', encoding='utf-8') as fp:
                json.dump({"metrics":metric_dict, "results": res}, fp)


            print("F1: ", np.mean(f1s))
            print("BLEU: ", metrics.bleu_corpus(qgolds, qpreds))
            print("NLL: ", np.mean(nlls))
            print("SOWE: ", np.mean(sowe_similarities))

            print("Copy prob: ", np.mean(copy_probs))
            if FLAGS.eval_metrics:
                print("QA: ", np.mean(qa_scores))
                print("LM: ", np.mean(lm_scores))
                print("Disc: ", np.mean(disc_scores))
        return "how long"
    elif ner in ["ORDINAL"]:
        return "which"
    elif ner in ["MONEY", "PERCENT"]:
        return "how much"
    elif ner in ["NATIONALITY"]:
        return "what nationality"
    elif ner in ["CAUSE_OF_DEATH"]:
        return "how"
    else:
        exit("Unknown ner " + ner)


fallback = baseline_model.BaselineModel()

lm = LstmLmInstance()
qa = MpcmQaInstance()

lm.load_from_chkpt(FLAGS.model_dir + 'saved/lmtest')
print("LM loaded")
qa.load_from_chkpt(FLAGS.model_dir + 'saved/qatest')
print("QA loaded")

lm_vocab = lm.vocab
qa_vocab = qa.vocab

f1s = []
bleus = []
qa_scores = []
lm_scores = []
    def modify_seq2seq_model(self):
        print('Modifying Seq2Seq model to incorporate RL rewards')

        if FLAGS.policy_gradient:
            print('Building and loading LM')
            self.lm = LstmLmInstance()
            self.lm.load_from_chkpt(FLAGS.model_dir + 'saved/lmtest')

            print('Building and loading QA model')
            # self.qa = MpcmQaInstance()
            # self.qa.load_from_chkpt(FLAGS.model_dir+'saved/qatest')
            self.qa = QANetInstance()
            self.qa.load_from_chkpt(FLAGS.model_dir + 'saved/qanet2')

        with self.graph.as_default():

            self.lm_score = tf.placeholder(tf.float32, [None], "lm_score")
            self.qa_score = tf.placeholder(tf.float32, [None], "qa_score")
            self.disc_score = tf.placeholder(tf.float32, [None], "disc_score")
            self.bleu_score = tf.placeholder(tf.float32, [None], "bleu_score")
            self.rl_lm_enabled = tf.placeholder_with_default(
                False, (), "rl_lm_enabled")
            self.rl_qa_enabled = tf.placeholder_with_default(
                False, (), "rl_qa_enabled")
            self.rl_disc_enabled = tf.placeholder_with_default(
                False, (), "rl_disc_enabled")
            self.rl_bleu_enabled = tf.placeholder_with_default(
                False, (), "rl_bleu_enabled")
            self.step = tf.placeholder(tf.int32, (), "step")

            with tf.variable_scope('rl_rewards'):
                # NOTE: This isnt obvious! If we feed in the generated Qs as the gold with a reward,
                # we get REINFORCE. If we feed in a reward of 1.0 with an actual gold Q, we get cross entropy.
                # So we can combine both in the same set of ops, but need to construct batches appropriately
                mask = tf.one_hot(self.question_ids,
                                  depth=len(self.vocab) + FLAGS.max_copy_size)

                self.lm_loss = -1.0 * self.lm_score * tf.reduce_sum(
                    tf.reduce_sum(safe_log(self.q_hat) * mask, axis=[2]) *
                    self.target_weights,
                    axis=1) / tf.cast(self.question_length, tf.float32)
                self.qa_loss = -1.0 * self.qa_score * tf.reduce_sum(
                    tf.reduce_sum(safe_log(self.q_hat) * mask, axis=[2]) *
                    self.target_weights,
                    axis=1) / tf.cast(self.question_length, tf.float32)
                self.disc_loss = -1.0 * self.disc_score * tf.reduce_sum(
                    tf.reduce_sum(safe_log(self.q_hat) * mask, axis=[2]) *
                    self.target_weights,
                    axis=1) / tf.cast(self.question_length, tf.float32)
                self.bleu_loss = -1.0 * self.bleu_score * tf.reduce_sum(
                    tf.reduce_sum(safe_log(self.q_hat) * mask, axis=[2]) *
                    self.target_weights,
                    axis=1) / tf.cast(self.question_length, tf.float32)

            pg_loss = tf.cond(self.rl_lm_enabled, lambda: self.lm_loss, lambda: tf.constant([0.0])) + \
                tf.cond(self.rl_qa_enabled, lambda: self.qa_loss, lambda: tf.constant([0.0])) + \
                tf.cond(self.rl_disc_enabled, lambda: self.disc_loss, lambda: tf.constant([0.0])) + \
                tf.cond(self.rl_bleu_enabled, lambda: self.bleu_loss, lambda: tf.constant([0.0]))

            curr_batch_size_pg = tf.shape(self.answer_ids)[0] // 2

            # log the first half of the batch - this is the RL part
            self._train_summaries.append(
                tf.summary.scalar("train_loss/pg_loss_rl",
                                  tf.reduce_mean(
                                      pg_loss[:curr_batch_size_pg])))
            self._train_summaries.append(
                tf.summary.scalar("train_loss/pg_loss_ml",
                                  tf.reduce_mean(
                                      pg_loss[curr_batch_size_pg:])))

            self.pg_loss = tf.reduce_mean(pg_loss, axis=[0])

            self._train_summaries.append(
                tf.summary.scalar("train_loss/pg_loss", self.pg_loss))

            # this needs rebuilding again
            self.train_summary = tf.summary.merge(self._train_summaries)

            # dont bother calculating gradients if not training
            if self.training_mode:
                # these need to be redefined with the correct inputs
                # Calculate and clip gradients
                params = tf.trainable_variables()
                gradients = tf.gradients(self.pg_loss, params)
                clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5)

                # Optimization
                lr = FLAGS.learning_rate if not FLAGS.lr_schedule else tf.minimum(
                    1.0,
                    tf.cast(self.step, tf.float32) *
                    0.001) * FLAGS.learning_rate
                self.pg_optimizer = tf.train.AdamOptimizer(lr).apply_gradients(
                    zip(clipped_gradients,
                        params)) if self.training_mode else tf.no_op()

            total_params()
def main(_):
    model = FileLoaderModel('./models/BASELINE')
    squad = loader.load_squad_triples(FLAGS.data_path, True, as_dict=True)

    disc_path = FLAGS.model_dir + 'saved/discriminator-trained-latent'

    glove_embeddings = loader.load_glove(FLAGS.data_path)

    if FLAGS.eval_metrics:
        lm = LstmLmInstance()
        # qa = MpcmQaInstance()
        qa = QANetInstance()

        lm.load_from_chkpt(FLAGS.model_dir + 'saved/lmtest')
        # qa.load_from_chkpt(FLAGS.model_dir+'saved/qatest')
        qa.load_from_chkpt(FLAGS.model_dir + 'saved/qanet')

        discriminator = DiscriminatorInstance(trainable=False, path=disc_path)

    f1s = []
    bleus = []
    qa_scores = []
    qa_scores_gold = []
    lm_scores = []
    nlls = []
    disc_scores = []
    sowe_similarities = []

    qgolds = []
    qpreds = []
    ctxts = []
    answers = []
    ans_positions = []

    metric_individuals = []
    res = []

    missing = 0

    for id, el in tqdm(squad.items()):

        unfilt_ctxt_batch = [el[0]]
        a_text_batch = [el[2]]
        a_pos_batch = [el[3]]

        ctxts.extend(unfilt_ctxt_batch)
        answers.extend(a_text_batch)
        ans_positions.extend(a_pos_batch)

        pred_str = model.get_q(id)

        if pred_str is None:
            missing += 1
            continue
        gold_str = el[1]

        if FLAGS.eval_metrics:
            qa_pred = qa.get_ans(unfilt_ctxt_batch, [pred_str])
            gold_qa_pred = qa.get_ans(unfilt_ctxt_batch, [gold_str])

            qa_score = metrics.f1(el[2].lower(), qa_pred[0].lower())
            qa_score_gold = metrics.f1(el[2].lower(), gold_qa_pred[0].lower())
            lm_score = lm.get_seq_perplexity([pred_str]).tolist()
            disc_score = discriminator.get_pred(unfilt_ctxt_batch, [pred_str],
                                                a_text_batch,
                                                a_pos_batch).tolist()[0]

        f1s.append(metrics.f1(gold_str, pred_str))
        bleus.append(metrics.bleu(gold_str, pred_str))
        qgolds.append(gold_str)
        qpreds.append(pred_str)

        # calc cosine similarity between sums of word embeddings
        pred_sowe = np.sum(np.asarray([
            glove_embeddings[w] if w in glove_embeddings.keys() else np.zeros(
                (FLAGS.embedding_size, ))
            for w in preprocessing.tokenise(pred_str, asbytes=False)
        ]),
                           axis=0)
        gold_sowe = np.sum(np.asarray([
            glove_embeddings[w] if w in glove_embeddings.keys() else np.zeros(
                (FLAGS.embedding_size, ))
            for w in preprocessing.tokenise(gold_str, asbytes=False)
        ]),
                           axis=0)
        this_similarity = np.inner(pred_sowe, gold_sowe) / np.linalg.norm(
            pred_sowe, ord=2) / np.linalg.norm(gold_sowe, ord=2)

        sowe_similarities.append(this_similarity)

        this_metric_dict = {
            'f1': f1s[-1],
            'bleu': bleus[-1],
            'nll': 0,
            'sowe': sowe_similarities[-1]
        }
        if FLAGS.eval_metrics:
            this_metric_dict = {
                **this_metric_dict, 'qa': qa_score,
                'lm': lm_score,
                'disc': disc_score
            }
            qa_scores.append(qa_score)
            lm_scores.append(lm_score)
            disc_scores.append(disc_score)
        metric_individuals.append(this_metric_dict)

        res.append({
            'c': el[0],
            'q_pred': pred_str,
            'q_gold': gold_str,
            'a_pos': el[3],
            'a_text': el[2],
            'metrics': this_metric_dict
        })

    metric_dict = {
        'f1': np.mean(f1s),
        'bleu': np.mean(bleus),
        'nll': 0,
        'sowe': np.mean(sowe_similarities)
    }
    if FLAGS.eval_metrics:
        metric_dict = {
            **metric_dict, 'qa': np.mean(qa_scores),
            'lm': np.mean(lm_scores),
            'disc': np.mean(disc_scores)
        }
    # print(res)
    with open(FLAGS.log_dir + 'out_eval_BASELINE' +
              ("_train" if not FLAGS.eval_on_dev else "") + '.json',
              'w',
              encoding='utf-8') as fp:
        json.dump({"metrics": metric_dict, "results": res}, fp)

    print("F1: ", np.mean(f1s))
    print("BLEU: ", np.mean(bleus))
    print("NLL: ", 0)
    print("SOWE: ", np.mean(sowe_similarities))
    if FLAGS.eval_metrics:
        print("QA: ", np.mean(qa_scores))
        print("LM: ", np.mean(lm_scores))
        print("Disc: ", np.mean(disc_scores))

    print(missing, " ids were missing")