示例#1
0
def load_weights(checkpoint_path: str, model):
    """
    Load the weights from a checkpoint into a model.
    
    :param checkpoint_path: The path to the file in which the checkpoint is stored.
    :param model: The model for which to set the state from the checkpoint.
    """
    checkpoint = load_state(checkpoint_path)
    if "model" in checkpoint:
        if logging.DEBUG_VERBOSITY:
            logging.info("Load Model...")
        model.load_state_dict(checkpoint["model"])
    else:
        logging.warn("Could not find model_state in checkpoint.")
示例#2
0
 def _save(self, name: str) -> None:
     checkpoint_sub_path = os.path.join("checkpoints", name)
     log_path = get_log_path()
     assert log_path is not None
     checkpoint_path = os.path.join(log_path, checkpoint_sub_path)
     save_state(
         {
             "epoch": self.epoch,
             "model": self.model.state_dict(),
             "optimizer": self.optimizer.state_dict(),
             "loss": self.loss.state_dict()
         },
         checkpoint_path,
         file_format=self.file_format)
     info("Saved Checkoint: {}".format(checkpoint_sub_path))
示例#3
0
 def init_caching(self, cache_dir: str):
     """
     Initialize caching for quicker access once the data was cached once.
     
     The caching caches the calls to the getitem including application of regular transformers.
     When calling this function the cache gets read if it exists or otherwise the folder is created and on first calling the getitem the item is stored.
     
     :param cache_dir: Directory where the cache should be stored.
     """
     info("Init caching: {}".format(cache_dir))
     self._caching = True
     self._cache_dir = cache_dir
     # If it does not exist create the cache dir.
     if not os.path.exists(self._cache_dir):
         os.makedirs(self._cache_dir)
示例#4
0
def init_model(model, dataloader):
    """
    Initialize the model, so that the checkpoints can be loaded.

    This is done by running the forward pass once with an example.

    :param model: The model that should be initialized.
    :param dataloader: A dataloader that can be used to get a sample to feed through the network for full initialization.
    """
    if not getattr(model, "initialized_model", False):
        if logging.DEBUG_VERBOSITY:
            logging.info("Build Model")
        model.initialized_model = True
        features, _ = next(iter(dataloader))
        model(*features)
示例#5
0
    def on_fit_start(self, model, train_dataloader, dev_dataloader, loss,
                     optimizer, start_epoch: int, epochs: int) -> int:
        start_epoch = super().on_fit_start(model, train_dataloader,
                                           dev_dataloader, loss, optimizer,
                                           start_epoch, epochs)
        if get_log_path() is None:
            raise RuntimeError(
                "You must setup logger before calling the fit method. See babilim.core.logging.set_logger"
            )
        create_checkpoint_structure()

        info("Started fit.")
        self.start_time = time.time()
        log_progress(goal="warmup", progress=0, score=0)

        return start_epoch
示例#6
0
def _train(load_checkpoint,
           load_model,
           epochs=1,
           name=None,
           create_dataset=None,
           create_model=None,
           create_loss=None,
           create_optimizer=None,
           create_trainer=None,
           training_callbacks=DEFAULT_TRAINING_CALLBACKS):
    """
    The main training loop.
    """
    _setup_logging(name)
    train_data = create_dataset(split=SPLIT_TRAIN)
    if (hasattr(train_data, "to_pytorch")):
        train_data = train_data.to_pytorch()
    val_data = create_dataset(split=SPLIT_VAL)
    if (hasattr(val_data, "to_pytorch")):
        val_data = val_data.to_pytorch()
    model = create_model()
    init_model(model, train_data)
    if load_model is not None:
        logging.info("Loading model: {}".format(load_model))
        load_weights(load_model, model)

    loss = create_loss(model=model)
    optim = create_optimizer(model=model, loss=loss)
    trainer = create_trainer(model=model,
                             loss=loss,
                             optim=optim,
                             callbacks=training_callbacks,
                             train_data=train_data,
                             val_data=val_data)
    if load_checkpoint is not None:
        trainer.restore(load_checkpoint)
    trainer.fit(epochs=epochs)
