Exemplo n.º 1
0
def run(
    args,
    logger,
    biencoder,
    biencoder_params,
    crossencoder,
    crossencoder_params,
    candidate_encoding,
    title2id,
    id2title,
    id2text,
    wikipedia_id2local_id,
    test_data=None,
):

    if not test_data and not args.test_mentions and not args.interactive:
        msg = (
            "ERROR: either you start BLINK with the "
            "interactive option (-i) or you pass in input test mentions (--test_mentions)"
            "and test entitied (--test_entities)"
        )
        raise ValueError(msg)

    id2url = {
        v : 'https://en.wikipedia.org/wiki?curid=%s' % k 
        for k, v in wikipedia_id2local_id.items()
    }

    stopping_condition = False
    while not stopping_condition:

        samples = None

        if args.interactive:
            logger.info("interactive mode")

            # biencoder_params["eval_batch_size"] = 1

            # Load NER model
            ner_model = NER.get_model()

            # Interactive
            text = input("insert text:")

            # Identify mentions
            samples = _annotate(ner_model, [text])

            _print_colorful_text(text, samples)

        else:
            logger.info("test dataset mode")

            if test_data:
                samples = test_data
            else:
                # Load test mentions
                samples = _get_test_samples(
                    args.test_mentions,
                    args.test_entities,
                    title2id,
                    wikipedia_id2local_id,
                    logger,
                )

            stopping_condition = True

        # don't look at labels
        keep_all = (
            args.interactive
            or samples[0]["label"] == "unknown"
            or samples[0]["label_id"] < 0
        )

        # prepare the data for biencoder
        logger.info("preparing data for biencoder")
        dataloader = _process_biencoder_dataloader(
            samples, biencoder.tokenizer, biencoder_params
        )

        # run biencoder
        logger.info("run biencoder")
        top_k = args.top_k
        labels, nns, scores = _run_biencoder(
            biencoder, dataloader, candidate_encoding, top_k
        )

        if args.interactive:

            print("\nfast (biencoder) predictions:")

            _print_colorful_text(text, samples)

            # print biencoder prediction
            idx = 0
            for entity_list, sample in zip(nns, samples):
                e_id = entity_list[0]
                e_title = id2title[e_id]
                e_text = id2text[e_id]
                e_url = id2url[e_id]
                _print_colorful_prediction(idx, sample, e_id, e_title, e_text, e_url, args.show_url)
                idx += 1
            print()

            if args.fast:
                # use only biencoder
                continue

        else:

            biencoder_accuracy = -1
            recall_at = -1
            if not keep_all:
                # get recall values
                top_k = args.top_k
                x = []
                y = []
                for i in range(1, top_k):
                    temp_y = 0.0
                    for label, top in zip(labels, nns):
                        if label in top[:i]:
                            temp_y += 1
                    if len(labels) > 0:
                        temp_y /= len(labels)
                    x.append(i)
                    y.append(temp_y)
                # plt.plot(x, y)
                biencoder_accuracy = y[0]
                recall_at = y[-1]
                print("biencoder accuracy: %.4f" % biencoder_accuracy)
                print("biencoder recall@%d: %.4f" % (top_k, y[-1]))

            if args.fast:

                predictions = []
                for entity_list in nns:
                    sample_prediction = []
                    for e_id in entity_list:
                        e_title = id2title[e_id]
                        sample_prediction.append(e_title)
                    predictions.append(sample_prediction)

                # use only biencoder
                return (
                    biencoder_accuracy,
                    recall_at,
                    -1,
                    -1,
                    len(samples),
                    predictions,
                    scores,
                )

        # prepare crossencoder data
        context_input, candidate_input, label_input = prepare_crossencoder_data(
            crossencoder.tokenizer, samples, labels, nns, id2title, id2text, keep_all,
        )

        context_input = modify(
            context_input, candidate_input, crossencoder_params["max_seq_length"]
        )

        dataloader = _process_crossencoder_dataloader(
            context_input, label_input, crossencoder_params
        )

        # run crossencoder and get accuracy
        accuracy, index_array, unsorted_scores = _run_crossencoder(
            crossencoder,
            dataloader,
            logger,
            context_len=biencoder_params["max_context_length"],
        )

        if args.interactive:

            print("\naccurate (crossencoder) predictions:")

            _print_colorful_text(text, samples)

            # print crossencoder prediction
            idx = 0
            for entity_list, index_list, sample in zip(nns, index_array, samples):
                e_id = entity_list[index_list[-1]]
                e_title = id2title[e_id]
                e_text = id2text[e_id]
                e_url = id2url[e_id]
                _print_colorful_prediction(idx, sample, e_id, e_title, e_text, e_url, args.show_url)
                idx += 1
            print()
        else:

            scores = []
            predictions = []
            for entity_list, index_list, scores_list in zip(
                nns, index_array, unsorted_scores
            ):

                index_list = index_list.tolist()

                # descending order
                index_list.reverse()

                sample_prediction = []
                sample_scores = []
                for index in index_list:
                    e_id = entity_list[index]
                    e_title = id2title[e_id]
                    sample_prediction.append(e_title)
                    sample_scores.append(scores_list[index])
                predictions.append(sample_prediction)
                scores.append(sample_scores)

            crossencoder_normalized_accuracy = -1
            overall_unormalized_accuracy = -1
            if not keep_all:
                crossencoder_normalized_accuracy = accuracy
                print(
                    "crossencoder normalized accuracy: %.4f"
                    % crossencoder_normalized_accuracy
                )

                overall_unormalized_accuracy = (
                    crossencoder_normalized_accuracy * len(label_input) / len(samples)
                )
                print(
                    "overall unnormalized accuracy: %.4f" % overall_unormalized_accuracy
                )
            return (
                biencoder_accuracy,
                recall_at,
                crossencoder_normalized_accuracy,
                overall_unormalized_accuracy,
                len(samples),
                predictions,
                scores,
            )
