コード例 #1
0
    def __init__(self, args, model_device, entity_symbols, word_symbols):
        super(EmbeddingLayer, self).__init__()
        self.logger = logging_utils.get_logger(args)
        self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand

        # Word Embedding (passed to Sentence and Entity Embedding classess)
        mod, load_class = import_class(
            "bootleg.embeddings.word_embeddings",
            args.data_config.word_embedding.load_class)
        self.word_emb = getattr(mod,
                                load_class)(args.data_config.word_embedding,
                                            args, word_symbols)

        # Sentence Embedding
        mod, load_class = import_class(
            "bootleg.embeddings.word_embeddings",
            args.data_config.word_embedding.sent_class)
        self.sent_emb = getattr(mod,
                                load_class)(args.data_config.word_embedding,
                                            args, self.word_emb.get_dim(),
                                            word_symbols)

        # Entity Embedding
        self.entity_embs = nn.ModuleDict()
        self.logger.info('Loading embeddings...')
        for emb in args.data_config.ent_embeddings:
            try:
                emb_args = emb.args
            except:
                emb_args = None
            assert "load_class" in emb, "You must specify a load_class in the embedding config: {load_class: ..., key: ...}"
            assert "key" in emb, "You must specify a key in the embedding config: {load_class: ..., key: ...}"
            mod, load_class = import_class("bootleg.embeddings",
                                           emb.load_class)
            emb_obj = getattr(mod, load_class)(main_args=args,
                                               emb_args=emb_args,
                                               model_device=model_device,
                                               entity_symbols=entity_symbols,
                                               word_symbols=word_symbols,
                                               word_emb=self.word_emb,
                                               key=emb.key)
            self.entity_embs[emb.key] = emb_obj
        self.logger.info('Finished loading embeddings.')

        # Track the dimensions of different embeddings
        self.emb_sizes = {}
        for emb in self.entity_embs.values():
            key = emb.key
            dim = emb.get_dim()
            assert not key in self.emb_sizes, f"Can't have duplicate keys in your embeddings and {key} is already here"
            self.emb_sizes[key] = dim
        self.sent_emb_size = self.sent_emb._dim

        self.project_sent = MLP(input_size=self.sent_emb_size,
                                num_hidden_units=0,
                                output_size=args.model_config.hidden_size,
                                num_layers=1,
                                dropout=0,
                                residual=False,
                                activation=None)
コード例 #2
0
ファイル: model.py プロジェクト: syyunn/bootleg
 def __init__(self, args, model_device, entity_symbols, word_symbols):
     super(Model, self).__init__()
     self.model_device = model_device
     self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand
     self.logger = logging_utils.get_logger(args)
     # embeddings
     self.emb_layer = EmbeddingLayer(args, self.model_device,
                                     entity_symbols, word_symbols)
     self.type_pred = False
     if args.data_config.type_prediction.use_type_pred:
         self.type_pred = True
         # Add 1 for pad type
         self.type_prediction = TypePred(
             args.model_config.hidden_size,
             args.data_config.type_prediction.dim,
             args.data_config.type_prediction.num_types + 1)
     self.emb_combiner = EmbCombinerProj(args, self.emb_layer.emb_sizes,
                                         self.emb_layer.sent_emb_size,
                                         word_symbols, entity_symbols)
     # attention network
     mod, load_class = import_class("bootleg.layers.attn_networks",
                                    args.model_config.attn_load_class)
     self.attn_network = getattr(mod,
                                 load_class)(args, self.emb_layer.emb_sizes,
                                             self.emb_layer.sent_emb_size,
                                             entity_symbols, word_symbols)
     # slice heads
     self.slice_heads = self.get_slice_method(args, entity_symbols)
     self.freeze_components(args)
