コード例 #1
0
ファイル: data_utils.py プロジェクト: syyunn/bootleg
def get_eval_slice_subset_indices(args, eval_slice_dataset, dataset):
    logger = logging_utils.get_logger(args)
    logger.debug('Starting to sample indices')
    eval_slices = args.run_config.eval_slices
    # Get unique sentence indexes from samples for all slices
    # Will take union of all the data rows that map from these sentence indexes to eval
    sent_indices = set()
    for slice_name in eval_slices:
        # IF THE SEED CHANGES WHAT IS SAMPLED FOR DEV WILL CHANGE
        # randomly sample indices from slice
        slice_indexes = eval_slice_dataset.get_non_empty_sent_idxs(slice_name)
        perc_eval_examples = int(args.run_config.perc_eval * len(slice_indexes))
        data_len = max(perc_eval_examples, args.run_config.min_eval_size, 1)
        if data_len >= len(slice_indexes):
            # if requested sample is larger than the actual data, just use the whole slice
            random_indices = range(len(slice_indexes))
        else:
            random_indices = np.random.choice(len(slice_indexes), data_len, replace=False)
        for idx in random_indices:
            sent_indices.add(slice_indexes[idx])

    logger.debug('Starting to gather indices')
    # Get corresponding indices for key set
    indices = []
    for sent_idx in sent_indices:
        if sent_idx in dataset.sent_idx_to_idx:
            samples = dataset.sent_idx_to_idx[sent_idx]
            for data_idx in samples:
                indices.append(data_idx)
    logger.info(f'Sampled {len(indices)} indices from dataset (dev/test) for evaluation.')
    return indices
コード例 #2
0
ファイル: model.py プロジェクト: parakalan/bootleg
 def __init__(self, args, model_device, entity_symbols, word_symbols):
     super(Model, self).__init__()
     self.model_device = model_device
     self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand
     self.logger = logging_utils.get_logger(args)
     # embeddings
     self.emb_layer = EmbeddingLayer(args, self.model_device,
                                     entity_symbols, word_symbols)
     self.type_pred = False
     if args.data_config.type_prediction.use_type_pred:
         self.type_pred = True
         # Add 1 for pad type
         self.type_prediction = TypePred(
             args.model_config.hidden_size,
             args.data_config.type_prediction.dim,
             args.data_config.type_prediction.num_types + 1)
     self.emb_combiner = EmbCombinerProj(args, self.emb_layer.emb_sizes,
                                         self.emb_layer.sent_emb_size,
                                         word_symbols, entity_symbols)
     # attention network
     mod, load_class = import_class("bootleg.layers.attn_networks",
                                    args.model_config.attn_load_class)
     self.attn_network = getattr(mod,
                                 load_class)(args, self.emb_layer.emb_sizes,
                                             self.emb_layer.sent_emb_size,
                                             entity_symbols, word_symbols)
     # slice heads
     self.slice_heads = self.get_slice_method(args, entity_symbols)
     self.freeze_components(args)
コード例 #3
0
 def __init__(self, main_args, emb_args, model_device, entity_symbols,
              word_symbols, word_emb, key):
     super(AvgTitleEmb, self).__init__(main_args=main_args,
                                       emb_args=emb_args,
                                       model_device=model_device,
                                       entity_symbols=entity_symbols,
                                       word_symbols=word_symbols,
                                       word_emb=word_emb,
                                       key=key)
     # self.word_emb = word_emb
     self.logger = logging_utils.get_logger(main_args)
     self.model_device = model_device
     self.word_emb = word_emb
     self.word_symbols = word_symbols
     self.orig_dim = word_emb.get_dim()
     self.merge_func = self.average_titles
     self.normalize = True
     self._dim = main_args.model_config.hidden_size
     self.requires_grad_title = word_emb.requires_grad
     if "freeze_word_emb_for_titles" in emb_args:
         self.requires_grad_title = not emb_args.freeze_word_emb_for_titles
     assert not self.requires_grad_title or word_emb.requires_grad,\
         "Inconsistent Args: You have to not freeze word embeddings for titles but freeze word embeddings"
     self.entity2titleid_table = self.prep(main_args=main_args,
                                           word_symbols=word_symbols,
                                           entity_symbols=entity_symbols,
                                           log_func=self.logger.debug)
     self.entity2titleid_table = self.entity2titleid_table.to(model_device)
コード例 #4
0
 def __setstate__(self, state):
     self.__dict__.update(state)
     self.alias2entity_table = torch.tensor(
         np.memmap(self.prep_file,
                   dtype='int64',
                   mode='r',
                   shape=(self.num_aliases_with_pad, self.K)))
     self.logger = logging_utils.get_logger(self.args)
