コード例 #1
0
    def test_checks(self):
        data = [
            {
                "entity_id": "Q123",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            },
            {
                "entity_id": "Q345",
                "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
                "title": "Cat",
                "types": {"hyena": ["animal"], "wiki": ["cat"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
        ]
        self.write_data(self.profile_file, data)
        entity_profile = EntityProfile.load_from_jsonl(
            self.profile_file, max_candidates=5
        )

        with self.assertRaises(AttributeError) as context:
            entity_profile.add_relation("Q345", "sibling", "Q123")
        assert type(context.exception) is AttributeError

        entity_profile = EntityProfile.load_from_jsonl(
            self.profile_file, max_candidates=5, edit_mode=True
        )

        with self.assertRaises(ValueError) as context:
            entity_profile.add_relation("Q789", "sibling", "Q123")
        assert type(context.exception) is ValueError
        assert "is not in our dump" in str(context.exception)

        with self.assertRaises(ValueError) as context:
            entity_profile.add_relation(qid="Q789", relation="sibling", qid2="Q123")
        assert type(context.exception) is ValueError
        assert "is not in our dump" in str(context.exception)

        with self.assertRaises(ValueError) as context:
            entity_profile.add_type(qid="Q345", type="sibling", type_system="blah")
        assert type(context.exception) is ValueError
        assert "type system blah is not one" in str(context.exception)

        with self.assertRaises(ValueError) as context:
            entity_profile.get_types(qid="Q345", type_system="blah")
        assert type(context.exception) is ValueError
        assert "type system blah is not one" in str(context.exception)
コード例 #2
0
 def test_profile_load_jsonl_errors(self):
     data = [
         {
             "entity_id": 123,
             "mentions": [["dog"], ["dogg"], ["animal"]],
             "title": "Dog",
             "types": {"hyena": ["animal"], "wiki": ["dog"]},
             "relations": [
                 {"relation": "sibling", "object": "Q345"},
                 {"relation": "sibling", "object": "Q567"},
             ],
         },
     ]
     self.write_data(self.profile_file, data)
     with self.assertRaises(ValidationError) as context:
         EntityProfile._read_profile_file(self.profile_file)
     assert type(context.exception) is ValidationError
コード例 #3
0
def fit_profiles(args):
    print(json.dumps(vars(args), indent=4))

    if args.model_config is not None:
        assert (
            args.save_model_config is not None
        ), f"If you pass in a model config, you must pass in a model save config path"

    print(f"Loading train entity profile from {args.train_entity_profile}")
    train_entity_profile = EntityProfile.load_from_cache(
        load_dir=args.train_entity_profile)
    print(f"Loading new entity profile from {args.new_entity_profile}")
    new_entity_profile = EntityProfile.load_from_cache(
        load_dir=args.new_entity_profile)

    oldqid2newqid = dict()
    newqid2oldqid = dict()
    if args.oldqid2newqid is not None and len(args.oldqid2newqid) > 0:
        with open(args.oldqid2newqid) as in_f:
            oldqid2newqid = ujson.load(in_f)
            newqid2oldqid = {v: k for k, v in oldqid2newqid.items()}
            assert len(oldqid2newqid) == len(
                newqid2oldqid
            ), f"The dicts of oldqid2newqid and its inverse do not have the same length"

    np_removed_ents, np_same_ents, np_new_ents, oldeid2neweid = match_entities(
        train_entity_profile, new_entity_profile, oldqid2newqid, newqid2oldqid)
    neweid2oldeid = {v: k for k, v in oldeid2neweid.items()}
    assert len(oldeid2neweid) == len(
        neweid2oldeid
    ), f"The lengths of oldeid2neweid and neweid2oldeid don't match"
    state_dict, model_state_dict = load_statedict(args.model_path)

    # We do not support modifying a topK model. Only the original model.
    try:
        get_nested_item(model_state_dict, ENTITY_TOPK_KEYS)
        raise NotImplementedError(
            f"We don't support fitting a topK mini model. Instead, call `fit_to_profile` on the full Bootleg model. "
            f"Then call utils.entity_profile.compress_topk_entity_embeddings to create your own mini model."
        )
    except:
        pass

    try:
        weight_shape = get_nested_item(model_state_dict,
                                       ENTITY_EMB_KEYS).shape[1]
    except:
        raise ValueError(
            f"ERROR: All of {ENTITY_EMB_KEYS} are not in model_state_dict")

    if args.init_vec_file is not None and len(args.init_vec_file) > 0:
        vector_for_new_ent = np.load(args.init_vec_file)
    else:
        print(f"Setting init vector to be all zeros")
        vector_for_new_ent = np.zeros(weight_shape)
    new_model_state_dict = refit_weights(
        np_same_ents,
        neweid2oldeid,
        train_entity_profile,
        new_entity_profile,
        vector_for_new_ent,
        model_state_dict,
    )
    print(new_model_state_dict["module_pool"]["learned"].keys())
    state_dict["model"] = new_model_state_dict
    print(f"Saving model at {args.save_model_path}")
    torch.save(state_dict, args.save_model_path)

    if args.model_config is not None:
        modify_config(
            args.model_config,
            args.save_model_config,
            args.save_model_path,
            args.new_entity_profile,
        )
コード例 #4
0
    def test_match_entities(self):
        # TEST ADD
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db,
            edit_mode=True,
        )
        # Modify train profle
        new_entity1 = {
            "entity_id": "Q910",
            "mentions": [["cobra", 10.0], ["animal", 3.0]],
            "title": "Cobra",
            "types": {
                "hyena": ["animal"],
                "wiki": ["dog"]
            },
            "relations": [{
                "relation": "sibling",
                "object": "Q123"
            }],
        }
        new_entity2 = {
            "entity_id": "Q101",
            "mentions": [["snake", 10.0], ["snakes", 7.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {
                "hyena": ["animal"],
                "wiki": ["dog"]
            },
        }
        train_entity_profile.add_entity(new_entity1)
        train_entity_profile.add_entity(new_entity2)
        train_entity_profile.reidentify_entity("Q123", "Q321")
        # Save new profile
        train_entity_profile.save(self.new_db)
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db, )
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db)

        oldqid2newqid = {"Q123": "Q321"}
        newqid2oldqid = {v: k for k, v in oldqid2newqid.items()}
        np_removed_ents, np_same_ents, np_new_ents, oldeid2neweid = match_entities(
            train_entity_profile, new_entity_profile, oldqid2newqid,
            newqid2oldqid)
        gold_removed_ents = set()
        gold_same_ents = {"Q321", "Q345", "Q567", "Q789"}
        gold_new_ents = {"Q910", "Q101"}
        gold_oldeid2neweid = {1: 1, 2: 2, 3: 3, 4: 4}
        self.assertSetEqual(gold_removed_ents, np_removed_ents)
        self.assertSetEqual(gold_same_ents, np_same_ents)
        self.assertSetEqual(gold_new_ents, np_new_ents)
        self.assertDictEqual(gold_oldeid2neweid, oldeid2neweid)

        # TEST PRUNE - this profile already has 910 and 101
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db, edit_mode=True)
        new_entity_profile.prune_to_entities({"Q321", "Q910", "Q101"
                                              })  # These now get eids 1, 2, 3
        # Manually set the eids for the test
        new_entity_profile._entity_symbols._qid2eid = {
            "Q321": 3,
            "Q910": 2,
            "Q101": 1
        }
        new_entity_profile._entity_symbols._eid2qid = {
            3: "Q321",
            2: "Q910",
            1: "Q101"
        }
        # Save new profile
        new_entity_profile.save(self.new_db)
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db, )
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db)
        oldqid2newqid = {"Q123": "Q321"}
        newqid2oldqid = {v: k for k, v in oldqid2newqid.items()}
        np_removed_ents, np_same_ents, np_new_ents, oldeid2neweid = match_entities(
            train_entity_profile, new_entity_profile, oldqid2newqid,
            newqid2oldqid)
        gold_removed_ents = {"Q345", "Q567", "Q789"}
        gold_same_ents = {"Q321"}
        gold_new_ents = {"Q910", "Q101"}
        gold_oldeid2neweid = {1: 3}
        self.assertSetEqual(gold_removed_ents, np_removed_ents)
        self.assertSetEqual(gold_same_ents, np_same_ents)
        self.assertSetEqual(gold_new_ents, np_new_ents)
        self.assertDictEqual(gold_oldeid2neweid, oldeid2neweid)