コード例 #3
0
ファイル: data_utils.py プロジェクト: paper2code/bootleg
def create_dataset(args,
                   data_args,
                   is_writer,
                   word_symbols,
                   entity_symbols,
                   slice_dataset=None,
                   dataset_is_eval=False):
    dataset_name = generate_save_data_name(
        data_args=args.data_config,
        use_weak_label=data_args.use_weak_label,
        split_name=os.path.splitext(data_args.file)[0])
    prep_dir = get_data_prep_dir(args)
    full_dataset_name = os.path.join(prep_dir, dataset_name)
    mod, load_class = import_class("bootleg.dataloaders", data_args.load_class)
    dataset = getattr(mod, load_class)(args=args,
                                       use_weak_label=data_args.use_weak_label,
                                       input_src=os.path.join(
                                           args.data_config.data_dir,
                                           data_args.file),
                                       dataset_name=full_dataset_name,
                                       is_writer=is_writer,
                                       distributed=args.run_config.distributed,
                                       word_symbols=word_symbols,
                                       entity_symbols=entity_symbols,
                                       slice_dataset=slice_dataset,
                                       dataset_is_eval=dataset_is_eval)
    return dataset
コード例 #4
0
ファイル: data_utils.py プロジェクト: syyunn/bootleg
def create_slice_dataset(args, data_args, is_writer, dataset_is_eval):
    # Note that the weak labelling is going to alter our indexing for the slices. Our slices still only score gold==True
    dataset_name = generate_slice_name(args, args.data_config, use_weak_label=data_args.use_weak_label,
                                           split_name="slice_" + os.path.splitext(data_args.file)[0],
                                           dataset_is_eval=dataset_is_eval)
    prep_dir = get_data_prep_dir(args)
    full_dataset_name = os.path.join(prep_dir, dataset_name)
    mod, load_class = import_class("bootleg.dataloaders", data_args.slice_class)
    dataset = getattr(mod, load_class)(args=args, use_weak_label=data_args.use_weak_label, input_src=os.path.join(args.data_config.data_dir, data_args.file),
                      dataset_name=full_dataset_name, is_writer=is_writer,
                      distributed=args.run_config.distributed, dataset_is_eval=dataset_is_eval)
    return dataset
コード例 #5
0
    def __init__(self,
                 config_args,
                 device='cuda',
                 max_alias_len=6,
                 cand_map=None,
                 threshold=0.0):
        self.args = config_args
        self.device = device
        self.entity_db = EntitySymbols(
            os.path.join(self.args.data_config.entity_dir,
                         self.args.data_config.entity_map_dir),
            alias_cand_map_file=self.args.data_config.alias_cand_map)
        self.word_db = data_utils.load_wordsymbols(self.args.data_config,
                                                   is_writer=True,
                                                   distributed=False)
        self.model = self._load_model()
        self.max_alias_len = max_alias_len
        if cand_map is None:
            alias_map = self.entity_db._alias2qids
        else:
            alias_map = ujson.load(open(cand_map))
        self.all_aliases_trie = get_all_aliases(alias_map,
                                                logger=logging.getLogger())
        self.alias_table = AliasEntityTable(args=self.args,
                                            entity_symbols=self.entity_db)

        # minimum probability of prediction to return mention
        self.threshold = threshold

        # get batch_on_the_fly embeddings _and_ the batch_prep embeddings
        self.batch_on_the_fly_embs = {}
        for i, emb in enumerate(self.args.data_config.ent_embeddings):
            if 'batch_prep' in emb and emb['batch_prep'] is True:
                self.args.data_config.ent_embeddings[i][
                    'batch_on_the_fly'] = True
                del self.args.data_config.ent_embeddings[i]['batch_prep']
            if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True:
                mod, load_class = import_class("bootleg.embeddings",
                                               emb.load_class)
                try:
                    self.batch_on_the_fly_embs[emb.key] = getattr(
                        mod, load_class)(main_args=self.args,
                                         emb_args=emb['args'],
                                         entity_symbols=self.entity_db,
                                         model_device=None,
                                         word_symbols=None)
                except AttributeError as e:
                    print(
                        f'No prep method found for {emb.load_class} with error {e}'
                    )
                except Exception as e:
                    print("ERROR", e)
コード例 #6
0
ファイル: model.py プロジェクト: syyunn/bootleg
 def __init__(self, args, model_device, entity_symbols, word_symbols):
     super(BaselineModel, self).__init__(args, model_device, entity_symbols,
                                         word_symbols)
     self.model_device = model_device
     self.logger = logging_utils.get_logger(args)
     mod, load_class = import_class("bootleg.layers.attn_networks",
                                    args.model_config.attn_load_class)
     self.emb_layer = EmbeddingLayerNoProj(args, self.model_device,
                                           entity_symbols, word_symbols)
     self.attn_network = getattr(mod,
                                 load_class)(args, self.emb_layer.emb_sizes,
                                             self.emb_layer.sent_emb_size,
                                             entity_symbols, word_symbols)
     self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand
     self.freeze_components(args)
