Пример #1
0
def evaluate(args):
    with open(os.path.join(args.records_dir, 'dev_meta.json'),
              'r',
              encoding='utf8') as p:
        dev_total = json.load(p)['size']

    dev_records_file = os.path.join(args.records_dir, 'dev.tfrecords')
    parser = get_record_parser(args)
    dev_dataset = get_batch_dataset(dev_records_file, parser, args)
    dev_iterator = dev_dataset.make_one_shot_iterator()

    session_config = tf.ConfigProto(allow_soft_placement=True)
    session_config.gpu_options.allow_growth = True
    sess = tf.Session(config=session_config)
    model = CommentModel(args, dev_iterator)
    saver = tf.train.Saver(max_to_keep=10000)
    saver.restore(sess, args.eval_model)

    total_loss = 0
    batches_num = int(np.ceil(dev_total / args.batch_size))
    for i in tqdm(range(batches_num), desc='eval'):
        loss = sess.run(model.loss)
        total_loss += loss
    dev_loss = total_loss / batches_num
    dev_ppl = np.exp(dev_loss)
    print(
        f'{time.asctime()} - Evaluation result -> dev_loss:{dev_loss:.3f}  dev_ppl:{dev_ppl:.3f}'
    )
Пример #2
0
def train_qqp_qap(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.ner_emb_file, "r") as fh:
        ner_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.label_emb_file, "r") as fh:
        label_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)
    with open(config.word_dictionary, "r") as fh:
        word_dictionary = json.load(fh)
    with open(config.char_dictionary, "r") as fh:
        char_dictionary = json.load(fh)
    with h5py.File(config.embedding_file, 'r') as fin:
        embed_weights = fin["embedding"][...]
        elmo_word_mat = np.zeros(
            (embed_weights.shape[0] + 1, embed_weights.shape[1]),
            dtype=np.float32)
    elmo_word_mat[1:, :] = embed_weights

    id2word = {word_dictionary[w]: w for w in word_dictionary}
    dev_total = meta["total"]
    best_bleu, best_ckpt = 0., 0
    print("Building model...")
    parser = get_record_parser(config)
    graph_qg = tf.Graph()
    graph_qqp = tf.Graph()
    graph_qa = tf.Graph()
    with graph_qg.as_default():
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config)
        dev_dataset = get_dataset(config.dev_record_file, parser, config)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

    with graph_qg.as_default() as g:
        model_qg = QGRLModel(config, word_mat, elmo_word_mat, label_mat,
                             pos_mat, ner_mat)
        model_qg.build_graph()
        model_qg.add_train_op()
    with graph_qqp.as_default() as g:
        model_qqp = QPCModel(config, dev=True, trainable=False, graph=g)
    with graph_qa.as_default() as g:
        model_qa = BidafQA(config,
                           word_mat,
                           char_mat,
                           dev=True,
                           trainable=False)

    sess_qg = tf.Session(graph=graph_qg)
    sess_qqp = tf.Session(graph=graph_qqp)
    sess_qa = tf.Session(graph=graph_qa)

    writer = tf.summary.FileWriter(config.output_dir)
    with sess_qg.as_default():
        with graph_qg.as_default():
            sess_qg.run(tf.global_variables_initializer())
            saver_qg = tf.train.Saver(max_to_keep=1000,
                                      var_list=[
                                          p for p in tf.global_variables()
                                          if "word_mat" not in p.name
                                      ])
            if os.path.exists(os.path.join(config.output_dir, "checkpoint")):
                print(tf.train.latest_checkpoint(config.output_dir))
                saver_qg.restore(sess_qg,
                                 tf.train.latest_checkpoint(config.output_dir))
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qg_ckpt = json.load(fh)
                    best_bleu, best_ckpt = float(
                        best_qg_ckpt["best_bleu"]), int(
                            best_qg_ckpt["best_ckpt"])
    with sess_qqp.as_default():
        with graph_qqp.as_default():
            sess_qqp.run(tf.global_variables_initializer())
            saver_qqp = tf.train.Saver(max_to_keep=1000,
                                       var_list=[
                                           p for p in tf.global_variables()
                                           if "word_mat" not in p.name
                                       ])
            if os.path.exists(config.best_ckpt_qpc):
                with open(config.best_ckpt_qpc, "r") as fh:
                    best_qpc_ckpt = json.load(fh)
                    best_ckpt = int(best_qpc_ckpt["best_ckpt"])
                print("{}/model_{}.ckpt".format(config.output_dir_qpc,
                                                best_ckpt))
                saver_qqp.restore(
                    sess_qqp,
                    "{}/model_{}.ckpt".format(config.output_dir_qpc,
                                              best_ckpt))
            else:
                print("NO the best QPC model to load!")
                exit()
    with sess_qa.as_default():
        with graph_qa.as_default():
            model_qa.build_graph()
            model_qa.add_train_op()
            sess_qa.run(tf.global_variables_initializer())
            saver_qa = tf.train.Saver(max_to_keep=1000,
                                      var_list=[
                                          p for p in tf.global_variables()
                                          if "word_mat" not in p.name
                                      ])
            if os.path.exists(config.best_ckpt_qa):
                with open(config.best_ckpt_qa, "r") as fh:
                    best_qpc_ckpt = json.load(fh)
                    best_ckpt = int(best_qpc_ckpt["best_ckpt"])
                print("{}/model_{}.ckpt".format(config.output_dir_qa,
                                                best_ckpt))
                saver_qa.restore(
                    sess_qa, "{}/model_{}.ckpt".format(config.output_dir_qa,
                                                       best_ckpt))
            else:
                print("NO the best QA model to load!")
                exit()

    global_step = max(sess_qg.run(model_qg.global_step), 1)
    train_next_element = train_iterator.get_next()
    for _ in tqdm(range(global_step, config.num_steps + 1)):
        global_step = sess_qg.run(model_qg.global_step) + 1
        para, para_unk, para_char, que, que_unk, que_char, labels, pos_tags, ner_tags, \
            que_labels, que_pos_tags, que_ner_tags, y1, y2, qa_id = sess_qg.run(train_next_element)
        symbols, symbols_rl = sess_qg.run(
            [model_qg.symbols, model_qg.symbols_rl],
            feed_dict={
                model_qg.para: para,
                model_qg.para_unk: para_unk,
                model_qg.que: que,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.qa_id: qa_id
            })
        # format sample for QA
        que_base, que_unk_base, que_char_base, que_rl, que_unk_rl, que_char_rl = \
            format_generated_ques_for_qa(train_eval_file, qa_id, symbols, symbols_rl, config.batch_size,
                                         config.ques_limit, config.char_limit, id2word, char_dictionary)
        label = np.zeros((config.batch_size, 2), dtype=np.int32)
        if global_step % 4 == 0:
            # QPP reward
            reward_base = sess_qqp.run(model_qqp.pos_prob,
                                       feed_dict={
                                           model_qqp.que1: que_unk,
                                           model_qqp.que2: que_unk_base,
                                           model_qqp.label: label,
                                           model_qqp.qa_id: qa_id,
                                       })
            reward_rl = sess_qqp.run(model_qqp.pos_prob,
                                     feed_dict={
                                         model_qqp.que1: que_unk,
                                         model_qqp.que2: que_unk_rl,
                                         model_qqp.label: label,
                                         model_qqp.qa_id: qa_id,
                                     })
            reward = [rr - rb for rr, rb in zip(reward_rl, reward_base)]
            mixing_ratio = 0.99
        else:
            # QAP reward
            base_qa_loss = sess_qa.run(model_qa.batch_loss,
                                       feed_dict={
                                           model_qa.para: para_unk,
                                           model_qa.para_char: para_char,
                                           model_qa.que: que_unk_base,
                                           model_qa.que_char: que_char_base,
                                           model_qa.y1: y1,
                                           model_qa.y2: y2,
                                           model_qa.qa_id: qa_id,
                                       })
            qa_loss = sess_qa.run(model_qa.batch_loss,
                                  feed_dict={
                                      model_qa.para: para_unk,
                                      model_qa.para_char: para_char,
                                      model_qa.que: que_unk_rl,
                                      model_qa.que_char: que_char_rl,
                                      model_qa.y1: y1,
                                      model_qa.y2: y2,
                                      model_qa.qa_id: qa_id,
                                  })
            reward_base = list(map(lambda x: np.exp(-x), list(base_qa_loss)))
            reward_rl = list(map(lambda x: np.exp(-x), list(qa_loss)))
            reward = list(
                map(lambda x: x[0] - x[1], zip(reward_rl, reward_base)))
            mixing_ratio = 0.97
        # train with rl
        loss_ml, _ = sess_qg.run(
            [model_qg.loss_ml, model_qg.train_op],
            feed_dict={
                model_qg.para: para,
                model_qg.para_unk: para_unk,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.que: que,
                model_qg.dropout: config.dropout,
                model_qg.qa_id: qa_id,
                model_qg.sampled_que: que_rl,
                model_qg.reward: reward,
                model_qg.lamda: mixing_ratio
            })
        if global_step % config.period == 0:
            loss_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/loss", simple_value=loss_ml),
            ])
            writer.add_summary(loss_sum, global_step)
            reward_base_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/reward_base",
                                 simple_value=np.mean(reward_base)),
            ])
            writer.add_summary(reward_base_sum, global_step)
            reward_rl_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/reward_rl",
                                 simple_value=np.mean(reward_rl)),
            ])
            writer.add_summary(reward_rl_sum, global_step)
        if global_step % config.checkpoint == 0:
            filename = os.path.join(config.output_dir,
                                    "model_{}.ckpt".format(global_step))
            saver_qg.save(sess_qg, filename)

            metrics = evaluate_batch(config,
                                     model_qg,
                                     config.val_num_batches,
                                     train_eval_file,
                                     sess_qg,
                                     train_iterator,
                                     id2word,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "train")

            metrics = evaluate_batch(config,
                                     model_qg,
                                     dev_total // config.batch_size + 1,
                                     dev_eval_file,
                                     sess_qg,
                                     dev_iterator,
                                     id2word,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "dev")

            bleu = metrics["bleu"]
            if bleu > best_bleu:
                best_bleu, best_ckpt = bleu, global_step
                save(config.best_ckpt, {
                    "best_bleu": str(best_bleu),
                    "best_ckpt": str(best_ckpt)
                }, config.best_ckpt)
