예제 #1
0
파일: prepare.py 프로젝트: wly-thu/QGforQA
def prepare(config):
    # process files
    train_examples, train_eval = process_file(config.train_file,
                                              lower_word=config.lower_word)
    dev_examples, dev_eval = process_file(config.dev_file,
                                          lower_word=config.lower_word)

    with open(config.word_dictionary, "r") as fh:
        word2idx_dict = json.load(fh)
    train_meta = build_features(config, train_examples, "train",
                                config.train_record_file, word2idx_dict)
    dev_meta = build_features(config, dev_examples, "dev",
                              config.dev_record_file, word2idx_dict)

    save(config.train_eval_file, train_eval, message="train eval")
    save(config.dev_eval_file, dev_eval, message="dev eval")
    save(config.train_meta, train_meta, message="train meta")
    save(config.dev_meta, dev_meta, message="dev meta")
예제 #2
0
파일: main.py 프로젝트: wly-thu/QGforQA
def train_qqp_qap(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.ner_emb_file, "r") as fh:
        ner_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.label_emb_file, "r") as fh:
        label_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)
    with open(config.word_dictionary, "r") as fh:
        word_dictionary = json.load(fh)
    with open(config.char_dictionary, "r") as fh:
        char_dictionary = json.load(fh)
    with h5py.File(config.embedding_file, 'r') as fin:
        embed_weights = fin["embedding"][...]
        elmo_word_mat = np.zeros(
            (embed_weights.shape[0] + 1, embed_weights.shape[1]),
            dtype=np.float32)
    elmo_word_mat[1:, :] = embed_weights

    id2word = {word_dictionary[w]: w for w in word_dictionary}
    dev_total = meta["total"]
    best_bleu, best_ckpt = 0., 0
    print("Building model...")
    parser = get_record_parser(config)
    graph_qg = tf.Graph()
    graph_qqp = tf.Graph()
    graph_qa = tf.Graph()
    with graph_qg.as_default():
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config)
        dev_dataset = get_dataset(config.dev_record_file, parser, config)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

    with graph_qg.as_default() as g:
        model_qg = QGRLModel(config, word_mat, elmo_word_mat, label_mat,
                             pos_mat, ner_mat)
        model_qg.build_graph()
        model_qg.add_train_op()
    with graph_qqp.as_default() as g:
        model_qqp = QPCModel(config, dev=True, trainable=False, graph=g)
    with graph_qa.as_default() as g:
        model_qa = BidafQA(config,
                           word_mat,
                           char_mat,
                           dev=True,
                           trainable=False)

    sess_qg = tf.Session(graph=graph_qg)
    sess_qqp = tf.Session(graph=graph_qqp)
    sess_qa = tf.Session(graph=graph_qa)

    writer = tf.summary.FileWriter(config.output_dir)
    with sess_qg.as_default():
        with graph_qg.as_default():
            sess_qg.run(tf.global_variables_initializer())
            saver_qg = tf.train.Saver(max_to_keep=1000,
                                      var_list=[
                                          p for p in tf.global_variables()
                                          if "word_mat" not in p.name
                                      ])
            if os.path.exists(os.path.join(config.output_dir, "checkpoint")):
                print(tf.train.latest_checkpoint(config.output_dir))
                saver_qg.restore(sess_qg,
                                 tf.train.latest_checkpoint(config.output_dir))
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qg_ckpt = json.load(fh)
                    best_bleu, best_ckpt = float(
                        best_qg_ckpt["best_bleu"]), int(
                            best_qg_ckpt["best_ckpt"])
    with sess_qqp.as_default():
        with graph_qqp.as_default():
            sess_qqp.run(tf.global_variables_initializer())
            saver_qqp = tf.train.Saver(max_to_keep=1000,
                                       var_list=[
                                           p for p in tf.global_variables()
                                           if "word_mat" not in p.name
                                       ])
            if os.path.exists(config.best_ckpt_qpc):
                with open(config.best_ckpt_qpc, "r") as fh:
                    best_qpc_ckpt = json.load(fh)
                    best_ckpt = int(best_qpc_ckpt["best_ckpt"])
                print("{}/model_{}.ckpt".format(config.output_dir_qpc,
                                                best_ckpt))
                saver_qqp.restore(
                    sess_qqp,
                    "{}/model_{}.ckpt".format(config.output_dir_qpc,
                                              best_ckpt))
            else:
                print("NO the best QPC model to load!")
                exit()
    with sess_qa.as_default():
        with graph_qa.as_default():
            model_qa.build_graph()
            model_qa.add_train_op()
            sess_qa.run(tf.global_variables_initializer())
            saver_qa = tf.train.Saver(max_to_keep=1000,
                                      var_list=[
                                          p for p in tf.global_variables()
                                          if "word_mat" not in p.name
                                      ])
            if os.path.exists(config.best_ckpt_qa):
                with open(config.best_ckpt_qa, "r") as fh:
                    best_qpc_ckpt = json.load(fh)
                    best_ckpt = int(best_qpc_ckpt["best_ckpt"])
                print("{}/model_{}.ckpt".format(config.output_dir_qa,
                                                best_ckpt))
                saver_qa.restore(
                    sess_qa, "{}/model_{}.ckpt".format(config.output_dir_qa,
                                                       best_ckpt))
            else:
                print("NO the best QA model to load!")
                exit()

    global_step = max(sess_qg.run(model_qg.global_step), 1)
    train_next_element = train_iterator.get_next()
    for _ in tqdm(range(global_step, config.num_steps + 1)):
        global_step = sess_qg.run(model_qg.global_step) + 1
        para, para_unk, para_char, que, que_unk, que_char, labels, pos_tags, ner_tags, \
            que_labels, que_pos_tags, que_ner_tags, y1, y2, qa_id = sess_qg.run(train_next_element)
        symbols, symbols_rl = sess_qg.run(
            [model_qg.symbols, model_qg.symbols_rl],
            feed_dict={
                model_qg.para: para,
                model_qg.para_unk: para_unk,
                model_qg.que: que,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.qa_id: qa_id
            })
        # format sample for QA
        que_base, que_unk_base, que_char_base, que_rl, que_unk_rl, que_char_rl = \
            format_generated_ques_for_qa(train_eval_file, qa_id, symbols, symbols_rl, config.batch_size,
                                         config.ques_limit, config.char_limit, id2word, char_dictionary)
        label = np.zeros((config.batch_size, 2), dtype=np.int32)
        if global_step % 4 == 0:
            # QPP reward
            reward_base = sess_qqp.run(model_qqp.pos_prob,
                                       feed_dict={
                                           model_qqp.que1: que_unk,
                                           model_qqp.que2: que_unk_base,
                                           model_qqp.label: label,
                                           model_qqp.qa_id: qa_id,
                                       })
            reward_rl = sess_qqp.run(model_qqp.pos_prob,
                                     feed_dict={
                                         model_qqp.que1: que_unk,
                                         model_qqp.que2: que_unk_rl,
                                         model_qqp.label: label,
                                         model_qqp.qa_id: qa_id,
                                     })
            reward = [rr - rb for rr, rb in zip(reward_rl, reward_base)]
            mixing_ratio = 0.99
        else:
            # QAP reward
            base_qa_loss = sess_qa.run(model_qa.batch_loss,
                                       feed_dict={
                                           model_qa.para: para_unk,
                                           model_qa.para_char: para_char,
                                           model_qa.que: que_unk_base,
                                           model_qa.que_char: que_char_base,
                                           model_qa.y1: y1,
                                           model_qa.y2: y2,
                                           model_qa.qa_id: qa_id,
                                       })
            qa_loss = sess_qa.run(model_qa.batch_loss,
                                  feed_dict={
                                      model_qa.para: para_unk,
                                      model_qa.para_char: para_char,
                                      model_qa.que: que_unk_rl,
                                      model_qa.que_char: que_char_rl,
                                      model_qa.y1: y1,
                                      model_qa.y2: y2,
                                      model_qa.qa_id: qa_id,
                                  })
            reward_base = list(map(lambda x: np.exp(-x), list(base_qa_loss)))
            reward_rl = list(map(lambda x: np.exp(-x), list(qa_loss)))
            reward = list(
                map(lambda x: x[0] - x[1], zip(reward_rl, reward_base)))
            mixing_ratio = 0.97
        # train with rl
        loss_ml, _ = sess_qg.run(
            [model_qg.loss_ml, model_qg.train_op],
            feed_dict={
                model_qg.para: para,
                model_qg.para_unk: para_unk,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.que: que,
                model_qg.dropout: config.dropout,
                model_qg.qa_id: qa_id,
                model_qg.sampled_que: que_rl,
                model_qg.reward: reward,
                model_qg.lamda: mixing_ratio
            })
        if global_step % config.period == 0:
            loss_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/loss", simple_value=loss_ml),
            ])
            writer.add_summary(loss_sum, global_step)
            reward_base_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/reward_base",
                                 simple_value=np.mean(reward_base)),
            ])
            writer.add_summary(reward_base_sum, global_step)
            reward_rl_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/reward_rl",
                                 simple_value=np.mean(reward_rl)),
            ])
            writer.add_summary(reward_rl_sum, global_step)
        if global_step % config.checkpoint == 0:
            filename = os.path.join(config.output_dir,
                                    "model_{}.ckpt".format(global_step))
            saver_qg.save(sess_qg, filename)

            metrics = evaluate_batch(config,
                                     model_qg,
                                     config.val_num_batches,
                                     train_eval_file,
                                     sess_qg,
                                     train_iterator,
                                     id2word,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "train")

            metrics = evaluate_batch(config,
                                     model_qg,
                                     dev_total // config.batch_size + 1,
                                     dev_eval_file,
                                     sess_qg,
                                     dev_iterator,
                                     id2word,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "dev")

            bleu = metrics["bleu"]
            if bleu > best_bleu:
                best_bleu, best_ckpt = bleu, global_step
                save(config.best_ckpt, {
                    "best_bleu": str(best_bleu),
                    "best_ckpt": str(best_ckpt)
                }, config.best_ckpt)
