예제 #1
0
    def save(self, save_dir):
        """Dumps the entity symbols.

        Args:
            save_dir: directory string to save

        Returns:
        """
        self._sort_alias_cands()
        utils.ensure_dir(save_dir)
        utils.dump_json_file(
            filename=os.path.join(save_dir, "config.json"),
            contents={
                "max_candidates": self.max_candidates,
                "datetime": str(datetime.now()),
            },
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, self.alias_cand_map_file),
            contents=self._alias2qids,
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, "qid2title.json"), contents=self._qid2title
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, "qid2eid.json"), contents=self._qid2eid
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, self.alias_idx_file),
            contents=self._alias2id,
        )
예제 #2
0
    def test_build_type_table_too_many_types(self):
        type_data = {"Q1": [1, 2, 3], "Q2": [4, 5, 6], "Q3": [], "Q4": [7, 8, 9]}
        type_vocab = {
            "T1": 1,
            "T2": 2,
            "T3": 3,
            "T4": 4,
            "T5": 5,
            "T6": 6,
            "T7": 7,
            "T8": 8,
            "T9": 9,
        }
        utils.dump_json_file(self.type_file, type_data)
        utils.dump_json_file(self.type_vocab_file, type_vocab)

        true_type_table = torch.tensor([[0], [1], [4], [0], [7], [0]]).long()
        true_type2row = {1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}
        pred_type_table, type2row, max_labels = TypeEmb.build_type_table(
            self.type_file,
            self.type_vocab_file,
            max_types=1,
            entity_symbols=self.entity_symbols,
        )
        assert torch.equal(pred_type_table, true_type_table)
        self.assertDictEqual(true_type2row, type2row)
        # there are 9 real types so we expect (including unk and pad) there to be type indices up to 10
        assert max_labels == 10
예제 #3
0
def compute_histograms(save_dir, entity_symbols):
    al_counts = Counter()
    for al in entity_symbols.get_all_aliases():
        num_entities = len(entity_symbols.get_qid_cands(al))
        al_counts.update([num_entities])
    utils.dump_json_file(filename=os.path.join(save_dir, "candidate_counts.json"), contents=al_counts)
    return
예제 #4
0
def main():
    args = parse_args()
    print(json.dumps(args, indent=4))
    assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1"
    state_dict, model_state_dict = load_statedict(args.init_checkpoint)
    assert ENTITY_EMB_KEY in model_state_dict
    print(f"Loading entity symbols from {os.path.join(args.entity_dir, args.entity_map_dir)}")
    entity_db = EntitySymbols(os.path.join(args.entity_dir, args.entity_map_dir), args.alias_cand_map_file)
    print(f"Loading qid2count from {args.qid2count}")
    qid2count = utils.load_json_file(args.qid2count)
    print(f"Filtering qids")
    qid2topk_eid, old2new_eid, toes_eid, new_num_topk_entities = filter_qids(args.perc_emb_drop, entity_db, qid2count)
    print(f"Filtering embeddings")
    model_state_dict = filter_embs(new_num_topk_entities, entity_db, old2new_eid, qid2topk_eid, toes_eid, model_state_dict)
    # Generate the new old2new_eid weight vector to save in model_state_dict
    oldeid2topkeid = torch.arange(0, entity_db.num_entities_with_pad_and_nocand)
    # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing for entity embeddings. So we must manually make it the last entry
    oldeid2topkeid[-1] = new_num_topk_entities+2-1
    for qid, new_eid in qid2topk_eid.items():
        old_eid = entity_db.get_eid(qid)
        oldeid2topkeid[old_eid] = new_eid
    assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
    assert oldeid2topkeid[-1] == new_num_topk_entities+2-1, "The -1 eid should still map to the last row"
    model_state_dict[ENTITY_EID_KEY] = oldeid2topkeid
    state_dict["model"] = model_state_dict
    print(f"Saving model at {args.save_checkpoint}")
    torch.save(state_dict, args.save_checkpoint)
    print(f"Saving entity_db at {os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file)}")
    utils.dump_json_file(os.path.join(args.entity_dir, args.entity_map_dir, args.save_qid2topk_file), qid2topk_eid)
