def __recover_child_process_weights(self, best_path, last_path): # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path # todo, pass also best score # load last weights if last_path is not None and self.lightning_module.trainer.state == TrainerState.FITTING: ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt)
def __recover_child_process_weights(self, best_path, last_path): # TODO: is there a better way than accessing callback through model -> trainer -> callback? # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback: self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path # todo, pass also best score # load last weights if last_path is not None and not self.lightning_module.trainer.testing: ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) self.lightning_module.load_state_dict(ckpt)
def __recover_child_process_weights(self, model, best_path, last_path): # transfer back the best path to the trainer if self.trainer.checkpoint_callback: self.trainer.checkpoint_callback.best_model_path = best_path # todo, pass also best score # load last weights if last_path is not None and not self.trainer.testing: ckpt = pl_load(last_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt) self.trainer.model = model
def load_best_model(self): """Loads the best model after training is done""" if self.trainer.checkpoint_callback is not None: logger.info("Loading the best model...") ckpt_path = self.trainer.checkpoint_callback.best_model_path logger.debug(f"Model Checkpoint: {ckpt_path}") ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) self.model.load_state_dict(ckpt["state_dict"]) else: logger.info( "No best model available to load. Did you run it more than 1 epoch?..." )
def load_from_checkpoint(cls, checkpoint_path: Union[str, IO], map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, strict: bool = True, **kwargs): """ This is NOT used when restoring a model from the trainer, but still practical in other cases! """ if map_location is not None: checkpoint = pl_load(checkpoint_path, map_location=map_location) else: checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) if checkpoint.get("version", None) is not None: _check_version(checkpoint["version"]) model = cls._load_model_state(checkpoint, strict=strict, **kwargs) model._loaded_checkpoint = checkpoint_path return model
def load_ckpt(self, n_epoch, logger=None, save_dir=None, experiment_id=None, run_id=None): """restart to train the model from step for num_epochs""" if self.model.anomaly_scores is not None: self.model.anomaly_scores = None if self.model.recon_x is not None: self.model.recon_x = None if logger is None: logger = self.logger if not (run_id is None): folder_path = self._get_model_path(save_dir, experiment_id, run_id) else: folder_path = self._get_model_path(logger.save_dir, logger.experiment_id, logger.run_id) if n_epoch == 'last': if (folder_path/'last.ckpt').exists(): checkpoint_file = 'last.ckpt' else: n_epoch = self.get_last_step(folder_path) checkpoint_file = f'model-epoch={n_epoch}.ckpt' else: checkpoint_file = f'model-epoch={n_epoch}.ckpt' if n_epoch > self.config.trainer.max_epochs: self.config.trainer.max_epochs = n_epoch + 1 # 本当は n_epoch が best かどうかを判定したほうがよさそう ckpt_path = str(folder_path/checkpoint_file) if run_id is None and self.model_checkpoint.best_model_path is not None: if len(self.model_checkpoint.best_model_path) != 0: ckpt_path = self.model_checkpoint.best_model_path ckpt = pl_load(ckpt_path, map_location=lambda storage,loc: storage) self.model.load_state_dict(ckpt['state_dict']) if torch.cuda.is_available() and self.config.trainer.use_gpu: self.config.trainer.args.gpus = 1 else: self.config.trainer.args.gpus = 0 callbacks=[self.progressbar] if self.early_stopping is not None: callbacks.append(self.early_stopping) if self.model_checkpoint is not None: callbacks.append(self.model_checkpoint) self.trainer = pl.Trainer(resume_from_checkpoint=ckpt_path, logger=self.logger, callbacks=callbacks, **(self.config.trainer.args))
def tune_train_once(config, checkpoint_dir=None, args: argparse.Namespace = None, model_class: type = None, build_method=None, task_info: TaskInfo = None, model_kwargs: dict = None, resume: str = None, **kwargs): if resume is None: resume = 'all' args_vars = vars(args) args_vars.update(config) pl.seed_everything(args.seed) logger = [ loggers.CSVLogger(save_dir=tune.get_trial_dir(), name="", version="."), loggers.TensorBoardLogger(save_dir=tune.get_trial_dir(), name="", version=".", default_hp_metric=False) ] trainer_args = dict(logger=logger, progress_bar_refresh_rate=0, callbacks=[ TuneReportCheckpointCallback( metrics={ f'tune_{task_info.metric_name}': f'val_{task_info.metric_name}' }, filename="tune.ckpt", on="validation_end") ]) if checkpoint_dir and resume == 'all': trainer_args['resume_from_checkpoint'] = os.path.join( checkpoint_dir, "tune.ckpt") # fix slurm trainer os.environ["SLURM_JOB_NAME"] = "bash" model = model_class(args, **model_kwargs) build_method(model, task_info) trainer: Trainer = Trainer.from_argparse_args(args, **trainer_args) if checkpoint_dir and resume == 'model': ckpt = pl_load(os.path.join(checkpoint_dir, "tune.ckpt"), map_location=lambda storage, loc: storage) model = model._load_model_state(ckpt) trainer.current_epoch = ckpt["epoch"] trainer.fit(model)
def pointpillars_kitti(*args, pretrained: bool = True, **kwargs) -> PointPillars: cfg = _ml3d.utils.Config.load_from_file( os.path.join(CONFIG_PATH, "pointpillars_kitti.yml")) cfg.model.device = "cpu" model = PointPillars(**cfg.model) if pretrained: weight_url = os.path.join( ROOT_URL, "pointpillars_kitti_202012221652utc.pth") model.load_state_dict( pl_load(weight_url, map_location="cpu")["model_state_dict"], ) model.cfg.batcher = "ObjectDetectBatchCollator" return model, 384, get_collate_fn(model)
def randlanet_s3dis(*args, use_fold_5: bool = True, **kwargs) -> RandLANet: cfg = _ml3d.utils.Config.load_from_file( os.path.join(CONFIG_PATH, "randlanet_s3dis.yml")) model = RandLANet(**cfg.model) if use_fold_5: weight_url = os.path.join( ROOT_URL, "randlanet_s3dis_area5_202010091333utc.pth") else: weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_202010091238.pth") model.load_state_dict( pl_load(weight_url, map_location="cpu")["model_state_dict"]) return model, 32, get_collate_fn(model)
def evaluate_experiment(self, name: str, gpus: int, nodes: int, version=None, output_path=os.getcwd(), evaluate_checkpoint=None): """ Evaluate the experiment. :param name: The name of the family of experiments you are conducting. :param gpus: The number of gpus used for training. :param nodes: The number of nodes used for training. :param version: The name for the specific run of the experiment in the family (defaults to a timestamp). :param output_path: The path where to store the outputs of the experiment (defaults to the current working directory). :param evaluate_checkpoint: The path to the checkpoint that should be loaded (defaults to None). """ if version is None: version = _generate_version() if evaluate_checkpoint is None: raise RuntimeError( "No checkpoint provided for evaluation, you must provide one.") self.output_path = os.path.join(output_path, name, version) if evaluate_checkpoint == "last": checkpoint_path = self._find_checkpoint(name, version, output_path) else: checkpoint_path = os.path.join(self.output_path, evaluate_checkpoint) if checkpoint_path is None or not os.path.exists(checkpoint_path): raise RuntimeError( f"Checkpoint does not exist: {str(checkpoint_path)}") self.testing = True trainer = pl.Trainer( default_root_dir=output_path, max_epochs=getattr(self.hparams, "max_epochs", 1000), gpus=gpus, num_nodes=nodes, logger=TensorBoardLogger(save_dir=output_path, version=version, name=name, log_graph=hasattr(self, "example_input_array"), default_hp_metric=False), accelerator="ddp" if gpus > 1 else None) ckpt = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) self.load_state_dict(ckpt['state_dict']) trainer.test(self)
def load_checkpoint( self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage ) -> Dict[str, Any]: """ Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. Args: path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. Returns: The loaded checkpoint. """ return pl_load(path, map_location=map_location)
def __test_using_best_weights(self, ckpt_path, test_dataloaders): model = self.get_model() # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: raise MisconfigurationException( 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.' ) # load best weights if ckpt_path is not None: # ckpt_path is 'best' so load the best model if ckpt_path == 'best': ckpt_path = self.checkpoint_callback.best_model_path if len(ckpt_path) == 0: rank_zero_warn( f'.test() found no path for the best weights, {ckpt_path}. Please ' f'specify a path for a checkpoint .test(ckpt_path=PATH)') return {} if self.accelerator_backend is not None: self.accelerator_backend.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) # attach dataloaders if test_dataloaders is not None: self.data_connector.attach_dataloaders( model, test_dataloaders=test_dataloaders) # run tests self.tested_ckpt_path = ckpt_path self.testing = True os.environ['PL_TESTING_MODE'] = '1' self.model = model results = self.fit(model) self.testing = False del os.environ['PL_TESTING_MODE'] # teardown if self.is_function_implemented('teardown'): model_ref = self.get_model() model_ref.teardown('test') return results
def __evaluate_using_weights( self, model, ckpt_path: Optional[str] = None, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None ): # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: raise MisconfigurationException( 'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.' ) # load best weights if ckpt_path is not None: # ckpt_path is 'best' so load the best model if ckpt_path == 'best': ckpt_path = self.checkpoint_callback.best_model_path if len(ckpt_path) == 0: rank_zero_warn( f'`.test()` found no path for the best weights, {ckpt_path}. Please' ' specify a path for a checkpoint `.test(ckpt_path=PATH)`' ) return {} self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) # attach dataloaders if dataloaders is not None: self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) if self.validating: self.validated_ckpt_path = ckpt_path else: self.tested_ckpt_path = ckpt_path # run test results = self.fit(model) # teardown if self.is_function_implemented('teardown', model=model): model.teardown('test') return results
def restore(self, checkpoint_path: str, on_gpu: bool): """ Load model/training states from the checkpoint file through file-read and state-restore. Also restores all training state like: - epoch - callbacks - schedulers - optimizer In detail, check return value description of `dump_checkpoint` """ # if on_gpu: # checkpoint = torch.load(checkpoint_path) # else: # load on CPU first # read a checkpoint dictionary object from the checkpoint file at `checkpoint_path` checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # restore states from the checkpoint dictionary object # load model state model = self.trainer.get_model() # give the datamodule a chance to load something if self.trainer.datamodule is not None: self.trainer.datamodule.on_load_checkpoint(checkpoint) # give model a chance to restore something model.on_load_checkpoint(checkpoint) # restore the state_dict on the model model.load_state_dict(checkpoint['state_dict']) if on_gpu: model.cuda(self.trainer.root_gpu) # restore amp scaling if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint: self.trainer.scaler.load_state_dict( checkpoint['native_amp_scaling_state']) elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint: amp.load_state_dict(checkpoint['amp_scaling_state']) # load training state (affects trainer only) self.restore_training_state(checkpoint)
def trainWithTune(config, checkpoint_dir=None, datamodule=None, num_epochs=10, num_gpus=0): trainer = Trainer( max_epochs=num_epochs, # If fractional GPUs passed in, convert to int. gpus=math.ceil(num_gpus), logger=TensorBoardLogger(save_dir=tune.get_trial_dir(), name="", version="."), progress_bar_refresh_rate=0, callbacks=[ TuneReportCheckpointCallback(metrics={ "loss": "val_loss", "mean_accuracy": "val_acc", "mean_iou": "val_iou", }, filename="checkpoint", on="validation_end") ]) if checkpoint_dir: # Currently, this leads to errors: # model = LightningMNISTClassifier.load_from_checkpoint( # os.path.join(checkpoint, "checkpoint")) # Workaround: ckpt = pl_load(os.path.join(checkpoint_dir, "checkpoint"), map_location=lambda storage, loc: storage) model = MMETrainingModule._load_model_state( ckpt, lr=10**config['log_lr'], lrRatio=10**config['log_lrRatio'], decay=10**config['log_decay'], num_cls=NUM_CLS) trainer.current_epoch = ckpt["epoch"] else: model = MMETrainingModule(lr=10**config['log_lr'], lrRatio=10**config['log_lrRatio'], decay=10**config['log_decay'], num_cls=NUM_CLS) trainer.fit(model, datamodule=datamodule)
def restore(self, checkpoint_path: str, on_gpu: bool): """ Restore training state from checkpoint. Also restores all training state like: - epoch - callbacks - schedulers - optimizer """ # if on_gpu: # checkpoint = torch.load(checkpoint_path) # else: # load on CPU first checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # load model state model = self.trainer.get_model() # give the datamodule a chance to load something if self.trainer.datamodule is not None: self.trainer.datamodule.on_load_checkpoint(checkpoint) # give model a chance to load something model.on_load_checkpoint(checkpoint) # load the state_dict on the model automatically model.load_state_dict(checkpoint['state_dict']) if on_gpu: model.cuda(self.trainer.root_gpu) # restore amp scaling if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint: self.trainer.scaler.load_state_dict( checkpoint['native_amp_scaling_state']) elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint: amp.load_state_dict(checkpoint['amp_scaling_state']) # load training state (affects trainer only) self.restore_training_state(checkpoint)
def restore(self, checkpoint_path: str, on_gpu: bool): """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. """ # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # acquire the model model = self.trainer.get_model() # restore model and datamodule state self.restore_model_state(model, checkpoint) if on_gpu: model.cuda(self.trainer.root_gpu) # restore training state self.restore_training_state(checkpoint)
def train_mnist_tune_checkpoint(config, checkpoint=None): trainer = pl.Trainer( max_epochs=10, progress_bar_refresh_rate=0, callbacks=[CheckpointCallback(), TuneReportCallback()]) if checkpoint: # Currently, this leads to errors: # model = LightningMNISTClassifier.load_from_checkpoint( # os.path.join(checkpoint, "checkpoint")) # Workaround: ckpt = pl_load( os.path.join(checkpoint, "checkpoint"), map_location=lambda storage, loc: storage) model = LightningMNISTClassifier._load_model_state(ckpt, config=config) trainer.current_epoch = ckpt["epoch"] else: model = LightningMNISTClassifier( config=config, data_dir=config["data_dir"]) trainer.fit(model)
def __load_ckpt_weights( self, model, ckpt_path: Optional[str] = None, ) -> Optional[str]: if ckpt_path is None: return fn = self.state.value if ckpt_path == 'best': # if user requests the best checkpoint but we don't have it, error if not self.checkpoint_callback.best_model_path: if self.fast_dev_run: raise MisconfigurationException( f'You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do' f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.' ) raise MisconfigurationException( f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.' ) # load best weights ckpt_path = self.checkpoint_callback.best_model_path if not ckpt_path: raise MisconfigurationException( f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`' ) # only one process running at this point for TPUs, as spawn isn't triggered yet if self._device_type != DeviceType.TPU: self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) return ckpt_path
def hpc_load(self, checkpoint_path: str, on_gpu: bool): """ Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc. All restored states are listed in return value description of `dump_checkpoint`. """ # read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path` checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) # acquire the model model = self.trainer.lightning_module # restore model and datamodule state self.restore_model_state(model, checkpoint) if self.trainer.root_gpu is not None: model.cuda(self.trainer.root_gpu) # restore training state self.restore_training_state(checkpoint) # call hpc specific hook model.on_hpc_load(checkpoint)
def load_checkpoint( self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage ) -> Dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. Args: path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. Returns: The loaded checkpoint. Raises: FileNotFoundError: If ``path`` is not found by the ``fsspec`` filesystem """ # Try to read the checkpoint at `path`. If not exist, do not restore checkpoint. fs = get_filesystem(path) if not fs.exists(path): raise FileNotFoundError(f"Checkpoint at {path} not found. Aborting training.") return pl_load(path, map_location=map_location)
def hpc_load(self, folderpath, on_gpu): filepath = '{}/hpc_ckpt_{}.ckpt'.format( folderpath, self.max_ckpt_in_folder(folderpath)) # load on CPU first checkpoint = pl_load(filepath, map_location=lambda storage, loc: storage) # load model state model = self.trainer.get_model() # restore states from 'PyTorch-Lightning checkpoint' dictionary object self.restore_model_state(model, checkpoint) if self.trainer.root_gpu is not None: model.cuda(self.trainer.root_gpu) # load training state (affects trainer only) self.restore_training_state(checkpoint) # call model hook model.on_hpc_load(checkpoint) log.info(f'restored hpc model from: {filepath}')
def train_mnist_tune_checkpoint(config, checkpoint_dir=None, data_dir=None, num_epochs=10, num_gpus=0): trainer = pl.Trainer( max_epochs=num_epochs, # If fractional GPUs passed in, convert to int. gpus=math.ceil(num_gpus), logger=TensorBoardLogger( save_dir=tune.get_trial_dir(), name="", version="."), progress_bar_refresh_rate=0, callbacks=[ TuneReportCheckpointCallback( metrics={ "loss": "ptl/val_loss", "mean_accuracy": "ptl/val_accuracy" }, filename="checkpoint", on="validation_end") ]) if checkpoint_dir: # Currently, this leads to errors: # model = LightningMNISTClassifier.load_from_checkpoint( # os.path.join(checkpoint, "checkpoint")) # Workaround: ckpt = pl_load( os.path.join(checkpoint_dir, "checkpoint"), map_location=lambda storage, loc: storage) model = LightningMNISTClassifier._load_model_state( ckpt, config=config, data_dir=data_dir) trainer.current_epoch = ckpt["epoch"] else: model = LightningMNISTClassifier(config=config, data_dir=data_dir) trainer.fit(model)
from shared import MODELS_FOLDER from tqdm.auto import tqdm if __name__ == "__main__": SRC_LANG = "en" TGT_LANG = "de" ENSEMBLE = 10 LAST_CKPT = 63 trafo_states = [] print('Loading...') for epoch in tqdm(range(LAST_CKPT - ENSEMBLE, LAST_CKPT)): ckpt_path = os.path.join(MODELS_FOLDER, 'transformer', f'trafo_{SRC_LANG}_{TGT_LANG}_{epoch}.pt') trafo_states.append( pl_load(ckpt_path, map_location=lambda storage, loc: storage)['state_dict']) "Average models into model" print("Averaging...") avg_state = {} for key in trafo_states[-1]: mean = 0 for trafo_state in trafo_states: mean += trafo_state[key] avg_state[key] = mean / len(trafo_states) print('saving...') ckpt_path = os.path.join(MODELS_FOLDER, 'transformer', f'trafo_{SRC_LANG}_{TGT_LANG}_{LAST_CKPT}.pt') avg_ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) avg_ckpt['state_dict'] = avg_state
def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor, reduce_lr_on_plateau): """ Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and checkpoint data """ max_epochs = 3 limit_train_batches = 5 limit_val_batches = 7 lr = 1e-1 class CustomBoringModel(BoringModel): def __init__(self): super().__init__() self.train_log_epochs = torch.randn(max_epochs, limit_train_batches) self.val_logs = torch.randn(max_epochs, limit_val_batches) def training_step(self, batch, batch_idx): log_value = self.train_log_epochs[self.current_epoch, batch_idx] self.log('train_log', log_value, on_epoch=True) return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): log_value = self.val_logs[self.current_epoch, batch_idx] self.log('val_log', log_value) self.log('epoch', self.current_epoch, on_epoch=True) return super().validation_step(batch, batch_idx) def configure_optimizers(self): optimizer = optim.SGD(self.parameters(), lr=lr) if reduce_lr_on_plateau: lr_scheduler = { 'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), 'monitor': monitor, 'strict': True, } else: lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] filename = '{' + f'{monitor}' + ':.4f}-{epoch}' checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) model = CustomBoringModel() if validation_step is None: model.validation_step = None if val_dataloaders is None: model.val_dataloaders = None trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint], limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches, max_epochs=max_epochs, progress_bar_refresh_rate=0, ) results = trainer.fit(model) assert results assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" ckpt_files = list(Path(tmpdir).glob('*.ckpt')) scores = [ metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric ] lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates assert len(ckpt_files) == len(scores) == max_epochs assert len(lr_scheduler_debug) == max_epochs for epoch in range(max_epochs): score = scores[epoch] expected_score = getattr(model, f'{monitor}s')[epoch].mean().item() expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt' assert math.isclose(score, expected_score, rel_tol=1e-4) chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) assert chk['epoch'] == epoch + 1 assert chk['global_step'] == limit_train_batches * (epoch + 1) mc_specific_data = chk['callbacks'][type(checkpoint)] assert mc_specific_data['dirpath'] == checkpoint.dirpath assert mc_specific_data['monitor'] == monitor assert mc_specific_data['current_score'] == score if not reduce_lr_on_plateau: lr_scheduler_specific_data = chk['lr_schedulers'][0] assert lr_scheduler_specific_data['_step_count'] == epoch + 2 assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**( epoch + 1)) assert lr_scheduler_debug[epoch]['monitor_val'] == ( score if reduce_lr_on_plateau else None) assert lr_scheduler_debug[epoch]['monitor_key'] == ( monitor if reduce_lr_on_plateau else None)
def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: return pl_load(checkpoint_path, map_location=(lambda storage, loc: storage))
def load_from_checkpoint( cls, checkpoint_path: str, *args, map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, hparams_file: Optional[str] = None, strict: bool = True, **kwargs, ): r""" Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to `__init__` in the checkpoint under `module_arguments` Any arguments specified through \*args and \*\*kwargs will override args stored in `hparams`. Args: checkpoint_path: Path to checkpoint. This can also be a URL. args: Any positional args needed to init the model. map_location: If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in :func:`torch.load`. hparams_file: Optional path to a .yaml file with hierarchical structure as in this example:: drop_prob: 0.2 dataloader: batch_size: 32 You most likely won't need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don't have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you'd like to use. These will be converted into a :class:`~dict` and passed into your :class:`LightningModule` for use. If your model's `hparams` argument is :class:`~argparse.Namespace` and .yaml file has hierarchical structure, you need to refactor your model to treat `hparams` as :class:`~dict`. strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys returned by this module's state dict. Default: `True`. hparam_overrides: A dictionary with keys to override in the hparams kwargs: Any keyword args needed to init the model. Return: :class:`LightningModule` with loaded weights and hyperparameters (if available). Example: .. code-block:: python # load weights without mapping ... MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path: NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x) """ if map_location is not None: checkpoint = pl_load(checkpoint_path, map_location=map_location) else: checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) if hparams_file is not None: extension = hparams_file.split('.')[-1] if extension.lower() in ('csv'): hparams = load_hparams_from_tags_csv(hparams_file) elif extension.lower() in ('yml', 'yaml'): hparams = load_hparams_from_yaml(hparams_file) else: raise ValueError( '.csv, .yml or .yaml is required for `hparams_file`') hparams['on_gpu'] = False # overwrite hparams by the given file checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams # for past checkpoint need to add the new key if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} # override the hparams with values that were passed in checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs) model = cls._load_model_state(checkpoint, *args, strict=strict, **kwargs) return model
def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Verify resuming from checkpoint runs the right number of epochs""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv('TORCH_HOME', tmpdir) hparams = EvalModelTemplate.get_default_hparams() def _new_model(): # Create a model that tracks epochs and batches seen model = EvalModelTemplate(**hparams) model.num_epochs_seen = 0 model.num_batches_seen = 0 model.num_on_load_checkpoint_called = 0 def increment_epoch(self): self.num_epochs_seen += 1 def increment_batch(self, _): self.num_batches_seen += 1 def increment_on_load_checkpoint(self, _): self.num_on_load_checkpoint_called += 1 # Bind methods to keep track of epoch numbers, batch numbers it has seen # as well as number of times it has called on_load_checkpoint() model.on_epoch_end = types.MethodType(increment_epoch, model) model.on_batch_start = types.MethodType(increment_batch, model) model.on_load_checkpoint = types.MethodType( increment_on_load_checkpoint, model) return model model = _new_model() trainer_options = dict( progress_bar_refresh_rate=0, max_epochs=2, limit_train_batches=0.65, limit_val_batches=1, checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), default_root_dir=tmpdir, early_stop_callback=False, val_check_interval=1., ) trainer = Trainer(**trainer_options) # fit model trainer.fit(model) training_batches = trainer.num_training_batches assert model.num_epochs_seen == 2 assert model.num_batches_seen == training_batches * 2 assert model.num_on_load_checkpoint_called == 0 # Other checkpoints can be uncommented if/when resuming mid-epoch is supported checkpoints = sorted( glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt'))) if url_ckpt: # transform local paths into url checkpoints ip, port = tmpdir_server checkpoints = [ f'http://{ip}:{port}/' + os.path.basename(check) for check in checkpoints ] for check in checkpoints: next_model = _new_model() state = pl_load(check) # Resume training trainer_options['max_epochs'] = 2 new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check) new_trainer.fit(next_model) assert state[ 'global_step'] + next_model.num_batches_seen == training_batches * trainer_options[ 'max_epochs'] assert next_model.num_on_load_checkpoint_called == 1
def load_checkpoint_hparams(checkpoint_path): ckpt = pl_load(checkpoint_path) return ckpt['hyper_parameters']
def assert_checkpoint_content(ckpt_dir): chk = pl_load(get_last_checkpoint(ckpt_dir)) assert chk["epoch"] == epochs assert chk["global_step"] == 4