Пример #3
0
def train_rl(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.ner_emb_file, "r") as fh:
        ner_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.label_emb_file, "r") as fh:
        label_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)
    with open(config.word_dictionary, "r") as fh:
        word_dictionary = json.load(fh)
    with h5py.File(config.embedding_file, 'r') as fin:
        embed_weights = fin["embedding"][...]
        elmo_word_mat = np.zeros(
            (embed_weights.shape[0] + 1, embed_weights.shape[1]),
            dtype=np.float32)
    elmo_word_mat[1:, :] = embed_weights

    id2word = {word_dictionary[w]: w for w in word_dictionary}
    best_bleu, best_ckpt = 0., 0
    dev_total = meta["total"]
    print("Building model...")
    parser = get_record_parser(config)
    graph_qg = tf.Graph()
    with graph_qg.as_default():
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config.batch_size)
        dev_dataset = get_dataset(config.dev_record_file, parser,
                                  config.batch_size)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

        model_qg = QGRLModel(config, word_mat, elmo_word_mat, label_mat,
                             pos_mat, ner_mat)
        model_qg.build_graph()
        model_qg.add_train_op()

    sess_qg = tf.Session(graph=graph_qg)

    writer = tf.summary.FileWriter(config.output_dir)
    with sess_qg.as_default():
        with graph_qg.as_default():
            sess_qg.run(tf.global_variables_initializer())
            saver_qg = tf.train.Saver(max_to_keep=1000,
                                      var_list=[
                                          p for p in tf.global_variables()
                                          if "word_mat" not in p.name
                                      ])
            if os.path.exists(os.path.join(config.output_dir, "checkpoint")):
                saver_qg.restore(sess_qg,
                                 tf.train.latest_checkpoint(config.output_dir))
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qg_ckpt = json.load(fh)
                    best_bleu, best_ckpt = float(
                        best_qg_ckpt["best_bleu"]), int(
                            best_qg_ckpt["best_ckpt"])

    global_step = max(sess_qg.run(model_qg.global_step), 1)
    train_next_element = train_iterator.get_next()
    for _ in tqdm(range(global_step, config.num_steps + 1)):
        global_step = sess_qg.run(model_qg.global_step) + 1
        para, para_unk, para_char, que, que_unk, que_char, labels, pos_tags, ner_tags, \
            que_labels, que_pos_tags, que_ner_tags, y1, y2, qa_id = sess_qg.run(train_next_element)
        # get greedy search questions as baseline and sampled questions
        symbols, symbols_rl = sess_qg.run(
            [model_qg.symbols, model_qg.symbols_rl],
            feed_dict={
                model_qg.para: para,
                model_qg.para_unk: para_unk,
                model_qg.que: que,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.qa_id: qa_id
            })
        # get rewards and format sampled questions
        reward, reward_rl, reward_base, que_rl = evaluate_rl(
            train_eval_file,
            qa_id,
            symbols,
            symbols_rl,
            id2word,
            metric=config.rl_metric)
        # update model with policy gradient
        loss_ml, _ = sess_qg.run(
            [model_qg.loss_ml, model_qg.train_op],
            feed_dict={
                model_qg.para: para,
                model_qg.para_unk: para_unk,
                model_qg.que: que,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.dropout: config.dropout,
                model_qg.qa_id: qa_id,
                model_qg.sampled_que: que_rl,
                model_qg.reward: reward
            })
        if global_step % config.period == 0:
            loss_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/loss", simple_value=loss_ml),
            ])
            writer.add_summary(loss_sum, global_step)
            reward_base_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/reward_base",
                                 simple_value=np.mean(reward_base)),
            ])
            writer.add_summary(reward_base_sum, global_step)
            reward_rl_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/reward_rl",
                                 simple_value=np.mean(reward_rl)),
            ])
            writer.add_summary(reward_rl_sum, global_step)
        if global_step % config.checkpoint == 0:
            filename = os.path.join(config.output_dir,
                                    "model_{}.ckpt".format(global_step))
            saver_qg.save(sess_qg, filename)

            metrics = evaluate_batch(config,
                                     model_qg,
                                     config.val_num_batches,
                                     train_eval_file,
                                     sess_qg,
                                     train_iterator,
                                     id2word,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "train")

            metrics = evaluate_batch(config,
                                     model_qg,
                                     dev_total // config.batch_size + 1,
                                     dev_eval_file,
                                     sess_qg,
                                     dev_iterator,
                                     id2word,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "dev")

            bleu = metrics["bleu"]
            if bleu > best_bleu:
                best_bleu, best_ckpt = bleu, global_step
                save(config.best_ckpt, {
                    "best_bleu": str(best_bleu),
                    "best_ckpt": str(best_ckpt)
                }, config.best_ckpt)