예제 #5
0
def main():
    args = parse_args()
    print(ujson.dumps(vars(args), indent=4))
    num_processes = int(0.8 * multiprocessing.cpu_count())
    print(f"Getting slice counts from {args.train_file}")
    qid_cnts = get_counts(num_processes, args.train_file)
    utils.dump_json_file(args.out_file, qid_cnts)
예제 #6
0
def compute_occurrences(save_dir,
                        data_file,
                        entity_dump,
                        lower,
                        strip,
                        num_workers=8):
    global all_aliases
    all_aliases = get_all_aliases(entity_dump._alias2qids)

    # divide up data into chunks
    num_lines = get_num_lines(data_file)
    num_processes = min(num_workers, int(multiprocessing.cpu_count()))
    logging.info(f'Using {num_processes} workers...')
    chunk_size = int(np.ceil(num_lines / (num_processes)))
    chunk_file_path = os.path.join(save_dir, 'tmp')
    utils.ensure_dir(chunk_file_path)
    chunk_infiles = [
        os.path.join(f'{chunk_file_path}', f'data_chunk_{chunk_id}_in.jsonl')
        for chunk_id in range(num_processes)
    ]
    chunk_text_data(data_file, chunk_infiles, chunk_size, num_lines)

    pool = multiprocessing.Pool(processes=num_processes)
    subprocess_args = [[chunk_infiles[i], lower, strip]
                       for i in range(num_processes)]
    results = pool.map(compute_occurrences_single, subprocess_args)
    pool.close()
    pool.join()
    logging.info('Finished collecting counts')
    logging.info('Merging counts....')
    # merge counters together
    ent_occurrences = Counter()
    # alias histogram
    alias_occurrences = Counter()
    # alias text occurrances
    alias_text_occurrences = Counter()
    # number of aliases per sentence
    alias_pair_occurrences = Counter()
    # alias|entity histogram
    alias_entity_pair = Counter()
    for result_set in results:
        ent_occurrences += result_set['ent_occurrences']
        alias_occurrences += result_set['alias_occurrences']
        alias_text_occurrences += result_set['alias_text_occurrences']
        alias_pair_occurrences += result_set['alias_pair_occurrences']
        alias_entity_pair += result_set['alias_entity_pair']
    # save counters
    utils.dump_json_file(filename=os.path.join(save_dir, "entity_count.json"),
                         contents=ent_occurrences)
    utils.dump_json_file(filename=os.path.join(save_dir, "alias_counts.json"),
                         contents=alias_occurrences)
    utils.dump_json_file(filename=os.path.join(save_dir,
                                               "alias_text_counts.json"),
                         contents=alias_text_occurrences)
    utils.dump_json_file(filename=os.path.join(save_dir,
                                               "alias_pair_occurrences.json"),
                         contents=alias_pair_occurrences)
    utils.dump_json_file(filename=os.path.join(save_dir,
                                               "alias_entity_counts.json"),
                         contents=alias_entity_pair)
예제 #7
0
    def test_load_kg_adj_indices_json(self):
        kg_data = {"Q1": {"Q2": 100}, "Q3": {"Q2": 11}}
        utils.dump_json_file(self.kg_adj_json, kg_data)

        adj_out = KGIndices.build_kg_adj(
            kg_adj_file=self.kg_adj_json,
            entity_symbols=self.entity_symbols,
            threshold=10,
            log_weight=True,
        )
        adj_out_gold = nx.adjacency_matrix(
            nx.Graph(
                np.array(
                    [
                        [0, 0, 0, 0, 0, 0],
                        [0, 0, np.log(100), 0, 0, 0],
                        [0, np.log(100), 0, np.log(11), 0, 0],
                        [0, 0, np.log(11), 0, 0, 0],
                        [0, 0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0, 0],
                    ]
                )
            )
        )
        np.testing.assert_allclose(adj_out.toarray(), adj_out_gold.toarray())

        # We filter out weights 10 or lower
        kg_data = {"Q1": {"Q2": 100}, "Q3": {"Q2": 10}}
        utils.dump_json_file(self.kg_adj_json, kg_data)

        adj_out = KGIndices.build_kg_adj(
            kg_adj_file=self.kg_adj_json,
            entity_symbols=self.entity_symbols,
            threshold=10,
            log_weight=True,
        )
        adj_out_gold = nx.adjacency_matrix(
            nx.Graph(
                np.array(
                    [
                        [0, 0, 0, 0, 0, 0],
                        [0, 0, np.log(100), 0, 0, 0],
                        [0, np.log(100), 0, 0, 0, 0],
                        [0, 0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0, 0],
                        [0, 0, 0, 0, 0, 0],
                    ]
                )
            )
        )
        np.testing.assert_allclose(adj_out.toarray(), adj_out_gold.toarray())