コード例 #7
0
 def __init__(self, args, model_device, entity_symbols, word_symbols):
     super(Model, self).__init__()
     self.model_device = model_device
     self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand
     self.logger = logging_utils.get_logger(args)
     # embeddings
     self.emb_layer = EmbeddingLayer(args, self.model_device,
                                     entity_symbols, word_symbols)
     self.emb_combiner = EmbCombinerProj(args, self.emb_layer.emb_sizes,
                                         self.emb_layer.sent_emb_size,
                                         word_symbols, entity_symbols)
     # attention network
     mod, load_class = import_class("bootleg.layers.attn_networks",
                                    args.model_config.attn_load_class)
     self.attn_network = getattr(mod,
                                 load_class)(args, self.emb_layer.emb_sizes,
                                             self.emb_layer.sent_emb_size,
                                             entity_symbols, word_symbols)
     # slice heads
     self.slice_heads = self.get_slice_method(args, entity_symbols)
     self.freeze_components(args)
コード例 #8
0
ファイル: wiki_dataset.py プロジェクト: parakalan/bootleg
    def __init__(self,
                 args,
                 use_weak_label,
                 input_src,
                 dataset_name,
                 is_writer,
                 distributed,
                 word_symbols,
                 entity_symbols,
                 slice_dataset=None,
                 dataset_is_eval=False):
        # Need to save args to reinstantiate logger
        self.args = args
        self.logger = logging_utils.get_logger(args)
        # Number of candidates, including NIL if a NIL model (train_in_candidates is False)
        self.K = entity_symbols.max_candidates + (
            not args.data_config.train_in_candidates)
        self.num_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand
        self.dataset_name = dataset_name
        self.slice_dataset = slice_dataset
        self.dataset_is_eval = dataset_is_eval
        # Slice names used for eval slices and a slicing model
        self.slice_names = train_utils.get_data_slices(args, dataset_is_eval)
        self.storage_type_file = data_utils.get_storage_file(self.dataset_name)
        # Mappings from sent_idx to row_id in dataset
        self.sent_idx_file = os.path.splitext(
            dataset_name)[0] + "_sent_idx.json"
        self.type_pred = False
        if args.data_config.type_prediction.use_type_pred:
            self.type_pred = True
            self.eid2typeid, self.num_types_with_pad = self.load_coarse_type_table(
                args, entity_symbols)
        # Load memory mapped file
        self.logger.info("Loading dataset...")
        self.logger.debug("Seeing if " + dataset_name + " exists")
        if (args.data_config.overwrite_preprocessed_data
                or (not os.path.exists(self.dataset_name))
                or (not os.path.exists(self.sent_idx_file))
                or (not os.path.exists(self.storage_type_file))
                or (not os.path.exists(
                    data_utils.get_batch_prep_config(self.dataset_name)))):
            start = time.time()
            self.logger.debug(f"Building dataset with {input_src}")
            # Only prep data once per node
            if is_writer:
                prep_data(args,
                          use_weak_label=use_weak_label,
                          dataset_is_eval=self.dataset_is_eval,
                          input_src=input_src,
                          dataset_name=dataset_name,
                          prep_dir=data_utils.get_data_prep_dir(args))
            if distributed:
                # Make sure all processes wait for data to be created
                dist.barrier()
            self.logger.debug(
                f"Finished building and saving dataset in {round(time.time() - start, 2)}s."
            )

        start = time.time()

        # Storage type for loading memory mapped file of dataset
        self.storage_type = pickle.load(open(self.storage_type_file, 'rb'))

        self.data = np.memmap(self.dataset_name,
                              dtype=self.storage_type,
                              mode='r')
        self.data_len = len(self.data)

        # Mapping from sentence idx to rows in the dataset (indices).
        # Needed when sampling sentence indices from slices for evaluation.
        sent_idx_to_idx_str = utils.load_json_file(self.sent_idx_file)
        self.sent_idx_to_idx = {
            int(i): val
            for i, val in sent_idx_to_idx_str.items()
        }
        self.logger.info(f"Finished loading dataset.")

        # Stores info about the batch prepped embedding memory mapped files and their shapes and datatypes
        # so we can load them
        self.batch_prep_config = utils.load_json_file(
            data_utils.get_batch_prep_config(self.dataset_name))
        self.batch_prepped_emb_files = {}
        self.batch_prepped_emb_file_names = {}
        for emb in args.data_config.ent_embeddings:
            if 'batch_prep' in emb and emb['batch_prep']:
                assert emb.key in self.batch_prep_config, f'Need to prep {emb.key}. Please call prep instead of run with batch_prep_embeddings set to true.'
                self.batch_prepped_emb_file_names[emb.key] = os.path.join(
                    os.path.dirname(self.dataset_name),
                    os.path.basename(
                        self.batch_prep_config[emb.key]['file_name']))
                self.batch_prepped_emb_files[emb.key] = np.memmap(
                    self.batch_prepped_emb_file_names[emb.key],
                    dtype=self.batch_prep_config[emb.key]['dtype'],
                    shape=tuple(self.batch_prep_config[emb.key]['shape']),
                    mode='r')
                assert len(self.batch_prepped_emb_files[emb.key]) == self.data_len,\
                    f'Preprocessed emb data file {self.batch_prep_config[emb.key]["file_name"]} does not match length of main data file.'

        # Stores embeddings that we compute on the fly; these are embeddings where batch_on_the_fly is set to true.
        self.batch_on_the_fly_embs = {}
        for emb in args.data_config.ent_embeddings:
            if 'batch_on_the_fly' in emb and emb['batch_on_the_fly'] is True:
                mod, load_class = import_class("bootleg.embeddings",
                                               emb.load_class)
                try:
                    self.batch_on_the_fly_embs[emb.key] = getattr(
                        mod, load_class)(main_args=args,
                                         emb_args=emb['args'],
                                         entity_symbols=entity_symbols,
                                         model_device=None,
                                         word_symbols=None,
                                         key=emb.key)
                except AttributeError as e:
                    self.logger.warning(
                        f'No prep method found for {emb.load_class} with error {e}'
                    )
                except Exception as e:
                    print("ERROR", e)
        # The data in this table shouldn't be pickled since we delete it in the class __getstate__
        self.alias2entity_table = AliasEntityTable(
            args=args, entity_symbols=entity_symbols)
        # Random NIL percent
        self.mask_perc = args.train_config.random_nil_perc
        self.random_nil = False
        # Don't want to random mask for eval
        if not dataset_is_eval:
            # Whether to use a random NIL training regime
            self.random_nil = args.train_config.random_nil
            if self.random_nil:
                self.logger.info(
                    f'Using random nils during training with {self.mask_perc} percent'
                )
