Exemplo n.º 1
0
def fit_profiles(args):
    print(json.dumps(vars(args), indent=4))

    if args.model_config is not None:
        assert (
            args.save_model_config is not None
        ), f"If you pass in a model config, you must pass in a model save config path"

    print(f"Loading train entity profile from {args.train_entity_profile}")
    train_entity_profile = EntityProfile.load_from_cache(
        load_dir=args.train_entity_profile)
    print(f"Loading new entity profile from {args.new_entity_profile}")
    new_entity_profile = EntityProfile.load_from_cache(
        load_dir=args.new_entity_profile)

    oldqid2newqid = dict()
    newqid2oldqid = dict()
    if args.oldqid2newqid is not None and len(args.oldqid2newqid) > 0:
        with open(args.oldqid2newqid) as in_f:
            oldqid2newqid = ujson.load(in_f)
            newqid2oldqid = {v: k for k, v in oldqid2newqid.items()}
            assert len(oldqid2newqid) == len(
                newqid2oldqid
            ), f"The dicts of oldqid2newqid and its inverse do not have the same length"

    np_removed_ents, np_same_ents, np_new_ents, oldeid2neweid = match_entities(
        train_entity_profile, new_entity_profile, oldqid2newqid, newqid2oldqid)
    neweid2oldeid = {v: k for k, v in oldeid2neweid.items()}
    assert len(oldeid2neweid) == len(
        neweid2oldeid
    ), f"The lengths of oldeid2neweid and neweid2oldeid don't match"
    state_dict, model_state_dict = load_statedict(args.model_path)

    # We do not support modifying a topK model. Only the original model.
    try:
        get_nested_item(model_state_dict, ENTITY_TOPK_KEYS)
        raise NotImplementedError(
            f"We don't support fitting a topK mini model. Instead, call `fit_to_profile` on the full Bootleg model. "
            f"Then call utils.entity_profile.compress_topk_entity_embeddings to create your own mini model."
        )
    except:
        pass

    try:
        weight_shape = get_nested_item(model_state_dict,
                                       ENTITY_EMB_KEYS).shape[1]
    except:
        raise ValueError(
            f"ERROR: All of {ENTITY_EMB_KEYS} are not in model_state_dict")

    if args.init_vec_file is not None and len(args.init_vec_file) > 0:
        vector_for_new_ent = np.load(args.init_vec_file)
    else:
        print(f"Setting init vector to be all zeros")
        vector_for_new_ent = np.zeros(weight_shape)
    new_model_state_dict = refit_weights(
        np_same_ents,
        neweid2oldeid,
        train_entity_profile,
        new_entity_profile,
        vector_for_new_ent,
        model_state_dict,
    )
    print(new_model_state_dict["module_pool"]["learned"].keys())
    state_dict["model"] = new_model_state_dict
    print(f"Saving model at {args.save_model_path}")
    torch.save(state_dict, args.save_model_path)

    if args.model_config is not None:
        modify_config(
            args.model_config,
            args.save_model_config,
            args.save_model_path,
            args.new_entity_profile,
        )
