def load_checkpoint( self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage ) -> Dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. Args: path: Path to checkpoint map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations. Returns: The loaded checkpoint. Raises: FileNotFoundError: If ``path`` is not found by the ``fsspec`` filesystem """ # Try to read the checkpoint at `path`. If not exist, do not restore checkpoint. fs = get_filesystem(path) if not fs.exists(path): raise FileNotFoundError(f"Checkpoint at {path} not found. Aborting training.") return pl_load(path, map_location=map_location)
def __init_ckpt_dir(self, filepath, save_top_k): self._fs = get_filesystem(filepath if filepath is not None else "") if ( save_top_k is not None and save_top_k > 0 and filepath is not None and self._fs.isdir(filepath) and len(self._fs.ls(filepath)) > 0 ): rank_zero_warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k={save_top_k}" " All files in this directory will be deleted when a checkpoint is saved!" ) if not filepath: # will be determined by trainer at runtime self.dirpath, self.filename = None, None else: if self._fs.isdir(filepath): self.dirpath, self.filename = filepath, None else: if self._fs.protocol == "file": # dont normalize remote paths filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath)
def __init__(self, output_filename: Optional[str] = None, line_count_restriction: float = 1.0): """ Args: output_filename: optionally save profile results to file instead of printing to std out when training is finished. line_count_restriction: this can be used to limit the number of functions reported for each action. either an integer (to select a count of lines), or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) """ self.profiled_actions = {} self.line_count_restriction = line_count_restriction self.output_fname = output_filename self.output_file = None if self.output_fname: fs = get_filesystem(self.output_fname) self.output_file = fs.open(self.output_fname, "w") streaming_out = [self.output_file.write ] if self.output_file else [log.info] super().__init__(output_streams=streaming_out)
def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object fs = get_filesystem(folderpath) fs.makedirs(folderpath, exist_ok=True) # save logger to make sure we get all the metrics logger.save() max_suffix = self.max_ckpt_in_folder(folderpath) ckpt_number = (max_suffix if max_suffix is not None else 0) + 1 fs.makedirs(folderpath, exist_ok=True) filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') # give model a chance to do something on hpc_save model = self.trainer.get_model() checkpoint = self.dump_checkpoint() model.on_hpc_save(checkpoint) if self.trainer.accelerator_backend: checkpoint = self.trainer.accelerator_backend.on_save(checkpoint) # do the actual save # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: atomic_save(checkpoint, filepath) except AttributeError as err: if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( 'warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}' ) atomic_save(checkpoint, filepath) return filepath
def hpc_save(self, folderpath: str, logger: Optional[LightningLoggerBase]) -> str: # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object fs = get_filesystem(folderpath) fs.makedirs(folderpath, exist_ok=True) # save logger to make sure we get all the metrics if logger: logger.finalize("finished") max_suffix = self.max_ckpt_version_in_folder(folderpath) ckpt_number = (max_suffix if max_suffix is not None else 0) + 1 fs.makedirs(folderpath, exist_ok=True) filepath = os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt") # give model a chance to do something on hpc_save model = self.trainer.lightning_module checkpoint = self.dump_checkpoint() # TODO: remove this in v1.8. model.on_hpc_save(checkpoint) # do the actual save # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: atomic_save(checkpoint, filepath) except AttributeError as err: if pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( f"warning, `hyper_parameters` dropped from checkpoint. An attribute is not picklable {err}" ) atomic_save(checkpoint, filepath) return filepath
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: # save the config in `setup` because (1) we want it to save regardless of the trainer function run # and we want to save before processes are spawned log_dir = trainer.log_dir # this broadcasts the directory assert log_dir is not None config_path = os.path.join(log_dir, self.config_filename) fs = get_filesystem(log_dir) if not self.overwrite: # check if the file exists on rank 0 file_exists = fs.isfile( config_path) if trainer.is_global_zero else False # broadcast whether to fail to all ranks file_exists = trainer.strategy.broadcast(file_exists) if file_exists: raise RuntimeError( f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" " results of a previous run. You can delete the previous config file," " set `LightningCLI(save_config_callback=None)` to disable config saving," " or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file." ) # save the file on rank 0 if trainer.is_global_zero: # save only on rank zero to avoid race conditions on DDP. # the `log_dir` needs to be created as we rely on the logger to do it usually # but it hasn't logged anything at this point fs.makedirs(log_dir, exist_ok=True) self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile)
def to_disk(self) -> None: """Write predictions to file(s). """ for filepath, predictions in self.predictions.items(): fs = get_filesystem(filepath) # normalize local filepaths only if fs.protocol == "file": filepath = os.path.realpath(filepath) if self.world_size > 1: stem, extension = os.path.splitext(filepath) filepath = f"{stem}_rank_{self.global_rank}{extension}" dirpath = os.path.split(filepath)[0] fs.mkdirs(dirpath, exist_ok=True) # Convert any tensor values to list predictions = { k: v if not isinstance(v, Tensor) else v.tolist() for k, v in predictions.items() } # Check if all features for this file add up to same length feature_lens = {k: len(v) for k, v in predictions.items()} if len(set(feature_lens.values())) != 1: raise ValueError( "Mismatching feature column lengths found in stored EvalResult predictions." ) # Switch predictions so each entry has its own dict outputs = [] for values in zip(*predictions.values()): output_element = dict(zip(predictions.keys(), values)) outputs.append(output_element) # Write predictions for current file to disk with fs.open(filepath, "wb") as fp: torch.save(outputs, fp)
def __init__( self, save_dir: str, name: Optional[str] = "default", version: Optional[Union[int, str]] = None, log_graph: bool = False, default_hp_metric: bool = True, prefix: str = "", sub_dir: Optional[str] = None, **kwargs, ): super().__init__() self._save_dir = save_dir self._name = name or "" self._version = version self._sub_dir = sub_dir self._log_graph = log_graph self._default_hp_metric = default_hp_metric self._prefix = prefix self._fs = get_filesystem(save_dir) self._experiment = None self.hparams = {} self._kwargs = kwargs
def restore(self, checkpoint_path: str, on_gpu: bool) -> bool: """ Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore. All restored states are listed in return value description of `dump_checkpoint`. """ # Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint. fs = get_filesystem(checkpoint_path) if not fs.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint at {checkpoint_path} not found. Aborting training.") checkpoint, load_optimizer_states = self.trainer.training_type_plugin.restore_model_state_from_ckpt_path( checkpoint_path, map_location=lambda storage, loc: storage ) model = self.trainer.lightning_module if on_gpu: model.cuda(self.trainer.root_gpu) # restore training state self.restore_training_state(checkpoint, load_optimizer_states) rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}") return True
def test_torchscript_save_load_custom_filesystem(tmpdir, modelclass): """ Test that scripted LightningModule is correctly saved and can be loaded with custom filesystems. """ _DUMMY_PRFEIX = "dummy" _PREFIX_SEPARATOR = "://" class DummyFileSystem(LocalFileSystem): ... fsspec.register_implementation(_DUMMY_PRFEIX, DummyFileSystem, clobber=True) model = modelclass() output_file = os.path.join(_DUMMY_PRFEIX, _PREFIX_SEPARATOR, tmpdir, "model.pt") script = model.to_torchscript(file_path=output_file) fs = get_filesystem(output_file) with fs.open(output_file, "rb") as f: loaded_script = torch.jit.load(f) assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
def _split_train_test_data(data: Dict, multi_label: bool = False) -> List[Dict]: file_path = data.get("export_json", None) if not file_path: raise MisconfigurationException( "The key `export_json` should be provided as a string.") fs = get_filesystem(file_path) with fs.open(file_path) as f: raw_data = np.asarray(json.load(f)) train_raw_data = [] test_raw_data = [] for task in raw_data: for annotation in task["annotations"]: if annotation["ground_truth"]: test_raw_data.append(task) elif not annotation["ground_truth"]: train_raw_data.append(task) break assert len(raw_data) == len(train_raw_data) + len(test_raw_data) dirname = os.path.dirname(file_path) basename = os.path.basename(file_path) results = [] for stage, raw_data in [("train", train_raw_data), ("test", test_raw_data)]: filename = basename if stage in basename else f"{stage}_{basename}" export_path = os.path.join(dirname, filename) LabelStudioInput._export_data_to_json(export_path, raw_data) output_data = deepcopy(data) output_data["export_json"] = export_path results.append(output_data) return results
def test_get_filesystem_local_filesystem(): assert isinstance(get_filesystem("tmpdir/tmp_file"), LocalFileSystem)
def scale_batch_size( trainer: 'pl.Trainer', model: 'pl.LightningModule', mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', ) -> Optional[int]: """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`""" if trainer.fast_dev_run: rank_zero_warn( 'Skipping batch size scaler since fast_dev_run is enabled.', UserWarning) return if not lightning_hasattr(model, batch_arg_name): raise MisconfigurationException( f'Field {batch_arg_name} not found in both `model` and `model.hparams`' ) if hasattr(model, batch_arg_name) and hasattr( model, "hparams") and batch_arg_name in model.hparams: rank_zero_warn( f'Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!' f' `model.{batch_arg_name}` will be used as the initial batch size for scaling.' f' If this is not the intended behavior, please remove either one.' ) if hasattr(model.train_dataloader, 'patch_loader_code'): raise MisconfigurationException( 'The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`.' ' Please disable the feature or incorporate the dataloader into the model.' ) # Arguments we adjust during the batch size finder, save for restoring __scale_batch_dump_params(trainer) # Set to values that are required by the algorithm __scale_batch_reset_params(trainer, model, steps_per_trial) # Save initial model, that is loaded after batch size is found save_path = os.path.join(trainer.default_root_dir, 'scale_batch_size_temp_model.ckpt') trainer.save_checkpoint(str(save_path)) if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() # Initially we just double in size until an OOM is encountered new_size, _ = _adjust_batch_size( trainer, batch_arg_name, value=init_val) # initially set to init_val if mode == 'power': new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials) elif mode == 'binsearch': new_size = _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials) else: raise ValueError( 'mode in method `scale_batch_size` could either be `power` or `binsearch`' ) garbage_collection_cuda() log.info( f'Finished batch size finder, will continue with full run using batch size {new_size}' ) # Restore initial state of model if trainer.is_global_zero: trainer.checkpoint_connector.restore(str(save_path)) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) # Finish by resetting variables so trainer is ready to fit model __scale_batch_restore_params(trainer) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() return new_size
def lr_find( trainer, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, datamodule: Optional[LightningDataModule] = None, update_attr: bool = False, ): r""" ``lr_find`` enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. Args: model: Model to do range testing for train_dataloader: A PyTorch ``DataLoader`` with training samples. If the model has a predefined train_dataloader method, this will be skipped. min_lr: minimum learning rate to investigate max_lr: maximum learning rate to investigate num_training: number of learning rates to test mode: Search strategy to update learning rate after each batch: - ``'exponential'`` (default): Will increase the learning rate exponentially. - ``'linear'``: Will increase the learning rate linearly. early_stop_threshold: threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None. datamodule: An optional ``LightningDataModule`` which holds the training and validation dataloader(s). Note that the ``train_dataloader`` and ``val_dataloaders`` parameters cannot be used at the same time as this parameter, or a ``MisconfigurationException`` will be raised. update_attr: Whether to update the learning rate attribute or not. Raises: MisconfigurationException: If learning rate/lr in ``model`` or ``model.hparams`` isn't overriden when ``auto_lr_find=True``, or if you are using `more than one optimizer` with learning rate finder. Example:: # Setup model and trainer model = MyModelClass(hparams) trainer = pl.Trainer() # Run lr finder lr_finder = trainer.tuner.lr_find(model, ...) # Inspect results fig = lr_finder.plot(); fig.show() suggested_lr = lr_finder.suggestion() # Overwrite lr and create new model hparams.lr = suggested_lr model = MyModelClass(hparams) # Ready to train with new learning rate trainer.fit(model) """ if trainer.fast_dev_run: rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) return # Determine lr attr if update_attr: lr_attr_name = _determine_lr_attr_name(trainer, model) save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') __lr_finder_dump_params(trainer, model) # Prevent going into infinite loop trainer.auto_lr_find = False # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) # Use special lr logger callback trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] # No logging trainer.logger = DummyLogger() # Max step set to number of iterations trainer.max_steps = num_training # Disable standard progress bar for fit if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() # Required for saving the model trainer.optimizers, trainer.schedulers = [], [], trainer.model = model # Dump model checkpoint trainer.save_checkpoint(str(save_path)) # Configure optimizer and scheduler model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers) # Fit, lr & loss logged in callback trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) # Prompt if we stopped early if trainer.global_step != num_training: log.info('LR finder stopped early due to diverging loss.') # Transfer results from callback to lr finder object lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses}) lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose # Reset model state if trainer.is_global_zero: trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) # Finish by resetting variables so trainer is ready to fit model __lr_finder_restore_params(trainer, model) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() # Update lr attr if required if update_attr: lr = lr_finder.suggestion() # TODO: log lr.results to self.logger lightning_setattr(model, lr_attr_name, lr) log.info(f'Learning rate set to {lr}') return lr_finder
def __init__( self, filepath: Optional[str] = None, monitor: Optional[str] = "checkpoint_on", verbose: bool = False, save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = "auto", period: int = 1, prefix: str = "", ): super().__init__() self._fs = get_filesystem(filepath if filepath is not None else "") if ( save_top_k > 0 and filepath is not None and self._fs.isdir(filepath) and len(self._fs.ls(filepath)) > 0 ): rank_zero_warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." " All files in this directory will be deleted when a checkpoint is saved!" ) self.monitor = monitor self.verbose = verbose if not filepath: # will be determined by trainer at runtime self.dirpath, self.filename = None, None else: if self._fs.isdir(filepath): self.dirpath, self.filename = filepath, None else: if self._fs.protocol == "file": # dont normalize remote paths filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath) self._fs.makedirs(self.dirpath, exist_ok=True) self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period self.epoch_last_check = None self.prefix = prefix self.best_k_models = {} # {filename: monitor} self.kth_best_model_path = "" self.best_model_score = 0 self.best_model_path = "" self.last_model_path = "" self.save_function = None self.warned_result_obj = False torch_inf = torch.tensor(np.Inf) mode_dict = { "min": (torch_inf, "min"), "max": (-torch_inf, "max"), "auto": (-torch_inf, "max") if monitor is not None and ("acc" in monitor or monitor.startswith("fmeasure")) else (torch_inf, "min"), } if mode not in mode_dict: rank_zero_warn( f"ModelCheckpoint mode {mode} is unknown, " f"fallback to auto mode.", RuntimeWarning, ) mode = "auto" self.kth_value, self.mode = mode_dict[mode]
def scale_batch_size(trainer, model: LightningModule, mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', **fit_kwargs): r""" Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. Args: trainer: The Trainer model: Model to fit. mode: string setting the search mode. Either `power` or `binsearch`. If mode is `power` we keep multiplying the batch size by 2, until we get an OOM error. If mode is 'binsearch', we will initially also keep multiplying by 2 and after encountering an OOM error do a binary search between the last successful batch size and the batch size that failed. steps_per_trial: number of steps to run with a given batch size. Idealy 1 should be enough to test if a OOM error occurs, however in practise a few are needed init_val: initial batch size to start the search with max_trials: max number of increase in batch size done before algorithm is terminated batch_arg_name: name of the attribute that stores the batch size. It is expected that the user has provided a model or datamodule that has a hyperparameter with that name. We will look for this attribute name in the following places - `model` - `model.hparams` - `model.datamodule` - `trainer.datamodule` (the datamodule passed to the tune method) **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader or datamodule. """ if trainer.fast_dev_run: rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning) return if not lightning_hasattr(model, batch_arg_name): raise MisconfigurationException( f'Field {batch_arg_name} not found in both `model` and `model.hparams`' ) if hasattr(model, batch_arg_name) and hasattr( model, "hparams") and batch_arg_name in model.hparams: rank_zero_warn( f'Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!' f' `model.{batch_arg_name}` will be used as the initial batch size for scaling.' f' If this is not the intended behavior, please remove either one.' ) if hasattr(model.train_dataloader, 'patch_loader_code'): raise MisconfigurationException( 'The batch scaling feature cannot be used with dataloaders' ' passed directly to `.fit()`. Please disable the feature or' ' incorporate the dataloader into the model.') # Arguments we adjust during the batch size finder, save for restoring __scale_batch_dump_params(trainer) # Set to values that are required by the algorithm __scale_batch_reset_params(trainer, model, steps_per_trial) # Save initial model, that is loaded after batch size is found save_path = os.path.join(trainer.default_root_dir, 'scale_batch_size_temp_model.ckpt') trainer.save_checkpoint(str(save_path)) if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() # Initially we just double in size until an OOM is encountered new_size = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val if mode == 'power': new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs) elif mode == 'binsearch': new_size = _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs) else: raise ValueError( 'mode in method `scale_batch_size` can only be `power` or `binsearch' ) garbage_collection_cuda() log.info( f'Finished batch size finder, will continue with full run using batch size {new_size}' ) # Restore initial state of model if trainer.is_global_zero: trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) # Finish by resetting variables so trainer is ready to fit model __scale_batch_restore_params(trainer) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() return new_size
def _export_data_to_json(export_path: str, raw_data: List[Dict]) -> Dict: fs = get_filesystem(export_path) if fs.exists(export_path): fs.delete(export_path) with fs.open(export_path, mode="w") as f: json.dump(raw_data, f)
def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False, save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() if filepath: self._fs = get_filesystem(filepath) else: self._fs = get_filesystem("") # will give local fileystem if save_top_k > 0 and filepath is not None and self._fs.isdir( filepath) and len(self._fs.ls(filepath)) > 0: rank_zero_warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" ) self._rank = 0 self.monitor = monitor self.verbose = verbose if filepath is None: # will be determined by trainer at runtime self.dirpath, self.filename = None, None else: if self._fs.isdir(filepath): self.dirpath, self.filename = filepath, "{epoch}" else: if self._fs.protocol == "file": # dont normalize remote paths filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath) self._fs.makedirs(self.dirpath, exist_ok=True) self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period self.epoch_last_check = None self.prefix = prefix self.best_k_models = {} # {filename: monitor} self.kth_best_model_path = '' self.best_model_score = 0 self.best_model_path = '' self.save_function = None self.warned_result_obj = False torch_inf = torch.tensor(np.Inf) mode_dict = { 'min': (torch_inf, 'min'), 'max': (-torch_inf, 'max'), 'auto': (-torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') else (torch_inf, 'min'), } if mode not in mode_dict: rank_zero_warn( f'ModelCheckpoint mode {mode} is unknown, ' f'fallback to auto mode.', RuntimeWarning) mode = 'auto' self.kth_value, self.mode = mode_dict[mode]
def lr_find( trainer: 'pl.Trainer', model: 'pl.LightningModule', min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, update_attr: bool = False, ) -> Optional[_LRFinder]: """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" if trainer.fast_dev_run: rank_zero_warn( 'Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) return # Determine lr attr if update_attr: lr_attr_name = _determine_lr_attr_name(trainer, model) save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') __lr_finder_dump_params(trainer, model) # Prevent going into infinite loop trainer.auto_lr_find = False # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) # Use special lr logger callback trainer.callbacks = [ _LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1) ] # No logging trainer.logger = DummyLogger() # Max step set to number of iterations trainer.max_steps = num_training # Disable standard progress bar for fit if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() # Required for saving the model trainer.optimizers, trainer.schedulers = [], [], trainer.model = model # Dump model checkpoint trainer.save_checkpoint(str(save_path)) # Configure optimizer and scheduler model.configure_optimizers = lr_finder._exchange_scheduler( model.configure_optimizers) # Fit, lr & loss logged in callback trainer.tuner._run(model) # Prompt if we stopped early if trainer.global_step != num_training: log.info( f'LR finder stopped early after {trainer.global_step} steps due to diverging loss.' ) # Transfer results from callback to lr finder object lr_finder.results.update({ 'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses }) lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose # Reset model state if trainer.is_global_zero: trainer.checkpoint_connector.restore( str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) # Finish by resetting variables so trainer is ready to fit model __lr_finder_restore_params(trainer, model) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() # Update lr attr if required if update_attr: lr = lr_finder.suggestion() # TODO: log lr.results to self.logger lightning_setattr(model, lr_attr_name, lr) log.info(f'Learning rate set to {lr}') return lr_finder