Пример #4
0
def test(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.ner_emb_file, "r") as fh:
        ner_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.label_emb_file, "r") as fh:
        label_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.test_eval_file, "r") as fh:
        eval_file = json.load(fh)
    with open(config.test_meta, "r") as fh:
        meta = json.load(fh)
    with open(config.word_dictionary, "r") as fh:
        word_dictionary = json.load(fh)
    with h5py.File(config.embedding_file, 'r') as fin:
        embed_weights = fin["embedding"][...]
        elmo_word_mat = np.zeros(
            (embed_weights.shape[0] + 1, embed_weights.shape[1]),
            dtype=np.float32)
    elmo_word_mat[1:, :] = embed_weights

    id2word = {word_dictionary[w]: w for w in word_dictionary}
    total = meta["total"]
    print(total)

    graph = tf.Graph()
    print("Loading model...")
    with graph.as_default() as g:
        test_iterator = get_dataset(
            config.test_record_file, get_record_parser(config, is_test=True),
            config.test_batch_size).make_one_shot_iterator()
        model = QGModel(config,
                        word_mat,
                        elmo_word_mat,
                        label_mat,
                        pos_mat,
                        ner_mat,
                        trainable=False)
        model.build_graph()
        model.add_train_op()
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        if os.path.exists(config.best_ckpt):
            with open(config.best_ckpt, "r") as fh:
                best_ckpt = json.load(fh)
                checkpoint_to_test = int(best_ckpt["best_ckpt"])
        else:
            print("No Best!")
            exit()

        with tf.Session(config=sess_config) as sess:
            if config.diverse_beam:
                filename = "{}/diverse{}_beam{}".format(
                    config.output_dir, config.diverse_rate, config.beam_size)
            elif config.sample:
                filename = "{}/temperature{}_sample{}".format(
                    config.output_dir, config.temperature, config.sample_size)
            else:
                filename = "{}/beam{}".format(config.output_dir,
                                              config.beam_size)
            writer = tf.summary.FileWriter(filename)
            checkpoint = "{}/model_{}.ckpt".format(config.output_dir,
                                                   checkpoint_to_test)
            print(checkpoint)
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(var_list=[
                p for p in tf.global_variables() if "word_mat" not in p.name
            ])
            saver.restore(sess, checkpoint)
            global_step = sess.run(model.global_step)
            metrics = evaluate_batch(config, model,
                                     total // config.test_batch_size + 1,
                                     eval_file, sess, test_iterator, id2word)
            print(metrics)
            write_metrics(metrics, writer, global_step, "test")