예제 #8
0
    def save(self, save_dir, prefix=""):
        """Dumps the type symbols.

        Args:
            save_dir: directory string to save
            prefix: prefix to add to beginning to file

        Returns:
        """
        utils.ensure_dir(str(save_dir))
        utils.dump_json_file(
            filename=os.path.join(save_dir, "config.json"),
            contents={
                "max_types": self.max_types,
            },
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, f"{prefix}qid2typenames.json"),
            contents=self._qid2typenames,
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, f"{prefix}qid2typeids.json"),
            contents=self._qid2typeid,
        )
        utils.dump_json_file(
            filename=os.path.join(save_dir, f"{prefix}type_vocab.json"),
            contents=self._type_vocab,
        )
예제 #9
0
def main(args, mode):
    multiprocessing.set_start_method("forkserver", force=True)
    # =================================
    # ARGUMENTS CHECK
    # =================================
    # distributed training
    assert (args.run_config.ngpus_per_node <= torch.cuda.device_count()) or (
        not torch.cuda.is_available()), 'Not enough GPUs per node.'
    world_size = args.run_config.ngpus_per_node * args.run_config.nodes
    if world_size > 1:
        args.run_config.distributed = True
    assert (args.run_config.distributed and world_size > 1) or (world_size
                                                                == 1)

    train_utils.setup_run_folders(args, mode)

    # check slice method
    assert args.train_config.slice_method in SLICE_METHODS, f"You're slice_method {args.train_config.slice_method} is not in {SLICE_METHODS}."
    train_utils.setup_train_heads_and_eval_slices(args)

    # check save step
    assert args.run_config.save_every_k_eval > 0, f"You must have save_every_k_eval set to be > 0"

    # since eval, make sure resume model file is set and exists
    if mode == "eval" or mode == "dump_preds" or mode == "dump_embs":
        assert args.run_config.init_checkpoint != "", \
            f"You must specify a model checkpoint in run_config to run {mode}"
        assert os.path.exists(args.run_config.init_checkpoint),\
            f"The resume model file of {args.run_config.init_checkpoint} doesn't exist"

    if mode == "dump_preds" or mode == "dump_embs":
        assert args.run_config.perc_eval == 1.0, f"If you are running dump_preds or dump_embs, run_config.perc_eval must be 1.0. You have {args.run_config.perc_eval}"
        assert args.data_config.test_dataset.use_weak_label is True, f"We do not support dumping when the test dataset gold is set to false. You can filter the dataset and run with filtered data."

    utils.dump_json_file(filename=os.path.join(
        train_utils.get_save_folder(args.run_config), f"config_{mode}.json"),
                         contents=args)
    if args.run_config.distributed:
        mp.spawn(main_worker,
                 nprocs=args.run_config.ngpus_per_node,
                 args=(args, mode, world_size))
    else:
        main_worker(gpu=args.run_config.gpu,
                    args=args,
                    mode=mode,
                    world_size=world_size)