コード例 #5
0
ファイル: wiki_slices.py プロジェクト: syyunn/bootleg
 def __init__(self, args, use_weak_label, input_src, dataset_name,
              is_writer, distributed, dataset_is_eval):
     self.logger = logging_utils.get_logger(args)
     self.args = args
     self.dataset_is_eval = dataset_is_eval
     self.storage_type = None
     self.dataset_name = dataset_name
     self.config_dataset_name = data_utils.get_slice_storage_file(
         dataset_name)
     self.sent_idx_to_idx_file = data_utils.get_sent_idx_file(dataset_name)
     # load memory mapped files
     self.logger.info(f"Loading slices...")
     self.logger.debug("Seeing if " + dataset_name + " exists")
     start = time.time()
     if (not args.data_config.overwrite_preprocessed_data
             and os.path.exists(self.dataset_name)
             and os.path.exists(self.config_dataset_name)
             and os.path.exists(self.sent_idx_to_idx_file)):
         self.logger.debug(f"Will load existing dataset {dataset_name}")
     else:
         self.logger.debug(f"Building dataset with {input_src}")
         # only prep data once per node
         if is_writer:
             self.storage_type = prep_slice(
                 args=args,
                 file=os.path.basename(input_src),
                 use_weak_label=use_weak_label,
                 dataset_is_eval=dataset_is_eval,
                 dataset_name=self.dataset_name,
                 sent_idx_file=self.sent_idx_to_idx_file,
                 storage_config=self.config_dataset_name,
                 logger=self.logger)
             np.save(self.config_dataset_name,
                     self.storage_type,
                     allow_pickle=True)
         if distributed:
             # Make sure all processes wait for data to be created
             dist.barrier()
         self.logger.debug(
             f"Finished building and saving dataset in {round(time.time() - start, 2)}s."
         )
     self.storage_type = np.load(self.config_dataset_name,
                                 allow_pickle=True).item()
     self.sent_idx_arr = np.memmap(self.sent_idx_to_idx_file,
                                   dtype=np.int,
                                   mode='r')
     st = time.time()
     # Load and reformat it to be the proper recarray shape of # rows x 1
     self.data = np.expand_dims(np.memmap(self.dataset_name,
                                          dtype=self.storage_type,
                                          mode='r').view(np.recarray),
                                axis=1)
     assert len(self.data) > 0
     assert len(self.sent_idx_arr) > 0
     self.logger.info(f"Finished loading slices.")
     self.data_len = len(self.data)
コード例 #6
0
ファイル: slice_heads.py プロジェクト: syyunn/bootleg
    def __init__(self, args, entity_symbols):
        super(SliceHeadsSBL, self).__init__()
        self.logger = logging_utils.get_logger(args)
        self.dropout = args.train_config.dropout
        if "use_ind_attn" in args.model_config.custom_args and args.model_config.custom_args.use_ind_attn:
            self.logger.info('Using attention only over indicator confidences')
            self.use_ind_attn = args.model_config.custom_args.use_ind_attn
        else:
            self.use_ind_attn = False

        # Debugging parameter to see what happens when we remove the "final_loss" merging head
        if "remove_final_loss" in args.model_config.custom_args:
            self.remove_final_loss = args.model_config.custom_args.remove_final_loss
        else:
            self.remove_final_loss = False
        self.logger.info(
            f"Remove final loss: {self.remove_final_loss} and sen's trick {self.use_ind_attn}"
        )

        # Softmax temperature
        self.temperature = args.train_config.softmax_temp
        self.hidden_size = args.model_config.hidden_size
        self.K = entity_symbols.max_candidates + (
            not args.data_config.train_in_candidates)
        self.M = args.data_config.max_aliases
        self.train_heads = args.train_config.train_heads

        # Predicts whether or not an example is in the slice
        self.indicator_heads = nn.ModuleDict()
        self.ind_alias_mha = nn.ModuleDict()

        # Creates slice expert representations
        self.transform_modules = nn.ModuleDict()

        for slice_head in self.train_heads:
            # Generates a BxMxH representation that gets fed to the linear layer for predictions
            # This does an attention over the alias word and sentence (with added alias-slice learned embedding) and
            # then does an attention over the alias candidates (with added entity-slice learned embedding)
            self.ind_alias_mha[slice_head] = AliasMHA(args)
            # Binary prediction of in the slice or not
            self.indicator_heads[slice_head] = nn.Linear(self.hidden_size, 2)
            # transform layer for each slice
            self.transform_modules[slice_head] = nn.Linear(
                self.hidden_size, self.hidden_size)

        # Shared prediction layer to get confidences
        self.shared_slice_pred_head = nn.Linear(self.hidden_size, 1)

        # Embedding for each slice head for the indicator (added to queries in the AliasMHA heads)
        self.slice_emb_ind_alias = nn.Embedding(len(self.train_heads),
                                                self.hidden_size)
        self.slice_emb_ind_ent = nn.Embedding(len(self.train_heads),
                                              self.hidden_size)

        # Final prediction layer
        self.final_pred_head = nn.Linear(self.hidden_size, 1)
