示例#1
0
    def load_from_cache(cls,
                        load_dir,
                        prefix="",
                        edit_mode=False,
                        verbose=False):
        """Loads type symbols from load_dir.

        Args:
            load_dir: directory to load from
            prefix: prefix to add to beginning to file
            edit_mode: edit mode flag
            verbose: verbose flag

        Returns: TypeSymbols
        """
        config = utils.load_json_file(
            filename=os.path.join(load_dir, "config.json"))
        max_types = config["max_types"]
        qid2typenames: Dict[str, List[str]] = utils.load_json_file(
            filename=os.path.join(load_dir, f"{prefix}qid2typenames.json"))
        qid2typeid: Dict[str, List[int]] = utils.load_json_file(
            filename=os.path.join(load_dir, f"{prefix}qid2typeids.json"))
        type_vocab: Dict[str, int] = utils.load_json_file(
            filename=os.path.join(load_dir, f"{prefix}type_vocab.json"))
        return cls(qid2typenames, qid2typeid, type_vocab, max_types, edit_mode,
                   verbose)
示例#2
0
 def load(self, load_dir):
     config = utils.load_json_file(filename=os.path.join(load_dir, "config.json"))
     self.max_candidates = config["max_candidates"]
     self.max_alias_len = config["max_alias_len"]
     self._alias2qids: Dict[str, list] = utils.load_json_file(filename=os.path.join(load_dir, self.alias_cand_map_file))
     self._qid2title: Dict[str, str] = utils.load_json_file(filename=os.path.join(load_dir, "qid2title.json"))
     self._qid2eid: Dict[str, int] = utils.load_json_file(filename=os.path.join(load_dir, "qid2eid.json"))
     self._sort_alias_cands()
示例#3
0
 def load_types(self, entity_symbols, emb_dir, max_types):
     """Loads all type information"""
     # load type vocab
     if self.type_vocab_file == "":
         print(
             "You did not give a type vocab file (from type name to typeid). We will use identity mapping"
         )
         typeid2typename = {}
     else:
         extension = os.path.splitext(self.type_vocab_file)[-1]
         if extension == ".json":
             type_vocab = utils.load_json_file(
                 os.path.join(emb_dir, self.type_vocab_file))
         else:
             print(
                 f"We only support loading json files for TypeSymbol. You have a file ending in {extension}"
             )
             return {}, {}, {}
         typeid2typename = {i: v for v, i in type_vocab.items()}
     # load mapping of entities to type ids
     qid2typenames = {qid: [] for qid in entity_symbols.get_all_qids()}
     qid2typeid = {qid: [] for qid in entity_symbols.get_all_qids()}
     print(f"Loading types from {self.type_file}")
     qid2typeid, qid2typenames = self.load_type_file(
         type_file=self.type_file,
         max_types=max_types,
         entity_symbols=entity_symbols,
         qid2typeid=qid2typeid,
         qid2typenames=qid2typenames,
         typeid2typename=typeid2typename)
     return qid2typenames, qid2typeid, typeid2typename
示例#4
0
def main():
    args = parse_args()
    print(json.dumps(args, indent=4))
    assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1"
    state_dict, model_state_dict = load_statedict(args.init_checkpoint)
    assert ENTITY_EMB_KEY in model_state_dict
    print(f"Loading entity symbols from {os.path.join(args.entity_dir, args.entity_map_dir)}")
    entity_db = EntitySymbols(os.path.join(args.entity_dir, args.entity_map_dir), args.alias_cand_map_file)
    print(f"Loading qid2count from {args.qid2count}")
    qid2count = utils.load_json_file(args.qid2count)
    print(f"Filtering qids")
    qid2topk_eid, old2new_eid, toes_eid, new_num_topk_entities = filter_qids(args.perc_emb_drop, entity_db, qid2count)
    print(f"Filtering embeddings")
    model_state_dict = filter_embs(new_num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, model_state_dict)
    # Generate the new old2new_eid weight vector to save in model_state_dict
    oldeid2topkeid = torch.arange(0, entity_db.num_entities_with_pad_and_nocand)
    # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing for entity embeddings. So we must manually make it the last entry
    oldeid2topkeid[-1] = new_num_topk_entities+2-1
    for qid, new_eid in qid2topk_eid.items():
        old_eid = entity_db.get_eid(qid)
        oldeid2topkeid[old_eid] = new_eid
    assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
    assert oldeid2topkeid[-1] == new_num_topk_entities+2-1, "The -1 eid should still map to the last row"
    model_state_dict[ENTITY_EID_KEY] = oldeid2topkeid
    state_dict["model"] = model_state_dict
    print(f"Saving model at {args.save_checkpoint}")
    torch.save(state_dict, args.save_checkpoint)
    print(f"Saving entity_db at {os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file)}")
    utils.dump_json_file(os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file), qid2topk_eid)