Пример #5
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.ner_emb_file, "r") as fh:
        ner_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.label_emb_file, "r") as fh:
        label_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)
    with open(config.word_dictionary, "r") as fh:
        word_dictionary = json.load(fh)
    with h5py.File(config.embedding_file, 'r') as fin:
        embed_weights = fin["embedding"][...]
        elmo_word_mat = np.zeros(
            (embed_weights.shape[0] + 1, embed_weights.shape[1]),
            dtype=np.float32)
    elmo_word_mat[1:, :] = embed_weights

    id2word = {word_dictionary[w]: w for w in word_dictionary}
    dev_total = meta["total"]
    print("Building model...")
    parser = get_record_parser(config)
    graph = tf.Graph()
    best_bleu, best_ckpt = 0., 0
    with graph.as_default() as g:
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config.batch_size)
        dev_dataset = get_dataset(config.dev_record_file, parser,
                                  config.batch_size)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

        model = QGModel(config, word_mat, elmo_word_mat, label_mat, pos_mat,
                        ner_mat)
        model.build_graph()
        model.add_train_op()

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        with tf.Session(config=sess_config) as sess:
            writer = tf.summary.FileWriter(config.output_dir)
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(max_to_keep=1000,
                                   var_list=[
                                       p for p in tf.global_variables()
                                       if "word_mat" not in p.name
                                   ])
            if os.path.exists(os.path.join(config.output_dir, "checkpoint")):
                saver.restore(sess,
                              tf.train.latest_checkpoint(config.output_dir))
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qg_ckpt = json.load(fh)
                    best_bleu, best_ckpt = float(
                        best_qg_ckpt["best_bleu"]), int(
                            best_qg_ckpt["best_ckpt"])

            global_step = max(sess.run(model.global_step), 1)
            train_next_element = train_iterator.get_next()
            for _ in tqdm(range(global_step, config.num_steps + 1)):
                global_step = sess.run(model.global_step) + 1
                para, para_unk, para_char, que, que_unk, que_char, labels, pos_tags, ner_tags, \
                    que_labels, que_pos_tags, que_ner_tags, y1, y2, qa_id = sess.run(train_next_element)
                loss, _ = sess.run(
                    [model.loss, model.train_op],
                    feed_dict={
                        model.para: para,
                        model.para_unk: para_unk,
                        model.que: que,
                        model.labels: labels,
                        model.pos_tags: pos_tags,
                        model.ner_tags: ner_tags,
                        model.dropout: config.dropout,
                        model.qa_id: qa_id,
                    })
                if global_step % config.period == 0:
                    loss_sum = tf.Summary(value=[
                        tf.Summary.Value(tag="model/loss", simple_value=loss),
                    ])
                    writer.add_summary(loss_sum, global_step)
                if global_step % config.checkpoint == 0:
                    filename = os.path.join(
                        config.output_dir, "model_{}.ckpt".format(global_step))
                    saver.save(sess, filename)

                    metrics = evaluate_batch(config,
                                             model,
                                             config.val_num_batches,
                                             train_eval_file,
                                             sess,
                                             train_iterator,
                                             id2word,
                                             evaluate_func=evaluate_simple)
                    write_metrics(metrics, writer, global_step, "train")

                    metrics = evaluate_batch(config,
                                             model,
                                             dev_total // config.batch_size +
                                             1,
                                             dev_eval_file,
                                             sess,
                                             dev_iterator,
                                             id2word,
                                             evaluate_func=evaluate_simple)
                    write_metrics(metrics, writer, global_step, "dev")

                    bleu = metrics["bleu"]
                    if bleu > best_bleu:
                        best_bleu, best_ckpt = bleu, global_step
                        save(
                            config.best_ckpt, {
                                "best_bleu": str(best_bleu),
                                "best_ckpt": str(best_ckpt)
                            }, config.best_ckpt)
