示例#1
0
 def __init__(self, main_args, emb_args, model_device, entity_symbols,
              word_symbols, word_emb, key):
     super(LearnedEntityEmb, self).__init__(main_args=main_args,
                                            emb_args=emb_args,
                                            model_device=model_device,
                                            entity_symbols=entity_symbols,
                                            word_symbols=word_symbols,
                                            word_emb=word_emb,
                                            key=key)
     self.logger = logging_utils.get_logger(main_args)
     self.learned_embedding_size = emb_args.learned_embedding_size
     self.normalize = True
     self.learned_entity_embedding = nn.Embedding(
         entity_symbols.num_entities_with_pad_and_nocand,
         self.learned_embedding_size,
         padding_idx=-1,
         sparse=True)
     # If tail_init is false, all embeddings are randomly intialized.
     # If tail_init is true, we initialize all embeddings to be the same.
     self.tail_init = True
     self.tail_init_zeros = False
     # None init vec will be random
     init_vec = None
     if "tail_init" in emb_args:
         self.tail_init = emb_args.tail_init
     if "tail_init_zeros" in emb_args:
         self.tail_init_zeros = emb_args.tail_init_zeros
         self.tail_init = False
         init_vec = torch.zeros(1, self.learned_embedding_size)
     assert not (self.tail_init and self.tail_init_zeros
                 ), f"Can only have one of tail_init or tail_init_zeros set"
     if self.tail_init or self.tail_init_zeros:
         if self.tail_init_zeros:
             self.logger.debug(
                 f"All learned entity embeddings are initialized to zero.")
         else:
             self.logger.debug(
                 f"All learned entity embeddings are initialized to the same value."
             )
         init_vec = model_utils.init_embeddings_to_vec(
             self.learned_entity_embedding, pad_idx=-1, vec=init_vec)
         vec_save_file = os.path.join(
             train_utils.get_save_folder(main_args.run_config),
             "init_vec_entity_embs.npy")
         self.logger.debug(f"Saving init vector to {vec_save_file}")
         np.save(vec_save_file, init_vec)
     else:
         self.logger.debug(
             f"All learned embeddings are randomly initialized.")
     self._dim = main_args.model_config.hidden_size
     self.eid2reg = None
     # Regularization mapping goes from eid to 2d dropout percent
     if "regularize_mapping" in emb_args:
         self.logger.debug(
             f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}"
         )
         self.eid2reg = self.load_regularization_mapping(
             main_args, entity_symbols, emb_args.regularize_mapping,
             self.logger.debug)
         self.eid2reg = self.eid2reg.to(model_device)
示例#2
0
 def setup_loss_file(self, args):
     save_folder = train_utils.get_save_folder(args.run_config)
     loss_file = "loss_results"
     loss_file += train_utils.get_file_suffix(args)
     loss_file += '.jsonl'
     loss_file = os.path.join(save_folder, loss_file)
     open(loss_file, 'w').close()
     return loss_file
示例#3
0
 def setup_dev_files(self, args):
     dev_files = {}
     save_folder = train_utils.get_save_folder(args.run_config)
     dev_file_tag = args.data_config.dev_dataset.file.split('.jsonl')[0]
     dev_file = dev_file_tag + "_dev_results"
     dev_file += train_utils.get_file_suffix(args)
     dev_file += '.jsonl'
     dev_file = os.path.join(save_folder, dev_file)
     # Clear old file
     open(dev_file, 'w').close()
     dev_files[args.data_config.dev_dataset.file] = dev_file
     return dev_files
示例#4
0
 def setup_test_files(self, args):
     test_files = {}
     save_folder = train_utils.get_save_folder(args.run_config)
     test_file_tag = args.data_config.test_dataset.file.split('.jsonl')[0]
     test_file = test_file_tag + "_test_results"
     test_file += train_utils.get_file_suffix(args)
     test_file += '.jsonl'
     test_file = os.path.join(save_folder, test_file)
     # Clear old file
     open(test_file, 'w').close()
     test_files[args.data_config.test_dataset.file] = test_file
     return test_files
