def test_construction_single(self): entity_dump_dir = "test/data/preprocessing/base/entity_data/entity_mappings" entity_symbols = EntitySymbols(load_dir=entity_dump_dir, alias_cand_map_file="alias2qids.json") fmt_types = {"trie1": "qid_cand_with_score"} max_values = {"trie1": 3} input_dicts = {"trie1": entity_symbols.get_alias2qids()} record_trie = RecordTrieCollection(load_dir=None, input_dicts=input_dicts, vocabulary=entity_symbols.get_qid2eid(), fmt_types=fmt_types, max_values=max_values) truealias2qids = { 'alias1': [["Q1", 10], ["Q4", 6]], 'multi word alias2': [["Q2", 5], ["Q1", 3], ["Q4", 2]], 'alias3': [["Q1", 30]], 'alias4': [["Q4", 20], ["Q3", 15], ["Q2", 1]] } for al in truealias2qids: self.assertEqual(truealias2qids[al], record_trie.get_value("trie1", al)) for al in truealias2qids: self.assertEqual([x[0] for x in truealias2qids[al]], record_trie.get_value("trie1", al, getter=lambda x: x[0])) for al in truealias2qids: self.assertEqual([x[1] for x in truealias2qids[al]], record_trie.get_value("trie1", al, getter=lambda x: x[1])) self.assertEqual(set(truealias2qids.keys()), set(record_trie.get_keys("trie1")))
def main(): args = parse_args() print(json.dumps(args, indent=4)) assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1" state_dict, model_state_dict = load_statedict(args.init_checkpoint) assert ENTITY_EMB_KEY in model_state_dict print(f"Loading entity symbols from {os.path.join(args.entity_dir, args.entity_map_dir)}") entity_db = EntitySymbols(os.path.join(args.entity_dir, args.entity_map_dir), args.alias_cand_map_file) print(f"Loading qid2count from {args.qid2count}") qid2count = utils.load_json_file(args.qid2count) print(f"Filtering qids") qid2topk_eid, old2new_eid, toes_eid, new_num_topk_entities = filter_qids(args.perc_emb_drop, entity_db, qid2count) print(f"Filtering embeddings") model_state_dict = filter_embs(new_num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, model_state_dict) # Generate the new old2new_eid weight vector to save in model_state_dict oldeid2topkeid = torch.arange(0, entity_db.num_entities_with_pad_and_nocand) # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing for entity embeddings. So we must manually make it the last entry oldeid2topkeid[-1] = new_num_topk_entities+2-1 for qid, new_eid in qid2topk_eid.items(): old_eid = entity_db.get_eid(qid) oldeid2topkeid[old_eid] = new_eid assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed" assert oldeid2topkeid[-1] == new_num_topk_entities+2-1, "The -1 eid should still map to the last row" model_state_dict[ENTITY_EID_KEY] = oldeid2topkeid state_dict["model"] = model_state_dict print(f"Saving model at {args.save_checkpoint}") torch.save(state_dict, args.save_checkpoint) print(f"Saving entity_db at {os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file)}") utils.dump_json_file(os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file), qid2topk_eid)
def main(): args = parse_args() # compute the max alias len in alias2qids with open(args.alias2qids) as f: alias2qids = ujson.load(f) with open(args.qid2title) as f: qid2title = ujson.load(f) max_alias_len = -1 for alias in alias2qids: assert alias.lower( ) == alias, f'bootleg assumes lowercase aliases in alias candidate maps: {alias}' # ensure only max_candidates per alias qids = sorted(alias2qids[alias], key=lambda x: (x[1], x[0]), reverse=True) alias2qids[alias] = qids[:args.max_candidates] # keep track of the maximum number of words in an alias max_alias_len = max(max_alias_len, len(alias.split())) entity_mappings = EntitySymbols( max_candidates=args.max_candidates, max_alias_len=max_alias_len, alias2qids=alias2qids, qid2title=qid2title, alias_cand_map_file=args.alias_cand_map_file) entity_mappings.dump(os.path.join(args.entity_dir, args.entity_map_dir)) print('entity mappings exported.')
def main(): args = parse_args() # compute the max alias len in alias2qids with open(args.alias2qids) as f: alias2qids = ujson.load(f) with open(args.qid2title) as f: qid2title = ujson.load(f) for alias in alias2qids: assert ( alias.lower() == alias ), f"bootleg assumes lowercase aliases in alias candidate maps: {alias}" # ensure only max_candidates per alias qids = sorted(alias2qids[alias], key=lambda x: (x[1], x[0]), reverse=True) alias2qids[alias] = qids[:args.max_candidates] entity_mappings = EntitySymbols( max_candidates=args.max_candidates, alias2qids=alias2qids, qid2title=qid2title, alias_cand_map_file=args.alias_cand_map_file, ) entity_mappings.save(os.path.join(args.entity_dir, args.entity_map_dir)) print("entity mappings exported.")
def test_reidentify_entity(self): alias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2], ["Q3", 1]], "alias3": [["Q1", 30.0]], "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]], } qid2title = { "Q1": "alias1", "Q2": "multi alias2", "Q3": "word alias3", "Q4": "nonalias4", } max_candidates = 3 entity_symbols = EntitySymbols( max_candidates=max_candidates, alias2qids=alias2qids, qid2title=qid2title, edit_mode=True, ) entity_symbols.reidentify_entity("Q1", "Q7") trueqid2aliases = { "Q7": {"alias1", "multi word alias2", "alias3"}, "Q2": {"multi word alias2", "alias4"}, "Q3": {"alias4"}, "Q4": {"alias1", "multi word alias2", "alias4"}, } truealias2qids = { "alias1": [["Q7", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q7", 3], ["Q4", 2]], "alias3": [["Q7", 30.0]], "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]], } trueqid2eid = {"Q7": 1, "Q2": 2, "Q3": 3, "Q4": 4} truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3} trueid2alias = {0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2"} truemax_eid = 4 truenum_entities = 4 truenum_entities_with_pad_and_nocand = 6 self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases) self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid) self.assertDictEqual( entity_symbols._eid2qid, {v: i for i, v in trueqid2eid.items()} ) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertIsNone(entity_symbols._alias_trie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertDictEqual(entity_symbols._id2alias, trueid2alias) self.assertEqual(entity_symbols.max_eid, truemax_eid) self.assertEqual(entity_symbols.num_entities, truenum_entities) self.assertEqual( entity_symbols.num_entities_with_pad_and_nocand, truenum_entities_with_pad_and_nocand, )
def __init__(self, config_args, device='cuda', max_alias_len=6, cand_map=None, threshold=0.0): self.args = config_args self.device = device self.entity_db = EntitySymbols( os.path.join(self.args.data_config.entity_dir, self.args.data_config.entity_map_dir), alias_cand_map_file=self.args.data_config.alias_cand_map) self.word_db = data_utils.load_wordsymbols(self.args.data_config, is_writer=True, distributed=False) self.model = self._load_model() self.max_alias_len = max_alias_len if cand_map is None: alias_map = self.entity_db._alias2qids else: alias_map = ujson.load(open(cand_map)) self.all_aliases_trie = get_all_aliases(alias_map, logger=logging.getLogger()) self.alias_table = AliasEntityTable(args=self.args, entity_symbols=self.entity_db) # minimum probability of prediction to return mention self.threshold = threshold # get batch_on_the_fly embeddings _and_ the batch_prep embeddings self.batch_on_the_fly_embs = {} for i, emb in enumerate(self.args.data_config.ent_embeddings): if 'batch_prep' in emb and emb['batch_prep'] is True: self.args.data_config.ent_embeddings[i][ 'batch_on_the_fly'] = True del self.args.data_config.ent_embeddings[i]['batch_prep'] if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True: mod, load_class = import_class("bootleg.embeddings", emb.load_class) try: self.batch_on_the_fly_embs[emb.key] = getattr( mod, load_class)(main_args=self.args, emb_args=emb['args'], entity_symbols=self.entity_db, model_device=None, word_symbols=None) except AttributeError as e: print( f'No prep method found for {emb.load_class} with error {e}' ) except Exception as e: print("ERROR", e)
def main(): args = parse_args() logging.info(json.dumps(args, indent=4)) entity_symbols = EntitySymbols( load_dir=os.path.join(args.data_dir, args.entity_symbols_dir)) train_file = os.path.join(args.data_dir, args.train_file) save_dir = os.path.join(args.save_dir, "stats") logging.info(f"Will save data to {save_dir}") utils.ensure_dir(save_dir) # compute_histograms(save_dir, entity_symbols) compute_occurrences(save_dir, train_file, entity_symbols, args.lower, args.strip, num_workers=args.num_workers) if not args.no_types: type_symbols = TypeSymbols( entity_symbols=entity_symbols, emb_dir=args.emb_dir, max_types=args.max_types, emb_file="hyena_type_emb.pkl", type_vocab_file="hyena_type_graph.vocab.pkl", type_file="hyena_types.txt") compute_type_occurrences(save_dir, "orig", entity_symbols, type_symbols.qid2typenames, train_file)
def test_filter_embs(self): entity_dump_dir = "test/data/preprocessing/base/entity_data/entity_mappings" entity_db = EntitySymbols(load_dir=entity_dump_dir, alias_cand_map_file="alias2qids.json") num_topk_entities = 2 old2new_eid = {0:0,-1:-1,2:1,3:2} qid2topk_eid = {"Q1":2,"Q2":1,"Q3":2,"Q4":2} toes_eid = 2 state_dict = {} state_dict["emb_layer.entity_embs.learned.learned_entity_embedding.weight"] = torch.Tensor([ [1.0,2,3,4,5], [2,2,2,2,2], [3,3,3,3,3], [4,4,4,4,4], [5,5,5,5,5], [0,0,0,0,0] ]) gold_state_dict = {} gold_state_dict["emb_layer.entity_embs.learned.learned_entity_embedding.weight"] = torch.Tensor([ [1.0,2,3,4,5], [3,3,3,3,3], [4,4,4,4,4], [0,0,0,0,0] ]) new_state_dict = filter_embs(num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, state_dict) for k in gold_state_dict: assert k in new_state_dict assert torch.equal(new_state_dict[k], gold_state_dict[k])
def setUp(self): self.args = parser_utils.get_full_config("test/run_args/test_model_training.json") train_utils.setup_train_heads_and_eval_slices(self.args) self.word_symbols = data_utils.load_wordsymbols(self.args.data_config) self.entity_symbols = EntitySymbols(os.path.join( self.args.data_config.entity_dir, self.args.data_config.entity_map_dir), alias_cand_map_file=self.args.data_config.alias_cand_map) slices = WikiSlices( args=self.args, use_weak_label=False, input_src=os.path.join(self.args.data_config.data_dir, "train.jsonl"), dataset_name=os.path.join(self.args.data_config.data_dir, data_utils.generate_save_data_name( data_args=self.args.data_config, use_weak_label=True, split_name="slice_train")), is_writer=True, distributed=self.args.run_config.distributed, dataset_is_eval=False ) self.data = WikiDataset( args=self.args, use_weak_label=False, input_src=os.path.join(self.args.data_config.data_dir, "train.jsonl"), dataset_name=os.path.join(self.args.data_config.data_dir, data_utils.generate_save_data_name( data_args=self.args.data_config, use_weak_label=False, split_name="train")), is_writer=True, distributed=self.args.run_config.distributed, word_symbols=self.word_symbols, entity_symbols=self.entity_symbols, slice_dataset=slices, dataset_is_eval=False ) self.trainer = Trainer(self.args, self.entity_symbols, self.word_symbols)
def test_load_and_save(self): entity_dump_dir = "test/data/preprocessing/base/entity_data/entity_mappings" entity_symbols = EntitySymbols(load_dir=entity_dump_dir, alias_cand_map_file="alias2qids.json") fmt_types = {"trie1": "qid_cand_with_score"} max_values = {"trie1": 3} input_dicts = {"trie1": entity_symbols.get_alias2qids()} record_trie = RecordTrieCollection(load_dir=None, input_dicts=input_dicts, vocabulary=entity_symbols.get_qid2eid(), fmt_types=fmt_types, max_values=max_values) record_trie.dump(save_dir=os.path.join(entity_dump_dir, "record_trie")) record_trie_loaded = RecordTrieCollection(load_dir=os.path.join(entity_dump_dir, "record_trie")) self.assertEqual(record_trie._fmt_types, record_trie_loaded._fmt_types) self.assertEqual(record_trie._max_values, record_trie_loaded._max_values) self.assertEqual(record_trie._stoi, record_trie_loaded._stoi) np.testing.assert_array_equal(record_trie._itos, record_trie_loaded._itos) self.assertEqual(record_trie._record_tris, record_trie_loaded._record_tris)
def load_from_jsonl( cls, profile_file, max_candidates=30, max_types=10, max_kg_connections=100, edit_mode=False, ): """Loads an entity profile from the raw jsonl file. Each line is a JSON object with entity metadata. Example object:: { "entity_id": "C000", "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]], "title": "Dog", "types": {"hyena": ["animal"], "wiki": ["dog"]}, "relations": [ {"relation": "sibling", "object": "Q345"}, {"relation": "sibling", "object": "Q567"}, ], } Args: profile_file: file where jsonl data lives max_candidates: maximum entity candidates max_types: maximum types per entity max_kg_connections: maximum KG connections per entity edit_mode: edit mode Returns: entity profile object """ qid2title, alias2qids, type_systems, qid2relations = cls._read_profile_file( profile_file ) entity_symbols = EntitySymbols( alias2qids=alias2qids, qid2title=qid2title, max_candidates=max_candidates, edit_mode=edit_mode, ) all_type_symbols = { ty_name: TypeSymbols( qid2typenames=type_map, max_types=max_types, edit_mode=edit_mode ) for ty_name, type_map in type_systems.items() } kg_symbols = KGSymbols( qid2relations, max_connections=max_kg_connections, edit_mode=edit_mode ) return cls(entity_symbols, all_type_symbols, kg_symbols, edit_mode)
def test_construction_double(self): truealias2qids = { 'alias1': [["Q1", 10], ["Q4", 6]], 'multi word alias2': [["Q2", 5], ["Q1", 3], ["Q4", 2]], 'alias3': [["Q1", 30]], 'alias4': [["Q4", 20], ["Q3", 15], ["Q2", 1]] } truealias2typeids = { 'Q1': [1, 2, 3], 'Q2': [8], 'Q3': [4], 'Q4': [2,4] } truerelations = { 'Q1': set(["Q2", "Q3"]), 'Q2': set(["Q1"]), 'Q3': set(["Q2"]), 'Q4': set(["Q1", "Q2", "Q3"]) } entity_dump_dir = "test/data/preprocessing/base/entity_data/entity_mappings" entity_symbols = EntitySymbols(load_dir=entity_dump_dir, alias_cand_map_file="alias2qids.json") fmt_types = {"trie1": "qid_cand_with_score", "trie2": "type_ids", "trie3": "kg_relations"} max_values = {"trie1": 3, "trie2": 3, "trie3": 3} input_dicts = {"trie1": entity_symbols.get_alias2qids(), "trie2": truealias2typeids, "trie3": truerelations} record_trie = RecordTrieCollection(load_dir=None, input_dicts=input_dicts, vocabulary=entity_symbols.get_qid2eid(), fmt_types=fmt_types, max_values=max_values) for al in truealias2qids: self.assertEqual(truealias2qids[al], record_trie.get_value("trie1", al)) for qid in truealias2typeids: self.assertEqual(truealias2typeids[qid], record_trie.get_value("trie2", qid)) for qid in truerelations: self.assertEqual(truerelations[qid], record_trie.get_value("trie3", qid))
def setUp(self): entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings" self.entity_symbols = EntitySymbols(entity_dump_dir, alias_cand_map_file="alias2qids.json") self.config = { 'data_config': {'train_in_candidates': False, 'entity_dir': 'test/data/entity_loader/entity_data', 'entity_prep_dir': 'prep', 'alias_cand_map': 'alias2qids.json', 'max_aliases': 3, 'data_dir': 'test/data/entity_loader', 'overwrite_preprocessed_data': True}, 'run_config': {'distributed': False} }
def setUp(self): entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings" self.entity_symbols = EntitySymbols.load_from_cache( entity_dump_dir, alias_cand_map_file="alias2qids.json" ) self.config = { "data_config": { "train_in_candidates": False, "entity_dir": "test/data/entity_loader/entity_data", "entity_prep_dir": "prep", "alias_cand_map": "alias2qids.json", "max_aliases": 3, "data_dir": "test/data/entity_loader", "overwrite_preprocessed_data": True, }, "run_config": {"distributed": False}, }
def test_filter_qids(self): entity_dump_dir = "test/data/preprocessing/base/entity_data/entity_mappings" entity_db = EntitySymbols(load_dir=entity_dump_dir, alias_cand_map_file="alias2qids.json") qid2count = {"Q1":10,"Q2":20,"Q3":2,"Q4":4} perc_emb_drop = 0.8 gold_qid2topk_eid = {"Q1":2,"Q2":1,"Q3":2,"Q4":2} gold_old2new_eid = {0:0,-1:-1,2:1,3:2} gold_new_toes_eid = 2 gold_num_topk_entities = 2 qid2topk_eid, old2new_eid, new_toes_eid, num_topk_entities = filter_qids(perc_emb_drop, entity_db, qid2count) self.assertEqual(gold_qid2topk_eid, qid2topk_eid) self.assertEqual(gold_old2new_eid, old2new_eid) self.assertEqual(gold_new_toes_eid, new_toes_eid) self.assertEqual(gold_num_topk_entities, num_topk_entities)
def write_data_labels_initializer(merged_entity_emb_file, merged_storage_type, sent_idx_map_file, train_in_candidates, dump_embs, data_config): global filt_emb_data_global filt_emb_data_global = np.memmap(merged_entity_emb_file, dtype=merged_storage_type, mode="r+") global sent_idx_map_global sent_idx_map_global = utils.load_single_item_trie(sent_idx_map_file) global train_in_candidates_global train_in_candidates_global = train_in_candidates global dump_embs_global dump_embs_global = dump_embs global entity_dump_global entity_dump_global = EntitySymbols( load_dir=os.path.join(data_config.entity_dir, data_config.entity_map_dir), alias_cand_map_file=data_config.alias_cand_map)
def main(): args = parse_args() logging.info(json.dumps(vars(args), indent=4)) entity_symbols = EntitySymbols.load_from_cache( load_dir=os.path.join(args.data_dir, args.entity_symbols_dir)) train_file = os.path.join(args.data_dir, args.train_file) save_dir = os.path.join(args.save_dir, "stats") logging.info(f"Will save data to {save_dir}") utils.ensure_dir(save_dir) # compute_histograms(save_dir, entity_symbols) compute_occurrences( save_dir, train_file, entity_symbols, args.lower, args.strip, num_workers=args.num_workers, )
def main(): args = parse_args() print(ujson.dumps(vars(args), indent=4)) entity_symbols = EntitySymbols.load_from_cache( os.path.join(args.entity_dir, args.entity_map_dir), alias_cand_map_file=args.alias_cand_map, alias_idx_file=args.alias_idx_map, ) print("DO LOWERCASE IS", "uncased" in args.bert_model) tokenizer = BertTokenizer.from_pretrained( args.bert_model, do_lower_case="uncased" in args.bert_model, cache_dir=args.word_model_cache, ) model = BertModel.from_pretrained( args.bert_model, cache_dir=args.word_model_cache, output_attentions=False, output_hidden_states=False, ) if not args.cpu: model = model.to("cuda") model.eval() entity2avgtitle = build_title_table( args.cpu, args.batch_size, model, tokenizer, entity_symbols ) save_fold = os.path.dirname(args.save_file) if len(save_fold) > 0: if not os.path.exists(save_fold): os.makedirs(save_fold) if args.output_method == "pt": save_obj = (entity_symbols.get_qid2eid(), entity2avgtitle) torch.save(obj=save_obj, f=args.save_file) else: res = {} for qid in tqdm(entity_symbols.get_all_qids(), desc="Building final json"): eid = entity_symbols.get_eid(qid) res[qid] = entity2avgtitle[eid].tolist() with open(args.save_file, "w") as out_f: ujson.dump(res, out_f) print(f"Done!")
def main(): args = parse_args() print(ujson.dumps(vars(args), indent=4)) num_processes = min(args.processes, int(0.8 * multiprocessing.cpu_count())) print("Loading entity symbols") entity_symbols = EntitySymbols.load_from_cache( os.path.join(args.entity_dir, args.entity_map_dir), alias_cand_map_file=args.alias_cand_map, alias_idx_file=args.alias_idx_map, ) in_file = os.path.join(args.data_dir, args.train_file) print(f"Getting slice counts from {in_file}") qid_cnts = get_counts(num_processes, in_file) with open(os.path.join(args.data_dir, "qid_cnts_train.json"), "w") as out_f: ujson.dump(qid_cnts, out_f) df = build_reg_csv(qid_cnts, entity_symbols) df.to_csv(args.out_file, index=False) print(f"Saved file to {args.out_file}")
def setUp(self): """ENTITY SYMBOLS { "multi word alias2":[["Q2",5.0],["Q1",3.0],["Q4",2.0]], "alias1":[["Q1",10.0],["Q4",6.0]], "alias3":[["Q1",30.0]], "alias4":[["Q4",20.0],["Q3",15.0],["Q2",1.0]] } EMBEDDINGS { "key": "learned", "freeze": false, "load_class": "LearnedEntityEmb", "args": { "learned_embedding_size": 10, } }, { "key": "learned_type", "load_class": "LearnedTypeEmb", "freeze": false, "args": { "type_labels": "type_pred_mapping.json", "max_types": 1, "type_dim": 5, "merge_func": "addattn", "attn_hidden_size": 5 } } """ self.args = parser_utils.parse_boot_and_emm_args( "test/run_args/test_model_training.json") self.entity_symbols = EntitySymbols.load_from_cache( os.path.join(self.args.data_config.entity_dir, self.args.data_config.entity_map_dir), alias_cand_map_file=self.args.data_config.alias_cand_map, ) emmental.init(log_dir="test/temp_log") if not os.path.exists(emmental.Meta.log_path): os.makedirs(emmental.Meta.log_path)
class EntitySymbolTest(unittest.TestCase): def setUp(self): entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings" self.entity_symbols = EntitySymbols(load_dir=entity_dump_dir, alias_cand_map_file="alias2qids.json") def test_load_entites_keep_noncandidate(self): truealias2qids = { 'alias1': [["Q1", 10.0], ["Q4", 6]], 'multi word alias2': [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], 'alias3': [["Q1", 30.0]], 'alias4': [["Q4", 20], ["Q3", 15.0], ["Q2", 1]] } trueqid2title = { 'Q1': "alias1", 'Q2': "multi alias2", 'Q3': "word alias3", 'Q4': "nonalias4" } # the non-candidate class is included in entity_dump trueqid2eid = { 'Q1': 1, 'Q2': 2, 'Q3': 3, 'Q4': 4 } self.assertEqual(self.entity_symbols.max_candidates, 3) self.assertEqual(self.entity_symbols.max_alias_len, 3) self.assertDictEqual(self.entity_symbols._alias2qids, truealias2qids) self.assertDictEqual(self.entity_symbols._qid2title, trueqid2title) self.assertDictEqual(self.entity_symbols._qid2eid, trueqid2eid) def test_getters(self): self.assertEqual(self.entity_symbols.get_qid(1), 'Q1') self.assertSetEqual(set(self.entity_symbols.get_all_aliases()), {'alias1', 'multi word alias2', 'alias3', 'alias4'}) self.assertEqual(self.entity_symbols.get_eid('Q3'), 3) self.assertListEqual(self.entity_symbols.get_qid_cands('alias1'), ['Q1', 'Q4']) self.assertListEqual(self.entity_symbols.get_qid_cands('alias1', max_cand_pad=True), ['Q1', 'Q4', '-1']) self.assertListEqual(self.entity_symbols.get_eid_cands('alias1', max_cand_pad=True), [1, 4, -1]) self.assertEqual(self.entity_symbols.get_title('Q1'), 'alias1')
def setUp(self): # tests that the sampling is done correctly on indices # load data from directory self.args = parser_utils.parse_boot_and_emm_args( "test/run_args/test_type_data.json") self.tokenizer = BertTokenizer.from_pretrained( "bert-base-cased", do_lower_case=False, cache_dir="test/data/emb_data/pretrained_bert_models", ) self.is_bert = True self.entity_symbols = EntitySymbols.load_from_cache( os.path.join(self.args.data_config.entity_dir, self.args.data_config.entity_map_dir), alias_cand_map_file=self.args.data_config.alias_cand_map, ) self.temp_file_name = "test/data/data_loader/test_data.jsonl" self.guid_dtype = lambda max_aliases: np.dtype([ ("sent_idx", "i8", 1), ("subsent_idx", "i8", 1), ("alias_orig_list_pos", "i8", max_aliases), ])
class Annotator(): """ Annotator class: convenient wrapper of preprocessing and model eval to allow for annotating single sentences at a time for quick experimentation, e.g. in notebooks. """ def __init__(self, config_args, device='cuda', max_alias_len=6, cand_map=None, threshold=0.0): self.args = config_args self.device = device self.entity_db = EntitySymbols( os.path.join(self.args.data_config.entity_dir, self.args.data_config.entity_map_dir), alias_cand_map_file=self.args.data_config.alias_cand_map) self.word_db = data_utils.load_wordsymbols(self.args.data_config, is_writer=True, distributed=False) self.model = self._load_model() self.max_alias_len = max_alias_len if cand_map is None: alias_map = self.entity_db._alias2qids else: alias_map = ujson.load(open(cand_map)) self.all_aliases_trie = get_all_aliases(alias_map, logger=logging.getLogger()) self.alias_table = AliasEntityTable(args=self.args, entity_symbols=self.entity_db) # minimum probability of prediction to return mention self.threshold = threshold # get batch_on_the_fly embeddings _and_ the batch_prep embeddings self.batch_on_the_fly_embs = {} for i, emb in enumerate(self.args.data_config.ent_embeddings): if 'batch_prep' in emb and emb['batch_prep'] is True: self.args.data_config.ent_embeddings[i][ 'batch_on_the_fly'] = True del self.args.data_config.ent_embeddings[i]['batch_prep'] if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True: mod, load_class = import_class("bootleg.embeddings", emb.load_class) try: self.batch_on_the_fly_embs[emb.key] = getattr( mod, load_class)(main_args=self.args, emb_args=emb['args'], entity_symbols=self.entity_db, model_device=None, word_symbols=None) except AttributeError as e: print( f'No prep method found for {emb.load_class} with error {e}' ) except Exception as e: print("ERROR", e) def _load_model(self): model_state_dict = torch.load( self.args.run_config.init_checkpoint, map_location=lambda storage, loc: storage)['model'] model = Model(args=self.args, model_device=self.device, entity_symbols=self.entity_db, word_symbols=self.word_db).to(self.device) # remove distributed naming if it exists if not self.args.run_config.distributed: new_state_dict = OrderedDict() for k, v in model_state_dict.items(): if 'module.' in k and k[:len('module.')] == 'module.': name = k[len('module.'):] new_state_dict[name] = v # we renamed all layers due to distributed naming if len(new_state_dict) == len(model_state_dict): model_state_dict = new_state_dict else: assert len(new_state_dict) == 0 model.load_state_dict(model_state_dict, strict=True) # must set model in eval mode model.eval() return model def extract_mentions(self, text): found_aliases, found_spans = find_aliases_in_sentence_tag( text, self.all_aliases_trie, self.max_alias_len) return { 'sentence': text, 'aliases': found_aliases, 'spans': found_spans, # we don't know the true QID 'qids': ['Q-1' for i in range(len(found_aliases))], 'gold': [True for i in range(len(found_aliases))] } def set_threshold(self, value): self.threshold = value def label_mentions(self, text): sample = self.extract_mentions(text) idxs_arr, aliases_to_predict_per_split, spans_arr, phrase_tokens_arr = sentence_utils.split_sentence( max_aliases=self.args.data_config.max_aliases, phrase=sample['sentence'], spans=sample['spans'], aliases=sample['aliases'], aliases_seen_by_model=[i for i in range(len(sample['aliases']))], seq_len=self.args.data_config.max_word_token_len, word_symbols=self.word_db) aliases_arr = [[sample['aliases'][idx] for idx in idxs] for idxs in idxs_arr] qids_arr = [[sample['qids'][idx] for idx in idxs] for idxs in idxs_arr] word_indices_arr = [ self.word_db.convert_tokens_to_ids(pt) for pt in phrase_tokens_arr ] if len(idxs_arr) > 1: #TODO: support sentences that overflow due to long sequence length or too many mentions raise ValueError( 'Overflowing sentences not currently supported in Annotator') # iterate over each sample in the split for sub_idx in range(len(idxs_arr)): example_aliases = np.ones(self.args.data_config.max_aliases, dtype=np.int) * PAD_ID example_true_entities = np.ones( self.args.data_config.max_aliases) * PAD_ID example_aliases_locs_start = np.ones( self.args.data_config.max_aliases) * PAD_ID example_aliases_locs_end = np.ones( self.args.data_config.max_aliases) * PAD_ID aliases = aliases_arr[sub_idx] for mention_idx, alias in enumerate(aliases): # get aliases alias_trie_idx = self.entity_db.get_alias_idx(alias) alias_qids = np.array(self.entity_db.get_qid_cands(alias)) example_aliases[mention_idx] = alias_trie_idx # alias_idx_pair span_idx = spans_arr[sub_idx][mention_idx] span_start_idx, span_end_idx = span_idx example_aliases_locs_start[mention_idx] = span_start_idx example_aliases_locs_end[mention_idx] = span_end_idx # get word indices word_indices = word_indices_arr[sub_idx] # entity indices from alias table (these are the candidates) entity_indices = self.alias_table(example_aliases) # all CPU embs have to retrieved on the fly batch_on_the_fly_data = {} for emb_name, emb in self.batch_on_the_fly_embs.items(): batch_on_the_fly_data[emb_name] = torch.tensor( emb.batch_prep(example_aliases, entity_indices), device=self.device) outs, entity_pack, _ = self.model( alias_idx_pair_sent=[ torch.tensor(example_aliases_locs_start, device=self.device).unsqueeze(0), torch.tensor(example_aliases_locs_end, device=self.device).unsqueeze(0) ], word_indices=torch.tensor([word_indices], device=self.device), alias_indices=torch.tensor(example_aliases, device=self.device).unsqueeze(0), entity_indices=torch.tensor(entity_indices, device=self.device).unsqueeze(0), batch_prepped_data={}, batch_on_the_fly_data=batch_on_the_fly_data) entity_cands = eval_utils.map_aliases_to_candidates( self.args.data_config.train_in_candidates, self.entity_db, aliases) # recover predictions probs = torch.exp( eval_utils.masked_class_logsoftmax( pred=outs[DISAMBIG][FINAL_LOSS], mask=~entity_pack.mask, dim=2)) max_probs, max_probs_indices = probs.max(2) pred_cands = [] pred_probs = [] titles = [] for alias_idx in range(len(aliases)): pred_idx = max_probs_indices[0][alias_idx] pred_prob = max_probs[0][alias_idx].item() pred_qid = entity_cands[alias_idx][pred_idx] if pred_prob > self.threshold: pred_cands.append(pred_qid) pred_probs.append(pred_prob) titles.append( self.entity_db. get_title(pred_qid) if pred_qid != 'NC' else 'NC') return pred_cands, pred_probs, titles
def create_task(args, entity_symbols=None, slice_datasets=None): """Creates a type prediction task. Args: args: args entity_symbols: entity symbols slice_datasets: slice datasets used in scorer (default None) Returns: EmmentalTask for type prediction """ if entity_symbols is None: entity_symbols = EntitySymbols.load_from_cache( load_dir=os.path.join( args.data_config.entity_dir, args.data_config.entity_map_dir ), alias_cand_map_file=args.data_config.alias_cand_map, alias_idx_file=args.data_config.alias_idx_map, ) # Create sentence encoder bert_model = BertEncoder( args.data_config.word_embedding, output_size=args.model_config.hidden_size ) # Create type prediction module # Add 1 for pad type type_prediction = TypePred( args.model_config.hidden_size, args.data_config.type_prediction.dim, args.data_config.type_prediction.num_types + 1, embedding_utils.get_max_candidates(entity_symbols, args.data_config), ) # Create scorer sliced_scorer = BootlegSlicedScorer( args.data_config.train_in_candidates, slice_datasets ) # Create module pool # BERT model will be shared across tasks as long as the name matches module_pool = nn.ModuleDict( {BERT_MODEL_NAME: bert_model, "type_prediction": type_prediction} ) # Create task flow task_flow = [ { "name": BERT_MODEL_NAME, "module": BERT_MODEL_NAME, "inputs": [ ("_input_", "entity_cand_eid"), ("_input_", "token_ids"), ], # We pass the entity_cand_eids to BERT in case of embeddings that require word information }, { "name": "type_prediction", "module": "type_prediction", # output: embedding_dict, batch_type_pred "inputs": [ (BERT_MODEL_NAME, 0), # sentence embedding ("_input_", "start_span_idx"), ], }, ] return EmmentalTask( name=TYPE_PRED_TASK, module_pool=module_pool, task_flow=task_flow, loss_func=partial(type_loss, "type_prediction"), output_func=partial(type_output, "type_prediction"), require_prob_for_eval=False, require_pred_for_eval=True, scorer=Scorer( customize_metric_funcs={ f"{TYPE_PRED_TASK}_scorer": sliced_scorer.type_pred_score } ), )
def load_from_cache( cls, load_dir, edit_mode=False, verbose=False, no_kg=False, no_type=False, type_systems_to_load=None, ): """Loaded a pre-saved profile. Args: load_dir: load directory edit_mode: edit mode flag, default False verbose: verbose flag, default False no_kg: load kg or not flag, default False no_type: load types or not flag, default False. If True, this will ignore type_systems_to_load. type_systems_to_load: list of type systems to load, default is None which means all types systems Returns: entity profile object """ # Check type system input load_dir = Path(load_dir) type_subfolder = load_dir / TYPE_SUBFOLDER if type_systems_to_load is not None: if not isinstance(type_systems_to_load, list): raise ValueError( f"`type_systems` must be a list of subfolders in {type_subfolder}" ) for sys in type_systems_to_load: if sys not in list([p.name for p in type_subfolder.iterdir()]): raise ValueError( f"`type_systems` must be a list of subfolders in {type_subfolder}. {sys} is not one." ) if verbose: print("Loading Entity Symbols") entity_symbols = EntitySymbols.load_from_cache( load_dir / ENTITY_SUBFOLDER, edit_mode=edit_mode, verbose=verbose, ) if no_type: print( f"Not loading type information. We will act as if there is no types associated with any entity " f"and will not modify the types in any way, even if calling `add`." ) type_sys_dict = {} for fold in type_subfolder.iterdir(): if ((not no_type) and (type_systems_to_load is None or fold.name in type_systems_to_load) and (fold.is_dir())): if verbose: print(f"Loading Type Symbols from {fold}") type_sys_dict[fold.name] = TypeSymbols.load_from_cache( type_subfolder / fold.name, edit_mode=edit_mode, verbose=verbose, ) if verbose: print(f"Loading KG Symbols") if no_kg: print( f"Not loading KG information. We will act as if there is not KG connections between entities. " f"We will not modify the KG information in any way, even if calling `add`." ) kg_symbols = None if not no_kg: kg_symbols = KGSymbols.load_from_cache( load_dir / KG_SUBFOLDER, edit_mode=edit_mode, verbose=verbose, ) return cls(entity_symbols, type_sys_dict, kg_symbols, edit_mode, verbose)
def model_eval(args, mode, is_writer, logger, world_size=1, rank=0): assert args.run_config.init_checkpoint != "", "You can't have an empty model file to do eval" # this is in main but call again in case eval is called directly train_utils.setup_train_heads_and_eval_slices(args) train_utils.setup_run_folders(args, mode) word_symbols = data_utils.load_wordsymbols( args.data_config, is_writer, distributed=args.run_config.distributed) logger.info(f"Loading entity_symbols...") entity_symbols = EntitySymbols( load_dir=os.path.join(args.data_config.entity_dir, args.data_config.entity_map_dir), alias_cand_map_file=args.data_config.alias_cand_map) logger.info( f"Loaded entity_symbols with {entity_symbols.num_entities} entities.") eval_slice_names = args.run_config.eval_slices test_dataset_collection = {} test_slice_dataset = data_utils.create_slice_dataset( args, args.data_config.test_dataset, is_writer, dataset_is_eval=True) test_dataset = data_utils.create_dataset(args, args.data_config.test_dataset, is_writer, word_symbols, entity_symbols, slice_dataset=test_slice_dataset, dataset_is_eval=True) test_dataloader, test_sampler = data_utils.create_dataloader( args, test_dataset, eval_slice_dataset=test_slice_dataset, batch_size=args.run_config.eval_batch_size) dataset_collection = DatasetCollection(args.data_config.test_dataset, args.data_config.test_dataset.file, test_dataset, test_dataloader, test_slice_dataset, test_sampler) test_dataset_collection[ args.data_config.test_dataset.file] = dataset_collection trainer = Trainer(args, entity_symbols, word_symbols, resume_model_file=args.run_config.init_checkpoint, eval_slice_names=eval_slice_names, model_eval=True) # Run evaluation numbers without dumping predictions (quick, batched) if mode == 'eval': status_reporter = StatusReporter(args, logger, is_writer, max_epochs=None, total_steps_per_epoch=None, is_eval=True) # results are written to json file for test_data_file in test_dataset_collection: logger.info( f"************************RUNNING EVAL {test_data_file}************************" ) test_dataloader = test_dataset_collection[ test_data_file].data_loader # True is for if the batch is test or not, None is for the global step eval_utils.run_batched_eval(args=args, is_test=True, global_step=None, logger=logger, trainer=trainer, dataloader=test_dataloader, status_reporter=status_reporter, file=test_data_file) elif mode == 'dump_preds' or mode == 'dump_embs': # get predictions and optionally dump the corresponding contextual entity embeddings # TODO: support dumping ids for other embeddings as well (static entity embeddings, type embeddings, relation embeddings) # TODO: remove collection abstraction for test_data_file in test_dataset_collection: logger.info( f"************************DUMPING PREDICTIONS FOR {test_data_file}************************" ) test_dataloader = test_dataset_collection[ test_data_file].data_loader pred_file, emb_file = eval_utils.run_dump_preds( args=args, entity_symbols=entity_symbols, test_data_file=test_data_file, logger=logger, trainer=trainer, dataloader=test_dataloader, dump_embs=(mode == 'dump_embs')) return pred_file, emb_file return
def train(args, is_writer, logger, world_size, rank): # This is main but call again in case train is called directly train_utils.setup_train_heads_and_eval_slices(args) train_utils.setup_run_folders(args, "train") # Load word symbols (like tokenizers) and entity symbols (aka entity profiles) word_symbols = data_utils.load_wordsymbols( args.data_config, is_writer, distributed=args.run_config.distributed) logger.info(f"Loading entity_symbols...") entity_symbols = EntitySymbols( load_dir=os.path.join(args.data_config.entity_dir, args.data_config.entity_map_dir), alias_cand_map_file=args.data_config.alias_cand_map) logger.info( f"Loaded entity_symbols with {entity_symbols.num_entities} entities.") # Get train dataset train_slice_dataset = data_utils.create_slice_dataset( args, args.data_config.train_dataset, is_writer, dataset_is_eval=False) train_dataset = data_utils.create_dataset( args, args.data_config.train_dataset, is_writer, word_symbols, entity_symbols, slice_dataset=train_slice_dataset, dataset_is_eval=False) train_dataloader, train_sampler = data_utils.create_dataloader( args, train_dataset, eval_slice_dataset=None, world_size=world_size, rank=rank, batch_size=args.train_config.batch_size) # Repeat for dev dev_dataset_collection = {} dev_slice_dataset = data_utils.create_slice_dataset( args, args.data_config.dev_dataset, is_writer, dataset_is_eval=True) dev_dataset = data_utils.create_dataset(args, args.data_config.dev_dataset, is_writer, word_symbols, entity_symbols, slice_dataset=dev_slice_dataset, dataset_is_eval=True) dev_dataloader, dev_sampler = data_utils.create_dataloader( args, dev_dataset, eval_slice_dataset=dev_slice_dataset, batch_size=args.run_config.eval_batch_size) dataset_collection = DatasetCollection(args.data_config.dev_dataset, args.data_config.dev_dataset.file, dev_dataset, dev_dataloader, dev_slice_dataset, dev_sampler) dev_dataset_collection[ args.data_config.dev_dataset.file] = dataset_collection eval_slice_names = args.run_config.eval_slices total_steps_per_epoch = len(train_dataloader) # Create trainer---model, optimizer, and scorer trainer = Trainer(args, entity_symbols, word_symbols, total_steps_per_epoch=total_steps_per_epoch, eval_slice_names=eval_slice_names, resume_model_file=args.run_config.init_checkpoint) # Set up epochs and intervals for saving and evaluating max_epochs = int(args.run_config.max_epochs) eval_steps = int(args.run_config.eval_steps) log_steps = int(args.run_config.log_steps) save_steps = max(int(args.run_config.save_every_k_eval * eval_steps), 1) logger.info( f"Eval steps {eval_steps}, Log steps {log_steps}, Save steps {save_steps}, Total training examples per epoch {len(train_dataset)}" ) status_reporter = StatusReporter(args, logger, is_writer, max_epochs, total_steps_per_epoch, is_eval=False) global_step = 0 for epoch in range(trainer.start_epoch, trainer.start_epoch + max_epochs): # this is to fix having to save/restore the RNG state for checkpointing torch.manual_seed(args.train_config.seed + epoch) np.random.seed(args.train_config.seed + epoch) if args.run_config.distributed: # for determinism across runs https://github.com/pytorch/examples/issues/501 train_sampler.set_epoch(epoch) start_time_load = time.time() for i, batch in enumerate(train_dataloader): load_time = time.time() - start_time_load start_time = time.time() _, loss_pack, _, _ = trainer.update(batch) # Log progress if (global_step + 1) % log_steps == 0: duration = time.time() - start_time status_reporter.step_status(epoch=epoch, step=global_step, loss_pack=loss_pack, time=duration, load_time=load_time, lr=trainer.get_lr()) # Save model if (global_step + 1) % save_steps == 0 and is_writer: logger.info("Saving model...") trainer.save(save_dir=train_utils.get_save_folder( args.run_config), epoch=epoch, step=global_step, step_in_batch=i, suffix=args.run_config.model_suffix) # Run evaluation if (global_step + 1) % eval_steps == 0: eval_utils.run_eval_all_dev_sets(args, global_step, dev_dataset_collection, logger, status_reporter, trainer) if args.run_config.distributed: dist.barrier() global_step += 1 # Time loading new batch start_time_load = time.time() ######### END OF EPOCH if is_writer: logger.info(f"Saving model end of epoch {epoch}...") trainer.save(save_dir=train_utils.get_save_folder(args.run_config), epoch=epoch, step=global_step, step_in_batch=i, end_of_epoch=True, suffix=args.run_config.model_suffix) # Always run eval when saving -- if this coincided with eval_step, then don't need to rerun eval if (global_step + 1) % eval_steps != 0: eval_utils.run_eval_all_dev_sets(args, global_step, dev_dataset_collection, logger, status_reporter, trainer) if is_writer: logger.info("Saving model...") trainer.save(save_dir=train_utils.get_save_folder(args.run_config), epoch=epoch, step=global_step, step_in_batch=i, end_of_epoch=True, last_epoch=True, suffix=args.run_config.model_suffix) if args.run_config.distributed: dist.barrier()
def test_add_entity(self): alias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2], ["Q3", 1]], "alias3": [["Q1", 30.0]], "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]], } qid2title = { "Q1": "alias1", "Q2": "multi alias2", "Q3": "word alias3", "Q4": "nonalias4", } max_candidates = 3 entity_symbols = EntitySymbols( max_candidates=max_candidates, alias2qids=alias2qids, qid2title=qid2title, edit_mode=True, ) trueqid2aliases = { "Q1": {"alias1", "multi word alias2", "alias3"}, "Q2": {"multi word alias2", "alias4"}, "Q3": {"alias4"}, "Q4": {"alias1", "multi word alias2", "alias4"}, } truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], "alias3": [["Q1", 30.0]], "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]], } trueqid2eid = {"Q1": 1, "Q2": 2, "Q3": 3, "Q4": 4} truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3} trueid2alias = {0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2"} truemax_eid = 4 truemax_alid = 3 truenum_entities = 4 truenum_entities_with_pad_and_nocand = 6 self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases) self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid) self.assertDictEqual( entity_symbols._eid2qid, {v: i for i, v in trueqid2eid.items()} ) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertIsNone(entity_symbols._alias_trie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertDictEqual(entity_symbols._id2alias, trueid2alias) self.assertEqual(entity_symbols.max_eid, truemax_eid) self.assertEqual(entity_symbols.max_alid, truemax_alid) self.assertEqual(entity_symbols.num_entities, truenum_entities) self.assertEqual( entity_symbols.num_entities_with_pad_and_nocand, truenum_entities_with_pad_and_nocand, ) # Add entity entity_symbols.add_entity( "Q5", [["multi word alias2", 1.5], ["alias5", 20.0]], "Snake" ) trueqid2aliases = { "Q1": {"alias1", "multi word alias2", "alias3"}, "Q2": {"multi word alias2", "alias4"}, "Q3": {"alias4"}, "Q4": {"alias1", "alias4"}, "Q5": {"multi word alias2", "alias5"}, } truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [ ["Q2", 5.0], ["Q1", 3], ["Q5", 1.5], ], # adding new entity-mention pair - we override scores to add it. Hence Q4 is removed "alias3": [["Q1", 30.0]], "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]], "alias5": [["Q5", 20]], } trueqid2eid = {"Q1": 1, "Q2": 2, "Q3": 3, "Q4": 4, "Q5": 5} truealias2id = { "alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3, "alias5": 4, } trueid2alias = { 0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2", 4: "alias5", } truemax_eid = 5 truemax_alid = 4 truenum_entities = 5 truenum_entities_with_pad_and_nocand = 7 self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases) self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid) self.assertDictEqual( entity_symbols._eid2qid, {v: i for i, v in trueqid2eid.items()} ) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertIsNone(entity_symbols._alias_trie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertDictEqual(entity_symbols._id2alias, trueid2alias) self.assertEqual(entity_symbols.max_eid, truemax_eid) self.assertEqual(entity_symbols.max_alid, truemax_alid) self.assertEqual(entity_symbols.num_entities, truenum_entities) self.assertEqual( entity_symbols.num_entities_with_pad_and_nocand, truenum_entities_with_pad_and_nocand, )
def test_add_remove_mention(self): alias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2], ["Q3", 1]], "alias3": [["Q1", 30.0]], "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]], } qid2title = { "Q1": "alias1", "Q2": "multi alias2", "Q3": "word alias3", "Q4": "nonalias4", } max_candidates = 3 # the non-candidate class is included in entity_dump trueqid2eid = {"Q1": 1, "Q2": 2, "Q3": 3, "Q4": 4} truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3} truealiastrie = {"multi word alias2": 0, "alias1": 1, "alias3": 2, "alias4": 3} truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], "alias3": [["Q1", 30.0]], "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]], } entity_symbols = EntitySymbols( max_candidates=max_candidates, alias2qids=alias2qids, qid2title=qid2title, ) tri_as_dict = {} for k in entity_symbols._alias_trie: tri_as_dict[k] = entity_symbols._alias_trie[k] self.assertEqual(entity_symbols.max_candidates, 3) self.assertEqual(entity_symbols.max_eid, 4) self.assertEqual(entity_symbols.max_alid, 3) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertDictEqual(entity_symbols._qid2title, qid2title) self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid) self.assertDictEqual(tri_as_dict, truealiastrie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertIsNone(entity_symbols._qid2aliases) # Check if fails with edit_mode = False with self.assertRaises(AttributeError) as context: entity_symbols.add_mention("Q2", "alias3", 31.0) assert type(context.exception) is AttributeError entity_symbols = EntitySymbols( max_candidates=max_candidates, alias2qids=alias2qids, qid2title=qid2title, edit_mode=True, ) # Check nothing changes if pair doesn't exist entity_symbols.remove_mention("Q3", "alias1") trueqid2aliases = { "Q1": {"alias1", "multi word alias2", "alias3"}, "Q2": {"multi word alias2", "alias4"}, "Q3": {"alias4"}, "Q4": {"alias1", "multi word alias2", "alias4"}, } self.assertEqual(entity_symbols.max_candidates, 3) self.assertEqual(entity_symbols.max_eid, 4) self.assertEqual(entity_symbols.max_alid, 3) self.assertDictEqual(entity_symbols._qid2title, qid2title) self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid) self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertIsNone(entity_symbols._alias_trie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) # ADD Q2 ALIAS 3 entity_symbols.add_mention("Q2", "alias3", 31.0) trueqid2aliases = { "Q1": {"alias1", "multi word alias2", "alias3"}, "Q2": {"multi word alias2", "alias4", "alias3"}, "Q3": {"alias4"}, "Q4": {"alias1", "multi word alias2", "alias4"}, } truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], "alias3": [["Q2", 31.0], ["Q1", 30.0]], "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]], } truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3} trueid2alias = {0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2"} self.assertEqual(entity_symbols.max_eid, 4) self.assertEqual(entity_symbols.max_alid, 3) self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertIsNone(entity_symbols._alias_trie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertDictEqual(entity_symbols._id2alias, trueid2alias) # ADD Q1 ALIAS 4 entity_symbols.add_mention("Q1", "alias4", 31.0) trueqid2aliases = { "Q1": {"alias1", "multi word alias2", "alias3", "alias4"}, "Q2": {"multi word alias2", "alias3"}, "Q3": {"alias4"}, "Q4": {"alias1", "multi word alias2", "alias4"}, } truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], "alias3": [["Q2", 31.0], ["Q1", 30.0]], "alias4": [["Q1", 31.0], ["Q4", 20], ["Q3", 15.0]], } truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3} trueid2alias = {0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2"} self.assertEqual(entity_symbols.max_eid, 4) self.assertEqual(entity_symbols.max_alid, 3) self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertIsNone(entity_symbols._alias_trie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertDictEqual(entity_symbols._id2alias, trueid2alias) # REMOVE Q3 ALIAS 4 entity_symbols.remove_mention("Q3", "alias4") trueqid2aliases = { "Q1": {"alias1", "multi word alias2", "alias3", "alias4"}, "Q2": {"multi word alias2", "alias3"}, "Q3": set(), "Q4": {"alias1", "multi word alias2", "alias4"}, } truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], "alias3": [["Q2", 31.0], ["Q1", 30.0]], "alias4": [["Q1", 31.0], ["Q4", 20]], } truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3} trueid2alias = {0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2"} self.assertEqual(entity_symbols.max_eid, 4) self.assertEqual(entity_symbols.max_alid, 3) self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertIsNone(entity_symbols._alias_trie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertDictEqual(entity_symbols._id2alias, trueid2alias) # REMOVE Q4 ALIAS 4 entity_symbols.remove_mention("Q4", "alias4") trueqid2aliases = { "Q1": {"alias1", "multi word alias2", "alias3", "alias4"}, "Q2": {"multi word alias2", "alias3"}, "Q3": set(), "Q4": {"alias1", "multi word alias2"}, } truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], "alias3": [["Q2", 31.0], ["Q1", 30.0]], "alias4": [["Q1", 31.0]], } truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3} trueid2alias = {0: "alias1", 1: "alias3", 2: "alias4", 3: "multi word alias2"} self.assertEqual(entity_symbols.max_eid, 4) self.assertEqual(entity_symbols.max_alid, 3) self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertIsNone(entity_symbols._alias_trie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertDictEqual(entity_symbols._id2alias, trueid2alias) # REMOVE Q1 ALIAS 4 entity_symbols.remove_mention("Q1", "alias4") trueqid2aliases = { "Q1": {"alias1", "multi word alias2", "alias3"}, "Q2": {"multi word alias2", "alias3"}, "Q3": set(), "Q4": {"alias1", "multi word alias2"}, } truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], "alias3": [["Q2", 31.0], ["Q1", 30.0]], } truealias2id = {"alias1": 0, "alias3": 1, "multi word alias2": 3} trueid2alias = {0: "alias1", 1: "alias3", 3: "multi word alias2"} self.assertEqual(entity_symbols.max_eid, 4) self.assertEqual(entity_symbols.max_alid, 3) self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertIsNone(entity_symbols._alias_trie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertDictEqual(entity_symbols._id2alias, trueid2alias) # ADD Q1 BLIAS 0 entity_symbols.add_mention("Q1", "blias0", 11) trueqid2aliases = { "Q1": {"alias1", "multi word alias2", "alias3", "blias0"}, "Q2": {"multi word alias2", "alias3"}, "Q3": set(), "Q4": {"alias1", "multi word alias2"}, } truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], "alias3": [["Q2", 31.0], ["Q1", 30.0]], "blias0": [["Q1", 11.0]], } truealias2id = {"alias1": 0, "alias3": 1, "multi word alias2": 3, "blias0": 4} trueid2alias = {0: "alias1", 1: "alias3", 3: "multi word alias2", 4: "blias0"} self.assertEqual(entity_symbols.max_eid, 4) self.assertEqual(entity_symbols.max_alid, 4) self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertIsNone(entity_symbols._alias_trie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertDictEqual(entity_symbols._id2alias, trueid2alias) # SET SCORE Q2 ALIAS3 # Check if fails not a pair with self.assertRaises(ValueError) as context: entity_symbols.set_score("Q2", "alias1", 2) assert type(context.exception) is ValueError entity_symbols.set_score("Q2", "alias3", 2) trueqid2aliases = { "Q1": {"alias1", "multi word alias2", "alias3", "blias0"}, "Q2": {"multi word alias2", "alias3"}, "Q3": set(), "Q4": {"alias1", "multi word alias2"}, } truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], "alias3": [["Q1", 30.0], ["Q2", 2]], "blias0": [["Q1", 11.0]], } truealias2id = {"alias1": 0, "alias3": 1, "multi word alias2": 3, "blias0": 4} trueid2alias = {0: "alias1", 1: "alias3", 3: "multi word alias2", 4: "blias0"} self.assertEqual(entity_symbols.max_eid, 4) self.assertEqual(entity_symbols.max_alid, 4) self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases) self.assertDictEqual(entity_symbols._alias2qids, truealias2qids) self.assertIsNone(entity_symbols._alias_trie) self.assertDictEqual(entity_symbols._alias2id, truealias2id) self.assertDictEqual(entity_symbols._id2alias, trueid2alias)
def test_getters(self): truealias2qids = { "alias1": [["Q1", 10.0], ["Q4", 6]], "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]], "alias3": [["Q1", 30.0]], "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]], } trueqid2title = { "Q1": "alias1", "Q2": "multi alias2", "Q3": "word alias3", "Q4": "nonalias4", } entity_symbols = EntitySymbols( max_candidates=3, alias2qids=truealias2qids, qid2title=trueqid2title, ) self.assertEqual(entity_symbols.get_qid(1), "Q1") self.assertSetEqual( set(entity_symbols.get_all_aliases()), {"alias1", "multi word alias2", "alias3", "alias4"}, ) self.assertEqual(entity_symbols.get_eid("Q3"), 3) self.assertListEqual(entity_symbols.get_qid_cands("alias1"), ["Q1", "Q4"]) self.assertListEqual( entity_symbols.get_qid_cands("alias1", max_cand_pad=True), ["Q1", "Q4", "-1"], ) self.assertListEqual( entity_symbols.get_eid_cands("alias1", max_cand_pad=True), [1, 4, -1] ) self.assertEqual(entity_symbols.get_title("Q1"), "alias1") self.assertEqual(entity_symbols.get_alias_idx("alias1"), 0) self.assertEqual(entity_symbols.get_alias_from_idx(1), "alias3") self.assertEqual(entity_symbols.alias_exists("alias3"), True) self.assertEqual(entity_symbols.alias_exists("alias5"), False)