示例#1
0
    def __init__(self, args, logger=None) -> None:
        self.args = args
        self.logger = logger
        # Load NER model
        self.ner_model = NER.get_model()
        # load biencoder model
        if logger:
            logger.info("loading biencoder model")
        with open(args.biencoder_config) as json_file:
            biencoder_params = json.load(json_file)
            biencoder_params["path_to_model"] = args.biencoder_model
        biencoder = load_biencoder(biencoder_params)

        crossencoder = None
        crossencoder_params = None
        if not args.fast:
            # load crossencoder model
            if logger:
                logger.info("loading crossencoder model")
            with open(args.crossencoder_config) as json_file:
                crossencoder_params = json.load(json_file)
                crossencoder_params["path_to_model"] = args.crossencoder_model
            crossencoder = load_crossencoder(crossencoder_params)

        # load candidate entities
        if logger:
            logger.info("loading candidate entities")
        (
            candidate_encoding,
            title2id,
            id2title,
            id2text,
            wikipedia_id2local_id,
            faiss_indexer,
        ) = _load_candidates(
            args.entity_catalogue,
            args.entity_encoding,
            faiss_index=args.faiss_index,
            index_path=args.index_path,
            logger=logger,
        )

        self.biencoder = biencoder
        self.biencoder_params = biencoder_params
        self.crossencoder = crossencoder
        self.crossencoder_params = crossencoder_params
        self.candidate_encoding = candidate_encoding
        self.title2id = title2id
        self.id2title = id2title
        self.id2text = id2text
        self.wikipedia_id2local_id = wikipedia_id2local_id
        self.faiss_indexer = faiss_indexer
        self.id2url = {
            v: "https://en.wikipedia.org/wiki?curid=%s" % k
            for k, v in wikipedia_id2local_id.items()
        }
