def lr_find( trainer, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, datamodule: Optional[LightningDataModule] = None, update_attr: bool = False, ): r""" ``lr_find`` enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. Args: model: Model to do range testing for train_dataloader: A PyTorch ``DataLoader`` with training samples. If the model has a predefined train_dataloader method, this will be skipped. min_lr: minimum learning rate to investigate max_lr: maximum learning rate to investigate num_training: number of learning rates to test mode: Search strategy to update learning rate after each batch: - ``'exponential'`` (default): Will increase the learning rate exponentially. - ``'linear'``: Will increase the learning rate linearly. early_stop_threshold: threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None. datamodule: An optional ``LightningDataModule`` which holds the training and validation dataloader(s). Note that the ``train_dataloader`` and ``val_dataloaders`` parameters cannot be used at the same time as this parameter, or a ``MisconfigurationException`` will be raised. update_attr: Whether to update the learning rate attribute or not. Raises: MisconfigurationException: If learning rate/lr in ``model`` or ``model.hparams`` isn't overriden when ``auto_lr_find=True``, or if you are using `more than one optimizer` with learning rate finder. Example:: # Setup model and trainer model = MyModelClass(hparams) trainer = pl.Trainer() # Run lr finder lr_finder = trainer.tuner.lr_find(model, ...) # Inspect results fig = lr_finder.plot(); fig.show() suggested_lr = lr_finder.suggestion() # Overwrite lr and create new model hparams.lr = suggested_lr model = MyModelClass(hparams) # Ready to train with new learning rate trainer.fit(model) """ if trainer.fast_dev_run: rank_zero_warn( 'Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) return # Determine lr attr if update_attr: lr_attr_name = _determine_lr_attr_name(trainer, model) save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') __lr_finder_dump_params(trainer, model) # Prevent going into infinite loop trainer.auto_lr_find = False # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) # Use special lr logger callback trainer.callbacks = [ _LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1) ] # No logging trainer.logger = DummyLogger() # Max step set to number of iterations trainer.max_steps = num_training # Disable standard progress bar for fit if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() # Required for saving the model trainer.optimizers, trainer.schedulers = [], [], trainer.model = model # Dump model checkpoint trainer.save_checkpoint(str(save_path)) # Configure optimizer and scheduler model.configure_optimizers = lr_finder._exchange_scheduler( model.configure_optimizers) # Fit, lr & loss logged in callback trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) # Prompt if we stopped early if trainer.global_step != num_training: log.info('LR finder stopped early due to diverging loss.') # Transfer results from callback to lr finder object lr_finder.results.update({ 'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses }) lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose # Reset model state if trainer.is_global_zero: trainer.checkpoint_connector.restore( str(save_path), on_gpu=trainer._device_type == DeviceType.GPU) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) # Finish by resetting variables so trainer is ready to fit model __lr_finder_restore_params(trainer, model) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() # Update lr attr if required if update_attr: lr = lr_finder.suggestion() # TODO: log lr.results to self.logger lightning_setattr(model, lr_attr_name, lr) log.info(f'Learning rate set to {lr}') return lr_finder
def lr_find( trainer: "pl.Trainer", model: "pl.LightningModule", min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = "exponential", early_stop_threshold: float = 4.0, update_attr: bool = False, ) -> Optional[_LRFinder]: """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" if trainer.fast_dev_run: rank_zero_warn("Skipping learning rate finder since fast_dev_run is enabled.") return None # Determine lr attr if update_attr: lr_attr_name = _determine_lr_attr_name(trainer, model) # Save initial model, that is loaded after learning rate is found ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt") trainer.save_checkpoint(ckpt_path) params = __lr_finder_dump_params(trainer) # Set to values that are required by the algorithm __lr_finder_reset_params(trainer, num_training, early_stop_threshold) # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) # Disable standard progress bar for fit if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() # Configure optimizer and scheduler trainer.strategy.setup_optimizers = lr_finder._exchange_scheduler(trainer, model) # type: ignore[assignment] # Fit, lr & loss logged in callback trainer.tuner._run(model) # Prompt if we stopped early if trainer.global_step != num_training: log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.") # Transfer results from callback to lr finder object lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses}) lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose __lr_finder_restore_params(trainer, params) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() # Update lr attr if required if update_attr: lr = lr_finder.suggestion() # TODO: log lr.results to self.logger lightning_setattr(model, lr_attr_name, lr) log.info(f"Learning rate set to {lr}") # Restore initial state of model trainer._checkpoint_connector.restore(ckpt_path) trainer.strategy.remove_checkpoint(ckpt_path) return lr_finder
def lr_find( trainer: "pl.Trainer", model: "pl.LightningModule", min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = "exponential", early_stop_threshold: float = 4.0, update_attr: bool = False, ) -> Optional[_LRFinder]: """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`""" if trainer.fast_dev_run: rank_zero_warn( "Skipping learning rate finder since fast_dev_run is enabled.", UserWarning) return # Determine lr attr if update_attr: lr_attr_name = _determine_lr_attr_name(trainer, model) save_path = os.path.join(trainer.default_root_dir, "lr_find_temp_model.ckpt") __lr_finder_dump_params(trainer, model) # Prevent going into infinite loop trainer.auto_lr_find = False # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) # Use special lr logger callback trainer.callbacks = [ _LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1) ] # No logging trainer.logger = DummyLogger() # Max step set to number of iterations trainer.fit_loop.max_steps = num_training # Disable standard progress bar for fit if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() # Required for saving the model trainer.optimizers, trainer.lr_schedulers = [], [] trainer.model = model # Dump model checkpoint trainer.save_checkpoint(str(save_path)) # Configure optimizer and scheduler model.configure_optimizers = lr_finder._exchange_scheduler( model.configure_optimizers) # Fit, lr & loss logged in callback trainer.tuner._run(model) # Prompt if we stopped early if trainer.global_step != num_training: log.info( f"LR finder stopped early after {trainer.global_step} steps due to diverging loss." ) # Transfer results from callback to lr finder object lr_finder.results.update({ "lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses }) lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose # Reset model state if trainer.is_global_zero: trainer.checkpoint_connector.restore(str(save_path)) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) # Finish by resetting variables so trainer is ready to fit model __lr_finder_restore_params(trainer, model) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() # Update lr attr if required if update_attr: lr = lr_finder.suggestion() # TODO: log lr.results to self.logger lightning_setattr(model, lr_attr_name, lr) log.info(f"Learning rate set to {lr}") return lr_finder