示例#5
0
    def load_from_cache(
        cls,
        load_dir,
        alias_cand_map_file="alias2qids.json",
        alias_idx_file="alias2id.json",
        edit_mode=False,
        verbose=False,
    ):
        """Loads entity symbols from load_dir.

        Args:
            load_dir: directory to load from
            alias_cand_map_file: alias2qid file
            alias_idx_file: alias2id file
            edit_mode: edit mode flag
            verbose: verbose flag

        Returns:
        """
        config = utils.load_json_file(filename=os.path.join(load_dir, "config.json"))
        max_candidates = config["max_candidates"]
        alias2qids: Dict[str, list] = utils.load_json_file(
            filename=os.path.join(load_dir, alias_cand_map_file)
        )
        qid2title: Dict[str, str] = utils.load_json_file(
            filename=os.path.join(load_dir, "qid2title.json")
        )
        qid2eid: Dict[str, int] = utils.load_json_file(
            filename=os.path.join(load_dir, "qid2eid.json")
        )
        alias2id: Dict[str, int] = utils.load_json_file(
            filename=os.path.join(load_dir, alias_idx_file)
        )
        return cls(
            alias2qids,
            qid2title,
            qid2eid,
            alias2id,
            max_candidates,
            alias_cand_map_file,
            alias_idx_file,
            edit_mode,
            verbose,
        )
示例#6
0
def get_slice_stats(num_processes, file):
    """Gets true anchor slice counts."""
    pool = multiprocessing.Pool(processes=num_processes)
    final_counts = defaultdict(int)
    final_slice_to_sent = defaultdict(set)
    final_sent_to_slices = defaultdict(lambda: defaultdict(int))
    temp_out_dir = os.path.join(os.path.dirname(file), "_temp")
    os.mkdir(temp_out_dir)

    all_lines = [li for li in open(file)]
    num_lines = len(all_lines)
    chunk_size = int(np.ceil(num_lines / num_processes))
    line_chunks = [
        all_lines[i:i + chunk_size] for i in range(0, num_lines, chunk_size)
    ]

    input_args = [[i, line_chunks[i], i * chunk_size, temp_out_dir]
                  for i in range(len(line_chunks))]

    for i in tqdm(
            pool.imap_unordered(get_slice_stats_hlp, input_args, chunksize=1),
            total=len(line_chunks),
            desc="Gathering slice counts",
    ):
        cnt_res = utils.load_json_file(
            os.path.join(temp_out_dir, f"{FINAL_COUNTS_PREFIX}_{i}.json"))
        sent_to_slices = utils.load_json_file(
            os.path.join(temp_out_dir,
                         f"{FINAL_SENT_TO_SLICE_PREFIX}_{i}.json"))
        slice_to_sent = utils.load_json_file(
            os.path.join(temp_out_dir,
                         f"{FINAL_SLICE_TO_SENT_PREFIX}_{i}.json"))
        for k in cnt_res:
            final_counts[k] += cnt_res[k]
        for k in slice_to_sent:
            final_slice_to_sent[k].update(slice_to_sent[k])
        for k in sent_to_slices:
            final_sent_to_slices[k].update(sent_to_slices[k])
    shutil.rmtree(temp_out_dir)
    return dict(final_counts), dict(final_slice_to_sent), dict(
        final_sent_to_slices)
示例#7
0
 def load(self, load_dir):
     self._fmt_types = utils.load_json_file(
         filename=os.path.join(load_dir, "fmt_types.json"))
     self._max_values = utils.load_json_file(
         filename=os.path.join(load_dir, "max_values.json"))
     self._stoi = utils.load_json_file(
         filename=os.path.join(load_dir, "vocabulary.json"))
     self._itos = np.load(file=os.path.join(load_dir, "itos.npy"),
                          allow_pickle=True)
     assert self._fmt_types.keys() == self._max_values.keys()
     for tri_name in self._fmt_types:
         assert f'record_trie_{tri_name}.marisa' in os.listdir(
             load_dir
         ), f"Missing record_trie_{tri_name}.marisa in {load_dir}"
     self._record_tris = {}
     for tri_name in self._fmt_types:
         self._record_tris[tri_name] = marisa_trie.RecordTrie(
             self._get_fmt_strings[self._fmt_types[tri_name]](
                 self._max_values[tri_name])).mmap(
                     os.path.join(load_dir,
                                  f'record_trie_{tri_name}.marisa'))
