def _run_power_scaling(trainer: 'pl.Trainer', model: 'pl.LightningModule',
                       new_size: int, batch_arg_name: str,
                       max_trials: int) -> int:
    """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """
    for _ in range(max_trials):
        garbage_collection_cuda()
        trainer.train_loop.global_step = 0  # reset after each try
        try:
            # Try fit
            trainer.tuner._run(model)
            # Double in size
            new_size, changed = _adjust_batch_size(trainer,
                                                   batch_arg_name,
                                                   factor=2.0,
                                                   desc='succeeded')
        except RuntimeError as exception:
            # Only these errors should trigger an adjustment
            if is_oom_error(exception):
                # If we fail in power mode, half the size and return
                garbage_collection_cuda()
                new_size, _ = _adjust_batch_size(trainer,
                                                 batch_arg_name,
                                                 factor=0.5,
                                                 desc='failed')
                break
            else:
                raise  # some other error not memory related

        if not changed:
            break
    return new_size
Esempio n. 2
0
def _run_binsearch_scaling(trainer: "pl.Trainer", model: "pl.LightningModule",
                           new_size: int, batch_arg_name: str,
                           max_trials: int) -> int:
    """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is
    encountered.

    Hereafter, the batch size is further refined using a binary search
    """
    low = 1
    high = None
    count = 0
    while True:
        garbage_collection_cuda()
        trainer.fit_loop.global_step = 0  # reset after each try
        try:
            # Try fit
            trainer.tuner._run(model)
            count += 1
            if count > max_trials:
                break
            # Double in size
            low = new_size
            if high:
                if high - low <= 1:
                    break
                midval = (high + low) // 2
                new_size, changed = _adjust_batch_size(trainer,
                                                       batch_arg_name,
                                                       value=midval,
                                                       desc="succeeded")
            else:
                new_size, changed = _adjust_batch_size(trainer,
                                                       batch_arg_name,
                                                       factor=2.0,
                                                       desc="succeeded")

            if changed:
                # Force the train dataloader to reset as the batch size has changed
                trainer.reset_train_dataloader(model)
                trainer.reset_val_dataloader(model)
            else:
                break

        except RuntimeError as exception:
            # Only these errors should trigger an adjustment
            if is_oom_error(exception):
                # If we fail in power mode, half the size and return
                garbage_collection_cuda()
                high = new_size
                midval = (high + low) // 2
                new_size, _ = _adjust_batch_size(trainer,
                                                 batch_arg_name,
                                                 value=midval,
                                                 desc="failed")
                if high - low <= 1:
                    break
            else:
                raise  # some other error not memory related

    return new_size
def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials,
                       **fit_kwargs):
    """ Batch scaling mode where the size is doubled at each iteration until an
        OOM error is encountered. """
    for _ in range(max_trials):
        garbage_collection_cuda()
        trainer.global_step = 0  # reset after each try
        try:
            # Try fit
            trainer.fit(model, **fit_kwargs)
            # Double in size
            new_size = _adjust_batch_size(trainer,
                                          batch_arg_name,
                                          factor=2.0,
                                          desc='succeeded')
        except RuntimeError as exception:
            # Only these errors should trigger an adjustment
            if is_oom_error(exception):
                # If we fail in power mode, half the size and return
                garbage_collection_cuda()
                new_size = _adjust_batch_size(trainer,
                                              batch_arg_name,
                                              factor=0.5,
                                              desc='failed')
                break
            else:
                raise  # some other error not memory related
    return new_size