コード例 #7
0
 def __init__(self, main_args, emb_args, model_device, entity_symbols,
              word_symbols, word_emb, key):
     super(TypeEmb, self).__init__(main_args=main_args,
                                   emb_args=emb_args,
                                   model_device=model_device,
                                   entity_symbols=entity_symbols,
                                   word_symbols=word_symbols,
                                   word_emb=word_emb,
                                   key=key)
     self.logger = logging_utils.get_logger(main_args)
     self.merge_func = self.average_types
     self.orig_dim = emb_args.type_dim
     # Function for merging multiple types
     if "merge_func" in emb_args:
         if emb_args.merge_func not in ["average", "softattn", "addattn"]:
             self.logger.warning(
                 f"{key}: You have set the type merge_func to be {emb_args.merge_func} but that is not in the allowable list of [average, sofftattn]"
             )
         elif emb_args.merge_func == "softattn":
             if "attn_hidden_size" in emb_args:
                 attn_hidden_size = emb_args.attn_hidden_size
             else:
                 attn_hidden_size = 100
             self.logger.debug(
                 f"{key}: Setting merge_func to be soft_attn in type emb with context size {attn_hidden_size}"
             )
             # Softmax of types using the sentence context
             self.soft_attn = SoftAttn(
                 emb_dim=self.orig_dim,
                 context_dim=main_args.model_config.hidden_size,
                 size=attn_hidden_size)
             self.merge_func = self.soft_attn_merge
         elif emb_args.merge_func == "addattn":
             if "attn_hidden_size" in emb_args:
                 attn_hidden_size = emb_args.attn_hidden_size
             else:
                 attn_hidden_size = 100
             self.logger.debug(
                 f"{key}: Setting merge_func to be add_attn in type emb with context size {attn_hidden_size}"
             )
             # Softmax of types using the sentence context
             self.add_attn = PositionAwareAttention(
                 input_size=self.orig_dim,
                 attn_size=attn_hidden_size,
                 feature_size=0)
             self.merge_func = self.add_attn_merge
     self.normalize = True
     self.entity_symbols = entity_symbols
     self.max_types = emb_args.max_types
     self.eid2typeids_table, self.type2rowid, num_types_with_unk, self.prep_file = self.prep(
         main_args=main_args,
         emb_args=emb_args,
         entity_symbols=entity_symbols,
         log_func=self.logger.debug)
     self.num_types_with_pad_and_unk = num_types_with_unk + 1
     self.eid2typeids_table = self.eid2typeids_table.to(model_device)
コード例 #8
0
ファイル: wiki_slices.py プロジェクト: syyunn/bootleg
 def __setstate__(self, state):
     self.__dict__.update(state)
     self.data = np.expand_dims(np.memmap(self.dataset_name,
                                          dtype=self.storage_type,
                                          mode='r').view(np.recarray),
                                axis=1)
     self.sent_idx_arr = np.memmap(self.sent_idx_to_idx_file,
                                   dtype=np.int,
                                   mode='r')
     self.logger = logging_utils.get_logger(self.args)
コード例 #9
0
ファイル: wiki_dataset.py プロジェクト: parakalan/bootleg
 def __setstate__(self, state):
     self.__dict__.update(state)
     self.data = np.memmap(self.dataset_name,
                           dtype=self.storage_type,
                           mode='r')
     self.batch_prepped_emb_files = {}
     for emb_name, file_name in self.batch_prepped_emb_file_names.items():
         self.batch_prepped_emb_files[emb_name] = np.memmap(
             self.batch_prepped_emb_file_names[emb_name],
             dtype=self.batch_prep_config[emb_name]['dtype'],
             shape=tuple(self.batch_prep_config[emb_name]['shape']),
             mode='r')
     self.logger = logging_utils.get_logger(self.args)
コード例 #10
0
ファイル: model.py プロジェクト: syyunn/bootleg
 def __init__(self, args, model_device, entity_symbols, word_symbols):
     super(BaselineModel, self).__init__(args, model_device, entity_symbols,
                                         word_symbols)
     self.model_device = model_device
     self.logger = logging_utils.get_logger(args)
     mod, load_class = import_class("bootleg.layers.attn_networks",
                                    args.model_config.attn_load_class)
     self.emb_layer = EmbeddingLayerNoProj(args, self.model_device,
                                           entity_symbols, word_symbols)
     self.attn_network = getattr(mod,
                                 load_class)(args, self.emb_layer.emb_sizes,
                                             self.emb_layer.sent_emb_size,
                                             entity_symbols, word_symbols)
     self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand
     self.freeze_components(args)
コード例 #11
0
ファイル: attn_networks.py プロジェクト: syyunn/bootleg
 def __init__(self, args, embedding_sizes, sent_emb_size, entity_symbols,
              word_symbols):
     super(AttnNetwork, self).__init__()
     self.logger = logging_utils.get_logger(args)
     self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand
     # Number of candidates
     self.K = entity_symbols.max_candidates + (
         not args.data_config.train_in_candidates)
     # Number of aliases
     self.M = args.data_config.max_aliases
     self.sent_emb_size = sent_emb_size
     self.sent_len = args.data_config.max_word_token_len + 2 * word_symbols.is_bert
     self.hidden_size = args.model_config.hidden_size
     self.num_heads = args.model_config.num_heads
     self.num_model_stages = args.model_config.num_model_stages
     assert self.num_model_stages > 0, f"You must have > 0 model stages. You have {self.num_model_stages}"
     self.num_fc_layers = args.model_config.num_fc_layers
     self.ff_inner_size = args.model_config.ff_inner_size