예제 #3
0
파일: main.py 프로젝트: wly-thu/QGforQA
def train_rl(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.ner_emb_file, "r") as fh:
        ner_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.label_emb_file, "r") as fh:
        label_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)
    with open(config.word_dictionary, "r") as fh:
        word_dictionary = json.load(fh)
    with h5py.File(config.embedding_file, 'r') as fin:
        embed_weights = fin["embedding"][...]
        elmo_word_mat = np.zeros(
            (embed_weights.shape[0] + 1, embed_weights.shape[1]),
            dtype=np.float32)
    elmo_word_mat[1:, :] = embed_weights

    id2word = {word_dictionary[w]: w for w in word_dictionary}
    best_bleu, best_ckpt = 0., 0
    dev_total = meta["total"]
    print("Building model...")
    parser = get_record_parser(config)
    graph_qg = tf.Graph()
    with graph_qg.as_default():
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config.batch_size)
        dev_dataset = get_dataset(config.dev_record_file, parser,
                                  config.batch_size)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

        model_qg = QGRLModel(config, word_mat, elmo_word_mat, label_mat,
                             pos_mat, ner_mat)
        model_qg.build_graph()
        model_qg.add_train_op()

    sess_qg = tf.Session(graph=graph_qg)

    writer = tf.summary.FileWriter(config.output_dir)
    with sess_qg.as_default():
        with graph_qg.as_default():
            sess_qg.run(tf.global_variables_initializer())
            saver_qg = tf.train.Saver(max_to_keep=1000,
                                      var_list=[
                                          p for p in tf.global_variables()
                                          if "word_mat" not in p.name
                                      ])
            if os.path.exists(os.path.join(config.output_dir, "checkpoint")):
                saver_qg.restore(sess_qg,
                                 tf.train.latest_checkpoint(config.output_dir))
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qg_ckpt = json.load(fh)
                    best_bleu, best_ckpt = float(
                        best_qg_ckpt["best_bleu"]), int(
                            best_qg_ckpt["best_ckpt"])

    global_step = max(sess_qg.run(model_qg.global_step), 1)
    train_next_element = train_iterator.get_next()
    for _ in tqdm(range(global_step, config.num_steps + 1)):
        global_step = sess_qg.run(model_qg.global_step) + 1
        para, para_unk, para_char, que, que_unk, que_char, labels, pos_tags, ner_tags, \
            que_labels, que_pos_tags, que_ner_tags, y1, y2, qa_id = sess_qg.run(train_next_element)
        # get greedy search questions as baseline and sampled questions
        symbols, symbols_rl = sess_qg.run(
            [model_qg.symbols, model_qg.symbols_rl],
            feed_dict={
                model_qg.para: para,
                model_qg.para_unk: para_unk,
                model_qg.que: que,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.qa_id: qa_id
            })
        # get rewards and format sampled questions
        reward, reward_rl, reward_base, que_rl = evaluate_rl(
            train_eval_file,
            qa_id,
            symbols,
            symbols_rl,
            id2word,
            metric=config.rl_metric)
        # update model with policy gradient
        loss_ml, _ = sess_qg.run(
            [model_qg.loss_ml, model_qg.train_op],
            feed_dict={
                model_qg.para: para,
                model_qg.para_unk: para_unk,
                model_qg.que: que,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.dropout: config.dropout,
                model_qg.qa_id: qa_id,
                model_qg.sampled_que: que_rl,
                model_qg.reward: reward
            })
        if global_step % config.period == 0:
            loss_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/loss", simple_value=loss_ml),
            ])
            writer.add_summary(loss_sum, global_step)
            reward_base_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/reward_base",
                                 simple_value=np.mean(reward_base)),
            ])
            writer.add_summary(reward_base_sum, global_step)
            reward_rl_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/reward_rl",
                                 simple_value=np.mean(reward_rl)),
            ])
            writer.add_summary(reward_rl_sum, global_step)
        if global_step % config.checkpoint == 0:
            filename = os.path.join(config.output_dir,
                                    "model_{}.ckpt".format(global_step))
            saver_qg.save(sess_qg, filename)

            metrics = evaluate_batch(config,
                                     model_qg,
                                     config.val_num_batches,
                                     train_eval_file,
                                     sess_qg,
                                     train_iterator,
                                     id2word,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "train")

            metrics = evaluate_batch(config,
                                     model_qg,
                                     dev_total // config.batch_size + 1,
                                     dev_eval_file,
                                     sess_qg,
                                     dev_iterator,
                                     id2word,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "dev")

            bleu = metrics["bleu"]
            if bleu > best_bleu:
                best_bleu, best_ckpt = bleu, global_step
                save(config.best_ckpt, {
                    "best_bleu": str(best_bleu),
                    "best_ckpt": str(best_ckpt)
                }, config.best_ckpt)