示例#8
0
 def build_static_embeddings(cls, emb_file, entity_symbols):
     """Builds the table of the embedding associated with each entity."""
     ending = os.path.splitext(emb_file)[1]
     if ending == ".glove":
         embeddings, embedding_size = data_utils.load_glove(emb_file,
                                                            log_func=print)
     elif ending == ".json":
         dct = utils.load_json_file(emb_file)
         val = next(iter(dct.values()))
         if type(val) is int or type(val) is float:
             embedding_size = 1
             conver_func = lambda x: np.array([x])
         elif type(val) is list:
             embedding_size = len(val)
             conver_func = lambda x: np.array([y for y in x])
         else:
             raise ValueError(
                 f"Unrecognized type for the array value of {type(val)}")
         embeddings = {}
         for k in dct:
             embeddings[k] = conver_func(dct[k])
             assert len(embeddings[k]) == embedding_size
     else:
         raise ValueError(
             f"We do not support static embeddings from {ending}. We only support .glove or .json"
         )
     entity2staticemb_table = torch.zeros(
         entity_symbols.num_entities_with_pad_and_nocand, embedding_size)
     found = 0
     for qid in tqdm(entity_symbols.get_all_qids()):
         if qid in embeddings:
             found += 1
             emb = embeddings[qid]
             eid = entity_symbols.get_eid(qid)
             entity2staticemb_table[
                 eid, :embedding_size] = torch.from_numpy(emb)
     print(
         f"Found {found/len(entity_symbols.get_all_qids())} percent of all entities have a static embedding"
     )
     return entity2staticemb_table
示例#9
0
    def __init__(
        self,
        main_args,
        dataset,
        use_weak_label,
        entity_symbols,
        dataset_threads,
        split="train",
    ):
        global_start = time.time()
        log_rank_0_info(logger,
                        f"Building slice dataset for {split} from {dataset}.")
        spawn_method = main_args.run_config.spawn_method
        data_config = main_args.data_config
        orig_spawn = multiprocessing.get_start_method()
        multiprocessing.set_start_method(spawn_method, force=True)
        self.slice_names = data_utils.get_eval_slices(data_config.eval_slices)
        self.get_slice_dt = lambda max_a2p: np.dtype([
            ("sent_idx", int),
            ("subslice_idx", int),
            ("alias_slice_incidence", int, (max_a2p, )),
            ("prob_labels", float, (max_a2p, )),
        ])
        self.get_storage = lambda max_a2p: np.dtype(
            [(slice_name, self.get_slice_dt(max_a2p))
             for slice_name in self.slice_names])
        # Folder for all mmap saved files
        save_dataset_folder = data_utils.get_save_data_folder(
            data_config, use_weak_label, dataset)
        utils.ensure_dir(save_dataset_folder)
        # Folder for temporary output files
        temp_output_folder = os.path.join(data_config.data_dir,
                                          data_config.data_prep_dir,
                                          f"prep_{split}_slice_files")
        utils.ensure_dir(temp_output_folder)
        # Input step 1
        create_ex_indir = os.path.join(temp_output_folder,
                                       "create_examples_input")
        utils.ensure_dir(create_ex_indir)
        # Input step 2
        create_ex_outdir = os.path.join(temp_output_folder,
                                        "create_examples_output")
        utils.ensure_dir(create_ex_outdir)
        # Meta data saved files
        meta_file = os.path.join(temp_output_folder, "meta_data.json")
        # File for standard training data
        hash = hashlib.sha1(str(
            self.slice_names).encode("UTF-8")).hexdigest()[:10]
        self.save_dataset_name = os.path.join(save_dataset_folder,
                                              f"ned_slices_{hash}.bin")
        self.save_data_config_name = os.path.join(save_dataset_folder,
                                                  "ned_slices_config.json")

        # =======================================================================================
        # SLICE DATA
        # =======================================================================================
        log_rank_0_debug(logger, "Loading dataset...")
        log_rank_0_debug(logger, f"Seeing if {self.save_dataset_name} exists")
        if data_config.overwrite_preprocessed_data or (not os.path.exists(
                self.save_dataset_name)):
            st_time = time.time()
            try:
                log_rank_0_info(
                    logger,
                    f"Building dataset from scratch. Saving to {save_dataset_folder}",
                )
                create_examples(
                    dataset,
                    create_ex_indir,
                    create_ex_outdir,
                    meta_file,
                    data_config,
                    dataset_threads,
                    self.slice_names,
                    use_weak_label,
                    split,
                )
                max_alias2pred = utils.load_json_file(
                    meta_file)["max_alias2pred"]
                convert_examples_to_features_and_save(
                    meta_file,
                    dataset_threads,
                    self.slice_names,
                    self.save_dataset_name,
                    self.get_storage(max_alias2pred),
                )
                utils.dump_json_file(self.save_data_config_name,
                                     {"max_alias2pred": max_alias2pred})

                log_rank_0_debug(
                    logger,
                    f"Finished prepping data in {time.time() - st_time}")
            except Exception as e:
                tb = traceback.TracebackException.from_exception(e)
                logger.error(e)
                logger.error("\n".join(tb.stack.format()))
                shutil.rmtree(save_dataset_folder, ignore_errors=True)
                raise

        log_rank_0_info(
            logger,
            f"Loading data from {self.save_dataset_name} and {self.save_data_config_name}",
        )
        max_alias2pred = utils.load_json_file(
            self.save_data_config_name)["max_alias2pred"]
        self.data, self.sent_to_row_id_dict = self.build_data_dict(
            self.save_dataset_name, self.get_storage(max_alias2pred))
        assert len(self.data) > 0
        assert len(self.sent_to_row_id_dict) > 0
        log_rank_0_debug(logger, f"Removing temporary output files")
        shutil.rmtree(temp_output_folder, ignore_errors=True)
        # Set spawn back to original/default, which is "fork" or "spawn". This is needed for the Meta.config to
        # be correctly passed in the collate_fn.
        multiprocessing.set_start_method(orig_spawn, force=True)
        log_rank_0_info(
            logger,
            f"Final slice data initialization time from {split} is {time.time() - global_start}s",
        )
