Exemple #1
0
def test_pipeline():
    base_url = Path(__file__).parent
    wiki_subfolder = "wiki_test"
    sample = {
        "test_doc": ["the brown fox jumped over the lazy dog", [[10, 3]]]
    }
    config = {
        "mode": "eval",
        "model_path": f"{base_url}/{wiki_subfolder}/generated/model",
    }

    md = MentionDetection(base_url, wiki_subfolder)
    tagger = Cmns(base_url, wiki_subfolder, n=5)
    model = EntityDisambiguation(base_url, wiki_subfolder, config)

    mentions_dataset, total_mentions = md.format_spans(sample)

    predictions, _ = model.predict(mentions_dataset)
    results = process_results(mentions_dataset,
                              predictions,
                              sample,
                              include_offset=False)

    gold_truth = {"test_doc": [(10, 3, "Fox", "fox", -1, "NULL", 0.0)]}

    return results == gold_truth
        def __init__(self, *args, **kwargs):
            self.model = model
            self.tagger_ner = tagger_ner

            self.base_url = base_url
            self.wiki_version = wiki_version

            self.custom_ner = not isinstance(tagger_ner, SequenceTagger)
            self.mention_detection = MentionDetection(base_url, wiki_version)

            super().__init__(*args, **kwargs)
Exemple #3
0
        def __init__(self, *args, **kwargs):
            self.model = model
            self.tagger_ner = tagger_ner

            self.mode = mode
            self.include_conf = include_conf
            self.base_url = base_url
            self.wiki_subfolder = wiki_subfolder

            self.custom_ner = not isinstance(tagger_ner, SequenceTagger)
            self.use_offset = False if ((mode == "ED")
                                        or self.custom_ner) else True

            self.doc_cnt = 0
            self.mention_detection = MentionDetection(base_url, wiki_subfolder)

            super().__init__(*args, **kwargs)
Exemple #4
0
def test_md():
    # return standard Flair tagger + mention detection object
    tagger = SequenceTagger.load("ner-fast")
    md = MentionDetection(Path(__file__).parent, "wiki_test")

    # first test case: repeating sentences
    sample1 = {"test_doc": ["Fox, Fox. Fox.", []]}
    resulting_spans1 = {(0, 3), (5, 3), (10, 3)}
    predictions = md.find_mentions(sample1, tagger)
    predicted_spans = {
        (m["pos"], m["end_pos"] - m["pos"]) for m in predictions[0]["test_doc"]
    }
    assert resulting_spans1 == predicted_spans

    # second test case: excessive whitespace
    sample2 = {"test_doc": ["Fox.                Fox.                   Fox.", []]}
    resulting_spans2 = {(0, 3), (20, 3), (43, 3)}
    predictions = md.find_mentions(sample2, tagger)
    predicted_spans = {
        (m["pos"], m["end_pos"] - m["pos"]) for m in predictions[0]["test_doc"]
    }
    assert resulting_spans2 == predicted_spans
 def __init_rel_models(self, rel_wiki_year, rel_base_url, rel_model_path,
                       car_id_to_name_path):
     """ """
     # Will require models and data saved paths
     wiki_version = "wiki_" + rel_wiki_year
     self.mention_detection = MentionDetection(rel_base_url, wiki_version)
     self.tagger_ner = load_flair_ner("ner-fast")
     config = {
         "mode": "eval",
         "model_path": rel_model_path,
     }
     self.entity_disambiguation = EntityDisambiguation(
         rel_base_url, wiki_version, config)
     self.car_id_to_name_path = car_id_to_name_path
Exemple #6
0
# ------------- RUN SEPARATELY TO BALANCE LOAD--------------------
if not server:
    import flair
    import torch

    from flair.models import SequenceTagger

    from REL.mention_detection import MentionDetection
    from REL.entity_disambiguation import EntityDisambiguation
    from time import time

    base_url = "C:/Users/mickv/desktop/data_back/"

    flair.device = torch.device('cuda:0')

    mention_detection = MentionDetection(base_url, wiki_version)

    # Alternatively use Flair NER tagger.
    tagger_ner = SequenceTagger.load("ner-fast")

    start = time()
    mentions_dataset, n_mentions = mention_detection.find_mentions(
        docs, tagger_ner)
    print('MD took: {}'.format(time() - start))

    # 3. Load model.
    config = {
        "mode": "eval",
        "model_path": "{}/{}/generated/model".format(base_url, wiki_version),
    }
    model = EntityDisambiguation(base_url, wiki_version, config)