예제 #10
0
def compute_type_occurrences(save_dir, prefix, entity_symbols, qid2typenames,
                             data_file):
    # type histogram
    type_occurances = defaultdict(int)
    # type intersection histogram (frequency of type co-occurring together)
    type_pair_occurances = defaultdict(int)
    # computes number of aliases in a sentence with the number of shared types amongst those aliases; just outputs number of types for single alias
    num_al_num_match_type = defaultdict(int)
    type_pairs = defaultdict(int)
    with open(data_file, "r") as in_file:
        for line in in_file:
            line = json.loads(line.strip())
            aliases = line['aliases']
            qids = line['qids']
            all_ex_types = set()
            i = 0
            # for all pairs of qid, types, get the intersection of types
            # if the intersection > 1, write type pairs and increment
            for qid, alias in zip(qids, aliases):
                types = qid2typenames.get(qid, [])
                for ty in types:
                    type_occurances[ty] += 1
                if i == 0:
                    all_ex_types = set(types)
                else:
                    all_ex_types = all_ex_types.intersection(set(types))
                i += 1
            if len(aliases) > 1:
                num_al_num_match_type[
                    f"{len(aliases)}|{len(set(all_ex_types))}"] += 1
                type_pair_occurances[tuple(sorted(list(all_ex_types)))] += 1

            alias_subsets = list(combinations(qids, 2))
            for qid1, qid2 in alias_subsets:
                types1 = qid2typenames.get(qid1, [])
                types2 = qid2typenames.get(qid2, [])
                overlap_types = set(types1).intersection(set(types2))
                if len(overlap_types) > 0:
                    type_pairs[tuple(overlap_types)] += 1
    logging.info(f"Saving type data...")
    utils.dump_json_file(filename=os.path.join(
        save_dir, f"{prefix}_type_occurances.json"),
                         contents=type_occurances)
    utils.dump_json_file(filename=os.path.join(
        save_dir, f"{prefix}_type_pair_occurances.json"),
                         contents=type_pair_occurances)
    utils.dump_json_file(filename=os.path.join(
        save_dir, f"{prefix}_num_al_num_match_type.json"),
                         contents=num_al_num_match_type)
    utils.dump_json_file(filename=os.path.join(
        save_dir, f"{prefix}_num_type_pairs.json"),
                         contents=type_pairs)
예제 #11
0
def get_slice_stats_hlp(args):
    i, lines, offset, temp_out_dir = args

    res = defaultdict(int)  # slice -> cnt
    slice_to_sent = defaultdict(set)  # slice -> sent_idx (for sampling)
    sent_to_slices = defaultdict(
        lambda: defaultdict(int))  # sent_idx -> slice -> cnt (for sampling)
    for line in tqdm(lines, total=len(lines),
                     desc=f"Processing lines for {i}"):
        line = ujson.loads(line)
        slices = line.get("slices", {})
        anchors = line["gold"]
        for slice_name in slices:
            for al_str in slices[slice_name]:
                if anchors[int(
                        al_str)] is True and slices[slice_name][al_str] > 0.5:
                    res[slice_name] += 1
                    slice_to_sent[slice_name].add(int(line["sent_idx_unq"]))
        sent_to_slices[int(line["sent_idx_unq"])].update(res)

    utils.dump_json_file(
        os.path.join(temp_out_dir, f"{FINAL_COUNTS_PREFIX}_{i}.json"), res)
    utils.dump_json_file(
        os.path.join(temp_out_dir, f"{FINAL_SENT_TO_SLICE_PREFIX}_{i}.json"),
        sent_to_slices,
    )
    utils.dump_json_file(
        os.path.join(temp_out_dir, f"{FINAL_SLICE_TO_SENT_PREFIX}_{i}.json"),
        slice_to_sent,
    )
    return i
예제 #12
0
 def test_topk_embedding(self):
     topkqid2eid = {"Q1": 1, "Q2": 3, "Q3": 3, "Q4": 2}
     utils.dump_json_file(self.qid2topkeid, topkqid2eid)
     self.args.data_config.ent_embeddings[0]["args"]["perc_emb_drop"] = 0.5
     self.args.data_config.ent_embeddings[0]["args"][
         "qid2topk_eid"
     ] = self.qid2topkeid
     learned_emb = TopKEntityEmb(
         main_args=self.args,
         emb_args=self.args.data_config.ent_embeddings[0]["args"],
         entity_symbols=self.entity_symbols,
         key="learned",
         cpu=True,
         normalize=False,
         dropout1d_perc=0.0,
         dropout2d_perc=0.0,
     )
     num_new_eids_padunk = 5
     eid2topkeid_gold = torch.tensor([0, 1, 3, 3, 2, num_new_eids_padunk - 1])
     assert torch.equal(eid2topkeid_gold, learned_emb.eid2topkeid)
     assert list(learned_emb.learned_entity_embedding.weight.shape) == [
         num_new_eids_padunk,
         self.learned_embedding_size,
     ]
