def test(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.test_eval_file, "r") as fh:
        eval_file = json.load(fh)
    with open(config.test_meta, "r") as fh:
        meta = json.load(fh)

    total = meta["total"]

    graph = tf.Graph()
    print("Loading model...")
    with graph.as_default() as g:
        test_batch = get_dataset(config.test_record_file,
                                 get_record_parser(config, is_test=True),
                                 config).make_one_shot_iterator()

        model = Model(config,
                      test_batch,
                      word_mat,
                      char_mat,
                      trainable=False,
                      graph=g)

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        with tf.Session(config=sess_config) as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))
            if config.decay < 1.0:
                sess.run(model.assign_vars)
            losses = []
            answer_dict = {}
            remapped_dict = {}
            for step in tqdm(range(total // config.batch_size + 1)):
                qa_id, loss, yp1, yp2 = sess.run(
                    [model.qa_id, model.loss, model.yp1, model.yp2])
                answer_dict_, remapped_dict_ = convert_tokens(
                    eval_file, qa_id.tolist(), yp1.tolist(), yp2.tolist())
                answer_dict.update(answer_dict_)
                remapped_dict.update(remapped_dict_)
                losses.append(loss)
            loss = np.mean(losses)
            metrics = evaluate(eval_file, answer_dict)
            with open(config.answer_file, "w") as fh:
                json.dump(remapped_dict, fh)
            print("Exact Match: {}, F1: {}".format(metrics['exact_match'],
                                                   metrics['f1']))
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)

    dev_total = meta["total"]

    print("Building model...")
    parser = get_record_parser(config)
    train_dataset = get_batch_dataset(config.train_record_file, parser, config)
    dev_dataset = get_dataset(config.dev_record_file, parser, config)
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle,
                                                   train_dataset.output_types,
                                                   train_dataset.output_shapes)
    train_iterator = train_dataset.make_one_shot_iterator()
    dev_iterator = dev_dataset.make_one_shot_iterator()

    model = Model(config, iterator, word_mat, char_mat)

    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.allow_growth = True

    loss_save = 100.0
    patience = 0
    lr = config.init_lr

    with tf.Session(config=sess_config) as sess:
        writer = tf.summary.FileWriter(config.log_dir)
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        train_handle = sess.run(train_iterator.string_handle())
        dev_handle = sess.run(dev_iterator.string_handle())
        sess.run(tf.assign(model.is_train, tf.constant(True, dtype=tf.bool)))
        sess.run(tf.assign(model.lr, tf.constant(lr, dtype=tf.float32)))

        for _ in tqdm(range(1, config.num_steps + 1)):
            global_step = sess.run(model.global_step) + 1
            loss, train_op = sess.run([model.loss, model.train_op],
                                      feed_dict={handle: train_handle})
            if global_step % config.period == 0:
                loss_sum = tf.Summary(value=[
                    tf.Summary.Value(tag="model/loss", simple_value=loss),
                ])
                writer.add_summary(loss_sum, global_step)
            if global_step % config.checkpoint == 0:
                sess.run(
                    tf.assign(model.is_train, tf.constant(False,
                                                          dtype=tf.bool)))
                _, summ = evaluate_batch(model, config.val_num_batches,
                                         train_eval_file, sess, "train",
                                         handle, train_handle)
                for s in summ:
                    writer.add_summary(s, global_step)

                metrics, summ = evaluate_batch(
                    model, dev_total // config.batch_size + 1, dev_eval_file,
                    sess, "dev", handle, dev_handle)
                sess.run(
                    tf.assign(model.is_train, tf.constant(True,
                                                          dtype=tf.bool)))

                dev_loss = metrics["loss"]
                if dev_loss < loss_save:
                    loss_save = dev_loss
                    patience = 0
                else:
                    patience += 1
                if patience >= config.patience:
                    lr /= 2.0
                    loss_save = dev_loss
                    patience = 0
                sess.run(tf.assign(model.lr, tf.constant(lr,
                                                         dtype=tf.float32)))
                for s in summ:
                    writer.add_summary(s, global_step)
                writer.flush()
                filename = os.path.join(config.save_dir,
                                        "model_{}.ckpt".format(global_step))
                saver.save(sess, filename)
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)

    dev_total = meta["total"]
    print("Building model...")
    parser = get_record_parser(config)
    graph = tf.Graph()
    with graph.as_default() as g:
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config)
        dev_dataset = get_dataset(config.dev_record_file, parser, config)
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_dataset.output_types, train_dataset.output_shapes)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

        model = Model(config, iterator, word_mat, char_mat, graph=g)

        # 如果你指定的设备不存在,允许TF自动分配设备
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        loss_save = 100.0
        patience = 0
        best_f1 = 0.
        best_em = 0.

        with tf.Session(config=sess_config) as sess:
            writer = tf.summary.FileWriter(config.log_dir)
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            train_handle = sess.run(train_iterator.string_handle())
            dev_handle = sess.run(dev_iterator.string_handle())
            if os.path.exists(os.path.join(config.save_dir, "checkpoint")):
                saver.restore(sess,
                              tf.train.latest_checkpoint(config.save_dir))
            global_step = max(sess.run(model.global_step), 1)

            for _ in tqdm(range(global_step, config.num_steps + 1)):
                global_step = sess.run(model.global_step) + 1
                loss, train_op = sess.run([model.loss, model.train_op],
                                          feed_dict={
                                              handle: train_handle,
                                              model.dropout: config.dropout
                                          })
                if global_step % config.period == 0:
                    loss_sum = tf.Summary(value=[
                        tf.Summary.Value(tag="model/loss", simple_value=loss),
                    ])
                    writer.add_summary(loss_sum, global_step)
                if global_step % config.checkpoint == 0:
                    _, summ = evaluate_batch(model, config.val_num_batches,
                                             train_eval_file, sess, "train",
                                             handle, train_handle)
                    for s in summ:
                        writer.add_summary(s, global_step)

                    metrics, summ = evaluate_batch(
                        model, dev_total // config.batch_size + 1,
                        dev_eval_file, sess, "dev", handle, dev_handle)

                    dev_f1 = metrics["f1"]
                    dev_em = metrics["exact_match"]
                    if dev_f1 < best_f1 and dev_em < best_em:
                        patience += 1
                        if patience > config.early_stop:
                            break
                    else:
                        patience = 0
                        best_em = max(best_em, dev_em)
                        best_f1 = max(best_f1, dev_f1)

                    for s in summ:
                        writer.add_summary(s, global_step)
                    writer.flush()
                    filename = os.path.join(
                        config.save_dir, "model_{}.ckpt".format(global_step))
                    saver.save(sess, filename)