コード例 #5
0
    def test_fit_entities(self):
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db,
            edit_mode=True,
        )
        # Modify train profle
        new_entity1 = {
            "entity_id": "Q910",
            "mentions": [["cobra", 10.0], ["animal", 3.0]],
            "title": "Cobra",
            "types": {
                "hyena": ["animal"],
                "wiki": ["dog"]
            },
            "relations": [{
                "relation": "sibling",
                "object": "Q123"
            }],
        }
        new_entity2 = {
            "entity_id": "Q101",
            "mentions": [["snake", 10.0], ["snakes", 7.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {
                "hyena": ["animal"],
                "wiki": ["dog"]
            },
        }
        train_entity_profile.add_entity(new_entity1)
        train_entity_profile.add_entity(new_entity2)
        train_entity_profile.reidentify_entity("Q123", "Q321")
        # Save new profile
        train_entity_profile.save(self.new_db)
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db, )
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db)
        neweid2oldeid = {1: 1, 2: 2, 3: 3, 4: 4}
        np_same_ents = {"Q321", "Q345", "Q567", "Q789"}
        state_dict = {
            "module_pool": {
                "learned": {
                    # 4 entities + 1 UNK + 1 PAD for train data
                    "learned_entity_embedding.weight":
                    torch.tensor([
                        [1.0, 2, 3, 4, 5],
                        [2, 2, 2, 2, 2],
                        [3, 3, 3, 3, 3],
                        [4, 4, 4, 4, 4],
                        [5, 5, 5, 5, 5],
                        [0, 0, 0, 0, 0],
                    ])
                }
            }
        }
        vector_for_new_ent = np.arange(5)
        new_state_dict = refit_weights(
            np_same_ents,
            neweid2oldeid,
            train_entity_profile,
            new_entity_profile,
            vector_for_new_ent,
            state_dict,
        )

        gold_state_dict = {
            "module_pool": {
                "learned": {
                    "learned_entity_embedding.weight":
                    torch.tensor([
                        [1.0, 2, 3, 4, 5],
                        [2, 2, 2, 2, 2],
                        [3, 3, 3, 3, 3],
                        [4, 4, 4, 4, 4],
                        [5, 5, 5, 5, 5],
                        [0, 1, 2, 3, 4],
                        [0, 1, 2, 3, 4],
                        [0, 0, 0, 0, 0],
                    ])
                }
            }
        }
        gld = gold_state_dict
        nsd = new_state_dict
        keys_to_check = [
            "module_pool", "learned", "learned_entity_embedding.weight"
        ]
        for k in keys_to_check:
            assert k in nsd
            assert k in gld
            if type(gld[k]) is dict:
                gld = gld[k]
                nsd = nsd[k]
                continue
            else:
                assert torch.equal(nsd[k], gld[k])

        # TEST WITH EIDREG
        state_dict = {
            "module_pool": {
                "learned": {
                    # 4 entities + 1 UNK + 1 PAD for train data
                    "learned_entity_embedding.weight":
                    torch.tensor([
                        [1.0, 2, 3, 4, 5],
                        [2, 2, 2, 2, 2],
                        [3, 3, 3, 3, 3],
                        [4, 4, 4, 4, 4],
                        [5, 5, 5, 5, 5],
                        [0, 0, 0, 0, 0],
                    ]),
                    "eid2reg":
                    torch.tensor([0.0, 0.2, 0.3, 0.4, 0.5, 0.0], ),
                }
            }
        }
        vector_for_new_ent = np.arange(5)
        new_state_dict = refit_weights(
            np_same_ents,
            neweid2oldeid,
            train_entity_profile,
            new_entity_profile,
            vector_for_new_ent,
            state_dict,
        )

        gold_state_dict = {
            "module_pool": {
                "learned": {
                    "learned_entity_embedding.weight":
                    torch.tensor([
                        [1.0, 2, 3, 4, 5],
                        [2, 2, 2, 2, 2],
                        [3, 3, 3, 3, 3],
                        [4, 4, 4, 4, 4],
                        [5, 5, 5, 5, 5],
                        [0, 1, 2, 3, 4],
                        [0, 1, 2, 3, 4],
                        [0, 0, 0, 0, 0],
                    ]),
                    "eid2reg":
                    torch.tensor([0.0, 0.2, 0.3, 0.4, 0.5, 0.5, 0.5, 0.0], ),
                }
            }
        }
        gld = gold_state_dict
        nsd = new_state_dict
        keys_to_check = [
            "module_pool", "learned", "learned_entity_embedding.weight"
        ]
        for k in keys_to_check:
            assert k in nsd
            assert k in gld
            if type(gld[k]) is dict:
                gld = gld[k]
                nsd = nsd[k]
                continue
            else:
                assert torch.equal(nsd[k], gld[k])
        gld = gold_state_dict
        nsd = new_state_dict
        keys_to_check = ["module_pool", "learned", "eid2reg"]
        for k in keys_to_check:
            assert k in nsd
            assert k in gld
            if type(gld[k]) is dict:
                gld = gld[k]
                nsd = nsd[k]
                continue
            else:
                assert torch.equal(nsd[k], gld[k])

        # TEST WITH PRUNE
        # TEST PRUNE - this profile already has 910 and 101
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db, edit_mode=True)
        new_entity_profile.prune_to_entities({"Q321", "Q910", "Q101"
                                              })  # These now get eids 1, 2, 3
        # Manually set the eids for the test
        new_entity_profile._entity_symbols._qid2eid = {
            "Q321": 3,
            "Q910": 2,
            "Q101": 1
        }
        new_entity_profile._entity_symbols._eid2qid = {
            3: "Q321",
            2: "Q910",
            1: "Q101"
        }
        new_entity_profile.save(self.new_db)
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db, )
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db)
        neweid2oldeid = {3: 1}
        np_same_ents = {"Q321"}
        state_dict = {
            "module_pool": {
                "learned": {
                    # 4 entities + 1 UNK + 1 PAD for train data
                    "learned_entity_embedding.weight":
                    torch.tensor([
                        [1.0, 2, 3, 4, 5],
                        [2, 2, 2, 2, 2],
                        [3, 3, 3, 3, 3],
                        [4, 4, 4, 4, 4],
                        [5, 5, 5, 5, 5],
                        [0, 0, 0, 0, 0],
                    ]),
                    "eid2reg":
                    torch.tensor([0.0, 0.2, 0.3, 0.4, 0.5, 0.0], ),
                }
            }
        }
        vector_for_new_ent = np.arange(5)
        new_state_dict = refit_weights(
            np_same_ents,
            neweid2oldeid,
            train_entity_profile,
            new_entity_profile,
            vector_for_new_ent,
            state_dict,
        )

        gold_state_dict = {
            "module_pool": {
                "learned": {
                    "learned_entity_embedding.weight":
                    torch.tensor([
                        [1.0, 2, 3, 4, 5],
                        [0, 1, 2, 3, 4],
                        [0, 1, 2, 3, 4],
                        [2, 2, 2, 2, 2],
                        [0, 0, 0, 0, 0],
                    ]),
                    "eid2reg":
                    torch.tensor([0.0, 0.5, 0.5, 0.2, 0.0], ),
                }
            }
        }
        gld = gold_state_dict
        nsd = new_state_dict
        keys_to_check = [
            "module_pool", "learned", "learned_entity_embedding.weight"
        ]
        for k in keys_to_check:
            assert k in nsd
            assert k in gld
            if type(gld[k]) is dict:
                gld = gld[k]
                nsd = nsd[k]
                continue
            else:
                assert torch.equal(nsd[k], gld[k])
        gld = gold_state_dict
        nsd = new_state_dict
        keys_to_check = ["module_pool", "learned", "eid2reg"]
        for k in keys_to_check:
            assert k in nsd
            assert k in gld
            if type(gld[k]) is dict:
                gld = gld[k]
                nsd = nsd[k]
                continue
            else:
                assert torch.equal(nsd[k], gld[k])