Exemple #7
0
    text = """Obama will visit Germany. And have a meeting with Merkel tomorrow.
    Obama will visit Germany. And have a meeting with Merkel tomorrow. Go all the way or blah blah Charles Bukowski."""
    spans = []  # [(0, 5), (17, 7), (50, 6)]
    processed = {"test_doc": [text, spans], "test_doc2": [text, spans]}
    return processed


base_url = "/users/vanhulsm/Desktop/projects/data/"
wiki_version = "wiki_2014"

# 1. Input sentences when using Flair.
input_text = example_preprocessing()

# For Mention detection two options.
# 2. Mention detection, we used the NER tagger, user can also use his/her own mention detection module.
mention_detection = MentionDetection(base_url, wiki_version)

# 2. Alternatively. if you want to use your own MD system (or ngram detection),
# the required input is: {doc_name: [text, spans] ... }.
mentions_dataset, n_mentions = mention_detection.format_spans(input_text)

# 2. Alternative MD module is using an n-gram tagger.
tagger_ner = load_flair_ner("ner-fast")
# tagger_ngram = Cmns(base_url, wiki_version, n=5)

mentions_dataset, n_mentions = mention_detection.find_mentions(
    input_text, tagger_ner)

# 3. Load model.
config = {
    "mode": "eval",
    class GetHandler(BaseHTTPRequestHandler):
        def __init__(self, *args, **kwargs):
            self.model = model
            self.tagger_ner = tagger_ner

            self.base_url = base_url
            self.wiki_version = wiki_version

            self.custom_ner = not isinstance(tagger_ner, SequenceTagger)
            self.mention_detection = MentionDetection(base_url, wiki_version)

            super().__init__(*args, **kwargs)

        def do_GET(self):
            self.send_response(200)
            self.end_headers()
            self.wfile.write(
                bytes(
                    json.dumps({
                        "schemaVersion": 1,
                        "label": "status",
                        "message": "up",
                        "color": "green",
                    }),
                    "utf-8",
                ))
            return

        def do_HEAD(self):
            # send bad request response code
            self.send_response(400)
            self.end_headers()
            self.wfile.write(bytes(json.dumps([]), "utf-8"))
            return

        def do_POST(self):
            """
            Returns response.

            :return:
            """
            try:
                content_length = int(self.headers["Content-Length"])
                post_data = self.rfile.read(content_length)
                self.send_response(200)
                self.end_headers()

                text, spans = self.read_json(post_data)
                response = self.generate_response(text, spans)

                # print('response in server.py code:\n\n {}'.format(response))
                self.wfile.write(bytes(json.dumps(response), "utf-8"))
            except Exception as e:
                print(f"Encountered exception: {repr(e)}")
                self.send_response(400)
                self.end_headers()
                self.wfile.write(bytes(json.dumps([]), "utf-8"))
            return

        def read_json(self, post_data):
            """
            Reads input JSON message.

            :return: document text and spans.
            """

            data = json.loads(post_data.decode("utf-8"))
            text = data["text"]
            text = text.replace("&", "&")

            # GERBIL sends dictionary, users send list of lists.
            if "spans" in data:
                try:
                    spans = [list(d.values()) for d in data["spans"]]
                except Exception:
                    spans = data["spans"]
                    pass
            else:
                spans = []

            return text, spans

        def generate_response(self, text, spans):
            """
            Generates response for API. Can be either ED only or EL, meaning end-to-end.

            :return: list of tuples for each entity found.
            """

            if len(text) == 0:
                return []

            if len(spans) > 0:
                # ED.
                processed = {API_DOC: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.format_spans(
                    processed)
            else:
                # EL
                processed = {API_DOC: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.find_mentions(
                    processed, self.tagger_ner)

            # Disambiguation
            predictions, timing = self.model.predict(mentions_dataset)

            # Process result.
            result = process_results(
                mentions_dataset,
                predictions,
                processed,
                include_offset=False if
                ((len(spans) > 0) or self.custom_ner) else True,
            )

            # Singular document.
            if len(result) > 0:
                return [*result.values()][0]

            return []
Exemple #9
0
    class GetHandler(BaseHTTPRequestHandler):
        def __init__(self, *args, **kwargs):
            self.model = models
            self.tagger_ner = tagger_ner

            self.argss = argss
            self.logger = logger

            self.base_url = base_url
            self.wiki_version = wiki_version

            self.custom_ner = not isinstance(tagger_ner, SequenceTagger)
            self.mention_detection = MentionDetection(base_url, wiki_version)

            super().__init__(*args, **kwargs)

        def do_GET(self):
            self.send_response(200)
            self.end_headers()
            self.wfile.write(
                bytes(
                    json.dumps({
                        "schemaVersion": 1,
                        "label": "status",
                        "message": "up",
                        "color": "green",
                    }),
                    "utf-8",
                ))
            return

        def do_POST(self):
            """
            Returns response.

            :return:
            """
            content_length = int(self.headers["Content-Length"])
            print(content_length)
            post_data = self.rfile.read(content_length)
            self.send_response(200)
            self.end_headers()

            text, spans = self.read_json(post_data)
            response = self.generate_response(text, spans)

            print(response)
            print("=========")

            # print('response in server.py code:\n\n {}'.format(response))
            self.wfile.write(bytes(json.dumps(response), "utf-8"))
            return

        def read_json(self, post_data):
            """
            Reads input JSON message.

            :return: document text and spans.
            """

            data = json.loads(post_data.decode("utf-8"))
            text = data["text"]
            text = text.replace("&", "&")

            # GERBIL sends dictionary, users send list of lists.
            try:
                spans = [list(d.values()) for d in data["spans"]]
            except Exception:
                spans = data["spans"]
                pass

            return text, spans

        def generate_response(self, text, spans):
            """
            Generates response for API. Can be either ED only or EL, meaning end-to-end.

            :return: list of tuples for each entity found.
            """

            if len(text) == 0:
                return []

            if len(spans) > 0:
                # ED.
                processed = {API_DOC: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.format_spans(
                    processed)
            else:
                # EL
                processed = {API_DOC: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.find_mentions(
                    processed, self.tagger_ner)

            # Create to-be-linked dataset.
            data_to_link = []
            temp_m = mentions_dataset[API_DOC]
            for i, m in enumerate(temp_m):
                # Using ngram, which is basically the original mention (without preprocessing as in BLINK's code).
                temp = {
                    "id": i,
                    "label": "unknown",
                    "label_id": -1,
                    "context_left": m["context"][0].lower(),
                    "mention": m["ngram"].lower(),
                    "context_right": m["context"][1].lower(),
                }
                data_to_link.append(temp)
            _, _, _, _, _, predictions, scores, = main_dense.run(
                self.argss, self.logger, *self.model, test_data=data_to_link)

            predictions = {
                API_DOC: [{
                    "prediction": x[0].replace(" ", "_")
                } for x in predictions]
            }
            # Process result.
            result = process_results(
                mentions_dataset,
                predictions,
                processed,
                include_offset=False if
                ((len(spans) > 0) or self.custom_ner) else True,
            )

            # Singular document.
            if len(result) > 0:
                return [*result.values()][0]

            return []
Exemple #10
0
def test_mention_detection_instantiation():
    return MentionDetection(Path(__file__).parent, "wiki_test")
Exemple #11
0
    class GetHandler(BaseHTTPRequestHandler):
        def __init__(self, *args, **kwargs):
            self.model = model
            self.tagger_ner = tagger_ner

            self.mode = mode
            self.include_conf = include_conf
            self.base_url = base_url
            self.wiki_subfolder = wiki_subfolder

            self.custom_ner = not isinstance(tagger_ner, SequenceTagger)
            self.use_offset = False if ((mode == "ED")
                                        or self.custom_ner) else True

            self.doc_cnt = 0
            self.mention_detection = MentionDetection(base_url, wiki_subfolder)

            super().__init__(*args, **kwargs)

        # TODO: lowercase POST
        def do_POST(self):
            """
            Returns response.

            :return:
            """
            content_length = int(self.headers["Content-Length"])
            post_data = self.rfile.read(content_length)
            self.send_response(200)
            self.end_headers()

            text, spans = self.read_json(post_data)
            response = self.generate_response(text, spans)

            # print('response in server.py code:\n\n {}'.format(response))
            self.wfile.write(bytes(json.dumps(response), "utf-8"))
            return

        def read_json(self, post_data):
            """
            Reads input JSON message.

            :return: document text and spans.
            """

            data = json.loads(post_data.decode("utf-8"))
            text = data["text"]
            text = html.unescape(text)
            spans = [(int(j["start"]), int(j["length"]))
                     for j in data["spans"]]
            return text, spans

        def generate_response(self, text, spans):
            """
            Generates response for API. Can be either ED only or EL, meaning end-to-end.

            :return: list of tuples for each entity found.
            """

            n_words = len(text.split())

            if len(text) == 0:
                return []

            start = time.time()
            if (self.mode == "ED") or self.custom_ner:
                if self.custom_ner:
                    spans = self.tagger_ner(text)

                # Verify if we have spans.
                if len(spans) == 0:
                    print("No spans found while in ED mode..?")
                    return []
                processed = {
                    GERBIL: [text, spans]
                }  # self.split_text(text, spans)
                mentions_dataset, total_ment = self.mention_detection.format_spans(
                    processed)
            elif self.mode == "EL":
                # EL
                processed = {GERBIL: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.find_mentions(
                    processed, self.tagger_ner)
            else:
                raise Exception(
                    "Faulty mode, only valid options are: ED or EL")
            time_md = time.time() - start

            # Disambiguation
            start = time.time()
            predictions, timing = self.model.predict(mentions_dataset)
            time_ed = time.time() - start

            # Tuple of.
            efficiency = [
                str(n_words),
                str(total_ment),
                str(time_md),
                str(time_ed)
            ]

            # write to txt file.
            with (self.base_url / self.wiki_subfolder /
                  "generated/efficiency.txt").open('a', encoding='utf-8') as f:
                f.write('\t'.join(efficiency) + '\n')

            # Process result.
            result = process_results(
                mentions_dataset,
                predictions,
                processed,
                include_offset=self.use_offset,
                include_conf=self.include_conf,
            )

            self.doc_cnt += 1
            return result
Exemple #12
0
    # Example splitting, should be of format {doc_1: {sent_idx: [sentence, []]}, .... }}
    text = "Obama will visit Germany. And have a meeting with Merkel tomorrow."
    spans = [(0, 5), (17, 7), (50, 6)]
    processed = {"test_doc": [text, spans], "test_doc2": [text, spans]}
    return processed


base_url = Path("")
wiki_subfolder = "wiki_2019"

# 1. Input sentences when using Flair.
input_documents = example_preprocessing()

# For Mention detection two options.
# 2. Mention detection, we used the NER tagger, user can also use his/her own mention detection module.
mention_detection = MentionDetection(base_url, wiki_subfolder)

# If you want to use your own MD system, the required input is: {doc_name: [text, spans] ... }.
mentions_dataset, n_mentions = mention_detection.format_spans(input_documents)

# Alternatively use Flair NER tagger.
tagger_ner = SequenceTagger.load("ner-fast")
mentions_dataset, n_mentions = mention_detection.find_mentions(
    input_documents, tagger_ner)

# 3. Load model.
config = {
    "mode": "eval",
    "model_path": base_url / wiki_subfolder / "generated" / "model",
}
model = EntityDisambiguation(base_url, wiki_subfolder, config)
Exemple #13
0
    def parse_cbor_to_protobuf(self,
                               read_path,
                               write_path,
                               num_docs,
                               buffer_size=10,
                               print_intervals=100,
                               write_output=True,
                               use_rel=False,
                               rel_base_url=None,
                               rel_wiki_year='2014',
                               rel_model_path=None,
                               car_id_to_name_path=None):
        """ Read TREC CAR cbor file to create a list of protobuffer Document messages
        (protocol_buffers/document.proto:Documents).  This list of messages are streammed to binary file using 'stream'
        package. """

        # list of Document messages.
        documents = []
        t_start = time.time()

        # Use REL (Radbound entity linker): https://github.com/informagi/REL
        if use_rel:
            # Will require models and data saved paths
            wiki_version = "wiki_" + rel_wiki_year
            self.mention_detection = MentionDetection(rel_base_url,
                                                      wiki_version)
            self.tagger_ner = load_flair_ner("ner-fast")
            config = {
                "mode": "eval",
                "model_path": rel_model_path,
            }
            self.entity_disambiguation = EntityDisambiguation(
                rel_base_url, wiki_version, config)
            self.car_id_to_name_path = car_id_to_name_path

        with open(read_path, 'rb') as f_read:

            # Loop over Page objects
            for i, page in enumerate(iter_pages(f_read)):

                # Stops when 'num_pages' have been processed.
                if i >= num_docs:
                    break

                # parse page to create new document.
                self.parse_page_to_protobuf(page=page)

                # Append Document message to document list.
                documents.append(self.document)

                # Prints updates at 'print_pages' intervals.
                if ((i + 1) % print_intervals == 0):
                    print('----- DOC #{} -----'.format(i))
                    print(self.document.doc_id)
                    time_delta = time.time() - t_start
                    print('time elapse: {} --> time / page: {}'.format(
                        time_delta, time_delta / (i + 1)))

        if write_output:
            print('STEAMING DATA TO FILE: {}'.format(write_path))
            self.write_documents_to_file(path=write_path,
                                         documents=documents,
                                         buffer_size=buffer_size)
            print('FILE WRITTEN')

        time_delta = time.time() - t_start
        print('PROCESSED DATA: {} --> processing time / page: {}'.format(
            time_delta, time_delta / (i + 1)))