예제 #4
0
파일: main.py 프로젝트: wly-thu/QGforQA
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.ner_emb_file, "r") as fh:
        ner_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.label_emb_file, "r") as fh:
        label_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)
    with open(config.word_dictionary, "r") as fh:
        word_dictionary = json.load(fh)
    with h5py.File(config.embedding_file, 'r') as fin:
        embed_weights = fin["embedding"][...]
        elmo_word_mat = np.zeros(
            (embed_weights.shape[0] + 1, embed_weights.shape[1]),
            dtype=np.float32)
    elmo_word_mat[1:, :] = embed_weights

    id2word = {word_dictionary[w]: w for w in word_dictionary}
    dev_total = meta["total"]
    print("Building model...")
    parser = get_record_parser(config)
    graph = tf.Graph()
    best_bleu, best_ckpt = 0., 0
    with graph.as_default() as g:
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config.batch_size)
        dev_dataset = get_dataset(config.dev_record_file, parser,
                                  config.batch_size)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

        model = QGModel(config, word_mat, elmo_word_mat, label_mat, pos_mat,
                        ner_mat)
        model.build_graph()
        model.add_train_op()

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        with tf.Session(config=sess_config) as sess:
            writer = tf.summary.FileWriter(config.output_dir)
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(max_to_keep=1000,
                                   var_list=[
                                       p for p in tf.global_variables()
                                       if "word_mat" not in p.name
                                   ])
            if os.path.exists(os.path.join(config.output_dir, "checkpoint")):
                saver.restore(sess,
                              tf.train.latest_checkpoint(config.output_dir))
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qg_ckpt = json.load(fh)
                    best_bleu, best_ckpt = float(
                        best_qg_ckpt["best_bleu"]), int(
                            best_qg_ckpt["best_ckpt"])

            global_step = max(sess.run(model.global_step), 1)
            train_next_element = train_iterator.get_next()
            for _ in tqdm(range(global_step, config.num_steps + 1)):
                global_step = sess.run(model.global_step) + 1
                para, para_unk, para_char, que, que_unk, que_char, labels, pos_tags, ner_tags, \
                    que_labels, que_pos_tags, que_ner_tags, y1, y2, qa_id = sess.run(train_next_element)
                loss, _ = sess.run(
                    [model.loss, model.train_op],
                    feed_dict={
                        model.para: para,
                        model.para_unk: para_unk,
                        model.que: que,
                        model.labels: labels,
                        model.pos_tags: pos_tags,
                        model.ner_tags: ner_tags,
                        model.dropout: config.dropout,
                        model.qa_id: qa_id,
                    })
                if global_step % config.period == 0:
                    loss_sum = tf.Summary(value=[
                        tf.Summary.Value(tag="model/loss", simple_value=loss),
                    ])
                    writer.add_summary(loss_sum, global_step)
                if global_step % config.checkpoint == 0:
                    filename = os.path.join(
                        config.output_dir, "model_{}.ckpt".format(global_step))
                    saver.save(sess, filename)

                    metrics = evaluate_batch(config,
                                             model,
                                             config.val_num_batches,
                                             train_eval_file,
                                             sess,
                                             train_iterator,
                                             id2word,
                                             evaluate_func=evaluate_simple)
                    write_metrics(metrics, writer, global_step, "train")

                    metrics = evaluate_batch(config,
                                             model,
                                             dev_total // config.batch_size +
                                             1,
                                             dev_eval_file,
                                             sess,
                                             dev_iterator,
                                             id2word,
                                             evaluate_func=evaluate_simple)
                    write_metrics(metrics, writer, global_step, "dev")

                    bleu = metrics["bleu"]
                    if bleu > best_bleu:
                        best_bleu, best_ckpt = bleu, global_step
                        save(
                            config.best_ckpt, {
                                "best_bleu": str(best_bleu),
                                "best_ckpt": str(best_ckpt)
                            }, config.best_ckpt)