コード例 #12
0
 def __init__(self, args, entity_symbols):
     super(AliasEntityTable, self).__init__()
     self.args = args
     self.logger = logging_utils.get_logger(args)
     self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand
     self.num_aliases_with_pad = len(entity_symbols.get_all_aliases()) + 1
     self.M = args.data_config.max_aliases
     self.K = entity_symbols.max_candidates + (
         not args.data_config.train_in_candidates)
     self.alias2entity_table, self.prep_file = self.prep(
         args,
         entity_symbols,
         num_aliases_with_pad=self.num_aliases_with_pad,
         num_cands_K=self.K,
         log_func=self.logger.debug)
     # Small check that loading was done correctly. This isn't a catch all, but will catch is the same or something went wrong.
     assert np.all(
         np.array(self.alias2entity_table[-1]) == np.ones(self.K) * -1
     ), f"The last row of the alias table isn't -1, something wasn't loaded right."
コード例 #13
0
ファイル: data_utils.py プロジェクト: syyunn/bootleg
def create_dataloader(args, dataset, batch_size, eval_slice_dataset=None, world_size=None, rank=None):
    logger = logging_utils.get_logger(args)
    if eval_slice_dataset is not None and not args.run_config.distributed:
        indices = get_eval_slice_subset_indices(args, eval_slice_dataset=eval_slice_dataset, dataset=dataset)
        # Form sampler with for indices from eval_slice_dataset
        sampler = SubsetRandomSampler(indices)
    elif args.run_config.distributed:
        # wrap dataset object to use a subsetsampler with distributed
        if eval_slice_dataset is not None:
            indices = get_eval_slice_subset_indices(args, eval_slice_dataset=eval_slice_dataset, dataset=dataset)
            dataset = DistributedIndicesWrapper(dataset, torch.tensor(indices))
        sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    else:
        sampler = None
    dataloader = DataLoader(dataset, batch_size=batch_size,
                            shuffle=(sampler is None),
                            sampler=sampler,
                            num_workers=args.run_config.dataloader_threads,
                            pin_memory=False)
    return dataloader, sampler
コード例 #14
0
ファイル: base_emb.py プロジェクト: paper2code/bootleg
 def __init__(self, main_args, emb_args, model_device, entity_symbols,
              word_symbols, word_emb, key):
     super(EntityEmb, self).__init__()
     # Metaclasses are types that create classes
     # https://stackoverflow.com/questions/100003/what-are-metaclasses-in-python
     # We use this to enforce that a self.normalize attribute is instantiated in subclasses
     __metaclass__ = RequiredAttributes("normalize")
     self.logger = logging_utils.get_logger(main_args)
     self.entity_symbols = entity_symbols
     self.key = key
     self.dropout_perc = 0
     self.mask_perc = 0
     # Used for 2d dropout
     if MASK_PERC in emb_args:
         self.mask_perc = emb_args[MASK_PERC]
         self.logger.debug(
             f'Setting {self.key} mask perc to {self.mask_perc}')
     # Used for 1d dropout
     if "dropout" in emb_args:
         self.dropout_perc = emb_args.dropout_perc
         self.logger.debug(f'Setting {self.key} dropout to {self.dropout}')
コード例 #15
0
 def __init__(self, args, model_device, entity_symbols, word_symbols):
     super(Model, self).__init__()
     self.model_device = model_device
     self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand
     self.logger = logging_utils.get_logger(args)
     # embeddings
     self.emb_layer = EmbeddingLayer(args, self.model_device,
                                     entity_symbols, word_symbols)
     self.emb_combiner = EmbCombinerProj(args, self.emb_layer.emb_sizes,
                                         self.emb_layer.sent_emb_size,
                                         word_symbols, entity_symbols)
     # attention network
     mod, load_class = import_class("bootleg.layers.attn_networks",
                                    args.model_config.attn_load_class)
     self.attn_network = getattr(mod,
                                 load_class)(args, self.emb_layer.emb_sizes,
                                             self.emb_layer.sent_emb_size,
                                             entity_symbols, word_symbols)
     # slice heads
     self.slice_heads = self.get_slice_method(args, entity_symbols)
     self.freeze_components(args)
コード例 #16
0
ファイル: entity_embs.py プロジェクト: paper2code/bootleg
 def __init__(self, main_args, emb_args, model_device, entity_symbols,
              word_symbols, word_emb, key):
     super(LearnedEntityEmb, self).__init__(main_args=main_args,
                                            emb_args=emb_args,
                                            model_device=model_device,
                                            entity_symbols=entity_symbols,
                                            word_symbols=word_symbols,
                                            word_emb=word_emb,
                                            key=key)
     self.logger = logging_utils.get_logger(main_args)
     self.learned_embedding_size = emb_args.learned_embedding_size
     self.normalize = True
     self.learned_entity_embedding = nn.Embedding(
         entity_symbols.num_entities_with_pad_and_nocand,
         self.learned_embedding_size,
         padding_idx=-1,
         sparse=True)
     # If tail_init is false, all embeddings are randomly intialized.
     # If tail_init is true, we initialize all embeddings to be the same.
     tail_init = True
     if "tail_init" in emb_args:
         tail_init = emb_args.tail_init
     if tail_init:
         qid_count_dict = {}
         self.logger.debug(
             f"All learned entity embeddings are initialized to the same value."
         )
         init_vec = model_utils.init_tail_embeddings(
             self.learned_entity_embedding,
             qid_count_dict,
             entity_symbols,
             pad_idx=-1)
     else:
         self.logger.debug(
             f"All learned embeddings are randomly initialized.")
     self._dim = main_args.model_config.hidden_size
     self.dropout = nn.Dropout(self.dropout_perc)