Пример #6
0
def train(args):
    output_dir = args.output
    log_dir = args.log_dir if args.log_dir else os.path.join(output_dir, 'log')
    model_dir = os.path.join(output_dir, 'model')
    records_dir = args.records_dir if not args.data_dir else os.path.join(
        args.data_dir, args.records_dir)
    result_dir = os.path.join(output_dir, 'result')
    for dir in [output_dir, log_dir, model_dir, result_dir]:
        if not os.path.exists(dir):
            os.makedirs(dir)

    # save the args info to ouptut dir.
    with open(os.path.join(output_dir, 'args.json'), 'w') as p:
        json.dump(vars(args), p, indent=2)

    # load meta info
    with open(os.path.join(records_dir, 'train_meta.json'),
              'r',
              encoding='utf8') as p:
        train_total = json.load(p)['size']
        batch_num_per_epoch = int(np.ceil(train_total / args.batch_size))
        print(f'{time.asctime()} - batch num per epoch: {batch_num_per_epoch}')

    with open(os.path.join(records_dir, 'dev_meta.json'), 'r',
              encoding='utf8') as p:
        dev_total = json.load(p)['size']

    train_records_file = os.path.join(records_dir, 'train.tfrecords')
    dev_records_file = os.path.join(records_dir, 'dev.tfrecords')

    with tf.Graph().as_default() as graph, tf.device('/gpu:0'):

        parser = get_record_parser(args)
        train_dataset = get_batch_dataset(train_records_file, parser, args)
        dev_dataset = get_batch_dataset(dev_records_file, parser, args)

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_dataset.output_types, train_dataset.output_shapes)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

        model = CommentModel(args, iterator)

        session_config = tf.ConfigProto(allow_soft_placement=True)
        session_config.gpu_options.allow_growth = True
        sess = tf.Session(config=session_config)

        # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

        writer = tf.summary.FileWriter(log_dir)
        best_ppl = tf.Variable(300,
                               trainable=False,
                               name='best_ppl',
                               dtype=tf.float32)

        saver = tf.train.Saver(max_to_keep=10000)
        if args.restore:
            model_file = args.restore_model or tf.train.latest_checkpoint(
                model_dir)
            print(f'{time.asctime()} - Restore model from {model_file}..')
            var_list = [
                _[0] for _ in checkpoint_utils.list_variables(model_file)
            ]
            saved_vars = [
                _ for _ in tf.global_variables()
                if _.name.split(':')[0] in var_list
            ]
            res_saver = tf.train.Saver(saved_vars)
            res_saver.restore(sess, model_file)

            left_vars = [
                _ for _ in tf.global_variables()
                if _.name.split(':')[0] not in var_list
            ]
            sess.run(tf.initialize_variables(left_vars))
            print(
                f'{time.asctime()} - Restore {len(var_list)} vars and initialize {len(left_vars)} vars.'
            )
            print(left_vars)
        else:
            print(f'{time.asctime()} - Initialize model..')
            sess.run(tf.global_variables_initializer())
            # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)

        train_handle = sess.run(train_iterator.string_handle())
        dev_handle = sess.run(dev_iterator.string_handle())

        sess.run(tf.assign(model.is_train, tf.constant(True,
                                                       dtype=tf.bool)))  #tmp

        patience = 0

        lr = sess.run(model.lr)
        b_ppl = sess.run(best_ppl)
        print(f'{time.asctime()} - lr: {lr:.3f}  best_ppl:{b_ppl:.3f}')

        t0 = datetime.now()

        while True:
            global_step = sess.run(model.global_step) + 1
            epoch = int(np.ceil(global_step / batch_num_per_epoch))

            loss, loss_gen, ppl, train_op, merge_sum, target, check_1 = sess.run(
                [
                    model.loss, model.loss_gen, model.ppl, model._train_op,
                    model._summaries, model.target, model.check_dec_outputs
                ],
                feed_dict={handle: train_handle})

            ela_time = str(datetime.now() - t0).split('.')[0]

            print(
                (f'{time.asctime()} - step/epoch:{global_step}/{epoch:<3d}   '
                 f'gen_loss:{loss_gen:<3.3f}  '
                 f'ppl:{ppl:<4.3f}  '
                 f'elapsed:{ela_time}\r'),
                end='')

            if global_step % args.period == 0:
                writer.add_summary(merge_sum, global_step)
                writer.flush()

            if global_step % args.checkpoint == 0:
                model_file = os.path.join(model_dir, 'model')
                saver.save(sess, model_file, global_step=global_step)

            # if  global_step % batch_num_per_epoch== 0:
            if global_step % args.checkpoint == 0 and not args.no_eval:
                sess.run(
                    tf.assign(model.is_train, tf.constant(False,
                                                          dtype=tf.bool)))
                metrics, summ = evaluate_batch(model,
                                               dev_total // args.batch_size,
                                               sess, handle, dev_handle,
                                               iterator)
                sess.run(
                    tf.assign(model.is_train, tf.constant(True,
                                                          dtype=tf.bool)))

                for s in summ:
                    writer.add_summary(s, global_step)

                dev_ppl = metrics['ppl']
                dev_gen_loss = metrics['gen_loss']

                tqdm.write(
                    f'{time.asctime()} - Evaluate after steps:{global_step}, '
                    f' gen_loss:{dev_gen_loss:.4f},  ppl:{dev_ppl:.3f}')

                if dev_ppl < b_ppl:
                    sess.run(tf.assign(best_ppl, dev_ppl))
                    saver.save(sess, save_path=os.path.join(model_dir, 'best'))
                    tqdm.write(
                        f'{time.asctime()} - the ppl is lower than current best ppl so saved the model.'
                    )
                    patience = 0
                else:
                    patience += 1

                if patience >= args.patience:
                    lr = lr / 2
                    sess.run(
                        tf.assign(model.lr, tf.constant(lr, dtype=tf.float32)))
                    patience = 0
                    tqdm.write(
                        f'{time.asctime()} - The lr is decayed form {lr*2} to {lr}.'
                    )
