Example #1
0
    def load_regularization_mapping(cls, data_config, entity_symbols,
                                    reg_file):
        """Reads in a csv file with columns [qid, regularization].

        In the forward pass, the entity id with associated qid will be
        regularized with probability regularization.

        Args:
            data_config: data config
            qid2topk_eid: Dict from QID to eid in the entity embedding
            num_entities_with_pad_and_nocand: number of entities including pad and null candidate option
            reg_file: regularization csv file

        Returns: Tensor where each value is the regularization value for EID
        """
        reg_str = os.path.splitext(os.path.basename(reg_file.replace("/",
                                                                     "_")))[0]
        prep_dir = data_utils.get_data_prep_dir(data_config)
        prep_file = os.path.join(
            prep_dir, f"entity_regularization_mapping_{reg_str}.pt")
        utils.ensure_dir(os.path.dirname(prep_file))
        log_rank_0_debug(logger,
                         f"Looking for regularization mapping in {prep_file}")
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(
                logger,
                f"Loading existing entity regularization mapping from {prep_file}",
            )
            start = time.time()
            eid2reg = torch.load(prep_file)
            log_rank_0_debug(
                logger,
                f"Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s",
            )
        else:
            start = time.time()
            log_rank_0_info(
                logger,
                f"Building entity regularization mapping from {reg_file}")
            qid2reg = pd.read_csv(reg_file)
            assert (
                "qid" in qid2reg.columns
                and "regularization" in qid2reg.columns
            ), f"Expected qid and regularization as the column names for {reg_file}"
            # default of no mask
            eid2reg_arr = [0.0
                           ] * entity_symbols.num_entities_with_pad_and_nocand
            for row_idx, row in qid2reg.iterrows():
                if entity_symbols.qid_exists(row["qid"]):
                    eid = entity_symbols.get_eid(row["qid"])
                    eid2reg_arr[eid] = row["regularization"]
            eid2reg = torch.tensor(eid2reg_arr)
            torch.save(eid2reg, prep_file)
            log_rank_0_debug(
                logger,
                f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s.",
            )
        return eid2reg
Example #2
0
 def load_regularization_mapping(cls, main_args, num_types_with_pad_and_unk,
                                 type2row_dict, reg_file, log_func):
     """
     Reads in a csv file with columns [qid, regularization].
     In the forward pass, the entity id with associated qid will be regularized with probability regularization.
     """
     reg_str = reg_file.split(".csv")[0]
     prep_dir = data_utils.get_data_prep_dir(main_args)
     prep_file = os.path.join(
         prep_dir, f'entity_regularization_mapping_{reg_str}.pt')
     utils.ensure_dir(os.path.dirname(prep_file))
     log_func(f"Looking for regularization mapping in {prep_file}")
     if (not main_args.data_config.overwrite_preprocessed_data
             and os.path.exists(prep_file)):
         log_func(
             f'Loading existing entity regularization mapping from {prep_file}'
         )
         start = time.time()
         typeid2reg = torch.load(prep_file)
         log_func(
             f'Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s'
         )
     else:
         start = time.time()
         reg_file = os.path.join(main_args.data_config.data_dir, reg_file)
         log_func(f'Building entity regularization mapping from {reg_file}')
         typeid2reg_raw = pd.read_csv(reg_file)
         assert "typeid" in typeid2reg_raw.columns and "regularization" in typeid2reg_raw.columns, f"Expected typeid and regularization as the column names for {reg_file}"
         # default of no mask
         typeid2reg_arr = [0.0] * num_types_with_pad_and_unk
         for row_idx, row in typeid2reg_raw.iterrows():
             # Happens when we filter QIDs not in our entity dump and the max typeid is smaller than the total number
             if int(row["typeid"]) not in type2row_dict:
                 continue
             typeid = type2row_dict[int(row["typeid"])]
             typeid2reg_arr[typeid] = row["regularization"]
         typeid2reg = torch.tensor(typeid2reg_arr)
         torch.save(typeid2reg, prep_file)
         log_func(
             f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s."
         )
     return typeid2reg