コード例 #6
0
 def setUp(self) -> None:
     self.dir = Path("test/data/fit_to_profile_test")
     self.train_db = Path(self.dir / "train_entity_db")
     self.train_db.mkdir(exist_ok=True, parents=True)
     self.new_db = Path(self.dir / "entity_db_save2")
     self.new_db.mkdir(exist_ok=True, parents=True)
     self.profile_file = Path(self.dir / "raw_data/entity_profile.jsonl")
     self.profile_file.parent.mkdir(exist_ok=True, parents=True)
     # Dump train profile data
     data = [
         {
             "entity_id":
             "Q123",
             "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
             "title":
             "Dog",
             "types": {
                 "hyena": ["animal"],
                 "wiki": ["dog"]
             },
             "relations": [
                 {
                     "relation": "sibling",
                     "object": "Q345"
                 },
                 {
                     "relation": "sibling",
                     "object": "Q567"
                 },
             ],
         },
         {
             "entity_id": "Q345",
             "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
             "title": "Cat",
             "types": {
                 "hyena": ["animal"],
                 "wiki": ["cat"]
             },
             "relations": [{
                 "relation": "sibling",
                 "object": "Q123"
             }],
         },
         # Missing type system
         {
             "entity_id": "Q567",
             "mentions": [["catt", 6.5], ["animal", 3.3]],
             "title": "Catt",
             "types": {
                 "hyena": ["animal", "animall"]
             },
             "relations": [{
                 "relation": "sibling",
                 "object": "Q123"
             }],
         },
         # No KG/Types
         {
             "entity_id": "Q789",
             "mentions": [["animal", 12.2]],
             "title": "Dogg",
         },
     ]
     self.write_data(self.profile_file, data)
     ep = EntityProfile.load_from_jsonl(self.profile_file)
     ep.save(self.train_db)