예제 #5
0
def prepare(config):
    # process files
    train_examples, train_eval = process_file(config.train_para_file,
                                              config.train_question_file,
                                              para_limit=config.para_limit)
    dev_examples, dev_eval = process_file(config.dev_para_file,
                                          config.dev_question_file,
                                          para_limit=config.para_limit)
    test_examples, test_eval = process_file(config.test_para_file,
                                            config.test_question_file,
                                            para_limit=config.para_limit)

    with open(config.word_dictionary, "r") as fh:
        word2idx_dict = json.load(fh)
        print("num of words {}".format(len(word2idx_dict)))
    with open(config.label_dictionary, "r") as fh:
        label2idx_dict = json.load(fh)
        print("num of labels {}".format(len(label2idx_dict)))
    with open(config.pos_dictionary, "r") as fh:
        pos2idx_dict = json.load(fh)
        print("num of pos tags {}".format(len(pos2idx_dict)))
    with open(config.ner_dictionary, "r") as fh:
        ner2idx_dict = json.load(fh)
        print("num of ner tags {}".format(len(ner2idx_dict)))

    train_meta = build_features(config, train_examples, "train",
                                config.train_record_file, word2idx_dict,
                                pos2idx_dict, ner2idx_dict, label2idx_dict,
                                config.para_limit, config.ques_limit,
                                config.max_input_length)
    dev_meta = build_features(config, dev_examples, "dev",
                              config.dev_record_file, word2idx_dict,
                              pos2idx_dict, ner2idx_dict, label2idx_dict,
                              config.para_limit, config.ques_limit,
                              config.max_input_length)
    test_meta = build_features(config, test_examples, "test",
                               config.test_record_file, word2idx_dict,
                               pos2idx_dict, ner2idx_dict, label2idx_dict,
                               config.para_limit, config.ques_limit,
                               config.max_input_length)

    save(config.train_eval_file, train_eval, message="train eval")
    save(config.dev_eval_file, dev_eval, message="dev eval")
    save(config.test_eval_file, test_eval, message="test eval")
    save(config.train_meta, train_meta, message="train meta")
    save(config.dev_meta, dev_meta, message="dev meta")
    save(config.test_meta, test_meta, message="test meta")
