def _run_power_scaling(trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int) -> int: """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): garbage_collection_cuda() trainer.train_loop.global_step = 0 # reset after each try try: # Try fit trainer.tuner._run(model) # Double in size new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') except RuntimeError as exception: # Only these errors should trigger an adjustment if is_oom_error(exception): # If we fail in power mode, half the size and return garbage_collection_cuda() new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed') break else: raise # some other error not memory related if not changed: break return new_size
def _run_binsearch_scaling(trainer: "pl.Trainer", model: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int) -> int: """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search """ low = 1 high = None count = 0 while True: garbage_collection_cuda() trainer.fit_loop.global_step = 0 # reset after each try try: # Try fit trainer.tuner._run(model) count += 1 if count > max_trials: break # Double in size low = new_size if high: if high - low <= 1: break midval = (high + low) // 2 new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="succeeded") else: new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded") if changed: # Force the train dataloader to reset as the batch size has changed trainer.reset_train_dataloader(model) trainer.reset_val_dataloader(model) else: break except RuntimeError as exception: # Only these errors should trigger an adjustment if is_oom_error(exception): # If we fail in power mode, half the size and return garbage_collection_cuda() high = new_size midval = (high + low) // 2 new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="failed") if high - low <= 1: break else: raise # some other error not memory related return new_size
def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): garbage_collection_cuda() trainer.global_step = 0 # reset after each try try: # Try fit trainer.fit(model, **fit_kwargs) # Double in size new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') except RuntimeError as exception: # Only these errors should trigger an adjustment if is_oom_error(exception): # If we fail in power mode, half the size and return garbage_collection_cuda() new_size = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed') break else: raise # some other error not memory related return new_size
def finetune(self, dataset, validation_split: float = 0.15, epochs: int = 20, batch_size: int = None, optimal_batch_size: int = None, early_stopping: bool = True, trainer = None): self.batch_size = batch_size or 1 if not torch.cuda.is_available(): raise Exception("You need a cuda capable (Nvidia) GPU for finetuning") len_train = int(len(dataset) * (1 - validation_split)) len_valid = len(dataset) - len_train dataset_train, dataset_valid = torch.utils.data.random_split(dataset, [len_train, len_valid]) self.dataset_train = dataset_train self.dataset_valid = dataset_valid if batch_size == None: # Find batch size temp_trainer = pl.Trainer(auto_scale_batch_size="power", gpus=-1) print("Finding the optimal batch size...") temp_trainer.tune(self) # Ensure that memory gets cleared del self.trainer del temp_trainer garbage_collection_cuda() trainer_kwargs = {} if optimal_batch_size: # Don't go over batch_size = min(self.batch_size, optimal_batch_size) accumulate_grad_batches = max(1, int(optimal_batch_size / batch_size)) trainer_kwargs["accumulate_grad_batches"] = accumulate_grad_batches if early_stopping: # Stop when val loss stops improving early_stopping = EarlyStopping(monitor="val_loss", patience=1) trainer_kwargs["callbacks"] = [early_stopping] if not trainer: trainer = pl.Trainer(gpus=-1, max_epochs=epochs, checkpoint_callback=False, logger=False, **trainer_kwargs) self.model.train() trainer.fit(self) del self.dataset_train del self.dataset_valid del self.trainer # For some reason the model can end up on CPU after training self.to(self._model_device) self.model.eval() print("Training finished! Save your model for later with backprop.save or upload it with backprop.upload")
def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): """ Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search """ high = None count = 0 while True: garbage_collection_cuda() trainer.global_step = 0 # reset after each try try: # Try fit trainer.fit(model, **fit_kwargs) count += 1 if count > max_trials: break # Double in size low = new_size if high: if high - low <= 1: break midval = (high + low) // 2 new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded') else: new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') if not changed: break except RuntimeError as exception: # Only these errors should trigger an adjustment if is_oom_error(exception): # If we fail in power mode, half the size and return garbage_collection_cuda() high = new_size midval = (high + low) // 2 new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='failed') if high - low <= 1: break else: raise # some other error not memory related return new_size
def pick_single_gpu_realist_workload(exclude_gpus: list, model:LightningModule): for i in range(torch.cuda.device_count()): if i in exclude_gpus: continue # Try to allocate on device: device = torch.device(f"cuda:{i}") batch=next(iter(model.train_dataloader)) try: model_device = model.to(device) batch_device = batch.to(device) model_device.train() # record grads model_device(batch_device) except RuntimeError as exception: if is_oom_error(exception): # clean after the failed attempt garbage_collection_cuda() else: raise continue return i raise RuntimeError("No GPUs available.")
def scale_batch_size(self, model: LightningModule, mode: str = 'power', steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = 'batch_size'): r""" Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. Args: model: Model to fit. mode: string setting the search mode. Either `power` or `binsearch`. If mode is `power` we keep multiplying the batch size by 2, until we get an OOM error. If mode is 'binsearch', we will initially also keep multiplying by 2 and after encountering an OOM error do a binary search between the last successful batch size and the batch size that failed. steps_per_trial: number of steps to run with a given batch size. Idealy 1 should be enough to test if a OOM error occurs, however in practise a few are needed init_val: initial batch size to start the search with max_trials: max number of increase in batch size done before algorithm is terminated """ if not hasattr(model, batch_arg_name): raise MisconfigurationException( f'Field {batch_arg_name} not found in `model.hparams`') if hasattr(model.train_dataloader, 'patch_loader_code'): raise MisconfigurationException( 'The batch scaling feature cannot be used with dataloaders' ' passed directly to `.fit()`. Please disable the feature or' ' incorporate the dataloader into the model.') # Arguments we adjust during the batch size finder, save for restoring self.__scale_batch_dump_params() # Set to values that are required by the algorithm self.__scale_batch_reset_params(model, steps_per_trial) # Save initial model, that is loaded after batch size is found save_path = os.path.join(self.default_root_dir, 'temp_model.ckpt') self.save_checkpoint(str(save_path)) if self.progress_bar_callback: self.progress_bar_callback.disable() # Initially we just double in size until an OOM is encountered new_size = _adjust_batch_size( self, value=init_val) # initially set to init_val if mode == 'power': new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials) elif mode == 'binsearch': new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials) else: raise ValueError( 'mode in method `scale_batch_size` can only be `power` or `binsearch' ) garbage_collection_cuda() log.info( f'Finished batch size finder, will continue with full run using batch size {new_size}' ) # Restore initial state of model self.restore(str(save_path), on_gpu=self.on_gpu) os.remove(save_path) # Finish by resetting variables so trainer is ready to fit model self.__scale_batch_restore_params() if self.progress_bar_callback: self.progress_bar_callback.enable() return new_size
def scale_batch_size( trainer: "pl.Trainer", model: "pl.LightningModule", mode: str = "power", steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, batch_arg_name: str = "batch_size", ) -> Optional[int]: """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`""" if trainer.fast_dev_run: rank_zero_warn("Skipping batch size scaler since fast_dev_run is enabled.") return if not lightning_hasattr(model, batch_arg_name): raise MisconfigurationException(f"Field {batch_arg_name} not found in both `model` and `model.hparams`") if hasattr(model, batch_arg_name) and hasattr(model, "hparams") and batch_arg_name in model.hparams: rank_zero_warn( f"Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!" f" `model.{batch_arg_name}` will be used as the initial batch size for scaling." " If this is not the intended behavior, please remove either one." ) if not trainer._data_connector._train_dataloader_source.is_module(): raise MisconfigurationException( "The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`." " Please disable the feature or incorporate the dataloader into the model." ) # Save initial model, that is loaded after batch size is found ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") trainer.save_checkpoint(ckpt_path) params = __scale_batch_dump_params(trainer) # Set to values that are required by the algorithm __scale_batch_reset_params(trainer, steps_per_trial) if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() # Initially we just double in size until an OOM is encountered new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val if mode == "power": new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials) elif mode == "binsearch": new_size = _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials) else: raise ValueError("mode in method `scale_batch_size` could either be `power` or `binsearch`") garbage_collection_cuda() log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}") # Restore initial state of model trainer._checkpoint_connector.restore(ckpt_path) trainer.strategy.remove_checkpoint(ckpt_path) __scale_batch_restore_params(trainer, params) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() return new_size