예제 #1
0
def create_config(model_path, data_path, model_name):
    """Creates Bootleg config.

    Args:
        model_path: model directory
        data_path: data directory
        model_name: model name

    Returns: updated config
    """
    config_file = model_path / model_name / "bootleg_config.yaml"
    config_args = load_yaml_file(config_file)

    # set the model checkpoint path
    config_args["emmental"]["model_path"] = str(model_path / model_name /
                                                "bootleg_wiki.pth")

    # set the path for the entity db and candidate map
    config_args["data_config"]["entity_dir"] = str(data_path / "entity_db")
    config_args["data_config"]["alias_cand_map"] = "alias2qids.json"

    # set the embedding paths
    config_args["data_config"]["emb_dir"] = str(data_path / "entity_db")
    config_args["data_config"]["word_embedding"]["cache_dir"] = str(
        data_path / "pretrained_bert_models")

    # set log path
    config_args["emmental"]["log_path"] = str(data_path / "log_dir")

    config_args = parse_boot_and_emm_args(config_args)
    return config_args
예제 #2
0
 def setUp(self) -> None:
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_end2end.json")
     # This _MUST_ get passed the args so it gets a random seed set
     emmental.init(log_dir="test/temp_log", config=self.args)
     if not os.path.exists(emmental.Meta.log_path):
         os.makedirs(emmental.Meta.log_path)
예제 #3
0
 def setUp(self) -> None:
     emmental.init(log_dir="test/temp_log")
     if not os.path.exists(emmental.Meta.log_path):
         os.makedirs(emmental.Meta.log_path)
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_embeddings.json"
     )
     self.entity_symbols = EntitySymbolsSubclass()
     self.tokenizer = load_tokenizer()
예제 #4
0
 def setUp(self) -> None:
     emmental.init(log_dir="test/temp_log")
     if not os.path.exists(emmental.Meta.log_path):
         os.makedirs(emmental.Meta.log_path)
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_embeddings.json"
     )
     self.entity_symbols = EntitySymbolsSubclass()
     self.kg_adj = os.path.join(self.args.data_config.emb_dir, "kg_adj.txt")
     self.kg_adj_json = os.path.join(self.args.data_config.emb_dir, "kg_adj.json")
예제 #5
0
 def setUp(self) -> None:
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_embeddings.json"
     )
     emmental.init(log_dir="test/temp_log", config=self.args)
     if not os.path.exists(emmental.Meta.log_path):
         os.makedirs(emmental.Meta.log_path)
     self.args.data_config.ent_embeddings = [
         DottedDict(
             {
                 "key": "learned1",
                 "load_class": "LearnedEntityEmb",
                 "dropout1d": 0.5,
                 "args": {"learned_embedding_size": 5, "tail_init": False},
             }
         ),
         DottedDict(
             {
                 "key": "learned2",
                 "dropout2d": 0.5,
                 "load_class": "LearnedEntityEmb",
                 "args": {"learned_embedding_size": 5, "tail_init": False},
             }
         ),
         DottedDict(
             {
                 "key": "learned3",
                 "load_class": "LearnedEntityEmb",
                 "freeze": True,
                 "args": {"learned_embedding_size": 5, "tail_init": False},
             }
         ),
         DottedDict(
             {
                 "key": "learned4",
                 "load_class": "LearnedEntityEmb",
                 "normalize": False,
                 "args": {"learned_embedding_size": 5, "tail_init": False},
             }
         ),
         DottedDict(
             {
                 "key": "learned5",
                 "load_class": "LearnedEntityEmb",
                 "cpu": True,
                 "args": {"learned_embedding_size": 5, "tail_init": False},
             }
         ),
     ]
     self.tokenizer = load_tokenizer()
     self.entity_symbols = EntitySymbolsSubclass()
