Ejemplo n.º 1
0
def train_qqp_qap(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.ner_emb_file, "r") as fh:
        ner_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.label_emb_file, "r") as fh:
        label_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)
    with open(config.word_dictionary, "r") as fh:
        word_dictionary = json.load(fh)
    with open(config.char_dictionary, "r") as fh:
        char_dictionary = json.load(fh)
    with h5py.File(config.embedding_file, 'r') as fin:
        embed_weights = fin["embedding"][...]
        elmo_word_mat = np.zeros(
            (embed_weights.shape[0] + 1, embed_weights.shape[1]),
            dtype=np.float32)
    elmo_word_mat[1:, :] = embed_weights

    id2word = {word_dictionary[w]: w for w in word_dictionary}
    dev_total = meta["total"]
    best_bleu, best_ckpt = 0., 0
    print("Building model...")
    parser = get_record_parser(config)
    graph_qg = tf.Graph()
    graph_qqp = tf.Graph()
    graph_qa = tf.Graph()
    with graph_qg.as_default():
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config)
        dev_dataset = get_dataset(config.dev_record_file, parser, config)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

    with graph_qg.as_default() as g:
        model_qg = QGRLModel(config, word_mat, elmo_word_mat, label_mat,
                             pos_mat, ner_mat)
        model_qg.build_graph()
        model_qg.add_train_op()
    with graph_qqp.as_default() as g:
        model_qqp = QPCModel(config, dev=True, trainable=False, graph=g)
    with graph_qa.as_default() as g:
        model_qa = BidafQA(config,
                           word_mat,
                           char_mat,
                           dev=True,
                           trainable=False)

    sess_qg = tf.Session(graph=graph_qg)
    sess_qqp = tf.Session(graph=graph_qqp)
    sess_qa = tf.Session(graph=graph_qa)

    writer = tf.summary.FileWriter(config.output_dir)
    with sess_qg.as_default():
        with graph_qg.as_default():
            sess_qg.run(tf.global_variables_initializer())
            saver_qg = tf.train.Saver(max_to_keep=1000,
                                      var_list=[
                                          p for p in tf.global_variables()
                                          if "word_mat" not in p.name
                                      ])
            if os.path.exists(os.path.join(config.output_dir, "checkpoint")):
                print(tf.train.latest_checkpoint(config.output_dir))
                saver_qg.restore(sess_qg,
                                 tf.train.latest_checkpoint(config.output_dir))
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qg_ckpt = json.load(fh)
                    best_bleu, best_ckpt = float(
                        best_qg_ckpt["best_bleu"]), int(
                            best_qg_ckpt["best_ckpt"])
    with sess_qqp.as_default():
        with graph_qqp.as_default():
            sess_qqp.run(tf.global_variables_initializer())
            saver_qqp = tf.train.Saver(max_to_keep=1000,
                                       var_list=[
                                           p for p in tf.global_variables()
                                           if "word_mat" not in p.name
                                       ])
            if os.path.exists(config.best_ckpt_qpc):
                with open(config.best_ckpt_qpc, "r") as fh:
                    best_qpc_ckpt = json.load(fh)
                    best_ckpt = int(best_qpc_ckpt["best_ckpt"])
                print("{}/model_{}.ckpt".format(config.output_dir_qpc,
                                                best_ckpt))
                saver_qqp.restore(
                    sess_qqp,
                    "{}/model_{}.ckpt".format(config.output_dir_qpc,
                                              best_ckpt))
            else:
                print("NO the best QPC model to load!")
                exit()
    with sess_qa.as_default():
        with graph_qa.as_default():
            model_qa.build_graph()
            model_qa.add_train_op()
            sess_qa.run(tf.global_variables_initializer())
            saver_qa = tf.train.Saver(max_to_keep=1000,
                                      var_list=[
                                          p for p in tf.global_variables()
                                          if "word_mat" not in p.name
                                      ])
            if os.path.exists(config.best_ckpt_qa):
                with open(config.best_ckpt_qa, "r") as fh:
                    best_qpc_ckpt = json.load(fh)
                    best_ckpt = int(best_qpc_ckpt["best_ckpt"])
                print("{}/model_{}.ckpt".format(config.output_dir_qa,
                                                best_ckpt))
                saver_qa.restore(
                    sess_qa, "{}/model_{}.ckpt".format(config.output_dir_qa,
                                                       best_ckpt))
            else:
                print("NO the best QA model to load!")
                exit()

    global_step = max(sess_qg.run(model_qg.global_step), 1)
    train_next_element = train_iterator.get_next()
    for _ in tqdm(range(global_step, config.num_steps + 1)):
        global_step = sess_qg.run(model_qg.global_step) + 1
        para, para_unk, para_char, que, que_unk, que_char, labels, pos_tags, ner_tags, \
            que_labels, que_pos_tags, que_ner_tags, y1, y2, qa_id = sess_qg.run(train_next_element)
        symbols, symbols_rl = sess_qg.run(
            [model_qg.symbols, model_qg.symbols_rl],
            feed_dict={
                model_qg.para: para,
                model_qg.para_unk: para_unk,
                model_qg.que: que,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.qa_id: qa_id
            })
        # format sample for QA
        que_base, que_unk_base, que_char_base, que_rl, que_unk_rl, que_char_rl = \
            format_generated_ques_for_qa(train_eval_file, qa_id, symbols, symbols_rl, config.batch_size,
                                         config.ques_limit, config.char_limit, id2word, char_dictionary)
        label = np.zeros((config.batch_size, 2), dtype=np.int32)
        if global_step % 4 == 0:
            # QPP reward
            reward_base = sess_qqp.run(model_qqp.pos_prob,
                                       feed_dict={
                                           model_qqp.que1: que_unk,
                                           model_qqp.que2: que_unk_base,
                                           model_qqp.label: label,
                                           model_qqp.qa_id: qa_id,
                                       })
            reward_rl = sess_qqp.run(model_qqp.pos_prob,
                                     feed_dict={
                                         model_qqp.que1: que_unk,
                                         model_qqp.que2: que_unk_rl,
                                         model_qqp.label: label,
                                         model_qqp.qa_id: qa_id,
                                     })
            reward = [rr - rb for rr, rb in zip(reward_rl, reward_base)]
            mixing_ratio = 0.99
        else:
            # QAP reward
            base_qa_loss = sess_qa.run(model_qa.batch_loss,
                                       feed_dict={
                                           model_qa.para: para_unk,
                                           model_qa.para_char: para_char,
                                           model_qa.que: que_unk_base,
                                           model_qa.que_char: que_char_base,
                                           model_qa.y1: y1,
                                           model_qa.y2: y2,
                                           model_qa.qa_id: qa_id,
                                       })
            qa_loss = sess_qa.run(model_qa.batch_loss,
                                  feed_dict={
                                      model_qa.para: para_unk,
                                      model_qa.para_char: para_char,
                                      model_qa.que: que_unk_rl,
                                      model_qa.que_char: que_char_rl,
                                      model_qa.y1: y1,
                                      model_qa.y2: y2,
                                      model_qa.qa_id: qa_id,
                                  })
            reward_base = list(map(lambda x: np.exp(-x), list(base_qa_loss)))
            reward_rl = list(map(lambda x: np.exp(-x), list(qa_loss)))
            reward = list(
                map(lambda x: x[0] - x[1], zip(reward_rl, reward_base)))
            mixing_ratio = 0.97
        # train with rl
        loss_ml, _ = sess_qg.run(
            [model_qg.loss_ml, model_qg.train_op],
            feed_dict={
                model_qg.para: para,
                model_qg.para_unk: para_unk,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.que: que,
                model_qg.dropout: config.dropout,
                model_qg.qa_id: qa_id,
                model_qg.sampled_que: que_rl,
                model_qg.reward: reward,
                model_qg.lamda: mixing_ratio
            })
        if global_step % config.period == 0:
            loss_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/loss", simple_value=loss_ml),
            ])
            writer.add_summary(loss_sum, global_step)
            reward_base_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/reward_base",
                                 simple_value=np.mean(reward_base)),
            ])
            writer.add_summary(reward_base_sum, global_step)
            reward_rl_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/reward_rl",
                                 simple_value=np.mean(reward_rl)),
            ])
            writer.add_summary(reward_rl_sum, global_step)
        if global_step % config.checkpoint == 0:
            filename = os.path.join(config.output_dir,
                                    "model_{}.ckpt".format(global_step))
            saver_qg.save(sess_qg, filename)

            metrics = evaluate_batch(config,
                                     model_qg,
                                     config.val_num_batches,
                                     train_eval_file,
                                     sess_qg,
                                     train_iterator,
                                     id2word,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "train")

            metrics = evaluate_batch(config,
                                     model_qg,
                                     dev_total // config.batch_size + 1,
                                     dev_eval_file,
                                     sess_qg,
                                     dev_iterator,
                                     id2word,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "dev")

            bleu = metrics["bleu"]
            if bleu > best_bleu:
                best_bleu, best_ckpt = bleu, global_step
                save(config.best_ckpt, {
                    "best_bleu": str(best_bleu),
                    "best_ckpt": str(best_ckpt)
                }, config.best_ckpt)
