Ejemplo n.º 1
0
def test_md():
    # return standard Flair tagger + mention detection object
    tagger = SequenceTagger.load("ner-fast")
    md = MentionDetection(Path(__file__).parent, "wiki_test")

    # first test case: repeating sentences
    sample1 = {"test_doc": ["Fox, Fox. Fox.", []]}
    resulting_spans1 = {(0, 3), (5, 3), (10, 3)}
    predictions = md.find_mentions(sample1, tagger)
    predicted_spans = {
        (m["pos"], m["end_pos"] - m["pos"]) for m in predictions[0]["test_doc"]
    }
    assert resulting_spans1 == predicted_spans

    # second test case: excessive whitespace
    sample2 = {"test_doc": ["Fox.                Fox.                   Fox.", []]}
    resulting_spans2 = {(0, 3), (20, 3), (43, 3)}
    predictions = md.find_mentions(sample2, tagger)
    predicted_spans = {
        (m["pos"], m["end_pos"] - m["pos"]) for m in predictions[0]["test_doc"]
    }
    assert resulting_spans2 == predicted_spans
Ejemplo n.º 2
0
    from flair.models import SequenceTagger

    from REL.mention_detection import MentionDetection
    from REL.entity_disambiguation import EntityDisambiguation
    from time import time

    base_url = "C:/Users/mickv/desktop/data_back/"

    flair.device = torch.device('cuda:0')

    mention_detection = MentionDetection(base_url, wiki_version)

    # Alternatively use Flair NER tagger.
    tagger_ner = SequenceTagger.load("ner-fast")

    start = time()
    mentions_dataset, n_mentions = mention_detection.find_mentions(
        docs, tagger_ner)
    print('MD took: {}'.format(time() - start))

    # 3. Load model.
    config = {
        "mode": "eval",
        "model_path": "{}/{}/generated/model".format(base_url, wiki_version),
    }
    model = EntityDisambiguation(base_url, wiki_version, config)

    # 4. Entity disambiguation.
    start = time()
    predictions, timing = model.predict(mentions_dataset)
    print('ED took: {}'.format(time() - start))
Ejemplo n.º 3
0
# 1. Input sentences when using Flair.
input_text = example_preprocessing()

# For Mention detection two options.
# 2. Mention detection, we used the NER tagger, user can also use his/her own mention detection module.
mention_detection = MentionDetection(base_url, wiki_version)

# 2. Alternatively. if you want to use your own MD system (or ngram detection),
# the required input is: {doc_name: [text, spans] ... }.
mentions_dataset, n_mentions = mention_detection.format_spans(input_text)

# 2. Alternative MD module is using an n-gram tagger.
tagger_ner = load_flair_ner("ner-fast")
# tagger_ngram = Cmns(base_url, wiki_version, n=5)

mentions_dataset, n_mentions = mention_detection.find_mentions(
    input_text, tagger_ner)

# 3. Load model.
config = {
    "mode": "eval",
    "model_path": "{}/{}/generated/model".format(base_url, wiki_version),
}
model = EntityDisambiguation(base_url, wiki_version, config)

# 4. Entity disambiguation.
predictions, timing = model.predict(mentions_dataset)

# 5. Optionally use our function to get results in a usable format.
result = process_results(mentions_dataset, predictions, input_text)