コード例 #9
0
    def __init__(self,
                 args=None,
                 entity_symbols=None,
                 word_symbols=None,
                 total_steps_per_epoch=0,
                 resume_model_file="",
                 eval_slice_names=None,
                 model_eval=False):
        self.model_eval = model_eval  # keep track of mode for model loading
        self.distributed = args.run_config.distributed
        self.args = args
        self.total_steps_per_epoch = total_steps_per_epoch
        self.start_epoch = 0
        self.start_step = 0
        self.use_cuda = not args.run_config.cpu and torch.cuda.is_available()
        self.logger = logging_utils.get_logger(args)
        if not self.use_cuda:
            self.model_device = "cpu"
            self.embedding_device = "cpu"
        else:
            self.model_device = args.run_config.gpu
            self.embedding_device = args.run_config.gpu
        # Load base model
        mod, load_class = import_class("bootleg",
                                       args.model_config.base_model_load_class)
        self.model = getattr(mod, load_class)(args=args,
                                              model_device=self.model_device,
                                              entity_symbols=entity_symbols,
                                              word_symbols=word_symbols)
        self.use_eval_wrapper = False
        if eval_slice_names is not None:
            self.use_eval_wrapper = True
            # Mapping of all output heads to indexes for the buffers
            head_key_to_idx = train_utils.get_head_key_to_idx(args)
            self.eval_wrapper = EvalWrapper(
                args=args,
                head_key_to_idx=head_key_to_idx,
                eval_slice_names=eval_slice_names,
                train_head_names=args.train_config.train_heads)
            self.eval_wrapper.to(self.model_device)
        self.optimizer = SparseDenseAdam(
            list(self.model.parameters()),
            lr=args.train_config.lr,
            weight_decay=args.train_config.weight_decay)

        self.scorer = Scorer(args, self.model_device)

        self.model.to(self.model_device)
        if self.distributed:
            # move everything to GPU
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[self.model_device],
                find_unused_parameters=True)

        # load model into existing model if model_file is provided
        if resume_model_file.endswith(".pt"):
            self.logger.info(f'Loading model from {resume_model_file}...')
            self.load(resume_model_file)

        self.logger.debug("Model device " + str(self.model_device))
        self.logger.debug("Embedding device " + str(self.embedding_device))
        self.logger.debug(
            f"*************************MODEL PARAMS WITH GRAD*************************"
        )
        self.logger.debug(
            f'Number of model parameters with grad: {count_parameters(self.model, True, self.logger)}'
        )
        self.logger.debug(
            f"*************************MODEL PARAMS WITHOUT GRAD*************************"
        )
        self.logger.debug(
            f'Number of model parameters without grad: {count_parameters(self.model, False, self.logger)}'
        )