Ejemplo n.º 2
0
def train_rl(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.ner_emb_file, "r") as fh:
        ner_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.label_emb_file, "r") as fh:
        label_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)
    with open(config.word_dictionary, "r") as fh:
        word_dictionary = json.load(fh)
    with h5py.File(config.embedding_file, 'r') as fin:
        embed_weights = fin["embedding"][...]
        elmo_word_mat = np.zeros(
            (embed_weights.shape[0] + 1, embed_weights.shape[1]),
            dtype=np.float32)
    elmo_word_mat[1:, :] = embed_weights

    id2word = {word_dictionary[w]: w for w in word_dictionary}
    best_bleu, best_ckpt = 0., 0
    dev_total = meta["total"]
    print("Building model...")
    parser = get_record_parser(config)
    graph_qg = tf.Graph()
    with graph_qg.as_default():
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config.batch_size)
        dev_dataset = get_dataset(config.dev_record_file, parser,
                                  config.batch_size)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

        model_qg = QGRLModel(config, word_mat, elmo_word_mat, label_mat,
                             pos_mat, ner_mat)
        model_qg.build_graph()
        model_qg.add_train_op()

    sess_qg = tf.Session(graph=graph_qg)

    writer = tf.summary.FileWriter(config.output_dir)
    with sess_qg.as_default():
        with graph_qg.as_default():
            sess_qg.run(tf.global_variables_initializer())
            saver_qg = tf.train.Saver(max_to_keep=1000,
                                      var_list=[
                                          p for p in tf.global_variables()
                                          if "word_mat" not in p.name
                                      ])
            if os.path.exists(os.path.join(config.output_dir, "checkpoint")):
                saver_qg.restore(sess_qg,
                                 tf.train.latest_checkpoint(config.output_dir))
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qg_ckpt = json.load(fh)
                    best_bleu, best_ckpt = float(
                        best_qg_ckpt["best_bleu"]), int(
                            best_qg_ckpt["best_ckpt"])

    global_step = max(sess_qg.run(model_qg.global_step), 1)
    train_next_element = train_iterator.get_next()
    for _ in tqdm(range(global_step, config.num_steps + 1)):
        global_step = sess_qg.run(model_qg.global_step) + 1
        para, para_unk, para_char, que, que_unk, que_char, labels, pos_tags, ner_tags, \
            que_labels, que_pos_tags, que_ner_tags, y1, y2, qa_id = sess_qg.run(train_next_element)
        # get greedy search questions as baseline and sampled questions
        symbols, symbols_rl = sess_qg.run(
            [model_qg.symbols, model_qg.symbols_rl],
            feed_dict={
                model_qg.para: para,
                model_qg.para_unk: para_unk,
                model_qg.que: que,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.qa_id: qa_id
            })
        # get rewards and format sampled questions
        reward, reward_rl, reward_base, que_rl = evaluate_rl(
            train_eval_file,
            qa_id,
            symbols,
            symbols_rl,
            id2word,
            metric=config.rl_metric)
        # update model with policy gradient
        loss_ml, _ = sess_qg.run(
            [model_qg.loss_ml, model_qg.train_op],
            feed_dict={
                model_qg.para: para,
                model_qg.para_unk: para_unk,
                model_qg.que: que,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.dropout: config.dropout,
                model_qg.qa_id: qa_id,
                model_qg.sampled_que: que_rl,
                model_qg.reward: reward
            })
        if global_step % config.period == 0:
            loss_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/loss", simple_value=loss_ml),
            ])
            writer.add_summary(loss_sum, global_step)
            reward_base_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/reward_base",
                                 simple_value=np.mean(reward_base)),
            ])
            writer.add_summary(reward_base_sum, global_step)
            reward_rl_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/reward_rl",
                                 simple_value=np.mean(reward_rl)),
            ])
            writer.add_summary(reward_rl_sum, global_step)
        if global_step % config.checkpoint == 0:
            filename = os.path.join(config.output_dir,
                                    "model_{}.ckpt".format(global_step))
            saver_qg.save(sess_qg, filename)

            metrics = evaluate_batch(config,
                                     model_qg,
                                     config.val_num_batches,
                                     train_eval_file,
                                     sess_qg,
                                     train_iterator,
                                     id2word,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "train")

            metrics = evaluate_batch(config,
                                     model_qg,
                                     dev_total // config.batch_size + 1,
                                     dev_eval_file,
                                     sess_qg,
                                     dev_iterator,
                                     id2word,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "dev")

            bleu = metrics["bleu"]
            if bleu > best_bleu:
                best_bleu, best_ckpt = bleu, global_step
                save(config.best_ckpt, {
                    "best_bleu": str(best_bleu),
                    "best_ckpt": str(best_ckpt)
                }, config.best_ckpt)