コード例 #17
0
 def __init__(self,
              main_args,
              emb_args,
              model_device,
              entity_symbols,
              word_symbols=None,
              word_emb=None,
              key=""):
     super(KGWeightedAdjEmb, self).__init__(main_args=main_args,
                                            emb_args=emb_args,
                                            model_device=model_device,
                                            entity_symbols=entity_symbols,
                                            word_symbols=word_symbols,
                                            word_emb=word_emb,
                                            key=key)
     # needed to recreate logger
     self.main_args = main_args
     self.logger = logging_utils.get_logger(main_args)
     self._dim = 1
     self.model_device = model_device
     self.kg_adj, self.prep_file = self.prep(main_args=main_args,
                                             emb_args=emb_args,
                                             entity_symbols=entity_symbols,
                                             log_func=self.logger.debug)
コード例 #18
0
 def __setstate__(self, state):
     self.__dict__.update(state)
     self.logger = logging_utils.get_logger(self.main_args)
     # we can assume the adjacency matrix has already been built and saved
     self.kg_adj = scipy.sparse.load_npz(self.prep_file)
コード例 #19
0
ファイル: base_sent_emb.py プロジェクト: syyunn/bootleg
 def __init__(self, emb_args, main_args, word_emb_dim, word_symbols):
     super(BaseSentEmbedding, self).__init__()
     self.logger = logging_utils.get_logger(main_args)
     self._key = "sentence"
     self._dim = word_emb_dim
コード例 #20
0
 def __init__(self,
              main_args,
              emb_args,
              model_device,
              entity_symbols,
              word_symbols=None,
              word_emb=None,
              key=""):
     super(KGRelEmb, self).__init__(main_args=main_args,
                                    emb_args=emb_args,
                                    model_device=model_device,
                                    entity_symbols=entity_symbols,
                                    word_symbols=word_symbols,
                                    word_emb=word_emb,
                                    key=key)
     # needed to recreate logger
     self.main_args = main_args
     self.logger = logging_utils.get_logger(main_args)
     self.model_device = model_device
     self.mask_candidates = True
     self.kg_adj, self.kg_relations, self.rel2rowid, self.prep_file_adj = self.prep(
         main_args=main_args,
         emb_args=emb_args,
         entity_symbols=entity_symbols,
         log_func=self.logger.debug)
     # initialize learned relation embedding
     self.num_relations_with_pad = len(self.kg_relations) + 1
     self._dim = emb_args.rel_dim
     # Sparse cannot be true for the relation embedding or we get distributed Gloo errors depending on the batch (Gloo EnforceNotMet...)
     self.relation_emb = torch.nn.Embedding(self.num_relations_with_pad,
                                            self._dim,
                                            padding_idx=0,
                                            sparse=False).to(model_device)
     self.merge_func = self.average_rels
     if "merge_func" in emb_args:
         if emb_args.merge_func not in ["average", "softattn", "addattn"]:
             self.logger.warning(
                 f"{key}: You have set the type merge_func to be {emb_args.merge_func} but that is not in the allowable list of [average, sofftattn]"
             )
         elif emb_args.merge_func == "softattn":
             if "attn_hidden_size" in emb_args:
                 attn_hidden_size = emb_args.attn_hidden_size
             else:
                 attn_hidden_size = 100
             self.sub_chunk = 3000
             self.logger.debug(
                 f"{key}: Setting merge_func to be soft_attn in relation emb with context size {attn_hidden_size} and sub_chunk of {self.sub_chunk}"
             )
             # Softmax of types using the sentence context
             self.soft_attn = SoftAttn(
                 emb_dim=self._dim,
                 context_dim=main_args.model_config.hidden_size,
                 size=attn_hidden_size)
             self.merge_func = self.soft_attn_merge
         elif emb_args.merge_func == "addattn":
             if "attn_hidden_size" in emb_args:
                 attn_hidden_size = emb_args.attn_hidden_size
             else:
                 attn_hidden_size = 100
             self.logger.info(
                 f"{key}: Setting merge_func to be add_attn in relation emb with context size {attn_hidden_size}"
             )
             # Softmax of types using the sentence context
             self.add_attn = PositionAwareAttention(
                 input_size=self._dim,
                 attn_size=attn_hidden_size,
                 feature_size=0)
             self.merge_func = self.add_attn_merge
     self.normalize = True
     self.logger.debug(
         f'{key}: Using {self._dim}-dim relation embedding for {len(self.kg_relations)} relations with 1 pad relation. Normalize is {self.normalize}'
     )