print(result)
Ejemplo n.º 4
0
    class GetHandler(BaseHTTPRequestHandler):
        def __init__(self, *args, **kwargs):
            self.model = model
            self.tagger_ner = tagger_ner

            self.base_url = base_url
            self.wiki_version = wiki_version

            self.custom_ner = not isinstance(tagger_ner, SequenceTagger)
            self.mention_detection = MentionDetection(base_url, wiki_version)

            super().__init__(*args, **kwargs)

        def do_GET(self):
            self.send_response(200)
            self.end_headers()
            self.wfile.write(
                bytes(
                    json.dumps({
                        "schemaVersion": 1,
                        "label": "status",
                        "message": "up",
                        "color": "green",
                    }),
                    "utf-8",
                ))
            return

        def do_HEAD(self):
            # send bad request response code
            self.send_response(400)
            self.end_headers()
            self.wfile.write(bytes(json.dumps([]), "utf-8"))
            return

        def do_POST(self):
            """
            Returns response.

            :return:
            """
            try:
                content_length = int(self.headers["Content-Length"])
                post_data = self.rfile.read(content_length)
                self.send_response(200)
                self.end_headers()

                text, spans = self.read_json(post_data)
                response = self.generate_response(text, spans)

                # print('response in server.py code:\n\n {}'.format(response))
                self.wfile.write(bytes(json.dumps(response), "utf-8"))
            except Exception as e:
                print(f"Encountered exception: {repr(e)}")
                self.send_response(400)
                self.end_headers()
                self.wfile.write(bytes(json.dumps([]), "utf-8"))
            return

        def read_json(self, post_data):
            """
            Reads input JSON message.

            :return: document text and spans.
            """

            data = json.loads(post_data.decode("utf-8"))
            text = data["text"]
            text = text.replace("&", "&")

            # GERBIL sends dictionary, users send list of lists.
            if "spans" in data:
                try:
                    spans = [list(d.values()) for d in data["spans"]]
                except Exception:
                    spans = data["spans"]
                    pass
            else:
                spans = []

            return text, spans

        def generate_response(self, text, spans):
            """
            Generates response for API. Can be either ED only or EL, meaning end-to-end.

            :return: list of tuples for each entity found.
            """

            if len(text) == 0:
                return []

            if len(spans) > 0:
                # ED.
                processed = {API_DOC: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.format_spans(
                    processed)
            else:
                # EL
                processed = {API_DOC: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.find_mentions(
                    processed, self.tagger_ner)

            # Disambiguation
            predictions, timing = self.model.predict(mentions_dataset)

            # Process result.
            result = process_results(
                mentions_dataset,
                predictions,
                processed,
                include_offset=False if
                ((len(spans) > 0) or self.custom_ner) else True,
            )

            # Singular document.
            if len(result) > 0:
                return [*result.values()][0]

            return []
Ejemplo n.º 5
0
    class GetHandler(BaseHTTPRequestHandler):
        def __init__(self, *args, **kwargs):
            self.model = models
            self.tagger_ner = tagger_ner

            self.argss = argss
            self.logger = logger

            self.base_url = base_url
            self.wiki_version = wiki_version

            self.custom_ner = not isinstance(tagger_ner, SequenceTagger)
            self.mention_detection = MentionDetection(base_url, wiki_version)

            super().__init__(*args, **kwargs)

        def do_GET(self):
            self.send_response(200)
            self.end_headers()
            self.wfile.write(
                bytes(
                    json.dumps({
                        "schemaVersion": 1,
                        "label": "status",
                        "message": "up",
                        "color": "green",
                    }),
                    "utf-8",
                ))
            return

        def do_POST(self):
            """
            Returns response.

            :return:
            """
            content_length = int(self.headers["Content-Length"])
            print(content_length)
            post_data = self.rfile.read(content_length)
            self.send_response(200)
            self.end_headers()

            text, spans = self.read_json(post_data)
            response = self.generate_response(text, spans)

            print(response)
            print("=========")

            # print('response in server.py code:\n\n {}'.format(response))
            self.wfile.write(bytes(json.dumps(response), "utf-8"))
            return

        def read_json(self, post_data):
            """
            Reads input JSON message.

            :return: document text and spans.
            """

            data = json.loads(post_data.decode("utf-8"))
            text = data["text"]
            text = text.replace("&", "&")

            # GERBIL sends dictionary, users send list of lists.
            try:
                spans = [list(d.values()) for d in data["spans"]]
            except Exception:
                spans = data["spans"]
                pass

            return text, spans

        def generate_response(self, text, spans):
            """
            Generates response for API. Can be either ED only or EL, meaning end-to-end.

            :return: list of tuples for each entity found.
            """

            if len(text) == 0:
                return []

            if len(spans) > 0:
                # ED.
                processed = {API_DOC: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.format_spans(
                    processed)
            else:
                # EL
                processed = {API_DOC: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.find_mentions(
                    processed, self.tagger_ner)

            # Create to-be-linked dataset.
            data_to_link = []
            temp_m = mentions_dataset[API_DOC]
            for i, m in enumerate(temp_m):
                # Using ngram, which is basically the original mention (without preprocessing as in BLINK's code).
                temp = {
                    "id": i,
                    "label": "unknown",
                    "label_id": -1,
                    "context_left": m["context"][0].lower(),
                    "mention": m["ngram"].lower(),
                    "context_right": m["context"][1].lower(),
                }
                data_to_link.append(temp)
            _, _, _, _, _, predictions, scores, = main_dense.run(
                self.argss, self.logger, *self.model, test_data=data_to_link)

            predictions = {
                API_DOC: [{
                    "prediction": x[0].replace(" ", "_")
                } for x in predictions]
            }
            # Process result.
            result = process_results(
                mentions_dataset,
                predictions,
                processed,
                include_offset=False if
                ((len(spans) > 0) or self.custom_ner) else True,
            )

            # Singular document.
            if len(result) > 0:
                return [*result.values()][0]

            return []
Ejemplo n.º 6
0
    class GetHandler(BaseHTTPRequestHandler):
        def __init__(self, *args, **kwargs):
            self.model = model
            self.tagger_ner = tagger_ner

            self.mode = mode
            self.include_conf = include_conf
            self.base_url = base_url
            self.wiki_subfolder = wiki_subfolder

            self.custom_ner = not isinstance(tagger_ner, SequenceTagger)
            self.use_offset = False if ((mode == "ED")
                                        or self.custom_ner) else True

            self.doc_cnt = 0
            self.mention_detection = MentionDetection(base_url, wiki_subfolder)

            super().__init__(*args, **kwargs)

        # TODO: lowercase POST
        def do_POST(self):
            """
            Returns response.

            :return:
            """
            content_length = int(self.headers["Content-Length"])
            post_data = self.rfile.read(content_length)
            self.send_response(200)
            self.end_headers()

            text, spans = self.read_json(post_data)
            response = self.generate_response(text, spans)

            # print('response in server.py code:\n\n {}'.format(response))
            self.wfile.write(bytes(json.dumps(response), "utf-8"))
            return

        def read_json(self, post_data):
            """
            Reads input JSON message.

            :return: document text and spans.
            """

            data = json.loads(post_data.decode("utf-8"))
            text = data["text"]
            text = html.unescape(text)
            spans = [(int(j["start"]), int(j["length"]))
                     for j in data["spans"]]
            return text, spans

        def generate_response(self, text, spans):
            """
            Generates response for API. Can be either ED only or EL, meaning end-to-end.

            :return: list of tuples for each entity found.
            """

            n_words = len(text.split())

            if len(text) == 0:
                return []

            start = time.time()
            if (self.mode == "ED") or self.custom_ner:
                if self.custom_ner:
                    spans = self.tagger_ner(text)

                # Verify if we have spans.
                if len(spans) == 0:
                    print("No spans found while in ED mode..?")
                    return []
                processed = {
                    GERBIL: [text, spans]
                }  # self.split_text(text, spans)
                mentions_dataset, total_ment = self.mention_detection.format_spans(
                    processed)
            elif self.mode == "EL":
                # EL
                processed = {GERBIL: [text, spans]}
                mentions_dataset, total_ment = self.mention_detection.find_mentions(
                    processed, self.tagger_ner)
            else:
                raise Exception(
                    "Faulty mode, only valid options are: ED or EL")
            time_md = time.time() - start

            # Disambiguation
            start = time.time()
            predictions, timing = self.model.predict(mentions_dataset)
            time_ed = time.time() - start

            # Tuple of.
            efficiency = [
                str(n_words),
                str(total_ment),
                str(time_md),
                str(time_ed)
            ]

            # write to txt file.
            with (self.base_url / self.wiki_subfolder /
                  "generated/efficiency.txt").open('a', encoding='utf-8') as f:
                f.write('\t'.join(efficiency) + '\n')

            # Process result.
            result = process_results(
                mentions_dataset,
                predictions,
                processed,
                include_offset=self.use_offset,
                include_conf=self.include_conf,
            )

            self.doc_cnt += 1
            return result
Ejemplo n.º 7
0
base_url = Path("")
wiki_subfolder = "wiki_2019"

# 1. Input sentences when using Flair.
input_documents = example_preprocessing()

# For Mention detection two options.
# 2. Mention detection, we used the NER tagger, user can also use his/her own mention detection module.
mention_detection = MentionDetection(base_url, wiki_subfolder)

# If you want to use your own MD system, the required input is: {doc_name: [text, spans] ... }.
mentions_dataset, n_mentions = mention_detection.format_spans(input_documents)

# Alternatively use Flair NER tagger.
tagger_ner = SequenceTagger.load("ner-fast")
mentions_dataset, n_mentions = mention_detection.find_mentions(
    input_documents, tagger_ner)

# 3. Load model.
config = {
    "mode": "eval",
    "model_path": base_url / wiki_subfolder / "generated" / "model",
}
model = EntityDisambiguation(base_url, wiki_subfolder, config)

# 4. Entity disambiguation.
predictions, timing = model.predict(mentions_dataset)

# 5. Optionally use our function to get results in a usable format.
result = process_results(mentions_dataset,
                         predictions,
                         input_documents,