示例#10
0
def convert_examples_to_features_and_save(meta_file, dataset_threads,
                                          slice_names, save_dataset_name,
                                          storage):
    """Converts the prepped examples into input features and saves in memmap
    files. These are used in the __get_item__ method.

    Args:
        meta_file: metadata file where input file paths are saved
        dataset_threads: number of threads
        slice_names: list of slice names to evaluation on
        save_dataset_name: data file name to save
        storage: data storage type (for memmap)

    Returns:
    """
    log_rank_0_debug(logger, "Starting to extract subsentences")
    start = time.time()
    num_processes = min(dataset_threads,
                        int(0.8 * multiprocessing.cpu_count()))

    log_rank_0_info(
        logger,
        f"Starting to build and save features with {num_processes} threads")

    log_rank_0_debug(logger, f"Counting lines")
    total_input = utils.load_json_file(meta_file)["num_mentions"]
    max_alias2pred = utils.load_json_file(meta_file)["max_alias2pred"]
    files_and_counts = utils.load_json_file(meta_file)["files_and_counts"]

    # IMPORTANT: for distributed writing to memmap files, you must create them in w+ mode before
    # being opened in r+ mode by workers
    memmap_file = np.memmap(save_dataset_name,
                            dtype=storage,
                            mode="w+",
                            shape=(total_input, ),
                            order="C")
    # Save -1 in sent_idx to check that things are loaded correctly later
    memmap_file[slice_names[0]]["sent_idx"][:] = -1

    input_args = []
    # Saves where in memap file to start writing
    offset = 0
    for i, in_file_name in enumerate(files_and_counts.keys()):
        input_args.append({
            "file_name": in_file_name,
            "in_file_lines": files_and_counts[in_file_name],
            "save_file_offset": offset,
            "ex_print_mod": int(np.ceil(total_input / 20)),
            "slice_names": slice_names,
            "max_alias2pred": max_alias2pred,
        })
        offset += files_and_counts[in_file_name]

    if num_processes == 1:
        assert len(input_args) == 1
        total_output = convert_examples_to_features_and_save_single(
            input_args[0], memmap_file)
    else:
        log_rank_0_debug(
            logger,
            "Initializing pool. This make take a few minutes.",
        )
        pool = multiprocessing.Pool(
            processes=num_processes,
            initializer=convert_examples_to_features_and_save_initializer,
            initargs=[save_dataset_name, storage],
        )

        total_output = 0
        for res in pool.imap_unordered(
                convert_examples_to_features_and_save_hlp, input_args,
                chunksize=1):
            total_output += res
        pool.close()

    # Verify that sentences are unique and saved correctly
    mmap_file = np.memmap(save_dataset_name, dtype=storage, mode="r")
    all_uniq_ids = set()
    for i in tqdm(range(total_input), desc="Checking sentence uniqueness"):
        assert (mmap_file[slice_names[0]]["sent_idx"][i] !=
                -1), f"Index {i} has -1 sent idx"
        uniq_id = str(
            f"{mmap_file[slice_names[0]]['sent_idx'][i]}.{mmap_file[slice_names[0]]['subslice_idx'][i]}"
        )
        assert (uniq_id not in all_uniq_ids
                ), f"Idx {uniq_id} is not unique and already in data"
        all_uniq_ids.add(uniq_id)

    log_rank_0_debug(
        logger,
        f"Done with extracting examples in {time.time() - start}. Total lines seen {total_input}. "
        f"Total lines kept {total_output}",
    )
    return