예제 #6
0
파일: main.py 프로젝트: wly-thu/QGforQA
def train(config):
    with open(config.word_emb_file, "rb") as fh:
        word_mat = np.array(pkl.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.ner_emb_file, "r") as fh:
        ner_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.label_emb_file, "r") as fh:
        label_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.word_dictionary, "r") as fh:
        word_dictionary = json.load(fh)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)
    with open(config.map_to_orig, "rb") as fh:
        map_to_orig = pkl.load(fh)

    id2word = {word_dictionary[w]: w for w in word_dictionary}
    dev_total = meta["total"]
    best_bleu, best_ckpt = 0., 0
    print("Building model...")
    parser = get_record_parser(config)
    graph_para = tf.Graph()
    graph_qg = tf.Graph()
    with graph_para.as_default() as g:
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config.batch_size)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_dataset = get_dataset(config.dev_record_file, parser,
                                  config.batch_size)
        dev_iterator = dev_dataset.make_one_shot_iterator()
        model_para = BertEmb(config, config.para_limit + 2, graph=g)
    with graph_qg.as_default() as g:
        model_qg = QGModel(config, word_mat, label_mat, pos_mat, ner_mat)
        model_qg.build_graph()
        model_qg.add_train_op()

    sess_para = tf.Session(graph=graph_para)
    sess_qg = tf.Session(graph=graph_qg)

    with sess_para.as_default():
        with graph_para.as_default():
            print("init from pretrained bert..")
            tvars = tf.trainable_variables()
            (assignment_map, initialized_variable_names) = \
                modeling.get_assignment_map_from_checkpoint(tvars, config.init_checkpoint)
            tf.train.init_from_checkpoint(config.init_checkpoint,
                                          assignment_map)
            sess_para.run(tf.global_variables_initializer())
    with sess_qg.as_default():
        with graph_qg.as_default():
            sess_qg.run(tf.global_variables_initializer())
            saver_qg = tf.train.Saver(max_to_keep=1000,
                                      var_list=[
                                          p for p in tf.global_variables()
                                          if "word_mat" not in p.name
                                      ])
            if os.path.exists(os.path.join(config.output_dir, "checkpoint")):
                print(tf.train.latest_checkpoint(config.output_dir))
                saver_qg.restore(sess_qg,
                                 tf.train.latest_checkpoint(config.output_dir))
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qg_ckpt = json.load(fh)
                    best_bleu, best_ckpt = float(
                        best_qg_ckpt["best_bleu"]), int(
                            best_qg_ckpt["best_ckpt"])
                    print("best_bleu:{}, best_ckpt:{}".format(
                        best_bleu, best_ckpt))

    writer = tf.summary.FileWriter(config.output_dir)
    global_step = max(sess_qg.run(model_qg.global_step), 1)
    train_next_element = train_iterator.get_next()
    for _ in tqdm(range(global_step, config.num_steps + 1)):
        global_step = sess_qg.run(model_qg.global_step) + 1
        para, para_unk, ques, labels, pos_tags, ner_tags, qa_id = sess_para.run(
            train_next_element)
        para_emb = sess_para.run(model_para.bert_emb,
                                 feed_dict={model_para.input_ids: para_unk})
        loss, _ = sess_qg.run(
            [model_qg.loss, model_qg.train_op],
            feed_dict={
                model_qg.para: para,
                model_qg.bert_para: para_emb,
                model_qg.que: ques,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.dropout: config.dropout,
                model_qg.qa_id: qa_id
            })
        if global_step % config.period == 0:
            loss_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/loss", simple_value=loss),
            ])
            writer.add_summary(loss_sum, global_step)
        if global_step % config.checkpoint == 0:
            filename = os.path.join(config.output_dir,
                                    "model_{}.ckpt".format(global_step))
            saver_qg.save(sess_qg, filename)

            metrics = evaluate_batch(config,
                                     model_para,
                                     model_qg,
                                     sess_para,
                                     sess_qg,
                                     config.val_num_batches,
                                     train_eval_file,
                                     train_iterator,
                                     id2word,
                                     map_to_orig,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "train")

            metrics = evaluate_batch(config,
                                     model_para,
                                     model_qg,
                                     sess_para,
                                     sess_qg,
                                     dev_total // config.batch_size + 1,
                                     dev_eval_file,
                                     dev_iterator,
                                     id2word,
                                     map_to_orig,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "dev")
            bleu = metrics["bleu"]
            if bleu > best_bleu:
                best_bleu, best_ckpt = bleu, global_step
                save(config.best_ckpt, {
                    "best_bleu": str(best_bleu),
                    "best_ckpt": str(best_ckpt)
                }, config.best_ckpt)
