Ejemplo n.º 1
0
    def train_step(self, contexts, questions, answers):
        config = tf.app.flags.FLAGS

        # query = zip(contexts, questions)
        toks = [word_tokenize(ctxt.replace("''", '" ').replace("``", '" ')) for ctxt in contexts]
        questions = [q.replace(loader.PAD,"").replace(loader.EOS,"") for q in questions]
        query = list(zip(contexts, questions))

        y1,y2 = zip(*answers)

        length=config.batch_size
        if len(query) < config.batch_size:
            length=len(query)
            query += [["blank","blank"] for i in range(config.batch_size-len(query))]
        feats=[convert_to_features(config, q, self.word_dictionary, self.char_dictionary) for q in query]
        c,ch,q,qh = zip(*feats)
        fd = {'context:0': c,
              'question:0': q,
              'context_char:0': ch,
              'question_char:0': qh,
              'answer_index1:0': y1,
              'answer_index2:0': y2,
              'dropout:0': config.dropout}

        _,loss = self.sess.run([self.model.train_op, self.model.loss], feed_dict = fd)
        return loss
Ejemplo n.º 2
0
    def get_ans(self, contexts, questions):
        config = tf.app.flags.FLAGS

        # query = zip(contexts, questions)
        toks = [word_tokenize(ctxt.replace("''", '" ').replace("``", '" ')) for ctxt in contexts]
        questions = [q.replace(loader.PAD,"").replace(loader.EOS,"") for q in questions]
        query = list(zip(contexts, questions))

        length=config.batch_size
        if len(query) < config.batch_size:
            length=len(query)
            query += [["blank","blank"] for i in range(config.batch_size-len(query))]
        feats=[convert_to_features(config, q, self.word_dictionary, self.char_dictionary) for q in query]
        c,ch,q,qh = zip(*feats)
        fd = {'context:0': c,
              'question:0': q,
              'context_char:0': ch,
              'question_char:0': qh}

        yp1,yp2 = self.sess.run([self.model.yp1, self.model.yp2], feed_dict = fd)
        spans = list(zip(yp1, yp2))[:length]
        return [" ".join(toks[i][span[0]:span[1]+1]) for i,span in enumerate(spans)]