Esempio n. 4
0
    def finetune(self, dataset, validation_split: float = 0.15, epochs: int = 20, batch_size: int = None,
                optimal_batch_size: int = None, early_stopping: bool = True, trainer = None):
        self.batch_size = batch_size or 1

        if not torch.cuda.is_available():
            raise Exception("You need a cuda capable (Nvidia) GPU for finetuning")
        
        len_train = int(len(dataset) * (1 - validation_split))
        len_valid = len(dataset) - len_train
        dataset_train, dataset_valid = torch.utils.data.random_split(dataset, [len_train, len_valid])

        self.dataset_train = dataset_train
        self.dataset_valid = dataset_valid

        if batch_size == None:
            # Find batch size
            temp_trainer = pl.Trainer(auto_scale_batch_size="power", gpus=-1)
            print("Finding the optimal batch size...")
            temp_trainer.tune(self)

            # Ensure that memory gets cleared
            del self.trainer
            del temp_trainer
            garbage_collection_cuda()

        trainer_kwargs = {}
        
        if optimal_batch_size:
            # Don't go over
            batch_size = min(self.batch_size, optimal_batch_size)
            accumulate_grad_batches = max(1, int(optimal_batch_size / batch_size))
            trainer_kwargs["accumulate_grad_batches"] = accumulate_grad_batches
        
        if early_stopping:
            # Stop when val loss stops improving
            early_stopping = EarlyStopping(monitor="val_loss", patience=1)
            trainer_kwargs["callbacks"] = [early_stopping]

        if not trainer:
            trainer = pl.Trainer(gpus=-1, max_epochs=epochs, checkpoint_callback=False,
                logger=False, **trainer_kwargs)

        self.model.train()
        trainer.fit(self)

        del self.dataset_train
        del self.dataset_valid
        del self.trainer

        # For some reason the model can end up on CPU after training
        self.to(self._model_device)
        self.model.eval()
        print("Training finished! Save your model for later with backprop.save or upload it with backprop.upload")
def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name,
                           max_trials, **fit_kwargs):
    """ Batch scaling mode where the size is initially is doubled at each iteration
        until an OOM error is encountered. Hereafter, the batch size is further
        refined using a binary search """
    high = None
    count = 0
    while True:
        garbage_collection_cuda()
        trainer.global_step = 0  # reset after each try
        try:
            # Try fit
            trainer.fit(model, **fit_kwargs)
            count += 1
            if count > max_trials:
                break
            # Double in size
            low = new_size
            if high:
                if high - low <= 1:
                    break
                midval = (high + low) // 2
                new_size, changed = _adjust_batch_size(trainer,
                                                       batch_arg_name,
                                                       value=midval,
                                                       desc='succeeded')
            else:
                new_size, changed = _adjust_batch_size(trainer,
                                                       batch_arg_name,
                                                       factor=2.0,
                                                       desc='succeeded')

            if not changed:
                break

        except RuntimeError as exception:
            # Only these errors should trigger an adjustment
            if is_oom_error(exception):
                # If we fail in power mode, half the size and return
                garbage_collection_cuda()
                high = new_size
                midval = (high + low) // 2
                new_size, _ = _adjust_batch_size(trainer,
                                                 batch_arg_name,
                                                 value=midval,
                                                 desc='failed')
                if high - low <= 1:
                    break
            else:
                raise  # some other error not memory related

    return new_size