예제 #7
0
def get_vocab(config):
    print("Get the vocabulary...")
    word_counter, char_counter = Counter(), Counter()
    pos_counter, ner_counter, label_counter = Counter(), Counter(), Counter()
    files = [(config.train_para_file, config.train_question_file),
             (config.dev_para_file, config.dev_question_file)]
    for para_file, que_file in files:
        with open("{}.tok".format(para_file), 'r') as fp, open("{}.tok".format(que_file), 'r') as fq, \
                open("{}.pos".format(para_file), 'r') as fpp, open("{}.pos".format(que_file), 'r') as fqp, \
                open("{}.ner".format(para_file), 'r') as fpn, open("{}.ner".format(que_file), 'r') as fqn, \
                open("{}.label".format(para_file), 'r') as fpl:
            while True:
                para, question = fp.readline(), fq.readline()
                pos, que_pos = fpp.readline(), fqp.readline()
                ner, que_ner = fpn.readline(), fqn.readline()
                label = fpl.readline()
                if not question or not para:
                    break
                if config.lower_word:
                    para = para.lower()
                    question = question.lower()
                para_tokens = para.strip().split(' ')
                que_tokens = question.strip().split(' ')
                pos_tags = pos.strip().split(' ')
                ner_tags = ner.strip().split(' ')
                que_pos_tags = que_pos.strip().split(' ')
                que_ner_tags = que_ner.strip().split(' ')
                labels = label.strip().split(' ')
                for token in para_tokens + que_tokens:
                    word_counter[token] += 1
                    for char in list(token):
                        char_counter[char] += 1
                for pos_tag in pos_tags + que_pos_tags:
                    pos_counter[pos_tag] += 1
                for ner_tag in ner_tags + que_ner_tags:
                    ner_counter[ner_tag] += 1
                for label in labels:
                    label_counter[label] += 1
    word_emb_mat, word2idx_dict, unk_num = get_word_embedding(
        word_counter,
        emb_file=config.glove_word_file,
        emb_size=config.glove_word_size,
        vocab_size=config.vocab_size_limit,
        vec_size=config.glove_dim,
        vocab_file=config.vocab_file)
    char_emb_mat, char2idx_dict = get_tag_embedding(char_counter,
                                                    "char",
                                                    vec_size=config.char_dim)
    pos_emb_mat, pos2idx_dict = get_tag_embedding(pos_counter,
                                                  "pos",
                                                  vec_size=config.pos_dim)
    ner_emb_mat, ner2idx_dict = get_tag_embedding(ner_counter,
                                                  "ner",
                                                  vec_size=config.ner_dim)
    label_emb_mat, label2idx_dict = get_tag_embedding(
        label_counter, "label", vec_size=config.label_dim)
    print("{} out of {} are not in glove".format(unk_num, len(word2idx_dict)))
    print("{} chars".format(char_emb_mat.shape[0]))
    print("{} pos tags, {} ner tags, {} answer labels, {} chars".format(
        pos_emb_mat.shape[0], ner_emb_mat.shape[0], label_emb_mat.shape[0],
        char_emb_mat.shape[0]))
    save(config.word_emb_file, word_emb_mat, message="word embedding")
    save(config.char_emb_file, char_emb_mat, message="char embedding")
    save(config.pos_emb_file, pos_emb_mat, message="pos embedding")
    save(config.ner_emb_file, ner_emb_mat, message="ner embedding")
    save(config.label_emb_file, label_emb_mat, message="label embedding")
    save(config.word_dictionary, word2idx_dict, message="word dictionary")
    save(config.char_dictionary, char2idx_dict, message="char dictionary")
    save(config.pos_dictionary, pos2idx_dict, message="pos dictionary")
    save(config.ner_dictionary, ner2idx_dict, message="ner dictionary")
    save(config.label_dictionary, label2idx_dict, message="label dictionary")
    print("Dump elmo word embedding...")
    token_embedding_file = config.embedding_file
    dump_token_embeddings(config.vocab_file, config.elmo_options_file,
                          config.elmo_weight_file, token_embedding_file)