示例#11
0
    def __init__(self,
                 args,
                 use_weak_label,
                 input_src,
                 dataset_name,
                 is_writer,
                 distributed,
                 word_symbols,
                 entity_symbols,
                 slice_dataset=None,
                 dataset_is_eval=False):
        # Need to save args to reinstantiate logger
        self.args = args
        self.logger = logging_utils.get_logger(args)
        # Number of candidates, including NIL if a NIL model (train_in_candidates is False)
        self.K = entity_symbols.max_candidates + (
            not args.data_config.train_in_candidates)
        self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand
        self.dataset_name = dataset_name
        self.slice_dataset = slice_dataset
        self.dataset_is_eval = dataset_is_eval
        # Slice names used for eval slices and a slicing model
        self.slice_names = train_utils.get_data_slices(args, dataset_is_eval)
        self.storage_type_file = data_utils.get_storage_file(self.dataset_name)
        # Mappings from sent_idx to row_id in dataset
        self.sent_idx_file = os.path.splitext(
            dataset_name)[0] + "_sent_idx.json"
        self.type_pred = False
        if args.data_config.type_prediction.use_type_pred:
            self.type_pred = True
            self.eid2typeid, self.num_types_with_pad = self.load_coarse_type_table(
                args, entity_symbols)
        # Load memory mapped file
        self.logger.info("Loading dataset...")
        self.logger.debug("Seeing if " + dataset_name + " exists")
        if (args.data_config.overwrite_preprocessed_data
                or (not os.path.exists(self.dataset_name))
                or (not os.path.exists(self.sent_idx_file))
                or (not os.path.exists(self.storage_type_file))
                or (not os.path.exists(
                    data_utils.get_batch_prep_config(self.dataset_name)))):
            start = time.time()
            self.logger.debug(f"Building dataset with {input_src}")
            # Only prep data once per node
            if is_writer:
                prep_data(args,
                          use_weak_label=use_weak_label,
                          dataset_is_eval=self.dataset_is_eval,
                          input_src=input_src,
                          dataset_name=dataset_name,
                          prep_dir=data_utils.get_data_prep_dir(args))
            if distributed:
                # Make sure all processes wait for data to be created
                dist.barrier()
            self.logger.debug(
                f"Finished building and saving dataset in {round(time.time() - start, 2)}s."
            )

        start = time.time()

        # Storage type for loading memory mapped file of dataset
        self.storage_type = pickle.load(open(self.storage_type_file, 'rb'))

        self.data = np.memmap(self.dataset_name,
                              dtype=self.storage_type,
                              mode='r')
        self.data_len = len(self.data)

        # Mapping from sentence idx to rows in the dataset (indices).
        # Needed when sampling sentence indices from slices for evaluation.
        sent_idx_to_idx_str = utils.load_json_file(self.sent_idx_file)
        self.sent_idx_to_idx = {
            int(i): val
            for i, val in sent_idx_to_idx_str.items()
        }
        self.logger.info(f"Finished loading dataset.")

        # Stores info about the batch prepped embedding memory mapped files and their shapes and datatypes
        # so we can load them
        self.batch_prep_config = utils.load_json_file(
            data_utils.get_batch_prep_config(self.dataset_name))
        self.batch_prepped_emb_files = {}
        self.batch_prepped_emb_file_names = {}
        for emb in args.data_config.ent_embeddings:
            if 'batch_prep' in emb and emb['batch_prep']:
                assert emb.key in self.batch_prep_config, f'Need to prep {emb.key}. Please call prep instead of run with batch_prep_embeddings set to true.'
                self.batch_prepped_emb_file_names[emb.key] = os.path.join(
                    os.path.dirname(self.dataset_name),
                    os.path.basename(
                        self.batch_prep_config[emb.key]['file_name']))
                self.batch_prepped_emb_files[emb.key] = np.memmap(
                    self.batch_prepped_emb_file_names[emb.key],
                    dtype=self.batch_prep_config[emb.key]['dtype'],
                    shape=tuple(self.batch_prep_config[emb.key]['shape']),
                    mode='r')
                assert len(self.batch_prepped_emb_files[emb.key]) == self.data_len,\
                    f'Preprocessed emb data file {self.batch_prep_config[emb.key]["file_name"]} does not match length of main data file.'

        # Stores embeddings that we compute on the fly; these are embeddings where batch_on_the_fly is set to true.
        self.batch_on_the_fly_embs = {}
        for emb in args.data_config.ent_embeddings:
            if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True:
                mod, load_class = import_class("bootleg.embeddings",
                                               emb.load_class)
                try:
                    self.batch_on_the_fly_embs[emb.key] = getattr(
                        mod, load_class)(main_args=args,
                                         emb_args=emb['args'],
                                         entity_symbols=entity_symbols,
                                         model_device=None,
                                         word_symbols=None,
                                         key=emb.key)
                except AttributeError as e:
                    self.logger.warning(
                        f'No prep method found for {emb.load_class} with error {e}'
                    )
                except Exception as e:
                    print("ERROR", e)
        # The data in this table shouldn't be pickled since we delete it in the class __getstate__
        self.alias2entity_table = AliasEntityTable(
            args=args, entity_symbols=entity_symbols)
        # Random NIL percent
        self.mask_perc = args.train_config.random_nil_perc
        self.random_nil = False
        # Don't want to random mask for eval
        if not dataset_is_eval:
            # Whether to use a random NIL training regime
            self.random_nil = args.train_config.random_nil
            if self.random_nil:
                self.logger.info(
                    f'Using random nils during training with {self.mask_perc} percent'
                )