예제 #6
0
def parse_cmdline_args():
    """Takes an input config file and parses it into the correct subdictionary
    groups for the model.

    Returns:
        model run mode of train, eval, or dumping
        parsed Dict config
        path to original config path
    """
    # Parse cmdline args to specify config and mode
    cli_parser = argparse.ArgumentParser(
        description="Bootleg CLI Config",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    cli_parser.add_argument(
        "--config_script",
        type=str,
        default="",
        help=
        "Should mimic the config_args found in utils/parser/bootleg_args.py with parameters you want to override."
        "You can also override the parameters from config_script by passing them in directly after config_script. "
        "E.g., --train_config.batch_size 5",
    )
    cli_parser.add_argument(
        "--mode",
        type=str,
        default="train",
        choices=["train", "eval", "dump_preds", "dump_embs"],
    )
    cli_parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help=
        "When using torch.distributed it passes local_rank as command arg. We must capture it here.",
    )

    # you can add other args that will override those in the config_script
    # parse_known_args returns 'args' that are the same as what parse_args() returns
    # and 'unknown' which are args that the parser doesn't recognize but you want to keep.
    # 'unknown' are what we pass on to our override any args from the second phase of arg parsing from the json file
    cli_args, unknown = cli_parser.parse_known_args()
    if len(cli_args.config_script) == 0:
        raise ValueError(f"You must pass a config script via --config.")
    config = parse_boot_and_emm_args(cli_args.config_script, unknown)

    #  Modify the local rank param from the cli args
    config.learner_config.local_rank = cli_args.local_rank
    mode = cli_args.mode
    return mode, config, cli_args.config_script
예제 #7
0
 def setUp(self) -> None:
     emmental.init(log_dir="test/temp_log")
     if not os.path.exists(emmental.Meta.log_path):
         os.makedirs(emmental.Meta.log_path)
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_embeddings.json"
     )
     self.entity_symbols = EntitySymbolsSubclass()
     self.entity_symbols_extra = EntitySymbolsSubclassExtra()
     self.type_file = os.path.join(
         self.args.data_config.emb_dir, "temp_type_file.json"
     )
     self.type_vocab_file = os.path.join(
         self.args.data_config.emb_dir, "temp_type_vocab.json"
     )
     self.regularization_csv = os.path.join(
         self.args.data_config.data_dir, "test_reg.csv"
     )
예제 #8
0
 def setUp(self):
     """ENTITY SYMBOLS
      {
        "multi word alias2":[["Q2",5.0],["Q1",3.0],["Q4",2.0]],
        "alias1":[["Q1",10.0],["Q4",6.0]],
        "alias3":[["Q1",30.0]],
        "alias4":[["Q4",20.0],["Q3",15.0],["Q2",1.0]]
      }
      EMBEDDINGS
      {
          "key": "learned",
          "freeze": false,
          "load_class": "LearnedEntityEmb",
          "args":
          {
            "learned_embedding_size": 10,
          }
      },
      {
         "key": "learned_type",
         "load_class": "LearnedTypeEmb",
         "freeze": false,
         "args": {
             "type_labels": "type_pred_mapping.json",
             "max_types": 1,
             "type_dim": 5,
             "merge_func": "addattn",
             "attn_hidden_size": 5
         }
     }
     """
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_model_training.json")
     self.entity_symbols = EntitySymbols.load_from_cache(
         os.path.join(self.args.data_config.entity_dir,
                      self.args.data_config.entity_map_dir),
         alias_cand_map_file=self.args.data_config.alias_cand_map,
     )
     emmental.init(log_dir="test/temp_log")
     if not os.path.exists(emmental.Meta.log_path):
         os.makedirs(emmental.Meta.log_path)
