コード例 #1
0
ファイル: test_ed_pipeline.py プロジェクト: zhy5186612/REL
def test_pipeline():
    base_url = Path(__file__).parent
    wiki_subfolder = "wiki_test"
    sample = {
        "test_doc": ["the brown fox jumped over the lazy dog", [[10, 3]]]
    }
    config = {
        "mode": "eval",
        "model_path": f"{base_url}/{wiki_subfolder}/generated/model",
    }

    md = MentionDetection(base_url, wiki_subfolder)
    tagger = Cmns(base_url, wiki_subfolder, n=5)
    model = EntityDisambiguation(base_url, wiki_subfolder, config)

    mentions_dataset, total_mentions = md.format_spans(sample)

    predictions, _ = model.predict(mentions_dataset)
    results = process_results(mentions_dataset,
                              predictions,
                              sample,
                              include_offset=False)

    gold_truth = {"test_doc": [(10, 3, "Fox", "fox", -1, "NULL", 0.0)]}

    return results == gold_truth
コード例 #2
0
ファイル: server.py プロジェクト: zxlzr/REL
        def generate_response(self, text, spans):
            """
            Generates response for API. Can be either ED only or EL, meaning end-to-end.

            :return: list of tuples for each entity found.
            """

            if len(text) == 0:
                return []

            if len(spans) > 0:
                # Now we do ED.
                processed = {API_DOC: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.format_spans(
                    processed)
            elif self.custom_ner:
                # Verify if we have spans.
                if len(spans) == 0:
                    print("No spans found for custom MD.")
                    return []
                spans = self.tagger_ner(text)

                processed = {API_DOC: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.format_spans(
                    processed)
            else:
                # EL
                processed = {API_DOC: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.find_mentions(
                    processed, self.tagger_ner)

            # Disambiguation
            predictions, timing = self.model.predict(mentions_dataset)

            # Process result.
            result = process_results(
                mentions_dataset,
                predictions,
                processed,
                include_offset=False if
                ((len(spans) > 0) or self.custom_ner) else True,
            )

            # Singular document.
            if len(result) > 0:
                return [*result.values()][0]

            return []
コード例 #3
0
ファイル: predict_EL.py プロジェクト: ruanchaves/REL-1
# For Mention detection two options.
# 2. Mention detection, we used the NER tagger, user can also use his/her own mention detection module.
mention_detection = MentionDetection(base_url, wiki_version)

# 2. Alternatively. if you want to use your own MD system (or ngram detection),
# the required input is: {doc_name: [text, spans] ... }.
mentions_dataset, n_mentions = mention_detection.format_spans(input_text)

# 2. Alternative MD module is using an n-gram tagger.
tagger_ner = load_flair_ner("ner-fast")
# tagger_ngram = Cmns(base_url, wiki_version, n=5)

mentions_dataset, n_mentions = mention_detection.find_mentions(
    input_text, tagger_ner)

# 3. Load model.
config = {
    "mode": "eval",
    "model_path": "{}/{}/generated/model".format(base_url, wiki_version),
}
model = EntityDisambiguation(base_url, wiki_version, config)

# 4. Entity disambiguation.
predictions, timing = model.predict(mentions_dataset)

# 5. Optionally use our function to get results in a usable format.
result = process_results(mentions_dataset, predictions, input_text)

print(result)
コード例 #4
0
    def __add_rel_entity_links(self):
        """ Add REL entity linker links to all document contents. """

        # Build batch REL text input - sharding document contents into sentences.
        processed_document_contents = {}
        for document_content in self.document.document_contents:
            content_id = document_content.content_id
            text = document_content.text
            for i, sentence in enumerate(text.split(".")):
                sentence_id = str(content_id + (str(i)))
                processed_document_contents[sentence_id] = [str(sentence), []]

        # Run REL model with document contents.
        mentions_dataset, n_mentions = self.mention_detection.find_mentions(
            processed_document_contents, self.tagger_ner)
        predictions, timing = self.entity_disambiguation.predict(
            mentions_dataset)
        entity_links_dict = process_results(mentions_dataset, predictions,
                                            processed_document_contents)

        # Connet to LMDB of: {pickle(car_id): pickle(car_name).}
        env = lmdb.open(self.car_id_to_name_path, map_size=2e10)
        with env.begin(write=False) as txn:

            for document_content in self.document.document_contents:

                content_id = document_content.content_id
                text = document_content.text

                i_sentence_start = 0
                for i, sentence in enumerate(text.split(".")):
                    sentence_id = str(content_id + (str(i)))
                    if sentence_id in entity_links_dict:
                        for entity_link in entity_links_dict[sentence_id]:
                            # % of confidence in entity linking
                            if float(entity_link[4]) >= 0.0:
                                i_start = i_sentence_start + entity_link[0] + 1
                                i_end = i_start + entity_link[1]
                                span_text = text[i_start:i_end]

                                entity_id = self.__rel_id_to_car_id(
                                    rel_id=entity_link[3])
                                entity_name_pickle = txn.get(
                                    pickle.dumps(entity_id))

                                if entity_name_pickle != None:
                                    self.valid_counter += 1
                                    entity_name = pickle.loads(
                                        entity_name_pickle)
                                    #entity_name = entity_id

                                    if entity_link[2] == span_text:

                                        assert entity_link[2] == span_text, \
                                            "word_text: '{}' , text[start_i:end_i]: '{}'".format(entity_link[2], span_text)

                                        anchor_text_location = document_pb2.EntityLink.AnchorTextLocation(
                                        )
                                        anchor_text_location.start = i_start
                                        anchor_text_location.end = i_end

                                        # Create new EntityLink message.
                                        rel_entity_link = document_pb2.EntityLink(
                                        )
                                        rel_entity_link.anchor_text = entity_link[
                                            2]
                                        rel_entity_link.entity_id = entity_id
                                        rel_entity_link.entity_name = entity_name
                                        rel_entity_link.anchor_text_location.MergeFrom(
                                            anchor_text_location)

                                        document_content.rel_entity_links.append(
                                            rel_entity_link)

                                    else:
                                        regex = re.escape(entity_link[2])
                                        for match in re.finditer(
                                                r'{}'.format(regex), sentence):
                                            i_start = i_sentence_start + match.start(
                                            )
                                            i_end = i_sentence_start + match.end(
                                            )
                                            span_text = text[i_start:i_end]

                                            assert entity_link[2] == span_text, \
                                                "word_text: '{}' , text[start_i:end_i]: '{}'".format(entity_link[2], span_text)

                                            anchor_text_location = document_pb2.EntityLink.AnchorTextLocation(
                                            )
                                            anchor_text_location.start = i_start
                                            anchor_text_location.end = i_end

                                            # Create new EntityLink message.
                                            rel_entity_link = document_pb2.EntityLink(
                                            )
                                            rel_entity_link.anchor_text = entity_link[
                                                2]
                                            rel_entity_link.entity_id = entity_id
                                            rel_entity_link.entity_name = entity_name
                                            rel_entity_link.anchor_text_location.MergeFrom(
                                                anchor_text_location)

                                            document_content.rel_entity_links.append(
                                                rel_entity_link)

                    i_sentence_start += len(sentence) + 1
コード例 #5
0
        def generate_response(self, text, spans):
            """
            Generates response for API. Can be either ED only or EL, meaning end-to-end.

            :return: list of tuples for each entity found.
            """

            n_words = len(text.split())

            if len(text) == 0:
                return []

            start = time.time()
            if (self.mode == "ED") or self.custom_ner:
                if self.custom_ner:
                    spans = self.tagger_ner(text)

                # Verify if we have spans.
                if len(spans) == 0:
                    print("No spans found while in ED mode..?")
                    return []
                processed = {
                    GERBIL: [text, spans]
                }  # self.split_text(text, spans)
                mentions_dataset, total_ment = self.mention_detection.format_spans(
                    processed)
            elif self.mode == "EL":
                # EL
                processed = {GERBIL: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.find_mentions(
                    processed, self.tagger_ner)
            else:
                raise Exception(
                    "Faulty mode, only valid options are: ED or EL")
            time_md = time.time() - start

            # Disambiguation
            start = time.time()
            predictions, timing = self.model.predict(mentions_dataset)
            time_ed = time.time() - start

            # Tuple of.
            efficiency = [
                str(n_words),
                str(total_ment),
                str(time_md),
                str(time_ed)
            ]

            # write to txt file.
            with (self.base_url / self.wiki_subfolder /
                  "generated/efficiency.txt").open('a', encoding='utf-8') as f:
                f.write('\t'.join(efficiency) + '\n')

            # Process result.
            result = process_results(
                mentions_dataset,
                predictions,
                processed,
                include_offset=self.use_offset,
                include_conf=self.include_conf,
            )

            self.doc_cnt += 1
            return result
コード例 #6
0
# For Mention detection two options.
# 2. Mention detection, we used the NER tagger, user can also use his/her own mention detection module.
mention_detection = MentionDetection(base_url, wiki_subfolder)

# If you want to use your own MD system, the required input is: {doc_name: [text, spans] ... }.
mentions_dataset, n_mentions = mention_detection.format_spans(input_documents)

# Alternatively use Flair NER tagger.
tagger_ner = SequenceTagger.load("ner-fast")
mentions_dataset, n_mentions = mention_detection.find_mentions(
    input_documents, tagger_ner)

# 3. Load model.
config = {
    "mode": "eval",
    "model_path": base_url / wiki_subfolder / "generated" / "model",
}
model = EntityDisambiguation(base_url, wiki_subfolder, config)

# 4. Entity disambiguation.
predictions, timing = model.predict(mentions_dataset)

# 5. Optionally use our function to get results in a usable format.
result = process_results(mentions_dataset,
                         predictions,
                         input_documents,
                         include_conf=True)

print(result)