示例#12
0
    def build_static_embeddings(cls, emb_file, entity_symbols):
        """Builds the table of the embedding associated with each entity.

        Args:
            emb_file: embedding file to load
            entity_symbols: entity symbols

        Returns: numpy embedding matrix where each row is an emedding
        """
        ending = os.path.splitext(emb_file)[1]
        found = 0
        raw_num_ents = 0
        if ending == ".json":
            dct = utils.load_json_file(emb_file)
            val = next(iter(dct.values()))
            if type(val) is int or type(val) is float:
                embedding_size = 1
                conver_func = lambda x: np.array([x])
            elif type(val) is list:
                embedding_size = len(val)
                conver_func = lambda x: np.array([y for y in x])
            else:
                raise ValueError(
                    f"Unrecognized type for the array value of {type(val)}"
                )
            embeddings = {}
            for k in dct:
                embeddings[k] = conver_func(dct[k])
                assert len(embeddings[k]) == embedding_size
            entity2staticemb_table = np.zeros(
                (entity_symbols.num_entities_with_pad_and_nocand, embedding_size)
            )
            raw_num_ents = len(embeddings)
            for qid in tqdm(entity_symbols.get_all_qids()):
                if qid in embeddings:
                    found += 1
                    emb = embeddings[qid]
                    eid = entity_symbols.get_eid(qid)
                    entity2staticemb_table[eid, :embedding_size] = emb
        elif ending == ".pt":
            log_rank_0_debug(
                logger,
                f"We are readining in the embedding file from a .pt. We assume this is already mapped to eids",
            )
            (qid2eid_map, entity2staticemb_table_raw) = torch.load(emb_file)
            entity2staticemb_table_raw = (
                entity2staticemb_table_raw.detach().cpu().numpy()
            )
            raw_num_ents = entity2staticemb_table_raw.shape[0]
            # +2 handles the PAD and UNK entities
            assert entity2staticemb_table_raw.shape[0] == len(qid2eid_map) + 2, (
                f"The saved static embeddings file had mismatched shapes between qid2eid {len(qid2eid_map)} and "
                f"weights {entity2staticemb_table_raw.shape[0]}"
            )
            entity2staticemb_table = np.zeros(
                (
                    entity_symbols.num_entities_with_pad_and_nocand,
                    entity2staticemb_table_raw.shape[1],
                )
            )
            found = 0
            for qid in tqdm(entity_symbols.get_all_qids()):
                if qid in qid2eid_map:
                    found += 1
                    raw_eid = qid2eid_map[qid]
                    emb = entity2staticemb_table_raw[raw_eid]
                    new_eid = entity_symbols.get_eid(qid)
                    entity2staticemb_table[new_eid, :] = emb
        else:
            raise ValueError(
                f"We do not support static embeddings from {ending}. We only support .json and .pt"
            )
        log_rank_0_debug(
            logger,
            f"Found {found} ({found/len(entity_symbols.get_all_qids())} percent) of all entities after "
            f"reading {raw_num_ents} original entities have a static embedding",
        )
        return entity2staticemb_table