Exemplo n.º 2
0
    def link_text(self, text):
        # Identify mentions
        samples = _annotate(self.ner_model, [text])

        _print_colorful_text(text, samples)

        # don't look at labels
        keep_all = True

        # prepare the data for biencoder
        logger.info("preparing data for biencoder")
        dataloader = _process_biencoder_dataloader(samples,
                                                   self.biencoder.tokenizer,
                                                   self.biencoder_params)

        # run biencoder
        logger.info("run biencoder")
        top_k = args.top_k
        labels, nns, scores = _run_biencoder(
            self.biencoder,
            dataloader,
            self.candidate_encoding,
            top_k,
            self.faiss_indexer,
        )

        # print biencoder prediction
        idx = 0
        linked_entities = []
        for entity_list, sample in zip(nns, samples):
            e_id = entity_list[0]
            e_title = self.id2title[e_id]
            e_text = self.id2text[e_id]
            e_url = self.id2url[e_id]
            linked_entities.append({
                "idx": idx,
                "sample": sample,
                "entity_id": e_id.item(),
                "entity_title": e_title,
                "entity_text": e_text,
                "url": e_url,
                "crossencoder": False,
            })
            idx += 1

        if args.fast:
            # use only biencoder
            return {"samples": samples, "linked_entities": linked_entities}

        # prepare crossencoder data
        context_input, candidate_input, label_input = prepare_crossencoder_data(
            self.crossencoder.tokenizer,
            samples,
            labels,
            nns,
            self.id2title,
            self.id2text,
            keep_all,
        )

        context_input = modify(context_input, candidate_input,
                               self.crossencoder_params["max_seq_length"])

        dataloader = _process_crossencoder_dataloader(context_input,
                                                      label_input,
                                                      self.crossencoder_params)

        # run crossencoder and get accuracy
        accuracy, index_array, unsorted_scores = _run_crossencoder(
            self.crossencoder,
            dataloader,
            logger,
            context_len=self.biencoder_params["max_context_length"],
        )

        # print crossencoder prediction
        idx = 0
        linked_entities = []
        for entity_list, index_list, sample in zip(nns, index_array, samples):
            e_id = entity_list[index_list[-1]]
            e_title = self.id2title[e_id]
            e_text = self.id2text[e_id]
            e_url = self.id2url[e_id]
            _print_colorful_prediction(idx, sample, e_id, e_title, e_text,
                                       e_url, args.show_url)
            linked_entities.append({
                "idx": idx,
                "sample": sample,
                "entity_id": e_id.item(),
                "entity_title": e_title,
                "entity_text": e_text,
                "url": e_url,
                "crossencoder": True,
            })
            idx += 1
        return {"samples": samples, "linked_entities": linked_entities}