예제 #8
0
파일: prepare.py 프로젝트: wly-thu/QGforQA
def prepare(config):
    # process files
    train_examples, train_eval = process_file(config.train_para_file,
                                              config.train_question_file,
                                              config.train_answer_file,
                                              lower_word=config.lower_word)
    dev_examples, dev_eval = process_file(config.dev_para_file,
                                          config.dev_question_file,
                                          config.dev_answer_file,
                                          lower_word=config.lower_word)
    test_examples, test_eval = process_file(config.test_para_file,
                                            config.test_question_file,
                                            config.test_answer_file,
                                            lower_word=config.lower_word)

    with open(config.word_dictionary, "r") as fh:
        word2idx_dict = json.load(fh)
    with open(config.char_dictionary, "r") as fh:
        char2idx_dict = json.load(fh)
    with open(config.label_dictionary, "r") as fh:
        label2idx_dict = json.load(fh)
    with open(config.pos_dictionary, "r") as fh:
        pos2idx_dict = json.load(fh)
    with open(config.ner_dictionary, "r") as fh:
        ner2idx_dict = json.load(fh)

    train_meta = build_features(config, train_examples, "train",
                                config.train_record_file, word2idx_dict,
                                pos2idx_dict, ner2idx_dict, label2idx_dict,
                                char2idx_dict)
    dev_meta = build_features(config, dev_examples, "dev",
                              config.dev_record_file, word2idx_dict,
                              pos2idx_dict, ner2idx_dict, label2idx_dict,
                              char2idx_dict)
    test_meta = build_features(config,
                               test_examples,
                               "test",
                               config.test_record_file,
                               word2idx_dict,
                               pos2idx_dict,
                               ner2idx_dict,
                               label2idx_dict,
                               char2idx_dict,
                               is_test=True)

    save(config.train_eval_file, train_eval, message="train eval")
    save(config.dev_eval_file, dev_eval, message="dev eval")
    save(config.test_eval_file, test_eval, message="test eval")
    save(config.train_meta, train_meta, message="train meta")
    save(config.dev_meta, dev_meta, message="dev meta")
    save(config.test_meta, test_meta, message="test meta")
