def _tune( self, model: 'pl.LightningModule', scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, lr_find_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Optional[Union[int, _LRFinder]]]: scale_batch_size_kwargs = scale_batch_size_kwargs or {} lr_find_kwargs = lr_find_kwargs or {} # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added result = {} # Run auto batch size scaling if self.trainer.auto_scale_batch_size: if isinstance(self.trainer.auto_scale_batch_size, str): scale_batch_size_kwargs.setdefault("mode", self.trainer.auto_scale_batch_size) result['scale_batch_size'] = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs) # Run learning rate finder: if self.trainer.auto_lr_find: lr_find_kwargs.setdefault('update_attr', True) result['lr_find'] = lr_find(self.trainer, model, **lr_find_kwargs) self.trainer.state.status = TrainerStatus.FINISHED return result
def _tune( self, model: "pl.LightningModule", scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, lr_find_kwargs: Optional[Dict[str, Any]] = None, ) -> _TunerResult: scale_batch_size_kwargs = scale_batch_size_kwargs or {} lr_find_kwargs = lr_find_kwargs or {} # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added result = _TunerResult() self.trainer.strategy.connect(model) is_tuning = self.trainer.auto_scale_batch_size or self.trainer.auto_lr_find if self.trainer._accelerator_connector.is_distributed and is_tuning: raise MisconfigurationException( "`trainer.tune()` is currently not supported with" f" `Trainer(strategy={self.trainer.strategy.strategy_name!r})`." ) # Run auto batch size scaling if self.trainer.auto_scale_batch_size: if isinstance(self.trainer.auto_scale_batch_size, str): scale_batch_size_kwargs.setdefault("mode", self.trainer.auto_scale_batch_size) result["scale_batch_size"] = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs) # Run learning rate finder: if self.trainer.auto_lr_find: lr_find_kwargs.setdefault("update_attr", True) result["lr_find"] = lr_find(self.trainer, model, **lr_find_kwargs) self.trainer.state.status = TrainerStatus.FINISHED return result
def scale_batch_size( self, model, mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', **fit_kwargs ): r""" Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. Args: model: Model to fit. mode: string setting the search mode. Either `power` or `binsearch`. If mode is `power` we keep multiplying the batch size by 2, until we get an OOM error. If mode is 'binsearch', we will initially also keep multiplying by 2 and after encountering an OOM error do a binary search between the last successful batch size and the batch size that failed. steps_per_trial: number of steps to run with a given batch size. Idealy 1 should be enough to test if a OOM error occurs, however in practise a few are needed init_val: initial batch size to start the search with max_trials: max number of increase in batch size done before algorithm is terminated batch_arg_name: name of the attribute that stores the batch size. It is expected that the user has provided a model or datamodule that has a hyperparameter with that name. We will look for this attribute name in the following places - ``model`` - ``model.hparams`` - ``model.datamodule`` - ``trainer.datamodule`` (the datamodule passed to the tune method) **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader or datamodule. """ return scale_batch_size( self.trainer, model, mode, steps_per_trial, init_val, max_trials, batch_arg_name, **fit_kwargs, )
def scale_batch_size(self, model, mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size', **fit_kwargs): return scale_batch_size( self.trainer, model, mode, steps_per_trial, init_val, max_trials, batch_arg_name, **fit_kwargs )