Exemplo n.º 2
0
    def test_fit_entities(self):
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db,
            edit_mode=True,
        )
        # Modify train profle
        new_entity1 = {
            "entity_id": "Q910",
            "mentions": [["cobra", 10.0], ["animal", 3.0]],
            "title": "Cobra",
            "types": {
                "hyena": ["animal"],
                "wiki": ["dog"]
            },
            "relations": [{
                "relation": "sibling",
                "object": "Q123"
            }],
        }
        new_entity2 = {
            "entity_id": "Q101",
            "mentions": [["snake", 10.0], ["snakes", 7.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {
                "hyena": ["animal"],
                "wiki": ["dog"]
            },
        }
        train_entity_profile.add_entity(new_entity1)
        train_entity_profile.add_entity(new_entity2)
        train_entity_profile.reidentify_entity("Q123", "Q321")
        # Save new profile
        train_entity_profile.save(self.new_db)
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db, )
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db)
        neweid2oldeid = {1: 1, 2: 2, 3: 3, 4: 4}
        np_same_ents = {"Q321", "Q345", "Q567", "Q789"}
        state_dict = {
            "module_pool": {
                "learned": {
                    # 4 entities + 1 UNK + 1 PAD for train data
                    "learned_entity_embedding.weight":
                    torch.tensor([
                        [1.0, 2, 3, 4, 5],
                        [2, 2, 2, 2, 2],
                        [3, 3, 3, 3, 3],
                        [4, 4, 4, 4, 4],
                        [5, 5, 5, 5, 5],
                        [0, 0, 0, 0, 0],
                    ])
                }
            }
        }
        vector_for_new_ent = np.arange(5)
        new_state_dict = refit_weights(
            np_same_ents,
            neweid2oldeid,
            train_entity_profile,
            new_entity_profile,
            vector_for_new_ent,
            state_dict,
        )

        gold_state_dict = {
            "module_pool": {
                "learned": {
                    "learned_entity_embedding.weight":
                    torch.tensor([
                        [1.0, 2, 3, 4, 5],
                        [2, 2, 2, 2, 2],
                        [3, 3, 3, 3, 3],
                        [4, 4, 4, 4, 4],
                        [5, 5, 5, 5, 5],
                        [0, 1, 2, 3, 4],
                        [0, 1, 2, 3, 4],
                        [0, 0, 0, 0, 0],
                    ])
                }
            }
        }
        gld = gold_state_dict
        nsd = new_state_dict
        keys_to_check = [
            "module_pool", "learned", "learned_entity_embedding.weight"
        ]
        for k in keys_to_check:
            assert k in nsd
            assert k in gld
            if type(gld[k]) is dict:
                gld = gld[k]
                nsd = nsd[k]
                continue
            else:
                assert torch.equal(nsd[k], gld[k])

        # TEST WITH EIDREG
        state_dict = {
            "module_pool": {
                "learned": {
                    # 4 entities + 1 UNK + 1 PAD for train data
                    "learned_entity_embedding.weight":
                    torch.tensor([
                        [1.0, 2, 3, 4, 5],
                        [2, 2, 2, 2, 2],
                        [3, 3, 3, 3, 3],
                        [4, 4, 4, 4, 4],
                        [5, 5, 5, 5, 5],
                        [0, 0, 0, 0, 0],
                    ]),
                    "eid2reg":
                    torch.tensor([0.0, 0.2, 0.3, 0.4, 0.5, 0.0], ),
                }
            }
        }
        vector_for_new_ent = np.arange(5)
        new_state_dict = refit_weights(
            np_same_ents,
            neweid2oldeid,
            train_entity_profile,
            new_entity_profile,
            vector_for_new_ent,
            state_dict,
        )

        gold_state_dict = {
            "module_pool": {
                "learned": {
                    "learned_entity_embedding.weight":
                    torch.tensor([
                        [1.0, 2, 3, 4, 5],
                        [2, 2, 2, 2, 2],
                        [3, 3, 3, 3, 3],
                        [4, 4, 4, 4, 4],
                        [5, 5, 5, 5, 5],
                        [0, 1, 2, 3, 4],
                        [0, 1, 2, 3, 4],
                        [0, 0, 0, 0, 0],
                    ]),
                    "eid2reg":
                    torch.tensor([0.0, 0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.0], ),
                }
            }
        }
        gld = gold_state_dict
        nsd = new_state_dict
        keys_to_check = [
            "module_pool", "learned", "learned_entity_embedding.weight"
        ]
        for k in keys_to_check:
            assert k in nsd
            assert k in gld
            if type(gld[k]) is dict:
                gld = gld[k]
                nsd = nsd[k]
                continue
            else:
                assert torch.equal(nsd[k], gld[k])
        gld = gold_state_dict
        nsd = new_state_dict
        keys_to_check = ["module_pool", "learned", "eid2reg"]
        for k in keys_to_check:
            assert k in nsd
            assert k in gld
            if type(gld[k]) is dict:
                gld = gld[k]
                nsd = nsd[k]
                continue
            else:
                assert torch.equal(nsd[k], gld[k])

        # TEST WITH PRUNE
        # TEST PRUNE - this profile already has 910 and 101
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db, edit_mode=True)
        new_entity_profile.prune_to_entities({"Q321", "Q910", "Q101"
                                              })  # These now get eids 1, 2, 3
        # Manually set the eids for the test
        new_entity_profile._entity_symbols._qid2eid = {
            "Q321": 3,
            "Q910": 2,
            "Q101": 1
        }
        new_entity_profile._entity_symbols._eid2qid = {
            3: "Q321",
            2: "Q910",
            1: "Q101"
        }
        new_entity_profile.save(self.new_db)
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db, )
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db)
        neweid2oldeid = {3: 1}
        np_same_ents = {"Q321"}
        state_dict = {
            "module_pool": {
                "learned": {
                    # 4 entities + 1 UNK + 1 PAD for train data
                    "learned_entity_embedding.weight":
                    torch.tensor([
                        [1.0, 2, 3, 4, 5],
                        [2, 2, 2, 2, 2],
                        [3, 3, 3, 3, 3],
                        [4, 4, 4, 4, 4],
                        [5, 5, 5, 5, 5],
                        [0, 0, 0, 0, 0],
                    ]),
                    "eid2reg":
                    torch.tensor([0.0, 0.2, 0.3, 0.4, 0.5, 0.0], ),
                }
            }
        }
        vector_for_new_ent = np.arange(5)
        new_state_dict = refit_weights(
            np_same_ents,
            neweid2oldeid,
            train_entity_profile,
            new_entity_profile,
            vector_for_new_ent,
            state_dict,
        )

        gold_state_dict = {
            "module_pool": {
                "learned": {
                    "learned_entity_embedding.weight":
                    torch.tensor([
                        [1.0, 2, 3, 4, 5],
                        [0, 1, 2, 3, 4],
                        [0, 1, 2, 3, 4],
                        [2, 2, 2, 2, 2],
                        [0, 0, 0, 0, 0],
                    ]),
                    "eid2reg":
                    torch.tensor([0.0, 0.5, 0.5, 0.2, 0.0], ),
                }
            }
        }
        gld = gold_state_dict
        nsd = new_state_dict
        keys_to_check = [
            "module_pool", "learned", "learned_entity_embedding.weight"
        ]
        for k in keys_to_check:
            assert k in nsd
            assert k in gld
            if type(gld[k]) is dict:
                gld = gld[k]
                nsd = nsd[k]
                continue
            else:
                assert torch.equal(nsd[k], gld[k])
        gld = gold_state_dict
        nsd = new_state_dict
        keys_to_check = ["module_pool", "learned", "eid2reg"]
        for k in keys_to_check:
            assert k in nsd
            assert k in gld
            if type(gld[k]) is dict:
                gld = gld[k]
                nsd = nsd[k]
                continue
            else:
                assert torch.equal(nsd[k], gld[k])