예제 #9
0
 def setUp(self):
     emmental.init(log_dir="test/temp_log")
     if not os.path.exists(emmental.Meta.log_path):
         os.makedirs(emmental.Meta.log_path)
     self.entity_symbols = EntitySymbolsSubclass()
     self.hidden_size = 30
     self.learned_embedding_size = 50
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_embeddings.json"
     )
     self.regularization_csv = os.path.join(
         self.args.data_config.data_dir, "test_reg.csv"
     )
     self.static_emb = os.path.join(self.args.data_config.data_dir, "static_emb.pt")
     self.qid2topkeid = os.path.join(
         self.args.data_config.data_dir, "test_eid2topk.json"
     )
     self.args.model_config.hidden_size = self.hidden_size
     self.args.data_config.ent_embeddings[0]["args"] = DottedDict(
         {"learned_embedding_size": self.learned_embedding_size}
     )
예제 #10
0
 def setUp(self):
     # tests that the sampling is done correctly on indices
     # load data from directory
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_type_data.json")
     self.tokenizer = BertTokenizer.from_pretrained(
         "bert-base-cased",
         do_lower_case=False,
         cache_dir="test/data/emb_data/pretrained_bert_models",
     )
     self.is_bert = True
     self.entity_symbols = EntitySymbols.load_from_cache(
         os.path.join(self.args.data_config.entity_dir,
                      self.args.data_config.entity_map_dir),
         alias_cand_map_file=self.args.data_config.alias_cand_map,
     )
     self.temp_file_name = "test/data/data_loader/test_data.jsonl"
     self.guid_dtype = lambda max_aliases: np.dtype([
         ("sent_idx", "i8", 1),
         ("subsent_idx", "i8", 1),
         ("alias_orig_list_pos", "i8", max_aliases),
     ])
예제 #11
0
 def create_config(self, overrides):
     config_args = parse_boot_and_emm_args(self.config_path, overrides)
     return config_args
예제 #12
0
 def setUp(self) -> None:
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_embeddings.json"
     )
     self.entity_symbols = EntitySymbolsSubclass()
