def test_pipeline(): base_url = Path(__file__).parent wiki_subfolder = "wiki_test" sample = { "test_doc": ["the brown fox jumped over the lazy dog", [[10, 3]]] } config = { "mode": "eval", "model_path": f"{base_url}/{wiki_subfolder}/generated/model", } md = MentionDetection(base_url, wiki_subfolder) tagger = Cmns(base_url, wiki_subfolder, n=5) model = EntityDisambiguation(base_url, wiki_subfolder, config) mentions_dataset, total_mentions = md.format_spans(sample) predictions, _ = model.predict(mentions_dataset) results = process_results(mentions_dataset, predictions, sample, include_offset=False) gold_truth = {"test_doc": [(10, 3, "Fox", "fox", -1, "NULL", 0.0)]} return results == gold_truth
def __init__(self, *args, **kwargs): self.model = model self.tagger_ner = tagger_ner self.base_url = base_url self.wiki_version = wiki_version self.custom_ner = not isinstance(tagger_ner, SequenceTagger) self.mention_detection = MentionDetection(base_url, wiki_version) super().__init__(*args, **kwargs)
def __init__(self, *args, **kwargs): self.model = model self.tagger_ner = tagger_ner self.mode = mode self.include_conf = include_conf self.base_url = base_url self.wiki_subfolder = wiki_subfolder self.custom_ner = not isinstance(tagger_ner, SequenceTagger) self.use_offset = False if ((mode == "ED") or self.custom_ner) else True self.doc_cnt = 0 self.mention_detection = MentionDetection(base_url, wiki_subfolder) super().__init__(*args, **kwargs)
def test_md(): # return standard Flair tagger + mention detection object tagger = SequenceTagger.load("ner-fast") md = MentionDetection(Path(__file__).parent, "wiki_test") # first test case: repeating sentences sample1 = {"test_doc": ["Fox, Fox. Fox.", []]} resulting_spans1 = {(0, 3), (5, 3), (10, 3)} predictions = md.find_mentions(sample1, tagger) predicted_spans = { (m["pos"], m["end_pos"] - m["pos"]) for m in predictions[0]["test_doc"] } assert resulting_spans1 == predicted_spans # second test case: excessive whitespace sample2 = {"test_doc": ["Fox. Fox. Fox.", []]} resulting_spans2 = {(0, 3), (20, 3), (43, 3)} predictions = md.find_mentions(sample2, tagger) predicted_spans = { (m["pos"], m["end_pos"] - m["pos"]) for m in predictions[0]["test_doc"] } assert resulting_spans2 == predicted_spans
def __init_rel_models(self, rel_wiki_year, rel_base_url, rel_model_path, car_id_to_name_path): """ """ # Will require models and data saved paths wiki_version = "wiki_" + rel_wiki_year self.mention_detection = MentionDetection(rel_base_url, wiki_version) self.tagger_ner = load_flair_ner("ner-fast") config = { "mode": "eval", "model_path": rel_model_path, } self.entity_disambiguation = EntityDisambiguation( rel_base_url, wiki_version, config) self.car_id_to_name_path = car_id_to_name_path
# ------------- RUN SEPARATELY TO BALANCE LOAD-------------------- if not server: import flair import torch from flair.models import SequenceTagger from REL.mention_detection import MentionDetection from REL.entity_disambiguation import EntityDisambiguation from time import time base_url = "C:/Users/mickv/desktop/data_back/" flair.device = torch.device('cuda:0') mention_detection = MentionDetection(base_url, wiki_version) # Alternatively use Flair NER tagger. tagger_ner = SequenceTagger.load("ner-fast") start = time() mentions_dataset, n_mentions = mention_detection.find_mentions( docs, tagger_ner) print('MD took: {}'.format(time() - start)) # 3. Load model. config = { "mode": "eval", "model_path": "{}/{}/generated/model".format(base_url, wiki_version), } model = EntityDisambiguation(base_url, wiki_version, config)
text = """Obama will visit Germany. And have a meeting with Merkel tomorrow. Obama will visit Germany. And have a meeting with Merkel tomorrow. Go all the way or blah blah Charles Bukowski.""" spans = [] # [(0, 5), (17, 7), (50, 6)] processed = {"test_doc": [text, spans], "test_doc2": [text, spans]} return processed base_url = "/users/vanhulsm/Desktop/projects/data/" wiki_version = "wiki_2014" # 1. Input sentences when using Flair. input_text = example_preprocessing() # For Mention detection two options. # 2. Mention detection, we used the NER tagger, user can also use his/her own mention detection module. mention_detection = MentionDetection(base_url, wiki_version) # 2. Alternatively. if you want to use your own MD system (or ngram detection), # the required input is: {doc_name: [text, spans] ... }. mentions_dataset, n_mentions = mention_detection.format_spans(input_text) # 2. Alternative MD module is using an n-gram tagger. tagger_ner = load_flair_ner("ner-fast") # tagger_ngram = Cmns(base_url, wiki_version, n=5) mentions_dataset, n_mentions = mention_detection.find_mentions( input_text, tagger_ner) # 3. Load model. config = { "mode": "eval",
class GetHandler(BaseHTTPRequestHandler): def __init__(self, *args, **kwargs): self.model = model self.tagger_ner = tagger_ner self.base_url = base_url self.wiki_version = wiki_version self.custom_ner = not isinstance(tagger_ner, SequenceTagger) self.mention_detection = MentionDetection(base_url, wiki_version) super().__init__(*args, **kwargs) def do_GET(self): self.send_response(200) self.end_headers() self.wfile.write( bytes( json.dumps({ "schemaVersion": 1, "label": "status", "message": "up", "color": "green", }), "utf-8", )) return def do_HEAD(self): # send bad request response code self.send_response(400) self.end_headers() self.wfile.write(bytes(json.dumps([]), "utf-8")) return def do_POST(self): """ Returns response. :return: """ try: content_length = int(self.headers["Content-Length"]) post_data = self.rfile.read(content_length) self.send_response(200) self.end_headers() text, spans = self.read_json(post_data) response = self.generate_response(text, spans) # print('response in server.py code:\n\n {}'.format(response)) self.wfile.write(bytes(json.dumps(response), "utf-8")) except Exception as e: print(f"Encountered exception: {repr(e)}") self.send_response(400) self.end_headers() self.wfile.write(bytes(json.dumps([]), "utf-8")) return def read_json(self, post_data): """ Reads input JSON message. :return: document text and spans. """ data = json.loads(post_data.decode("utf-8")) text = data["text"] text = text.replace("&", "&") # GERBIL sends dictionary, users send list of lists. if "spans" in data: try: spans = [list(d.values()) for d in data["spans"]] except Exception: spans = data["spans"] pass else: spans = [] return text, spans def generate_response(self, text, spans): """ Generates response for API. Can be either ED only or EL, meaning end-to-end. :return: list of tuples for each entity found. """ if len(text) == 0: return [] if len(spans) > 0: # ED. processed = {API_DOC: [text, spans]} mentions_dataset, total_ment = self.mention_detection.format_spans( processed) else: # EL processed = {API_DOC: [text, spans]} mentions_dataset, total_ment = self.mention_detection.find_mentions( processed, self.tagger_ner) # Disambiguation predictions, timing = self.model.predict(mentions_dataset) # Process result. result = process_results( mentions_dataset, predictions, processed, include_offset=False if ((len(spans) > 0) or self.custom_ner) else True, ) # Singular document. if len(result) > 0: return [*result.values()][0] return []
class GetHandler(BaseHTTPRequestHandler): def __init__(self, *args, **kwargs): self.model = models self.tagger_ner = tagger_ner self.argss = argss self.logger = logger self.base_url = base_url self.wiki_version = wiki_version self.custom_ner = not isinstance(tagger_ner, SequenceTagger) self.mention_detection = MentionDetection(base_url, wiki_version) super().__init__(*args, **kwargs) def do_GET(self): self.send_response(200) self.end_headers() self.wfile.write( bytes( json.dumps({ "schemaVersion": 1, "label": "status", "message": "up", "color": "green", }), "utf-8", )) return def do_POST(self): """ Returns response. :return: """ content_length = int(self.headers["Content-Length"]) print(content_length) post_data = self.rfile.read(content_length) self.send_response(200) self.end_headers() text, spans = self.read_json(post_data) response = self.generate_response(text, spans) print(response) print("=========") # print('response in server.py code:\n\n {}'.format(response)) self.wfile.write(bytes(json.dumps(response), "utf-8")) return def read_json(self, post_data): """ Reads input JSON message. :return: document text and spans. """ data = json.loads(post_data.decode("utf-8")) text = data["text"] text = text.replace("&", "&") # GERBIL sends dictionary, users send list of lists. try: spans = [list(d.values()) for d in data["spans"]] except Exception: spans = data["spans"] pass return text, spans def generate_response(self, text, spans): """ Generates response for API. Can be either ED only or EL, meaning end-to-end. :return: list of tuples for each entity found. """ if len(text) == 0: return [] if len(spans) > 0: # ED. processed = {API_DOC: [text, spans]} mentions_dataset, total_ment = self.mention_detection.format_spans( processed) else: # EL processed = {API_DOC: [text, spans]} mentions_dataset, total_ment = self.mention_detection.find_mentions( processed, self.tagger_ner) # Create to-be-linked dataset. data_to_link = [] temp_m = mentions_dataset[API_DOC] for i, m in enumerate(temp_m): # Using ngram, which is basically the original mention (without preprocessing as in BLINK's code). temp = { "id": i, "label": "unknown", "label_id": -1, "context_left": m["context"][0].lower(), "mention": m["ngram"].lower(), "context_right": m["context"][1].lower(), } data_to_link.append(temp) _, _, _, _, _, predictions, scores, = main_dense.run( self.argss, self.logger, *self.model, test_data=data_to_link) predictions = { API_DOC: [{ "prediction": x[0].replace(" ", "_") } for x in predictions] } # Process result. result = process_results( mentions_dataset, predictions, processed, include_offset=False if ((len(spans) > 0) or self.custom_ner) else True, ) # Singular document. if len(result) > 0: return [*result.values()][0] return []
def test_mention_detection_instantiation(): return MentionDetection(Path(__file__).parent, "wiki_test")
class GetHandler(BaseHTTPRequestHandler): def __init__(self, *args, **kwargs): self.model = model self.tagger_ner = tagger_ner self.mode = mode self.include_conf = include_conf self.base_url = base_url self.wiki_subfolder = wiki_subfolder self.custom_ner = not isinstance(tagger_ner, SequenceTagger) self.use_offset = False if ((mode == "ED") or self.custom_ner) else True self.doc_cnt = 0 self.mention_detection = MentionDetection(base_url, wiki_subfolder) super().__init__(*args, **kwargs) # TODO: lowercase POST def do_POST(self): """ Returns response. :return: """ content_length = int(self.headers["Content-Length"]) post_data = self.rfile.read(content_length) self.send_response(200) self.end_headers() text, spans = self.read_json(post_data) response = self.generate_response(text, spans) # print('response in server.py code:\n\n {}'.format(response)) self.wfile.write(bytes(json.dumps(response), "utf-8")) return def read_json(self, post_data): """ Reads input JSON message. :return: document text and spans. """ data = json.loads(post_data.decode("utf-8")) text = data["text"] text = html.unescape(text) spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]] return text, spans def generate_response(self, text, spans): """ Generates response for API. Can be either ED only or EL, meaning end-to-end. :return: list of tuples for each entity found. """ n_words = len(text.split()) if len(text) == 0: return [] start = time.time() if (self.mode == "ED") or self.custom_ner: if self.custom_ner: spans = self.tagger_ner(text) # Verify if we have spans. if len(spans) == 0: print("No spans found while in ED mode..?") return [] processed = { GERBIL: [text, spans] } # self.split_text(text, spans) mentions_dataset, total_ment = self.mention_detection.format_spans( processed) elif self.mode == "EL": # EL processed = {GERBIL: [text, spans]} mentions_dataset, total_ment = self.mention_detection.find_mentions( processed, self.tagger_ner) else: raise Exception( "Faulty mode, only valid options are: ED or EL") time_md = time.time() - start # Disambiguation start = time.time() predictions, timing = self.model.predict(mentions_dataset) time_ed = time.time() - start # Tuple of. efficiency = [ str(n_words), str(total_ment), str(time_md), str(time_ed) ] # write to txt file. with (self.base_url / self.wiki_subfolder / "generated/efficiency.txt").open('a', encoding='utf-8') as f: f.write('\t'.join(efficiency) + '\n') # Process result. result = process_results( mentions_dataset, predictions, processed, include_offset=self.use_offset, include_conf=self.include_conf, ) self.doc_cnt += 1 return result
# Example splitting, should be of format {doc_1: {sent_idx: [sentence, []]}, .... }} text = "Obama will visit Germany. And have a meeting with Merkel tomorrow." spans = [(0, 5), (17, 7), (50, 6)] processed = {"test_doc": [text, spans], "test_doc2": [text, spans]} return processed base_url = Path("") wiki_subfolder = "wiki_2019" # 1. Input sentences when using Flair. input_documents = example_preprocessing() # For Mention detection two options. # 2. Mention detection, we used the NER tagger, user can also use his/her own mention detection module. mention_detection = MentionDetection(base_url, wiki_subfolder) # If you want to use your own MD system, the required input is: {doc_name: [text, spans] ... }. mentions_dataset, n_mentions = mention_detection.format_spans(input_documents) # Alternatively use Flair NER tagger. tagger_ner = SequenceTagger.load("ner-fast") mentions_dataset, n_mentions = mention_detection.find_mentions( input_documents, tagger_ner) # 3. Load model. config = { "mode": "eval", "model_path": base_url / wiki_subfolder / "generated" / "model", } model = EntityDisambiguation(base_url, wiki_subfolder, config)
def parse_cbor_to_protobuf(self, read_path, write_path, num_docs, buffer_size=10, print_intervals=100, write_output=True, use_rel=False, rel_base_url=None, rel_wiki_year='2014', rel_model_path=None, car_id_to_name_path=None): """ Read TREC CAR cbor file to create a list of protobuffer Document messages (protocol_buffers/document.proto:Documents). This list of messages are streammed to binary file using 'stream' package. """ # list of Document messages. documents = [] t_start = time.time() # Use REL (Radbound entity linker): https://github.com/informagi/REL if use_rel: # Will require models and data saved paths wiki_version = "wiki_" + rel_wiki_year self.mention_detection = MentionDetection(rel_base_url, wiki_version) self.tagger_ner = load_flair_ner("ner-fast") config = { "mode": "eval", "model_path": rel_model_path, } self.entity_disambiguation = EntityDisambiguation( rel_base_url, wiki_version, config) self.car_id_to_name_path = car_id_to_name_path with open(read_path, 'rb') as f_read: # Loop over Page objects for i, page in enumerate(iter_pages(f_read)): # Stops when 'num_pages' have been processed. if i >= num_docs: break # parse page to create new document. self.parse_page_to_protobuf(page=page) # Append Document message to document list. documents.append(self.document) # Prints updates at 'print_pages' intervals. if ((i + 1) % print_intervals == 0): print('----- DOC #{} -----'.format(i)) print(self.document.doc_id) time_delta = time.time() - t_start print('time elapse: {} --> time / page: {}'.format( time_delta, time_delta / (i + 1))) if write_output: print('STEAMING DATA TO FILE: {}'.format(write_path)) self.write_documents_to_file(path=write_path, documents=documents, buffer_size=buffer_size) print('FILE WRITTEN') time_delta = time.time() - t_start print('PROCESSED DATA: {} --> processing time / page: {}'.format( time_delta, time_delta / (i + 1)))