コード例 #7
0
    def test_end2end(self):
        # ======================
        # PART 1: TRAIN A SMALL MODEL WITH ONE PROFILE DUMP
        # ======================
        # Generate entity profile data
        data = [
            {
                "entity_id": "Q123",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            },
            {
                "entity_id": "Q345",
                "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
                "title": "Cat",
                "types": {"hyena": ["animal"], "wiki": ["cat"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
            # Missing type system
            {
                "entity_id": "Q567",
                "mentions": [["catt", 6.5], ["animal", 3.3]],
                "title": "Catt",
                "types": {"hyena": ["animal", "animall"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
            # No KG/Types
            {
                "entity_id": "Q789",
                "mentions": [["animal", 12.2]],
                "title": "Dogg",
            },
        ]
        # Generate train data
        train_data = [
            {
                "sent_idx_unq": 1,
                "sentence": "I love animals and dogs",
                "qids": ["Q567", "Q123"],
                "aliases": ["animal", "dog"],
                "gold": [True, True],
                "spans": [[2, 3], [4, 5]],
            }
        ]
        self.write_data(self.profile_file, data)
        self.write_data(self.train_data, train_data)
        entity_profile = EntityProfile.load_from_jsonl(self.profile_file)
        # Dump profile data in format for model
        entity_profile.save(self.save_dir)

        # Setup model args to read the data/new profile data
        raw_args = {
            "emmental": {
                "n_epochs": 1,
            },
            "run_config": {
                "dataloader_threads": 1,
                "dataset_threads": 1,
            },
            "train_config": {"batch_size": 2},
            "model_config": {"hidden_size": 20, "num_heads": 1},
            "data_config": {
                "entity_dir": str(self.save_dir),
                "max_seq_len": 7,
                "max_aliases": 2,
                "data_dir": str(self.data_dir),
                "emb_dir": str(self.save_dir),
                "word_embedding": {
                    "layers": 1,
                    "freeze": True,
                    "cache_dir": str(self.save_dir / "retrained_bert_model"),
                },
                "ent_embeddings": [
                    {
                        "key": "learned",
                        "freeze": False,
                        "load_class": "LearnedEntityEmb",
                        "args": {"learned_embedding_size": 10},
                    },
                    {
                        "key": "learned_type",
                        "load_class": "LearnedTypeEmb",
                        "freeze": True,
                        "args": {
                            "type_labels": f"{TYPE_SUBFOLDER}/wiki/qid2typeids.json",
                            "type_vocab": f"{TYPE_SUBFOLDER}/wiki/type_vocab.json",
                            "max_types": 2,
                            "type_dim": 10,
                        },
                    },
                    {
                        "key": "kg_adj",
                        "load_class": "KGIndices",
                        "batch_on_the_fly": True,
                        "normalize": False,
                        "args": {"kg_adj": f"{KG_SUBFOLDER}/kg_adj.txt"},
                    },
                ],
                "train_dataset": {"file": "train.jsonl"},
                "dev_dataset": {"file": "train.jsonl"},
                "test_dataset": {"file": "train.jsonl"},
            },
        }
        with open(self.arg_file, "w") as out_f:
            ujson.dump(raw_args, out_f)

        args = parser_utils.parse_boot_and_emm_args(str(self.arg_file))
        # This _MUST_ get passed the args so it gets a random seed set
        emmental.init(log_dir=str(self.dir / "temp_log"), config=args)
        if not os.path.exists(emmental.Meta.log_path):
            os.makedirs(emmental.Meta.log_path)

        scores = run_model(mode="train", config=args)
        saved_model_path1 = f"{emmental.Meta.log_path}/last_model.pth"
        assert type(scores) is dict

        # ======================
        # PART 3: MODIFY PROFILE AND LOAD PRETRAINED MODEL AND TRAIN FOR MORE
        # ======================
        entity_profile = EntityProfile.load_from_jsonl(
            self.profile_file, edit_mode=True
        )
        entity_profile.add_type("Q123", "cat", "wiki")
        entity_profile.remove_type("Q123", "dog", "wiki")
        entity_profile.add_mention("Q123", "cat", 100.0)
        # Dump profile data in format for model
        entity_profile.save(self.save_dir2)

        # Modify arg paths
        args["data_config"]["entity_dir"] = str(self.save_dir2)
        args["data_config"]["emb_dir"] = str(self.save_dir2)

        # Load pretrained model
        args["model_config"]["model_path"] = f"{emmental.Meta.log_path}/last_model.pth"
        emmental.Meta.config["model_config"]["model_path"] = saved_model_path1

        # Init another run
        emmental.init(log_dir=str(self.dir / "temp_log"), config=args)
        if not os.path.exists(emmental.Meta.log_path):
            os.makedirs(emmental.Meta.log_path)
        scores = run_model(mode="train", config=args)
        saved_model_path2 = f"{emmental.Meta.log_path}/last_model.pth"
        assert type(scores) is dict

        # ======================
        # PART 4: VERIFY CHANGES IN THE MODEL WERE AS EXPECTED
        # ======================
        # Check that type mappings are different in the right way...we remove "dog"
        # from EID 1 and added "cat". "dog" is not longer a type.
        eid2typeids_table1, type2row_dict1, num_types_with_unk1 = torch.load(
            self.save_dir / "prep" / "type_table_type_mappings_wiki_qid2typeids_2.pt"
        )
        eid2typeids_table2, type2row_dict2, num_types_with_unk2 = torch.load(
            self.save_dir2 / "prep" / "type_table_type_mappings_wiki_qid2typeids_2.pt"
        )
        # Modify mapping 2 to become mapping 1
        # Row 1 is Q123, Col 0 is type (this was "cat")
        eid2typeids_table2[1][0] = entity_profile._type_systems["wiki"]._type_vocab[
            "dog"
        ]
        self.assertEqual(num_types_with_unk1, num_types_with_unk2)
        self.assertDictEqual({1: 1, 2: 2}, type2row_dict1)
        self.assertDictEqual({1: 1}, type2row_dict2)
        assert torch.equal(eid2typeids_table1, eid2typeids_table2)
        # Check that the alias mappings are different
        alias2entity_table1 = torch.from_numpy(
            np.memmap(
                self.save_dir / "prep" / "alias2entity_table_alias2qids_InC1.pt",
                dtype="int64",
                mode="r",
                shape=(5, 30),
            )
        )
        gold_alias2entity_table1 = torch.tensor(
            [
                [
                    4,
                    1,
                    3,
                    2,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    2,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    2,
                    3,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
            ]
        )
        assert torch.equal(alias2entity_table1, gold_alias2entity_table1)
        # The change is the "cat" alias has entity 1 added to the beginning
        # It used to only point to Q345 which is entity 2
        alias2entity_table2 = torch.from_numpy(
            np.memmap(
                self.save_dir2 / "prep" / "alias2entity_table_alias2qids_InC1.pt",
                dtype="int64",
                mode="r",
                shape=(5, 30),
            )
        )
        gold_alias2entity_table2 = torch.tensor(
            [
                [
                    4,
                    1,
                    3,
                    2,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    2,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    2,
                    3,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
            ]
        )
        assert torch.equal(alias2entity_table2, gold_alias2entity_table2)

        # The type embeddings were frozen so they should be the same
        model1 = torch.load(saved_model_path1)
        model2 = torch.load(saved_model_path2)
        assert torch.equal(
            model1["model"]["module_pool"]["learned_type"]["type_emb.weight"],
            model2["model"]["module_pool"]["learned_type"]["type_emb.weight"],
        )
コード例 #8
0
    def test_prune_to_entities(self):
        data = [
            {
                "entity_id": "Q123",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            },
            {
                "entity_id": "Q345",
                "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
                "title": "Cat",
                "types": {"hyena": ["animal"], "wiki": ["cat"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
        ]
        self.write_data(self.profile_file, data)
        entity_profile = EntityProfile.load_from_jsonl(
            self.profile_file, max_candidates=5, edit_mode=True
        )
        entity_profile.save(self.save_dir2)

        with self.assertRaises(ValueError) as context:
            entity_profile.prune_to_entities({"Q123", "Q567"})
        assert type(context.exception) is ValueError
        assert "The entity Q567 does not exist" in str(context.exception)

        entity_profile.prune_to_entities({"Q123"})
        self.assertTrue(entity_profile.qid_exists("Q123"))
        self.assertFalse(entity_profile.qid_exists("Q345"))
        self.assertListEqual(
            entity_profile.get_mentions_with_scores("Q123"),
            [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
        )
        self.assertListEqual(entity_profile.get_types("Q123", "hyena"), ["animal"])
        self.assertListEqual(entity_profile.get_types("Q123", "wiki"), ["dog"])
        self.assertListEqual(
            entity_profile.get_connections_by_relation("Q123", "sibling"),
            [],
        )

        # Check that no_kg still works with load_from_cache
        entity_profile2 = EntityProfile.load_from_cache(
            self.save_dir2, no_kg=True, edit_mode=True
        )
        entity_profile2.prune_to_entities({"Q123"})
        self.assertTrue(entity_profile2.qid_exists("Q123"))
        self.assertFalse(entity_profile2.qid_exists("Q345"))
        self.assertListEqual(
            entity_profile2.get_mentions_with_scores("Q123"),
            [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
        )
        self.assertListEqual(entity_profile2.get_types("Q123", "hyena"), ["animal"])
        self.assertListEqual(entity_profile2.get_types("Q123", "wiki"), ["dog"])
        self.assertListEqual(
            entity_profile2.get_connections_by_relation("Q123", "sibling"),
            [],
        )
コード例 #9
0
    def test_profile_load_simple(self):
        data = [
            {
                "entity_id": "Q123",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            },
            {
                "entity_id": "Q345",
                "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
                "title": "Cat",
                "types": {"hyena": ["animal"], "wiki": ["cat"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
            # Missing type system
            {
                "entity_id": "Q567",
                "mentions": [["catt", 6.5], ["animal", 3.3]],
                "title": "Catt",
                "types": {"hyena": ["animal", "animall"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
            # No KG/Types
            {
                "entity_id": "Q789",
                "mentions": [["animal", 12.2]],
                "title": "Dogg",
            },
        ]
        self.write_data(self.profile_file, data)
        gold_qid2title = {"Q123": "Dog", "Q345": "Cat", "Q567": "Catt", "Q789": "Dogg"}
        gold_alias2qids = {
            "dog": [["Q123", 10.0]],
            "dogg": [["Q123", 7.0]],
            "cat": [["Q345", 10.0]],
            "catt": [["Q345", 7.0], ["Q567", 6.5]],
            "animal": [["Q789", 12.2], ["Q123", 4.0], ["Q567", 3.3], ["Q345", 3.0]],
        }
        gold_type_systems = {
            "hyena": {
                "Q123": ["animal"],
                "Q345": ["animal"],
                "Q567": ["animal", "animall"],
                "Q789": [],
            },
            "wiki": {"Q123": ["dog"], "Q345": ["cat"], "Q567": [], "Q789": []},
        }
        gold_qid2relations = {
            "Q123": {"sibling": ["Q345", "Q567"]},
            "Q345": {"sibling": ["Q123"]},
            "Q567": {"sibling": ["Q123"]},
            "Q789": {},
        }
        (
            qid2title,
            alias2qids,
            type_systems,
            qid2relations,
        ) = EntityProfile._read_profile_file(self.profile_file)

        self.assertDictEqual(gold_qid2title, qid2title)
        self.assertDictEqual(gold_alias2qids, alias2qids)
        self.assertDictEqual(gold_type_systems, type_systems)
        self.assertDictEqual(gold_qid2relations, qid2relations)

        # Test loading/saving from jsonl
        ep = EntityProfile.load_from_jsonl(self.profile_file, edit_mode=True)
        ep.save_to_jsonl(self.profile_file)
        read_in_data = [ujson.loads(li) for li in open(self.profile_file)]

        assert len(read_in_data) == len(data)

        for qid_obj in data:
            found_other_obj = None
            for possible_match in read_in_data:
                if qid_obj["entity_id"] == possible_match["entity_id"]:
                    found_other_obj = possible_match
                    break
            assert found_other_obj is not None
            self.assertDictEqual(qid_obj, found_other_obj)
コード例 #10
0
    def test_add_entity(self):
        data = [
            {
                "entity_id": "Q123",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            },
            {
                "entity_id": "Q345",
                "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
                "title": "Cat",
                "types": {"hyena": ["animal"], "wiki": ["cat"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
        ]
        self.write_data(self.profile_file, data)
        entity_profile = EntityProfile.load_from_jsonl(
            self.profile_file, max_candidates=3, edit_mode=True
        )
        entity_profile.save(self.save_dir2)

        # Test bad format
        with self.assertRaises(ValueError) as context:
            entity_profile.add_entity(["bad format"])
        assert type(context.exception) is ValueError
        assert "The input to update_entity needs to be a dictionary" in str(
            context.exception
        )

        new_entity = {
            "entity_id": "Q345",
            "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
            "title": "Cat",
            "types": {"hyena": ["animal"], "wiki": ["cat"]},
            "relations": [{"relation": "sibling", "object": "Q123"}],
        }

        # Test already existing entity
        with self.assertRaises(ValueError) as context:
            entity_profile.add_entity(new_entity)
        assert type(context.exception) is ValueError
        assert "The entity Q345 already exists" in str(context.exception)

        new_entity = {
            "entity_id": "Q789",
            "mentions": [["snake", 10.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {"hyena": ["animal"], "new_sys": ["snakey"]},
            "relations": [{"relation": "sibling", "object": "Q123"}],
        }

        # Test new type system
        with self.assertRaises(ValueError) as context:
            entity_profile.add_entity(new_entity)
        assert type(context.exception) is ValueError
        assert "When adding a new entity, you must use the same type system" in str(
            context.exception
        )

        new_entity = {
            "entity_id": "Q789",
            "mentions": [["snake", 10.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {"hyena": ["animal"]},
            "relations": [{"relatiion": "sibbbling", "object": "Q123"}],
        }

        # Test new bad relation format
        with self.assertRaises(ValueError) as context:
            entity_profile.add_entity(new_entity)
        assert type(context.exception) is ValueError
        assert (
            "For each value in relations, it must be a JSON with keys relation and object"
            in str(context.exception)
        )

        new_entity = {
            "entity_id": "Q789",
            "mentions": [["snake", 10.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {"hyena": ["animal"]},
            "relations": [{"relation": "sibbbling", "object": "Q123"}],
        }

        # Test new relation
        with self.assertRaises(ValueError) as context:
            entity_profile.add_entity(new_entity)
        assert type(context.exception) is ValueError
        assert (
            "When adding a new entity, you must use the same set of relations."
            in str(context.exception)
        )

        new_entity = {
            "entity_id": "Q789",
            "mentions": [["snake", 10.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {"hyena": ["animal"]},
            "relations": [{"relation": "sibling", "object": "Q123"}],
        }
        # Assert it is added
        entity_profile.add_entity(new_entity)
        self.assertTrue(entity_profile.qid_exists("Q789"))
        self.assertEqual(entity_profile.get_title("Q789"), "Snake")
        self.assertListEqual(
            entity_profile.get_mentions_with_scores("Q789"),
            [["snake", 10.0], ["animal", 3.0]],
        )
        self.assertListEqual(entity_profile.get_types("Q789", "hyena"), ["animal"])
        self.assertListEqual(entity_profile.get_types("Q789", "wiki"), [])
        self.assertListEqual(
            entity_profile.get_connections_by_relation("Q789", "sibling"), ["Q123"]
        )

        # Check that no_kg still works with load_from_cache
        entity_profile2 = EntityProfile.load_from_cache(
            self.save_dir2, no_kg=True, edit_mode=True
        )
        entity_profile2.add_entity(new_entity)
        self.assertTrue(entity_profile2.qid_exists("Q789"))
        self.assertEqual(entity_profile2.get_title("Q789"), "Snake")
        self.assertListEqual(
            entity_profile2.get_mentions_with_scores("Q789"),
            [["snake", 10.0], ["animal", 3.0]],
        )
        self.assertListEqual(entity_profile2.get_types("Q789", "hyena"), ["animal"])
        self.assertListEqual(entity_profile2.get_types("Q789", "wiki"), [])
        self.assertListEqual(
            entity_profile2.get_connections_by_relation("Q789", "sibling"), []
        )
コード例 #11
0
    def test_profile_dump_load(self):
        data = [
            {
                "entity_id": "Q123",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            },
            {
                "entity_id": "Q345",
                "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
                "title": "Cat",
                "types": {"hyena": ["animal"], "wiki": ["cat"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
        ]
        self.write_data(self.profile_file, data)
        entity_profile = EntityProfile.load_from_jsonl(
            self.profile_file, max_candidates=5, edit_mode=True
        )
        entity_profile.save(self.save_dir2)

        # Test load correctly
        entity_profile2 = EntityProfile.load_from_cache(self.save_dir2)

        self.assertSetEqual(
            set(entity_profile.get_all_qids()), set(entity_profile2.get_all_qids())
        )
        self.assertSetEqual(
            set(entity_profile.get_all_typesystems()),
            set(entity_profile2.get_all_typesystems()),
        )
        for type_sys in entity_profile.get_all_typesystems():
            self.assertSetEqual(
                set(entity_profile.get_all_types(type_sys)),
                set(entity_profile2.get_all_types(type_sys)),
            )
        for qid in entity_profile.get_all_qids():
            self.assertSetEqual(
                set(entity_profile.get_all_connections(qid)),
                set(entity_profile2.get_all_connections(qid)),
            )

        # Test load with no types or kgs
        entity_profile2 = EntityProfile.load_from_cache(
            self.save_dir2, no_type=True, no_kg=True
        )

        self.assertSetEqual(
            set(entity_profile.get_all_qids()), set(entity_profile2.get_all_qids())
        )
        assert len(entity_profile2.get_all_typesystems()) == 0
        self.assertIsNone(entity_profile2._kg_symbols)

        # Testing that the functions still work despite not loading them
        assert len(entity_profile2.get_all_connections("Q123")) == 0

        # Test load with no types or kgs
        entity_profile2 = EntityProfile.load_from_cache(
            self.save_dir2, no_kg=True, type_systems_to_load=["wiki"]
        )

        self.assertSetEqual(
            set(entity_profile.get_all_qids()), set(entity_profile2.get_all_qids())
        )
        assert entity_profile2.get_all_typesystems() == ["wiki"]
        self.assertSetEqual(
            set(entity_profile.get_all_types("wiki")),
            set(entity_profile2.get_all_types("wiki")),
        )
        self.assertIsNone(entity_profile2._kg_symbols)

        # Assert error loading type system that is not there
        with self.assertRaises(ValueError) as context:
            entity_profile2.get_all_types("hyena")
        assert type(context.exception) is ValueError
        assert "type system hyena is not one" in str(context.exception)
コード例 #12
0
ファイル: fit_to_profile.py プロジェクト: lorr1/bootleg
def fit_profiles(args):
    print(json.dumps(vars(args), indent=4))

    if args.model_config is not None:
        assert (
            args.save_model_config is not None
        ), f"If you pass in a model config, you must pass in a model save config path"

    print(f"Loading train entity profile from {args.train_entity_profile}")
    train_entity_profile = EntityProfile.load_from_cache(
        load_dir=args.train_entity_profile, no_type=True, no_kg=True
    )
    print(f"Loading new entity profile from {args.new_entity_profile}")
    new_entity_profile = EntityProfile.load_from_cache(
        load_dir=args.new_entity_profile, no_type=True, no_kg=True
    )
    oldqid2newqid = dict()
    newqid2oldqid = dict()
    if args.oldqid2newqid is not None and len(args.oldqid2newqid) > 0:
        with open(args.oldqid2newqid) as in_f:
            oldqid2newqid = ujson.load(in_f)
            newqid2oldqid = {v: k for k, v in oldqid2newqid.items()}
            assert len(oldqid2newqid) == len(
                newqid2oldqid
            ), f"The dicts of oldqid2newqid and its inverse do not have the same length"

    np_removed_ents, np_same_ents, np_new_ents, oldeid2neweid = match_entities(
        train_entity_profile, new_entity_profile, oldqid2newqid, newqid2oldqid
    )
    neweid2oldeid = {v: k for k, v in oldeid2neweid.items()}
    assert len(oldeid2neweid) == len(
        neweid2oldeid
    ), f"The lengths of oldeid2neweid and neweid2oldeid don't match"
    state_dict, model_state_dict = load_statedict(args.model_path)

    # We do not support modifying a topK model. Only the original model.
    try:
        get_nested_item(model_state_dict, ENTITY_TOPK_KEYS)
        raise NotImplementedError(
            f"We don't support fitting a topK mini model. Instead, call `fit_to_profile` on the full Bootleg model. "
            f"Then call utils.entity_profile.compress_topk_entity_embeddings to create your own mini model."
        )
    except:
        pass

    try:
        weight_shape = get_nested_item(model_state_dict, ENTITY_EMB_KEYS).shape[1]
    except:
        raise ValueError(f"ERROR: All of {ENTITY_EMB_KEYS} are not in model_state_dict")

    # Refit weights
    if args.init_vec_file is not None and len(args.init_vec_file) > 0:
        vector_for_new_ent = np.load(args.init_vec_file)
    else:
        print(f"Setting init vector to be all zeros")
        vector_for_new_ent = np.zeros(weight_shape)
    new_model_state_dict = refit_weights(
        np_same_ents,
        neweid2oldeid,
        train_entity_profile,
        new_entity_profile,
        vector_for_new_ent,
        model_state_dict,
    )
    state_dict["model"] = new_model_state_dict
    print(f"Saving model at {args.save_model_path}")
    torch.save(state_dict, args.save_model_path)
    del new_model_state_dict
    del state_dict

    # Refit titles
    # Will keep track of all embeddings to adjust. If given a config, we will only adjust
    # the one from the config. Othwerwise, we adjust all that are "static_table_" arrays of
    # length BERT_DIM
    if not args.no_title_emb:
        title_embeddings = []
        prepped_title_emb_files = []
        title_emb_file = None
        prep_subdir = "prep"
        # First try to read entity_prep_dir from config
        if args.model_config is not None:
            with open(args.model_config) as file:
                config = yaml.load(file, Loader=yaml.FullLoader)
            prep_subdir = config["data_config"].get("entity_prep_dir", "prep")
            for ent in config["data_config"]["ent_embeddings"]:
                if ent["load_class"] == "StaticEmb" and ent["key"] == "title_static":
                    assert (
                        "emb_file" in ent["args"]
                    ), f"emb_file needs to be in title_static config"
                    title_emb_file = ent["args"]["emb_file"]

        prep_dir = Path(args.train_entity_profile) / prep_subdir
        out_prep_dir = Path(args.new_entity_profile) / prep_subdir

        print(f"Looking for title embedding in {prep_dir}")
        # Try to find a saved title prep file
        for file in prep_dir.iterdir():
            if file.is_file() and file.name.startswith("static_table_"):
                # If we know the title embedding file from the config, use it to find the right prepped file
                if (
                    title_emb_file is not None
                    and file.name
                    != f"static_table_{os.path.splitext(os.path.basename(title_emb_file))[0]}.npy"
                ):
                    continue
                possible_titles = np.load(file, mmap_mode="r")
                if list(possible_titles.shape) == [
                    train_entity_profile.num_entities_with_pad_and_nocand,
                    BERT_DIM,
                ]:
                    title_embeddings.append(possible_titles)
                    prepped_title_emb_files.append(file.name)
        if len(title_embeddings) == 0:
            print(
                f"We were unable to adjust titles. If your model does not use title embeddings, ignore this. If your"
                f"model does (all Bootleg models by default do), please call "
                f"```python -m bootleg.utils.preprocessing.build_static_embeddings --help``` to extract manually. "
                f"The saved file from this method should be added to ```emb_file``` config param for the title"
                f"embedding."
            )
        else:
            for title_embed, prepped_title_emb_file in zip(
                title_embeddings, prepped_title_emb_files
            ):
                print(f"Attempting to refit title {prepped_title_emb_file}")
                out_prep_dir.mkdir(parents=True, exist_ok=True)
                save_file = out_prep_dir / prepped_title_emb_file
                print(f"Saving to {save_file}")
                # Returns memmap array
                _ = refit_titles(
                    np_same_ents,
                    np_new_ents,
                    neweid2oldeid,
                    train_entity_profile,
                    new_entity_profile,
                    title_embed,
                    str(save_file),
                    args.bert_model,
                    args.bert_model_cache,
                    args.cpu,
                )

    if args.model_config is not None:
        modify_config(
            args.model_config,
            args.save_model_config,
            args.save_model_path,
            args.new_entity_profile,
        )
コード例 #13
0
ファイル: test_fit_to_profile.py プロジェクト: lorr1/bootleg
    def test_fit_titles(self):
        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db,
            edit_mode=True,
        )
        # Modify train profle
        new_entity1 = {
            "entity_id": "Q910",
            "mentions": [["cobra", 10.0], ["animal", 3.0]],
            "title": "Cobra",
            "types": {
                "hyena": ["animal"],
                "wiki": ["dog"]
            },
            "relations": [{
                "relation": "sibling",
                "object": "Q123"
            }],
        }
        new_entity2 = {
            "entity_id": "Q101",
            "mentions": [["snake", 10.0], ["snakes", 7.0], ["animal", 3.0]],
            "title": "Snake",
            "types": {
                "hyena": ["animal"],
                "wiki": ["dog"]
            },
        }
        train_entity_profile.add_entity(new_entity1)
        train_entity_profile.add_entity(new_entity2)
        train_entity_profile.reidentify_entity("Q123", "Q321")
        # Save new profile
        train_entity_profile.save(self.new_db)
        # Create old title embs
        title_embs = np.zeros((
            6,
            768,
        ))
        for i in range(len(title_embs)):
            title_embs[i] = np.ones(768) * i

        train_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.train_db, )
        new_entity_profile = EntityProfile.load_from_cache(
            load_dir=self.new_db)
        neweid2oldeid = {1: 1, 2: 2, 3: 3, 4: 4}
        np_same_ents = {"Q321", "Q345", "Q567", "Q789"}
        np_new_ents = {"Q910", "Q101"}

        new_title_embs = refit_titles(
            np_same_ents,
            np_new_ents,
            neweid2oldeid,
            train_entity_profile,
            new_entity_profile,
            title_embs,
            str(self.dir / "temp_title.npy"),
            "bert-base-uncased",
            str(self.dir / "temp_bert_models"),
            True,
        )
        # Compute gold BERT titles for new entities
        tokenizer = BertTokenizer.from_pretrained(
            "bert-base-uncased",
            do_lower_case=True,
            cache_dir=str(self.dir / "temp_bert_models"),
        )
        model = BertModel.from_pretrained(
            "bert-base-uncased",
            cache_dir=str(self.dir / "temp_bert_models"),
            output_attentions=False,
            output_hidden_states=False,
        )
        model.eval()
        input_ids = tokenizer(["Cobra", "Snake"],
                              padding=True,
                              truncation=True,
                              return_tensors="pt")
        inputs = input_ids["input_ids"]
        attention_mask = input_ids["attention_mask"]
        outputs = model(inputs, attention_mask=attention_mask)[0]
        outputs[inputs == 0] = 0
        avgtitle = average_titles(inputs, outputs).to("cpu").detach().numpy()

        gold_title_embs = np.zeros((
            8,
            768,
        ))
        gold_title_embs[0] = np.ones(768) * 0
        gold_title_embs[1] = np.ones(768) * 1
        gold_title_embs[2] = np.ones(768) * 2
        gold_title_embs[3] = np.ones(768) * 3
        gold_title_embs[4] = np.ones(768) * 4
        gold_title_embs[5] = avgtitle[0]
        gold_title_embs[6] = avgtitle[1]
        gold_title_embs[7] = np.ones(
            768) * 0  # Last row is set to zero in fit method
        np.testing.assert_array_almost_equal(new_title_embs, gold_title_embs)