Пример #7
0
def debug(args):
    from utils import get_record_parser, get_batch_dataset
    import tensorflow as tf
    # parser = get_record_parser(args)
    # dataset = get_batch_dataset('data/records/dev.tfrecords', parser, args)
    # iterator = dataset.make_one_shot_iterator()
    # sess = tf.Session()
    # while True:
    #     print(sess.run(iterator.get_next()))
    # break
    # vocab = Vocab(args.vocab, args.vocab_size)
    # test_file = os.path.join(args.data, 'news_test.data')
    # batcher = TestBatcher(args, vocab, test_file).batcher()
    # for b in batcher:
    #     pass

    parser = get_record_parser(args)
    dataset = get_batch_dataset('data/records/dev.tfrecords', parser, args)

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle,
                                                   dataset.output_types,
                                                   dataset.output_shapes)
    train_iterator = dataset.make_one_shot_iterator()

    model = CommentModel(args, iterator)

    session_config = tf.ConfigProto(allow_soft_placement=True)
    session_config.gpu_options.allow_growth = True

    sess = tf.Session(config=session_config)
    sess.run(tf.global_variables_initializer())

    train_handle = sess.run(train_iterator.string_handle())

    sess.run(tf.assign(model.is_train, tf.constant(True, dtype=tf.bool)))
    get_results = {
        'description': model.description,
        'description_sen': model.description_sen,
        'description_off': model.description_off,
        'description_len': model.description_len,
        'description_mask': model.description_mask,
        'query': model.query,
        'query_len': model.query_len,
        'query_mask': model.query_mask,
        'response': model.response,
        'response_len': model.response_len,
        'response_mask': model.response_mask,
        'target': model.target,
        'target_len': model.target_len,
        'target_mask': model.target_mask,
        # 'x': model.x,
        # 'y': model.y,
        # 'spans': model.span_seq,
        # 'span_num': model.span_num,
        # 'span_mask': model.span_mask
    }

    while True:
        results = sess.run(get_results, feed_dict={handle: train_handle})
        results = {k: v.tolist() for k, v in results.items()}
        from pprint import pprint
        if results['loss'] > 100000:
            pprint(results['loss'], width=1000)
            pprint(results['target'], width=1000)
            pprint(results['target_mask'], width=1000)