Example #4
0
def train(config):
    # shape [91589, 300]
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    # shape [1427,64]
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    # {'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a
    # golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue
    # of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of
    #  the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It
    # is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette
    # Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the
    # Gold Dome), is a simple, modern stone statue of Mary.', 'spans': [[0, 15], [15, 16], [17, 20], [21, 27], [28,
    # 31], [32, 33], [34, 42], [43, 52], [52, 53], [54, 58], [59, 62], [63, 67], [68, 76], [76, 78], [79, 83], [84,
    # 88], [89, 91], [92, 93], [94, 100], [101, 107], [108, 110], [111, 114], [115, 121], [122, 126], [126, 127],
    # [128, 139], [140, 142], [143, 148], [149, 151], [152, 155], [156, 160], [161, 169], [170, 173], [174, 180],
    # [181, 183], [183, 184], [185, 187], [188, 189], [190, 196], [197, 203], [204, 206], [207, 213], [214, 218],
    # [219, 223], [224, 232], [233, 237], [238, 241], [242, 248], [249, 250], [250, 256], [257, 259], [260, 262],
    # [263, 268], [268, 269], [269, 270], [271, 275], [276, 278], [279, 282], [283, 287], [288, 296], [297, 299],
    # [300, 303], [304, 312], [313, 315], [316, 319], [320, 326], [327, 332], [332, 333], [334, 345], [346, 352],
    # [353, 356], [357, 365], [366, 368], [369, 372], [373, 379], [379, 380], [381, 382], [383, 389], [390, 395],
    # [396, 398], [399, 405], [406, 409], [410, 420], [420, 421], [422, 424], [425, 427], [428, 429], [430, 437],
    # [438, 440], [441, 444], [445, 451], [452, 454], [455, 462], [462, 463], [464, 470], [471, 476], [477, 480],
    # [481, 487], [488, 492], [493, 502], [503, 511], [512, 514], [515, 520], [521, 531], [532, 541], [542, 544],
    # [545, 549], [549, 550], [551, 553], [554, 557], [558, 561], [562, 564], [565, 568], [569, 573], [574, 579],
    # [580, 581], [581, 584], [585, 587], [588, 589], [590, 596], [597, 601], [602, 606], [607, 615], [616, 623],
    # [624, 625], [626, 633], [634, 637], [638, 641], [642, 646], [647, 651], [651, 652], [652, 653], [654, 656],
    # [657, 658], [659, 665], [665, 666], [667, 673], [674, 679], [680, 686], [687, 689], [690, 694], [694, 695]],
    # 'answers': ['Saint Bernadette Soubirous'], 'uuid': '5733be284776f41900661182'} 一共87599 个
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    # dev_eval_file与train_eval_file 格式一样
    # 一共10570 个
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    # meta 10482
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)

    dev_total = meta["total"]
    print("BIDAF Building model...")
    parser = get_record_parser(config)
    graph = tf.Graph()
    with graph.as_default() as g:
        # context_idxs, ques_idxs, context_char_idxs, ques_char_idxs, y1, y2, qa_id是parse返回的数据具体的shape和数据类型为
        # <BatchDataset shapes: ((?, 400), (?, 50), (?, 400, 16), (?, 50, 16), (?, 400), (?, 400), (?,)),
        # types: (tf.int32, tf.int32, tf.int32, tf.int32, tf.float32, tf.float32, tf.int64)>
        train_data_set = get_batch_dataset(config.train_record_file, parser, config)
        dev_data_set = get_dataset(config.dev_record_file, parser, config)
        # feed able iterator https://www.bilibili.com/read/cv647026
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(handle, train_data_set.output_types, train_data_set.output_shapes)
        train_iterator = train_data_set.make_one_shot_iterator()
        dev_iterator = dev_data_set.make_one_shot_iterator()

        model = Model(config, iterator, word_mat, char_mat, graph=g)

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        patience = 0
        best_f1 = 0.
        best_em = 0.

        with tf.Session(config=sess_config) as sess:
            writer = tf.summary.FileWriter(config.log_dir)
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            train_handle = sess.run(train_iterator.string_handle())
            dev_handle = sess.run(dev_iterator.string_handle())
            if os.path.exists(os.path.join(config.save_dir, "checkpoint")):
                saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))
            global_step = max(sess.run(model.global_step))

            for _ in tqdm(range(global_step, config.num_steps + 1)):
                global_step = sess.run(model.global_step) + 1
                loss, train_op = sess.run([model.loss, model.train_op], feed_dict={handle: train_handle,
                                                                                   model.dropout: config.dropout})
                if global_step % config.period == 0:
                    loss_sum = tf.Summary(value=[tf.Summary.Value(tag="model/loss", simple_value=loss), ])
                    writer.add_summary(loss_sum, global_step)
                if global_step % config.checkpoint == 0:
                    _, summ = evaluate_batch(
                        model, config.val_num_batches, train_eval_file, sess, "train", handle, train_handle)
                    for s in summ:
                        writer.add_summary(s, global_step)

                    metrics, summ = evaluate_batch(
                        model, dev_total // config.batch_size + 1, dev_eval_file, sess, "dev", handle, dev_handle)
                    dev_f1 = metrics["f1"]
                    dev_em = metrics["exact_match"]
                    if dev_f1 < best_f1 and dev_em < best_em:
                        patience +=1
                        if patience > config.early_stop:
                            break
                    else:
                        patience = 0
                        best_em = max(best_em, dev_em)
                        best_f1 = max(best_f1, dev_f1)

                    for s in summ:
                        writer.add_summary(s, global_step)
                    writer.flush()
                    filename = os.path.join(config.save_dir, "model_{}.ckpt".format(global_step))
                    saver.save(sess, filename)