def test_edge_case(self): # Edge-case lengths # Test maximum sequence length max_aliases = 30 max_seq_len = 3 # Manual data sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5' aliases = ["The big alias1", "multi word alias2 and alias3"] aliases_to_predict = [0, 1] spans = [[0, 3], [8, 13]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [ "The big alias1".split(), "multi word alias2".split() ] true_spans_arr = [[[0, 3]], [[0, 5]]] true_alias_to_predict_arr = [[0], [0]] true_aliases_arr = [["The big alias1"], ["multi word alias2 and alias3"]] assert len(idxs_arr) == 2 assert len(aliases_to_predict_arr) == 2 assert len(spans_arr) == 2 assert len(phrase_tokens_arr) == 2 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def test_split_sentence_alias_to_predict(self): # No splitting nut change in aliases to predict...nothing should change max_aliases = 30 max_seq_len = 24 # Manually created data sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5' aliases = ["The big", "alias3", "alias5"] aliases_to_predict = [0, 1] spans = [[0, 2], [12, 13], [20, 21]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # Truth data true_phrase_arr = [ "The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5 <pad> <pad> <pad>" .split(" ") ] true_spans_arr = [[[0, 2], [12, 13], [20, 21]]] true_alias_to_predict_arr = [[0, 1]] true_aliases_arr = [["The big", "alias3", "alias5"]] assert len(idxs_arr) == 1 assert len(aliases_to_predict_arr) == 1 assert len(spans_arr) == 1 assert len(phrase_tokens_arr) == 1 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def test_split_sentence_max_aliases(self): # Test if the sentence splits correctly when max_aliases is less than the number of aliases max_aliases = 2 max_seq_len = 24 # Manually created data sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5' aliases = ["The big", "alias3", "alias5"] aliases_to_predict = [0, 1, 2] spans = [[0, 2], [12, 13], [20, 21]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [ "The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5 <pad> <pad> <pad>" .split(" ") ] * 2 true_spans_arr = [[[0, 2], [12, 13]], [[20, 21]]] true_alias_to_predict_arr = [[0, 1], [0]] true_aliases_arr = [["The big", "alias3"], ["alias5"]] assert len(idxs_arr) == 2 assert len(aliases_to_predict_arr) == 2 assert len(spans_arr) == 2 assert len(phrase_tokens_arr) == 2 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def label_mentions(self, text): sample = self.extract_mentions(text) idxs_arr, aliases_to_predict_per_split, spans_arr, phrase_tokens_arr = sentence_utils.split_sentence( max_aliases=self.args.data_config.max_aliases, phrase=sample['sentence'], spans=sample['spans'], aliases=sample['aliases'], aliases_seen_by_model=[i for i in range(len(sample['aliases']))], seq_len=self.args.data_config.max_word_token_len, word_symbols=self.word_db) aliases_arr = [[sample['aliases'][idx] for idx in idxs] for idxs in idxs_arr] qids_arr = [[sample['qids'][idx] for idx in idxs] for idxs in idxs_arr] word_indices_arr = [ self.word_db.convert_tokens_to_ids(pt) for pt in phrase_tokens_arr ] if len(idxs_arr) > 1: #TODO: support sentences that overflow due to long sequence length or too many mentions raise ValueError( 'Overflowing sentences not currently supported in Annotator') # iterate over each sample in the split for sub_idx in range(len(idxs_arr)): example_aliases = np.ones(self.args.data_config.max_aliases, dtype=np.int) * PAD_ID example_true_entities = np.ones( self.args.data_config.max_aliases) * PAD_ID example_aliases_locs_start = np.ones( self.args.data_config.max_aliases) * PAD_ID example_aliases_locs_end = np.ones( self.args.data_config.max_aliases) * PAD_ID aliases = aliases_arr[sub_idx] for mention_idx, alias in enumerate(aliases): # get aliases alias_trie_idx = self.entity_db.get_alias_idx(alias) alias_qids = np.array(self.entity_db.get_qid_cands(alias)) example_aliases[mention_idx] = alias_trie_idx # alias_idx_pair span_idx = spans_arr[sub_idx][mention_idx] span_start_idx, span_end_idx = span_idx example_aliases_locs_start[mention_idx] = span_start_idx example_aliases_locs_end[mention_idx] = span_end_idx # get word indices word_indices = word_indices_arr[sub_idx] # entity indices from alias table (these are the candidates) entity_indices = self.alias_table(example_aliases) # all CPU embs have to retrieved on the fly batch_on_the_fly_data = {} for emb_name, emb in self.batch_on_the_fly_embs.items(): batch_on_the_fly_data[emb_name] = torch.tensor( emb.batch_prep(example_aliases, entity_indices), device=self.device) outs, entity_pack, _ = self.model( alias_idx_pair_sent=[ torch.tensor(example_aliases_locs_start, device=self.device).unsqueeze(0), torch.tensor(example_aliases_locs_end, device=self.device).unsqueeze(0) ], word_indices=torch.tensor([word_indices], device=self.device), alias_indices=torch.tensor(example_aliases, device=self.device).unsqueeze(0), entity_indices=torch.tensor(entity_indices, device=self.device).unsqueeze(0), batch_prepped_data={}, batch_on_the_fly_data=batch_on_the_fly_data) entity_cands = eval_utils.map_aliases_to_candidates( self.args.data_config.train_in_candidates, self.entity_db, aliases) # recover predictions probs = torch.exp( eval_utils.masked_class_logsoftmax( pred=outs[DISAMBIG][FINAL_LOSS], mask=~entity_pack.mask, dim=2)) max_probs, max_probs_indices = probs.max(2) pred_cands = [] pred_probs = [] titles = [] for alias_idx in range(len(aliases)): pred_idx = max_probs_indices[0][alias_idx] pred_prob = max_probs[0][alias_idx].item() pred_qid = entity_cands[alias_idx][pred_idx] if pred_prob > self.threshold: pred_cands.append(pred_qid) pred_probs.append(pred_prob) titles.append( self.entity_db. get_title(pred_qid) if pred_qid != 'NC' else 'NC') return pred_cands, pred_probs, titles
def label_mentions(self, text_list, label_func=find_aliases_in_sentence_tag): """Extracts mentions and runs disambiguation. Args: text_list: list of text to disambiguate (or single sentence) label_func: mention extraction funciton (optional) Returns: Dict of * ``qids``: final predicted QIDs, * ``probs``: final predicted probs, * ``titles``: final predicted titles, * ``cands``: all entity canddiates, * ``cand_probs``: probabilities of all candidates, * ``spans``: final extracted word spans, * ``aliases``: final extracted aliases, """ if type(text_list) is str: text_list = [text_list] else: assert (type(text_list) is list and len(text_list) > 0 and type(text_list[0]) is str ), f"We only accept inputs of strings and lists of strings" ebs = int(self.config.run_config.eval_batch_size) self.config.data_config.max_aliases = int( self.config.data_config.max_aliases) total_start_exs = 0 total_final_exs = 0 dropped_by_thresh = 0 final_char_spans = [] batch_example_aliases = [] batch_example_aliases_locs_start = [] batch_example_aliases_locs_end = [] batch_example_alias_list_pos = [] batch_example_true_entities = [] batch_word_indices = [] batch_spans_arr = [] batch_aliases_arr = [] batch_idx_unq = [] batch_subsplit_idx = [] for idx_unq, text in tqdm( enumerate(text_list), desc="Prepping data", total=len(text_list), disable=not self.verbose, ): sample = self.extract_mentions(text, label_func) total_start_exs += len(sample["aliases"]) char_spans = self.get_char_spans(sample["spans"], text) final_char_spans.append(char_spans) ( idxs_arr, aliases_to_predict_per_split, spans_arr, phrase_tokens_arr, pos_idxs, ) = sentence_utils.split_sentence( max_aliases=self.config.data_config.max_aliases, phrase=sample["sentence"], spans=sample["spans"], aliases=sample["aliases"], aliases_seen_by_model=list(range(len(sample["aliases"]))), seq_len=self.config.data_config.max_seq_len, is_bert=True, tokenizer=self.tokenizer, ) aliases_arr = [[sample["aliases"][idx] for idx in idxs] for idxs in idxs_arr] old_spans_arr = [[sample["spans"][idx] for idx in idxs] for idxs in idxs_arr] qids_arr = [[sample["qids"][idx] for idx in idxs] for idxs in idxs_arr] word_indices_arr = [ self.tokenizer.convert_tokens_to_ids(pt) for pt in phrase_tokens_arr ] # iterate over each sample in the split for sub_idx in range(len(idxs_arr)): # ==================================================== # GENERATE MODEL INPUTS # ==================================================== aliases_to_predict_arr = aliases_to_predict_per_split[sub_idx] assert ( len(aliases_to_predict_arr) >= 0 ), f"There are no aliases to predict for an example. This should not happen at this point." assert ( len(aliases_arr[sub_idx]) <= self.config.data_config.max_aliases ), f"{sample} should have no more than {self.config.data_config.max_aliases} aliases." example_aliases = np.ones( self.config.data_config.max_aliases) * PAD_ID example_aliases_locs_start = ( np.ones(self.config.data_config.max_aliases) * PAD_ID) example_aliases_locs_end = ( np.ones(self.config.data_config.max_aliases) * PAD_ID) example_alias_list_pos = ( np.ones(self.config.data_config.max_aliases) * PAD_ID) example_true_entities = ( np.ones(self.config.data_config.max_aliases) * PAD_ID) for mention_idx, alias in enumerate(aliases_arr[sub_idx]): span_start_idx, span_end_idx = spans_arr[sub_idx][ mention_idx] # generate indexes into alias table. alias_trie_idx = self.entity_db.get_alias_idx(alias) alias_qids = np.array(self.entity_db.get_qid_cands(alias)) if not qids_arr[sub_idx][mention_idx] in alias_qids: # assert not data_args.train_in_candidates if not self.config.data_config.train_in_candidates: # set class label to be "not in candidate set" true_entity_idx = 0 else: true_entity_idx = -2 else: # Here we are getting the correct class label for training. # Our training is "which of the max_entities entity candidates is the right one # (class labels 1 to max_entities) or is it none of these (class label 0)". # + (not discard_noncandidate_entities) is to ensure label 0 is # reserved for "not in candidate set" class true_entity_idx = np.nonzero( alias_qids == qids_arr[sub_idx][mention_idx] )[0][0] + ( not self.config.data_config.train_in_candidates) example_aliases[mention_idx] = alias_trie_idx example_aliases_locs_start[mention_idx] = span_start_idx # The span_idxs are [start, end). We want [start, end]. So subtract 1 from end idx. example_aliases_locs_end[mention_idx] = span_end_idx - 1 example_alias_list_pos[mention_idx] = idxs_arr[sub_idx][ mention_idx] # leave as -1 if it's not an alias we want to predict; we get these if we split a sentence # and need to only predict subsets if mention_idx in aliases_to_predict_arr: example_true_entities[mention_idx] = true_entity_idx # get word indices word_indices = word_indices_arr[sub_idx] batch_example_aliases.append(example_aliases) batch_example_aliases_locs_start.append( example_aliases_locs_start) batch_example_aliases_locs_end.append(example_aliases_locs_end) batch_example_alias_list_pos.append(example_alias_list_pos) batch_example_true_entities.append(example_true_entities) batch_word_indices.append(word_indices) batch_aliases_arr.append(aliases_arr[sub_idx]) # Add the orginal sample spans because spans_arr is w.r.t BERT subword token batch_spans_arr.append(old_spans_arr[sub_idx]) batch_idx_unq.append(idx_unq) batch_subsplit_idx.append(sub_idx) batch_example_aliases = torch.tensor(batch_example_aliases).long() batch_example_aliases_locs_start = torch.tensor( batch_example_aliases_locs_start, device=self.torch_device) batch_example_aliases_locs_end = torch.tensor( batch_example_aliases_locs_end, device=self.torch_device) batch_example_true_entities = torch.tensor(batch_example_true_entities, device=self.torch_device) batch_word_indices = torch.tensor(batch_word_indices, device=self.torch_device) final_pred_cands = [[] for _ in range(len(text_list))] final_all_cands = [[] for _ in range(len(text_list))] final_cand_probs = [[] for _ in range(len(text_list))] final_pred_probs = [[] for _ in range(len(text_list))] final_titles = [[] for _ in range(len(text_list))] final_spans = [[] for _ in range(len(text_list))] final_aliases = [[] for _ in range(len(text_list))] for b_i in tqdm( range(0, batch_example_aliases.shape[0], ebs), desc="Evaluating model", disable=not self.verbose, ): start_span_idx = batch_example_aliases_locs_start[b_i:b_i + ebs] end_span_idx = batch_example_aliases_locs_end[b_i:b_i + ebs] word_indices = batch_word_indices[b_i:b_i + ebs] alias_indices = batch_example_aliases[b_i:b_i + ebs] x_dict = self.get_forward_batch(start_span_idx, end_span_idx, word_indices, alias_indices) x_dict["guid"] = torch.arange(b_i, b_i + ebs, device=self.torch_device) (uid_bdict, _, prob_bdict, _) = self.model( # type: ignore uids=x_dict["guid"], X_dict=x_dict, Y_dict=None, task_to_label_dict=self.task_to_label_dict, return_action_outputs=False, ) # ==================================================== # EVALUATE MODEL OUTPUTS # ==================================================== # recover predictions probs = prob_bdict[NED_TASK] max_probs = probs.max(2) max_probs_indices = probs.argmax(2) for ex_i in range(probs.shape[0]): idx_unq = batch_idx_unq[b_i + ex_i] entity_cands = eval_utils.map_aliases_to_candidates( self.config.data_config.train_in_candidates, self.config.data_config.max_aliases, self.entity_db.get_alias2qids(), batch_aliases_arr[b_i + ex_i], ) # batch size is 1 so we can reshape probs_ex = probs[ex_i].reshape( self.config.data_config.max_aliases, probs.shape[2]) for alias_idx, true_entity_pos_idx in enumerate( batch_example_true_entities[b_i + ex_i]): if true_entity_pos_idx != PAD_ID: pred_idx = max_probs_indices[ex_i][alias_idx] pred_prob = max_probs[ex_i][alias_idx].item() all_cands = entity_cands[alias_idx] pred_qid = all_cands[pred_idx] if pred_prob > self.threshold: final_all_cands[idx_unq].append(all_cands) final_cand_probs[idx_unq].append( probs_ex[alias_idx]) final_pred_cands[idx_unq].append(pred_qid) final_pred_probs[idx_unq].append(pred_prob) final_aliases[idx_unq].append( batch_aliases_arr[b_i + ex_i][alias_idx]) final_spans[idx_unq].append( batch_spans_arr[b_i + ex_i][alias_idx]) final_titles[idx_unq].append( self.entity_db.get_title(pred_qid) if pred_qid != "NC" else "NC") total_final_exs += 1 else: dropped_by_thresh += 1 assert total_final_exs + dropped_by_thresh == total_start_exs, ( f"Something went wrong and we have predicted fewer mentions than extracted. " f"Start {total_start_exs}, Out {total_final_exs}, No cand {dropped_by_thresh}" ) res_dict = { "qids": final_pred_cands, "probs": final_pred_probs, "titles": final_titles, "cands": final_all_cands, "cand_probs": final_cand_probs, "spans": final_spans, "aliases": final_aliases, } return res_dict
def test_real_cases_bert(self): # Example 1 max_aliases = 10 max_seq_len = 100 # Manual data sentence = "The guest roster for O'Brien 's final show on January 22\u2014 Tom Hanks , Steve Carell and original first guest Will Ferrell \u2014was regarded by O'Brien as a `` dream lineup '' ; in addition , Neil Young performed his song `` Long May You Run `` and , as the show closed , was joined by Beck , Ferrell ( dressed as Ronnie Van Zant ) , Billy Gibbons , Ben Harper , O'Brien , Viveca Paulin , and The Tonight Show Band to perform the Lynyrd Skynyrd song `` Free Bird `` ." aliases = [ "tom hanks", "steve carell", "will ferrell", "neil young", "long may you run", "beck", "ronnie van zant", "billy gibbons", "ben harper", "viveca paulin", "lynyrd skynyrd", "free bird" ] spans = [[11, 13], [14, 16], [20, 22], [36, 38], [42, 46], [57, 58], [63, 66], [68, 70], [71, 73], [76, 78], [87, 89], [91, 93]] aliases_to_predict = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] # Truth true_phrase_arr = [ [ '[CLS]', 'The', 'guest', 'roster', 'for', 'O', "'", 'Brien', "'", 's', 'final', 'show', 'on', 'January', '22', '—', 'Tom', 'Hank', '##s', ',', 'Steve', 'Care', '##ll', 'and', 'original', 'first', 'guest', 'Will', 'Fe', '##rrell', '—', 'was', 'regarded', 'by', 'O', "'", 'Brien', 'as', 'a', '`', '`', 'dream', 'lineup', "'", "'", ';', 'in', 'addition', ',', 'Neil', 'Young', 'performed', 'his', 'song', '`', '`', 'Long', 'May', 'You', 'Run', '`', '`', 'and', ',', 'as', 'the', 'show', 'closed', ',', 'was', 'joined', 'by', 'Beck', ',', 'Fe', '##rrell', '(', 'dressed', 'as', 'Ronnie', 'Van', 'Z', '##ant', ')', ',', 'Billy', 'Gibbons', ',', 'Ben', 'Harper', ',', 'O', "'", 'Brien', ',', 'V', '##ive', '##ca', 'Paul', '##in', ',', '[SEP]' ], [ '[CLS]', 'The', 'guest', 'roster', 'for', 'O', "'", 'Brien', "'", 's', 'final', 'show', 'on', 'January', '22', '—', 'Tom', 'Hank', '##s', ',', 'Steve', 'Care', '##ll', 'and', 'original', 'first', 'guest', 'Will', 'Fe', '##rrell', '—', 'was', 'regarded', 'by', 'O', "'", 'Brien', 'as', 'a', '`', '`', 'dream', 'lineup', "'", "'", ';', 'in', 'addition', ',', 'Neil', 'Young', 'performed', 'his', 'song', '`', '`', 'Long', 'May', 'You', 'Run', '`', '`', 'and', ',', 'as', 'the', 'show', 'closed', ',', 'was', 'joined', 'by', 'Beck', ',', 'Fe', '##rrell', '(', 'dressed', 'as', 'Ronnie', 'Van', 'Z', '##ant', ')', ',', 'Billy', 'Gibbons', ',', 'Ben', 'Harper', ',', 'O', "'", 'Brien', ',', 'V', '##ive', '##ca', 'Paul', '##in', ',', '[SEP]' ], [ '[CLS]', 'original', 'first', 'guest', 'Will', 'Fe', '##rrell', '—', 'was', 'regarded', 'by', 'O', "'", 'Brien', 'as', 'a', '`', '`', 'dream', 'lineup', "'", "'", ';', 'in', 'addition', ',', 'Neil', 'Young', 'performed', 'his', 'song', '`', '`', 'Long', 'May', 'You', 'Run', '`', '`', 'and', ',', 'as', 'the', 'show', 'closed', ',', 'was', 'joined', 'by', 'Beck', ',', 'Fe', '##rrell', '(', 'dressed', 'as', 'Ronnie', 'Van', 'Z', '##ant', ')', ',', 'Billy', 'Gibbons', ',', 'Ben', 'Harper', ',', 'O', "'", 'Brien', ',', 'V', '##ive', '##ca', 'Paul', '##in', ',', 'and', 'The', 'Tonight', 'Show', 'Band', 'to', 'perform', 'the', 'L', '##yn', '##yr', '##d', 'Sky', '##ny', '##rd', 'song', '`', '`', 'Free', 'Bird', '`', '`', '.', '[SEP]' ] ] true_spans_arr = [[[12, 14], [15, 17], [23, 25], [40, 42], [46, 50], [61, 62], [67, 70], [72, 74]], [[46, 50], [61, 62], [67, 70], [72, 74], [76, 78], [81, 84], [93, 95], [100, 102]], [[17, 19], [23, 27], [38, 39], [44, 47], [49, 51], [53, 55], [58, 61], [70, 72], [77, 79]]] true_alias_to_predict_arr = [[0, 1, 2, 3, 4, 5], [2, 3, 4, 5], [7, 8]] true_aliases_arr = [[ "tom hanks", "steve carell", "will ferrell", "neil young", "long may you run", "beck", "ronnie van zant", "billy gibbons" ], [ "long may you run", "beck", "ronnie van zant", "billy gibbons", "ben harper", "viveca paulin", "lynyrd skynyrd", "free bird" ], [ "neil young", "long may you run", "beck", "ronnie van zant", "billy gibbons", "ben harper", "viveca paulin", "lynyrd skynyrd", "free bird" ]] # Run function args = parser_utils.get_full_config( "test/run_args/test_data_bert.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) assert len(idxs_arr) == 3 assert len(aliases_to_predict_arr) == 3 assert len(spans_arr) == 3 assert len(phrase_tokens_arr) == 3 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def test_split_sentence_bert(self): # Example 1 max_aliases = 30 max_seq_len = 20 # Manual data sentence = 'Kittens love purpleish pupppeteers because alias2 and spanning the brreaches alias5' aliases = ["Kittens love", "alias2", "alias5"] spans = [[0, 2], [5, 6], [10, 11]] aliases_to_predict = [0, 1, 2] # Truth bert_tokenized = [ 'Kit', '##tens', 'love', 'purple', '##ish', 'pu', '##pp', '##pet', '##eers', 'because', 'alias', '##2', 'and', 'spanning', 'the', 'br', '##rea', '##ches', 'alias', '##5' ] true_phrase_arr = [['[CLS]'] + bert_tokenized + ['[SEP]']] true_spans_arr = [[[1, 4], [11, 13], [19, 21]]] true_alias_to_predict_arr = [[0, 1, 2]] true_aliases_arr = [["Kittens love", "alias2", "alias5"]] # Run function args = parser_utils.get_full_config( "test/run_args/test_data_bert.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) assert len(idxs_arr) == 1 assert len(aliases_to_predict_arr) == 1 assert len(spans_arr) == 1 assert len(phrase_tokens_arr) == 1 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i]) # Example 2 max_aliases = 30 max_seq_len = 7 # Manual data sentence = 'Kittens love purpleish pupppeteers because alias2 and spanning the brreaches alias5' aliases = ["Kittens love", "alias2", "alias5"] spans = [[0, 2], [5, 6], [10, 11]] aliases_to_predict = [0, 1, 2] # Run function args = parser_utils.get_full_config( "test/run_args/test_data_bert.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # Truth true_phrase_arr = [[ '[CLS]', 'Kit', '##tens', 'love', 'purple', '##ish', 'pu', '##pp', '[SEP]' ], [ '[CLS]', '##eers', 'because', 'alias', '##2', 'and', 'spanning', 'the', '[SEP]' ], [ '[CLS]', 'spanning', 'the', 'br', '##rea', '##ches', 'alias', '##5', '[SEP]' ]] true_spans_arr = [[[1, 4]], [[3, 5]], [[6, 8]]] true_alias_to_predict_arr = [[0], [0], [0]] true_aliases_arr = [["Kittens love"], ["alias2"], ["alias5"]] assert len(idxs_arr) == 3 assert len(aliases_to_predict_arr) == 3 assert len(spans_arr) == 3 assert len(phrase_tokens_arr) == 3 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i]) # Example 3: Test greedy nature of algorithm. It will greedily pack the first two aliases together and the last alias will be split up even though the second alias is also in the second split. max_aliases = 30 max_seq_len = 18 # Manual data sentence = 'Kittens Kittens Kittens Kittens love purpleish pupppeteers because alias2 and spanning the brreaches alias5' aliases = ["Kittens love", "alias2", "alias5"] spans = [[3, 5], [8, 9], [13, 14]] aliases_to_predict = [0, 1, 2] # Run function args = parser_utils.get_full_config( "test/run_args/test_data_bert.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [[ '[CLS]', '##tens', 'Kit', '##tens', 'Kit', '##tens', 'love', 'purple', '##ish', 'pu', '##pp', '##pet', '##eers', 'because', 'alias', '##2', 'and', 'spanning', 'the', '[SEP]' ], [ '[CLS]', 'love', 'purple', '##ish', 'pu', '##pp', '##pet', '##eers', 'because', 'alias', '##2', 'and', 'spanning', 'the', 'br', '##rea', '##ches', 'alias', '##5', '[SEP]' ]] true_spans_arr = [[[4, 7], [14, 16]], [[9, 11], [17, 19]]] true_alias_to_predict_arr = [[0, 1], [1]] true_aliases_arr = [["Kittens love", "alias2"], ["alias2", "alias5"]] assert len(idxs_arr) == 2 assert len(aliases_to_predict_arr) == 2 assert len(spans_arr) == 2 assert len(phrase_tokens_arr) == 2 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def test_real_cases(self): # Real examples we messed up # EXAMPLE 1 max_aliases = 30 max_seq_len = 50 # 3114|0~*~1~*~2~*~3~*~4~*~5|mexico~*~panama~*~ecuador~*~peru~*~bolivia~*~colombia|3966054~*~22997~*~9334~*~170691~*~3462~*~5222|19:20~*~36:37~*~39:40~*~44:45~*~48:49~*~70:71|The animal is called paca in most of its range but tepezcuintle original Aztec language name in most of Mexico and Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media del Río Lebrija en Rionegro Santander Humboldt Institute or guartinaja in Colombia sentence = 'The animal is called paca in most of its range but tepezcuintle original Aztec language name in most of Mexico and Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media del Río Lebrija en Rionegro Santander Humboldt Institute or guartinaja in Colombia' aliases = [ "mexico", "panama", "ecuador", "peru", "bolivia", "colombia" ] aliases_to_predict = [0, 1, 2, 3, 4, 5] spans = [[19, 20], [36, 37], [39, 40], [44, 45], [48, 49], [70, 71]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [ 'range but tepezcuintle original Aztec language name in most of Mexico and Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media' .split(), 'Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media del Río Lebrija en Rionegro Santander Humboldt Institute or guartinaja in Colombia' .split() ] true_spans_arr = [[[10, 11], [27, 28], [30, 31], [35, 36], [39, 40]], [[15, 16], [18, 19], [23, 24], [27, 28], [49, 50]]] true_alias_to_predict_arr = [[0, 1, 2, 3, 4], [4]] true_aliases_arr = [["mexico", "panama", "ecuador", "peru", "bolivia"], [ "panama", "ecuador", "peru", "bolivia", "colombia" ]] assert len(idxs_arr) == 2 assert len(aliases_to_predict_arr) == 2 assert len(spans_arr) == 2 assert len(phrase_tokens_arr) == 2 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i]) # EXAMPLE 2 max_aliases = 10 max_seq_len = 50 # 20|0~*~1~*~2~*~3~*~4~*~5~*~6~*~7~*~8~*~9~*~10~*~11~*~12~*~13~*~14~*~15~*~16~*~17~*~18~*~19~*~20|coolock~*~swords~*~darndale~*~santry~*~donnycarney~*~baldoyle~*~sutton~*~donaghmede~*~artane~*~whitehall~*~kilbarrack~*~raheny~*~clontarf~*~fairview~*~malahide~*~howth~*~marino~*~ballybough~*~north strand~*~sheriff street~*~east wall|1037463~*~182210~*~8554720~*~2432965~*~7890942~*~1223621~*~1008011~*~3698049~*~1469895~*~2144656~*~3628425~*~1108214~*~1564212~*~1438118~*~944694~*~1037467~*~5745962~*~2436385~*~5310245~*~12170199~*~2814197|12:13~*~14:15~*~15:16~*~17:18~*~18:19~*~19:20~*~20:21~*~21:22~*~22:23~*~23:24~*~24:25~*~25:26~*~26:27~*~27:28~*~28:29~*~29:30~*~30:31~*~38:39~*~39:41~*~41:43~*~43:45|East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall sentence = "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall" aliases = [ "coolock", "swords", "darndale", "santry", "donnycarney", "baldoyle", "sutton", "donaghmede", "artane", "whitehall", "kilbarrack", "raheny", "clontarf", "fairview", "malahide", "howth", "marino", "ballybough", "north strand", "sheriff street", "east wall" ] aliases_to_predict = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 171, 8, 19, 20 ] spans = [[12, 13], [14, 15], [15, 16], [17, 18], [18, 19], [19, 20], [20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 31], [38, 39], [39, 41], [41, 43], [43, 45]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # Truth true_phrase_arr = [ "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall <pad> <pad> <pad> <pad> <pad>" .split(), "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall <pad> <pad> <pad> <pad> <pad>" .split(), "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall <pad> <pad> <pad> <pad> <pad>" .split() ] true_spans_arr = [[[12, 13], [14, 15], [15, 16], [17, 18], [18, 19], [19, 20], [20, 21], [21, 22]], [[20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26], [26, 27], [27, 28], [28, 29]], [[27, 28], [28, 29], [29, 30], [30, 31], [38, 39], [39, 41], [41, 43], [43, 45]]] true_alias_to_predict_arr = [[0, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 6, 7]] true_aliases_arr = [[ "coolock", "swords", "darndale", "santry", "donnycarney", "baldoyle", "sutton", "donaghmede" ], [ "sutton", "donaghmede", "artane", "whitehall", "kilbarrack", "raheny", "clontarf", "fairview", "malahide" ], [ "fairview", "malahide", "howth", "marino", "ballybough", "north strand", "sheriff street", "east wall" ]] assert len(idxs_arr) == 3 assert len(aliases_to_predict_arr) == 3 assert len(spans_arr) == 3 assert len(phrase_tokens_arr) == 3 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i]) # Example 2 max_aliases = 10 max_seq_len = 100 # 84|0~*~1|kentucky~*~green|621151~*~478999|8:9~*~9:10|The Assembly also reserved tolls collected on the Kentucky Green and Barren rivers for education and passed a two percent property tax to fund the state s schools sentence = "The Assembly also reserved tolls collected on the Kentucky Green and Barren rivers for education and passed a two percent property tax to fund the state s schools" aliases = ["kentucky", "green"] aliases_to_predict = [0, 1] spans = [[8, 9], [9, 10]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [ "The Assembly also reserved tolls collected on the Kentucky Green and Barren rivers for education and passed a two percent property tax to fund the state s schools <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>" .split() ] true_spans_arr = [[[8, 9], [9, 10]]] true_alias_to_predict_arr = [[0, 1]] true_aliases_arr = [["kentucky", "green"]] assert len(idxs_arr) == 1 assert len(aliases_to_predict_arr) == 1 assert len(spans_arr) == 1 assert len(phrase_tokens_arr) == 1 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def test_seq_length(self): # Test maximum sequence length max_aliases = 30 max_seq_len = 12 # Manual data sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5' aliases = ["The big", "alias3", "alias5"] aliases_to_predict = [0, 1, 2] spans = [[0, 2], [12, 13], [20, 21]] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [ "The big alias1 ran away from dogs and multi word alias2 and". split(), "word alias2 and alias3 because we want our cat and our alias5". split() ] true_spans_arr = [[[0, 2]], [[3, 4], [11, 12]]] true_alias_to_predict_arr = [[0], [0, 1]] true_aliases_arr = [["The big"], ["alias3", "alias5"]] assert len(idxs_arr) == 2 assert len(aliases_to_predict_arr) == 2 assert len(spans_arr) == 2 assert len(phrase_tokens_arr) == 2 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i]) # Now test with modified aliases to perdict aliases_to_predict = [1, 2] # Run function args = parser_utils.get_full_config("test/run_args/test_data.json") word_symbols = data_utils.load_wordsymbols(args.data_config) idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence( max_aliases, sentence, spans, aliases, aliases_to_predict, max_seq_len, word_symbols) # True data true_phrase_arr = [ "word alias2 and alias3 because we want our cat and our alias5". split() ] true_spans_arr = [[[3, 4], [11, 12]]] true_alias_to_predict_arr = [[0, 1]] true_aliases_arr = [["alias3", "alias5"]] assert len(idxs_arr) == 1 assert len(aliases_to_predict_arr) == 1 assert len(spans_arr) == 1 assert len(phrase_tokens_arr) == 1 for i in range(len(idxs_arr)): self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len) self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i]) self.assertEqual(spans_arr[i], true_spans_arr[i]) self.assertEqual(aliases_to_predict_arr[i], true_alias_to_predict_arr[i]) self.assertEqual([aliases[idx] for idx in idxs_arr[i]], true_aliases_arr[i])
def label_mentions(self, text_list): if type(text_list) is str: text_list = [text_list] else: assert type(text_list) is list and len(text_list) > 0 and type( text_list[0]) is str, f"We only accept inputs of strings and lists of strings" ebs = self.args.run_config.eval_batch_size total_start_exs = 0 total_final_exs = 0 dropped_by_thresh = 0 final_char_spans = [] batch_example_aliases = [] batch_example_aliases_locs_start = [] batch_example_aliases_locs_end = [] batch_example_alias_list_pos = [] batch_example_true_entities = [] batch_word_indices = [] batch_spans_arr = [] batch_aliases_arr = [] batch_idx_unq = [] batch_subsplit_idx = [] for idx_unq, text in tqdm(enumerate(text_list), desc="Prepping data", total=len(text_list)): sample = self.extract_mentions(text) total_start_exs += len(sample['aliases']) char_spans = self.get_char_spans(sample['spans'], text) final_char_spans.append(char_spans) idxs_arr, aliases_to_predict_per_split, spans_arr, phrase_tokens_arr, pos_idxs = sentence_utils.split_sentence( max_aliases=self.args.data_config.max_aliases, phrase=sample['sentence'], spans=sample['spans'], aliases=sample['aliases'], aliases_seen_by_model=[i for i in range(len(sample['aliases']))], seq_len=self.args.data_config.max_word_token_len, word_symbols=self.word_db) aliases_arr = [[sample['aliases'][idx] for idx in idxs] for idxs in idxs_arr] old_spans_arr = [[sample['spans'][idx] for idx in idxs] for idxs in idxs_arr] qids_arr = [[sample['qids'][idx] for idx in idxs] for idxs in idxs_arr] word_indices_arr = [self.word_db.convert_tokens_to_ids(pt) for pt in phrase_tokens_arr] # iterate over each sample in the split for sub_idx in range(len(idxs_arr)): # ==================================================== # GENERATE MODEL INPUTS # ==================================================== aliases_to_predict_arr = aliases_to_predict_per_split[sub_idx] assert len(aliases_to_predict_arr) >= 0, f'There are no aliases to predict for an example. This should not happen at this point.' assert len(aliases_arr[ sub_idx]) <= self.args.data_config.max_aliases, f'Each example should have no more that {self.args.data_config.max_aliases} max aliases. {sample} does.' example_aliases = np.ones(self.args.data_config.max_aliases) * PAD_ID example_aliases_locs_start = np.ones(self.args.data_config.max_aliases) * PAD_ID example_aliases_locs_end = np.ones(self.args.data_config.max_aliases) * PAD_ID example_alias_list_pos = np.ones(self.args.data_config.max_aliases) * PAD_ID example_true_entities = np.ones(self.args.data_config.max_aliases) * PAD_ID for mention_idx, alias in enumerate(aliases_arr[sub_idx]): span_start_idx, span_end_idx = spans_arr[sub_idx][mention_idx] # generate indexes into alias table. alias_trie_idx = self.entity_db.get_alias_idx(alias) alias_qids = np.array(self.entity_db.get_qid_cands(alias)) if not qids_arr[sub_idx][mention_idx] in alias_qids: # assert not data_args.train_in_candidates if not self.args.data_config.train_in_candidates: # set class label to be "not in candidate set" true_entity_idx = 0 else: true_entity_idx = -2 else: # Here we are getting the correct class label for training. # Our training is "which of the max_entities entity candidates is the right one (class labels 1 to max_entities) or is it none of these (class label 0)". # + (not discard_noncandidate_entities) is to ensure label 0 is reserved for "not in candidate set" class true_entity_idx = np.nonzero(alias_qids == qids_arr[sub_idx][mention_idx])[0][0] + ( not self.args.data_config.train_in_candidates) example_aliases[mention_idx] = alias_trie_idx example_aliases_locs_start[mention_idx] = span_start_idx # The span_idxs are [start, end). We want [start, end]. So subtract 1 from end idx. example_aliases_locs_end[mention_idx] = span_end_idx - 1 example_alias_list_pos[mention_idx] = idxs_arr[sub_idx][mention_idx] # leave as -1 if it's not an alias we want to predict; we get these if we split a sentence and need to only predict subsets if mention_idx in aliases_to_predict_arr: example_true_entities[mention_idx] = true_entity_idx # get word indices word_indices = word_indices_arr[sub_idx] batch_example_aliases.append(example_aliases) batch_example_aliases_locs_start.append(example_aliases_locs_start) batch_example_aliases_locs_end.append(example_aliases_locs_end) batch_example_alias_list_pos.append(example_alias_list_pos) batch_example_true_entities.append(example_true_entities) batch_word_indices.append(word_indices) batch_aliases_arr.append(aliases_arr[sub_idx]) # Add the orginal sample spans because spans_arr is w.r.t BERT subword token batch_spans_arr.append(old_spans_arr[sub_idx]) batch_idx_unq.append(idx_unq) batch_subsplit_idx.append(sub_idx) batch_example_aliases = torch.tensor(batch_example_aliases).long() batch_example_aliases_locs_start = torch.tensor(batch_example_aliases_locs_start, device=self.device) batch_example_aliases_locs_end = torch.tensor(batch_example_aliases_locs_end, device=self.device) batch_example_true_entities = torch.tensor(batch_example_true_entities, device=self.device) batch_word_indices = torch.tensor(batch_word_indices, device=self.device) final_pred_cands = [[] for _ in range(len(text_list))] final_all_cands = [[] for _ in range(len(text_list))] final_cand_probs = [[] for _ in range(len(text_list))] final_pred_probs = [[] for _ in range(len(text_list))] final_titles = [[] for _ in range(len(text_list))] final_spans = [[] for _ in range(len(text_list))] final_aliases = [[] for _ in range(len(text_list))] for b_i in tqdm(range(0, batch_example_aliases.shape[0], ebs), desc="Evaluating model"): # entity indices from alias table (these are the candidates) batch_entity_indices = self.alias_table(batch_example_aliases[b_i:b_i + ebs]) # all CPU embs have to retrieved on the fly batch_on_the_fly_data = {} for emb_name, emb in self.batch_on_the_fly_embs.items(): batch_prep = [] for j in range(b_i, min(b_i + ebs, batch_example_aliases.shape[0])): batch_prep.append(emb.batch_prep(batch_example_aliases[j], batch_entity_indices[j - b_i])) batch_on_the_fly_data[emb_name] = torch.tensor(batch_prep, device=self.device) alias_idx_pair_sent = [batch_example_aliases_locs_start[b_i:b_i + ebs], batch_example_aliases_locs_end[b_i:b_i + ebs]] word_indices = batch_word_indices[b_i:b_i + ebs] alias_indices = batch_example_aliases[b_i:b_i + ebs] entity_indices = torch.tensor(batch_entity_indices, device=self.device) outs, entity_pack, _ = self.model( alias_idx_pair_sent=alias_idx_pair_sent, word_indices=word_indices, alias_indices=alias_indices, entity_indices=entity_indices, batch_prepped_data={}, batch_on_the_fly_data=batch_on_the_fly_data) # ==================================================== # EVALUATE MODEL OUTPUTS # ==================================================== final_loss_vals = outs[DISAMBIG][FINAL_LOSS] # recover predictions probs = torch.exp(eval_utils.masked_class_logsoftmax(pred=final_loss_vals, mask=~entity_pack.mask, dim=2)) max_probs, max_probs_indices = probs.max(2) for ex_i in range(final_loss_vals.shape[0]): idx_unq = batch_idx_unq[b_i + ex_i] subsplit_idx = batch_subsplit_idx[b_i + ex_i] entity_cands = eval_utils.map_aliases_to_candidates(self.args.data_config.train_in_candidates, self.entity_db, batch_aliases_arr[b_i + ex_i]) # batch size is 1 so we can reshape probs_ex = probs[ex_i].detach().cpu().numpy().reshape(self.args.data_config.max_aliases, probs.shape[2]) for alias_idx, true_entity_pos_idx in enumerate(batch_example_true_entities[b_i + ex_i]): if true_entity_pos_idx != PAD_ID: pred_idx = max_probs_indices[ex_i][alias_idx] pred_prob = max_probs[ex_i][alias_idx].item() all_cands = entity_cands[alias_idx] pred_qid = all_cands[pred_idx] if pred_prob > self.threshold: final_all_cands[idx_unq].append(all_cands) final_cand_probs[idx_unq].append(probs_ex[alias_idx]) final_pred_cands[idx_unq].append(pred_qid) final_pred_probs[idx_unq].append(pred_prob) final_aliases[idx_unq].append(batch_aliases_arr[b_i + ex_i][alias_idx]) final_spans[idx_unq].append(batch_spans_arr[b_i + ex_i][alias_idx]) final_titles[idx_unq].append(self.entity_db.get_title(pred_qid) if pred_qid != 'NC' else 'NC') total_final_exs += 1 else: dropped_by_thresh += 1 assert total_final_exs + dropped_by_thresh == total_start_exs, f"Something went wrong and we have predicted fewer mentions than extracted. Start {total_start_exs}, Out {total_final_exs}, No cand {dropped_by_thresh}" return final_pred_cands, final_pred_probs, final_titles, final_all_cands, final_cand_probs, final_spans, final_aliases
def label_mentions( self, text_list=None, label_func=find_aliases_in_sentence_tag, extracted_examples=None, ): """Extracts mentions and runs disambiguation. If user provides extracted_examples, we will ignore text_list Args: text_list: list of text to disambiguate (or single string) (can be None if extracted_examples is not None) label_func: mention extraction funciton (optional) extracted_examples: List of Dicts of keys "sentence", "aliases", "spans", "cands" (QIDs) (optional) Returns: Dict of * ``qids``: final predicted QIDs, * ``probs``: final predicted probs, * ``titles``: final predicted titles, * ``cands``: all entity candidates, * ``cand_probs``: probabilities of all candidates, * ``spans``: final extracted word spans, * ``aliases``: final extracted aliases, * ``embs``: final entity contextualized embeddings (if return_embs is True) * ``cand_embs``: final candidate entity contextualized embeddings (if return_embs is True) """ # Check inputs are sane do_extract_mentions = True if extracted_examples is not None: do_extract_mentions = False assert (type(extracted_examples) is list ), f"Must provide a list of Dics for extracted_examples" check_ex = extracted_examples[0] assert (len({ "sentence", "aliases", "spans", "cands" }.intersection(check_ex.keys())) == 4), ( f"You must have keys of sentence, aliases, spans, and cands for extracted_examples. You have" f"{extracted_examples.keys()}") else: assert ( text_list is not None ), f"If you do not provide extracted_examples you must provide text_list" if text_list is None: assert extracted_examples is not None, ( f"If you do not provide text_list " f"you must provide extracted_exampels") else: if type(text_list) is str: text_list = [text_list] else: assert ( type(text_list) is list and len(text_list) > 0 and type(text_list[0]) is str ), f"We only accept inputs of strings and lists of strings" # Get number of examples if extracted_examples is not None: num_exs = len(extracted_examples) else: num_exs = len(text_list) ebs = int(self.config.run_config.eval_batch_size) self.config.data_config.max_aliases = int( self.config.data_config.max_aliases) total_start_exs = 0 total_final_exs = 0 dropped_by_thresh = 0 final_char_spans = [] batch_example_qid_cands = [] batch_example_eid_cands = [] batch_example_aliases_locs_start = [] batch_example_aliases_locs_end = [] batch_example_alias_list_pos = [] batch_example_true_entities = [] batch_word_indices = [] batch_spans_arr = [] batch_example_aliases = [] batch_idx_unq = [] batch_subsplit_idx = [] for idx_unq in tqdm( range(num_exs), desc="Prepping data", total=num_exs, disable=not self.verbose, ): if do_extract_mentions: sample = self.extract_mentions(text_list[idx_unq], label_func) else: sample = extracted_examples[idx_unq] # Add the unk qids and gold values sample["qids"] = ["Q-1" for _ in range(len(sample["aliases"]))] sample["gold"] = [True for _ in range(len(sample["aliases"]))] total_start_exs += len(sample["aliases"]) char_spans = self.get_char_spans(sample["spans"], sample["sentence"]) final_char_spans.append(char_spans) ( idxs_arr, aliases_to_predict_per_split, spans_arr, phrase_tokens_arr, pos_idxs, ) = sentence_utils.split_sentence( max_aliases=self.config.data_config.max_aliases, phrase=sample["sentence"], spans=sample["spans"], aliases=sample["aliases"], aliases_seen_by_model=list(range(len(sample["aliases"]))), seq_len=self.config.data_config.max_seq_len, is_bert=True, tokenizer=self.tokenizer, ) aliases_arr = [[sample["aliases"][idx] for idx in idxs] for idxs in idxs_arr] old_spans_arr = [[sample["spans"][idx] for idx in idxs] for idxs in idxs_arr] qids_arr = [[sample["qids"][idx] for idx in idxs] for idxs in idxs_arr] word_indices_arr = [ self.tokenizer.convert_tokens_to_ids(pt) for pt in phrase_tokens_arr ] # iterate over each sample in the split for sub_idx in range(len(idxs_arr)): # ==================================================== # GENERATE MODEL INPUTS # ==================================================== aliases_to_predict_arr = aliases_to_predict_per_split[sub_idx] assert ( len(aliases_to_predict_arr) >= 0 ), f"There are no aliases to predict for an example. This should not happen at this point." assert ( len(aliases_arr[sub_idx]) <= self.config.data_config.max_aliases ), f"{sample} should have no more than {self.config.data_config.max_aliases} aliases." example_aliases_locs_start = ( np.ones(self.config.data_config.max_aliases) * PAD_ID) example_aliases_locs_end = ( np.ones(self.config.data_config.max_aliases) * PAD_ID) example_alias_list_pos = ( np.ones(self.config.data_config.max_aliases) * PAD_ID) example_true_entities = ( np.ones(self.config.data_config.max_aliases) * PAD_ID) example_qid_cands = [[ "-1" for _ in range( get_max_candidates(self.entity_db, self.config.data_config)) ] for _ in range(self.config.data_config.max_aliases)] example_eid_cands = [[ -1 for _ in range( get_max_candidates(self.entity_db, self.config.data_config)) ] for _ in range(self.config.data_config.max_aliases)] for mention_idx, alias in enumerate(aliases_arr[sub_idx]): span_start_idx, span_end_idx = spans_arr[sub_idx][ mention_idx] # generate indexes into alias table. alias_qids = np.array(sample["cands"][mention_idx]) # first entry is the non candidate class (NC and eid 0) - used when train in cands is false # if we train in candidates, this gets overwritten example_qid_cands[mention_idx][0] = "NC" example_qid_cands[mention_idx][ (not self.config.data_config.train_in_candidates ):len(alias_qids) + (not self.config.data_config.train_in_candidates )] = sample["cands"][mention_idx] example_eid_cands[mention_idx][0] = 0 example_eid_cands[mention_idx][ (not self.config.data_config.train_in_candidates ):len(alias_qids) + (not self.config.data_config.train_in_candidates)] = [ self.entity_db.get_eid(q) for q in sample["cands"][mention_idx] ] if not qids_arr[sub_idx][mention_idx] in alias_qids: # assert not data_args.train_in_candidates if not self.config.data_config.train_in_candidates: # set class label to be "not in candidate set" true_entity_idx = 0 else: true_entity_idx = -2 else: # Here we are getting the correct class label for training. # Our training is "which of the max_entities entity candidates is the right one # (class labels 1 to max_entities) or is it none of these (class label 0)". # + (not discard_noncandidate_entities) is to ensure label 0 is # reserved for "not in candidate set" class true_entity_idx = np.nonzero( alias_qids == qids_arr[sub_idx][mention_idx] )[0][0] + ( not self.config.data_config.train_in_candidates) example_aliases_locs_start[mention_idx] = span_start_idx # The span_idxs are [start, end). We want [start, end]. So subtract 1 from end idx. example_aliases_locs_end[mention_idx] = span_end_idx - 1 example_alias_list_pos[mention_idx] = idxs_arr[sub_idx][ mention_idx] # leave as -1 if it's not an alias we want to predict; we get these if we split a sentence # and need to only predict subsets if mention_idx in aliases_to_predict_arr: example_true_entities[mention_idx] = true_entity_idx # get word indices word_indices = word_indices_arr[sub_idx] batch_example_qid_cands.append(example_qid_cands) batch_example_eid_cands.append(example_eid_cands) batch_example_aliases_locs_start.append( example_aliases_locs_start) batch_example_aliases_locs_end.append(example_aliases_locs_end) batch_example_alias_list_pos.append(example_alias_list_pos) batch_example_true_entities.append(example_true_entities) batch_word_indices.append(word_indices) batch_example_aliases.append(aliases_arr[sub_idx]) # Add the orginal sample spans because spans_arr is w.r.t BERT subword token batch_spans_arr.append(old_spans_arr[sub_idx]) batch_idx_unq.append(idx_unq) batch_subsplit_idx.append(sub_idx) batch_example_eid_cands = torch.tensor(batch_example_eid_cands).long() batch_example_aliases_locs_start = torch.tensor( batch_example_aliases_locs_start) batch_example_aliases_locs_end = torch.tensor( batch_example_aliases_locs_end) batch_example_true_entities = torch.tensor(batch_example_true_entities) batch_word_indices = torch.tensor(batch_word_indices) final_pred_cands = [[] for _ in range(num_exs)] final_all_cands = [[] for _ in range(num_exs)] final_cand_probs = [[] for _ in range(num_exs)] final_pred_probs = [[] for _ in range(num_exs)] final_entity_embs = [[] for _ in range(num_exs)] final_entity_cand_embs = [[] for _ in range(num_exs)] final_titles = [[] for _ in range(num_exs)] final_spans = [[] for _ in range(num_exs)] final_aliases = [[] for _ in range(num_exs)] for b_i in tqdm( range(0, batch_word_indices.shape[0], ebs), desc="Evaluating model", disable=not self.verbose, ): start_span_idx = batch_example_aliases_locs_start[b_i:b_i + ebs] end_span_idx = batch_example_aliases_locs_end[b_i:b_i + ebs] word_indices = batch_word_indices[b_i:b_i + ebs] eid_cands = batch_example_eid_cands[b_i:b_i + ebs] x_dict = self.get_forward_batch(start_span_idx, end_span_idx, word_indices, eid_cands) x_dict["guid"] = torch.arange(b_i, b_i + ebs, device=self.torch_device) with torch.no_grad(): res = self.model( # type: ignore uids=x_dict["guid"], X_dict=x_dict, Y_dict=None, task_to_label_dict=self.task_to_label_dict, return_action_outputs=self.return_embs, ) del x_dict if self.return_embs: (uid_bdict, _, prob_bdict, _, out_bdict) = res output_embs = out_bdict[NED_TASK][f"{PRED_LAYER}_ent_embs"] else: output_embs = None (uid_bdict, _, prob_bdict, _) = res # ==================================================== # EVALUATE MODEL OUTPUTS # ==================================================== # recover predictions probs = prob_bdict[NED_TASK] max_probs = probs.max(2) max_probs_indices = probs.argmax(2) for ex_i in range(probs.shape[0]): idx_unq = batch_idx_unq[b_i + ex_i] entity_cands = batch_example_qid_cands[b_i + ex_i] # batch size is 1 so we can reshape probs_ex = probs[ex_i].reshape( self.config.data_config.max_aliases, probs.shape[2]) for alias_idx, true_entity_pos_idx in enumerate( batch_example_true_entities[b_i + ex_i]): if true_entity_pos_idx != PAD_ID: pred_idx = max_probs_indices[ex_i][alias_idx] pred_prob = max_probs[ex_i][alias_idx].item() all_cands = entity_cands[alias_idx] pred_qid = all_cands[pred_idx] if pred_prob > self.threshold: final_all_cands[idx_unq].append(all_cands) final_cand_probs[idx_unq].append( probs_ex[alias_idx]) final_pred_cands[idx_unq].append(pred_qid) final_pred_probs[idx_unq].append(pred_prob) if self.return_embs: final_entity_embs[idx_unq].append( output_embs[ex_i][alias_idx][pred_idx]) final_entity_cand_embs[idx_unq].append( output_embs[ex_i][alias_idx]) final_aliases[idx_unq].append( batch_example_aliases[b_i + ex_i][alias_idx]) final_spans[idx_unq].append( batch_spans_arr[b_i + ex_i][alias_idx]) final_titles[idx_unq].append( self.entity_db.get_title(pred_qid) if pred_qid != "NC" else "NC") total_final_exs += 1 else: dropped_by_thresh += 1 assert total_final_exs + dropped_by_thresh == total_start_exs, ( f"Something went wrong and we have predicted fewer mentions than extracted. " f"Start {total_start_exs}, Out {total_final_exs}, No cand {dropped_by_thresh}" ) res_dict = { "qids": final_pred_cands, "probs": final_pred_probs, "titles": final_titles, "cands": final_all_cands, "cand_probs": final_cand_probs, "spans": final_spans, "aliases": final_aliases, } if self.return_embs: res_dict["embs"] = final_entity_embs res_dict["cand_embs"] = final_entity_cand_embs return res_dict