Пример #8
0
def test(config):
    with open(config.word_emb_file, "rb") as fh:
        word_mat = np.array(pkl.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.ner_emb_file, "r") as fh:
        ner_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.label_emb_file, "r") as fh:
        label_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.word_dictionary, "r") as fh:
        word_dictionary = json.load(fh)
    with open(config.test_eval_file, "r") as fh:
        test_eval_file = json.load(fh)
    with open(config.test_meta, "r") as fh:
        meta = json.load(fh)
    with open(config.map_to_orig, "rb") as fh:
        map_to_orig = pkl.load(fh)

    id2word = {word_dictionary[w]: w for w in word_dictionary}
    test_total = meta["total"]
    best_bleu, best_ckpt = 0., 0

    print("Building model...")
    parser = get_record_parser(config)
    graph_para = tf.Graph()
    graph_qg = tf.Graph()
    with graph_para.as_default() as g:
        test_dataset = get_dataset(config.test_record_file, parser,
                                   config.batch_size)
        test_iterator = test_dataset.make_one_shot_iterator()
        model_para = BertEmb(config, config.para_limit + 2, graph=g)
    with graph_qg.as_default() as g:
        model_qg = QGModel(config,
                           word_mat,
                           label_mat,
                           pos_mat,
                           ner_mat,
                           trainable=False)
        model_qg.build_graph()
        model_qg.add_train_op()

    sess_para = tf.Session(graph=graph_para)
    sess_qg = tf.Session(graph=graph_qg)

    with sess_para.as_default():
        with graph_para.as_default():
            print("init from pretrained bert..")
            tvars = tf.trainable_variables()
            (assignment_map, initialized_variable_names) = \
                modeling.get_assignment_map_from_checkpoint(tvars, config.init_checkpoint)
            tf.train.init_from_checkpoint(config.init_checkpoint,
                                          assignment_map)
            sess_para.run(tf.global_variables_initializer())
    with sess_qg.as_default():
        with graph_qg.as_default():
            sess_qg.run(tf.global_variables_initializer())
            saver_qg = tf.train.Saver(max_to_keep=1000,
                                      var_list=[
                                          p for p in tf.global_variables()
                                          if "word_mat" not in p.name
                                      ])
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qg_ckpt = json.load(fh)
                    best_bleu, best_ckpt = float(
                        best_qg_ckpt["best_bleu"]), int(
                            best_qg_ckpt["best_ckpt"])
                    print("best_bleu:{}, best_ckpt:{}".format(
                        best_bleu, best_ckpt))
            else:
                print("No best checkpoint!")
                exit()
            checkpoint = "{}/model_{}.ckpt".format(config.output_dir,
                                                   best_ckpt)
            print(checkpoint)
            saver_qg.restore(sess_qg, checkpoint)
    writer = tf.summary.FileWriter(config.output_dir)
    metrics = evaluate_batch(config,
                             model_para,
                             model_qg,
                             sess_para,
                             sess_qg,
                             test_total // config.batch_size + 1,
                             test_eval_file,
                             test_iterator,
                             id2word,
                             map_to_orig,
                             evaluate_func=evaluate)
    print(metrics)
    write_metrics(metrics, writer, best_ckpt, "test")
