def lr_find( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, datamodule: Optional[LightningDataModule] = None, update_attr: bool = False, ): self.setup_trainer(model, train_dataloader, val_dataloaders, datamodule) return lr_find( self.trainer, model, train_dataloader, val_dataloaders, min_lr, max_lr, num_training, mode, early_stop_threshold, datamodule, update_attr, )
def _tune( self, model: "pl.LightningModule", scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, lr_find_kwargs: Optional[Dict[str, Any]] = None, ) -> _TunerResult: scale_batch_size_kwargs = scale_batch_size_kwargs or {} lr_find_kwargs = lr_find_kwargs or {} # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added result = _TunerResult() self.trainer.strategy.connect(model) is_tuning = self.trainer.auto_scale_batch_size or self.trainer.auto_lr_find if self.trainer._accelerator_connector.is_distributed and is_tuning: raise MisconfigurationException( "`trainer.tune()` is currently not supported with" f" `Trainer(strategy={self.trainer.strategy.strategy_name!r})`." ) # Run auto batch size scaling if self.trainer.auto_scale_batch_size: if isinstance(self.trainer.auto_scale_batch_size, str): scale_batch_size_kwargs.setdefault("mode", self.trainer.auto_scale_batch_size) result["scale_batch_size"] = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs) # Run learning rate finder: if self.trainer.auto_lr_find: lr_find_kwargs.setdefault("update_attr", True) result["lr_find"] = lr_find(self.trainer, model, **lr_find_kwargs) self.trainer.state.status = TrainerStatus.FINISHED return result
def _tune( self, model: 'pl.LightningModule', scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, lr_find_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Optional[Union[int, _LRFinder]]]: scale_batch_size_kwargs = scale_batch_size_kwargs or {} lr_find_kwargs = lr_find_kwargs or {} # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added result = {} # Run auto batch size scaling if self.trainer.auto_scale_batch_size: if isinstance(self.trainer.auto_scale_batch_size, str): scale_batch_size_kwargs.setdefault("mode", self.trainer.auto_scale_batch_size) result['scale_batch_size'] = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs) # Run learning rate finder: if self.trainer.auto_lr_find: lr_find_kwargs.setdefault('update_attr', True) result['lr_find'] = lr_find(self.trainer, model, **lr_find_kwargs) self.trainer.state.status = TrainerStatus.FINISHED return result
def lr_find( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, ): return lr_find(self.trainer, model, train_dataloader, val_dataloaders, min_lr, max_lr, num_training, mode, early_stop_threshold)