Exemplo n.º 1
0
def test(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.test_eval_file, "r") as fh:
        eval_file = json.load(fh)
    with open(config.test_meta, "r") as fh:
        meta = json.load(fh)

    total = meta["num_batches"]

    print("Loading model...")
    test_batch = get_batch_dataset(config.test_record_file, get_record_parser(
        config, is_test=True), config, is_test=True).make_one_shot_iterator()

    model = Model(config, test_batch, word_mat, char_mat, trainable=False)

    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.allow_growth = True

    with tf.Session(config=sess_config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))
        sess.run(tf.assign(model.is_train, tf.constant(False, dtype=tf.bool)))
        losses = []
        answer_dict = {}
        select_right = []
        for step in tqdm(range(1, total + 1)):
            qa_id, loss, yp1, yp2 , y1, y2, is_select_p, is_select= sess.run(
                [model.qa_id, model.loss, model.yp1, model.yp2, model.y1, model.y2, model.is_select_p, model.is_select])
            y1 = np.argmax(y1, axis=-1)
            y2 = np.argmax(y2, axis=-1)
            sp = np.argmax(is_select_p, axis=-1)
            s = np.argmax(is_select, axis=-1)
            sp = [ n+i*config.passage_num for i,n in enumerate(sp.tolist()) ]
            s = [ m+i*config.passage_num for i,m in enumerate(s.tolist()) ]
            select_right.append(len(set(s).intersection(set(sp))))

            answer_dict_, _ = convert_tokens(
                eval_file, [qa_id[n] for n in sp], [yp1[n] for n in sp], [yp2[n] for n in sp], [y1[n] for n in sp], [y2[n] for n in sp], sp, s)
            answer_dict.update(answer_dict_)
            losses.append(loss)
        loss = np.mean(losses)
        select_accu = sum(select_right)/ (len(select_right)*(config.batch_size/config.passage_num))
        write_prediction(eval_file, answer_dict, 'answer_for_evl.json', config)
        metrics = evaluate(eval_file, answer_dict, filter=False)
        metrics['Selection Accuracy'] = select_accu
        
        print("Exact Match: {}, F1: {}, selection accuracy: {}".format(
            metrics['exact_match'], metrics['f1'], metrics['Selection Accuracy']))