예제 #9
0
def train(config):
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)

    dev_total = meta["total"]
    print("Building model...")
    best_acc, best_ckpt = 0., 0
    parser = get_record_parser_qqp(config)
    graph = tf.Graph()
    with graph.as_default() as g:
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config.batch_size)
        dev_dataset = get_dataset(config.dev_record_file, parser,
                                  config.batch_size)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

        model = QPCModel(config, graph=g)

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        with tf.Session(config=sess_config) as sess:
            writer = tf.summary.FileWriter(config.output_dir)
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(max_to_keep=1000)
            if os.path.exists(os.path.join(config.output_dir, "checkpoint")):
                saver.restore(sess,
                              tf.train.latest_checkpoint(config.output_dir))
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qqp_ckpt = json.load(fh)
                    best_acc, best_ckpt = float(
                        best_qqp_ckpt["best_acc"]), int(
                            best_qqp_ckpt["best_ckpt"])

            global_step = max(sess.run(model.global_step), 1)
            train_next_element = train_iterator.get_next()
            for _ in tqdm(range(global_step, config.num_steps + 1)):
                global_step = sess.run(model.global_step) + 1
                que1, que2, label, qa_id = sess.run(train_next_element)
                loss, pred_label, _ = sess.run(
                    [model.loss, model.pred_label, model.train_op],
                    feed_dict={
                        model.que1: que1,
                        model.que2: que2,
                        model.label: label,
                        model.dropout: config.dropout,
                        model.qa_id: qa_id,
                    })
                if global_step % config.period == 0:
                    loss_sum = tf.Summary(value=[
                        tf.Summary.Value(tag="model/loss", simple_value=loss),
                    ])
                    writer.add_summary(loss_sum, global_step)
                if global_step % config.checkpoint == 0:
                    filename = os.path.join(
                        config.output_dir, "model_{}.ckpt".format(global_step))
                    saver.save(sess, filename)

                    metrics = evaluate_batch(config, model,
                                             config.val_num_batches,
                                             train_eval_file, sess,
                                             train_iterator)
                    write_metrics(metrics, writer, global_step, "train")

                    metrics = evaluate_batch(
                        config, model, dev_total // config.batch_size + 1,
                        dev_eval_file, sess, dev_iterator)
                    write_metrics(metrics, writer, global_step, "dev")

                    acc = metrics["accuracy"]
                    if acc > best_acc:
                        best_acc, best_ckpt = acc, global_step
                        save(config.best_ckpt, {
                            "best_acc": str(acc),
                            "best_ckpt": str(best_ckpt)
                        }, config.best_ckpt)
예제 #10
0
파일: main.py 프로젝트: wly-thu/QGforQA
def train_for_qg(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)

    dev_total = meta["total"]
    best_em, best_ckpt = 0., 0
    print("Building model...")
    parser = get_record_parser_qg(config)
    graph = tf.Graph()
    with graph.as_default() as g:
        train_record_file = config.train_record_file
        train_dataset = get_batch_dataset(train_record_file, parser, config.batch_size)
        dev_dataset = get_dataset(config.dev_record_file, parser, config.batch_size)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

        model = BidafQA(config, word_mat, char_mat)

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        with tf.Session(config=sess_config) as sess:
            # build graph
            model.build_graph()
            # add training operation
            model.add_train_op()
            writer = tf.summary.FileWriter(config.output_dir)
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(max_to_keep=1000,
                                   var_list=[p for p in tf.global_variables() if "word_mat" not in p.name])
            if os.path.exists(os.path.join(config.output_dir, "checkpoint")):
                saver.restore(sess, tf.train.latest_checkpoint(config.output_dir))
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qa_ckpt = json.load(fh)
                    best_em, best_ckpt = float(best_qa_ckpt["best_em"]), int(best_qa_ckpt["best_ckpt"])
            global_step = max(sess.run(model.global_step), 1)
            train_next_element = train_iterator.get_next()
            for _ in tqdm(range(global_step, config.num_steps + 1)):
                global_step = sess.run(model.global_step) + 1
                _, para, para_char, _, que, que_char, _, _, _, _, _, _, y1, y2, qa_id = sess.run(train_next_element)
                loss, _ = sess.run([model.loss, model.train_op], feed_dict={
                    model.para: para, model.que: que, model.para_char: para_char, model.que_char: que_char,
                    model.y1: y1, model.y2: y2, model.dropout: config.dropout, model.qa_id: qa_id,
                })
                if global_step % config.period == 0:
                    loss_sum = tf.Summary(value=[tf.Summary.Value(
                            tag="model/loss", simple_value=loss), ])
                    writer.add_summary(loss_sum, global_step)
                if global_step % config.checkpoint == 0:
                    filename = os.path.join(
                            config.output_dir, "model_{}.ckpt".format(global_step))
                    saver.save(sess, filename)

                    metrics = evaluate_batch_qa_for_qg(model, config.val_num_batches,
                                                       train_eval_file, sess, train_iterator)
                    write_metrics(metrics, writer, global_step, "train")

                    metrics = evaluate_batch_qa_for_qg(model, dev_total // config.batch_size + 1,
                                                       dev_eval_file, sess, dev_iterator)
                    write_metrics(metrics, writer, global_step, "dev")

                    em = metrics["em"]
                    if em > best_em:
                        best_em, best_ckpt = em, global_step
                        save(config.best_ckpt, {"best_em": str(em), "best_ckpt": str(best_ckpt)},
                             config.best_ckpt)