コード例 #1
0
ファイル: tuning.py プロジェクト: lujun59/nlp_learning
 def lr_find(
     self,
     model: LightningModule,
     train_dataloader: Optional[DataLoader] = None,
     val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
     min_lr: float = 1e-8,
     max_lr: float = 1,
     num_training: int = 100,
     mode: str = 'exponential',
     early_stop_threshold: float = 4.0,
     datamodule: Optional[LightningDataModule] = None,
     update_attr: bool = False,
 ):
     self.setup_trainer(model, train_dataloader, val_dataloaders,
                        datamodule)
     return lr_find(
         self.trainer,
         model,
         train_dataloader,
         val_dataloaders,
         min_lr,
         max_lr,
         num_training,
         mode,
         early_stop_threshold,
         datamodule,
         update_attr,
     )
コード例 #2
0
    def _tune(
        self,
        model: "pl.LightningModule",
        scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
        lr_find_kwargs: Optional[Dict[str, Any]] = None,
    ) -> _TunerResult:
        scale_batch_size_kwargs = scale_batch_size_kwargs or {}
        lr_find_kwargs = lr_find_kwargs or {}
        # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
        result = _TunerResult()

        self.trainer.strategy.connect(model)

        is_tuning = self.trainer.auto_scale_batch_size or self.trainer.auto_lr_find
        if self.trainer._accelerator_connector.is_distributed and is_tuning:
            raise MisconfigurationException(
                "`trainer.tune()` is currently not supported with"
                f" `Trainer(strategy={self.trainer.strategy.strategy_name!r})`."
            )

        # Run auto batch size scaling
        if self.trainer.auto_scale_batch_size:
            if isinstance(self.trainer.auto_scale_batch_size, str):
                scale_batch_size_kwargs.setdefault("mode", self.trainer.auto_scale_batch_size)
            result["scale_batch_size"] = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs)

        # Run learning rate finder:
        if self.trainer.auto_lr_find:
            lr_find_kwargs.setdefault("update_attr", True)
            result["lr_find"] = lr_find(self.trainer, model, **lr_find_kwargs)

        self.trainer.state.status = TrainerStatus.FINISHED

        return result
コード例 #3
0
    def _tune(
        self,
        model: 'pl.LightningModule',
        scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
        lr_find_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Optional[Union[int, _LRFinder]]]:
        scale_batch_size_kwargs = scale_batch_size_kwargs or {}
        lr_find_kwargs = lr_find_kwargs or {}
        # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
        result = {}

        # Run auto batch size scaling
        if self.trainer.auto_scale_batch_size:
            if isinstance(self.trainer.auto_scale_batch_size, str):
                scale_batch_size_kwargs.setdefault("mode", self.trainer.auto_scale_batch_size)
            result['scale_batch_size'] = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs)

        # Run learning rate finder:
        if self.trainer.auto_lr_find:
            lr_find_kwargs.setdefault('update_attr', True)
            result['lr_find'] = lr_find(self.trainer, model, **lr_find_kwargs)

        self.trainer.state.status = TrainerStatus.FINISHED

        return result
コード例 #4
0
 def lr_find(
     self,
     model: LightningModule,
     train_dataloader: Optional[DataLoader] = None,
     val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
     min_lr: float = 1e-8,
     max_lr: float = 1,
     num_training: int = 100,
     mode: str = 'exponential',
     early_stop_threshold: float = 4.0,
 ):
     return lr_find(self.trainer, model, train_dataloader, val_dataloaders,
                    min_lr, max_lr, num_training, mode,
                    early_stop_threshold)