Exemplo n.º 3
0
    def test_match_entities(self):
        # TEST ADD
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db,
            edit_mode=True,
        )
        # Modify train profle
        new_entity1 = {
            "entity_id": "Q910",
            "mentions": [["cobra", 10.0], ["animal", 3.0]],
            "title": "Cobra",
            "types": {
                "hyena": ["animal"],
                "wiki": ["dog"]
            },
            "relations": [{
                "relation": "sibling",
                "object": "Q123"
            }],
        }
        new_entity2 = {
            "entity_id": "Q101",
            "mentions": [["snake", 10.0], ["snakes", 7.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {
                "hyena": ["animal"],
                "wiki": ["dog"]
            },
        }
        train_entity_profile.add_entity(new_entity1)
        train_entity_profile.add_entity(new_entity2)
        train_entity_profile.reidentify_entity("Q123", "Q321")
        # Save new profile
        train_entity_profile.save(self.new_db)
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db, )
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db)

        oldqid2newqid = {"Q123": "Q321"}
        newqid2oldqid = {v: k for k, v in oldqid2newqid.items()}
        np_removed_ents, np_same_ents, np_new_ents, oldeid2neweid = match_entities(
            train_entity_profile, new_entity_profile, oldqid2newqid,
            newqid2oldqid)
        gold_removed_ents = set()
        gold_same_ents = {"Q321", "Q345", "Q567", "Q789"}
        gold_new_ents = {"Q910", "Q101"}
        gold_oldeid2neweid = {1: 1, 2: 2, 3: 3, 4: 4}
        self.assertSetEqual(gold_removed_ents, np_removed_ents)
        self.assertSetEqual(gold_same_ents, np_same_ents)
        self.assertSetEqual(gold_new_ents, np_new_ents)
        self.assertDictEqual(gold_oldeid2neweid, oldeid2neweid)

        # TEST PRUNE - this profile already has 910 and 101
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db, edit_mode=True)
        new_entity_profile.prune_to_entities({"Q321", "Q910", "Q101"
                                              })  # These now get eids 1, 2, 3
        # Manually set the eids for the test
        new_entity_profile._entity_symbols._qid2eid = {
            "Q321": 3,
            "Q910": 2,
            "Q101": 1
        }
        new_entity_profile._entity_symbols._eid2qid = {
            3: "Q321",
            2: "Q910",
            1: "Q101"
        }
        # Save new profile
        new_entity_profile.save(self.new_db)
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db, )
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db)
        oldqid2newqid = {"Q123": "Q321"}
        newqid2oldqid = {v: k for k, v in oldqid2newqid.items()}
        np_removed_ents, np_same_ents, np_new_ents, oldeid2neweid = match_entities(
            train_entity_profile, new_entity_profile, oldqid2newqid,
            newqid2oldqid)
        gold_removed_ents = {"Q345", "Q567", "Q789"}
        gold_same_ents = {"Q321"}
        gold_new_ents = {"Q910", "Q101"}
        gold_oldeid2neweid = {1: 3}
        self.assertSetEqual(gold_removed_ents, np_removed_ents)
        self.assertSetEqual(gold_same_ents, np_same_ents)
        self.assertSetEqual(gold_new_ents, np_new_ents)
        self.assertDictEqual(gold_oldeid2neweid, oldeid2neweid)