Пример #9
0
def train(config):
    with open(config.word_emb_file, "rb") as fh:
        word_mat = np.array(pkl.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.ner_emb_file, "r") as fh:
        ner_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.label_emb_file, "r") as fh:
        label_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.word_dictionary, "r") as fh:
        word_dictionary = json.load(fh)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)
    with open(config.map_to_orig, "rb") as fh:
        map_to_orig = pkl.load(fh)

    id2word = {word_dictionary[w]: w for w in word_dictionary}
    dev_total = meta["total"]
    best_bleu, best_ckpt = 0., 0
    print("Building model...")
    parser = get_record_parser(config)
    graph_para = tf.Graph()
    graph_qg = tf.Graph()
    with graph_para.as_default() as g:
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config.batch_size)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_dataset = get_dataset(config.dev_record_file, parser,
                                  config.batch_size)
        dev_iterator = dev_dataset.make_one_shot_iterator()
        model_para = BertEmb(config, config.para_limit + 2, graph=g)
    with graph_qg.as_default() as g:
        model_qg = QGModel(config, word_mat, label_mat, pos_mat, ner_mat)
        model_qg.build_graph()
        model_qg.add_train_op()

    sess_para = tf.Session(graph=graph_para)
    sess_qg = tf.Session(graph=graph_qg)

    with sess_para.as_default():
        with graph_para.as_default():
            print("init from pretrained bert..")
            tvars = tf.trainable_variables()
            (assignment_map, initialized_variable_names) = \
                modeling.get_assignment_map_from_checkpoint(tvars, config.init_checkpoint)
            tf.train.init_from_checkpoint(config.init_checkpoint,
                                          assignment_map)
            sess_para.run(tf.global_variables_initializer())
    with sess_qg.as_default():
        with graph_qg.as_default():
            sess_qg.run(tf.global_variables_initializer())
            saver_qg = tf.train.Saver(max_to_keep=1000,
                                      var_list=[
                                          p for p in tf.global_variables()
                                          if "word_mat" not in p.name
                                      ])
            if os.path.exists(os.path.join(config.output_dir, "checkpoint")):
                print(tf.train.latest_checkpoint(config.output_dir))
                saver_qg.restore(sess_qg,
                                 tf.train.latest_checkpoint(config.output_dir))
            if os.path.exists(config.best_ckpt):
                with open(config.best_ckpt, "r") as fh:
                    best_qg_ckpt = json.load(fh)
                    best_bleu, best_ckpt = float(
                        best_qg_ckpt["best_bleu"]), int(
                            best_qg_ckpt["best_ckpt"])
                    print("best_bleu:{}, best_ckpt:{}".format(
                        best_bleu, best_ckpt))

    writer = tf.summary.FileWriter(config.output_dir)
    global_step = max(sess_qg.run(model_qg.global_step), 1)
    train_next_element = train_iterator.get_next()
    for _ in tqdm(range(global_step, config.num_steps + 1)):
        global_step = sess_qg.run(model_qg.global_step) + 1
        para, para_unk, ques, labels, pos_tags, ner_tags, qa_id = sess_para.run(
            train_next_element)
        para_emb = sess_para.run(model_para.bert_emb,
                                 feed_dict={model_para.input_ids: para_unk})
        loss, _ = sess_qg.run(
            [model_qg.loss, model_qg.train_op],
            feed_dict={
                model_qg.para: para,
                model_qg.bert_para: para_emb,
                model_qg.que: ques,
                model_qg.labels: labels,
                model_qg.pos_tags: pos_tags,
                model_qg.ner_tags: ner_tags,
                model_qg.dropout: config.dropout,
                model_qg.qa_id: qa_id
            })
        if global_step % config.period == 0:
            loss_sum = tf.Summary(value=[
                tf.Summary.Value(tag="model/loss", simple_value=loss),
            ])
            writer.add_summary(loss_sum, global_step)
        if global_step % config.checkpoint == 0:
            filename = os.path.join(config.output_dir,
                                    "model_{}.ckpt".format(global_step))
            saver_qg.save(sess_qg, filename)

            metrics = evaluate_batch(config,
                                     model_para,
                                     model_qg,
                                     sess_para,
                                     sess_qg,
                                     config.val_num_batches,
                                     train_eval_file,
                                     train_iterator,
                                     id2word,
                                     map_to_orig,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "train")

            metrics = evaluate_batch(config,
                                     model_para,
                                     model_qg,
                                     sess_para,
                                     sess_qg,
                                     dev_total // config.batch_size + 1,
                                     dev_eval_file,
                                     dev_iterator,
                                     id2word,
                                     map_to_orig,
                                     evaluate_func=evaluate_simple)
            write_metrics(metrics, writer, global_step, "dev")
            bleu = metrics["bleu"]
            if bleu > best_bleu:
                best_bleu, best_ckpt = bleu, global_step
                save(config.best_ckpt, {
                    "best_bleu": str(best_bleu),
                    "best_ckpt": str(best_ckpt)
                }, config.best_ckpt)