Example #3
0
 def load_regularization_mapping(cls, main_args, entity_symbols, reg_file,
                                 log_func):
     """
     Reads in a csv file with columns [qid, regularization].
     In the forward pass, the entity id with associated qid will be regularized with probability regularization.
     """
     reg_str = reg_file.split(".csv")[0]
     prep_dir = data_utils.get_data_prep_dir(main_args)
     prep_file = os.path.join(
         prep_dir, f'entity_regularization_mapping_{reg_str}.pt')
     utils.ensure_dir(os.path.dirname(prep_file))
     log_func(f"Looking for regularization mapping in {prep_file}")
     if (not main_args.data_config.overwrite_preprocessed_data
             and os.path.exists(prep_file)):
         log_func(
             f'Loading existing entity regularization mapping from {prep_file}'
         )
         start = time.time()
         eid2reg = torch.load(prep_file)
         log_func(
             f'Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s'
         )
     else:
         start = time.time()
         log_func(f'Building entity regularization mapping from {reg_file}')
         qid2reg = pd.read_csv(reg_file)
         assert "qid" in qid2reg.columns and "regularization" in qid2reg.columns, f"Expected qid and regularization as the column names for {reg_file}"
         # default of no mask
         eid2reg_arr = [0.0
                        ] * entity_symbols.num_entities_with_pad_and_nocand
         for row_idx, row in qid2reg.iterrows():
             eid = entity_symbols.get_eid(row["qid"])
             eid2reg_arr[eid] = row["regularization"]
         eid2reg = torch.tensor(eid2reg_arr)
         torch.save(eid2reg, prep_file)
         log_func(
             f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s."
         )
     return eid2reg
Example #4
0
    def __init__(self,
                 args,
                 use_weak_label,
                 input_src,
                 dataset_name,
                 is_writer,
                 distributed,
                 word_symbols,
                 entity_symbols,
                 slice_dataset=None,
                 dataset_is_eval=False):
        # Need to save args to reinstantiate logger
        self.args = args
        self.logger = logging_utils.get_logger(args)
        # Number of candidates, including NIL if a NIL model (train_in_candidates is False)
        self.K = entity_symbols.max_candidates + (
            not args.data_config.train_in_candidates)
        self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand
        self.dataset_name = dataset_name
        self.slice_dataset = slice_dataset
        self.dataset_is_eval = dataset_is_eval
        # Slice names used for eval slices and a slicing model
        self.slice_names = train_utils.get_data_slices(args, dataset_is_eval)
        self.storage_type_file = data_utils.get_storage_file(self.dataset_name)
        # Mappings from sent_idx to row_id in dataset
        self.sent_idx_file = os.path.splitext(
            dataset_name)[0] + "_sent_idx.json"
        self.type_pred = False
        if args.data_config.type_prediction.use_type_pred:
            self.type_pred = True
            self.eid2typeid, self.num_types_with_pad = self.load_coarse_type_table(
                args, entity_symbols)
        # Load memory mapped file
        self.logger.info("Loading dataset...")
        self.logger.debug("Seeing if " + dataset_name + " exists")
        if (args.data_config.overwrite_preprocessed_data
                or (not os.path.exists(self.dataset_name))
                or (not os.path.exists(self.sent_idx_file))
                or (not os.path.exists(self.storage_type_file))
                or (not os.path.exists(
                    data_utils.get_batch_prep_config(self.dataset_name)))):
            start = time.time()
            self.logger.debug(f"Building dataset with {input_src}")
            # Only prep data once per node
            if is_writer:
                prep_data(args,
                          use_weak_label=use_weak_label,
                          dataset_is_eval=self.dataset_is_eval,
                          input_src=input_src,
                          dataset_name=dataset_name,
                          prep_dir=data_utils.get_data_prep_dir(args))
            if distributed:
                # Make sure all processes wait for data to be created
                dist.barrier()
            self.logger.debug(
                f"Finished building and saving dataset in {round(time.time() - start, 2)}s."
            )

        start = time.time()

        # Storage type for loading memory mapped file of dataset
        self.storage_type = pickle.load(open(self.storage_type_file, 'rb'))

        self.data = np.memmap(self.dataset_name,
                              dtype=self.storage_type,
                              mode='r')
        self.data_len = len(self.data)

        # Mapping from sentence idx to rows in the dataset (indices).
        # Needed when sampling sentence indices from slices for evaluation.
        sent_idx_to_idx_str = utils.load_json_file(self.sent_idx_file)
        self.sent_idx_to_idx = {
            int(i): val
            for i, val in sent_idx_to_idx_str.items()
        }
        self.logger.info(f"Finished loading dataset.")

        # Stores info about the batch prepped embedding memory mapped files and their shapes and datatypes
        # so we can load them
        self.batch_prep_config = utils.load_json_file(
            data_utils.get_batch_prep_config(self.dataset_name))
        self.batch_prepped_emb_files = {}
        self.batch_prepped_emb_file_names = {}
        for emb in args.data_config.ent_embeddings:
            if 'batch_prep' in emb and emb['batch_prep']:
                assert emb.key in self.batch_prep_config, f'Need to prep {emb.key}. Please call prep instead of run with batch_prep_embeddings set to true.'
                self.batch_prepped_emb_file_names[emb.key] = os.path.join(
                    os.path.dirname(self.dataset_name),
                    os.path.basename(
                        self.batch_prep_config[emb.key]['file_name']))
                self.batch_prepped_emb_files[emb.key] = np.memmap(
                    self.batch_prepped_emb_file_names[emb.key],
                    dtype=self.batch_prep_config[emb.key]['dtype'],
                    shape=tuple(self.batch_prep_config[emb.key]['shape']),
                    mode='r')
                assert len(self.batch_prepped_emb_files[emb.key]) == self.data_len,\
                    f'Preprocessed emb data file {self.batch_prep_config[emb.key]["file_name"]} does not match length of main data file.'

        # Stores embeddings that we compute on the fly; these are embeddings where batch_on_the_fly is set to true.
        self.batch_on_the_fly_embs = {}
        for emb in args.data_config.ent_embeddings:
            if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True:
                mod, load_class = import_class("bootleg.embeddings",
                                               emb.load_class)
                try:
                    self.batch_on_the_fly_embs[emb.key] = getattr(
                        mod, load_class)(main_args=args,
                                         emb_args=emb['args'],
                                         entity_symbols=entity_symbols,
                                         model_device=None,
                                         word_symbols=None,
                                         key=emb.key)
                except AttributeError as e:
                    self.logger.warning(
                        f'No prep method found for {emb.load_class} with error {e}'
                    )
                except Exception as e:
                    print("ERROR", e)
        # The data in this table shouldn't be pickled since we delete it in the class __getstate__
        self.alias2entity_table = AliasEntityTable(
            args=args, entity_symbols=entity_symbols)
        # Random NIL percent
        self.mask_perc = args.train_config.random_nil_perc
        self.random_nil = False
        # Don't want to random mask for eval
        if not dataset_is_eval:
            # Whether to use a random NIL training regime
            self.random_nil = args.train_config.random_nil
            if self.random_nil:
                self.logger.info(
                    f'Using random nils during training with {self.mask_perc} percent'
                )