コード例 #21
0
ファイル: type_embs.py プロジェクト: syyunn/bootleg
 def __init__(self, main_args, emb_args, model_device, entity_symbols,
              word_symbols, word_emb, key):
     super(TypeEmb, self).__init__(main_args=main_args,
                                   emb_args=emb_args,
                                   model_device=model_device,
                                   entity_symbols=entity_symbols,
                                   word_symbols=word_symbols,
                                   word_emb=word_emb,
                                   key=key)
     self.logger = logging_utils.get_logger(main_args)
     self.merge_func = self.average_types
     self.orig_dim = emb_args.type_dim
     # Function for merging multiple types
     if "merge_func" in emb_args:
         if emb_args.merge_func not in ["average", "softattn", "addattn"]:
             self.logger.warning(
                 f"{key}: You have set the type merge_func to be {emb_args.merge_func} but that is not in the allowable list of [average, sofftattn]"
             )
         elif emb_args.merge_func == "softattn":
             if "attn_hidden_size" in emb_args:
                 attn_hidden_size = emb_args.attn_hidden_size
             else:
                 attn_hidden_size = 100
             # Softmax of types using the sentence context
             self.soft_attn = SoftAttn(
                 emb_dim=self.orig_dim,
                 context_dim=main_args.model_config.hidden_size,
                 size=attn_hidden_size)
             self.merge_func = self.soft_attn_merge
         elif emb_args.merge_func == "addattn":
             if "attn_hidden_size" in emb_args:
                 attn_hidden_size = emb_args.attn_hidden_size
             else:
                 attn_hidden_size = 100
             # Softmax of types using the sentence context
             self.add_attn = PositionAwareAttention(
                 input_size=self.orig_dim,
                 attn_size=attn_hidden_size,
                 feature_size=0)
             self.merge_func = self.add_attn_merge
     self.normalize = True
     self.entity_symbols = entity_symbols
     self.max_types = emb_args.max_types
     self.eid2typeids_table, self.type2row_dict, num_types_with_unk, self.prep_file = self.prep(
         main_args=main_args,
         emb_args=emb_args,
         entity_symbols=entity_symbols,
         log_func=self.logger.debug)
     self.num_types_with_pad_and_unk = num_types_with_unk + 1
     self.eid2typeids_table = self.eid2typeids_table.to(model_device)
     # Regularization mapping goes from typeid to 2d dropout percent
     self.typeid2reg = None
     if "regularize_mapping" in emb_args:
         self.logger.debug(
             f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}"
         )
         self.typeid2reg = self.load_regularization_mapping(
             main_args, self.num_types_with_pad_and_unk, self.type2row_dict,
             emb_args.regularize_mapping, self.logger.debug)
         self.typeid2reg = self.typeid2reg.to(model_device)
     assert self.eid2typeids_table.shape[1] == emb_args.max_types, f"Something went wrong with loading type file." \
                                                          f" The given max types {emb_args.max_types} does not match that of type table {self.eid2typeids_table.shape[1]}"
     self.logger.debug(
         f"{key}: Type embedding with {self.max_types} types with dim {self.orig_dim}. Setting merge_func to be {self.merge_func.__name__} in type emb."
     )
コード例 #22
0
ファイル: wiki_dataset.py プロジェクト: parakalan/bootleg
    def __init__(self,
                 args,
                 use_weak_label,
                 input_src,
                 dataset_name,
                 is_writer,
                 distributed,
                 word_symbols,
                 entity_symbols,
                 slice_dataset=None,
                 dataset_is_eval=False):
        # Need to save args to reinstantiate logger
        self.args = args
        self.logger = logging_utils.get_logger(args)
        # Number of candidates, including NIL if a NIL model (train_in_candidates is False)
        self.K = entity_symbols.max_candidates + (
            not args.data_config.train_in_candidates)
        self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand
        self.dataset_name = dataset_name
        self.slice_dataset = slice_dataset
        self.dataset_is_eval = dataset_is_eval
        # Slice names used for eval slices and a slicing model
        self.slice_names = train_utils.get_data_slices(args, dataset_is_eval)
        self.storage_type_file = data_utils.get_storage_file(self.dataset_name)
        # Mappings from sent_idx to row_id in dataset
        self.sent_idx_file = os.path.splitext(
            dataset_name)[0] + "_sent_idx.json"
        self.type_pred = False
        if args.data_config.type_prediction.use_type_pred:
            self.type_pred = True
            self.eid2typeid, self.num_types_with_pad = self.load_coarse_type_table(
                args, entity_symbols)
        # Load memory mapped file
        self.logger.info("Loading dataset...")
        self.logger.debug("Seeing if " + dataset_name + " exists")
        if (args.data_config.overwrite_preprocessed_data
                or (not os.path.exists(self.dataset_name))
                or (not os.path.exists(self.sent_idx_file))
                or (not os.path.exists(self.storage_type_file))
                or (not os.path.exists(
                    data_utils.get_batch_prep_config(self.dataset_name)))):
            start = time.time()
            self.logger.debug(f"Building dataset with {input_src}")
            # Only prep data once per node
            if is_writer:
                prep_data(args,
                          use_weak_label=use_weak_label,
                          dataset_is_eval=self.dataset_is_eval,
                          input_src=input_src,
                          dataset_name=dataset_name,
                          prep_dir=data_utils.get_data_prep_dir(args))
            if distributed:
                # Make sure all processes wait for data to be created
                dist.barrier()
            self.logger.debug(
                f"Finished building and saving dataset in {round(time.time() - start, 2)}s."
            )

        start = time.time()

        # Storage type for loading memory mapped file of dataset
        self.storage_type = pickle.load(open(self.storage_type_file, 'rb'))

        self.data = np.memmap(self.dataset_name,
                              dtype=self.storage_type,
                              mode='r')
        self.data_len = len(self.data)

        # Mapping from sentence idx to rows in the dataset (indices).
        # Needed when sampling sentence indices from slices for evaluation.
        sent_idx_to_idx_str = utils.load_json_file(self.sent_idx_file)
        self.sent_idx_to_idx = {
            int(i): val
            for i, val in sent_idx_to_idx_str.items()
        }
        self.logger.info(f"Finished loading dataset.")

        # Stores info about the batch prepped embedding memory mapped files and their shapes and datatypes
        # so we can load them
        self.batch_prep_config = utils.load_json_file(
            data_utils.get_batch_prep_config(self.dataset_name))
        self.batch_prepped_emb_files = {}
        self.batch_prepped_emb_file_names = {}
        for emb in args.data_config.ent_embeddings:
            if 'batch_prep' in emb and emb['batch_prep']:
                assert emb.key in self.batch_prep_config, f'Need to prep {emb.key}. Please call prep instead of run with batch_prep_embeddings set to true.'
                self.batch_prepped_emb_file_names[emb.key] = os.path.join(
                    os.path.dirname(self.dataset_name),
                    os.path.basename(
                        self.batch_prep_config[emb.key]['file_name']))
                self.batch_prepped_emb_files[emb.key] = np.memmap(
                    self.batch_prepped_emb_file_names[emb.key],
                    dtype=self.batch_prep_config[emb.key]['dtype'],
                    shape=tuple(self.batch_prep_config[emb.key]['shape']),
                    mode='r')
                assert len(self.batch_prepped_emb_files[emb.key]) == self.data_len,\
                    f'Preprocessed emb data file {self.batch_prep_config[emb.key]["file_name"]} does not match length of main data file.'

        # Stores embeddings that we compute on the fly; these are embeddings where batch_on_the_fly is set to true.
        self.batch_on_the_fly_embs = {}
        for emb in args.data_config.ent_embeddings:
            if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True:
                mod, load_class = import_class("bootleg.embeddings",
                                               emb.load_class)
                try:
                    self.batch_on_the_fly_embs[emb.key] = getattr(
                        mod, load_class)(main_args=args,
                                         emb_args=emb['args'],
                                         entity_symbols=entity_symbols,
                                         model_device=None,
                                         word_symbols=None,
                                         key=emb.key)
                except AttributeError as e:
                    self.logger.warning(
                        f'No prep method found for {emb.load_class} with error {e}'
                    )
                except Exception as e:
                    print("ERROR", e)
        # The data in this table shouldn't be pickled since we delete it in the class __getstate__
        self.alias2entity_table = AliasEntityTable(
            args=args, entity_symbols=entity_symbols)
        # Random NIL percent
        self.mask_perc = args.train_config.random_nil_perc
        self.random_nil = False
        # Don't want to random mask for eval
        if not dataset_is_eval:
            # Whether to use a random NIL training regime
            self.random_nil = args.train_config.random_nil
            if self.random_nil:
                self.logger.info(
                    f'Using random nils during training with {self.mask_perc} percent'
                )
