コード例 #1
0
ファイル: pyt_bertqe.py プロジェクト: cmacdonald/BERT-QE
    def __init__(self,
                 bert_json_path,
                 checkpoint_path,
                 verbose=False,
                 body_attr="text"):
        self.max_qlen = 128
        self.max_dlen = 256
        self.max_seq_len = 384
        self.body_attr = body_attr
        self.verbose = verbose

        from bert import modeling, tokenization
        bert_config = modeling.BertConfig.from_json_file(bert_json_path)
        self.tokenizer = tokenization.FullTokenizer(
            vocab_file="./bert/vocab.txt", do_lower_case=True)

        tpu_cluster_resolver = None
        iterations_per_loop = 500
        init_checkpoint = None  # @param {type:"string"}
        use_tpu = False
        is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2

        run_config = tf.contrib.tpu.RunConfig(
            cluster=tpu_cluster_resolver,
            keep_checkpoint_max=1,
            tpu_config=tf.contrib.tpu.TPUConfig(
                iterations_per_loop=iterations_per_loop,
                per_host_input_for_training=is_per_host))

        model_fn = model_fn_builder(bert_config=bert_config,
                                    num_labels=2,
                                    init_checkpoint=init_checkpoint,
                                    use_tpu=use_tpu,
                                    use_one_hot_embeddings=use_tpu)

        self.estimator = tf.contrib.tpu.TPUEstimator(
            use_tpu=use_tpu,
            model_fn=model_fn,
            config=run_config,
            eval_batch_size=32,
            predict_batch_size=32,
            params={"qc_scores": "qc_scores"})

        print("BERTQE Ready")
コード例 #2
0
def main(_):
    if not tf.gfile.Exists(FLAGS.output_path):
        tf.gfile.MakeDirs(FLAGS.output_path)

    if not do_train and not do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    bert_config = modeling.BertConfig.from_json_file(config_dict[FLAGS.model_size])

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tpu_cluster_resolver = None
    if use_tpu:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2

    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        save_checkpoints_steps=save_checkpoints_steps,
        model_dir=FLAGS.output_path,
        keep_checkpoint_max=5,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=iterations_per_loop,
            num_shards=num_tpu_cores,
            per_host_input_for_training=is_per_host))

    model_fn = model_fn_builder(
        bert_config=bert_config,
        num_labels=2,
        init_checkpoint=init_checkpoint,
        use_tpu=use_tpu,
        use_one_hot_embeddings=use_tpu,
        learning_rate=learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=train_batch_size,
        eval_batch_size=eval_batch_size,
        predict_batch_size=eval_batch_size,
        params={"train_examples": train_examples,
                "num_train_epochs": num_train_epochs})

    try:
        if do_train:
            tf.logging.info("***** Running training *****")
            tf.logging.info("  Batch size = %d", train_batch_size)
            tf.logging.info("  Num steps = %d", num_train_steps)
            train_input_fn = input_fn_builder(
                dataset_path=os.path.join(FLAGS.data_path, "{}_query_maxp_train.tf".format(FLAGS.dataset)),
                seq_length=FLAGS.max_seq_length,
                is_training=True,
                drop_remainder=True)

            current_step = 0
            steps_per_epoch = train_examples // train_batch_size
            tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                            ' step %d.',
                            num_train_steps,
                            num_train_steps / steps_per_epoch,
                            current_step)

            start_timestamp = time.time()

            estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

            elapsed_time = int(time.time() - start_timestamp)
            tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                            num_train_steps, elapsed_time)

    except KeyboardInterrupt:
        pass

    tf.logging.info("Done Training!")

    if do_eval:
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Batch size = %d", eval_batch_size)

        for split in ["valid", "test"]:
            query_docids_map = []
            with tf.gfile.Open(os.path.join(FLAGS.passage_path,
                                            "{0}_query_passage_{1}_top1.tsv".format(FLAGS.dataset, split))) as ref_file:
                for line in ref_file:
                    query_docids_map.append(line.strip().split("\t"))

            eval_input_fn = input_fn_builder(
                dataset_path=os.path.join(FLAGS.data_path, "{0}_query_maxp_{1}.tf".format(FLAGS.dataset, split)),
                seq_length=FLAGS.max_seq_length,
                is_training=False,
                drop_remainder=False)

            total_count = 0
            tsv_file_path = os.path.join(FLAGS.output_path, "{0}_{1}_result.tsv".format(FLAGS.dataset, split))
            trec_file_path = os.path.join(FLAGS.output_path, "{0}_{1}_result.trec".format(FLAGS.dataset, split))

            result = estimator.predict(input_fn=eval_input_fn,
                                       yield_single_examples=True)

            start_time = time.time()
            results = []
            result_dict = collections.OrderedDict()
            with tf.gfile.Open(tsv_file_path, 'w') as tsv_file, tf.gfile.Open(trec_file_path, 'w') as trec_file:
                for item in result:

                    results.append(item["probs"])
                    total_count += 1

                    if total_count == len(query_docids_map) or query_docids_map[total_count][0] != \
                            query_docids_map[total_count - 1][0]:

                        candidate_doc_num = len(results)

                        probs = np.stack(results)
                        results = probs[:, 1]

                        start_idx = total_count - candidate_doc_num
                        end_idx = total_count
                        query_ids, _, doc_ids, passage_ids, rank, _, _ = zip(*query_docids_map[start_idx:end_idx])
                        assert len(set(query_ids)) == 1, "Query ids must be all the same."
                        query_id = query_ids[0]

                        result_dict[query_id] = dict()

                        for i, doc in enumerate(doc_ids):
                            result_dict[query_id][doc] = (passage_ids[i], results[i])

                        ranking_list = sorted(result_dict[query_id].items(), key=lambda x: x[1][1], reverse=True)
                        for rank, (doc_id, (pid, score)) in enumerate(ranking_list):
                            tsv_file.write("\t".join(
                                [query_id, "Q0", doc_id, pid, str(rank + 1), str(score), "maxp_finetune"]) + "\n")
                            trec_file.write(
                                "\t".join([query_id, "Q0", doc_id, str(rank + 1), str(score), "maxp_finetune"]) + "\n")

                        results = []

                    if total_count % 1000 == 0:
                        tf.logging.info("Read {} examples in {} secs".format(
                            total_count, int(time.time() - start_time)))

                tf.logging.info("Done Evaluating!")