Exemplo n.º 4
0
    def test_prune_to_entities(self):
        data = [
            {
                "entity_id": "Q123",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            },
            {
                "entity_id": "Q345",
                "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
                "title": "Cat",
                "types": {"hyena": ["animal"], "wiki": ["cat"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
        ]
        self.write_data(self.profile_file, data)
        entity_profile = EntityProfile.load_from_jsonl(
            self.profile_file, max_candidates=5, edit_mode=True
        )
        entity_profile.save(self.save_dir2)

        with self.assertRaises(ValueError) as context:
            entity_profile.prune_to_entities({"Q123", "Q567"})
        assert type(context.exception) is ValueError
        assert "The entity Q567 does not exist" in str(context.exception)

        entity_profile.prune_to_entities({"Q123"})
        self.assertTrue(entity_profile.qid_exists("Q123"))
        self.assertFalse(entity_profile.qid_exists("Q345"))
        self.assertListEqual(
            entity_profile.get_mentions_with_scores("Q123"),
            [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
        )
        self.assertListEqual(entity_profile.get_types("Q123", "hyena"), ["animal"])
        self.assertListEqual(entity_profile.get_types("Q123", "wiki"), ["dog"])
        self.assertListEqual(
            entity_profile.get_connections_by_relation("Q123", "sibling"),
            [],
        )

        # Check that no_kg still works with load_from_cache
        entity_profile2 = EntityProfile.load_from_cache(
            self.save_dir2, no_kg=True, edit_mode=True
        )
        entity_profile2.prune_to_entities({"Q123"})
        self.assertTrue(entity_profile2.qid_exists("Q123"))
        self.assertFalse(entity_profile2.qid_exists("Q345"))
        self.assertListEqual(
            entity_profile2.get_mentions_with_scores("Q123"),
            [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
        )
        self.assertListEqual(entity_profile2.get_types("Q123", "hyena"), ["animal"])
        self.assertListEqual(entity_profile2.get_types("Q123", "wiki"), ["dog"])
        self.assertListEqual(
            entity_profile2.get_connections_by_relation("Q123", "sibling"),
            [],
        )
Exemplo n.º 5
0
    def test_add_entity(self):
        data = [
            {
                "entity_id": "Q123",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            },
            {
                "entity_id": "Q345",
                "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
                "title": "Cat",
                "types": {"hyena": ["animal"], "wiki": ["cat"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
        ]
        self.write_data(self.profile_file, data)
        entity_profile = EntityProfile.load_from_jsonl(
            self.profile_file, max_candidates=3, edit_mode=True
        )
        entity_profile.save(self.save_dir2)

        # Test bad format
        with self.assertRaises(ValueError) as context:
            entity_profile.add_entity(["bad format"])
        assert type(context.exception) is ValueError
        assert "The input to update_entity needs to be a dictionary" in str(
            context.exception
        )

        new_entity = {
            "entity_id": "Q345",
            "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
            "title": "Cat",
            "types": {"hyena": ["animal"], "wiki": ["cat"]},
            "relations": [{"relation": "sibling", "object": "Q123"}],
        }

        # Test already existing entity
        with self.assertRaises(ValueError) as context:
            entity_profile.add_entity(new_entity)
        assert type(context.exception) is ValueError
        assert "The entity Q345 already exists" in str(context.exception)

        new_entity = {
            "entity_id": "Q789",
            "mentions": [["snake", 10.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {"hyena": ["animal"], "new_sys": ["snakey"]},
            "relations": [{"relation": "sibling", "object": "Q123"}],
        }

        # Test new type system
        with self.assertRaises(ValueError) as context:
            entity_profile.add_entity(new_entity)
        assert type(context.exception) is ValueError
        assert "When adding a new entity, you must use the same type system" in str(
            context.exception
        )

        new_entity = {
            "entity_id": "Q789",
            "mentions": [["snake", 10.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {"hyena": ["animal"]},
            "relations": [{"relatiion": "sibbbling", "object": "Q123"}],
        }

        # Test new bad relation format
        with self.assertRaises(ValueError) as context:
            entity_profile.add_entity(new_entity)
        assert type(context.exception) is ValueError
        assert (
            "For each value in relations, it must be a JSON with keys relation and object"
            in str(context.exception)
        )

        new_entity = {
            "entity_id": "Q789",
            "mentions": [["snake", 10.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {"hyena": ["animal"]},
            "relations": [{"relation": "sibbbling", "object": "Q123"}],
        }

        # Test new relation
        with self.assertRaises(ValueError) as context:
            entity_profile.add_entity(new_entity)
        assert type(context.exception) is ValueError
        assert (
            "When adding a new entity, you must use the same set of relations."
            in str(context.exception)
        )

        new_entity = {
            "entity_id": "Q789",
            "mentions": [["snake", 10.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {"hyena": ["animal"]},
            "relations": [{"relation": "sibling", "object": "Q123"}],
        }
        # Assert it is added
        entity_profile.add_entity(new_entity)
        self.assertTrue(entity_profile.qid_exists("Q789"))
        self.assertEqual(entity_profile.get_title("Q789"), "Snake")
        self.assertListEqual(
            entity_profile.get_mentions_with_scores("Q789"),
            [["snake", 10.0], ["animal", 3.0]],
        )
        self.assertListEqual(entity_profile.get_types("Q789", "hyena"), ["animal"])
        self.assertListEqual(entity_profile.get_types("Q789", "wiki"), [])
        self.assertListEqual(
            entity_profile.get_connections_by_relation("Q789", "sibling"), ["Q123"]
        )

        # Check that no_kg still works with load_from_cache
        entity_profile2 = EntityProfile.load_from_cache(
            self.save_dir2, no_kg=True, edit_mode=True
        )
        entity_profile2.add_entity(new_entity)
        self.assertTrue(entity_profile2.qid_exists("Q789"))
        self.assertEqual(entity_profile2.get_title("Q789"), "Snake")
        self.assertListEqual(
            entity_profile2.get_mentions_with_scores("Q789"),
            [["snake", 10.0], ["animal", 3.0]],
        )
        self.assertListEqual(entity_profile2.get_types("Q789", "hyena"), ["animal"])
        self.assertListEqual(entity_profile2.get_types("Q789", "wiki"), [])
        self.assertListEqual(
            entity_profile2.get_connections_by_relation("Q789", "sibling"), []
        )
Exemplo n.º 6
0
    def test_profile_dump_load(self):
        data = [
            {
                "entity_id": "Q123",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            },
            {
                "entity_id": "Q345",
                "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
                "title": "Cat",
                "types": {"hyena": ["animal"], "wiki": ["cat"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
        ]
        self.write_data(self.profile_file, data)
        entity_profile = EntityProfile.load_from_jsonl(
            self.profile_file, max_candidates=5, edit_mode=True
        )
        entity_profile.save(self.save_dir2)

        # Test load correctly
        entity_profile2 = EntityProfile.load_from_cache(self.save_dir2)

        self.assertSetEqual(
            set(entity_profile.get_all_qids()), set(entity_profile2.get_all_qids())
        )
        self.assertSetEqual(
            set(entity_profile.get_all_typesystems()),
            set(entity_profile2.get_all_typesystems()),
        )
        for type_sys in entity_profile.get_all_typesystems():
            self.assertSetEqual(
                set(entity_profile.get_all_types(type_sys)),
                set(entity_profile2.get_all_types(type_sys)),
            )
        for qid in entity_profile.get_all_qids():
            self.assertSetEqual(
                set(entity_profile.get_all_connections(qid)),
                set(entity_profile2.get_all_connections(qid)),
            )

        # Test load with no types or kgs
        entity_profile2 = EntityProfile.load_from_cache(
            self.save_dir2, no_type=True, no_kg=True
        )

        self.assertSetEqual(
            set(entity_profile.get_all_qids()), set(entity_profile2.get_all_qids())
        )
        assert len(entity_profile2.get_all_typesystems()) == 0
        self.assertIsNone(entity_profile2._kg_symbols)

        # Testing that the functions still work despite not loading them
        assert len(entity_profile2.get_all_connections("Q123")) == 0

        # Test load with no types or kgs
        entity_profile2 = EntityProfile.load_from_cache(
            self.save_dir2, no_kg=True, type_systems_to_load=["wiki"]
        )

        self.assertSetEqual(
            set(entity_profile.get_all_qids()), set(entity_profile2.get_all_qids())
        )
        assert entity_profile2.get_all_typesystems() == ["wiki"]
        self.assertSetEqual(
            set(entity_profile.get_all_types("wiki")),
            set(entity_profile2.get_all_types("wiki")),
        )
        self.assertIsNone(entity_profile2._kg_symbols)

        # Assert error loading type system that is not there
        with self.assertRaises(ValueError) as context:
            entity_profile2.get_all_types("hyena")
        assert type(context.exception) is ValueError
        assert "type system hyena is not one" in str(context.exception)
Exemplo n.º 7
0
def fit_profiles(args):
    print(json.dumps(vars(args), indent=4))

    if args.model_config is not None:
        assert (
            args.save_model_config is not None
        ), f"If you pass in a model config, you must pass in a model save config path"

    print(f"Loading train entity profile from {args.train_entity_profile}")
    train_entity_profile = EntityProfile.load_from_cache(
        load_dir=args.train_entity_profile, no_type=True, no_kg=True
    )
    print(f"Loading new entity profile from {args.new_entity_profile}")
    new_entity_profile = EntityProfile.load_from_cache(
        load_dir=args.new_entity_profile, no_type=True, no_kg=True
    )
    oldqid2newqid = dict()
    newqid2oldqid = dict()
    if args.oldqid2newqid is not None and len(args.oldqid2newqid) > 0:
        with open(args.oldqid2newqid) as in_f:
            oldqid2newqid = ujson.load(in_f)
            newqid2oldqid = {v: k for k, v in oldqid2newqid.items()}
            assert len(oldqid2newqid) == len(
                newqid2oldqid
            ), f"The dicts of oldqid2newqid and its inverse do not have the same length"

    np_removed_ents, np_same_ents, np_new_ents, oldeid2neweid = match_entities(
        train_entity_profile, new_entity_profile, oldqid2newqid, newqid2oldqid
    )
    neweid2oldeid = {v: k for k, v in oldeid2neweid.items()}
    assert len(oldeid2neweid) == len(
        neweid2oldeid
    ), f"The lengths of oldeid2neweid and neweid2oldeid don't match"
    state_dict, model_state_dict = load_statedict(args.model_path)

    # We do not support modifying a topK model. Only the original model.
    try:
        get_nested_item(model_state_dict, ENTITY_TOPK_KEYS)
        raise NotImplementedError(
            f"We don't support fitting a topK mini model. Instead, call `fit_to_profile` on the full Bootleg model. "
            f"Then call utils.entity_profile.compress_topk_entity_embeddings to create your own mini model."
        )
    except:
        pass

    try:
        weight_shape = get_nested_item(model_state_dict, ENTITY_EMB_KEYS).shape[1]
    except:
        raise ValueError(f"ERROR: All of {ENTITY_EMB_KEYS} are not in model_state_dict")

    # Refit weights
    if args.init_vec_file is not None and len(args.init_vec_file) > 0:
        vector_for_new_ent = np.load(args.init_vec_file)
    else:
        print(f"Setting init vector to be all zeros")
        vector_for_new_ent = np.zeros(weight_shape)
    new_model_state_dict = refit_weights(
        np_same_ents,
        neweid2oldeid,
        train_entity_profile,
        new_entity_profile,
        vector_for_new_ent,
        model_state_dict,
    )
    state_dict["model"] = new_model_state_dict
    print(f"Saving model at {args.save_model_path}")
    torch.save(state_dict, args.save_model_path)
    del new_model_state_dict
    del state_dict

    # Refit titles
    # Will keep track of all embeddings to adjust. If given a config, we will only adjust
    # the one from the config. Othwerwise, we adjust all that are "static_table_" arrays of
    # length BERT_DIM
    if not args.no_title_emb:
        title_embeddings = []
        prepped_title_emb_files = []
        title_emb_file = None
        prep_subdir = "prep"
        # First try to read entity_prep_dir from config
        if args.model_config is not None:
            with open(args.model_config) as file:
                config = yaml.load(file, Loader=yaml.FullLoader)
            prep_subdir = config["data_config"].get("entity_prep_dir", "prep")
            for ent in config["data_config"]["ent_embeddings"]:
                if ent["load_class"] == "StaticEmb" and ent["key"] == "title_static":
                    assert (
                        "emb_file" in ent["args"]
                    ), f"emb_file needs to be in title_static config"
                    title_emb_file = ent["args"]["emb_file"]

        prep_dir = Path(args.train_entity_profile) / prep_subdir
        out_prep_dir = Path(args.new_entity_profile) / prep_subdir

        print(f"Looking for title embedding in {prep_dir}")
        # Try to find a saved title prep file
        for file in prep_dir.iterdir():
            if file.is_file() and file.name.startswith("static_table_"):
                # If we know the title embedding file from the config, use it to find the right prepped file
                if (
                    title_emb_file is not None
                    and file.name
                    != f"static_table_{os.path.splitext(os.path.basename(title_emb_file))[0]}.npy"
                ):
                    continue
                possible_titles = np.load(file, mmap_mode="r")
                if list(possible_titles.shape) == [
                    train_entity_profile.num_entities_with_pad_and_nocand,
                    BERT_DIM,
                ]:
                    title_embeddings.append(possible_titles)
                    prepped_title_emb_files.append(file.name)
        if len(title_embeddings) == 0:
            print(
                f"We were unable to adjust titles. If your model does not use title embeddings, ignore this. If your"
                f"model does (all Bootleg models by default do), please call "
                f"```python -m bootleg.utils.preprocessing.build_static_embeddings --help``` to extract manually. "
                f"The saved file from this method should be added to ```emb_file``` config param for the title"
                f"embedding."
            )
        else:
            for title_embed, prepped_title_emb_file in zip(
                title_embeddings, prepped_title_emb_files
            ):
                print(f"Attempting to refit title {prepped_title_emb_file}")
                out_prep_dir.mkdir(parents=True, exist_ok=True)
                save_file = out_prep_dir / prepped_title_emb_file
                print(f"Saving to {save_file}")
                # Returns memmap array
                _ = refit_titles(
                    np_same_ents,
                    np_new_ents,
                    neweid2oldeid,
                    train_entity_profile,
                    new_entity_profile,
                    title_embed,
                    str(save_file),
                    args.bert_model,
                    args.bert_model_cache,
                    args.cpu,
                )

    if args.model_config is not None:
        modify_config(
            args.model_config,
            args.save_model_config,
            args.save_model_path,
            args.new_entity_profile,
        )
Exemplo n.º 8
0
    def test_fit_titles(self):
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db,
            edit_mode=True,
        )
        # Modify train profle
        new_entity1 = {
            "entity_id": "Q910",
            "mentions": [["cobra", 10.0], ["animal", 3.0]],
            "title": "Cobra",
            "types": {
                "hyena": ["animal"],
                "wiki": ["dog"]
            },
            "relations": [{
                "relation": "sibling",
                "object": "Q123"
            }],
        }
        new_entity2 = {
            "entity_id": "Q101",
            "mentions": [["snake", 10.0], ["snakes", 7.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {
                "hyena": ["animal"],
                "wiki": ["dog"]
            },
        }
        train_entity_profile.add_entity(new_entity1)
        train_entity_profile.add_entity(new_entity2)
        train_entity_profile.reidentify_entity("Q123", "Q321")
        # Save new profile
        train_entity_profile.save(self.new_db)
        # Create old title embs
        title_embs = np.zeros((
            6,
            768,
        ))
        for i in range(len(title_embs)):
            title_embs[i] = np.ones(768) * i

        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db, )
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db)
        neweid2oldeid = {1: 1, 2: 2, 3: 3, 4: 4}
        np_same_ents = {"Q321", "Q345", "Q567", "Q789"}
        np_new_ents = {"Q910", "Q101"}

        new_title_embs = refit_titles(
            np_same_ents,
            np_new_ents,
            neweid2oldeid,
            train_entity_profile,
            new_entity_profile,
            title_embs,
            str(self.dir / "temp_title.npy"),
            "bert-base-uncased",
            str(self.dir / "temp_bert_models"),
            True,
        )
        # Compute gold BERT titles for new entities
        tokenizer = BertTokenizer.from_pretrained(
            "bert-base-uncased",
            do_lower_case=True,
            cache_dir=str(self.dir / "temp_bert_models"),
        )
        model = BertModel.from_pretrained(
            "bert-base-uncased",
            cache_dir=str(self.dir / "temp_bert_models"),
            output_attentions=False,
            output_hidden_states=False,
        )
        model.eval()
        input_ids = tokenizer(["Cobra", "Snake"],
                              padding=True,
                              truncation=True,
                              return_tensors="pt")
        inputs = input_ids["input_ids"]
        attention_mask = input_ids["attention_mask"]
        outputs = model(inputs, attention_mask=attention_mask)[0]
        outputs[inputs == 0] = 0
        avgtitle = average_titles(inputs, outputs).to("cpu").detach().numpy()

        gold_title_embs = np.zeros((
            8,
            768,
        ))
        gold_title_embs[0] = np.ones(768) * 0
        gold_title_embs[1] = np.ones(768) * 1
        gold_title_embs[2] = np.ones(768) * 2
        gold_title_embs[3] = np.ones(768) * 3
        gold_title_embs[4] = np.ones(768) * 4
        gold_title_embs[5] = avgtitle[0]
        gold_title_embs[6] = avgtitle[1]
        gold_title_embs[7] = np.ones(
            768) * 0  # Last row is set to zero in fit method
        np.testing.assert_array_almost_equal(new_title_embs, gold_title_embs)