def _run_power_scaling(trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int) -> int: """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): garbage_collection_cuda() trainer.fit_loop.global_step = 0 # reset after each try try: # Try fit trainer.tuner._run(model) # Double in size new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') except RuntimeError as exception: # Only these errors should trigger an adjustment if is_oom_error(exception): # If we fail in power mode, half the size and return garbage_collection_cuda() new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed') break else: raise # some other error not memory related if changed: # Force the train dataloader to reset as the batch size has changed trainer.reset_train_dataloader(model) else: break return new_size
def _run_binsearch_scaling(trainer: "pl.Trainer", model: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int) -> int: """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search """ low = 1 high = None count = 0 while True: garbage_collection_cuda() trainer.fit_loop.global_step = 0 # reset after each try try: # Try fit trainer.tuner._run(model) count += 1 if count > max_trials: break # Double in size low = new_size if high: if high - low <= 1: break midval = (high + low) // 2 new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="succeeded") else: new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded") if changed: # Force the train dataloader to reset as the batch size has changed trainer.reset_train_dataloader(model) trainer.reset_val_dataloader(model) else: break except RuntimeError as exception: # Only these errors should trigger an adjustment if is_oom_error(exception): # If we fail in power mode, half the size and return garbage_collection_cuda() high = new_size midval = (high + low) // 2 new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="failed") if high - low <= 1: break else: raise # some other error not memory related return new_size
def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): garbage_collection_cuda() trainer.global_step = 0 # reset after each try try: # Try fit trainer.fit(model, **fit_kwargs) # Double in size new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') except RuntimeError as exception: # Only these errors should trigger an adjustment if is_oom_error(exception): # If we fail in power mode, half the size and return garbage_collection_cuda() new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed') break else: raise # some other error not memory related if not changed: break return new_size
def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): """ Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search """ high = None count = 0 while True: garbage_collection_cuda() trainer.global_step = 0 # reset after each try try: # Try fit trainer.fit(model, **fit_kwargs) count += 1 if count > max_trials: break # Double in size low = new_size if high: if high - low <= 1: break midval = (high + low) // 2 new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded') else: new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') if not changed: break except RuntimeError as exception: # Only these errors should trigger an adjustment if is_oom_error(exception): # If we fail in power mode, half the size and return garbage_collection_cuda() high = new_size midval = (high + low) // 2 new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='failed') if high - low <= 1: break else: raise # some other error not memory related return new_size
def pick_single_gpu_realist_workload(exclude_gpus: list, model:LightningModule): for i in range(torch.cuda.device_count()): if i in exclude_gpus: continue # Try to allocate on device: device = torch.device(f"cuda:{i}") batch=next(iter(model.train_dataloader)) try: model_device = model.to(device) batch_device = batch.to(device) model_device.train() # record grads model_device(batch_device) except RuntimeError as exception: if is_oom_error(exception): # clean after the failed attempt garbage_collection_cuda() else: raise continue return i raise RuntimeError("No GPUs available.")