コード例 #1
0
    def get_predictions(self, batch):
        """
        Return: 
        """
        starts, ends = [], []
        for ckpt, FLAGS in zip(self.ckpts, self.flags):
            qa_model = QAModel(FLAGS,
                               self.id2word,
                               self.word2id,
                               self.emb_matrix,
                               self.id2idf,
                               is_training=False)
            with tf.Session(config=self.tf_config) as session:
                qa_model.initialize_from_checkpoint(session, ckpt, True)
                pred_start_pos, pred_end_pos = qa_model.get_start_end_pos(
                    session, batch)
                starts.append(pred_start_pos)
                ends.append(pred_end_pos)
            del qa_model
            tf.reset_default_graph()

        starts, _ = stats.mode(np.array(starts))
        ends, _ = stats.mode(np.array(ends))
        return (starts[0].astype(np.int), ends[0].astype(np.int))
コード例 #2
0
def main(unused_argv):
    # Print an error message if you've entered flags incorrectly
    if len(unused_argv) != 1:
        raise Exception("There is a problem with how you entered flags: %s" %
                        unused_argv)

    # Check for Python 2
    if sys.version_info[0] != 2:
        raise Exception(
            "ERROR: You must use Python 2 but you are running Python %i" %
            sys.version_info[0])

    # Define path for glove vecs
    FLAGS.glove_path = FLAGS.glove_path or \
                       os.path.join(DEFAULT_DATA_DIR,
                            "glove.6B.{}d.txt".format(FLAGS.embedding_size))

    # Load embedding matrix and vocab mappings
    timer.start("glove_getter")
    emb_matrix, word2id, id2word = get_glove(FLAGS.glove_path,
                                             FLAGS.embedding_size)
    id2idf = get_idf(os.path.abspath(FLAGS.idf_path), word2id)
    logger.warn("Get glove embedding of size {} takes {:.2f} s".format(
        FLAGS.embedding_size, timer.stop("glove_getter")))
    # Print out Tensorflow version
    # print "This code was developed and tested on TensorFlow 1.4.1. Your TensorFlow version: %s" % tf.__version__

    ensumble = FLAGS.ensumble
    print(ensumble)
    if not ensumble and not FLAGS.attn_layer and not FLAGS.train_dir and FLAGS.mode != "official_eval":
        raise Exception(
            "You need to specify either --attn_layer or --train_dir")

    # Define train_dir
    if not FLAGS.experiment_name:
        FLAGS.experiment_name = "A_{}_E_{}_D_{}".format(
            FLAGS.attn_layer, FLAGS.embedding_size, FLAGS.dropout)

    checkptr_name = FLAGS.experiment_name + "/glove{}".format(
        FLAGS.embedding_size)
    FLAGS.train_dir = FLAGS.train_dir or\
                        os.path.join(EXPERIMENTS_DIR, checkptr_name)

    # Initialize bestmodel directory
    bestmodel_dir = os.path.join(FLAGS.train_dir, "best_checkpoint")

    # Get filepaths to train/dev datafiles for tokenized queries, contexts and answers
    train_context_path = os.path.join(FLAGS.data_dir, "train.context")
    train_qn_path = os.path.join(FLAGS.data_dir, "train.question")
    train_ans_path = os.path.join(FLAGS.data_dir, "train.span")
    dev_context_path = os.path.join(FLAGS.data_dir, "dev.context")
    dev_qn_path = os.path.join(FLAGS.data_dir, "dev.question")
    dev_ans_path = os.path.join(FLAGS.data_dir, "dev.span")

    # Some GPU settings
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    is_training = (FLAGS.mode == "train")
    if not ensumble:
        # Initialize model
        qa_model = QAModel(FLAGS, id2word, word2id, emb_matrix, id2idf,
                           is_training)
    else:
        ensumbler = Ensumbler(ensumble, config, id2word, word2id, emb_matrix,
                              id2idf)

    # Split by mode
    if FLAGS.mode == "train":
        # Setup train dir and logfile
        if not os.path.exists(FLAGS.train_dir):
            os.makedirs(FLAGS.train_dir)
        file_handler = logging.FileHandler(
            os.path.join(FLAGS.train_dir, "log.txt"))
        logging.getLogger().addHandler(file_handler)

        # Save a record of flags as a .json file in train_dir
        with open(os.path.join(FLAGS.train_dir, "flags.json"), 'w') as fout:
            json.dump(FLAGS.__flags, fout)

        # Make bestmodel dir if necessary
        if not os.path.exists(bestmodel_dir):
            os.makedirs(bestmodel_dir)

        with tf.Session(config=config) as sess:
            # Load most recent model
            qa_model.initialize_from_checkpoint(sess,
                                                FLAGS.train_dir,
                                                expect_exists=False)

            # Train
            qa_model.train(sess, train_context_path, train_qn_path,
                           train_ans_path, dev_qn_path, dev_context_path,
                           dev_ans_path)

    elif FLAGS.mode == "show_examples":
        with tf.Session(config=config) as sess:

            # Load best model
            qa_model.initialize_from_checkpoint(sess,
                                                bestmodel_dir,
                                                expect_exists=True)

            # Show examples with F1/EM scores
            f1, em = qa_model.check_f1_em(sess,
                                          dev_context_path,
                                          dev_qn_path,
                                          dev_ans_path,
                                          "dev",
                                          num_samples=10,
                                          print_to_screen=True)
            logger.info("Dev: F1 = {0:.3}, EM = {0:.3}".format(f1, em))

    elif FLAGS.mode == "eval":
        if ensumble:
            # train
            train_f1, train_em = ensumbler.check_f1_em(train_context_path,
                                                       train_qn_path,
                                                       train_ans_path, "train",
                                                       FLAGS.n_eval)
            # dev
            dev_f1, dev_em = ensumbler.check_f1_em(dev_context_path,
                                                   dev_qn_path, dev_ans_path,
                                                   "dev", FLAGS.n_eval)

        else:
            with tf.Session(config=config) as sess:

                # Load best model
                qa_model.initialize_from_checkpoint(sess,
                                                    FLAGS.ckpt_load_dir,
                                                    expect_exists=True)

                logger.info("Model initialzed from checkpoint")
                # train
                train_f1, train_em = qa_model.check_f1_em(
                    sess,
                    train_context_path,
                    train_qn_path,
                    train_ans_path,
                    "train",
                    num_samples=10,
                    print_to_screen=False)
                # dev
                dev_f1, dev_em = qa_model.check_f1_em(sess,
                                                      dev_context_path,
                                                      dev_qn_path,
                                                      dev_ans_path,
                                                      "dev",
                                                      num_samples=10,
                                                      print_to_screen=False)
        logger.error("Train: F1 = {:.3}, EM = {:.3}".format(
            train_f1, train_em))
        logger.error("Dev:   F1 = {:.3}, EM = {:.3}".format(dev_f1, dev_em))

    elif FLAGS.mode == "official_eval":
        if not ensumble:
            if FLAGS.json_in_path == "":
                raise Exception(
                    "For official_eval mode, you need to specify --json_in_path"
                )
            if FLAGS.ckpt_load_dir == "":
                raise Exception(
                    "For official_eval mode, you need to specify --ckpt_load_dir"
                )

        # Read the JSON data from file
        qn_uuid_data, context_token_data, qn_token_data = get_json_data(
            FLAGS.json_in_path)

        if ensumble:
            answers_dict = ensumbler.generate_answers(qn_uuid_data,
                                                      context_token_data,
                                                      qn_token_data)
        else:
            with tf.Session(config=config) as sess:
                # Load model from ckpt_load_dir
                qa_model.initialize_from_checkpoint(sess,
                                                    FLAGS.ckpt_load_dir,
                                                    expect_exists=True)
                # Get a predicted answer for each example in the data
                # Return a mapping answers_dict from uuid to answer
                answers_dict = generate_answers(sess, qa_model, word2id,
                                                id2idf, qn_uuid_data,
                                                context_token_data,
                                                qn_token_data)

        # Write the uuid->answer mapping a to json file in root dir
        print "Writing predictions to %s..." % FLAGS.json_out_path
        with io.open(FLAGS.json_out_path, 'w', encoding='utf-8') as f:
            f.write(unicode(json.dumps(answers_dict, ensure_ascii=False)))
            print "Wrote predictions to %s" % FLAGS.json_out_path

    else:
        raise Exception("Unexpected value of FLAGS.mode: %s" % FLAGS.mode)