示例#5
0
def main(args, mode):
    multiprocessing.set_start_method("forkserver", force=True)
    # =================================
    # ARGUMENTS CHECK
    # =================================
    # distributed training
    assert (args.run_config.ngpus_per_node <= torch.cuda.device_count()) or (
        not torch.cuda.is_available()), 'Not enough GPUs per node.'
    world_size = args.run_config.ngpus_per_node * args.run_config.nodes
    if world_size > 1:
        args.run_config.distributed = True
    assert (args.run_config.distributed and world_size > 1) or (world_size
                                                                == 1)

    train_utils.setup_run_folders(args, mode)

    # check slice method
    assert args.train_config.slice_method in SLICE_METHODS, f"You're slice_method {args.train_config.slice_method} is not in {SLICE_METHODS}."
    train_utils.setup_train_heads_and_eval_slices(args)

    # check save step
    assert args.run_config.save_every_k_eval > 0, f"You must have save_every_k_eval set to be > 0"

    # since eval, make sure resume model file is set and exists
    if mode == "eval" or mode == "dump_preds" or mode == "dump_embs":
        assert args.run_config.init_checkpoint != "", \
            f"You must specify a model checkpoint in run_config to run {mode}"
        assert os.path.exists(args.run_config.init_checkpoint),\
            f"The resume model file of {args.run_config.init_checkpoint} doesn't exist"

    if mode == "dump_preds" or mode == "dump_embs":
        assert args.run_config.perc_eval == 1.0, f"If you are running dump_preds or dump_embs, run_config.perc_eval must be 1.0. You have {args.run_config.perc_eval}"
        assert args.data_config.test_dataset.use_weak_label is True, f"We do not support dumping when the test dataset gold is set to false. You can filter the dataset and run with filtered data."

    utils.dump_json_file(filename=os.path.join(
        train_utils.get_save_folder(args.run_config), f"config_{mode}.json"),
                         contents=args)
    if args.run_config.distributed:
        mp.spawn(main_worker,
                 nprocs=args.run_config.ngpus_per_node,
                 args=(args, mode, world_size))
    else:
        main_worker(gpu=args.run_config.gpu,
                    args=args,
                    mode=mode,
                    world_size=world_size)
