def __init__(self, base_url, wiki_subfolder): self.cnt_exact = 0 self.cnt_partial = 0 self.cnt_total = 0 self.wiki_db = GenericLookup( "entity_word_embedding", "{}/{}/generated/".format(base_url, wiki_subfolder), )
def __init__(self, base_url, wiki_version, user_config, reset_embeddings=False): self.base_url = base_url self.wiki_version = wiki_version self.embeddings = {} self.config = self.__get_config(user_config) self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.prerank_model = None self.model = None self.reset_embeddings = reset_embeddings self.emb = GenericLookup( "entity_word_embedding", os.path.join(base_url, wiki_version, "generated")) self.g_emb = GenericLookup("common_drawl", os.path.join(base_url, "generic")) test = self.g_emb.emb(["in"], "embeddings")[0] assert ( test is not None ), "Glove embeddings in wrong folder..? Test embedding not found.." self.__load_embeddings() self.coref = TrainingEvaluationDatasets(base_url, wiki_version) self.prerank_model = PreRank(self.config).to(self.device) self.__max_conf = None # Load LR model for confidence. if os.path.exists( Path(self.config["model_path"]).parent / "lr_model.pkl"): with open( Path(self.config["model_path"]).parent / "lr_model.pkl", "rb", ) as f: self.model_lr = pkl.load(f) else: print( "No LR model found, confidence scores ED will be set to zero.") self.model_lr = None if self.config["mode"] == "eval": print("Loading model from given path: {}".format( self.config["model_path"])) self.model = self.__load(self.config["model_path"]) else: if reset_embeddings: raise Exception( "You cannot train a model and reset the embeddings.") self.model = MulRelRanker(self.config, self.device).to(self.device)
def __init__(self, base_url, wiki_subfolder): if isinstance(base_url, str): base_url = Path(base_url) self.cnt_exact = 0 self.cnt_partial = 0 self.cnt_total = 0 self.wiki_db = GenericLookup( "entity_word_embedding", base_url / wiki_subfolder / "generated", )
def store(self): """ Stores results in a sqlite3 database. :return: """ print("Please take a break, this will take a while :).") wiki_db = GenericLookup( "entity_word_embedding", "{}/{}/generated/".format(self.base_url, self.wiki_version), table_name="wiki", columns={"p_e_m": "blob", "lower": "text", "freq": "INTEGER"}, ) wiki_db.load_wiki(self.p_e_m, self.mention_freq, batch_size=50000, reset=True)
def __init__(self, base_url, wiki_version, user_config, reset_embeddings=False): if isinstance(base_url, str): base_url = Path(base_url) self.base_url = base_url self.wiki_version = wiki_version self.embeddings = {} self.config = self.__get_config(user_config) self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.prerank_model = None self.model = None self.reset_embeddings = reset_embeddings self.emb = GenericLookup( "entity_word_embedding", base_url / wiki_version / "generated", ) test = self.emb.emb(["in"], "embeddings")[0] assert test is not None, "Wikipedia embeddings in wrong folder..? Test embedding not found.." self.g_emb = GenericLookup("common_drawl", base_url / "generic") test = self.g_emb.emb(["in"], "embeddings")[0] assert test is not None, "Glove embeddings in wrong folder..? Test embedding not found.." self.__load_embeddings() self.coref = TrainingEvaluationDatasets(base_url, wiki_version) self.prerank_model = PreRank(self.config).to(self.device) self.__max_conf = None if self.config["mode"] == "eval": log.info("Loading model from given path: {}".format( self.config["model_path"])) self.model = self.__load(self.config["model_path"]) else: if reset_embeddings: raise Exception( "You cannot train a model and reset the embeddings.") self.model = MulRelRanker(self.config, self.device).to(self.device)
def __init__(self, base_url, wiki_version, wikipedia): self.wned_path = "{}/generic/test_datasets/wned-datasets/".format(base_url) self.aida_path = "{}/generic/test_datasets/AIDA/".format(base_url) self.wikipedia = wikipedia self.base_url = base_url self.wiki_version = wiki_version self.wiki_db = GenericLookup( "entity_word_embedding", "{}/{}/generated/".format(base_url, wiki_version), ) super().__init__(base_url, wiki_version)
def __init__(self, base_url, wiki_version, wikipedia): if isinstance(base_url, str): base_url = Path(base_url) self.wned_path = base_url / "generic/test_datasets/wned-datasets" self.aida_path = base_url / "generic/test_datasets/AIDA" self.wikipedia = wikipedia self.base_url = base_url self.wiki_version = wiki_version self.wiki_db = GenericLookup( "entity_word_embedding", base_url / wiki_version / "generated", ) super().__init__(base_url, wiki_version)
class EntityDisambiguation: def __init__(self, base_url, wiki_version, user_config, reset_embeddings=False): self.base_url = base_url self.wiki_version = wiki_version self.embeddings = {} self.config = self.__get_config(user_config) self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.prerank_model = None self.model = None self.reset_embeddings = reset_embeddings self.emb = GenericLookup( "entity_word_embedding", os.path.join(base_url, wiki_version, "generated")) self.g_emb = GenericLookup("common_drawl", os.path.join(base_url, "generic")) test = self.g_emb.emb(["in"], "embeddings")[0] assert ( test is not None ), "Glove embeddings in wrong folder..? Test embedding not found.." self.__load_embeddings() self.coref = TrainingEvaluationDatasets(base_url, wiki_version) self.prerank_model = PreRank(self.config).to(self.device) self.__max_conf = None # Load LR model for confidence. if os.path.exists( Path(self.config["model_path"]).parent / "lr_model.pkl"): with open( Path(self.config["model_path"]).parent / "lr_model.pkl", "rb", ) as f: self.model_lr = pkl.load(f) else: print( "No LR model found, confidence scores ED will be set to zero.") self.model_lr = None if self.config["mode"] == "eval": print("Loading model from given path: {}".format( self.config["model_path"])) self.model = self.__load(self.config["model_path"]) else: if reset_embeddings: raise Exception( "You cannot train a model and reset the embeddings.") self.model = MulRelRanker(self.config, self.device).to(self.device) def __get_config(self, user_config): """ User configuration that may overwrite default settings. :return: configuration used for ED. """ default_config: Dict[str, Any] = { "mode": "train", "model_path": "./", "prerank_ctx_window": 50, "keep_p_e_m": 4, "keep_ctx_ent": 3, "ctx_window": 100, "tok_top_n": 25, "mulrel_type": "ment-norm", "n_rels": 3, "hid_dims": 100, "emb_dims": 300, "snd_local_ctx_window": 6, "dropout_rate": 0.3, "n_epochs": 1000, "dev_f1_change_lr": 0.915, "n_not_inc": 10, "eval_after_n_epochs": 5, "learning_rate": 1e-4, "margin": 0.01, "df": 0.5, "n_loops": 10, # 'freeze_embs': True, "n_cands_before_rank": 30, "first_head_uniforn": False, "use_pad_ent": True, "use_local": True, "use_local_only": False, "oracle": False, } default_config.update(user_config) config = default_config model_dict = json.loads( pkg_resources.resource_string("REL.models", "models.json")) model_path: str = config["model_path"] # load aliased url if it exists, else keep original string config["model_path"] = model_dict.get(model_path, model_path) if urlparse(str(config["model_path"])).scheme in ("http", "https"): model_path = utils.fetch_model( config["model_path"], cache_dir=Path("~/.rel_cache").expanduser(), ) assert tarfile.is_tarfile( model_path), "Only tar-files are supported!" # make directory with name of tarfile (minus extension) # extract the files in the archive to that directory # assign config[model_path] accordingly with tarfile.open(model_path) as f: f.extractall(Path("~/.rel_cache").expanduser()) # NOTE: use double stem to deal with e.g. *.tar.gz # this also handles *.tar correctly stem = Path(Path(model_path).stem).stem # NOTE: it is required that the model file(s) are named "model.state_dict" # and "model.config" if supplied, other names won't work. config["model_path"] = Path( "~/.rel_cache").expanduser() / stem / "model" return config def __load_embeddings(self): """ Initialised embedding dictionary and creates #UNK# token for respective embeddings. :return: - """ self.__batch_embs = {} for name in ["snd", "entity", "word"]: # Init entity embeddings. self.embeddings["{}_seen".format(name)] = set() self.embeddings["{}_voca".format(name)] = Vocabulary() self.embeddings["{}_embeddings".format(name)] = None if name in ["word", "entity"]: # Add #UNK# token. self.embeddings["{}_voca".format(name)].add_to_vocab("#UNK#") e = self.emb.emb(["#{}/UNK#".format(name.upper())], "embeddings")[0] assert e is not None, "#UNK# token not found for {} in db".format( name) self.__batch_embs[name] = [] self.__batch_embs[name].append(torch.tensor(e)) else: # For Glove the #UNK# token was randomly initialised as can be seen. We added this to # our generated database for reproducability. Author also reports no significant difference # in using the mean of the vector or a randomly intialised vector for the glove embeddings. # https://github.com/lephong/mulrel-nel/issues/21 self.embeddings["{}_voca".format(name)].add_to_vocab("#UNK#") e = self.g_emb.emb(["#SND/UNK#"], "embeddings")[0] assert e is not None, "#UNK# token not found for {} in db".format( name) self.__batch_embs[name] = [] self.__batch_embs[name].append(torch.tensor(e)) def train(self, org_train_dataset, org_dev_datasets): """ Responsible for training the ED model. :return: - """ train_dataset = self.get_data_items(org_train_dataset, "train", predict=False) dev_datasets = [] for dname, data in org_dev_datasets.items(): dev_datasets.append( (dname, self.get_data_items(data, dname, predict=True))) print("Creating optimizer") optimizer = optim.Adam( [p for p in self.model.parameters() if p.requires_grad], lr=self.config["learning_rate"], ) best_f1 = -1 not_better_count = 0 eval_after_n_epochs = self.config["eval_after_n_epochs"] for e in range(self.config["n_epochs"]): shuffle(train_dataset) total_loss = 0 for dc, batch in enumerate( train_dataset): # each document is a minibatch self.model.train() optimizer.zero_grad() # convert data items to pytorch inputs token_ids = [ m["context"][0] + m["context"][1] if len(m["context"][0]) + len(m["context"][1]) > 0 else [self.embeddings["word_voca"].unk_id] for m in batch ] s_ltoken_ids = [m["snd_ctx"][0] for m in batch] s_rtoken_ids = [m["snd_ctx"][1] for m in batch] s_mtoken_ids = [m["snd_ment"] for m in batch] entity_ids = Variable( torch.LongTensor([ m["selected_cands"]["cands"] for m in batch ]).to(self.device)) true_pos = Variable( torch.LongTensor([ m["selected_cands"]["true_pos"] for m in batch ]).to(self.device)) p_e_m = Variable( torch.FloatTensor([ m["selected_cands"]["p_e_m"] for m in batch ]).to(self.device)) entity_mask = Variable( torch.FloatTensor([ m["selected_cands"]["mask"] for m in batch ]).to(self.device)) token_ids, token_mask = utils.make_equal_len( token_ids, self.embeddings["word_voca"].unk_id) s_ltoken_ids, s_ltoken_mask = utils.make_equal_len( s_ltoken_ids, self.embeddings["snd_voca"].unk_id, to_right=False) s_rtoken_ids, s_rtoken_mask = utils.make_equal_len( s_rtoken_ids, self.embeddings["snd_voca"].unk_id) s_rtoken_ids = [l[::-1] for l in s_rtoken_ids] s_rtoken_mask = [l[::-1] for l in s_rtoken_mask] s_mtoken_ids, s_mtoken_mask = utils.make_equal_len( s_mtoken_ids, self.embeddings["snd_voca"].unk_id) token_ids = Variable( torch.LongTensor(token_ids).to(self.device)) token_mask = Variable( torch.FloatTensor(token_mask).to(self.device)) # too ugly but too lazy to fix it self.model.s_ltoken_ids = Variable( torch.LongTensor(s_ltoken_ids).to(self.device)) self.model.s_ltoken_mask = Variable( torch.FloatTensor(s_ltoken_mask).to(self.device)) self.model.s_rtoken_ids = Variable( torch.LongTensor(s_rtoken_ids).to(self.device)) self.model.s_rtoken_mask = Variable( torch.FloatTensor(s_rtoken_mask).to(self.device)) self.model.s_mtoken_ids = Variable( torch.LongTensor(s_mtoken_ids).to(self.device)) self.model.s_mtoken_mask = Variable( torch.FloatTensor(s_mtoken_mask).to(self.device)) scores, ent_scores = self.model.forward( token_ids, token_mask, entity_ids, entity_mask, p_e_m, self.embeddings, gold=true_pos.view(-1, 1), ) loss = self.model.loss(scores, true_pos) # loss = self.model.prob_loss(scores, true_pos) loss.backward() optimizer.step() self.model.regularize(max_norm=100) loss = loss.cpu().data.numpy() total_loss += loss print( "epoch", e, "%0.2f%%" % (dc / len(train_dataset) * 100), loss, end="\r", ) print("epoch", e, "total loss", total_loss, total_loss / len(train_dataset)) if (e + 1) % eval_after_n_epochs == 0: dev_f1 = 0 for dname, data in dev_datasets: predictions = self.__predict(data) f1, recall, precision, _ = self.__eval( org_dev_datasets[dname], predictions) print( dname, utils.tokgreen( "Micro F1: {}, Recall: {}, Precision: {}".format( f1, recall, precision)), ) if dname == "aida_testA": dev_f1 = f1 if (self.config["learning_rate"] == 1e-4 and dev_f1 >= self.config["dev_f1_change_lr"]): eval_after_n_epochs = 2 best_f1 = dev_f1 not_better_count = 0 self.config["learning_rate"] = 1e-5 print("change learning rate to", self.config["learning_rate"]) for param_group in optimizer.param_groups: param_group["lr"] = self.config["learning_rate"] if dev_f1 < best_f1: not_better_count += 1 print("Not improving", not_better_count) else: not_better_count = 0 best_f1 = dev_f1 print("save model to", self.config["model_path"]) self.__save(self.config["model_path"]) if not_better_count == self.config["n_not_inc"]: break def evaluate(self, datasets): """ Parent function r esponsible for evaluating the ED model during the ED step. Note that this is different from predict as this requires ground truth entities to be present. :return: - """ dev_datasets = [] for dname, data in list(datasets.items()): start = time.time() dev_datasets.append( (dname, self.get_data_items(data, dname, predict=True))) for dname, data in dev_datasets: predictions = self.__predict(data) f1, recall, precision, total_nil = self.__eval( datasets[dname], predictions) print( dname, utils.tokgreen( "Micro F1: {}, Recall: {}, Precision: {}".format( f1, recall, precision)), ) print("Total NIL: {}".format(total_nil)) print("----------------------------------") def __create_dataset_LR(self, datasets, predictions, dname): X = [] y = [] meta = [] for doc, preds in predictions.items(): gt_doc = [c["gold"][0] for c in datasets[dname][doc]] for pred, gt in zip(preds, gt_doc): scores = [float(x) for x in pred["scores"]] cands = pred["candidates"] # Build classes for i, c in enumerate(cands): if c == "#UNK#": continue X.append([scores[i]]) meta.append([doc, gt, c]) if gt == c: y.append(1.0) else: y.append(0.0) return np.array(X), np.array(y), np.array(meta) def train_LR(self, datasets, model_path_lr, store_offline=True, threshold=0.3): """ Function that applies LR in an attempt to get confidence scores. Recall should be high, because if it is low than we would have ignored a corrrect entity. :return: - """ train_dataset = self.get_data_items(datasets["aida_train"], "train", predict=False) dev_datasets = [] for dname, data in list(datasets.items()): if dname == "aida_train": continue dev_datasets.append( (dname, self.get_data_items(data, dname, predict=True))) model = LogisticRegression() predictions = self.__predict(train_dataset, eval_raw=True) X, y, meta = self.__create_dataset_LR(datasets, predictions, "aida_train") model.fit(X, y) for dname, data in dev_datasets: predictions = self.__predict(data, eval_raw=True) X, y, meta = self.__create_dataset_LR(datasets, predictions, dname) preds = model.predict_proba(X) preds = np.array([x[1] for x in preds]) decisions = (preds >= threshold).astype(int) print( utils.tokgreen("{}, F1-score: {}".format( dname, f1_score(y, decisions)))) if store_offline: path = os.path.join(model_path_lr, "/lr_model.pkl") with open(path, "wb") as handle: pkl.dump(model, handle, protocol=pkl.HIGHEST_PROTOCOL) def predict(self, data): """ Parent function responsible for predicting on any raw text as input. This does not require ground truth entities to be present. :return: predictions and time taken for the ED step. """ self.coref.with_coref(data) data = self.get_data_items(data, "raw", predict=True) predictions, timing = self.__predict(data, include_timing=True, eval_raw=True) return predictions, timing def __compute_confidence_legacy(self, scores, preds): """ LEGACY :return: """ confidence_scores = [] for score, pred in zip(scores, preds): loss = 0 for j in range(len(score)): if j == pred: continue loss += max( 0, score[j].item() - score[pred].item() + self.config["margin"]) if not self.__max_conf: self.__max_conf = (self.config["keep_ctx_ent"] + self.config["keep_p_e_m"] - 1) * self.config["margin"] conf = 1 - (loss / self.__max_conf) confidence_scores.append(conf) return confidence_scores def __compute_confidence(self, scores, preds): """ Uses LR to find confidence scores for given ED outputs. :return: """ X = np.array([[score[pred]] for score, pred in zip(scores, preds)]) if self.model_lr: preds = self.model_lr.predict_proba(X) confidence_scores = [x[1] for x in preds] else: confidence_scores = [0.0 for _ in scores] return confidence_scores def __predict(self, data, include_timing=False, eval_raw=False): """ Uses the trained model to make predictions of individual batches (i.e. documents). :return: predictions and time taken for the ED step. """ predictions = {items[0]["doc_name"]: [] for items in data} self.model.eval() timing = [] for batch in data: # each document is a minibatch start = time.time() token_ids = [ m["context"][0] + m["context"][1] if len(m["context"][0]) + len(m["context"][1]) > 0 else [self.embeddings["word_voca"].unk_id] for m in batch ] s_ltoken_ids = [m["snd_ctx"][0] for m in batch] s_rtoken_ids = [m["snd_ctx"][1] for m in batch] s_mtoken_ids = [m["snd_ment"] for m in batch] entity_ids = Variable( torch.LongTensor([m["selected_cands"]["cands"] for m in batch]).to(self.device)) p_e_m = Variable( torch.FloatTensor([ m["selected_cands"]["p_e_m"] for m in batch ]).to(self.device)) entity_mask = Variable( torch.FloatTensor([m["selected_cands"]["mask"] for m in batch]).to(self.device)) true_pos = Variable( torch.LongTensor([ m["selected_cands"]["true_pos"] for m in batch ]).to(self.device)) token_ids, token_mask = utils.make_equal_len( token_ids, self.embeddings["word_voca"].unk_id) s_ltoken_ids, s_ltoken_mask = utils.make_equal_len( s_ltoken_ids, self.embeddings["snd_voca"].unk_id, to_right=False) s_rtoken_ids, s_rtoken_mask = utils.make_equal_len( s_rtoken_ids, self.embeddings["snd_voca"].unk_id) s_rtoken_ids = [l[::-1] for l in s_rtoken_ids] s_rtoken_mask = [l[::-1] for l in s_rtoken_mask] s_mtoken_ids, s_mtoken_mask = utils.make_equal_len( s_mtoken_ids, self.embeddings["snd_voca"].unk_id) token_ids = Variable(torch.LongTensor(token_ids).to(self.device)) token_mask = Variable( torch.FloatTensor(token_mask).to(self.device)) self.model.s_ltoken_ids = Variable( torch.LongTensor(s_ltoken_ids).to(self.device)) self.model.s_ltoken_mask = Variable( torch.FloatTensor(s_ltoken_mask).to(self.device)) self.model.s_rtoken_ids = Variable( torch.LongTensor(s_rtoken_ids).to(self.device)) self.model.s_rtoken_mask = Variable( torch.FloatTensor(s_rtoken_mask).to(self.device)) self.model.s_mtoken_ids = Variable( torch.LongTensor(s_mtoken_ids).to(self.device)) self.model.s_mtoken_mask = Variable( torch.FloatTensor(s_mtoken_mask).to(self.device)) scores, ent_scores = self.model.forward( token_ids, token_mask, entity_ids, entity_mask, p_e_m, self.embeddings, gold=true_pos.view(-1, 1), ) pred_ids = torch.argmax(scores, axis=1) scores = scores.cpu().data.numpy() confidence_scores = self.__compute_confidence(scores, pred_ids) pred_ids = np.argmax(scores, axis=1) if not eval_raw: pred_entities = [ m["selected_cands"]["named_cands"][i] if m["selected_cands"]["mask"][i] == 1 else (m["selected_cands"]["named_cands"][0] if m["selected_cands"]["mask"][0] == 1 else "NIL") for (i, m) in zip(pred_ids, batch) ] doc_names = [m["doc_name"] for m in batch] for dname, entity in zip(doc_names, pred_entities): predictions[dname].append({"pred": (entity, 0.0)}) else: pred_entities = [[ m["selected_cands"]["named_cands"][i], m["raw"]["mention"], m["selected_cands"]["named_cands"], s, cs, m["selected_cands"]["mask"], ] if m["selected_cands"]["mask"][i] == 1 else ([ m["selected_cands"]["named_cands"][0], m["raw"]["mention"], m["selected_cands"]["named_cands"], s, cs, m["selected_cands"]["mask"], ] if m["selected_cands"]["mask"][0] == 1 else [ "NIL", m["raw"]["mention"], m["selected_cands"]["named_cands"], s, cs, m["selected_cands"]["mask"], ]) for (i, m, s, cs) in zip(pred_ids, batch, scores, confidence_scores)] doc_names = [m["doc_name"] for m in batch] for dname, entity in zip(doc_names, pred_entities): if entity[0] != "NIL": predictions[dname].append({ "mention": entity[1], "prediction": entity[0], "candidates": entity[2], "conf_ed": entity[4], "scores": list([str(x) for x in entity[3]]), }) else: predictions[dname].append({ "mention": entity[1], "prediction": entity[0], "candidates": entity[2], "scores": [], }) timing.append(time.time() - start) if include_timing: return predictions, timing else: return predictions def prerank(self, dataset, dname, predict=False): """ Responsible for preranking the set of possible candidates using both context and p(e|m) scores. :return: dataset with, by default, max 3 + 4 candidates per mention. """ new_dataset = [] has_gold = 0 total = 0 for content in dataset: items = [] if self.config["keep_ctx_ent"] > 0: # rank the candidates by ntee scores lctx_ids = [ m["context"][0][max( len(m["context"][0]) - self.config["prerank_ctx_window"] // 2, 0, ):] for m in content ] rctx_ids = [ m["context"][1] [:min(len(m["context"][1] ), self.config["prerank_ctx_window"] // 2)] for m in content ] ment_ids = [[] for m in content] token_ids = [ l + m + r if len(l) + len(r) > 0 else [self.embeddings["word_voca"].unk_id] for l, m, r in zip(lctx_ids, ment_ids, rctx_ids) ] entity_ids = [m["cands"] for m in content] entity_ids = Variable( torch.LongTensor(entity_ids).to(self.device)) entity_mask = [m["mask"] for m in content] entity_mask = Variable( torch.FloatTensor(entity_mask).to(self.device)) token_ids, token_offsets = utils.flatten_list_of_lists( token_ids) token_offsets = Variable( torch.LongTensor(token_offsets).to(self.device)) token_ids = Variable( torch.LongTensor(token_ids).to(self.device)) entity_names = [m["named_cands"] for m in content] # named_cands log_probs = self.prerank_model.forward(token_ids, token_offsets, entity_ids, self.embeddings, self.emb) # Entity mask makes sure that the UNK entities are zero. log_probs = (log_probs * entity_mask).add_( (entity_mask - 1).mul_(1e10)) _, top_pos = torch.topk(log_probs, dim=1, k=self.config["keep_ctx_ent"]) top_pos = top_pos.data.cpu().numpy() else: top_pos = [[]] * len(content) # select candidats: mix between keep_ctx_ent best candidates (ntee scores) with # keep_p_e_m best candidates (p_e_m scores) for i, m in enumerate(content): sm = { "cands": [], "named_cands": [], "p_e_m": [], "mask": [], "true_pos": -1, } m["selected_cands"] = sm selected = set(top_pos[i]) idx = 0 while (len(selected) < self.config["keep_ctx_ent"] + self.config["keep_p_e_m"]): if idx not in selected: selected.add(idx) idx += 1 selected = sorted(list(selected)) for idx in selected: sm["cands"].append(m["cands"][idx]) sm["named_cands"].append(m["named_cands"][idx]) sm["p_e_m"].append(m["p_e_m"][idx]) sm["mask"].append(m["mask"][idx]) if idx == m["true_pos"]: sm["true_pos"] = len(sm["cands"]) - 1 if not predict: if sm["true_pos"] == -1: continue items.append(m) if sm["true_pos"] >= 0: has_gold += 1 total += 1 if predict: # only for oracle model, not used for eval if sm["true_pos"] == -1: sm["true_pos"] = 0 # a fake gold, happens only 2%, but avoid the non-gold if len(items) > 0: new_dataset.append(items) # if total > 0 if dname != "raw": print("Recall for {}: {}".format(dname, has_gold / total)) print("-----------------------------------------------") return new_dataset def __update_embeddings(self, emb_name, embs): """ Responsible for updating the dictionaries with their respective word, entity and snd (GloVe) embeddings. :return: - """ embs = embs.to(self.device) if self.embeddings["{}_embeddings".format(emb_name)]: new_weights = torch.cat( (self.embeddings["{}_embeddings".format(emb_name)].weight, embs)) else: new_weights = embs # Weights are now updated, so we create a new Embedding layer. layer = torch.nn.Embedding( self.embeddings["{}_voca".format(emb_name)].size(), self.config["emb_dims"]) layer.weight = torch.nn.Parameter(new_weights) layer.grad = False self.embeddings["{}_embeddings".format(emb_name)] = layer if emb_name == "word": layer = torch.nn.EmbeddingBag( self.embeddings["{}_voca".format(emb_name)].size(), self.config["emb_dims"], ) layer.weight = torch.nn.Parameter(new_weights) layer.requires_grad = False self.embeddings["{}_embeddings_bag".format(emb_name)] = layer del new_weights def __embed_words(self, words_filt, name, table_name="embeddings"): """ Responsible for retrieving embeddings using the given sqlite3 database. :return: - """ # Returns None if not in db. if table_name == "glove": embs = self.g_emb.emb(words_filt, "embeddings") else: embs = self.emb.emb(words_filt, table_name) # Now we go over the embs and see which one is None. Order is preserved. for e, c in zip(embs, words_filt): if name == "entity": c = c.replace("ENTITY/", "") self.embeddings["{}_seen".format(name)].add(c) if e is not None: # Embedding exists, so we add it. self.embeddings["{}_voca".format(name)].add_to_vocab(c) self.__batch_embs[name].append(torch.tensor(e)) def get_data_items(self, dataset, dname, predict=False): """ Responsible for formatting dataset. Triggers the preranking function. :return: preranking function. """ data = [] if self.reset_embeddings: # If user wants to reset, he can do this here, right before loading a new dataset. self.__load_embeddings() for doc_name, content in dataset.items(): items = [] if len(content) == 0: continue conll_doc = content[0].get("conll_doc", None) for m in content: named_cands = [c[0] for c in m["candidates"]] p_e_m = [min(1.0, max(1e-3, c[1])) for c in m["candidates"]] try: true_pos = named_cands.index(m["gold"][0]) p = p_e_m[true_pos] except: true_pos = -1 # Get all words and check for embeddings. named_cands = named_cands[:min( self.config["n_cands_before_rank"], len(named_cands))] # Candidate list per mention. named_cands_filt = set([ "ENTITY/" + item for item in named_cands if item not in self.embeddings["entity_seen"] ]) self.__embed_words(named_cands_filt, "entity", "embeddings") # Use re.split() to make sure that special characters are considered. lctx = [ x for x in re.split("(\W)", m["context"][0].strip()) if x != " " ] # .split() rctx = [ x for x in re.split("(\W)", m["context"][1].strip()) if x != " " ] # split() words_filt = set([ item for item in lctx + rctx if item not in self.embeddings["word_seen"] ]) self.__embed_words(words_filt, "word", "embeddings") snd_lctx = m["sentence"][:m["pos"]].strip().split() snd_lctx = [ t for t in snd_lctx[-self.config["snd_local_ctx_window"] // 2:] ] snd_rctx = m["sentence"][m["end_pos"]:].strip().split() snd_rctx = [ t for t in snd_rctx[:self.config["snd_local_ctx_window"] // 2] ] snd_ment = m["ngram"].strip().split() words_filt = set([ item for item in snd_lctx + snd_rctx + snd_ment if item not in self.embeddings["snd_seen"] ]) self.__embed_words(words_filt, "snd", "glove") p_e_m = p_e_m[:min(self. config["n_cands_before_rank"], len(p_e_m))] if true_pos >= len(named_cands): if not predict: true_pos = len(named_cands) - 1 p_e_m[-1] = p named_cands[-1] = m["gold"][0] else: true_pos = -1 cands = [ self.embeddings["entity_voca"].get_id( # ("" if self.generic else wiki_prefix) + c c) for c in named_cands ] mask = [1.0] * len(cands) if len(cands) == 0 and not predict: continue elif len(cands) < self.config["n_cands_before_rank"]: cands += [self.embeddings["entity_voca"].unk_id] * ( self.config["n_cands_before_rank"] - len(cands)) named_cands += [Vocabulary.unk_token] * ( self.config["n_cands_before_rank"] - len(named_cands)) p_e_m += [1e-8] * (self.config["n_cands_before_rank"] - len(p_e_m)) mask += [0.0] * (self.config["n_cands_before_rank"] - len(mask)) lctx_ids = [ self.embeddings["word_voca"].get_id(t) for t in lctx if utils.is_important_word(t) ] lctx_ids = [ tid for tid in lctx_ids if tid != self.embeddings["word_voca"].unk_id ] lctx_ids = lctx_ids[ max(0, len(lctx_ids) - self.config["ctx_window"] // 2):] rctx_ids = [ self.embeddings["word_voca"].get_id(t) for t in rctx if utils.is_important_word(t) ] rctx_ids = [ tid for tid in rctx_ids if tid != self.embeddings["word_voca"].unk_id ] rctx_ids = rctx_ids[:min(len(rctx_ids ), self.config["ctx_window"] // 2)] ment = m["mention"].strip().split() ment_ids = [ self.embeddings["word_voca"].get_id(t) for t in ment if utils.is_important_word(t) ] ment_ids = [ tid for tid in ment_ids if tid != self.embeddings["word_voca"].unk_id ] m["sent"] = " ".join(lctx + rctx) # Secondary local context. snd_lctx = [ self.embeddings["snd_voca"].get_id(t) for t in snd_lctx ] snd_rctx = [ self.embeddings["snd_voca"].get_id(t) for t in snd_rctx ] snd_ment = [ self.embeddings["snd_voca"].get_id(t) for t in snd_ment ] # This is only used for the original embeddings, now they are never empty. if len(snd_lctx) == 0: snd_lctx = [self.embeddings["snd_voca"].unk_id] if len(snd_rctx) == 0: snd_rctx = [self.embeddings["snd_voca"].unk_id] if len(snd_ment) == 0: snd_ment = [self.embeddings["snd_voca"].unk_id] items.append({ "context": (lctx_ids, rctx_ids), "snd_ctx": (snd_lctx, snd_rctx), "ment_ids": ment_ids, "snd_ment": snd_ment, "cands": cands, "named_cands": named_cands, "p_e_m": p_e_m, "mask": mask, "true_pos": true_pos, "doc_name": doc_name, "raw": m, }) if len(items) > 0: # note: this shouldn't affect the order of prediction because we use doc_name to add predicted entities, # and we don't shuffle the data for prediction if len(items) > 100: # print(len(items)) for k in range(0, len(items), 100): data.append(items[k:min(len(items), k + 100)]) else: data.append(items) # Update batch for n in ["word", "entity", "snd"]: if self.__batch_embs[n]: self.__batch_embs[n] = torch.stack(self.__batch_embs[n]) self.__update_embeddings(n, self.__batch_embs[n]) self.__batch_embs[n] = [] return self.prerank(data, dname, predict) def __eval(self, testset, system_pred): """ Responsible for evaluating data points, which is solely used for the local ED step. :return: F1, Recall, Precision and number of mentions for which we have no valid candidate. """ gold = [] pred = [] for doc_name, content in testset.items(): if len(content) == 0: continue gold += [c["gold"][0] for c in content] pred += [c["pred"][0] for c in system_pred[doc_name]] true_pos = 0 total_nil = 0 for g, p in zip(gold, pred): if p == "NIL": total_nil += 1 if g == p and p != "NIL": true_pos += 1 precision = true_pos / len([p for p in pred if p != "NIL"]) recall = true_pos / len(gold) f1 = 2 * precision * recall / (precision + recall) return f1, recall, precision, total_nil def __save(self, path): """ Responsible for storing the trained model during optimisation. :return: -. """ torch.save(self.model.state_dict(), "{}.state_dict".format(path)) with open("{}.config".format(path), "w") as f: json.dump(self.config, f) def __load(self, path): """ Responsible for loading a trained model and its respective config. Note that this config cannot be overwritten. If required, this behavior may be modified in future releases. :return: model """ if os.path.exists("{}.config".format(path)): with open("{}.config".format(path), "r") as f: temp = self.config["model_path"] self.config = json.load(f) self.config["model_path"] = temp else: print( "No configuration file found at {}, default settings will be used." .format("{}.config".format(path))) model = MulRelRanker(self.config, self.device).to(self.device) # , self.embeddings if not torch.cuda.is_available(): model.load_state_dict( torch.load( "{}{}".format(self.config["model_path"], ".state_dict"), map_location=torch.device("cpu"), )) else: model.load_state_dict( torch.load("{}{}".format(self.config["model_path"], ".state_dict"))) return model
from constant import invalid_relations_set from REL.db.generic import GenericLookup sqlite_path = "../../Documents/wiki_2020/generated" emb = GenericLookup("entity_word_embedding", save_dir=sqlite_path, table_name="embeddings") def Map(head, relations, tail, top_first=True, best_scores=True): if head == None or tail == None or relations == None: return {} head_p_e_m = emb.wiki(str(head), 'wiki') if head_p_e_m is None: return {} tail_p_e_m = emb.wiki(str(tail), 'wiki') if tail_p_e_m is None: return {} tail_p_e_m = tail_p_e_m[0] head_p_e_m = head_p_e_m[0] valid_relations = [ r for r in relations if r not in invalid_relations_set and r.isalpha() and len(r) > 1 ] if len(valid_relations) == 0: return {} return { 'h': head_p_e_m[0], 't': tail_p_e_m[0],
class MentionDetection: def __init__(self, base_url, wiki_subfolder): self.cnt_exact = 0 self.cnt_partial = 0 self.cnt_total = 0 self.wiki_db = GenericLookup( "entity_word_embedding", "{}/{}/generated/".format(base_url, wiki_subfolder), ) # def __verify_pos(self, ngram, start, end, sentence): # ngram = ngram.lower() # find_ngram = sentence[start:end].lower() # find_ngram_ws_invariant = " ".join( # [x.text for x in Sentence(find_ngram, use_tokenizer=True)] # ).lower() # assert (find_ngram == ngram) or ( # find_ngram_ws_invariant == ngram # ), "Mention not found on given position: {};{};{};{}".format( # find_ngram, ngram, find_ngram_ws_invariant, sentence # ) # def split_text(self, dataset): # """ # Splits text into sentences. This behavior is required for the default NER-tagger, which during experiments # was experienced to perform more optimally in such a fashion. # # :return: dictionary with sentences and optional given spans per sentence. # """ # # res = {} # for doc in dataset: # text, spans = dataset[doc] # sentences = split_single(text) # res[doc] = {} # # i = 0 # for sent in sentences: # if len(sent.strip()) == 0: # continue # # Match gt to sentence. # pos_start = text.find(sent) # pos_end = pos_start + len(sent) # # # ngram, start_pos, end_pos # spans_sent = [ # [text[x[0] : x[0] + x[1]], x[0], x[0] + x[1]] # for x in spans # if pos_start <= x[0] < pos_end # ] # res[doc][i] = [sent, spans_sent] # i += 1 # return res def _get_ctxt(self, start, end, idx_sent, sentence): """ Retrieves context surrounding a given mention up to 100 words from both sides. :return: left and right context """ # Iteratively add words up until we have 100 left_ctxt = split_in_words(sentence[:start]) if idx_sent > 0: i = idx_sent - 1 while (i >= 0) and (len(left_ctxt) <= 100): left_ctxt = split_in_words(self.sentences_doc[i]) + left_ctxt i -= 1 left_ctxt = left_ctxt[-100:] left_ctxt = " ".join(left_ctxt) right_ctxt = split_in_words(sentence[end:]) if idx_sent < len(self.sentences_doc): i = idx_sent + 1 while (i < len(self.sentences_doc)) and (len(right_ctxt) <= 100): right_ctxt = right_ctxt + split_in_words(self.sentences_doc[i]) i += 1 right_ctxt = right_ctxt[:100] right_ctxt = " ".join(right_ctxt) return left_ctxt, right_ctxt def _get_candidates(self, mention): """ Retrieves a maximum of 100 candidates from the sqlite3 database for a given mention. :return: set of candidates """ # Performs extra check for ED. cands = self.wiki_db.wiki(mention, "wiki") if cands: return cands[:100] else: return [] def format_spans(self, dataset): """ Responsible for formatting given spans into dataset for the ED step. More specifically, it returns the mention, its left/right context and a set of candidates. :return: Dictionary with mentions per document. """ dataset, _, _ = self.split_text(dataset) results = {} total_ment = 0 for doc in dataset: contents = dataset[doc] self.sentences_doc = [v[0] for v in contents.values()] results_doc = [] for idx_sent, (sentence, spans) in contents.items(): for ngram, start_pos, end_pos in spans: total_ment += 1 # end_pos = start_pos + length # ngram = text[start_pos:end_pos] mention = preprocess_mention(ngram, self.wiki_db) left_ctxt, right_ctxt = self._get_ctxt( start_pos, end_pos, idx_sent, sentence) chosen_cands = self._get_candidates(mention) res = { "mention": mention, "context": (left_ctxt, right_ctxt), "candidates": chosen_cands, "gold": ["NONE"], "pos": start_pos, "sent_idx": idx_sent, "ngram": ngram, "end_pos": end_pos, "sentence": sentence, } results_doc.append(res) results[doc] = results_doc return results, total_ment # def find_mentions(self, dataset, tagger_ner=None): # """ # Responsible for finding mentions given a set of documents. More specifically, # it returns the mention, its left/right context and a set of candidates. # # :return: Dictionary with mentions per document. # """ # # if tagger_ner is None: # raise Exception( # "No NER tagger is set, but you are attempting to perform Mention Detection.." # ) # # dataset, _, _ = self.split_text(dataset) # results = {} # total_ment = 0 # # for doc in dataset: # contents = dataset[doc] # # self.sentences_doc = [v[0] for v in contents.values()] # result_doc = [] # # sentences = [ # Sentence(v[0], use_tokenizer=True) for k, v in contents.items() # ] # # tagger_ner.predict(sentences) # # for (idx_sent, (sentence, ground_truth_sentence)), snt in zip( # contents.items(), sentences # ): # illegal = [] # for entity in snt.get_spans("ner"): # text, start_pos, end_pos, conf = ( # entity.text, # entity.start_pos, # entity.end_pos, # entity.score, # ) # total_ment += 1 # # m = preprocess_mention(text, self.wiki_db) # cands = self._get_candidates(m) # # if len(cands) == 0: # continue # # ngram = sentence[start_pos:end_pos] # illegal.extend(range(start_pos, end_pos)) # # left_ctxt, right_ctxt = self._get_ctxt( # start_pos, end_pos, idx_sent, sentence # ) # # res = { # "mention": m, # "context": (left_ctxt, right_ctxt), # "candidates": cands, # "gold": ["NONE"], # "pos": start_pos, # "sent_idx": idx_sent, # "ngram": ngram, # "end_pos": end_pos, # "sentence": sentence, # "conf_md": conf, # "tag": entity.tag, # } # # result_doc.append(res) # # results[doc] = result_doc # # return results, total_ment def split_text(self, dataset): """ Splits text into sentences. This behavior is required for the default NER-tagger, which during experiments was experienced to perform more optimally in such a fashion. :return: dictionary with sentences and optional given spans per sentence. """ res = {} splits = [0] processed_sentences = [] for doc in dataset: text, spans = dataset[doc] sentences = split_single(text) res[doc] = {} i = 0 for sent in sentences: if len(sent.strip()) == 0: continue # Match gt to sentence. pos_start = text.find(sent) pos_end = pos_start + len(sent) # ngram, start_pos, end_pos spans_sent = [[text[x[0]:x[0] + x[1]], x[0], x[0] + x[1]] for x in spans if pos_start <= x[0] < pos_end] res[doc][i] = [sent, spans_sent] if len(spans) == 0: processed_sentences.append( Sentence(sent, use_tokenizer=True)) i += 1 splits.append(splits[-1] + i) return res, processed_sentences, splits def find_mentions(self, dataset, tagger_ner=None): """ Responsible for finding mentions given a set of documents in a batch-wise manner. More specifically, it returns the mention, its left/right context and a set of candidates. :return: Dictionary with mentions per document. """ if tagger_ner is None: raise Exception( "No NER tagger is set, but you are attempting to perform Mention Detection.." ) dataset, processed_sentences, splits = self.split_text(dataset) results = {} total_ment = 0 tagger_ner.predict(processed_sentences, mini_batch_size=32) for i, doc in enumerate(dataset): contents = dataset[doc] self.sentences_doc = [v[0] for v in contents.values()] sentences = processed_sentences[splits[i]:splits[i + 1]] result_doc = [] for (idx_sent, (sentence, ground_truth_sentence)), snt in zip( contents.items(), sentences): illegal = [] for entity in snt.get_spans("ner"): text, start_pos, end_pos, conf = ( entity.text, entity.start_pos, entity.end_pos, entity.score, ) total_ment += 1 m = preprocess_mention(text, self.wiki_db) cands = self._get_candidates(m) if len(cands) == 0: continue ngram = sentence[start_pos:end_pos] illegal.extend(range(start_pos, end_pos)) left_ctxt, right_ctxt = self._get_ctxt( start_pos, end_pos, idx_sent, sentence) res = { "mention": m, "context": (left_ctxt, right_ctxt), "candidates": cands, "gold": ["NONE"], "pos": start_pos, "sent_idx": idx_sent, "ngram": ngram, "end_pos": end_pos, "sentence": sentence, "conf_md": conf, "tag": entity.tag, } result_doc.append(res) results[doc] = result_doc return results, total_ment
def __init__(self, base_url, wiki_version): self.wiki_db = GenericLookup( "entity_word_embedding", os.path.join(base_url, wiki_version, "generated"))
class MentionDetectionBase: def __init__(self, base_url, wiki_version): self.wiki_db = GenericLookup( "entity_word_embedding", os.path.join(base_url, wiki_version, "generated")) def get_ctxt(self, start, end, idx_sent, sentence, sentences_doc): """ Retrieves context surrounding a given mention up to 100 words from both sides. :return: left and right context """ # Iteratively add words up until we have 100 left_ctxt = split_in_words(sentence[:start]) if idx_sent > 0: i = idx_sent - 1 while (i >= 0) and (len(left_ctxt) <= 100): left_ctxt = split_in_words(sentences_doc[i]) + left_ctxt i -= 1 left_ctxt = left_ctxt[-100:] left_ctxt = " ".join(left_ctxt) right_ctxt = split_in_words(sentence[end:]) if idx_sent < len(sentences_doc): i = idx_sent + 1 while (i < len(sentences_doc)) and (len(right_ctxt) <= 100): right_ctxt = right_ctxt + split_in_words(sentences_doc[i]) i += 1 right_ctxt = right_ctxt[:100] right_ctxt = " ".join(right_ctxt) return left_ctxt, right_ctxt def get_candidates(self, mention): """ Retrieves a maximum of 100 candidates from the sqlite3 database for a given mention. :return: set of candidates """ # Performs extra check for ED. cands = self.wiki_db.wiki(mention, "wiki") if cands: return cands[:100] else: return [] def preprocess_mention(self, m): """ Responsible for preprocessing a mention and making sure we find a set of matching candidates in our database. :return: mention """ # TODO: This can be optimised (less db calls required). cur_m = modify_uppercase_phrase(m) freq_lookup_cur_m = self.wiki_db.wiki(cur_m, "wiki", "freq") if not freq_lookup_cur_m: cur_m = m freq_lookup_m = self.wiki_db.wiki(m, "wiki", "freq") freq_lookup_cur_m = self.wiki_db.wiki(cur_m, "wiki", "freq") if freq_lookup_m and (freq_lookup_m > freq_lookup_cur_m): # Cases like 'U.S.' are handed badly by modify_uppercase_phrase cur_m = m freq_lookup_cur_m = self.wiki_db.wiki(cur_m, "wiki", "freq") # If we cannot find the exact mention in our index, we try our luck to # find it in a case insensitive index. if not freq_lookup_cur_m: # cur_m and m both not found, verify if lower-case version can be found. find_lower = self.wiki_db.wiki(m.lower(), "wiki", "lower") if find_lower: cur_m = find_lower freq_lookup_cur_m = self.wiki_db.wiki(cur_m, "wiki", "freq") # Try and remove first or last characters (e.g. 'Washington,' to 'Washington') # To be error prone, we only try this if no match was found thus far, else # this might get in the way of 'U.S.' converting to 'US'. # Could do this recursively, interesting to explore in future work. if not freq_lookup_cur_m: temp = re.sub(r"[\(.|,|!|')]", "", m).strip() simple_lookup = self.wiki_db.wiki(temp, "wiki", "freq") if simple_lookup: cur_m = temp return cur_m
def __init__(self, base_url, wiki_version): self.wiki_db = GenericLookup( "entity_word_embedding", "{}/{}/generated/".format(base_url, wiki_version), )
class MentionDetection: """ Class responsible for mention detection. """ def __init__(self, base_url, wiki_subfolder): if isinstance(base_url, str): base_url = Path(base_url) self.cnt_exact = 0 self.cnt_partial = 0 self.cnt_total = 0 self.wiki_db = GenericLookup( "entity_word_embedding", base_url / wiki_subfolder / "generated", ) # def __verify_pos(self, ngram, start, end, sentence): # ngram = ngram.lower() # find_ngram = sentence[start:end].lower() # find_ngram_ws_invariant = " ".join( # [x.text for x in Sentence(find_ngram, use_tokenizer=True)] # ).lower() # assert (find_ngram == ngram) or ( # find_ngram_ws_invariant == ngram # ), "Mention not found on given position: {};{};{};{}".format( # find_ngram, ngram, find_ngram_ws_invariant, sentence # ) def split_text(self, dataset): """ Splits text into sentences. This behavior is required for the default NER-tagger, which during experiments was experienced to perform more optimally in such a fashion. :return: dictionary with sentences and optional given spans per sentence. """ res = {} for doc in dataset: text, spans = dataset[doc] sentences = split_single(text) res[doc] = {} i = 0 for sent in sentences: if len(sent.strip()) == 0: continue # Match gt to sentence. pos_start = text.find(sent) pos_end = pos_start + len(sent) # ngram, start_pos, end_pos spans_sent = [ [text[x[0] : x[0] + x[1]], x[0], x[0] + x[1]] for x in spans if pos_start <= x[0] < pos_end ] res[doc][i] = [sent, spans_sent] i += 1 return res def _get_ctxt(self, start, end, idx_sent, sentence): """ Retrieves context surrounding a given mention up to 100 words from both sides. :return: left and right context """ # Iteratively add words up until we have 100 left_ctxt = split_in_words(sentence[:start]) if idx_sent > 0: i = idx_sent - 1 while (i >= 0) and (len(left_ctxt) <= 100): left_ctxt = split_in_words(self.sentences_doc[i]) + left_ctxt i -= 1 left_ctxt = left_ctxt[-100:] left_ctxt = " ".join(left_ctxt) right_ctxt = split_in_words(sentence[end:]) if idx_sent < len(self.sentences_doc): i = idx_sent + 1 while (i < len(self.sentences_doc)) and (len(right_ctxt) <= 100): right_ctxt = right_ctxt + split_in_words(self.sentences_doc[i]) i += 1 right_ctxt = right_ctxt[:100] right_ctxt = " ".join(right_ctxt) return left_ctxt, right_ctxt def _get_candidates(self, mention, top_n=100): """ Retrieves a maximum of n candidates from the sqlite3 database for a given mention. :param top_n: number of candidates to return :return: set of candidates """ # Performs extra check for ED. # TODO: Add `LIMIT n` to the SQL Query to better performance candidates = self.wiki_db.wiki(mention, "wiki") if candidates: return candidates[:top_n] else: return [] def format_spans(self, dataset): """ Responsible for formatting given spans into dataset for the ED step. More specifically, it returns the mention, its left/right context and a set of candidates. :return: Dictionary with mentions per document. """ dataset = self.split_text(dataset) results = {} total_ment = 0 for doc in dataset: contents = dataset[doc] self.sentences_doc = [v[0] for v in contents.values()] results_doc = [] for idx_sent, (sentence, spans) in contents.items(): for ngram, start_pos, end_pos in spans: total_ment += 1 # end_pos = start_pos + length # ngram = text[start_pos:end_pos] mention = preprocess_mention(ngram, self.wiki_db) left_ctxt, right_ctxt = self._get_ctxt( start_pos, end_pos, idx_sent, sentence ) chosen_cands = self._get_candidates(mention) res = { "mention": mention, "context": (left_ctxt, right_ctxt), "candidates": chosen_cands, "gold": ["NONE"], "pos": start_pos, "sent_idx": idx_sent, "ngram": ngram, "end_pos": end_pos, "sentence": sentence, } results_doc.append(res) results[doc] = results_doc return results, total_ment def find_mentions(self, dataset, tagger_ner=None): """ Responsible for finding mentions given a set of documents. More specifically, it returns the mention, its left/right context and a set of candidates. :return: Dictionary with mentions per document. """ if tagger_ner is None: raise Exception( "No NER tagger is set, but you are attempting to perform Mention Detection.." ) dataset = self.split_text(dataset) results = {} total_ment = 0 for doc in dataset: contents = dataset[doc] self.sentences_doc = [v[0] for v in contents.values()] result_doc = [] sentences = [ Sentence(v[0], use_tokenizer=True) for k, v in contents.items() ] tagger_ner.predict(sentences) for (idx_sent, (sentence, ground_truth_sentence)), snt in zip( contents.items(), sentences ): illegal = [] for entity in snt.get_spans("ner"): text, start_pos, end_pos, conf = ( entity.text, entity.start_pos, entity.end_pos, entity.score, ) total_ment += 1 m = preprocess_mention(text, self.wiki_db) cands = self._get_candidates(m) if len(cands) == 0: continue ngram = sentence[start_pos:end_pos] illegal.extend(range(start_pos, end_pos)) left_ctxt, right_ctxt = self._get_ctxt( start_pos, end_pos, idx_sent, sentence ) res = { "mention": m, "context": (left_ctxt, right_ctxt), "candidates": cands, "gold": ["NONE"], "pos": start_pos, "sent_idx": idx_sent, "ngram": ngram, "end_pos": end_pos, "sentence": sentence, "conf_md": conf, } result_doc.append(res) results[doc] = result_doc return results, total_ment