コード例 #1
0
    def test_end2end_withtype_singlethread(self):
        self.args.data_config.type_prediction.use_type_pred = True
        self.args.model_config.hidden_size = 20
        # Just setting this for testing pipelines
        self.args.data_config.eval_accumulation_steps = 2
        # unfreezing the word embedding helps the type prediction task
        self.args.data_config.word_embedding.freeze = False
        scores = run_model(mode="train", config=self.args)
        assert type(scores) is dict
        assert len(scores) > 0
        # losses from two tasks contribute to this
        assert scores["model/all/train/loss"] < 0.08

        self.args["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"
        emmental.Meta.config["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"

        result_file, out_emb_file = run_model(mode="dump_embs",
                                              config=self.args)
        assert os.path.exists(result_file)
        results = [ujson.loads(li) for li in open(result_file)]
        assert 18 == len(results)  # 18 total sentences
        assert set([f for li in results for f in li["ctx_emb_ids"]
                    ]) == set(range(51))  # 38 total mentions
        assert os.path.exists(out_emb_file)
コード例 #2
0
    def test_end2end_bert_long_context(self):
        self.args.model_config.attn_class = "BERTNED"
        # Only take the learned entity embeddings for BERTNED
        self.args.data_config.ent_embeddings = self.args.data_config.ent_embeddings[:
                                                                                    1]
        # Set the learned embedding to hidden size for BERTNED
        self.args.data_config.ent_embeddings[
            0].args.learned_embedding_size = 20
        self.args.data_config.word_embedding.use_sent_proj = False
        self.args.data_config.max_seq_len = 100
        self.args.data_config.max_aliases = 10
        scores = run_model(mode="train", config=self.args)
        assert type(scores) is dict
        assert len(scores) > 0
        assert scores["model/all/train/loss"] < 0.5

        self.args["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"
        emmental.Meta.config["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"

        result_file, out_emb_file = run_model(mode="dump_embs",
                                              config=self.args)
        assert os.path.exists(result_file)
        results = [ujson.loads(li) for li in open(result_file)]
        assert 18 == len(results)  # 18 total sentences
        assert set([f for li in results for f in li["ctx_emb_ids"]
                    ]) == set(range(51))  # 38 total mentions
        assert os.path.exists(out_emb_file)

        shutil.rmtree("test/temp", ignore_errors=True)
コード例 #3
0
    def test_end2end_withtitle_accstep(self):
        self.args.data_config.ent_embeddings.append(
            DottedDict({
                "key": "title1",
                "load_class": "TitleEmb",
                "send_through_bert": True,
                "args": {
                    "proj": 6
                },
            }))
        # Just setting this for testing pipelines
        self.args.data_config.eval_accumulation_steps = 2
        self.args.run_config.dataset_threads = 2
        scores = run_model(mode="train", config=self.args)
        assert type(scores) is dict
        assert len(scores) > 0
        assert scores["model/all/train/loss"] < 0.08

        self.args["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"
        emmental.Meta.config["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"

        result_file, out_emb_file = run_model(mode="dump_embs",
                                              config=self.args)
        assert os.path.exists(result_file)
        results = [ujson.loads(li) for li in open(result_file)]
        assert 18 == len(results)  # 18 total sentences
        assert set([f for li in results for f in li["ctx_emb_ids"]
                    ]) == set(range(51))  # 38 total mentions
        assert os.path.exists(out_emb_file)
コード例 #4
0
    def test_end2end_withkg(self):
        # For the collate and dataloaders to play nicely, the spawn must be fork (this is set in run.py)
        torch.multiprocessing.set_start_method("fork", force=True)

        scores = run_model(mode="train", config=self.args)

        assert type(scores) is dict
        assert len(scores) > 0
        assert scores["model/all/train/loss"] < 0.08

        self.args["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"
        emmental.Meta.config["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"

        result_file, out_emb_file = run_model(mode="dump_embs",
                                              config=self.args)
        assert os.path.exists(result_file)
        results = [ujson.loads(li) for li in open(result_file)]
        assert 18 == len(results)  # 18 total sentences
        assert os.path.exists(out_emb_file)
コード例 #5
0
    def test_end2end_withreg_evalbatch(self):
        reg_file = "test/temp/reg_file.csv"
        utils.ensure_dir("test/temp")
        reg_data = [
            ["qid", "regularization"],
            ["Q1", "0.5"],
            ["Q2", "0.3"],
            ["Q3", "0.2"],
            ["Q4", "0.9"],
        ]
        self.args.data_config.eval_accumulation_steps = 2
        self.args.run_config.dataset_threads = 2
        self.args.run_config.eval_batch_size = 2
        with open(reg_file, "w") as out_f:
            for item in reg_data:
                out_f.write(",".join(item) + "\n")

        self.args.data_config.ent_embeddings[0]["args"][
            "regularize_mapping"] = reg_file
        scores = run_model(mode="train", config=self.args)
        assert type(scores) is dict
        assert len(scores) > 0
        assert scores["model/all/train/loss"] < 0.05

        self.args["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"
        emmental.Meta.config["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"

        result_file, out_emb_file = run_model(mode="dump_embs",
                                              config=self.args)
        assert os.path.exists(result_file)
        results = [ujson.loads(li) for li in open(result_file)]
        assert 18 == len(results)  # 18 total sentences
        assert set([f for li in results for f in li["ctx_emb_ids"]
                    ]) == set(range(51))  # 38 total mentions
        assert os.path.exists(out_emb_file)

        shutil.rmtree("test/temp", ignore_errors=True)
コード例 #6
0
    def test_end2end_withoutkg(self):
        # KG IS LAST EMBEDDING SO WE REMOVE IT
        self.args.data_config.ent_embeddings = self.args.data_config.ent_embeddings[:
                                                                                    -1]
        # Just setting this for testing pipelines
        self.args.data_config.max_aliases = 1
        scores = run_model(mode="train", config=self.args)
        assert type(scores) is dict
        assert len(scores) > 0
        assert scores["model/all/train/loss"] < 0.05

        self.args["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"
        emmental.Meta.config["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"

        result_file, out_emb_file = run_model(mode="dump_embs",
                                              config=self.args)
        assert os.path.exists(result_file)
        results = [ujson.loads(li) for li in open(result_file)]
        assert 18 == len(results)  # 18 total sentences
        assert set([f for li in results for f in li["ctx_emb_ids"]
                    ]) == set(range(51))  # 38 total mentions
        assert os.path.exists(out_emb_file)
コード例 #7
0
ファイル: bootleg.py プロジェクト: stanford-oval/genienlp
 def disambiguate_mentions(self, config_args):
     run_model(self.args.bootleg_dump_mode, config_args)
コード例 #8
0
    def test_end2end(self):
        # ======================
        # PART 1: TRAIN A SMALL MODEL WITH ONE PROFILE DUMP
        # ======================
        # Generate entity profile data
        data = [
            {
                "entity_id": "Q123",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            },
            {
                "entity_id": "Q345",
                "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
                "title": "Cat",
                "types": {"hyena": ["animal"], "wiki": ["cat"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
            # Missing type system
            {
                "entity_id": "Q567",
                "mentions": [["catt", 6.5], ["animal", 3.3]],
                "title": "Catt",
                "types": {"hyena": ["animal", "animall"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
            # No KG/Types
            {
                "entity_id": "Q789",
                "mentions": [["animal", 12.2]],
                "title": "Dogg",
            },
        ]
        # Generate train data
        train_data = [
            {
                "sent_idx_unq": 1,
                "sentence": "I love animals and dogs",
                "qids": ["Q567", "Q123"],
                "aliases": ["animal", "dog"],
                "gold": [True, True],
                "spans": [[2, 3], [4, 5]],
            }
        ]
        self.write_data(self.profile_file, data)
        self.write_data(self.train_data, train_data)
        entity_profile = EntityProfile.load_from_jsonl(self.profile_file)
        # Dump profile data in format for model
        entity_profile.save(self.save_dir)

        # Setup model args to read the data/new profile data
        raw_args = {
            "emmental": {
                "n_epochs": 1,
            },
            "run_config": {
                "dataloader_threads": 1,
                "dataset_threads": 1,
            },
            "train_config": {"batch_size": 2},
            "model_config": {"hidden_size": 20, "num_heads": 1},
            "data_config": {
                "entity_dir": str(self.save_dir),
                "max_seq_len": 7,
                "max_aliases": 2,
                "data_dir": str(self.data_dir),
                "emb_dir": str(self.save_dir),
                "word_embedding": {
                    "layers": 1,
                    "freeze": True,
                    "cache_dir": str(self.save_dir / "retrained_bert_model"),
                },
                "ent_embeddings": [
                    {
                        "key": "learned",
                        "freeze": False,
                        "load_class": "LearnedEntityEmb",
                        "args": {"learned_embedding_size": 10},
                    },
                    {
                        "key": "learned_type",
                        "load_class": "LearnedTypeEmb",
                        "freeze": True,
                        "args": {
                            "type_labels": f"{TYPE_SUBFOLDER}/wiki/qid2typeids.json",
                            "type_vocab": f"{TYPE_SUBFOLDER}/wiki/type_vocab.json",
                            "max_types": 2,
                            "type_dim": 10,
                        },
                    },
                    {
                        "key": "kg_adj",
                        "load_class": "KGIndices",
                        "batch_on_the_fly": True,
                        "normalize": False,
                        "args": {"kg_adj": f"{KG_SUBFOLDER}/kg_adj.txt"},
                    },
                ],
                "train_dataset": {"file": "train.jsonl"},
                "dev_dataset": {"file": "train.jsonl"},
                "test_dataset": {"file": "train.jsonl"},
            },
        }
        with open(self.arg_file, "w") as out_f:
            ujson.dump(raw_args, out_f)

        args = parser_utils.parse_boot_and_emm_args(str(self.arg_file))
        # This _MUST_ get passed the args so it gets a random seed set
        emmental.init(log_dir=str(self.dir / "temp_log"), config=args)
        if not os.path.exists(emmental.Meta.log_path):
            os.makedirs(emmental.Meta.log_path)

        scores = run_model(mode="train", config=args)
        saved_model_path1 = f"{emmental.Meta.log_path}/last_model.pth"
        assert type(scores) is dict

        # ======================
        # PART 3: MODIFY PROFILE AND LOAD PRETRAINED MODEL AND TRAIN FOR MORE
        # ======================
        entity_profile = EntityProfile.load_from_jsonl(
            self.profile_file, edit_mode=True
        )
        entity_profile.add_type("Q123", "cat", "wiki")
        entity_profile.remove_type("Q123", "dog", "wiki")
        entity_profile.add_mention("Q123", "cat", 100.0)
        # Dump profile data in format for model
        entity_profile.save(self.save_dir2)

        # Modify arg paths
        args["data_config"]["entity_dir"] = str(self.save_dir2)
        args["data_config"]["emb_dir"] = str(self.save_dir2)

        # Load pretrained model
        args["model_config"]["model_path"] = f"{emmental.Meta.log_path}/last_model.pth"
        emmental.Meta.config["model_config"]["model_path"] = saved_model_path1

        # Init another run
        emmental.init(log_dir=str(self.dir / "temp_log"), config=args)
        if not os.path.exists(emmental.Meta.log_path):
            os.makedirs(emmental.Meta.log_path)
        scores = run_model(mode="train", config=args)
        saved_model_path2 = f"{emmental.Meta.log_path}/last_model.pth"
        assert type(scores) is dict

        # ======================
        # PART 4: VERIFY CHANGES IN THE MODEL WERE AS EXPECTED
        # ======================
        # Check that type mappings are different in the right way...we remove "dog"
        # from EID 1 and added "cat". "dog" is not longer a type.
        eid2typeids_table1, type2row_dict1, num_types_with_unk1 = torch.load(
            self.save_dir / "prep" / "type_table_type_mappings_wiki_qid2typeids_2.pt"
        )
        eid2typeids_table2, type2row_dict2, num_types_with_unk2 = torch.load(
            self.save_dir2 / "prep" / "type_table_type_mappings_wiki_qid2typeids_2.pt"
        )
        # Modify mapping 2 to become mapping 1
        # Row 1 is Q123, Col 0 is type (this was "cat")
        eid2typeids_table2[1][0] = entity_profile._type_systems["wiki"]._type_vocab[
            "dog"
        ]
        self.assertEqual(num_types_with_unk1, num_types_with_unk2)
        self.assertDictEqual({1: 1, 2: 2}, type2row_dict1)
        self.assertDictEqual({1: 1}, type2row_dict2)
        assert torch.equal(eid2typeids_table1, eid2typeids_table2)
        # Check that the alias mappings are different
        alias2entity_table1 = torch.from_numpy(
            np.memmap(
                self.save_dir / "prep" / "alias2entity_table_alias2qids_InC1.pt",
                dtype="int64",
                mode="r",
                shape=(5, 30),
            )
        )
        gold_alias2entity_table1 = torch.tensor(
            [
                [
                    4,
                    1,
                    3,
                    2,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    2,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    2,
                    3,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
            ]
        )
        assert torch.equal(alias2entity_table1, gold_alias2entity_table1)
        # The change is the "cat" alias has entity 1 added to the beginning
        # It used to only point to Q345 which is entity 2
        alias2entity_table2 = torch.from_numpy(
            np.memmap(
                self.save_dir2 / "prep" / "alias2entity_table_alias2qids_InC1.pt",
                dtype="int64",
                mode="r",
                shape=(5, 30),
            )
        )
        gold_alias2entity_table2 = torch.tensor(
            [
                [
                    4,
                    1,
                    3,
                    2,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    2,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    2,
                    3,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
            ]
        )
        assert torch.equal(alias2entity_table2, gold_alias2entity_table2)

        # The type embeddings were frozen so they should be the same
        model1 = torch.load(saved_model_path1)
        model2 = torch.load(saved_model_path2)
        assert torch.equal(
            model1["model"]["module_pool"]["learned_type"]["type_emb.weight"],
            model2["model"]["module_pool"]["learned_type"]["type_emb.weight"],
        )
コード例 #9
0
ファイル: test_annotator.py プロジェクト: lorr1/bootleg
    def test_annotator(self):
        torch.multiprocessing.set_start_method("fork", force=True)
        # Just to make it go faster
        self.args["learner_config"]["n_epochs"] = 5
        # First train some model so we have it stored
        run_model(mode="train", config=self.args)

        self.args["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"
        emmental.Meta.config["model_config"][
            "model_path"] = f"{emmental.Meta.log_path}/last_model.pth"

        ann = BootlegAnnotator(config=self.args, verbose=True)
        # TEST SINGLE TEXT
        # Res should have alias1
        res = ann.label_mentions(
            "alias1 alias2 multi word alias3 I have no idea")
        gold_ans = {
            "qids": [["Q1"]],
            "titles": [["alias1"]],
            "cands": [[["Q1", "Q4", "-1"]]],
            "spans": [[[0, 1]]],
            "aliases": [["alias1"]],
        }
        for k in gold_ans:
            self.assertListEqual(gold_ans[k], res[k])

        # TEST LONG TEXT
        # Res should have alias1
        res = ann.label_mentions([
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea",
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea",
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea. "
            "alias1 alias2 multi word alias3 I have no idea",
        ])
        gold_ans = {
            "qids": [["Q1"] * 8] * 3,
            "titles": [["alias1"] * 8] * 3,
            "cands": [[["Q1", "Q4", "-1"]] * 8] * 3,
            "spans": [
                [
                    [0, 1],
                    [9, 10],
                    [18, 19],
                    [27, 28],
                    [36, 37],
                    [45, 46],
                    [54, 55],
                    [63, 64],
                ],
                [
                    [0, 1],
                    [9, 10],
                    [18, 19],
                    [27, 28],
                    [36, 37],
                    [45, 46],
                    [54, 55],
                    [63, 64],
                ],
                [
                    [0, 1],
                    [9, 10],
                    [18, 19],
                    [27, 28],
                    [36, 37],
                    [45, 46],
                    [54, 55],
                    [63, 64],
                ],
            ],
            "aliases": [["alias1"] * 8] * 3,
        }
        for k in gold_ans:
            self.assertListEqual(gold_ans[k], res[k])

        # TEST RETURN EMBS
        ann.return_embs = True
        res = ann.label_mentions(
            "alias1 alias2 multi word alias3 I have no idea")
        assert "embs" in res
        assert res["embs"][0][0].shape[0] == 20
        assert list(res["cand_embs"][0][0].shape) == [3, 20]

        # TEST CUSTOM CANDS
        ann.return_embs = False
        extracted_exs = [
            {
                "sentence": "alias1 alias2 multi word alias3 I have no idea",
                "aliases": ["alias3"],
                "spans": [[0, 1]],
                "cands": [["Q3"]],
            },
            {
                "sentence": "alias1 alias2 multi word alias3 I have no idea. "
                "alias1 alias2 multi word alias3 I have no idea. ",
                "aliases": ["alias1", "alias3", "alias1"],
                "spans": [[0, 1], [1, 2], [9, 10]],
                "cands": [["Q2"], ["Q3"], ["Q2"]],
            },
        ]
        res = ann.label_mentions(extracted_examples=extracted_exs)
        gold_ans = {
            "qids": [["Q3"], ["Q2", "Q3", "Q2"]],
            "titles": [
                ["word alias3"],
                ["multi alias2", "word alias3", "multi alias2"],
            ],
            "cands": [
                [["Q3", "-1", "-1"]],
                [["Q2", "-1", "-1"], ["Q3", "-1", "-1"], ["Q2", "-1", "-1"]],
            ],
            "spans": [[[0, 1]], [[0, 1], [1, 2], [9, 10]]],
            "aliases": [["alias3"], ["alias1", "alias3", "alias1"]],
        }
        for k in gold_ans:
            self.assertListEqual(gold_ans[k], res[k])
コード例 #10
0
config_args = load_yaml_file(config_in_path)

# decrease number of data threads as this is a small file
config_args["run_config"]["dataset_threads"] = 2
config_args["run_config"]["log_level"] = "info"
# set the model checkpoint path
config_args["emmental"]["model_path"] = str(
    root_dir / "models/bootleg_wiki/bootleg_wiki.pth"
)
config_args["emmental"]["log_path"] = str(Path(args.data_dir) / "bootleg_results")
# set the path for the entity db and candidate map
config_args["data_config"]["entity_dir"] = str(root_dir / "data/wiki_entity_data")
config_args["data_config"]["alias_cand_map"] = str(cand_map)

config_args["data_config"]["data_dir"] = args.data_dir
config_args["data_config"]["test_dataset"]["file"] = args.outfile_name

# set the embedding paths
config_args["data_config"]["emb_dir"] = str(root_dir / "data/emb_data")
config_args["data_config"]["word_embedding"]["cache_dir"] = str(
    root_dir / "data/emb_data/pretrained_bert_models"
)


config_args = parse_boot_and_emm_args(
    config_args
)  # or you can pass in the config_out_path

bootleg_label_file, bootleg_emb_file = run_model(mode="dump_embs", config=config_args)