Example #5
0
    def load_regularization_mapping(cls, data_config,
                                    num_types_with_pad_and_unk, type2row_dict,
                                    reg_file):
        """Reads in a csv file with columns [qid, regularization].

        In the forward pass, the entity id with associated qid will be
        regularized with probability regularization.

        Args:
            data_config: data config
            num_entities_with_pad_and_nocand: number of types including pad and null option
            type2row_dict: Dict from typeID to row id in the type embedding matrix
            reg_file: regularization csv file

        Returns: Tensor where each value is the regularization value for EID
        """
        reg_str = os.path.splitext(os.path.basename(reg_file.replace("/",
                                                                     "_")))[0]
        prep_dir = data_utils.get_data_prep_dir(data_config)
        prep_file = os.path.join(prep_dir,
                                 f"type_regularization_mapping_{reg_str}.pt")
        utils.ensure_dir(os.path.dirname(prep_file))
        log_rank_0_debug(logger,
                         f"Looking for regularization mapping in {prep_file}")
        if not data_config.overwrite_preprocessed_data and os.path.exists(
                prep_file):
            log_rank_0_debug(
                logger,
                f"Loading existing entity regularization mapping from {prep_file}",
            )
            start = time.time()
            typeid2reg = torch.load(prep_file)
            log_rank_0_debug(
                logger,
                f"Loaded existing entity regularization mapping in {round(time.time() - start, 2)}s",
            )
        else:
            start = time.time()
            log_rank_0_debug(
                logger,
                f"Building entity regularization mapping from {reg_file}")
            typeid2reg_raw = pd.read_csv(reg_file)
            assert (
                "typeid" in typeid2reg_raw.columns
                and "regularization" in typeid2reg_raw.columns
            ), f"Expected typeid and regularization as the column names for {reg_file}"
            # default of no mask
            typeid2reg_arr = [0.0] * num_types_with_pad_and_unk
            for row_idx, row in typeid2reg_raw.iterrows():
                # Happens when we filter QIDs not in our entity db and the max typeid is smaller than the total number
                if int(row["typeid"]) not in type2row_dict:
                    continue
                typeid = type2row_dict[int(row["typeid"])]
                typeid2reg_arr[typeid] = row["regularization"]
            typeid2reg = torch.Tensor(typeid2reg_arr)
            torch.save(typeid2reg, prep_file)
            log_rank_0_debug(
                logger,
                f"Finished building and saving entity regularization mapping in {round(time.time() - start, 2)}s.",
            )
        return typeid2reg