示例#6
0
文件: run.py 项目: paper2code/bootleg
def train(args, is_writer, logger, world_size, rank):
    # This is main but call again in case train is called directly
    train_utils.setup_train_heads_and_eval_slices(args)
    train_utils.setup_run_folders(args, "train")

    # Load word symbols (like tokenizers) and entity symbols (aka entity profiles)
    word_symbols = data_utils.load_wordsymbols(
        args.data_config, is_writer, distributed=args.run_config.distributed)
    logger.info(f"Loading entity_symbols...")
    entity_symbols = EntitySymbols(
        load_dir=os.path.join(args.data_config.entity_dir,
                              args.data_config.entity_map_dir),
        alias_cand_map_file=args.data_config.alias_cand_map)
    logger.info(
        f"Loaded entity_symbols with {entity_symbols.num_entities} entities.")
    # Get train dataset
    train_slice_dataset = data_utils.create_slice_dataset(
        args, args.data_config.train_dataset, is_writer, dataset_is_eval=False)
    train_dataset = data_utils.create_dataset(
        args,
        args.data_config.train_dataset,
        is_writer,
        word_symbols,
        entity_symbols,
        slice_dataset=train_slice_dataset,
        dataset_is_eval=False)
    train_dataloader, train_sampler = data_utils.create_dataloader(
        args,
        train_dataset,
        eval_slice_dataset=None,
        world_size=world_size,
        rank=rank,
        batch_size=args.train_config.batch_size)

    # Repeat for dev
    dev_dataset_collection = {}
    dev_slice_dataset = data_utils.create_slice_dataset(
        args, args.data_config.dev_dataset, is_writer, dataset_is_eval=True)
    dev_dataset = data_utils.create_dataset(args,
                                            args.data_config.dev_dataset,
                                            is_writer,
                                            word_symbols,
                                            entity_symbols,
                                            slice_dataset=dev_slice_dataset,
                                            dataset_is_eval=True)
    dev_dataloader, dev_sampler = data_utils.create_dataloader(
        args,
        dev_dataset,
        eval_slice_dataset=dev_slice_dataset,
        batch_size=args.run_config.eval_batch_size)
    dataset_collection = DatasetCollection(args.data_config.dev_dataset,
                                           args.data_config.dev_dataset.file,
                                           dev_dataset, dev_dataloader,
                                           dev_slice_dataset, dev_sampler)
    dev_dataset_collection[
        args.data_config.dev_dataset.file] = dataset_collection

    eval_slice_names = args.run_config.eval_slices

    total_steps_per_epoch = len(train_dataloader)
    # Create trainer---model, optimizer, and scorer
    trainer = Trainer(args,
                      entity_symbols,
                      word_symbols,
                      total_steps_per_epoch=total_steps_per_epoch,
                      eval_slice_names=eval_slice_names,
                      resume_model_file=args.run_config.init_checkpoint)

    # Set up epochs and intervals for saving and evaluating
    max_epochs = int(args.run_config.max_epochs)
    eval_steps = int(args.run_config.eval_steps)
    log_steps = int(args.run_config.log_steps)
    save_steps = max(int(args.run_config.save_every_k_eval * eval_steps), 1)
    logger.info(
        f"Eval steps {eval_steps}, Log steps {log_steps}, Save steps {save_steps}, Total training examples per epoch {len(train_dataset)}"
    )
    status_reporter = StatusReporter(args,
                                     logger,
                                     is_writer,
                                     max_epochs,
                                     total_steps_per_epoch,
                                     is_eval=False)
    global_step = 0
    for epoch in range(trainer.start_epoch, trainer.start_epoch + max_epochs):
        # this is to fix having to save/restore the RNG state for checkpointing
        torch.manual_seed(args.train_config.seed + epoch)
        np.random.seed(args.train_config.seed + epoch)
        if args.run_config.distributed:
            # for determinism across runs https://github.com/pytorch/examples/issues/501
            train_sampler.set_epoch(epoch)

        start_time_load = time.time()
        for i, batch in enumerate(train_dataloader):
            load_time = time.time() - start_time_load
            start_time = time.time()
            _, loss_pack, _, _ = trainer.update(batch)
            # Log progress
            if (global_step + 1) % log_steps == 0:
                duration = time.time() - start_time
                status_reporter.step_status(epoch=epoch,
                                            step=global_step,
                                            loss_pack=loss_pack,
                                            time=duration,
                                            load_time=load_time,
                                            lr=trainer.get_lr())
            # Save model
            if (global_step + 1) % save_steps == 0 and is_writer:
                logger.info("Saving model...")
                trainer.save(save_dir=train_utils.get_save_folder(
                    args.run_config),
                             epoch=epoch,
                             step=global_step,
                             step_in_batch=i,
                             suffix=args.run_config.model_suffix)
            # Run evaluation
            if (global_step + 1) % eval_steps == 0:
                eval_utils.run_eval_all_dev_sets(args, global_step,
                                                 dev_dataset_collection,
                                                 logger, status_reporter,
                                                 trainer)
            if args.run_config.distributed:
                dist.barrier()
            global_step += 1
            # Time loading new batch
            start_time_load = time.time()
        ######### END OF EPOCH
        if is_writer:
            logger.info(f"Saving model end of epoch {epoch}...")
            trainer.save(save_dir=train_utils.get_save_folder(args.run_config),
                         epoch=epoch,
                         step=global_step,
                         step_in_batch=i,
                         end_of_epoch=True,
                         suffix=args.run_config.model_suffix)
        # Always run eval when saving -- if this coincided with eval_step, then don't need to rerun eval
        if (global_step + 1) % eval_steps != 0:
            eval_utils.run_eval_all_dev_sets(args, global_step,
                                             dev_dataset_collection, logger,
                                             status_reporter, trainer)
    if is_writer:
        logger.info("Saving model...")
        trainer.save(save_dir=train_utils.get_save_folder(args.run_config),
                     epoch=epoch,
                     step=global_step,
                     step_in_batch=i,
                     end_of_epoch=True,
                     last_epoch=True,
                     suffix=args.run_config.model_suffix)
    if args.run_config.distributed:
        dist.barrier()
示例#7
0
def get_log_name(args, mode):
    log_name = os.path.join(train_utils.get_save_folder(args.run_config),
                            f"log_{mode}")
    log_name += train_utils.get_file_suffix(args)
    log_name += f'_gpu{args.run_config.gpu}'
    return log_name