示例#13
0
    def __init__(
        self,
        main_args,
        emb_args,
        entity_symbols,
        key,
        cpu,
        normalize,
        dropout1d_perc,
        dropout2d_perc,
    ):
        super(TopKEntityEmb, self).__init__(
            main_args=main_args,
            emb_args=emb_args,
            entity_symbols=entity_symbols,
            key=key,
            cpu=cpu,
            normalize=normalize,
            dropout1d_perc=dropout1d_perc,
            dropout2d_perc=dropout2d_perc,
        )
        allowable_keys = {
            "learned_embedding_size",
            "perc_emb_drop",
            "qid2topk_eid",
            "regularize_mapping",
            "tail_init",
            "tail_init_zeros",
        }
        correct, bad_key = utils.assert_keys_in_dict(allowable_keys, emb_args)
        if not correct:
            raise ValueError(f"The key {bad_key} is not in {allowable_keys}")
        assert (
            "learned_embedding_size" in emb_args
        ), f"TopKEntityEmb must have learned_embedding_size in args"
        assert "perc_emb_drop" in emb_args, (
            f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings"
            f" removed."
        )
        self.learned_embedding_size = emb_args.learned_embedding_size
        # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding
        num_topk_entities_with_pad_and_nocand = (
            entity_symbols.num_entities_with_pad_and_nocand
            - int(emb_args.perc_emb_drop * entity_symbols.num_entities)
            + 1
        )
        # Mapping of entity to the new eid mapping
        eid2topkeid = torch.arange(0, entity_symbols.num_entities_with_pad_and_nocand)
        # There are issues with using -1 index into the embeddings; so we manually set it to be the last value
        eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1
        if "qid2topk_eid" not in emb_args:
            assert self.from_pretrained, (
                f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, "
                f"you must be loading a model from a checkpoint to build this index mapping"
            )

        self.learned_entity_embedding = nn.Embedding(
            num_topk_entities_with_pad_and_nocand,
            self.learned_embedding_size,
            padding_idx=-1,
            sparse=False,
        )

        self._dim = main_args.model_config.hidden_size

        if "regularize_mapping" in emb_args:
            eid2reg = torch.zeros(num_topk_entities_with_pad_and_nocand)
        else:
            eid2reg = None
        # If tail_init is false, all embeddings are randomly intialized.
        # If tail_init is true, we initialize all embeddings to be the same.
        self.tail_init = True
        self.tail_init_zeros = False
        # None init vec will be random
        init_vec = None
        if not self.from_pretrained:
            qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid)
            assert (
                len(qid2topk_eid) == entity_symbols.num_entities
            ), f"You must have an item in qid2topk_eid for each qid in entity_symbols"
            for qid in entity_symbols.get_all_qids():
                old_eid = entity_symbols.get_eid(qid)
                new_eid = qid2topk_eid[qid]
                eid2topkeid[old_eid] = new_eid
            assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
            assert (
                eid2topkeid[-1] == num_topk_entities_with_pad_and_nocand - 1
            ), "The -1 eid should still map to -1"

            if "tail_init" in emb_args:
                self.tail_init = emb_args.tail_init
            if "tail_init_zeros" in emb_args:
                self.tail_init_zeros = emb_args.tail_init_zeros
                self.tail_init = False
                init_vec = torch.zeros(1, self.learned_embedding_size)
            assert not (
                self.tail_init and self.tail_init_zeros
            ), f"Can only have one of tail_init or tail_init_zeros set"
            if self.tail_init or self.tail_init_zeros:
                if self.tail_init_zeros:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to zero.",
                    )
                else:
                    log_rank_0_debug(
                        logger,
                        f"All learned entity embeddings are initialized to the same value.",
                    )
                init_vec = model_utils.init_embeddings_to_vec(
                    self.learned_entity_embedding, pad_idx=-1, vec=init_vec
                )
                vec_save_file = os.path.join(
                    emmental.Meta.log_path, "init_vec_entity_embs.npy"
                )
                log_rank_0_debug(logger, f"Saving init vector to {vec_save_file}")
                if (
                    torch.distributed.is_initialized()
                    and torch.distributed.get_rank() == 0
                ):
                    np.save(vec_save_file, init_vec)
            else:
                log_rank_0_debug(
                    logger, f"All learned embeddings are randomly initialized."
                )
            # Regularization mapping goes from eid to 2d dropout percent
            if "regularize_mapping" in emb_args:
                log_rank_0_debug(
                    logger,
                    f"You are using regularization mapping with a topK entity embedding. "
                    f"This means all QIDs that are mapped to the same"
                    f" EID will get the same regularization value.",
                )
                if self.dropout1d_perc > 0 or self.dropout2d_perc > 0:
                    log_rank_0_debug(
                        logger,
                        f"You have 1D or 2D regularization set with a regularize_mapping. Do you mean to do this?",
                    )
                log_rank_0_debug(
                    logger,
                    f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}",
                )
                eid2reg = self.load_regularization_mapping(
                    main_args.data_config,
                    entity_symbols,
                    qid2topk_eid,
                    num_topk_entities_with_pad_and_nocand,
                    emb_args.regularize_mapping,
                )
        # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping
        self.register_buffer("eid2topkeid", eid2topkeid)
        self.register_buffer("eid2reg", eid2reg)
示例#14
0
 def __init__(self, main_args, emb_args, model_device, entity_symbols,
              word_symbols, word_emb, key):
     super(TopKEntityEmb, self).__init__(main_args=main_args,
                                         emb_args=emb_args,
                                         model_device=model_device,
                                         entity_symbols=entity_symbols,
                                         word_symbols=word_symbols,
                                         word_emb=word_emb,
                                         key=key)
     self.logger = logging_utils.get_logger(main_args)
     self.learned_embedding_size = emb_args.learned_embedding_size
     self.normalize = True
     assert "perc_emb_drop" in emb_args, f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings" \
                                         f" removed."
     # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding
     num_topk_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand - int(
         emb_args.perc_emb_drop * entity_symbols.num_entities) + 1
     # Mapping of entity to the new eid mapping
     qid2topk_eid = {}
     eid2topkeid = torch.arange(
         0, entity_symbols.num_entities_with_pad_and_nocand)
     # There are issues with using -1 index into the embeddings; so we manually set it to be the last value
     eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1
     if "qid2topk_eid" not in emb_args:
         assert len(main_args.run_config.init_checkpoint) > 0, f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, " \
                                                               f"you must be loading a model from a checkpoint to build this index mapping"
     else:
         qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid)
         assert len(
             qid2topk_eid
         ) == entity_symbols.num_entities, f"You must have an item in qid2topk_eid for each qid in entity_symbols"
         for qid in entity_symbols.get_all_qids():
             old_eid = entity_symbols.get_eid(qid)
             new_eid = qid2topk_eid[qid]
             eid2topkeid[old_eid] = new_eid
         assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
         assert eid2topkeid[
             -1] == num_topk_entities_with_pad_and_nocand - 1, "The -1 eid should still map to -1"
     self.learned_entity_embedding = nn.Embedding(
         num_topk_entities_with_pad_and_nocand,
         self.learned_embedding_size,
         padding_idx=-1,
         sparse=True)
     # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping
     self.register_buffer("eid2topkeid", eid2topkeid)
     # If tail_init is false, all embeddings are randomly intialized.
     # If tail_init is true, we initialize all embeddings to be the same.
     self.tail_init = True
     self.tail_init_zeros = False
     # None init vec will be random
     init_vec = None
     if "tail_init" in emb_args:
         self.tail_init = emb_args.tail_init
     if "tail_init_zeros" in emb_args:
         self.tail_init_zeros = emb_args.tail_init_zeros
         self.tail_init = False
         init_vec = torch.zeros(1, self.learned_embedding_size)
     assert not (self.tail_init and self.tail_init_zeros
                 ), f"Can only have one of tail_init or tail_init_zeros set"
     if self.tail_init or self.tail_init_zeros:
         if self.tail_init_zeros:
             self.logger.debug(
                 f"All learned entity embeddings are initialized to zero.")
         else:
             self.logger.debug(
                 f"All learned entity embeddings are initialized to the same value."
             )
         init_vec = model_utils.init_embeddings_to_vec(
             self.learned_entity_embedding, pad_idx=-1, vec=init_vec)
         vec_save_file = os.path.join(
             train_utils.get_save_folder(main_args.run_config),
             "init_vec_entity_embs.npy")
         self.logger.debug(f"Saving init vector to {vec_save_file}")
         np.save(vec_save_file, init_vec)
     else:
         self.logger.debug(
             f"All learned embeddings are randomly initialized.")
     self._dim = main_args.model_config.hidden_size
     self.eid2reg = None
     # Regularization mapping goes from eid to 2d dropout percent
     if "regularize_mapping" in emb_args:
         self.logger.warning(
             f"You are using regularization mapping with a topK entity embedding. This means all QIDs that are mapped to the same"
             f" EID will get the same regularization value.")
         self.logger.debug(
             f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}"
         )
         self.eid2reg = self.load_regularization_mapping(
             main_args, qid2topk_eid, num_topk_entities_with_pad_and_nocand,
             emb_args.regularize_mapping, self.logger.debug)
         self.eid2reg = self.eid2reg.to(model_device)