示例#7
0
def run(config_class=None):
    """
    Run the cli interface.

    Parses the command line arguments (also provides a --help parameter).

    :param config_class: (Optional[Class]) A pointer to a class definition of a config.
        If provided there is no config parameter for the command line.
        Else the config specified in the command line will be loaded and instantiated.
    """
    parser = argparse.ArgumentParser()
    parser = argparse.ArgumentParser(
        description='The main entry point for the script.')
    parser.add_argument('--mode',
                        type=str,
                        required=False,
                        default="train",
                        help='What mode should be run, trian or test.')
    parser.add_argument('--input',
                        type=str,
                        required=True,
                        help='Folder where the dataset can be found.')
    if "RESULTS_PATH" in os.environ:
        parser.add_argument(
            '--output',
            type=str,
            required=False,
            default=os.environ["RESULTS_PATH"],
            help='Folder where to save the results defaults to $RESULTS_PATH.')
    else:
        parser.add_argument(
            '--output',
            type=str,
            required=True,
            help=
            'Folder where to save the results, you can set $RESULTS_PATH as a default.'
        )
    if config_class is None:
        parser.add_argument('--config',
                            type=str,
                            required=True,
                            help='Configuration to use.')
    parser.add_argument(
        '--load_checkpoint',
        type=str,
        required=False,
        help=
        'Path to the checkpoint (model, loss, optimizer, trainer state) to load.'
    )
    parser.add_argument('--load_model',
                        type=str,
                        required=False,
                        help='Path to the model weights to load.')
    parser.add_argument('--name',
                        type=str,
                        default="",
                        required=False,
                        help='Name to give the run.')
    parser.add_argument(
        '--no_time_prefix_name',
        action='store_true',
        help='This flag will disable the time prefix for the name.')
    parser.add_argument('--device',
                        type=str,
                        default=None,
                        required=False,
                        help='CUDA device id')
    parser.add_argument(
        '--debug',
        action='store_true',
        help='This flag will make deeptech print debug messages.')
    args = parser.parse_args()

    if args.debug:
        logging.DEBUG_VERBOSITY = args.debug
        logging.debug(f"Set DEBUG_VERBOSITY={logging.DEBUG_VERBOSITY}")

    # Log args for reproducibility
    logging.info(f"Arg: --mode {args.mode}")
    logging.info(f"Arg: --input {args.input}")
    logging.info(f"Arg: --output {args.output}")
    if config_class is None:
        logging.info(f"Arg: --config {args.config}")
    logging.info(f"Arg: --load_checkpoint {args.load_checkpoint}")
    logging.info(f"Arg: --load_model {args.load_model}")
    logging.info(f"Arg: --name {args.name}")
    logging.info(f"Arg: --device {args.device}")

    if args.device:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.device

    if config_class is None:
        config = import_config(args.config, args.name, args.input, args.output)
    else:
        config = config_class(args.name, args.input, args.output)
    if args.no_time_prefix_name:
        config.training_name_prefix_time = False
    run_manual(args.mode, config, args.load_checkpoint, args.load_model)
示例#8
0
    def restore(self, state_dict_path):
        # Load Checkpoint
        logging.info("Loading checkpoint: {}".format(state_dict_path))
        checkpoint = load_state(state_dict_path)
        self.epoch = checkpoint["epoch"] + 1
        if "model" in checkpoint:
            if logging.DEBUG_VERBOSITY:
                logging.info("Load Model...")
            self.model.load_state_dict(checkpoint["model"])
        else:
            logging.warn("Could not find model_state in checkpoint.")
        if "optimizer" in checkpoint:
            if logging.DEBUG_VERBOSITY:
                logging.info("Load Optimizer...")
            self.optimizer.load_state_dict(checkpoint["optimizer"])
        else:
            logging.warn("Could not find optimizer_state in checkpoint.")
        if "loss" in checkpoint:
            if logging.DEBUG_VERBOSITY:
                logging.info("Load Loss...")
            self.loss.load_state_dict(checkpoint["loss"])
        else:
            logging.warn("Could not find loss_state in checkpoint.")

        if logging.DEBUG_VERBOSITY:
            logging.info("Trainable Variables:")
            # TODO
            logging.info("Untrainable Variables:")