示例#2
0
文件: main_dense.py 项目: yyht/BLINK
def run(
    args,
    logger,
    biencoder,
    biencoder_params,
    crossencoder,
    crossencoder_params,
    candidate_encoding,
    title2id,
    id2title,
    id2text,
    wikipedia_id2local_id,
    test_data=None,
):

    if not test_data and not args.test_mentions and not args.interactive:
        msg = (
            "ERROR: either you start BLINK with the "
            "interactive option (-i) or you pass in input test mentions (--test_mentions)"
            "and test entitied (--test_entities)"
        )
        raise ValueError(msg)

    id2url = {
        v : 'https://en.wikipedia.org/wiki?curid=%s' % k 
        for k, v in wikipedia_id2local_id.items()
    }

    stopping_condition = False
    while not stopping_condition:

        samples = None

        if args.interactive:
            logger.info("interactive mode")

            # biencoder_params["eval_batch_size"] = 1

            # Load NER model
            ner_model = NER.get_model()

            # Interactive
            text = input("insert text:")

            # Identify mentions
            samples = _annotate(ner_model, [text])

            _print_colorful_text(text, samples)

        else:
            logger.info("test dataset mode")

            if test_data:
                samples = test_data
            else:
                # Load test mentions
                samples = _get_test_samples(
                    args.test_mentions,
                    args.test_entities,
                    title2id,
                    wikipedia_id2local_id,
                    logger,
                )

            stopping_condition = True

        # don't look at labels
        keep_all = (
            args.interactive
            or samples[0]["label"] == "unknown"
            or samples[0]["label_id"] < 0
        )

        # prepare the data for biencoder
        logger.info("preparing data for biencoder")
        dataloader = _process_biencoder_dataloader(
            samples, biencoder.tokenizer, biencoder_params
        )

        # run biencoder
        logger.info("run biencoder")
        top_k = args.top_k
        labels, nns, scores = _run_biencoder(
            biencoder, dataloader, candidate_encoding, top_k
        )

        if args.interactive:

            print("\nfast (biencoder) predictions:")

            _print_colorful_text(text, samples)

            # print biencoder prediction
            idx = 0
            for entity_list, sample in zip(nns, samples):
                e_id = entity_list[0]
                e_title = id2title[e_id]
                e_text = id2text[e_id]
                e_url = id2url[e_id]
                _print_colorful_prediction(idx, sample, e_id, e_title, e_text, e_url, args.show_url)
                idx += 1
            print()

            if args.fast:
                # use only biencoder
                continue

        else:

            biencoder_accuracy = -1
            recall_at = -1
            if not keep_all:
                # get recall values
                top_k = args.top_k
                x = []
                y = []
                for i in range(1, top_k):
                    temp_y = 0.0
                    for label, top in zip(labels, nns):
                        if label in top[:i]:
                            temp_y += 1
                    if len(labels) > 0:
                        temp_y /= len(labels)
                    x.append(i)
                    y.append(temp_y)
                # plt.plot(x, y)
                biencoder_accuracy = y[0]
                recall_at = y[-1]
                print("biencoder accuracy: %.4f" % biencoder_accuracy)
                print("biencoder recall@%d: %.4f" % (top_k, y[-1]))

            if args.fast:

                predictions = []
                for entity_list in nns:
                    sample_prediction = []
                    for e_id in entity_list:
                        e_title = id2title[e_id]
                        sample_prediction.append(e_title)
                    predictions.append(sample_prediction)

                # use only biencoder
                return (
                    biencoder_accuracy,
                    recall_at,
                    -1,
                    -1,
                    len(samples),
                    predictions,
                    scores,
                )

        # prepare crossencoder data
        context_input, candidate_input, label_input = prepare_crossencoder_data(
            crossencoder.tokenizer, samples, labels, nns, id2title, id2text, keep_all,
        )

        context_input = modify(
            context_input, candidate_input, crossencoder_params["max_seq_length"]
        )

        dataloader = _process_crossencoder_dataloader(
            context_input, label_input, crossencoder_params
        )

        # run crossencoder and get accuracy
        accuracy, index_array, unsorted_scores = _run_crossencoder(
            crossencoder,
            dataloader,
            logger,
            context_len=biencoder_params["max_context_length"],
        )

        if args.interactive:

            print("\naccurate (crossencoder) predictions:")

            _print_colorful_text(text, samples)

            # print crossencoder prediction
            idx = 0
            for entity_list, index_list, sample in zip(nns, index_array, samples):
                e_id = entity_list[index_list[-1]]
                e_title = id2title[e_id]
                e_text = id2text[e_id]
                e_url = id2url[e_id]
                _print_colorful_prediction(idx, sample, e_id, e_title, e_text, e_url, args.show_url)
                idx += 1
            print()
        else:

            scores = []
            predictions = []
            for entity_list, index_list, scores_list in zip(
                nns, index_array, unsorted_scores
            ):

                index_list = index_list.tolist()

                # descending order
                index_list.reverse()

                sample_prediction = []
                sample_scores = []
                for index in index_list:
                    e_id = entity_list[index]
                    e_title = id2title[e_id]
                    sample_prediction.append(e_title)
                    sample_scores.append(scores_list[index])
                predictions.append(sample_prediction)
                scores.append(sample_scores)

            crossencoder_normalized_accuracy = -1
            overall_unormalized_accuracy = -1
            if not keep_all:
                crossencoder_normalized_accuracy = accuracy
                print(
                    "crossencoder normalized accuracy: %.4f"
                    % crossencoder_normalized_accuracy
                )

                overall_unormalized_accuracy = (
                    crossencoder_normalized_accuracy * len(label_input) / len(samples)
                )
                print(
                    "overall unnormalized accuracy: %.4f" % overall_unormalized_accuracy
                )
            return (
                biencoder_accuracy,
                recall_at,
                crossencoder_normalized_accuracy,
                overall_unormalized_accuracy,
                len(samples),
                predictions,
                scores,
            )
