def test_pipeline(): base_url = Path(__file__).parent wiki_subfolder = "wiki_test" sample = { "test_doc": ["the brown fox jumped over the lazy dog", [[10, 3]]] } config = { "mode": "eval", "model_path": f"{base_url}/{wiki_subfolder}/generated/model", } md = MentionDetection(base_url, wiki_subfolder) tagger = Cmns(base_url, wiki_subfolder, n=5) model = EntityDisambiguation(base_url, wiki_subfolder, config) mentions_dataset, total_mentions = md.format_spans(sample) predictions, _ = model.predict(mentions_dataset) results = process_results(mentions_dataset, predictions, sample, include_offset=False) gold_truth = {"test_doc": [(10, 3, "Fox", "fox", -1, "NULL", 0.0)]} return results == gold_truth
return processed base_url = "/users/vanhulsm/Desktop/projects/data/" wiki_version = "wiki_2014" # 1. Input sentences when using Flair. input_text = example_preprocessing() # For Mention detection two options. # 2. Mention detection, we used the NER tagger, user can also use his/her own mention detection module. mention_detection = MentionDetection(base_url, wiki_version) # 2. Alternatively. if you want to use your own MD system (or ngram detection), # the required input is: {doc_name: [text, spans] ... }. mentions_dataset, n_mentions = mention_detection.format_spans(input_text) # 2. Alternative MD module is using an n-gram tagger. tagger_ner = load_flair_ner("ner-fast") # tagger_ngram = Cmns(base_url, wiki_version, n=5) mentions_dataset, n_mentions = mention_detection.find_mentions( input_text, tagger_ner) # 3. Load model. config = { "mode": "eval", "model_path": "{}/{}/generated/model".format(base_url, wiki_version), } model = EntityDisambiguation(base_url, wiki_version, config)
class GetHandler(BaseHTTPRequestHandler): def __init__(self, *args, **kwargs): self.model = model self.tagger_ner = tagger_ner self.base_url = base_url self.wiki_version = wiki_version self.custom_ner = not isinstance(tagger_ner, SequenceTagger) self.mention_detection = MentionDetection(base_url, wiki_version) super().__init__(*args, **kwargs) def do_GET(self): self.send_response(200) self.end_headers() self.wfile.write( bytes( json.dumps({ "schemaVersion": 1, "label": "status", "message": "up", "color": "green", }), "utf-8", )) return def do_HEAD(self): # send bad request response code self.send_response(400) self.end_headers() self.wfile.write(bytes(json.dumps([]), "utf-8")) return def do_POST(self): """ Returns response. :return: """ try: content_length = int(self.headers["Content-Length"]) post_data = self.rfile.read(content_length) self.send_response(200) self.end_headers() text, spans = self.read_json(post_data) response = self.generate_response(text, spans) # print('response in server.py code:\n\n {}'.format(response)) self.wfile.write(bytes(json.dumps(response), "utf-8")) except Exception as e: print(f"Encountered exception: {repr(e)}") self.send_response(400) self.end_headers() self.wfile.write(bytes(json.dumps([]), "utf-8")) return def read_json(self, post_data): """ Reads input JSON message. :return: document text and spans. """ data = json.loads(post_data.decode("utf-8")) text = data["text"] text = text.replace("&", "&") # GERBIL sends dictionary, users send list of lists. if "spans" in data: try: spans = [list(d.values()) for d in data["spans"]] except Exception: spans = data["spans"] pass else: spans = [] return text, spans def generate_response(self, text, spans): """ Generates response for API. Can be either ED only or EL, meaning end-to-end. :return: list of tuples for each entity found. """ if len(text) == 0: return [] if len(spans) > 0: # ED. processed = {API_DOC: [text, spans]} mentions_dataset, total_ment = self.mention_detection.format_spans( processed) else: # EL processed = {API_DOC: [text, spans]} mentions_dataset, total_ment = self.mention_detection.find_mentions( processed, self.tagger_ner) # Disambiguation predictions, timing = self.model.predict(mentions_dataset) # Process result. result = process_results( mentions_dataset, predictions, processed, include_offset=False if ((len(spans) > 0) or self.custom_ner) else True, ) # Singular document. if len(result) > 0: return [*result.values()][0] return []
class GetHandler(BaseHTTPRequestHandler): def __init__(self, *args, **kwargs): self.model = models self.tagger_ner = tagger_ner self.argss = argss self.logger = logger self.base_url = base_url self.wiki_version = wiki_version self.custom_ner = not isinstance(tagger_ner, SequenceTagger) self.mention_detection = MentionDetection(base_url, wiki_version) super().__init__(*args, **kwargs) def do_GET(self): self.send_response(200) self.end_headers() self.wfile.write( bytes( json.dumps({ "schemaVersion": 1, "label": "status", "message": "up", "color": "green", }), "utf-8", )) return def do_POST(self): """ Returns response. :return: """ content_length = int(self.headers["Content-Length"]) print(content_length) post_data = self.rfile.read(content_length) self.send_response(200) self.end_headers() text, spans = self.read_json(post_data) response = self.generate_response(text, spans) print(response) print("=========") # print('response in server.py code:\n\n {}'.format(response)) self.wfile.write(bytes(json.dumps(response), "utf-8")) return def read_json(self, post_data): """ Reads input JSON message. :return: document text and spans. """ data = json.loads(post_data.decode("utf-8")) text = data["text"] text = text.replace("&", "&") # GERBIL sends dictionary, users send list of lists. try: spans = [list(d.values()) for d in data["spans"]] except Exception: spans = data["spans"] pass return text, spans def generate_response(self, text, spans): """ Generates response for API. Can be either ED only or EL, meaning end-to-end. :return: list of tuples for each entity found. """ if len(text) == 0: return [] if len(spans) > 0: # ED. processed = {API_DOC: [text, spans]} mentions_dataset, total_ment = self.mention_detection.format_spans( processed) else: # EL processed = {API_DOC: [text, spans]} mentions_dataset, total_ment = self.mention_detection.find_mentions( processed, self.tagger_ner) # Create to-be-linked dataset. data_to_link = [] temp_m = mentions_dataset[API_DOC] for i, m in enumerate(temp_m): # Using ngram, which is basically the original mention (without preprocessing as in BLINK's code). temp = { "id": i, "label": "unknown", "label_id": -1, "context_left": m["context"][0].lower(), "mention": m["ngram"].lower(), "context_right": m["context"][1].lower(), } data_to_link.append(temp) _, _, _, _, _, predictions, scores, = main_dense.run( self.argss, self.logger, *self.model, test_data=data_to_link) predictions = { API_DOC: [{ "prediction": x[0].replace(" ", "_") } for x in predictions] } # Process result. result = process_results( mentions_dataset, predictions, processed, include_offset=False if ((len(spans) > 0) or self.custom_ner) else True, ) # Singular document. if len(result) > 0: return [*result.values()][0] return []
class GetHandler(BaseHTTPRequestHandler): def __init__(self, *args, **kwargs): self.model = model self.tagger_ner = tagger_ner self.mode = mode self.include_conf = include_conf self.base_url = base_url self.wiki_subfolder = wiki_subfolder self.custom_ner = not isinstance(tagger_ner, SequenceTagger) self.use_offset = False if ((mode == "ED") or self.custom_ner) else True self.doc_cnt = 0 self.mention_detection = MentionDetection(base_url, wiki_subfolder) super().__init__(*args, **kwargs) # TODO: lowercase POST def do_POST(self): """ Returns response. :return: """ content_length = int(self.headers["Content-Length"]) post_data = self.rfile.read(content_length) self.send_response(200) self.end_headers() text, spans = self.read_json(post_data) response = self.generate_response(text, spans) # print('response in server.py code:\n\n {}'.format(response)) self.wfile.write(bytes(json.dumps(response), "utf-8")) return def read_json(self, post_data): """ Reads input JSON message. :return: document text and spans. """ data = json.loads(post_data.decode("utf-8")) text = data["text"] text = html.unescape(text) spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]] return text, spans def generate_response(self, text, spans): """ Generates response for API. Can be either ED only or EL, meaning end-to-end. :return: list of tuples for each entity found. """ n_words = len(text.split()) if len(text) == 0: return [] start = time.time() if (self.mode == "ED") or self.custom_ner: if self.custom_ner: spans = self.tagger_ner(text) # Verify if we have spans. if len(spans) == 0: print("No spans found while in ED mode..?") return [] processed = { GERBIL: [text, spans] } # self.split_text(text, spans) mentions_dataset, total_ment = self.mention_detection.format_spans( processed) elif self.mode == "EL": # EL processed = {GERBIL: [text, spans]} mentions_dataset, total_ment = self.mention_detection.find_mentions( processed, self.tagger_ner) else: raise Exception( "Faulty mode, only valid options are: ED or EL") time_md = time.time() - start # Disambiguation start = time.time() predictions, timing = self.model.predict(mentions_dataset) time_ed = time.time() - start # Tuple of. efficiency = [ str(n_words), str(total_ment), str(time_md), str(time_ed) ] # write to txt file. with (self.base_url / self.wiki_subfolder / "generated/efficiency.txt").open('a', encoding='utf-8') as f: f.write('\t'.join(efficiency) + '\n') # Process result. result = process_results( mentions_dataset, predictions, processed, include_offset=self.use_offset, include_conf=self.include_conf, ) self.doc_cnt += 1 return result
processed = {"test_doc": [text, spans], "test_doc2": [text, spans]} return processed base_url = Path("") wiki_subfolder = "wiki_2019" # 1. Input sentences when using Flair. input_documents = example_preprocessing() # For Mention detection two options. # 2. Mention detection, we used the NER tagger, user can also use his/her own mention detection module. mention_detection = MentionDetection(base_url, wiki_subfolder) # If you want to use your own MD system, the required input is: {doc_name: [text, spans] ... }. mentions_dataset, n_mentions = mention_detection.format_spans(input_documents) # Alternatively use Flair NER tagger. tagger_ner = SequenceTagger.load("ner-fast") mentions_dataset, n_mentions = mention_detection.find_mentions( input_documents, tagger_ner) # 3. Load model. config = { "mode": "eval", "model_path": base_url / wiki_subfolder / "generated" / "model", } model = EntityDisambiguation(base_url, wiki_subfolder, config) # 4. Entity disambiguation. predictions, timing = model.predict(mentions_dataset)