コード例 #23
0
ファイル: scorer.py プロジェクト: syyunn/bootleg
 def __init__(self, args=None, model_device=None):
     self.model_device = model_device
     self.logger = logging_utils.get_logger(args)
     self.crit_pred = nn.NLLLoss(ignore_index=-1)
     self.type_pred = nn.CrossEntropyLoss(ignore_index=-1)
     self.weights = {train_utils.get_slice_head_ind_name(slice_name):None for slice_name in args.train_config.train_heads}
コード例 #24
0
ファイル: entity_embs.py プロジェクト: syyunn/bootleg
 def __init__(self, main_args, emb_args, model_device, entity_symbols,
              word_symbols, word_emb, key):
     super(TopKEntityEmb, self).__init__(main_args=main_args,
                                         emb_args=emb_args,
                                         model_device=model_device,
                                         entity_symbols=entity_symbols,
                                         word_symbols=word_symbols,
                                         word_emb=word_emb,
                                         key=key)
     self.logger = logging_utils.get_logger(main_args)
     self.learned_embedding_size = emb_args.learned_embedding_size
     self.normalize = True
     assert "perc_emb_drop" in emb_args, f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings" \
                                         f" removed."
     # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding
     num_topk_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand - int(
         emb_args.perc_emb_drop * entity_symbols.num_entities) + 1
     # Mapping of entity to the new eid mapping
     qid2topk_eid = {}
     eid2topkeid = torch.arange(
         0, entity_symbols.num_entities_with_pad_and_nocand)
     # There are issues with using -1 index into the embeddings; so we manually set it to be the last value
     eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1
     if "qid2topk_eid" not in emb_args:
         assert len(main_args.run_config.init_checkpoint) > 0, f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, " \
                                                               f"you must be loading a model from a checkpoint to build this index mapping"
     else:
         qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid)
         assert len(
             qid2topk_eid
         ) == entity_symbols.num_entities, f"You must have an item in qid2topk_eid for each qid in entity_symbols"
         for qid in entity_symbols.get_all_qids():
             old_eid = entity_symbols.get_eid(qid)
             new_eid = qid2topk_eid[qid]
             eid2topkeid[old_eid] = new_eid
         assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
         assert eid2topkeid[
             -1] == num_topk_entities_with_pad_and_nocand - 1, "The -1 eid should still map to -1"
     self.learned_entity_embedding = nn.Embedding(
         num_topk_entities_with_pad_and_nocand,
         self.learned_embedding_size,
         padding_idx=-1,
         sparse=True)
     # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping
     self.register_buffer("eid2topkeid", eid2topkeid)
     # If tail_init is false, all embeddings are randomly intialized.
     # If tail_init is true, we initialize all embeddings to be the same.
     self.tail_init = True
     self.tail_init_zeros = False
     # None init vec will be random
     init_vec = None
     if "tail_init" in emb_args:
         self.tail_init = emb_args.tail_init
     if "tail_init_zeros" in emb_args:
         self.tail_init_zeros = emb_args.tail_init_zeros
         self.tail_init = False
         init_vec = torch.zeros(1, self.learned_embedding_size)
     assert not (self.tail_init and self.tail_init_zeros
                 ), f"Can only have one of tail_init or tail_init_zeros set"
     if self.tail_init or self.tail_init_zeros:
         if self.tail_init_zeros:
             self.logger.debug(
                 f"All learned entity embeddings are initialized to zero.")
         else:
             self.logger.debug(
                 f"All learned entity embeddings are initialized to the same value."
             )
         init_vec = model_utils.init_embeddings_to_vec(
             self.learned_entity_embedding, pad_idx=-1, vec=init_vec)
         vec_save_file = os.path.join(
             train_utils.get_save_folder(main_args.run_config),
             "init_vec_entity_embs.npy")
         self.logger.debug(f"Saving init vector to {vec_save_file}")
         np.save(vec_save_file, init_vec)
     else:
         self.logger.debug(
             f"All learned embeddings are randomly initialized.")
     self._dim = main_args.model_config.hidden_size
     self.eid2reg = None
     # Regularization mapping goes from eid to 2d dropout percent
     if "regularize_mapping" in emb_args:
         self.logger.warning(
             f"You are using regularization mapping with a topK entity embedding. This means all QIDs that are mapped to the same"
             f" EID will get the same regularization value.")
         self.logger.debug(
             f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}"
         )
         self.eid2reg = self.load_regularization_mapping(
             main_args, qid2topk_eid, num_topk_entities_with_pad_and_nocand,
             emb_args.regularize_mapping, self.logger.debug)
         self.eid2reg = self.eid2reg.to(model_device)