示例#3
0
    def load_model(self):

        #self.biencoder, self.biencoder_params, self.candidate_encoding, self.faiss_indexer \
        #   = load_models(logger=self.logger)

        self.ner_model = NER.get_model()
示例#4
0
文件: test2.py 项目: helderarr/WS
        entity_catalogue,
        entity_encoding,
        faiss_index=None,
        index_path=None,
        logger=logger,
    )

    return biencoder, biencoder_params, candidate_encoding, faiss_indexer


top_k = 1

logger = utils.get_logger('output')
biencoder, biencoder_params, candidate_encoding, faiss_indexer = load_models(
    logger=logger)
ner_model = NER.get_model()
samples = _annotate(ner_model, [
    'What is throat cancer? Throat cancer is any cancer that forms in the throat. The throat, also called the pharynx, is a 5-inch-long tube that runs from your nose to your neck. The larynx (voice box) and pharynx are the two main places throat cancer forms. Throat cancer is a type of head and neck cancer, which includes cancer of the mouth, tonsils, nose, sinuses, salivary glands and neck lymph nodes.',
    "Juan Carlos is the king os Spain", "Cristiano Ronaldo has 5 Ballon D'Or"
])
print(samples)
dataloader = _process_biencoder_dataloader(samples, biencoder.tokenizer,
                                           biencoder_params)
labels, nns, scores = _run_biencoder(biencoder, dataloader, candidate_encoding,
                                     top_k, faiss_indexer)

print(labels, nns, scores)

logger.info(nns)
logger.info(scores)
示例#5
0
def main(parameters):
    print("Parameters:", parameters)
    # Read data
    sentences = utils.read_sentences_from_file(
        parameters["path_to_input_file"],
        one_sentence_per_line=parameters["one_sentence_per_line"],
    )

    # Identify mentions
    ner_model = NER.get_model(parameters)
    ner_output_data = ner_model.predict(sentences)
    sentences = ner_output_data["sentences"]
    mentions = ner_output_data["mentions"]

    output_folder_path = parameters["output_folder_path"]

    if ((output_folder_path is not None) and os.path.exists(output_folder_path)
            and os.listdir(output_folder_path)):
        print(
            "The given output directory ({}) already exists and is not empty.".
            format(output_folder_path))
        answer = input(
            "Would you like to empty the existing directory? [Y/N]\n")

        if answer.strip() == "Y":
            print("Deleting {}...".format(output_folder_path))
            shutil.rmtree(output_folder_path)
        else:
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(output_folder_path))

    if output_folder_path is not None:
        utils.write_dicts_as_json_per_line(
            sentences, utils.get_sentences_txt_file_path(output_folder_path))
        utils.write_dicts_as_json_per_line(
            mentions, utils.get_mentions_txt_file_path(output_folder_path))

    # Generate candidates and get the data that describes the candidates
    candidate_generator = CG.get_model(parameters)
    candidate_generator.process_mentions_for_candidate_generator(
        sentences=sentences, mentions=mentions)

    for mention in mentions:
        mention["candidates"] = candidate_generator.get_candidates(mention)
        if parameters["consider_additional_datafetcher"]:
            data_fetcher = CDF.get_model(parameters)
            for candidate in mention["candidates"]:
                data_fetcher.get_data_for_entity(candidate)

    if output_folder_path is not None:
        utils.write_dicts_as_json_per_line(
            mentions, utils.get_mentions_txt_file_path(output_folder_path))

    # Reranking
    reranking_model = R.get_model(parameters)
    reranking_model.rerank(mentions, sentences)

    if output_folder_path is not None:
        utils.write_dicts_as_json_per_line(
            mentions, utils.get_mentions_txt_file_path(output_folder_path))
        utils.write_end2end_pickle_output(sentences, mentions,
                                          output_folder_path)
        utils.present_annotated_sentences(
            sentences,
            mentions,
            utils.get_end2end_pretty_output_file_path(output_folder_path),
        )

    # Showcase results
    utils.present_annotated_sentences(sentences, mentions)