Ejemplo n.º 3
0
def main(_):
    from tqdm import tqdm
    FLAGS = tf.app.flags.FLAGS

    # questions = ["What colour is the car?","When was the car made?","Where was the date?", "What was the dog called?","Who was the oldest cat?"]
    # contexts=["The car is green, and was built in 1985. This sentence should make it less likely to return the date, when asked about a cat. The oldest cat was called creme puff and lived for many years!" for i in range(len(questions))]

    trainable = False

    squad_train_full = loader.load_squad_triples(path="./data/")
    squad_dev_full = loader.load_squad_triples(path="./data/",
                                               dev=True,
                                               ans_list=True)

    para_limit = FLAGS.test_para_limit
    ques_limit = FLAGS.test_ques_limit
    char_limit = FLAGS.char_limit

    def filter_func(example, is_test=False):
        return len(example["context_tokens"]) > para_limit or \
               len(example["ques_tokens"]) > ques_limit or \
               (example["y2s"][0] - example["y1s"][0]) > ans_limit

    qa = QANetInstance()
    qa.load_from_chkpt("./models/saved/qanet2/", trainable=trainable)

    squad_train = []
    for x in squad_train_full:
        c_toks = word_tokenize(x[0])
        q_toks = word_tokenize(x[1])
        if len(c_toks) < para_limit and len(q_toks) < ques_limit:
            squad_train.append(x)

    squad_dev = []
    for x in squad_dev_full:
        c_toks = word_tokenize(x[0])
        q_toks = word_tokenize(x[1])
        if len(c_toks) < para_limit and len(q_toks) < ques_limit:
            squad_dev.append(x)

    num_train_steps = len(squad_train) // FLAGS.batch_size
    num_eval_steps = len(squad_dev) // FLAGS.batch_size

    best_f1 = 0
    if trainable:
        run_id = str(int(time.time()))
        chkpt_path = FLAGS.model_dir + 'qanet/' + run_id
        if not os.path.exists(chkpt_path):
            os.makedirs(chkpt_path)

        summary_writer = tf.summary.FileWriter(
            FLAGS.log_directory + 'qanet/' + run_id, qa.model.graph)
        for i in tqdm(range(FLAGS.qa_num_epochs * num_train_steps)):
            if i % num_train_steps == 0:
                print('Shuffling training set')
                np.random.shuffle(squad_train)

            this_batch = squad_train[i * FLAGS.batch_size:(i + 1) *
                                     FLAGS.batch_size]
            batch_contexts, batch_questions, batch_ans_text, batch_ans_charpos = zip(
                *this_batch)

            batch_answers = []
            for j, ctxt in enumerate(batch_contexts):
                ans_span = char_pos_to_word(
                    ctxt.encode(), [t.encode() for t in word_tokenize(ctxt)],
                    batch_ans_charpos[j])
                ans_span = (np.eye(FLAGS.test_para_limit)[ans_span],
                            np.eye(FLAGS.test_para_limit)
                            [ans_span + len(word_tokenize(batch_ans_text[j])) -
                             1])
                batch_answers.append(ans_span)
            this_loss = qa.train_step(batch_contexts, batch_questions,
                                      batch_answers)

            if i % 50 == 0:
                losssummary = tf.Summary(value=[
                    tf.Summary.Value(tag="train_loss/loss",
                                     simple_value=np.mean(this_loss))
                ])

                summary_writer.add_summary(losssummary, global_step=i)

            if i > 0 and i % 1000 == 0:
                qa_f1s = []
                qa_em = []

                for j in tqdm(range(num_eval_steps)):
                    this_batch = squad_dev[j * FLAGS.batch_size:(j + 1) *
                                           FLAGS.batch_size]

                    spans = qa.get_ans([x[0] for x in this_batch],
                                       [x[1] for x in this_batch])

                    for b in range(len(this_batch)):
                        qa_f1s.append(
                            metrics.f1(
                                metrics.normalize_answer(this_batch[b][2]),
                                metrics.normalize_answer(spans[b])))
                        qa_em.append(
                            1.0 * (metrics.normalize_answer(this_batch[b][2])
                                   == metrics.normalize_answer(spans[b])))

                f1summary = tf.Summary(value=[
                    tf.Summary.Value(tag="dev_perf/f1",
                                     simple_value=np.mean(qa_f1s))
                ])

                summary_writer.add_summary(f1summary, global_step=i)
                if np.mean(qa_f1s) > best_f1:
                    print("New best F1! ", np.mean(qa_f1s), " Saving...")
                    best_f1 = np.mean(qa_f1s)
                    qa.saver.save(qa.sess, chkpt_path + '/model.checkpoint')

    qa_f1s = []
    qa_em = []

    for i in tqdm(range(num_eval_steps)):
        this_batch = squad_dev[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]

        spans = qa.get_ans([x[0] for x in this_batch],
                           [x[1] for x in this_batch])

        for b in range(len(this_batch)):
            this_f1s = []
            this_em = []
            for a in range(len(this_batch[b][2])):
                this_f1s.append(
                    metrics.f1(metrics.normalize_answer(this_batch[b][2][a]),
                               metrics.normalize_answer(spans[b])))
                this_em.append(1.0 *
                               (metrics.normalize_answer(this_batch[b][2][a])
                                == metrics.normalize_answer(spans[b])))
            qa_em.append(max(this_em))
            qa_f1s.append(max(this_f1s))

        if i == 0:
            print(qa_f1s, qa_em)
            print(this_batch[0])
            print(spans[0])

    print('EM: ', np.mean(qa_em))
    print('F1: ', np.mean(qa_f1s))