예제 #13
0
 def dump(self, save_dir):
     #memmapped files bahve badly if you try to overwrite them in memory, which is what we'd be doing if load_dir == save_dir
     if self._loaded_from_dir is None or self._loaded_from_dir != save_dir:
         utils.ensure_dir(save_dir)
         utils.dump_json_file(filename=os.path.join(save_dir,
                                                    "fmt_types.json"),
                              contents=self._fmt_types)
         utils.dump_json_file(filename=os.path.join(save_dir,
                                                    "max_values.json"),
                              contents=self._max_values)
         utils.dump_json_file(filename=os.path.join(save_dir,
                                                    "vocabulary.json"),
                              contents=self._stoi)
         np.save(file=os.path.join(save_dir, "itos.npy"),
                 arr=self._itos,
                 allow_pickle=True)
         for tri_name in self._fmt_types:
             self._record_tris[tri_name].save(
                 os.path.join(save_dir, f'record_trie_{tri_name}.marisa'))
예제 #14
0
    def __init__(
        self,
        main_args,
        dataset,
        use_weak_label,
        entity_symbols,
        dataset_threads,
        split="train",
    ):
        global_start = time.time()
        log_rank_0_info(logger,
                        f"Building slice dataset for {split} from {dataset}.")
        spawn_method = main_args.run_config.spawn_method
        data_config = main_args.data_config
        orig_spawn = multiprocessing.get_start_method()
        multiprocessing.set_start_method(spawn_method, force=True)
        self.slice_names = data_utils.get_eval_slices(data_config.eval_slices)
        self.get_slice_dt = lambda max_a2p: np.dtype([
            ("sent_idx", int),
            ("subslice_idx", int),
            ("alias_slice_incidence", int, (max_a2p, )),
            ("prob_labels", float, (max_a2p, )),
        ])
        self.get_storage = lambda max_a2p: np.dtype(
            [(slice_name, self.get_slice_dt(max_a2p))
             for slice_name in self.slice_names])
        # Folder for all mmap saved files
        save_dataset_folder = data_utils.get_save_data_folder(
            data_config, use_weak_label, dataset)
        utils.ensure_dir(save_dataset_folder)
        # Folder for temporary output files
        temp_output_folder = os.path.join(data_config.data_dir,
                                          data_config.data_prep_dir,
                                          f"prep_{split}_slice_files")
        utils.ensure_dir(temp_output_folder)
        # Input step 1
        create_ex_indir = os.path.join(temp_output_folder,
                                       "create_examples_input")
        utils.ensure_dir(create_ex_indir)
        # Input step 2
        create_ex_outdir = os.path.join(temp_output_folder,
                                        "create_examples_output")
        utils.ensure_dir(create_ex_outdir)
        # Meta data saved files
        meta_file = os.path.join(temp_output_folder, "meta_data.json")
        # File for standard training data
        hash = hashlib.sha1(str(
            self.slice_names).encode("UTF-8")).hexdigest()[:10]
        self.save_dataset_name = os.path.join(save_dataset_folder,
                                              f"ned_slices_{hash}.bin")
        self.save_data_config_name = os.path.join(save_dataset_folder,
                                                  "ned_slices_config.json")

        # =======================================================================================
        # SLICE DATA
        # =======================================================================================
        log_rank_0_debug(logger, "Loading dataset...")
        log_rank_0_debug(logger, f"Seeing if {self.save_dataset_name} exists")
        if data_config.overwrite_preprocessed_data or (not os.path.exists(
                self.save_dataset_name)):
            st_time = time.time()
            try:
                log_rank_0_info(
                    logger,
                    f"Building dataset from scratch. Saving to {save_dataset_folder}",
                )
                create_examples(
                    dataset,
                    create_ex_indir,
                    create_ex_outdir,
                    meta_file,
                    data_config,
                    dataset_threads,
                    self.slice_names,
                    use_weak_label,
                    split,
                )
                max_alias2pred = utils.load_json_file(
                    meta_file)["max_alias2pred"]
                convert_examples_to_features_and_save(
                    meta_file,
                    dataset_threads,
                    self.slice_names,
                    self.save_dataset_name,
                    self.get_storage(max_alias2pred),
                )
                utils.dump_json_file(self.save_data_config_name,
                                     {"max_alias2pred": max_alias2pred})

                log_rank_0_debug(
                    logger,
                    f"Finished prepping data in {time.time() - st_time}")
            except Exception as e:
                tb = traceback.TracebackException.from_exception(e)
                logger.error(e)
                logger.error("\n".join(tb.stack.format()))
                shutil.rmtree(save_dataset_folder, ignore_errors=True)
                raise

        log_rank_0_info(
            logger,
            f"Loading data from {self.save_dataset_name} and {self.save_data_config_name}",
        )
        max_alias2pred = utils.load_json_file(
            self.save_data_config_name)["max_alias2pred"]
        self.data, self.sent_to_row_id_dict = self.build_data_dict(
            self.save_dataset_name, self.get_storage(max_alias2pred))
        assert len(self.data) > 0
        assert len(self.sent_to_row_id_dict) > 0
        log_rank_0_debug(logger, f"Removing temporary output files")
        shutil.rmtree(temp_output_folder, ignore_errors=True)
        # Set spawn back to original/default, which is "fork" or "spawn". This is needed for the Meta.config to
        # be correctly passed in the collate_fn.
        multiprocessing.set_start_method(orig_spawn, force=True)
        log_rank_0_info(
            logger,
            f"Final slice data initialization time from {split} is {time.time() - global_start}s",
        )