예제 #13
0
    def __init__(
        self,
        config=None,
        device=None,
        max_alias_len=6,
        cand_map=None,
        threshold=0.0,
        cache_dir=None,
        model_name=None,
        verbose=False,
    ):
        self.max_alias_len = (
            max_alias_len  # minimum probability of prediction to return mention
        )
        self.verbose = verbose
        self.threshold = threshold

        if not cache_dir:
            self.cache_dir = get_default_cache()
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"
        else:
            self.cache_dir = Path(cache_dir)
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"

        if not model_name:
            model_name = "bootleg_uncased"

        assert model_name in {
            "bootleg_cased",
            "bootleg_cased_mini",
            "bootleg_uncased",
            "bootleg_uncased_mini",
        }, (f"model_name must be one of [bootleg_cased, bootleg_cased_mini, "
            f"bootleg_uncased_mini, bootleg_uncased]. You have {model_name}.")

        if not config:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            self.model_path.mkdir(parents=True, exist_ok=True)
            self.data_path.mkdir(parents=True, exist_ok=True)
            create_sources(self.model_path, self.data_path, model_name)
            self.config = create_config(self.model_path, self.data_path,
                                        model_name)
        else:
            if "emmental" in config:
                config = parse_boot_and_emm_args(config)
            self.config = config
            # Ensure some of the critical annotator args are the correct type
            self.config.data_config.max_aliases = int(
                self.config.data_config.max_aliases)
            self.config.run_config.eval_batch_size = int(
                self.config.run_config.eval_batch_size)
            self.config.data_config.max_seq_len = int(
                self.config.data_config.max_seq_len)
            self.config.data_config.train_in_candidates = bool(
                self.config.data_config.train_in_candidates)

        if not device:
            device = 0 if torch.cuda.is_available() else -1

        if self.verbose:
            self.config.run_config.log_level = "DEBUG"
        else:
            self.config.run_config.log_level = "INFO"

        self.torch_device = (torch.device(device)
                             if device != -1 else torch.device("cpu"))
        self.config.model_config.device = device

        log_level = logging.getLevelName(
            self.config["run_config"]["log_level"].upper())
        emmental.init(
            log_dir=self.config["meta_config"]["log_path"],
            config=self.config,
            use_exact_log_path=self.config["meta_config"]
            ["use_exact_log_path"],
            level=log_level,
        )

        logger.debug("Reading entity database")
        self.entity_db = EntitySymbols.load_from_cache(
            os.path.join(
                self.config.data_config.entity_dir,
                self.config.data_config.entity_map_dir,
            ),
            alias_cand_map_file=self.config.data_config.alias_cand_map,
            alias_idx_file=self.config.data_config.alias_idx_map,
        )
        logger.debug("Reading word tokenizers")
        self.tokenizer = BertTokenizer.from_pretrained(
            self.config.data_config.word_embedding.bert_model,
            do_lower_case=True if "uncased"
            in self.config.data_config.word_embedding.bert_model else False,
            cache_dir=self.config.data_config.word_embedding.cache_dir,
        )

        # Create tasks
        tasks = [NED_TASK]
        if self.config.data_config.type_prediction.use_type_pred is True:
            tasks.append(TYPE_PRED_TASK)
        self.task_to_label_dict = {t: NED_TASK_TO_LABEL[t] for t in tasks}

        # Create tasks
        self.model = EmmentalModel(name="Bootleg")
        self.model.add_task(ned_task.create_task(self.config, self.entity_db))
        if TYPE_PRED_TASK in tasks:
            self.model.add_task(
                type_pred_task.create_task(self.config, self.entity_db))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(self.model)

        logger.debug("Loading model")
        # Load the best model from the pretrained model
        assert (
            self.config["model_config"]["model_path"] is not None
        ), f"Must have a model to load in the model_path for the BootlegAnnotator"
        self.model.load(self.config["model_config"]["model_path"])
        self.model.eval()
        if cand_map is None:
            alias_map = self.entity_db.get_alias2qids()
        else:
            logger.debug(f"Loading candidate map")
            alias_map = ujson.load(open(cand_map))

        self.all_aliases_trie = get_all_aliases(alias_map, verbose)

        logger.debug("Reading in alias table")
        self.alias2cands = AliasEntityTable(
            data_config=self.config.data_config, entity_symbols=self.entity_db)

        # get batch_on_the_fly embeddings
        self.batch_on_the_fly_embs = get_dataloader_embeddings(
            self.config, self.entity_db)