def pick_single_gpu_realist_workload(exclude_gpus: list, model:LightningModule):
    for i in range(torch.cuda.device_count()):
        if i in exclude_gpus:
            continue
        # Try to allocate on device:
        device = torch.device(f"cuda:{i}")
        batch=next(iter(model.train_dataloader))
        try:
            model_device = model.to(device) 
            batch_device = batch.to(device)
            model_device.train() # record grads 
            model_device(batch_device)
        except RuntimeError as exception:
            if is_oom_error(exception): # clean after the failed attempt
                garbage_collection_cuda()
            else: raise
            continue
        return i
    raise RuntimeError("No GPUs available.")
    def scale_batch_size(self,
                         model: LightningModule,
                         mode: str = 'power',
                         steps_per_trial: int = 3,
                         init_val: int = 2,
                         max_trials: int = 25,
                         batch_arg_name: str = 'batch_size'):
        r"""
        Will iteratively try to find the largest batch size for a given model
        that does not give an out of memory (OOM) error.

        Args:
            model: Model to fit.

            mode: string setting the search mode. Either `power` or `binsearch`.
                If mode is `power` we keep multiplying the batch size by 2, until
                we get an OOM error. If mode is 'binsearch', we will initially
                also keep multiplying by 2 and after encountering an OOM error
                do a binary search between the last successful batch size and the
                batch size that failed.

            steps_per_trial: number of steps to run with a given batch size.
                Idealy 1 should be enough to test if a OOM error occurs,
                however in practise a few are needed

            init_val: initial batch size to start the search with

            max_trials: max number of increase in batch size done before
               algorithm is terminated

        """
        if not hasattr(model, batch_arg_name):
            raise MisconfigurationException(
                f'Field {batch_arg_name} not found in `model.hparams`')

        if hasattr(model.train_dataloader, 'patch_loader_code'):
            raise MisconfigurationException(
                'The batch scaling feature cannot be used with dataloaders'
                ' passed directly to `.fit()`. Please disable the feature or'
                ' incorporate the dataloader into the model.')

        # Arguments we adjust during the batch size finder, save for restoring
        self.__scale_batch_dump_params()

        # Set to values that are required by the algorithm
        self.__scale_batch_reset_params(model, steps_per_trial)

        # Save initial model, that is loaded after batch size is found
        save_path = os.path.join(self.default_root_dir, 'temp_model.ckpt')
        self.save_checkpoint(str(save_path))

        if self.progress_bar_callback:
            self.progress_bar_callback.disable()

        # Initially we just double in size until an OOM is encountered
        new_size = _adjust_batch_size(
            self, value=init_val)  # initially set to init_val
        if mode == 'power':
            new_size = _run_power_scaling(self, model, new_size,
                                          batch_arg_name, max_trials)
        elif mode == 'binsearch':
            new_size = _run_binsearch_scaling(self, model, new_size,
                                              batch_arg_name, max_trials)
        else:
            raise ValueError(
                'mode in method `scale_batch_size` can only be `power` or `binsearch'
            )

        garbage_collection_cuda()
        log.info(
            f'Finished batch size finder, will continue with full run using batch size {new_size}'
        )

        # Restore initial state of model
        self.restore(str(save_path), on_gpu=self.on_gpu)
        os.remove(save_path)

        # Finish by resetting variables so trainer is ready to fit model
        self.__scale_batch_restore_params()
        if self.progress_bar_callback:
            self.progress_bar_callback.enable()

        return new_size
Esempio n. 8
0
def scale_batch_size(
    trainer: "pl.Trainer",
    model: "pl.LightningModule",
    mode: str = "power",
    steps_per_trial: int = 3,
    init_val: int = 2,
    max_trials: int = 25,
    batch_arg_name: str = "batch_size",
) -> Optional[int]:
    """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`"""
    if trainer.fast_dev_run:
        rank_zero_warn("Skipping batch size scaler since fast_dev_run is enabled.")
        return

    if not lightning_hasattr(model, batch_arg_name):
        raise MisconfigurationException(f"Field {batch_arg_name} not found in both `model` and `model.hparams`")
    if hasattr(model, batch_arg_name) and hasattr(model, "hparams") and batch_arg_name in model.hparams:
        rank_zero_warn(
            f"Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!"
            f" `model.{batch_arg_name}` will be used as the initial batch size for scaling."
            " If this is not the intended behavior, please remove either one."
        )

    if not trainer._data_connector._train_dataloader_source.is_module():
        raise MisconfigurationException(
            "The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`."
            " Please disable the feature or incorporate the dataloader into the model."
        )

    # Save initial model, that is loaded after batch size is found
    ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
    trainer.save_checkpoint(ckpt_path)
    params = __scale_batch_dump_params(trainer)

    # Set to values that are required by the algorithm
    __scale_batch_reset_params(trainer, steps_per_trial)

    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.disable()

    # Initially we just double in size until an OOM is encountered
    new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)  # initially set to init_val
    if mode == "power":
        new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials)
    elif mode == "binsearch":
        new_size = _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials)
    else:
        raise ValueError("mode in method `scale_batch_size` could either be `power` or `binsearch`")

    garbage_collection_cuda()
    log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")

    # Restore initial state of model
    trainer._checkpoint_connector.restore(ckpt_path)
    trainer.strategy.remove_checkpoint(ckpt_path)
    __scale_batch_restore_params(trainer, params)

    if trainer.progress_bar_callback:
        trainer.progress_bar_callback.enable()

    return new_size