def test_checks(self): data = [ { "entity_id": "Q123", "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]], "title": "Dog", "types": {"hyena": ["animal"], "wiki": ["dog"]}, "relations": [ {"relation": "sibling", "object": "Q345"}, {"relation": "sibling", "object": "Q567"}, ], }, { "entity_id": "Q345", "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]], "title": "Cat", "types": {"hyena": ["animal"], "wiki": ["cat"]}, "relations": [{"relation": "sibling", "object": "Q123"}], }, ] self.write_data(self.profile_file, data) entity_profile = EntityProfile.load_from_jsonl( self.profile_file, max_candidates=5 ) with self.assertRaises(AttributeError) as context: entity_profile.add_relation("Q345", "sibling", "Q123") assert type(context.exception) is AttributeError entity_profile = EntityProfile.load_from_jsonl( self.profile_file, max_candidates=5, edit_mode=True ) with self.assertRaises(ValueError) as context: entity_profile.add_relation("Q789", "sibling", "Q123") assert type(context.exception) is ValueError assert "is not in our dump" in str(context.exception) with self.assertRaises(ValueError) as context: entity_profile.add_relation(qid="Q789", relation="sibling", qid2="Q123") assert type(context.exception) is ValueError assert "is not in our dump" in str(context.exception) with self.assertRaises(ValueError) as context: entity_profile.add_type(qid="Q345", type="sibling", type_system="blah") assert type(context.exception) is ValueError assert "type system blah is not one" in str(context.exception) with self.assertRaises(ValueError) as context: entity_profile.get_types(qid="Q345", type_system="blah") assert type(context.exception) is ValueError assert "type system blah is not one" in str(context.exception)
def test_profile_load_jsonl_errors(self): data = [ { "entity_id": 123, "mentions": [["dog"], ["dogg"], ["animal"]], "title": "Dog", "types": {"hyena": ["animal"], "wiki": ["dog"]}, "relations": [ {"relation": "sibling", "object": "Q345"}, {"relation": "sibling", "object": "Q567"}, ], }, ] self.write_data(self.profile_file, data) with self.assertRaises(ValidationError) as context: EntityProfile._read_profile_file(self.profile_file) assert type(context.exception) is ValidationError
def fit_profiles(args): print(json.dumps(vars(args), indent=4)) if args.model_config is not None: assert ( args.save_model_config is not None ), f"If you pass in a model config, you must pass in a model save config path" print(f"Loading train entity profile from {args.train_entity_profile}") train_entity_profile = EntityProfile.load_from_cache( load_dir=args.train_entity_profile) print(f"Loading new entity profile from {args.new_entity_profile}") new_entity_profile = EntityProfile.load_from_cache( load_dir=args.new_entity_profile) oldqid2newqid = dict() newqid2oldqid = dict() if args.oldqid2newqid is not None and len(args.oldqid2newqid) > 0: with open(args.oldqid2newqid) as in_f: oldqid2newqid = ujson.load(in_f) newqid2oldqid = {v: k for k, v in oldqid2newqid.items()} assert len(oldqid2newqid) == len( newqid2oldqid ), f"The dicts of oldqid2newqid and its inverse do not have the same length" np_removed_ents, np_same_ents, np_new_ents, oldeid2neweid = match_entities( train_entity_profile, new_entity_profile, oldqid2newqid, newqid2oldqid) neweid2oldeid = {v: k for k, v in oldeid2neweid.items()} assert len(oldeid2neweid) == len( neweid2oldeid ), f"The lengths of oldeid2neweid and neweid2oldeid don't match" state_dict, model_state_dict = load_statedict(args.model_path) # We do not support modifying a topK model. Only the original model. try: get_nested_item(model_state_dict, ENTITY_TOPK_KEYS) raise NotImplementedError( f"We don't support fitting a topK mini model. Instead, call `fit_to_profile` on the full Bootleg model. " f"Then call utils.entity_profile.compress_topk_entity_embeddings to create your own mini model." ) except: pass try: weight_shape = get_nested_item(model_state_dict, ENTITY_EMB_KEYS).shape[1] except: raise ValueError( f"ERROR: All of {ENTITY_EMB_KEYS} are not in model_state_dict") if args.init_vec_file is not None and len(args.init_vec_file) > 0: vector_for_new_ent = np.load(args.init_vec_file) else: print(f"Setting init vector to be all zeros") vector_for_new_ent = np.zeros(weight_shape) new_model_state_dict = refit_weights( np_same_ents, neweid2oldeid, train_entity_profile, new_entity_profile, vector_for_new_ent, model_state_dict, ) print(new_model_state_dict["module_pool"]["learned"].keys()) state_dict["model"] = new_model_state_dict print(f"Saving model at {args.save_model_path}") torch.save(state_dict, args.save_model_path) if args.model_config is not None: modify_config( args.model_config, args.save_model_config, args.save_model_path, args.new_entity_profile, )
def test_match_entities(self): # TEST ADD train_entity_profile = EntityProfile.load_from_cache( load_dir=self.train_db, edit_mode=True, ) # Modify train profle new_entity1 = { "entity_id": "Q910", "mentions": [["cobra", 10.0], ["animal", 3.0]], "title": "Cobra", "types": { "hyena": ["animal"], "wiki": ["dog"] }, "relations": [{ "relation": "sibling", "object": "Q123" }], } new_entity2 = { "entity_id": "Q101", "mentions": [["snake", 10.0], ["snakes", 7.0], ["animal", 3.0]], "title": "Snake", "types": { "hyena": ["animal"], "wiki": ["dog"] }, } train_entity_profile.add_entity(new_entity1) train_entity_profile.add_entity(new_entity2) train_entity_profile.reidentify_entity("Q123", "Q321") # Save new profile train_entity_profile.save(self.new_db) train_entity_profile = EntityProfile.load_from_cache( load_dir=self.train_db, ) new_entity_profile = EntityProfile.load_from_cache( load_dir=self.new_db) oldqid2newqid = {"Q123": "Q321"} newqid2oldqid = {v: k for k, v in oldqid2newqid.items()} np_removed_ents, np_same_ents, np_new_ents, oldeid2neweid = match_entities( train_entity_profile, new_entity_profile, oldqid2newqid, newqid2oldqid) gold_removed_ents = set() gold_same_ents = {"Q321", "Q345", "Q567", "Q789"} gold_new_ents = {"Q910", "Q101"} gold_oldeid2neweid = {1: 1, 2: 2, 3: 3, 4: 4} self.assertSetEqual(gold_removed_ents, np_removed_ents) self.assertSetEqual(gold_same_ents, np_same_ents) self.assertSetEqual(gold_new_ents, np_new_ents) self.assertDictEqual(gold_oldeid2neweid, oldeid2neweid) # TEST PRUNE - this profile already has 910 and 101 new_entity_profile = EntityProfile.load_from_cache( load_dir=self.new_db, edit_mode=True) new_entity_profile.prune_to_entities({"Q321", "Q910", "Q101" }) # These now get eids 1, 2, 3 # Manually set the eids for the test new_entity_profile._entity_symbols._qid2eid = { "Q321": 3, "Q910": 2, "Q101": 1 } new_entity_profile._entity_symbols._eid2qid = { 3: "Q321", 2: "Q910", 1: "Q101" } # Save new profile new_entity_profile.save(self.new_db) train_entity_profile = EntityProfile.load_from_cache( load_dir=self.train_db, ) new_entity_profile = EntityProfile.load_from_cache( load_dir=self.new_db) oldqid2newqid = {"Q123": "Q321"} newqid2oldqid = {v: k for k, v in oldqid2newqid.items()} np_removed_ents, np_same_ents, np_new_ents, oldeid2neweid = match_entities( train_entity_profile, new_entity_profile, oldqid2newqid, newqid2oldqid) gold_removed_ents = {"Q345", "Q567", "Q789"} gold_same_ents = {"Q321"} gold_new_ents = {"Q910", "Q101"} gold_oldeid2neweid = {1: 3} self.assertSetEqual(gold_removed_ents, np_removed_ents) self.assertSetEqual(gold_same_ents, np_same_ents) self.assertSetEqual(gold_new_ents, np_new_ents) self.assertDictEqual(gold_oldeid2neweid, oldeid2neweid)
def test_fit_entities(self): train_entity_profile = EntityProfile.load_from_cache( load_dir=self.train_db, edit_mode=True, ) # Modify train profle new_entity1 = { "entity_id": "Q910", "mentions": [["cobra", 10.0], ["animal", 3.0]], "title": "Cobra", "types": { "hyena": ["animal"], "wiki": ["dog"] }, "relations": [{ "relation": "sibling", "object": "Q123" }], } new_entity2 = { "entity_id": "Q101", "mentions": [["snake", 10.0], ["snakes", 7.0], ["animal", 3.0]], "title": "Snake", "types": { "hyena": ["animal"], "wiki": ["dog"] }, } train_entity_profile.add_entity(new_entity1) train_entity_profile.add_entity(new_entity2) train_entity_profile.reidentify_entity("Q123", "Q321") # Save new profile train_entity_profile.save(self.new_db) train_entity_profile = EntityProfile.load_from_cache( load_dir=self.train_db, ) new_entity_profile = EntityProfile.load_from_cache( load_dir=self.new_db) neweid2oldeid = {1: 1, 2: 2, 3: 3, 4: 4} np_same_ents = {"Q321", "Q345", "Q567", "Q789"} state_dict = { "module_pool": { "learned": { # 4 entities + 1 UNK + 1 PAD for train data "learned_entity_embedding.weight": torch.tensor([ [1.0, 2, 3, 4, 5], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4], [5, 5, 5, 5, 5], [0, 0, 0, 0, 0], ]) } } } vector_for_new_ent = np.arange(5) new_state_dict = refit_weights( np_same_ents, neweid2oldeid, train_entity_profile, new_entity_profile, vector_for_new_ent, state_dict, ) gold_state_dict = { "module_pool": { "learned": { "learned_entity_embedding.weight": torch.tensor([ [1.0, 2, 3, 4, 5], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4], [5, 5, 5, 5, 5], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 0, 0, 0, 0], ]) } } } gld = gold_state_dict nsd = new_state_dict keys_to_check = [ "module_pool", "learned", "learned_entity_embedding.weight" ] for k in keys_to_check: assert k in nsd assert k in gld if type(gld[k]) is dict: gld = gld[k] nsd = nsd[k] continue else: assert torch.equal(nsd[k], gld[k]) # TEST WITH EIDREG state_dict = { "module_pool": { "learned": { # 4 entities + 1 UNK + 1 PAD for train data "learned_entity_embedding.weight": torch.tensor([ [1.0, 2, 3, 4, 5], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4], [5, 5, 5, 5, 5], [0, 0, 0, 0, 0], ]), "eid2reg": torch.tensor([0.0, 0.2, 0.3, 0.4, 0.5, 0.0], ), } } } vector_for_new_ent = np.arange(5) new_state_dict = refit_weights( np_same_ents, neweid2oldeid, train_entity_profile, new_entity_profile, vector_for_new_ent, state_dict, ) gold_state_dict = { "module_pool": { "learned": { "learned_entity_embedding.weight": torch.tensor([ [1.0, 2, 3, 4, 5], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4], [5, 5, 5, 5, 5], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 0, 0, 0, 0], ]), "eid2reg": torch.tensor([0.0, 0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.0], ), } } } gld = gold_state_dict nsd = new_state_dict keys_to_check = [ "module_pool", "learned", "learned_entity_embedding.weight" ] for k in keys_to_check: assert k in nsd assert k in gld if type(gld[k]) is dict: gld = gld[k] nsd = nsd[k] continue else: assert torch.equal(nsd[k], gld[k]) gld = gold_state_dict nsd = new_state_dict keys_to_check = ["module_pool", "learned", "eid2reg"] for k in keys_to_check: assert k in nsd assert k in gld if type(gld[k]) is dict: gld = gld[k] nsd = nsd[k] continue else: assert torch.equal(nsd[k], gld[k]) # TEST WITH PRUNE # TEST PRUNE - this profile already has 910 and 101 new_entity_profile = EntityProfile.load_from_cache( load_dir=self.new_db, edit_mode=True) new_entity_profile.prune_to_entities({"Q321", "Q910", "Q101" }) # These now get eids 1, 2, 3 # Manually set the eids for the test new_entity_profile._entity_symbols._qid2eid = { "Q321": 3, "Q910": 2, "Q101": 1 } new_entity_profile._entity_symbols._eid2qid = { 3: "Q321", 2: "Q910", 1: "Q101" } new_entity_profile.save(self.new_db) train_entity_profile = EntityProfile.load_from_cache( load_dir=self.train_db, ) new_entity_profile = EntityProfile.load_from_cache( load_dir=self.new_db) neweid2oldeid = {3: 1} np_same_ents = {"Q321"} state_dict = { "module_pool": { "learned": { # 4 entities + 1 UNK + 1 PAD for train data "learned_entity_embedding.weight": torch.tensor([ [1.0, 2, 3, 4, 5], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3], [4, 4, 4, 4, 4], [5, 5, 5, 5, 5], [0, 0, 0, 0, 0], ]), "eid2reg": torch.tensor([0.0, 0.2, 0.3, 0.4, 0.5, 0.0], ), } } } vector_for_new_ent = np.arange(5) new_state_dict = refit_weights( np_same_ents, neweid2oldeid, train_entity_profile, new_entity_profile, vector_for_new_ent, state_dict, ) gold_state_dict = { "module_pool": { "learned": { "learned_entity_embedding.weight": torch.tensor([ [1.0, 2, 3, 4, 5], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [2, 2, 2, 2, 2], [0, 0, 0, 0, 0], ]), "eid2reg": torch.tensor([0.0, 0.5, 0.5, 0.2, 0.0], ), } } } gld = gold_state_dict nsd = new_state_dict keys_to_check = [ "module_pool", "learned", "learned_entity_embedding.weight" ] for k in keys_to_check: assert k in nsd assert k in gld if type(gld[k]) is dict: gld = gld[k] nsd = nsd[k] continue else: assert torch.equal(nsd[k], gld[k]) gld = gold_state_dict nsd = new_state_dict keys_to_check = ["module_pool", "learned", "eid2reg"] for k in keys_to_check: assert k in nsd assert k in gld if type(gld[k]) is dict: gld = gld[k] nsd = nsd[k] continue else: assert torch.equal(nsd[k], gld[k])
def setUp(self) -> None: self.dir = Path("test/data/fit_to_profile_test") self.train_db = Path(self.dir / "train_entity_db") self.train_db.mkdir(exist_ok=True, parents=True) self.new_db = Path(self.dir / "entity_db_save2") self.new_db.mkdir(exist_ok=True, parents=True) self.profile_file = Path(self.dir / "raw_data/entity_profile.jsonl") self.profile_file.parent.mkdir(exist_ok=True, parents=True) # Dump train profile data data = [ { "entity_id": "Q123", "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]], "title": "Dog", "types": { "hyena": ["animal"], "wiki": ["dog"] }, "relations": [ { "relation": "sibling", "object": "Q345" }, { "relation": "sibling", "object": "Q567" }, ], }, { "entity_id": "Q345", "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]], "title": "Cat", "types": { "hyena": ["animal"], "wiki": ["cat"] }, "relations": [{ "relation": "sibling", "object": "Q123" }], }, # Missing type system { "entity_id": "Q567", "mentions": [["catt", 6.5], ["animal", 3.3]], "title": "Catt", "types": { "hyena": ["animal", "animall"] }, "relations": [{ "relation": "sibling", "object": "Q123" }], }, # No KG/Types { "entity_id": "Q789", "mentions": [["animal", 12.2]], "title": "Dogg", }, ] self.write_data(self.profile_file, data) ep = EntityProfile.load_from_jsonl(self.profile_file) ep.save(self.train_db)
def test_end2end(self): # ====================== # PART 1: TRAIN A SMALL MODEL WITH ONE PROFILE DUMP # ====================== # Generate entity profile data data = [ { "entity_id": "Q123", "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]], "title": "Dog", "types": {"hyena": ["animal"], "wiki": ["dog"]}, "relations": [ {"relation": "sibling", "object": "Q345"}, {"relation": "sibling", "object": "Q567"}, ], }, { "entity_id": "Q345", "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]], "title": "Cat", "types": {"hyena": ["animal"], "wiki": ["cat"]}, "relations": [{"relation": "sibling", "object": "Q123"}], }, # Missing type system { "entity_id": "Q567", "mentions": [["catt", 6.5], ["animal", 3.3]], "title": "Catt", "types": {"hyena": ["animal", "animall"]}, "relations": [{"relation": "sibling", "object": "Q123"}], }, # No KG/Types { "entity_id": "Q789", "mentions": [["animal", 12.2]], "title": "Dogg", }, ] # Generate train data train_data = [ { "sent_idx_unq": 1, "sentence": "I love animals and dogs", "qids": ["Q567", "Q123"], "aliases": ["animal", "dog"], "gold": [True, True], "spans": [[2, 3], [4, 5]], } ] self.write_data(self.profile_file, data) self.write_data(self.train_data, train_data) entity_profile = EntityProfile.load_from_jsonl(self.profile_file) # Dump profile data in format for model entity_profile.save(self.save_dir) # Setup model args to read the data/new profile data raw_args = { "emmental": { "n_epochs": 1, }, "run_config": { "dataloader_threads": 1, "dataset_threads": 1, }, "train_config": {"batch_size": 2}, "model_config": {"hidden_size": 20, "num_heads": 1}, "data_config": { "entity_dir": str(self.save_dir), "max_seq_len": 7, "max_aliases": 2, "data_dir": str(self.data_dir), "emb_dir": str(self.save_dir), "word_embedding": { "layers": 1, "freeze": True, "cache_dir": str(self.save_dir / "retrained_bert_model"), }, "ent_embeddings": [ { "key": "learned", "freeze": False, "load_class": "LearnedEntityEmb", "args": {"learned_embedding_size": 10}, }, { "key": "learned_type", "load_class": "LearnedTypeEmb", "freeze": True, "args": { "type_labels": f"{TYPE_SUBFOLDER}/wiki/qid2typeids.json", "type_vocab": f"{TYPE_SUBFOLDER}/wiki/type_vocab.json", "max_types": 2, "type_dim": 10, }, }, { "key": "kg_adj", "load_class": "KGIndices", "batch_on_the_fly": True, "normalize": False, "args": {"kg_adj": f"{KG_SUBFOLDER}/kg_adj.txt"}, }, ], "train_dataset": {"file": "train.jsonl"}, "dev_dataset": {"file": "train.jsonl"}, "test_dataset": {"file": "train.jsonl"}, }, } with open(self.arg_file, "w") as out_f: ujson.dump(raw_args, out_f) args = parser_utils.parse_boot_and_emm_args(str(self.arg_file)) # This _MUST_ get passed the args so it gets a random seed set emmental.init(log_dir=str(self.dir / "temp_log"), config=args) if not os.path.exists(emmental.Meta.log_path): os.makedirs(emmental.Meta.log_path) scores = run_model(mode="train", config=args) saved_model_path1 = f"{emmental.Meta.log_path}/last_model.pth" assert type(scores) is dict # ====================== # PART 3: MODIFY PROFILE AND LOAD PRETRAINED MODEL AND TRAIN FOR MORE # ====================== entity_profile = EntityProfile.load_from_jsonl( self.profile_file, edit_mode=True ) entity_profile.add_type("Q123", "cat", "wiki") entity_profile.remove_type("Q123", "dog", "wiki") entity_profile.add_mention("Q123", "cat", 100.0) # Dump profile data in format for model entity_profile.save(self.save_dir2) # Modify arg paths args["data_config"]["entity_dir"] = str(self.save_dir2) args["data_config"]["emb_dir"] = str(self.save_dir2) # Load pretrained model args["model_config"]["model_path"] = f"{emmental.Meta.log_path}/last_model.pth" emmental.Meta.config["model_config"]["model_path"] = saved_model_path1 # Init another run emmental.init(log_dir=str(self.dir / "temp_log"), config=args) if not os.path.exists(emmental.Meta.log_path): os.makedirs(emmental.Meta.log_path) scores = run_model(mode="train", config=args) saved_model_path2 = f"{emmental.Meta.log_path}/last_model.pth" assert type(scores) is dict # ====================== # PART 4: VERIFY CHANGES IN THE MODEL WERE AS EXPECTED # ====================== # Check that type mappings are different in the right way...we remove "dog" # from EID 1 and added "cat". "dog" is not longer a type. eid2typeids_table1, type2row_dict1, num_types_with_unk1 = torch.load( self.save_dir / "prep" / "type_table_type_mappings_wiki_qid2typeids_2.pt" ) eid2typeids_table2, type2row_dict2, num_types_with_unk2 = torch.load( self.save_dir2 / "prep" / "type_table_type_mappings_wiki_qid2typeids_2.pt" ) # Modify mapping 2 to become mapping 1 # Row 1 is Q123, Col 0 is type (this was "cat") eid2typeids_table2[1][0] = entity_profile._type_systems["wiki"]._type_vocab[ "dog" ] self.assertEqual(num_types_with_unk1, num_types_with_unk2) self.assertDictEqual({1: 1, 2: 2}, type2row_dict1) self.assertDictEqual({1: 1}, type2row_dict2) assert torch.equal(eid2typeids_table1, eid2typeids_table2) # Check that the alias mappings are different alias2entity_table1 = torch.from_numpy( np.memmap( self.save_dir / "prep" / "alias2entity_table_alias2qids_InC1.pt", dtype="int64", mode="r", shape=(5, 30), ) ) gold_alias2entity_table1 = torch.tensor( [ [ 4, 1, 3, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ], [ 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ], [ 2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ], [ 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ], [ 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ], ] ) assert torch.equal(alias2entity_table1, gold_alias2entity_table1) # The change is the "cat" alias has entity 1 added to the beginning # It used to only point to Q345 which is entity 2 alias2entity_table2 = torch.from_numpy( np.memmap( self.save_dir2 / "prep" / "alias2entity_table_alias2qids_InC1.pt", dtype="int64", mode="r", shape=(5, 30), ) ) gold_alias2entity_table2 = torch.tensor( [ [ 4, 1, 3, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ], [ 1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ], [ 2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ], [ 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ], [ 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, ], ] ) assert torch.equal(alias2entity_table2, gold_alias2entity_table2) # The type embeddings were frozen so they should be the same model1 = torch.load(saved_model_path1) model2 = torch.load(saved_model_path2) assert torch.equal( model1["model"]["module_pool"]["learned_type"]["type_emb.weight"], model2["model"]["module_pool"]["learned_type"]["type_emb.weight"], )
def test_prune_to_entities(self): data = [ { "entity_id": "Q123", "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]], "title": "Dog", "types": {"hyena": ["animal"], "wiki": ["dog"]}, "relations": [ {"relation": "sibling", "object": "Q345"}, {"relation": "sibling", "object": "Q567"}, ], }, { "entity_id": "Q345", "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]], "title": "Cat", "types": {"hyena": ["animal"], "wiki": ["cat"]}, "relations": [{"relation": "sibling", "object": "Q123"}], }, ] self.write_data(self.profile_file, data) entity_profile = EntityProfile.load_from_jsonl( self.profile_file, max_candidates=5, edit_mode=True ) entity_profile.save(self.save_dir2) with self.assertRaises(ValueError) as context: entity_profile.prune_to_entities({"Q123", "Q567"}) assert type(context.exception) is ValueError assert "The entity Q567 does not exist" in str(context.exception) entity_profile.prune_to_entities({"Q123"}) self.assertTrue(entity_profile.qid_exists("Q123")) self.assertFalse(entity_profile.qid_exists("Q345")) self.assertListEqual( entity_profile.get_mentions_with_scores("Q123"), [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]], ) self.assertListEqual(entity_profile.get_types("Q123", "hyena"), ["animal"]) self.assertListEqual(entity_profile.get_types("Q123", "wiki"), ["dog"]) self.assertListEqual( entity_profile.get_connections_by_relation("Q123", "sibling"), [], ) # Check that no_kg still works with load_from_cache entity_profile2 = EntityProfile.load_from_cache( self.save_dir2, no_kg=True, edit_mode=True ) entity_profile2.prune_to_entities({"Q123"}) self.assertTrue(entity_profile2.qid_exists("Q123")) self.assertFalse(entity_profile2.qid_exists("Q345")) self.assertListEqual( entity_profile2.get_mentions_with_scores("Q123"), [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]], ) self.assertListEqual(entity_profile2.get_types("Q123", "hyena"), ["animal"]) self.assertListEqual(entity_profile2.get_types("Q123", "wiki"), ["dog"]) self.assertListEqual( entity_profile2.get_connections_by_relation("Q123", "sibling"), [], )
def test_profile_load_simple(self): data = [ { "entity_id": "Q123", "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]], "title": "Dog", "types": {"hyena": ["animal"], "wiki": ["dog"]}, "relations": [ {"relation": "sibling", "object": "Q345"}, {"relation": "sibling", "object": "Q567"}, ], }, { "entity_id": "Q345", "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]], "title": "Cat", "types": {"hyena": ["animal"], "wiki": ["cat"]}, "relations": [{"relation": "sibling", "object": "Q123"}], }, # Missing type system { "entity_id": "Q567", "mentions": [["catt", 6.5], ["animal", 3.3]], "title": "Catt", "types": {"hyena": ["animal", "animall"]}, "relations": [{"relation": "sibling", "object": "Q123"}], }, # No KG/Types { "entity_id": "Q789", "mentions": [["animal", 12.2]], "title": "Dogg", }, ] self.write_data(self.profile_file, data) gold_qid2title = {"Q123": "Dog", "Q345": "Cat", "Q567": "Catt", "Q789": "Dogg"} gold_alias2qids = { "dog": [["Q123", 10.0]], "dogg": [["Q123", 7.0]], "cat": [["Q345", 10.0]], "catt": [["Q345", 7.0], ["Q567", 6.5]], "animal": [["Q789", 12.2], ["Q123", 4.0], ["Q567", 3.3], ["Q345", 3.0]], } gold_type_systems = { "hyena": { "Q123": ["animal"], "Q345": ["animal"], "Q567": ["animal", "animall"], "Q789": [], }, "wiki": {"Q123": ["dog"], "Q345": ["cat"], "Q567": [], "Q789": []}, } gold_qid2relations = { "Q123": {"sibling": ["Q345", "Q567"]}, "Q345": {"sibling": ["Q123"]}, "Q567": {"sibling": ["Q123"]}, "Q789": {}, } ( qid2title, alias2qids, type_systems, qid2relations, ) = EntityProfile._read_profile_file(self.profile_file) self.assertDictEqual(gold_qid2title, qid2title) self.assertDictEqual(gold_alias2qids, alias2qids) self.assertDictEqual(gold_type_systems, type_systems) self.assertDictEqual(gold_qid2relations, qid2relations) # Test loading/saving from jsonl ep = EntityProfile.load_from_jsonl(self.profile_file, edit_mode=True) ep.save_to_jsonl(self.profile_file) read_in_data = [ujson.loads(li) for li in open(self.profile_file)] assert len(read_in_data) == len(data) for qid_obj in data: found_other_obj = None for possible_match in read_in_data: if qid_obj["entity_id"] == possible_match["entity_id"]: found_other_obj = possible_match break assert found_other_obj is not None self.assertDictEqual(qid_obj, found_other_obj)
def test_add_entity(self): data = [ { "entity_id": "Q123", "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]], "title": "Dog", "types": {"hyena": ["animal"], "wiki": ["dog"]}, "relations": [ {"relation": "sibling", "object": "Q345"}, {"relation": "sibling", "object": "Q567"}, ], }, { "entity_id": "Q345", "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]], "title": "Cat", "types": {"hyena": ["animal"], "wiki": ["cat"]}, "relations": [{"relation": "sibling", "object": "Q123"}], }, ] self.write_data(self.profile_file, data) entity_profile = EntityProfile.load_from_jsonl( self.profile_file, max_candidates=3, edit_mode=True ) entity_profile.save(self.save_dir2) # Test bad format with self.assertRaises(ValueError) as context: entity_profile.add_entity(["bad format"]) assert type(context.exception) is ValueError assert "The input to update_entity needs to be a dictionary" in str( context.exception ) new_entity = { "entity_id": "Q345", "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]], "title": "Cat", "types": {"hyena": ["animal"], "wiki": ["cat"]}, "relations": [{"relation": "sibling", "object": "Q123"}], } # Test already existing entity with self.assertRaises(ValueError) as context: entity_profile.add_entity(new_entity) assert type(context.exception) is ValueError assert "The entity Q345 already exists" in str(context.exception) new_entity = { "entity_id": "Q789", "mentions": [["snake", 10.0], ["animal", 3.0]], "title": "Snake", "types": {"hyena": ["animal"], "new_sys": ["snakey"]}, "relations": [{"relation": "sibling", "object": "Q123"}], } # Test new type system with self.assertRaises(ValueError) as context: entity_profile.add_entity(new_entity) assert type(context.exception) is ValueError assert "When adding a new entity, you must use the same type system" in str( context.exception ) new_entity = { "entity_id": "Q789", "mentions": [["snake", 10.0], ["animal", 3.0]], "title": "Snake", "types": {"hyena": ["animal"]}, "relations": [{"relatiion": "sibbbling", "object": "Q123"}], } # Test new bad relation format with self.assertRaises(ValueError) as context: entity_profile.add_entity(new_entity) assert type(context.exception) is ValueError assert ( "For each value in relations, it must be a JSON with keys relation and object" in str(context.exception) ) new_entity = { "entity_id": "Q789", "mentions": [["snake", 10.0], ["animal", 3.0]], "title": "Snake", "types": {"hyena": ["animal"]}, "relations": [{"relation": "sibbbling", "object": "Q123"}], } # Test new relation with self.assertRaises(ValueError) as context: entity_profile.add_entity(new_entity) assert type(context.exception) is ValueError assert ( "When adding a new entity, you must use the same set of relations." in str(context.exception) ) new_entity = { "entity_id": "Q789", "mentions": [["snake", 10.0], ["animal", 3.0]], "title": "Snake", "types": {"hyena": ["animal"]}, "relations": [{"relation": "sibling", "object": "Q123"}], } # Assert it is added entity_profile.add_entity(new_entity) self.assertTrue(entity_profile.qid_exists("Q789")) self.assertEqual(entity_profile.get_title("Q789"), "Snake") self.assertListEqual( entity_profile.get_mentions_with_scores("Q789"), [["snake", 10.0], ["animal", 3.0]], ) self.assertListEqual(entity_profile.get_types("Q789", "hyena"), ["animal"]) self.assertListEqual(entity_profile.get_types("Q789", "wiki"), []) self.assertListEqual( entity_profile.get_connections_by_relation("Q789", "sibling"), ["Q123"] ) # Check that no_kg still works with load_from_cache entity_profile2 = EntityProfile.load_from_cache( self.save_dir2, no_kg=True, edit_mode=True ) entity_profile2.add_entity(new_entity) self.assertTrue(entity_profile2.qid_exists("Q789")) self.assertEqual(entity_profile2.get_title("Q789"), "Snake") self.assertListEqual( entity_profile2.get_mentions_with_scores("Q789"), [["snake", 10.0], ["animal", 3.0]], ) self.assertListEqual(entity_profile2.get_types("Q789", "hyena"), ["animal"]) self.assertListEqual(entity_profile2.get_types("Q789", "wiki"), []) self.assertListEqual( entity_profile2.get_connections_by_relation("Q789", "sibling"), [] )
def test_profile_dump_load(self): data = [ { "entity_id": "Q123", "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]], "title": "Dog", "types": {"hyena": ["animal"], "wiki": ["dog"]}, "relations": [ {"relation": "sibling", "object": "Q345"}, {"relation": "sibling", "object": "Q567"}, ], }, { "entity_id": "Q345", "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]], "title": "Cat", "types": {"hyena": ["animal"], "wiki": ["cat"]}, "relations": [{"relation": "sibling", "object": "Q123"}], }, ] self.write_data(self.profile_file, data) entity_profile = EntityProfile.load_from_jsonl( self.profile_file, max_candidates=5, edit_mode=True ) entity_profile.save(self.save_dir2) # Test load correctly entity_profile2 = EntityProfile.load_from_cache(self.save_dir2) self.assertSetEqual( set(entity_profile.get_all_qids()), set(entity_profile2.get_all_qids()) ) self.assertSetEqual( set(entity_profile.get_all_typesystems()), set(entity_profile2.get_all_typesystems()), ) for type_sys in entity_profile.get_all_typesystems(): self.assertSetEqual( set(entity_profile.get_all_types(type_sys)), set(entity_profile2.get_all_types(type_sys)), ) for qid in entity_profile.get_all_qids(): self.assertSetEqual( set(entity_profile.get_all_connections(qid)), set(entity_profile2.get_all_connections(qid)), ) # Test load with no types or kgs entity_profile2 = EntityProfile.load_from_cache( self.save_dir2, no_type=True, no_kg=True ) self.assertSetEqual( set(entity_profile.get_all_qids()), set(entity_profile2.get_all_qids()) ) assert len(entity_profile2.get_all_typesystems()) == 0 self.assertIsNone(entity_profile2._kg_symbols) # Testing that the functions still work despite not loading them assert len(entity_profile2.get_all_connections("Q123")) == 0 # Test load with no types or kgs entity_profile2 = EntityProfile.load_from_cache( self.save_dir2, no_kg=True, type_systems_to_load=["wiki"] ) self.assertSetEqual( set(entity_profile.get_all_qids()), set(entity_profile2.get_all_qids()) ) assert entity_profile2.get_all_typesystems() == ["wiki"] self.assertSetEqual( set(entity_profile.get_all_types("wiki")), set(entity_profile2.get_all_types("wiki")), ) self.assertIsNone(entity_profile2._kg_symbols) # Assert error loading type system that is not there with self.assertRaises(ValueError) as context: entity_profile2.get_all_types("hyena") assert type(context.exception) is ValueError assert "type system hyena is not one" in str(context.exception)
def fit_profiles(args): print(json.dumps(vars(args), indent=4)) if args.model_config is not None: assert ( args.save_model_config is not None ), f"If you pass in a model config, you must pass in a model save config path" print(f"Loading train entity profile from {args.train_entity_profile}") train_entity_profile = EntityProfile.load_from_cache( load_dir=args.train_entity_profile, no_type=True, no_kg=True ) print(f"Loading new entity profile from {args.new_entity_profile}") new_entity_profile = EntityProfile.load_from_cache( load_dir=args.new_entity_profile, no_type=True, no_kg=True ) oldqid2newqid = dict() newqid2oldqid = dict() if args.oldqid2newqid is not None and len(args.oldqid2newqid) > 0: with open(args.oldqid2newqid) as in_f: oldqid2newqid = ujson.load(in_f) newqid2oldqid = {v: k for k, v in oldqid2newqid.items()} assert len(oldqid2newqid) == len( newqid2oldqid ), f"The dicts of oldqid2newqid and its inverse do not have the same length" np_removed_ents, np_same_ents, np_new_ents, oldeid2neweid = match_entities( train_entity_profile, new_entity_profile, oldqid2newqid, newqid2oldqid ) neweid2oldeid = {v: k for k, v in oldeid2neweid.items()} assert len(oldeid2neweid) == len( neweid2oldeid ), f"The lengths of oldeid2neweid and neweid2oldeid don't match" state_dict, model_state_dict = load_statedict(args.model_path) # We do not support modifying a topK model. Only the original model. try: get_nested_item(model_state_dict, ENTITY_TOPK_KEYS) raise NotImplementedError( f"We don't support fitting a topK mini model. Instead, call `fit_to_profile` on the full Bootleg model. " f"Then call utils.entity_profile.compress_topk_entity_embeddings to create your own mini model." ) except: pass try: weight_shape = get_nested_item(model_state_dict, ENTITY_EMB_KEYS).shape[1] except: raise ValueError(f"ERROR: All of {ENTITY_EMB_KEYS} are not in model_state_dict") # Refit weights if args.init_vec_file is not None and len(args.init_vec_file) > 0: vector_for_new_ent = np.load(args.init_vec_file) else: print(f"Setting init vector to be all zeros") vector_for_new_ent = np.zeros(weight_shape) new_model_state_dict = refit_weights( np_same_ents, neweid2oldeid, train_entity_profile, new_entity_profile, vector_for_new_ent, model_state_dict, ) state_dict["model"] = new_model_state_dict print(f"Saving model at {args.save_model_path}") torch.save(state_dict, args.save_model_path) del new_model_state_dict del state_dict # Refit titles # Will keep track of all embeddings to adjust. If given a config, we will only adjust # the one from the config. Othwerwise, we adjust all that are "static_table_" arrays of # length BERT_DIM if not args.no_title_emb: title_embeddings = [] prepped_title_emb_files = [] title_emb_file = None prep_subdir = "prep" # First try to read entity_prep_dir from config if args.model_config is not None: with open(args.model_config) as file: config = yaml.load(file, Loader=yaml.FullLoader) prep_subdir = config["data_config"].get("entity_prep_dir", "prep") for ent in config["data_config"]["ent_embeddings"]: if ent["load_class"] == "StaticEmb" and ent["key"] == "title_static": assert ( "emb_file" in ent["args"] ), f"emb_file needs to be in title_static config" title_emb_file = ent["args"]["emb_file"] prep_dir = Path(args.train_entity_profile) / prep_subdir out_prep_dir = Path(args.new_entity_profile) / prep_subdir print(f"Looking for title embedding in {prep_dir}") # Try to find a saved title prep file for file in prep_dir.iterdir(): if file.is_file() and file.name.startswith("static_table_"): # If we know the title embedding file from the config, use it to find the right prepped file if ( title_emb_file is not None and file.name != f"static_table_{os.path.splitext(os.path.basename(title_emb_file))[0]}.npy" ): continue possible_titles = np.load(file, mmap_mode="r") if list(possible_titles.shape) == [ train_entity_profile.num_entities_with_pad_and_nocand, BERT_DIM, ]: title_embeddings.append(possible_titles) prepped_title_emb_files.append(file.name) if len(title_embeddings) == 0: print( f"We were unable to adjust titles. If your model does not use title embeddings, ignore this. If your" f"model does (all Bootleg models by default do), please call " f"```python -m bootleg.utils.preprocessing.build_static_embeddings --help``` to extract manually. " f"The saved file from this method should be added to ```emb_file``` config param for the title" f"embedding." ) else: for title_embed, prepped_title_emb_file in zip( title_embeddings, prepped_title_emb_files ): print(f"Attempting to refit title {prepped_title_emb_file}") out_prep_dir.mkdir(parents=True, exist_ok=True) save_file = out_prep_dir / prepped_title_emb_file print(f"Saving to {save_file}") # Returns memmap array _ = refit_titles( np_same_ents, np_new_ents, neweid2oldeid, train_entity_profile, new_entity_profile, title_embed, str(save_file), args.bert_model, args.bert_model_cache, args.cpu, ) if args.model_config is not None: modify_config( args.model_config, args.save_model_config, args.save_model_path, args.new_entity_profile, )
def test_fit_titles(self): train_entity_profile = EntityProfile.load_from_cache( load_dir=self.train_db, edit_mode=True, ) # Modify train profle new_entity1 = { "entity_id": "Q910", "mentions": [["cobra", 10.0], ["animal", 3.0]], "title": "Cobra", "types": { "hyena": ["animal"], "wiki": ["dog"] }, "relations": [{ "relation": "sibling", "object": "Q123" }], } new_entity2 = { "entity_id": "Q101", "mentions": [["snake", 10.0], ["snakes", 7.0], ["animal", 3.0]], "title": "Snake", "types": { "hyena": ["animal"], "wiki": ["dog"] }, } train_entity_profile.add_entity(new_entity1) train_entity_profile.add_entity(new_entity2) train_entity_profile.reidentify_entity("Q123", "Q321") # Save new profile train_entity_profile.save(self.new_db) # Create old title embs title_embs = np.zeros(( 6, 768, )) for i in range(len(title_embs)): title_embs[i] = np.ones(768) * i train_entity_profile = EntityProfile.load_from_cache( load_dir=self.train_db, ) new_entity_profile = EntityProfile.load_from_cache( load_dir=self.new_db) neweid2oldeid = {1: 1, 2: 2, 3: 3, 4: 4} np_same_ents = {"Q321", "Q345", "Q567", "Q789"} np_new_ents = {"Q910", "Q101"} new_title_embs = refit_titles( np_same_ents, np_new_ents, neweid2oldeid, train_entity_profile, new_entity_profile, title_embs, str(self.dir / "temp_title.npy"), "bert-base-uncased", str(self.dir / "temp_bert_models"), True, ) # Compute gold BERT titles for new entities tokenizer = BertTokenizer.from_pretrained( "bert-base-uncased", do_lower_case=True, cache_dir=str(self.dir / "temp_bert_models"), ) model = BertModel.from_pretrained( "bert-base-uncased", cache_dir=str(self.dir / "temp_bert_models"), output_attentions=False, output_hidden_states=False, ) model.eval() input_ids = tokenizer(["Cobra", "Snake"], padding=True, truncation=True, return_tensors="pt") inputs = input_ids["input_ids"] attention_mask = input_ids["attention_mask"] outputs = model(inputs, attention_mask=attention_mask)[0] outputs[inputs == 0] = 0 avgtitle = average_titles(inputs, outputs).to("cpu").detach().numpy() gold_title_embs = np.zeros(( 8, 768, )) gold_title_embs[0] = np.ones(768) * 0 gold_title_embs[1] = np.ones(768) * 1 gold_title_embs[2] = np.ones(768) * 2 gold_title_embs[3] = np.ones(768) * 3 gold_title_embs[4] = np.ones(768) * 4 gold_title_embs[5] = avgtitle[0] gold_title_embs[6] = avgtitle[1] gold_title_embs[7] = np.ones( 768) * 0 # Last row is set to zero in fit method np.testing.assert_array_almost_equal(new_title_embs, gold_title_embs)