예제 #15
0
 def dump(self, save_dir, stats={}, args=None):
     self._sort_alias_cands()
     utils.ensure_dir(save_dir)
     utils.dump_json_file(filename=os.path.join(save_dir, "config.json"), contents={"max_candidates":self.max_candidates,
                                                                                    "max_alias_len":self.max_alias_len,
                                                                                    "datetime": str(datetime.now())})
     utils.dump_json_file(filename=os.path.join(save_dir, self.alias_cand_map_file), contents=self._alias2qids)
     utils.dump_json_file(filename=os.path.join(save_dir, "qid2title.json"), contents=self._qid2title)
     utils.dump_json_file(filename=os.path.join(save_dir, "qid2eid.json"), contents=self._qid2eid)
     utils.dump_json_file(filename=os.path.join(save_dir, "filter_stats.json"), contents=stats)
     if args is not None:
         utils.dump_json_file(filename=os.path.join(save_dir, "args.json"), contents=vars(args))
예제 #16
0
def create_examples(
    dataset,
    create_ex_indir,
    create_ex_outdir,
    meta_file,
    data_config,
    dataset_threads,
    slice_names,
    use_weak_label,
    split,
):
    """Creates examples from the raw input data.

    Args:
        dataset: dataset file
        create_ex_indir: temporary directory where input files are stored
        create_ex_outdir: temporary directory to store output files from method
        meta_file: metadata file to save the file names/paths for the next step in prep pipeline
        data_config: data config
        dataset_threads: number of threads
        slice_names: list of slices to evaluate on
        use_weak_label: whether to use weak labeling or not
        split: data split

    Returns:
    """
    log_rank_0_debug(logger, "Starting to extract subsentences")
    start = time.time()
    num_processes = min(dataset_threads,
                        int(0.8 * multiprocessing.cpu_count()))

    log_rank_0_debug(logger, f"Counting lines")
    total_input = sum(1 for _ in open(dataset))
    if num_processes == 1:
        out_file_name = os.path.join(create_ex_outdir,
                                     os.path.basename(dataset))
        constants_dict = {
            "slice_names": slice_names,
            "use_weak_label": use_weak_label,
            "max_aliases": data_config.max_aliases,
            "split": split,
            "train_in_candidates": data_config.train_in_candidates,
        }
        files_and_counts = {}
        res = create_examples_single(dataset, total_input, out_file_name,
                                     constants_dict)
        total_output = res["total_lines"]
        max_alias2pred = res["max_alias2pred"]
        files_and_counts[res["output_filename"]] = res["total_lines"]
    else:
        log_rank_0_info(
            logger,
            f"Strating to extract examples with {num_processes} threads")
        log_rank_0_debug(
            logger, "Parallelizing with " + str(num_processes) + " threads.")
        chunk_input = int(np.ceil(total_input / num_processes))
        log_rank_0_debug(
            logger,
            f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines",
        )
        total_input_from_chunks, input_files_dict = utils.chunk_file(
            dataset, create_ex_indir, chunk_input)

        input_files = list(input_files_dict.keys())
        input_file_lines = [input_files_dict[k] for k in input_files]
        output_files = [
            in_file_name.replace(create_ex_indir, create_ex_outdir)
            for in_file_name in input_files
        ]
        assert (
            total_input == total_input_from_chunks
        ), f"Lengths of files {total_input} doesn't mathc {total_input_from_chunks}"
        log_rank_0_debug(logger, f"Done chunking files")

        pool = multiprocessing.Pool(
            processes=num_processes,
            initializer=create_examples_initializer,
            initargs=[
                data_config,
                slice_names,
                use_weak_label,
                data_config.max_aliases,
                split,
                data_config.train_in_candidates,
            ],
        )
        total_output = 0
        max_alias2pred = 0
        input_args = list(zip(input_files, input_file_lines, output_files))
        # Store output files and counts for saving in next step
        files_and_counts = {}
        for res in pool.imap_unordered(create_examples_hlp,
                                       input_args,
                                       chunksize=1):
            total_output += res["total_lines"]
            max_alias2pred = max(max_alias2pred, res["max_alias2pred"])
            files_and_counts[res["output_filename"]] = res["total_lines"]
        pool.close()
    utils.dump_json_file(
        meta_file,
        {
            "num_mentions": total_output,
            "files_and_counts": files_and_counts,
            "max_alias2pred": max_alias2pred,
        },
    )
    log_rank_0_debug(
        logger,
        f"Done with extracting examples in {time.time()-start}. Total lines seen {total_input}. "
        f"Total lines kept {total_output}.",
    )
    return