コード例 #3
0
ファイル: select_pieces.py プロジェクト: cmacdonald/BERT-QE
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    bert_config = modeling.BertConfig.from_json_file(
        config_dict[FLAGS.model_size])

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tpu_cluster_resolver = None
    if use_tpu:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        model_dir=FLAGS.output_path,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=iterations_per_loop,
            num_shards=num_tpu_cores,
            per_host_input_for_training=is_per_host))

    model_fn = model_fn_builder(bert_config=bert_config,
                                num_labels=2,
                                init_checkpoint=init_checkpoint,
                                use_tpu=use_tpu,
                                use_one_hot_embeddings=use_tpu)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.batch_size,
        eval_batch_size=FLAGS.batch_size,
        predict_batch_size=FLAGS.batch_size)

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Batch size = %d", FLAGS.batch_size)

    if FLAGS.task == "passage":
        path = FLAGS.output_path
    else:
        path = os.path.join(FLAGS.output_path, "fold-" + str(FLAGS.fold))

    if not tf.gfile.Exists(path):
        tf.gfile.MakeDirs(path)

    split_list = ["valid", "test"] if FLAGS.task == "chunk" else [""]
    for split in split_list:
        suffix = ""
        if split is not "":
            suffix = "_" + split

        predictions_path = os.path.join(
            path, "{}_query_{}_score{}.tsv".format(FLAGS.dataset, FLAGS.task,
                                                   suffix))
        ids_file_path = os.path.join(
            path, "query_{}_ids{}.txt".format(FLAGS.task, suffix))
        dataset_path = os.path.join(path,
                                    "query_{}{}.tf".format(FLAGS.task, suffix))

        query_chunks_ids = []
        with tf.gfile.Open(ids_file_path) as ids_file:
            for line in ids_file:
                qid, pid = line.strip().split("\t")
                query_chunks_ids.append([qid, pid])

        predict_input_fn = input_fn_builder(dataset_path=dataset_path,
                                            seq_length=FLAGS.max_seq_length,
                                            is_training=False,
                                            drop_remainder=False)

        tf.logging.set_verbosity(tf.logging.WARN)

        result = estimator.predict(input_fn=predict_input_fn,
                                   yield_single_examples=True)

        start_time = time.time()
        cnt = 0
        with tf.gfile.Open(predictions_path, "w") as predictions_file:
            for item in result:

                qid = query_chunks_ids[cnt][0]
                pid = query_chunks_ids[cnt][1]
                doc_id = pid.split("_")[0]

                probs = item["probs"]
                scores = probs[1]

                predictions_file.write("\t".join((qid, doc_id, pid,
                                                  str(float(scores)))) + "\n")
                cnt += 1
                if cnt % 10000 == 0:
                    print("process {} pairs  in {} secs.".format(
                        cnt, int(time.time() - start_time)))

            print("Done Evaluating!\nTotal examples:{}".format(cnt))

        if FLAGS.task == "passage":
            for fold in range(1, 6):
                path = os.path.join(path, "fold-" + str(fold))
                if not tf.gfile.Exists(path):
                    tf.gfile.MakeDirs(path)
                for split in ["train", "valid", "test"]:
                    qid_list = load_qid_from_cv(FLAGS.dataset, fold, split)

                    with tf.gfile.Open(predictions_path, 'r') as ref_file, \
                            tf.gfile.Open(
                                os.path.join(path, "{}_query_passage_score_{}.tsv".format(FLAGS.dataset, split)),
                                'w') as out_file, \
                            tf.gfile.Open(
                                os.path.join(path, "{}_query_passage_{}_top1.tsv".format(FLAGS.dataset, split)),
                                'w') as top_file:
                        top_res = collections.OrderedDict()
                        for line in ref_file:
                            qid, doc_id, pid, score = line.strip().split()
                            score = float(score)
                            if qid in qid_list:
                                out_file.write(line)

                                if qid not in top_res:
                                    top_res[qid] = dict()
                                if doc_id not in top_res[qid]:
                                    top_res[qid][doc_id] = {
                                        "pid": pid,
                                        "score": score
                                    }
                                else:
                                    if score > top_res[qid][doc_id]["score"]:
                                        top_res[qid][doc_id]["pid"] = pid
                                        top_res[qid][doc_id]["score"] = score

                        for qid, docs in top_res.items():
                            sorted_docs = sorted(docs.items(),
                                                 key=lambda x: x[1]["score"],
                                                 reverse=True)
                            for rank, (doc_id,
                                       pid_score) in enumerate(sorted_docs):
                                top_file.write("\t".join([
                                    qid, "Q0", doc_id, pid_score["pid"],
                                    str(rank + 1),
                                    str(pid_score["score"]),
                                    "BERT_top1_passage"
                                ]) + "\n")
        else:
            with tf.gfile.Open(predictions_path, 'r') as ref_file, \
                    tf.gfile.Open(
                        os.path.join(path, "{}_query_chunk_{}_kc-{}.tsv".format(FLAGS.dataset, split, FLAGS.kc)),
                        'w') as top_file:
                top_res = collections.OrderedDict()
                for line in ref_file:
                    qid, doc_id, pid, score = line.strip().split()
                    score = float(score)

                    if qid not in top_res:
                        top_res[qid] = dict()
                    top_res[qid][pid] = score

                for qid, pid_scores in top_res.items():
                    sorted_pids = sorted(pid_scores.items(),
                                         key=lambda x: x[1],
                                         reverse=True)
                    top_pids = sorted_pids[:FLAGS.kc]
                    for rank, (pid, score) in enumerate(top_pids):
                        top_file.write("\t".join([
                            qid, pid,
                            str(rank + 1),
                            str(score), "BERT_top{0}_chunk".format(FLAGS.kc)
                        ]) + "\n")