示例#15
0
def compress_topk_embeddings(args):
    assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1"
    print(
        f"Loading entity symbols from {os.path.join(args.entity_dir, 'entity_mappings')}"
    )
    entity_db = EntitySymbols.load_from_cache(
        os.path.join(args.entity_dir, "entity_mappings"))
    print(f"Loading qid2count from {args.qid2count}")
    qid2count = utils.load_json_file(args.qid2count)
    print(f"Filtering qids")
    (
        qid2topk_eid,
        old2new_eid,
        toes_eid,
        new_num_topk_entities,
    ) = filter_qids(args.perc_emb_drop, entity_db, qid2count)
    if len(args.model_path) > 0:
        assert (args.save_model_path is not None
                and len(args.save_model_path) > 0
                ), f"If you give a model path, you must give a save checkpoint"
        print(f"Filtering embeddings")
        state_dict, model_state_dict = load_statedict(args.model_path)
        try:
            get_nested_item(model_state_dict, ENTITY_EMB_KEYS)
        except:
            print(
                f"ERROR: All of {ENTITY_EMB_KEYS} are not in model_state_dict")
            raise
        model_state_dict = filter_embs(
            new_num_topk_entities,
            entity_db,
            old2new_eid,
            qid2topk_eid,
            toes_eid,
            model_state_dict,
        )
        # Generate the new old2new_eid weight vector to save in model_state_dict
        oldeid2topkeid = torch.arange(
            0, entity_db.num_entities_with_pad_and_nocand)
        # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing
        # for entity embeddings. So we must manually make it the last entry
        oldeid2topkeid[-1] = new_num_topk_entities + 2 - 1
        for qid, new_eid in tqdm(qid2topk_eid.items(), desc="Setting new ids"):
            old_eid = entity_db.get_eid(qid)
            oldeid2topkeid[old_eid] = new_eid
        assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
        assert (oldeid2topkeid[-1] == new_num_topk_entities + 2 -
                1), "The -1 eid should still map to the last row"
        model_state_dict = set_nested_item(model_state_dict, ENTITY_EID_KEYS,
                                           oldeid2topkeid)
        # Remove the eid2reg value as that was with the old entity id mapping
        try:
            model_state_dict = set_nested_item(model_state_dict,
                                               ENTITY_REG_KEYS, None)
        except:
            print(
                f"Could not remove regularization. If your model was trained with regularization mapping on "
                f"the learned entity embedding, this should not happen.")
        print(model_state_dict["module_pool"]["learned"].keys())
        state_dict["model"] = model_state_dict
        print(f"Saving model at {args.save_model_path}")
        torch.save(state_dict, args.save_model_path)
    print(
        f"Saving topk to eid at {os.path.join(args.entity_dir, 'entity_mappings', args.save_qid2topk_file)}"
    )
    utils.dump_json_file(
        os.path.join(args.entity_dir, "entity_mappings",
                     args.save_qid2topk_file),
        qid2topk_eid,
    )

    if args.model_config is not None:
        modify_config(
            args.model_config,
            args.save_model_config,
            args.save_model_path,
            os.path.join(args.entity_dir, "entity_mappings",
                         args.save_qid2topk_file),
            args.perc_emb_drop,
        )