コード例 #25
0
 def __init__(self, args, main_args, word_symbols):
     super(BaseWordEmbedding, self).__init__()
     self.logger = logging_utils.get_logger(main_args)
     self._key = "word"
     self.pad_id = word_symbols.pad_id
コード例 #26
0
 def __setstate__(self, state):
     self.__dict__.update(state)
     self.logger = logging_utils.get_logger(self.main_args)
コード例 #27
0
    def __init__(self,
                 args=None,
                 entity_symbols=None,
                 word_symbols=None,
                 total_steps_per_epoch=0,
                 resume_model_file="",
                 eval_slice_names=None,
                 model_eval=False):
        self.model_eval = model_eval  # keep track of mode for model loading
        self.distributed = args.run_config.distributed
        self.args = args
        self.total_steps_per_epoch = total_steps_per_epoch
        self.start_epoch = 0
        self.start_step = 0
        self.use_cuda = not args.run_config.cpu and torch.cuda.is_available()
        self.logger = logging_utils.get_logger(args)
        if not self.use_cuda:
            self.model_device = "cpu"
            self.embedding_device = "cpu"
        else:
            self.model_device = args.run_config.gpu
            self.embedding_device = args.run_config.gpu
        # Load base model
        mod, load_class = import_class("bootleg",
                                       args.model_config.base_model_load_class)
        self.model = getattr(mod, load_class)(args=args,
                                              model_device=self.model_device,
                                              entity_symbols=entity_symbols,
                                              word_symbols=word_symbols)
        self.use_eval_wrapper = False
        if eval_slice_names is not None:
            self.use_eval_wrapper = True
            # Mapping of all output heads to indexes for the buffers
            head_key_to_idx = train_utils.get_head_key_to_idx(args)
            self.eval_wrapper = EvalWrapper(
                args=args,
                head_key_to_idx=head_key_to_idx,
                eval_slice_names=eval_slice_names,
                train_head_names=args.train_config.train_heads)
            self.eval_wrapper.to(self.model_device)
        self.optimizer = SparseDenseAdam(
            list(self.model.parameters()),
            lr=args.train_config.lr,
            weight_decay=args.train_config.weight_decay)

        self.scorer = Scorer(args, self.model_device)

        self.model.to(self.model_device)
        if self.distributed:
            # move everything to GPU
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.model_device],
                find_unused_parameters=True)

        # load model into existing model if model_file is provided
        if resume_model_file.endswith(".pt"):
            self.logger.info(f'Loading model from {resume_model_file}...')
            self.load(resume_model_file)

        self.logger.debug("Model device " + str(self.model_device))
        self.logger.debug("Embedding device " + str(self.embedding_device))
        self.logger.debug(
            f"*************************MODEL PARAMS WITH GRAD*************************"
        )
        self.logger.debug(
            f'Number of model parameters with grad: {count_parameters(self.model, True, self.logger)}'
        )
        self.logger.debug(
            f"*************************MODEL PARAMS WITHOUT GRAD*************************"
        )
        self.logger.debug(
            f'Number of model parameters without grad: {count_parameters(self.model, False, self.logger)}'
        )