def test_pipeline(): base_url = Path(__file__).parent wiki_subfolder = "wiki_test" sample = { "test_doc": ["the brown fox jumped over the lazy dog", [[10, 3]]] } config = { "mode": "eval", "model_path": f"{base_url}/{wiki_subfolder}/generated/model", } md = MentionDetection(base_url, wiki_subfolder) tagger = Cmns(base_url, wiki_subfolder, n=5) model = EntityDisambiguation(base_url, wiki_subfolder, config) mentions_dataset, total_mentions = md.format_spans(sample) predictions, _ = model.predict(mentions_dataset) results = process_results(mentions_dataset, predictions, sample, include_offset=False) gold_truth = {"test_doc": [(10, 3, "Fox", "fox", -1, "NULL", 0.0)]} return results == gold_truth
def test_entity_disambiguation_instantiation(): return EntityDisambiguation( Path(__file__).parent, "wiki_test", { "mode": "eval", "model_path": Path(__file__).parent / "wiki_test" / "generated" / "model", }, )
def __init_rel_models(self, rel_wiki_year, rel_base_url, rel_model_path, car_id_to_name_path): """ """ # Will require models and data saved paths wiki_version = "wiki_" + rel_wiki_year self.mention_detection = MentionDetection(rel_base_url, wiki_version) self.tagger_ner = load_flair_ner("ner-fast") config = { "mode": "eval", "model_path": rel_model_path, } self.entity_disambiguation = EntityDisambiguation( rel_base_url, wiki_version, config) self.car_id_to_name_path = car_id_to_name_path
from flair.models import SequenceTagger from REL.mention_detection import MentionDetection from REL.entity_disambiguation import EntityDisambiguation from time import time base_url = "C:/Users/mickv/desktop/data_back/" flair.device = torch.device('cuda:0') mention_detection = MentionDetection(base_url, wiki_version) # Alternatively use Flair NER tagger. tagger_ner = SequenceTagger.load("ner-fast") start = time() mentions_dataset, n_mentions = mention_detection.find_mentions( docs, tagger_ner) print('MD took: {}'.format(time() - start)) # 3. Load model. config = { "mode": "eval", "model_path": "{}/{}/generated/model".format(base_url, wiki_version), } model = EntityDisambiguation(base_url, wiki_version, config) # 4. Entity disambiguation. start = time() predictions, timing = model.predict(mentions_dataset) print('ED took: {}'.format(time() - start))
from REL.entity_disambiguation import EntityDisambiguation from REL.training_datasets import TrainingEvaluationDatasets base_url = "/users/vanhulsm/Desktop/projects/data/" wiki_version = "wiki_2019" # 1. Load datasets # '/mnt/c/Users/mickv/Google Drive/projects/entity_tagging/deep-ed/data/wiki_2019/' datasets = TrainingEvaluationDatasets(base_url, wiki_version).load() # 2. Init model, where user can set his/her own config that will overwrite the default config. config = { "mode": "eval", "model_path": "{}/{}/generated/model".format(base_url, wiki_version), } model = EntityDisambiguation(base_url, wiki_version, config) # 3. Train and predict using LR model_path_lr = "{}/{}/generated/".format(base_url, wiki_version) model.train_LR(datasets, model_path_lr)
from REL.entity_disambiguation import EntityDisambiguation from REL.training_datasets import TrainingEvaluationDatasets base_url = "/users/vanhulsm/Desktop/projects/data/" wiki_version = "wiki_2014" # 1. Load datasets # '/mnt/c/Users/mickv/Google Drive/projects/entity_tagging/deep-ed/data/wiki_2019/' datasets = TrainingEvaluationDatasets(base_url, wiki_version).load() # 2. Init model, where user can set his/her own config that will overwrite the default config. config = { "mode": "eval", "model_path": "{}/{}/generated/model".format(base_url, wiki_version), } model = EntityDisambiguation(base_url, wiki_version, config) # 3. Train or evaluate model. if config["mode"] == "train": model.train( datasets["aida_train"], {k: v for k, v in datasets.items() if k != "aida_train"} ) else: model.evaluate({k: v for k, v in datasets.items() if "train" not in k})
from REL.entity_disambiguation import EntityDisambiguation from REL.ner import load_flair_ner from REL.server import make_handler # 0. Set your project url, which is used as a reference for your datasets etc. base_url = "" wiki_version = "wiki_2019" # 1. Init model, where user can set his/her own config that will overwrite the default config. # If mode is equal to 'eval', then the model_path should point to an existing model. config = { "mode": "eval", "model_path": "{}/{}/generated/model".format(base_url, wiki_version), } model = EntityDisambiguation(base_url, wiki_version, config) # 2. Create NER-tagger. tagger_ner = load_flair_ner("ner-fast") # or another tagger # 3. Init server. server_address = ("127.0.0.1", 5555) server = HTTPServer( server_address, make_handler(base_url, wiki_version, model, tagger_ner), ) try: print("Ready for listening.") server.serve_forever() except KeyboardInterrupt:
from http.server import HTTPServer from REL.entity_disambiguation import EntityDisambiguation from REL.ner import load_flair_ner p = argparse.ArgumentParser() p.add_argument("base_url") p.add_argument("wiki_version") p.add_argument("--ed-model", default="ed-wiki-2019") p.add_argument("--ner-model", default="ner-fast") p.add_argument("--bind", "-b", metavar="ADDRESS", default="0.0.0.0") p.add_argument("--port", "-p", default=5555, type=int) args = p.parse_args() ner_model = load_flair_ner(args.ner_model) ed_model = EntityDisambiguation(args.base_url, args.wiki_version, { "mode": "eval", "model_path": args.ed_model }) server_address = (args.bind, args.port) server = HTTPServer( server_address, make_handler(args.base_url, args.wiki_version, ed_model, ner_model), ) try: print("Ready for listening.") server.serve_forever() except KeyboardInterrupt: exit(0)
def parse_cbor_to_protobuf(self, read_path, write_path, num_docs, buffer_size=10, print_intervals=100, write_output=True, use_rel=False, rel_base_url=None, rel_wiki_year='2014', rel_model_path=None, car_id_to_name_path=None): """ Read TREC CAR cbor file to create a list of protobuffer Document messages (protocol_buffers/document.proto:Documents). This list of messages are streammed to binary file using 'stream' package. """ # list of Document messages. documents = [] t_start = time.time() # Use REL (Radbound entity linker): https://github.com/informagi/REL if use_rel: # Will require models and data saved paths wiki_version = "wiki_" + rel_wiki_year self.mention_detection = MentionDetection(rel_base_url, wiki_version) self.tagger_ner = load_flair_ner("ner-fast") config = { "mode": "eval", "model_path": rel_model_path, } self.entity_disambiguation = EntityDisambiguation( rel_base_url, wiki_version, config) self.car_id_to_name_path = car_id_to_name_path with open(read_path, 'rb') as f_read: # Loop over Page objects for i, page in enumerate(iter_pages(f_read)): # Stops when 'num_pages' have been processed. if i >= num_docs: break # parse page to create new document. self.parse_page_to_protobuf(page=page) # Append Document message to document list. documents.append(self.document) # Prints updates at 'print_pages' intervals. if ((i + 1) % print_intervals == 0): print('----- DOC #{} -----'.format(i)) print(self.document.doc_id) time_delta = time.time() - t_start print('time elapse: {} --> time / page: {}'.format( time_delta, time_delta / (i + 1))) if write_output: print('STEAMING DATA TO FILE: {}'.format(write_path)) self.write_documents_to_file(path=write_path, documents=documents, buffer_size=buffer_size) print('FILE WRITTEN') time_delta = time.time() - t_start print('PROCESSED DATA: {} --> processing time / page: {}'.format( time_delta, time_delta / (i + 1)))
spans = [(0, 5), (17, 7), (50, 6)] return spans # 0. Set your project url, which is used as a reference for your datasets etc. base_url = Path("") wiki_subfolder = "wiki_2019" # 1. Init model, where user can set his/her own config that will overwrite the default config. # If mode is equal to 'eval', then the model_path should point to an existing model. config = { "mode": "eval", "model_path": base_url / wiki_subfolder / "generated" / "model", } model = EntityDisambiguation(base_url, wiki_subfolder, config) # 2. Create NER-tagger. tagger_ner = SequenceTagger.load("ner-fast") # 2.1. Alternatively, one can create his/her own NER-tagger that given a text, # returns a list with spans (start_pos, length). # tagger_ner = user_func # 3. Init server. mode = "EL" server_address = ("localhost", 5555) server = HTTPServer( server_address, make_handler( base_url, wiki_subfolder, model, tagger_ner, mode=mode, include_conf=True