Exemple #1
0
    def _tune(
        self,
        model: 'pl.LightningModule',
        scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
        lr_find_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Optional[Union[int, _LRFinder]]]:
        scale_batch_size_kwargs = scale_batch_size_kwargs or {}
        lr_find_kwargs = lr_find_kwargs or {}
        # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
        result = {}

        # Run auto batch size scaling
        if self.trainer.auto_scale_batch_size:
            if isinstance(self.trainer.auto_scale_batch_size, str):
                scale_batch_size_kwargs.setdefault("mode", self.trainer.auto_scale_batch_size)
            result['scale_batch_size'] = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs)

        # Run learning rate finder:
        if self.trainer.auto_lr_find:
            lr_find_kwargs.setdefault('update_attr', True)
            result['lr_find'] = lr_find(self.trainer, model, **lr_find_kwargs)

        self.trainer.state.status = TrainerStatus.FINISHED

        return result
Exemple #2
0
    def _tune(
        self,
        model: "pl.LightningModule",
        scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
        lr_find_kwargs: Optional[Dict[str, Any]] = None,
    ) -> _TunerResult:
        scale_batch_size_kwargs = scale_batch_size_kwargs or {}
        lr_find_kwargs = lr_find_kwargs or {}
        # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
        result = _TunerResult()

        self.trainer.strategy.connect(model)

        is_tuning = self.trainer.auto_scale_batch_size or self.trainer.auto_lr_find
        if self.trainer._accelerator_connector.is_distributed and is_tuning:
            raise MisconfigurationException(
                "`trainer.tune()` is currently not supported with"
                f" `Trainer(strategy={self.trainer.strategy.strategy_name!r})`."
            )

        # Run auto batch size scaling
        if self.trainer.auto_scale_batch_size:
            if isinstance(self.trainer.auto_scale_batch_size, str):
                scale_batch_size_kwargs.setdefault("mode", self.trainer.auto_scale_batch_size)
            result["scale_batch_size"] = scale_batch_size(self.trainer, model, **scale_batch_size_kwargs)

        # Run learning rate finder:
        if self.trainer.auto_lr_find:
            lr_find_kwargs.setdefault("update_attr", True)
            result["lr_find"] = lr_find(self.trainer, model, **lr_find_kwargs)

        self.trainer.state.status = TrainerStatus.FINISHED

        return result
    def scale_batch_size(
        self,
        model,
        mode: str = 'power',
        steps_per_trial: int = 3,
        init_val: int = 2,
        max_trials: int = 25,
        batch_arg_name: str = 'batch_size',
        **fit_kwargs
    ):
        r"""
        Will iteratively try to find the largest batch size for a given model
        that does not give an out of memory (OOM) error.

        Args:
            model: Model to fit.

            mode: string setting the search mode. Either `power` or `binsearch`.
                If mode is `power` we keep multiplying the batch size by 2, until
                we get an OOM error. If mode is 'binsearch', we will initially
                also keep multiplying by 2 and after encountering an OOM error
                do a binary search between the last successful batch size and the
                batch size that failed.

            steps_per_trial: number of steps to run with a given batch size.
                Idealy 1 should be enough to test if a OOM error occurs,
                however in practise a few are needed

            init_val: initial batch size to start the search with

            max_trials: max number of increase in batch size done before
               algorithm is terminated

            batch_arg_name: name of the attribute that stores the batch size.
                It is expected that the user has provided a model or datamodule that has a hyperparameter
                with that name. We will look for this attribute name in the following places

                - ``model``
                - ``model.hparams``
                - ``model.datamodule``
                - ``trainer.datamodule`` (the datamodule passed to the tune method)

            **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
                or datamodule.

        """
        return scale_batch_size(
            self.trainer,
            model,
            mode,
            steps_per_trial,
            init_val,
            max_trials,
            batch_arg_name,
            **fit_kwargs,
        )
 def scale_batch_size(self,
                      model,
                      mode: str = 'power',
                      steps_per_trial: int = 3,
                      init_val: int = 2,
                      max_trials: int = 25,
                      batch_arg_name: str = 'batch_size',
                      **fit_kwargs):
     return scale_batch_size(
         self.trainer, model, mode, steps_per_trial, init_val, max_trials, batch_arg_name, **fit_kwargs
     )