コード例 #4
0
def main(_):
    bert_config = modeling.BertConfig.from_json_file(config_dict[FLAGS.model_size])

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    tpu_cluster_resolver = None
    if use_tpu:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2

    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        keep_checkpoint_max=1,
        model_dir=FLAGS.output_path,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=iterations_per_loop,
            num_shards=num_tpu_cores,
            per_host_input_for_training=is_per_host))

    model_fn = model_fn_builder(
        bert_config=bert_config,
        num_labels=2,
        init_checkpoint=init_checkpoint,
        use_tpu=use_tpu,
        use_one_hot_embeddings=use_tpu)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.batch_size,
        eval_batch_size=FLAGS.batch_size,
        predict_batch_size=FLAGS.batch_size,
        params={"qc_scores": "qc_scores"})

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Batch size = %d", FLAGS.batch_size)

    for split in ["valid", "test"]:

        maxp_run = load_run(os.path.join(FLAGS.first_model_path, "{}_{}_result.trec".format(FLAGS.dataset, split)))

        query_docids_map = []
        data_path = os.path.join(FLAGS.output_path, "rerank-{0}_kc-{1}".format(FLAGS.rerank_num, FLAGS.kc), "data")
        result_path = os.path.join(FLAGS.output_path, "rerank-{0}_kc-{1}".format(FLAGS.rerank_num, FLAGS.kc), "result")
        if not tf.gfile.Exists(result_path):
            tf.gfile.MakeDirs(result_path)

        with tf.gfile.Open(os.path.join(data_path, "chunk_passage_ids_{0}.txt".format(split))) as ref_file:
            for line in ref_file:
                query_docids_map.append(line.strip().split("\t"))

        predict_input_fn = input_fn_builder(
            dataset_path=os.path.join(data_path, "chunk_passage_{0}.tf".format(split)),
            is_training=False,
            seq_length=FLAGS.max_seq_length,
            drop_remainder=False)

        total_count = 0

        result_file = tf.gfile.Open(os.path.join(result_path, "{0}_{1}_result.trec".format(FLAGS.dataset, split)), 'w')

        ckpt = tf.train.latest_checkpoint(checkpoint_dir=FLAGS.third_model_path)
        print("use latest ckpt: {0}".format(ckpt))

        result = estimator.predict(input_fn=predict_input_fn,
                                   yield_single_examples=True,
                                   checkpoint_path=ckpt)

        start_time = time.time()
        results = []
        result_dict = collections.OrderedDict()
        for item in result:

            results.append((item["qc_scores"], item["probs"]))
            total_count += 1

            if total_count == len(query_docids_map) or query_docids_map[total_count][0] != \
                    query_docids_map[total_count - 1][0]:

                chunk_num = len(results) // FLAGS.rerank_num
                assert chunk_num <= FLAGS.kc

                qc_scores, probs = list(zip(*results))
                qc_scores = np.stack(qc_scores)
                cp_scores = np.stack(probs)[:, 1]

                qc_scores = np.reshape(qc_scores, [FLAGS.rerank_num, chunk_num])
                cp_scores = np.reshape(cp_scores, [FLAGS.rerank_num, chunk_num])

                # softmax normalization
                qc_scores = softmax(qc_scores, axis=-1)

                scores = np.sum(np.multiply(qc_scores, cp_scores), axis=-1, keepdims=False)

                start_idx = total_count - FLAGS.rerank_num * chunk_num
                end_idx = total_count
                query_ids, chunk_ids, passage_ids, labels, qc_scores = zip(*query_docids_map[start_idx:end_idx])
                assert len(set(query_ids)) == 1, "Query ids must be all the same."
                query_id = query_ids[0]

                candidate_docs = list()
                for pid in passage_ids:
                    doc_id = pid.split("_")[0]
                    if doc_id not in candidate_docs:
                        candidate_docs.append(doc_id)

                result_dict[query_id] = dict()

                for i, doc in enumerate(candidate_docs):
                    result_dict[query_id][doc] = scores[i]

                rerank_list = sorted(result_dict[query_id].items(), key=lambda x: x[1], reverse=True)

                last_score = rerank_list[-1][1]
                for doc in maxp_run[query_id][FLAGS.rerank_num:]:
                    current_score = last_score - 0.01
                    result_dict[query_id][doc] = current_score
                    last_score = current_score

                ranking_list = sorted(result_dict[query_id].items(), key=lambda x: x[1], reverse=True)

                for rank, (doc_id, score) in enumerate(ranking_list):
                    result_file.write(
                        "\t".join([query_id, "Q0", doc_id, str(rank + 1), str(score), "chunk_passage_PRF"]) + "\n")

                results = []

            if total_count % 1000 == 0:
                tf.logging.warn("Read {} examples in {} secs".format(
                    total_count, int(time.time() - start_time)))

        result_file.close()
        tf.logging.info("Done Evaluating!")