Exemplo n.º 2
0
Arquivo: main.py Projeto: txye/QANet
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)

    dev_total = meta["total"]
    print("Building model...")
    parser = get_record_parser(config)
    graph = tf.Graph()
    with graph.as_default() as g:
        train_dataset = get_batch_dataset(config.train_record_file, parser, config)
        dev_dataset = get_dataset(config.dev_record_file, parser, config)
        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_dataset.output_types, train_dataset.output_shapes)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_iterator = dev_dataset.make_one_shot_iterator()

        model = Model(config, iterator, word_mat, char_mat, graph = g)

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        loss_save = 100.0
        patience = 0
        best_f1 = 0.
        best_em = 0.

        with tf.Session(config=sess_config) as sess:
            writer = tf.summary.FileWriter(config.log_dir)
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            train_handle = sess.run(train_iterator.string_handle())
            dev_handle = sess.run(dev_iterator.string_handle())
            if os.path.exists(os.path.join(config.save_dir, "checkpoint")):
                saver.restore(sess, tf.train.latest_checkpoint(config.save_dir))
            global_step = max(sess.run(model.global_step), 1)

            for _ in tqdm(range(global_step, config.num_steps + 1)):
                global_step = sess.run(model.global_step) + 1
                loss, train_op = sess.run([model.loss, model.train_op], feed_dict={
                                          handle: train_handle, model.dropout: config.dropout})
                if global_step % config.period == 0:
                    loss_sum = tf.Summary(value=[tf.Summary.Value(
                        tag="model/loss", simple_value=loss), ])
                    writer.add_summary(loss_sum, global_step)
                if global_step % config.checkpoint == 0:
                    _, summ = evaluate_batch(
                        model, config.val_num_batches, train_eval_file, sess, "train", handle, train_handle)
                    for s in summ:
                        writer.add_summary(s, global_step)

                    metrics, summ = evaluate_batch(
                        model, dev_total // config.batch_size + 1, dev_eval_file, sess, "dev", handle, dev_handle)

                    dev_f1 = metrics["f1"]
                    dev_em = metrics["exact_match"]
                    if dev_f1 < best_f1 and dev_em < best_em:
                        patience += 1
                        if patience > config.early_stop:
                            break
                    else:
                        patience = 0
                        best_em = max(best_em, dev_em)
                        best_f1 = max(best_f1, dev_f1)

                    for s in summ:
                        writer.add_summary(s, global_step)
                    writer.flush()
                    filename = os.path.join(
                        config.save_dir, "model_{}.ckpt".format(global_step))
                    saver.save(sess, filename)
Exemplo n.º 3
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)

    dev_total = meta["total"]

    print("Building model...")
    parser = get_record_parser(config)
    train_dataset = get_batch_dataset(config.train_record_file, parser, config)
    dev_dataset = get_dataset(config.dev_record_file, parser, config)
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_dataset.output_types, train_dataset.output_shapes)
    train_iterator = train_dataset.make_one_shot_iterator()
    dev_iterator = dev_dataset.make_one_shot_iterator()

    model = Model(config, iterator, word_mat, char_mat)

    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.allow_growth = True

    loss_save = 100.0
    patience = 0
    lr = config.init_lr

    with tf.Session(config=sess_config) as sess:
        writer = tf.summary.FileWriter(config.log_dir)
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        train_handle = sess.run(train_iterator.string_handle())
        dev_handle = sess.run(dev_iterator.string_handle())
        sess.run(tf.assign(model.is_train, tf.constant(True, dtype=tf.bool)))
        sess.run(tf.assign(model.lr, tf.constant(lr, dtype=tf.float32)))

        for _ in tqdm(range(1, config.num_steps + 1)):
            global_step = sess.run(model.global_step) + 1
            loss, train_op = sess.run([model.loss, model.train_op], feed_dict={
                                      handle: train_handle})
            if global_step % config.period == 0:
                loss_sum = tf.Summary(value=[tf.Summary.Value(
                    tag="model/loss", simple_value=loss), ])
                writer.add_summary(loss_sum, global_step)
            if global_step % config.checkpoint == 0:
                sess.run(tf.assign(model.is_train,
                                   tf.constant(False, dtype=tf.bool)))
                _, summ = evaluate_batch(
                    model, config.val_num_batches, train_eval_file, sess, "train", handle, train_handle)
                for s in summ:
                    writer.add_summary(s, global_step)

                metrics, summ = evaluate_batch(
                    model, dev_total // config.batch_size + 1, dev_eval_file, sess, "dev", handle, dev_handle)
                sess.run(tf.assign(model.is_train,
                                   tf.constant(True, dtype=tf.bool)))

                dev_loss = metrics["loss"]
                if dev_loss < loss_save:
                    loss_save = dev_loss
                    patience = 0
                else:
                    patience += 1
                if patience >= config.patience:
                    lr /= 2.0
                    loss_save = dev_loss
                    patience = 0
                sess.run(tf.assign(model.lr, tf.constant(lr, dtype=tf.float32)))
                for s in summ:
                    writer.add_summary(s, global_step)
                writer.flush()
                filename = os.path.join(
                    config.save_dir, "model_{}.ckpt".format(global_step))
                saver.save(sess, filename)
Exemplo n.º 4
0
def train(config):
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.train_eval_file, "r") as fh:
        train_eval_file = json.load(fh)
    with open(config.dev_eval_file, "r") as fh:
        dev_eval_file = json.load(fh)
    with open(config.dev_meta, "r") as fh:
        meta = json.load(fh)
    '''
        Iterator: Represents the state of iterating through a Dataset.
        
        from_string_handle(): https://www.tensorflow.org/api_docs/python/tf/data/Iterator
            This method allows you to define a "feedable" iterator where you can choose between concrete iterators 
            by feeding a value in a tf.Session.run call. In that case, string_handle would a tf.placeholder, 
            and you would feed it with the value of tf.data.Iterator.string_handle in each step.
        
        make_one_shot_iterator():Creates an Iterator for enumerating the elements of this dataset. 
            The returned iterator will be initialized automatically. 
            A "one-shot" iterator does not currently support re-initialization.
    '''

    dev_total = meta["total"]

    print("get recorded feature parser...")
    parser = get_record_parser(config)

    graph = tf.Graph()
    with graph.as_default() as g:

        print("get dataset batch iterator...")
        train_dataset = get_batch_dataset(config.train_record_file, parser,
                                          config)
        train_iterator = train_dataset.make_one_shot_iterator()
        dev_dataset = get_dataset(config.dev_record_file, parser, config)
        dev_iterator = dev_dataset.make_one_shot_iterator()

        handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            handle, train_dataset.output_types, train_dataset.output_shapes)

        print("Building model...")
        model = Model(config, iterator, word_mat, char_mat, graph=g)

        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.allow_growth = True

        patience = 0
        best_em = 0.0
        best_f1 = 0.0
        best_acc = 0.0

        with tf.Session(config=sess_config) as sess:

            sess.run(tf.global_variables_initializer())

            train_handle = sess.run(train_iterator.string_handle())
            dev_handle = sess.run(dev_iterator.string_handle())

            saver = tf.train.Saver(max_to_keep=100)
            if os.path.exists(os.path.join(config.save_dir, "checkpoint")):
                saver.restore(sess,
                              tf.train.latest_checkpoint(config.save_dir))

            writer = tf.summary.FileWriter(config.log_dir)

            global_step = max(sess.run(model.global_step), 1)
            for _ in tqdm(range(global_step, config.num_steps + 1)):
                global_step = sess.run(model.global_step) + 1

                # train on one batch
                loss, train_op = sess.run([model.loss, model.train_op],
                                          feed_dict={
                                              handle: train_handle,
                                              model.dropout: config.dropout
                                          })
                # DEBUG
                # temp = sess.run(model.debug_ops, feed_dict={handle: train_handle, model.dropout: config.dropout})
                # for t in temp:
                #     # print(t)
                #     print(t.shape)
                # sys.exit(0)

                # save batch loss
                if global_step % config.period == 0:
                    loss_sum = tf.Summary(value=[
                        tf.Summary.Value(tag="model/loss", simple_value=loss),
                    ])
                    writer.add_summary(loss_sum, global_step)

                if global_step % config.checkpoint == 0:
                    # evaluate on train
                    print('Evaluating on Train...')
                    num_bt = config.val_num_batches
                    # _, summ = evaluate_batch(model, num_bt, train_eval_file, sess, "train", handle, train_handle)
                    _, summ = evaluate_batch_cand(model, num_bt,
                                                  train_eval_file, sess,
                                                  "train", handle,
                                                  train_handle)
                    # write to summary
                    for s in summ:
                        writer.add_summary(s, global_step)

                    # evaluate on dev
                    print('Evaluating on Dev...')
                    num_bt = dev_total // config.batch_size + 1
                    # metrics, summ = evaluate_batch(model, num_bt, dev_eval_file, sess, "dev", handle, dev_handle)
                    metrics, summ = evaluate_batch_cand(
                        model, num_bt, dev_eval_file, sess, "dev", handle,
                        dev_handle)
                    # write to summary
                    for s in summ:
                        writer.add_summary(s, global_step)
                    writer.flush()
                    # save checkpoint model
                    #filename = os.path.join(config.save_dir, "model_{}.ckpt".format(global_step))
                    #saver.save(sess, filename)

                    # early stop
                    # dev_f1 = metrics["f1"]
                    # dev_em = metrics["exact_match"]
                    dev_acc = metrics["acc"]
                    if dev_acc < best_acc:
                        patience += 1
                        if patience > config.early_stop:
                            print(
                                '>>>>>> WARNING !!! <<<<<<< Early_stop reached!!!'
                            )
                    # save best model
                    else:
                        patience = 0

                        best_acc = dev_acc
                        #filename = os.path.join(config.save_dir, "model_{}.best".format(global_step))
                        filename = os.path.join(
                            config.save_dir,
                            "model_{}.ckpt".format(global_step))
                        saver.save(sess, filename)