예제 #14
0
    def test_end2end(self):
        # ======================
        # PART 1: TRAIN A SMALL MODEL WITH ONE PROFILE DUMP
        # ======================
        # Generate entity profile data
        data = [
            {
                "entity_id": "Q123",
                "mentions": [["dog", 10.0], ["dogg", 7.0], ["animal", 4.0]],
                "title": "Dog",
                "types": {"hyena": ["animal"], "wiki": ["dog"]},
                "relations": [
                    {"relation": "sibling", "object": "Q345"},
                    {"relation": "sibling", "object": "Q567"},
                ],
            },
            {
                "entity_id": "Q345",
                "mentions": [["cat", 10.0], ["catt", 7.0], ["animal", 3.0]],
                "title": "Cat",
                "types": {"hyena": ["animal"], "wiki": ["cat"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
            # Missing type system
            {
                "entity_id": "Q567",
                "mentions": [["catt", 6.5], ["animal", 3.3]],
                "title": "Catt",
                "types": {"hyena": ["animal", "animall"]},
                "relations": [{"relation": "sibling", "object": "Q123"}],
            },
            # No KG/Types
            {
                "entity_id": "Q789",
                "mentions": [["animal", 12.2]],
                "title": "Dogg",
            },
        ]
        # Generate train data
        train_data = [
            {
                "sent_idx_unq": 1,
                "sentence": "I love animals and dogs",
                "qids": ["Q567", "Q123"],
                "aliases": ["animal", "dog"],
                "gold": [True, True],
                "spans": [[2, 3], [4, 5]],
            }
        ]
        self.write_data(self.profile_file, data)
        self.write_data(self.train_data, train_data)
        entity_profile = EntityProfile.load_from_jsonl(self.profile_file)
        # Dump profile data in format for model
        entity_profile.save(self.save_dir)

        # Setup model args to read the data/new profile data
        raw_args = {
            "emmental": {
                "n_epochs": 1,
            },
            "run_config": {
                "dataloader_threads": 1,
                "dataset_threads": 1,
            },
            "train_config": {"batch_size": 2},
            "model_config": {"hidden_size": 20, "num_heads": 1},
            "data_config": {
                "entity_dir": str(self.save_dir),
                "max_seq_len": 7,
                "max_aliases": 2,
                "data_dir": str(self.data_dir),
                "emb_dir": str(self.save_dir),
                "word_embedding": {
                    "layers": 1,
                    "freeze": True,
                    "cache_dir": str(self.save_dir / "retrained_bert_model"),
                },
                "ent_embeddings": [
                    {
                        "key": "learned",
                        "freeze": False,
                        "load_class": "LearnedEntityEmb",
                        "args": {"learned_embedding_size": 10},
                    },
                    {
                        "key": "learned_type",
                        "load_class": "LearnedTypeEmb",
                        "freeze": True,
                        "args": {
                            "type_labels": f"{TYPE_SUBFOLDER}/wiki/qid2typeids.json",
                            "type_vocab": f"{TYPE_SUBFOLDER}/wiki/type_vocab.json",
                            "max_types": 2,
                            "type_dim": 10,
                        },
                    },
                    {
                        "key": "kg_adj",
                        "load_class": "KGIndices",
                        "batch_on_the_fly": True,
                        "normalize": False,
                        "args": {"kg_adj": f"{KG_SUBFOLDER}/kg_adj.txt"},
                    },
                ],
                "train_dataset": {"file": "train.jsonl"},
                "dev_dataset": {"file": "train.jsonl"},
                "test_dataset": {"file": "train.jsonl"},
            },
        }
        with open(self.arg_file, "w") as out_f:
            ujson.dump(raw_args, out_f)

        args = parser_utils.parse_boot_and_emm_args(str(self.arg_file))
        # This _MUST_ get passed the args so it gets a random seed set
        emmental.init(log_dir=str(self.dir / "temp_log"), config=args)
        if not os.path.exists(emmental.Meta.log_path):
            os.makedirs(emmental.Meta.log_path)

        scores = run_model(mode="train", config=args)
        saved_model_path1 = f"{emmental.Meta.log_path}/last_model.pth"
        assert type(scores) is dict

        # ======================
        # PART 3: MODIFY PROFILE AND LOAD PRETRAINED MODEL AND TRAIN FOR MORE
        # ======================
        entity_profile = EntityProfile.load_from_jsonl(
            self.profile_file, edit_mode=True
        )
        entity_profile.add_type("Q123", "cat", "wiki")
        entity_profile.remove_type("Q123", "dog", "wiki")
        entity_profile.add_mention("Q123", "cat", 100.0)
        # Dump profile data in format for model
        entity_profile.save(self.save_dir2)

        # Modify arg paths
        args["data_config"]["entity_dir"] = str(self.save_dir2)
        args["data_config"]["emb_dir"] = str(self.save_dir2)

        # Load pretrained model
        args["model_config"]["model_path"] = f"{emmental.Meta.log_path}/last_model.pth"
        emmental.Meta.config["model_config"]["model_path"] = saved_model_path1

        # Init another run
        emmental.init(log_dir=str(self.dir / "temp_log"), config=args)
        if not os.path.exists(emmental.Meta.log_path):
            os.makedirs(emmental.Meta.log_path)
        scores = run_model(mode="train", config=args)
        saved_model_path2 = f"{emmental.Meta.log_path}/last_model.pth"
        assert type(scores) is dict

        # ======================
        # PART 4: VERIFY CHANGES IN THE MODEL WERE AS EXPECTED
        # ======================
        # Check that type mappings are different in the right way...we remove "dog"
        # from EID 1 and added "cat". "dog" is not longer a type.
        eid2typeids_table1, type2row_dict1, num_types_with_unk1 = torch.load(
            self.save_dir / "prep" / "type_table_type_mappings_wiki_qid2typeids_2.pt"
        )
        eid2typeids_table2, type2row_dict2, num_types_with_unk2 = torch.load(
            self.save_dir2 / "prep" / "type_table_type_mappings_wiki_qid2typeids_2.pt"
        )
        # Modify mapping 2 to become mapping 1
        # Row 1 is Q123, Col 0 is type (this was "cat")
        eid2typeids_table2[1][0] = entity_profile._type_systems["wiki"]._type_vocab[
            "dog"
        ]
        self.assertEqual(num_types_with_unk1, num_types_with_unk2)
        self.assertDictEqual({1: 1, 2: 2}, type2row_dict1)
        self.assertDictEqual({1: 1}, type2row_dict2)
        assert torch.equal(eid2typeids_table1, eid2typeids_table2)
        # Check that the alias mappings are different
        alias2entity_table1 = torch.from_numpy(
            np.memmap(
                self.save_dir / "prep" / "alias2entity_table_alias2qids_InC1.pt",
                dtype="int64",
                mode="r",
                shape=(5, 30),
            )
        )
        gold_alias2entity_table1 = torch.tensor(
            [
                [
                    4,
                    1,
                    3,
                    2,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    2,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    2,
                    3,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
            ]
        )
        assert torch.equal(alias2entity_table1, gold_alias2entity_table1)
        # The change is the "cat" alias has entity 1 added to the beginning
        # It used to only point to Q345 which is entity 2
        alias2entity_table2 = torch.from_numpy(
            np.memmap(
                self.save_dir2 / "prep" / "alias2entity_table_alias2qids_InC1.pt",
                dtype="int64",
                mode="r",
                shape=(5, 30),
            )
        )
        gold_alias2entity_table2 = torch.tensor(
            [
                [
                    4,
                    1,
                    3,
                    2,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    2,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    2,
                    3,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
                [
                    1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                    -1,
                ],
            ]
        )
        assert torch.equal(alias2entity_table2, gold_alias2entity_table2)

        # The type embeddings were frozen so they should be the same
        model1 = torch.load(saved_model_path1)
        model2 = torch.load(saved_model_path2)
        assert torch.equal(
            model1["model"]["module_pool"]["learned_type"]["type_emb.weight"],
            model2["model"]["module_pool"]["learned_type"]["type_emb.weight"],
        )
예제 #15
0
config_args = load_yaml_file(config_in_path)

# decrease number of data threads as this is a small file
config_args["run_config"]["dataset_threads"] = 2
config_args["run_config"]["log_level"] = "info"
# set the model checkpoint path
config_args["emmental"]["model_path"] = str(
    root_dir / "models/bootleg_wiki/bootleg_wiki.pth"
)
config_args["emmental"]["log_path"] = str(Path(args.data_dir) / "bootleg_results")
# set the path for the entity db and candidate map
config_args["data_config"]["entity_dir"] = str(root_dir / "data/wiki_entity_data")
config_args["data_config"]["alias_cand_map"] = str(cand_map)

config_args["data_config"]["data_dir"] = args.data_dir
config_args["data_config"]["test_dataset"]["file"] = args.outfile_name

# set the embedding paths
config_args["data_config"]["emb_dir"] = str(root_dir / "data/emb_data")
config_args["data_config"]["word_embedding"]["cache_dir"] = str(
    root_dir / "data/emb_data/pretrained_bert_models"
)


config_args = parse_boot_and_emm_args(
    config_args
)  # or you can pass in the config_out_path

bootleg_label_file, bootleg_emb_file = run_model(mode="dump_embs", config=config_args)