示例#8
0
 def __init__(self, main_args, emb_args, model_device, entity_symbols,
              word_symbols, word_emb, key):
     super(TopKEntityEmb, self).__init__(main_args=main_args,
                                         emb_args=emb_args,
                                         model_device=model_device,
                                         entity_symbols=entity_symbols,
                                         word_symbols=word_symbols,
                                         word_emb=word_emb,
                                         key=key)
     self.logger = logging_utils.get_logger(main_args)
     self.learned_embedding_size = emb_args.learned_embedding_size
     self.normalize = True
     assert "perc_emb_drop" in emb_args, f"To use TopKEntityEmb we need perc_emb_drop to be in the args. This gives the percentage of embeddings" \
                                         f" removed."
     # We remove perc_emb_drop percent of the embeddings and add one to represent the new toes embedding
     num_topk_entities_with_pad_and_nocand = entity_symbols.num_entities_with_pad_and_nocand - int(
         emb_args.perc_emb_drop * entity_symbols.num_entities) + 1
     # Mapping of entity to the new eid mapping
     qid2topk_eid = {}
     eid2topkeid = torch.arange(
         0, entity_symbols.num_entities_with_pad_and_nocand)
     # There are issues with using -1 index into the embeddings; so we manually set it to be the last value
     eid2topkeid[-1] = num_topk_entities_with_pad_and_nocand - 1
     if "qid2topk_eid" not in emb_args:
         assert len(main_args.run_config.init_checkpoint) > 0, f"If you don't provide the qid2topk_eid mapping as an argument to TopKEntityEmb, " \
                                                               f"you must be loading a model from a checkpoint to build this index mapping"
     else:
         qid2topk_eid = utils.load_json_file(emb_args.qid2topk_eid)
         assert len(
             qid2topk_eid
         ) == entity_symbols.num_entities, f"You must have an item in qid2topk_eid for each qid in entity_symbols"
         for qid in entity_symbols.get_all_qids():
             old_eid = entity_symbols.get_eid(qid)
             new_eid = qid2topk_eid[qid]
             eid2topkeid[old_eid] = new_eid
         assert eid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
         assert eid2topkeid[
             -1] == num_topk_entities_with_pad_and_nocand - 1, "The -1 eid should still map to -1"
     self.learned_entity_embedding = nn.Embedding(
         num_topk_entities_with_pad_and_nocand,
         self.learned_embedding_size,
         padding_idx=-1,
         sparse=True)
     # Keep this mapping so a topK model can simply be loaded without needing the new eid mapping
     self.register_buffer("eid2topkeid", eid2topkeid)
     # If tail_init is false, all embeddings are randomly intialized.
     # If tail_init is true, we initialize all embeddings to be the same.
     self.tail_init = True
     self.tail_init_zeros = False
     # None init vec will be random
     init_vec = None
     if "tail_init" in emb_args:
         self.tail_init = emb_args.tail_init
     if "tail_init_zeros" in emb_args:
         self.tail_init_zeros = emb_args.tail_init_zeros
         self.tail_init = False
         init_vec = torch.zeros(1, self.learned_embedding_size)
     assert not (self.tail_init and self.tail_init_zeros
                 ), f"Can only have one of tail_init or tail_init_zeros set"
     if self.tail_init or self.tail_init_zeros:
         if self.tail_init_zeros:
             self.logger.debug(
                 f"All learned entity embeddings are initialized to zero.")
         else:
             self.logger.debug(
                 f"All learned entity embeddings are initialized to the same value."
             )
         init_vec = model_utils.init_embeddings_to_vec(
             self.learned_entity_embedding, pad_idx=-1, vec=init_vec)
         vec_save_file = os.path.join(
             train_utils.get_save_folder(main_args.run_config),
             "init_vec_entity_embs.npy")
         self.logger.debug(f"Saving init vector to {vec_save_file}")
         np.save(vec_save_file, init_vec)
     else:
         self.logger.debug(
             f"All learned embeddings are randomly initialized.")
     self._dim = main_args.model_config.hidden_size
     self.eid2reg = None
     # Regularization mapping goes from eid to 2d dropout percent
     if "regularize_mapping" in emb_args:
         self.logger.warning(
             f"You are using regularization mapping with a topK entity embedding. This means all QIDs that are mapped to the same"
             f" EID will get the same regularization value.")
         self.logger.debug(
             f"Using regularization mapping in enity embedding from {emb_args.regularize_mapping}"
         )
         self.eid2reg = self.load_regularization_mapping(
             main_args, qid2topk_eid, num_topk_entities_with_pad_and_nocand,
             emb_args.regularize_mapping, self.logger.debug)
         self.eid2reg = self.eid2reg.to(model_device)
示例#9
0
 def setup_tensorboard(self, args):
     save_folder = os.path.join(
         train_utils.get_save_folder(args.run_config), "tensorboard")
     return SummaryWriter(log_dir=save_folder)