예제 #17
0
def compress_topk_embeddings(args):
    assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1"
    print(
        f"Loading entity symbols from {os.path.join(args.entity_dir, 'entity_mappings')}"
    )
    entity_db = EntitySymbols.load_from_cache(
        os.path.join(args.entity_dir, "entity_mappings"))
    print(f"Loading qid2count from {args.qid2count}")
    qid2count = utils.load_json_file(args.qid2count)
    print(f"Filtering qids")
    (
        qid2topk_eid,
        old2new_eid,
        toes_eid,
        new_num_topk_entities,
    ) = filter_qids(args.perc_emb_drop, entity_db, qid2count)
    if len(args.model_path) > 0:
        assert (args.save_model_path is not None
                and len(args.save_model_path) > 0
                ), f"If you give a model path, you must give a save checkpoint"
        print(f"Filtering embeddings")
        state_dict, model_state_dict = load_statedict(args.model_path)
        try:
            get_nested_item(model_state_dict, ENTITY_EMB_KEYS)
        except:
            print(
                f"ERROR: All of {ENTITY_EMB_KEYS} are not in model_state_dict")
            raise
        model_state_dict = filter_embs(
            new_num_topk_entities,
            entity_db,
            old2new_eid,
            qid2topk_eid,
            toes_eid,
            model_state_dict,
        )
        # Generate the new old2new_eid weight vector to save in model_state_dict
        oldeid2topkeid = torch.arange(
            0, entity_db.num_entities_with_pad_and_nocand)
        # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing
        # for entity embeddings. So we must manually make it the last entry
        oldeid2topkeid[-1] = new_num_topk_entities + 2 - 1
        for qid, new_eid in tqdm(qid2topk_eid.items(), desc="Setting new ids"):
            old_eid = entity_db.get_eid(qid)
            oldeid2topkeid[old_eid] = new_eid
        assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
        assert (oldeid2topkeid[-1] == new_num_topk_entities + 2 -
                1), "The -1 eid should still map to the last row"
        model_state_dict = set_nested_item(model_state_dict, ENTITY_EID_KEYS,
                                           oldeid2topkeid)
        # Remove the eid2reg value as that was with the old entity id mapping
        try:
            model_state_dict = set_nested_item(model_state_dict,
                                               ENTITY_REG_KEYS, None)
        except:
            print(
                f"Could not remove regularization. If your model was trained with regularization mapping on "
                f"the learned entity embedding, this should not happen.")
        print(model_state_dict["module_pool"]["learned"].keys())
        state_dict["model"] = model_state_dict
        print(f"Saving model at {args.save_model_path}")
        torch.save(state_dict, args.save_model_path)
    print(
        f"Saving topk to eid at {os.path.join(args.entity_dir, 'entity_mappings', args.save_qid2topk_file)}"
    )
    utils.dump_json_file(
        os.path.join(args.entity_dir, "entity_mappings",
                     args.save_qid2topk_file),
        qid2topk_eid,
    )

    if args.model_config is not None:
        modify_config(
            args.model_config,
            args.save_model_config,
            args.save_model_path,
            os.path.join(args.entity_dir, "entity_mappings",
                         args.save_qid2topk_file),
            args.perc_emb_drop,
        )