コード例 #10
0
ファイル: data.py プロジェクト: pombredanne/bootleg
def get_dataloader_embeddings(main_args, entity_symbols):
    """Gets KG embeddings that need to be processed in the __get_item__ method
    of a dataset (e.g., querying a sparce numpy matrix). We save, for each KG
    embedding class that needs this preprocessing, the adjacency matrix (for KG
    connections), the processing function to run in __get_item__, and the file
    to load the adj matrix for dumping/loading.

    Args:
        main_args: main arguments
        entity_symbols: entity symbols

    Returns: Dict of KG metadata for using in the __get_item__ method.
    """
    batch_on_the_fly_kg_adj = {}
    for emb in main_args.data_config.ent_embeddings:
        batch_on_fly = "batch_on_the_fly" in emb and emb[
            "batch_on_the_fly"] is True
        # Find embeddings that have a "batch of the fly" key
        if batch_on_fly:
            log_rank_0_debug(
                logger,
                f"Loading class {emb.load_class} for preprocessing as on the fly or in data prep embeddings",
            )
            (
                cpu,
                dropout1d_perc,
                dropout2d_perc,
                emb_args,
                freeze,
                normalize,
                through_bert,
            ) = embedding_utils.get_embedding_args(emb)
            try:
                # Load the object
                mod, load_class = import_class("bootleg.embeddings",
                                               emb.load_class)
                kg_class = getattr(mod, load_class)(
                    main_args=main_args,
                    emb_args=emb_args,
                    entity_symbols=entity_symbols,
                    key=emb.key,
                    cpu=cpu,
                    normalize=normalize,
                    dropout1d_perc=dropout1d_perc,
                    dropout2d_perc=dropout2d_perc,
                )
                # Extract its kg adj, we'll use this later
                # Extract the kg_adj_process_func (how to process the embeddings in __get_item__ or dataset prep)
                # Extract the prep_file. We use this to load the kg_adj back after
                # saving/loading state using scipy.sparse.load_npz(prep_file)
                assert hasattr(
                    kg_class, "kg_adj"
                ), f"The embedding class {emb.key} does not have a kg_adj attribute and it needs to."
                assert hasattr(
                    kg_class, "kg_adj_process_func"
                ), f"The embedding class {emb.key} does not have a kg_adj_process_func attribute and it needs to."
                assert hasattr(kg_class, "prep_file"), (
                    f"The embedding class {emb.key} does not have a prep_file attribute and it needs to. We will call"
                    f" `scipy.sparse.load_npz(prep_file)` to load the kg_adj matrix."
                )
                batch_on_the_fly_kg_adj[emb.key] = {
                    "kg_adj": kg_class.kg_adj,
                    "kg_adj_process_func": kg_class.kg_adj_process_func,
                    "prep_file": kg_class.prep_file,
                }
            except AttributeError as e:
                logger.warning(
                    f"No prep method found for {emb.load_class} with error {e}"
                )
                raise
            except Exception as e:
                print("ERROR", e)
                raise
    return batch_on_the_fly_kg_adj