def train(self): self.run_sanity_check(self.get_model()) # enable train mode model = self.get_model() model.train() torch.set_grad_enabled(True) # reload data when needed self.train_loop.reset_train_val_dataloaders(model) # hook self.train_loop.on_train_start() try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # hook self.train_loop.on_train_epoch_start(epoch) # run train epoch self.train_loop.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: # hook self.train_loop.on_train_end() return # update LR schedulers self.optimizer_connector.update_learning_rates( interval='epoch') # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if self.should_stop: if (met_min_epochs and met_min_steps): self.train_loop.on_train_end() return else: log.info( 'Trainer was signaled to stop but required minimum epochs' f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') # hook self.train_loop.on_train_end() except KeyboardInterrupt: rank_zero_warn( 'Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True self._state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() # hook self.train_loop.on_train_end()
def run_train(self): self._pre_training_routine() if not self.is_global_zero and self.progress_bar_callback is not None: self.progress_bar_callback.disable() self.run_sanity_check(self.lightning_module) # set stage for logging self._running_stage = RunningStage.TRAINING self.checkpoint_connector.has_trained = False # enable train mode model = self.lightning_module model.train() torch.set_grad_enabled(True) # reload data when needed self.train_loop.reset_train_val_dataloaders(model) # hook self.train_loop.on_train_start() try: if self.train_loop.should_skip_training(): return # run all epochs epochs = range(self.current_epoch, self.max_epochs) if self.max_epochs else count( self.current_epoch) for epoch in epochs: # hook self.train_loop.on_train_epoch_start(epoch) with self.profiler.profile("run_training_epoch"): # run train epoch self.train_loop.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: return # early stopping met_min_epochs = ( epoch >= self.min_epochs - 1) if self.min_epochs else True met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if self.should_stop: if met_min_epochs and met_min_steps: return else: log.info( 'Trainer was signaled to stop but required minimum epochs' f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') # hook self.train_loop.on_train_end() except KeyboardInterrupt: rank_zero_warn( 'Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press ctrl+c many times... only shutdown once if not self.interrupted: self.interrupted = True self._state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() finally: # hook self.train_loop.on_train_end()
def kth_best_model(self): rank_zero_warn( "Attribute `kth_best_model` has been renamed to `kth_best_model_path` since v0.8.0" " and will be removed in v0.10.0", DeprecationWarning) return self.kth_best_model_path
def run_evaluation(self, on_epoch=False): if not (self.evaluating or self.sanity_checking): rank_zero_warn( f"`trainer.run_evaluation()` was called but the running stage is set to {self._running_stage}." " This should not happen normally. Setting it to `RunningStage.VALIDATING`", RuntimeWarning) self.validating = True # prepare dataloaders dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders( ) # check if we want to skip this evaluation if self.evaluation_loop.should_skip_evaluation(max_batches): return [], [] # enable eval mode + no grads self.evaluation_loop.on_evaluation_model_eval() # ref model model = self.lightning_module model.zero_grad() torch.set_grad_enabled(False) # hook self.evaluation_loop.on_evaluation_start() # set up the eval loop self.evaluation_loop.setup(max_batches, dataloaders) # hook self.evaluation_loop.on_evaluation_epoch_start() # run validation/testing for dataloader_idx, dataloader in enumerate(dataloaders): # bookkeeping dl_outputs = [] dataloader = self.accelerator.process_dataloader(dataloader) dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when running on limited batches if batch_idx >= dl_max_batches: break # hook self.evaluation_loop.on_evaluation_batch_start( batch, batch_idx, dataloader_idx) # lightning module methods with self.profiler.profile("evaluation_step_and_end"): output = self.evaluation_loop.evaluation_step( batch, batch_idx, dataloader_idx) output = self.evaluation_loop.evaluation_step_end(output) # hook + store predictions self.evaluation_loop.on_evaluation_batch_end( output, batch, batch_idx, dataloader_idx) # log batch metrics self.evaluation_loop.log_evaluation_step_metrics(batch_idx) # track epoch level outputs dl_outputs = self.track_output_for_epoch_end( dl_outputs, output) # store batch level output per dataloader self.evaluation_loop.outputs.append(dl_outputs) outputs = self.evaluation_loop.outputs # reset outputs self.evaluation_loop.outputs = [] # with a single dataloader don't pass a 2D list if self.evaluation_loop.num_dataloaders == 1: outputs = outputs[0] # lightning module method self.evaluation_loop.evaluation_epoch_end(outputs) # hook self.evaluation_loop.on_evaluation_epoch_end(outputs) # update epoch-level lr_schedulers if on_epoch: self.optimizer_connector.update_learning_rates(interval='epoch') # hook self.evaluation_loop.on_evaluation_end() # log epoch metrics eval_loop_results = self.logger_connector.get_evaluate_epoch_results() # save predictions to disk self.evaluation_loop.predictions.to_disk() # enable train mode again self.evaluation_loop.on_evaluation_model_train() # reset cached results self.logger_connector.reset() torch.set_grad_enabled(True) return eval_loop_results
def _load_bolts_unet(_, num_classes: int, **kwargs) -> nn.Module: rank_zero_warn( "The UNet model does not require a backbone, so the backbone will be ignored.", UserWarning) return UNet(num_classes, **kwargs)
def use_ddp2(self, val: bool) -> None: rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self.accelerator_connector._distrib_type = DistributedType.DDP2
def _reset_eval_dataloader( self, model: LightningModule, mode: str) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: model: The current `LightningModule` mode: Either `'val'` or `'test'` Returns: Tuple (num_batches, dataloaders) """ # use the training loader as val and test when overfitting if self.overfit_batches > 0: dataloaders = self.request_dataloader( getattr(model, 'train_dataloader')) else: dataloaders = self.request_dataloader( getattr(model, f'{mode}_dataloader')) if not isinstance(dataloaders, list): dataloaders = [dataloaders] for loader_i in range(len(dataloaders)): loader = dataloaders[loader_i] # shuffling in val and test set is bad practice if mode in ('val', 'test') and hasattr( loader, 'sampler') and isinstance(loader.sampler, RandomSampler): # when overfitting, the dataloader should not have sampler if self.overfit_batches > 0: rank_zero_warn( 'You requested to overfit but enabled training dataloader shuffling.' ' We are turning it off for you.') dataloaders[loader_i] = self.replace_sampler( loader, SequentialSampler(loader.dataset)) else: rank_zero_warn( f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn' ' this off for validation and test dataloaders.') if any([dl is None for dl in dataloaders]): rank_zero_warn( "One of given dataloaders is None and it will be skipped.") # add samplers dataloaders = [ self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None ] loader_num_batches = [] # determine number of batches # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): num_batches = len(dataloader) if _has_len( dataloader) else float('inf') self._worker_check(dataloader, f'{mode} dataloader {i}') # percent or num_steps limit_eval_batches = getattr(self, f'limit_{mode}_batches') if num_batches != float('inf'): self._check_batch_limits(f'limit_{mode}_batches') # limit num batches either as a percent or num steps if isinstance(limit_eval_batches, float): num_batches = int(num_batches * limit_eval_batches) else: num_batches = min(len(dataloader), limit_eval_batches) elif limit_eval_batches not in (0.0, 1.0): raise MisconfigurationException( 'When using an infinite DataLoader (e.g. with an IterableDataset' f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,' f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.' ) if num_batches == 0 and limit_eval_batches > 0.0 and isinstance( limit_eval_batches, float): min_pct = 1.0 / len(dataloader) raise MisconfigurationException( f'you requested to check {limit_eval_batches} of the {mode} dataloader but' f' {limit_eval_batches}*{num_batches} = 0. Please increase the limit_{mode}_batches.' f' Try at least limit_{mode}_batches={min_pct}') loader_num_batches.append(num_batches) return loader_num_batches, dataloaders
def restore_training_state(self, checkpoint): """ Restore trainer state. Model will get its change to update :param checkpoint: :return: """ # validation if 'optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint: raise KeyError( 'Trying to restore training state but checkpoint contains only the model.' ' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.' ) if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]): raise ValueError( "The checkpoint you're attempting to load follows an" " outdated schema. You can upgrade to the current schema by running" " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" " where `model.ckpt` is your checkpoint file.") # restore amp scaling if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint: self.trainer.scaler.load_state_dict( checkpoint['native_amp_scaling_state']) elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint: amp.load_state_dict(checkpoint['amp_scaling_state']) # restore callback states self.trainer.on_load_checkpoint(checkpoint) self.trainer.global_step = checkpoint['global_step'] self.trainer.current_epoch = checkpoint['epoch'] # crash if max_epochs is lower then the current epoch from the checkpoint if self.trainer.current_epoch > self.trainer.max_epochs: m = f""" you restored a checkpoint with current_epoch={self.trainer.current_epoch} but the Trainer(max_epochs={self.trainer.max_epochs}) """ raise MisconfigurationException(m) # Division deals with global step stepping once per accumulated batch # Inequality deals with different global step for odd vs even num_training_batches n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches expected_steps = self.trainer.num_training_batches / n_accum if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1: rank_zero_warn( "You're resuming from a checkpoint that ended mid-epoch." " Training will start from the beginning of the next epoch." " This can cause unreliable results if further training is done," " consider using an end of epoch checkpoint.") # restore the optimizers optimizer_states = checkpoint['optimizer_states'] for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) # move optimizer to GPU 1 weight at a time # avoids OOM if self.trainer.root_gpu is not None: for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda(self.trainer.root_gpu) # restore the lr schedulers lr_schedulers = checkpoint['lr_schedulers'] for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): scheduler['scheduler'].load_state_dict(lrs_state)
def train(self): rank_zero_warn( 'Displayed epoch numbers in the progress bar start from "1" until v0.6.x,' ' but will start from "0" in v0.8.0.', RuntimeWarning) # get model model = self.get_model() # load data # if reload_dataloaders_every_epoch, this is moved to the epoch loop if not self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) self.reset_val_dataloader(model) # Train start events with self.profiler.profile('on_train_start'): # callbacks self.on_train_start() # initialize early stop callback if self.early_stop_callback is not None: self.early_stop_callback.on_train_start(self, self.get_model()) # model hooks model.on_train_start() try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): # reset train dataloader if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) if self.use_ddp \ and hasattr(self.train_dataloader.sampler, 'set_epoch'): self.train_dataloader.sampler.set_epoch(epoch) # update training progress in trainer and model model.current_epoch = epoch self.current_epoch = epoch total_val_batches = 0 is_val_epoch = False if not self.disable_validation and self.num_training_batches != float( 'inf'): # val can be checked multiple times in epoch is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 val_checks_per_epoch = self.num_training_batches // self.val_check_batch val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 total_val_batches = self.num_val_batches * val_checks_per_epoch # total batches includes multiple val checks self.total_batches = self.num_training_batches + total_val_batches # changing gradient according accumulation_scheduler self.accumulation_scheduler.on_epoch_start( self, self.get_model()) # stores accumulated grad fractions per batch self.batch_loss_value = TensorRunningAccum( window_length=self.accumulate_grad_batches) if self.fast_dev_run: # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run num_iterations = 2 elif self.total_batches == float('inf'): # for infinite train or val loader, the progress bar never ends num_iterations = None else: num_iterations = self.total_batches # reset progress bar # .reset() doesn't work on disabled progress bar so we should check if not self.main_progress_bar.disable: self.main_progress_bar.reset(num_iterations) desc = f'Epoch {epoch + 1}' self.main_progress_bar.set_description(desc) # ----------------- # RUN TNG EPOCH # ----------------- self.run_training_epoch() # update LR schedulers self.update_learning_rates(interval='epoch') if self.max_steps and self.max_steps == self.global_step: self.run_training_teardown() return # early stopping met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True # TODO wrap this logic into the callback if self.enable_early_stop: if (met_min_epochs and met_min_steps) or self.fast_dev_run: should_stop = self.early_stop_callback.on_epoch_end( self, self.get_model()) # stop training stop = should_stop and met_min_epochs if stop: self.run_training_teardown() return self.run_training_teardown() except KeyboardInterrupt: if self.proc_rank == 0: log.info( 'Detected KeyboardInterrupt, attempting graceful shutdown...' ) self.interrupted = True self.run_training_teardown()
def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: """Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). Args: model: The `LightningModule` if calling this outside of the trainer scope. """ self.train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model) if self.overfit_batches > 0: self.train_dataloader = self._resolve_overfit_batches(self.train_dataloader) # automatically add samplers self.train_dataloader = apply_to_collection( self.train_dataloader, DataLoader, self.prepare_dataloader, shuffle=True, mode=RunningStage.TRAINING ) # check the workers recursively apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader") # add worker_init_fn for correct seeding in worker processes apply_to_collection(self.train_dataloader, DataLoader, self._auto_add_worker_init_fn, rank=self.global_rank) # add collate_fn to collect metadata for fault tolerant training if _fault_tolerant_training(): apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate) # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode) module = model or self.lightning_module or self.datamodule self.num_training_batches = ( len(self.train_dataloader) if has_len_all_ranks(self.train_dataloader, self.training_type_plugin, module) else float("inf") ) if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) elif self.num_training_batches != float("inf"): self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) elif self.limit_train_batches != 1.0: raise MisconfigurationException( "When using an IterableDataset for `limit_train_batches`," " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies" " `num_training_batches` to use." ) # determine when to check validation # if int passed in, val checks that often # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval if self.val_check_batch > self.num_training_batches: raise ValueError( f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " f"to the number of the training batches ({self.num_training_batches}). " "If you want to disable validation set `limit_val_batches` to 0.0 instead." ) else: if not has_len_all_ranks(self.train_dataloader, self.training_type_plugin, module): if self.val_check_interval == 1.0: self.val_check_batch = float("inf") else: raise MisconfigurationException( "When using an IterableDataset for `train_dataloader`," " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" " checking validation every k training batches." ) else: self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) if self.logger and self.num_training_batches < self.log_every_n_steps: rank_zero_warn( f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" " you want to see logs for the training epoch." )
def _reset_eval_dataloader( self, mode: RunningStage, model: Optional["pl.LightningModule"] = None ) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: mode: The running stage of the ``Trainer`` model: The ``LightningModule`` if calling this outside of the trainer scope. Returns: Tuple (num_batches, dataloaders) """ assert mode.evaluating or mode == RunningStage.PREDICTING # always get the loaders first so we can count how many there are dataloaders = self.request_dataloader(mode, model=model) if not isinstance(dataloaders, list): dataloaders = [dataloaders] # when overfitting, use the training loader as val and test # duplicate it the numb of times needed to match the train loaders if self.overfit_batches > 0: train_dataloader = self.request_dataloader(RunningStage.TRAINING, model=model) dataloaders = [deepcopy(train_dataloader) for _ in range(len(dataloaders))] for loader_i in range(len(dataloaders)): loader = dataloaders[loader_i] if hasattr(loader, "sampler") and not isinstance(loader.sampler, SequentialSampler): # when overfitting, the dataloader should not have sampler if self.overfit_batches > 0 and mode.evaluating: rank_zero_warn( "You requested to overfit but enabled val/test dataloader shuffling." " We are turning it off for you." ) dataloaders[loader_i] = self._update_dataloader( loader, SequentialSampler(loader.dataset), mode=mode ) else: rank_zero_warn( f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`," "it is strongly recommended that you turn this off for val/test/predict dataloaders." ) if any(dl is None for dl in dataloaders): rank_zero_warn("One of given dataloaders is None and it will be skipped.") # add samplers dataloaders = [self.prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None] # add worker_init_fn for correct seeding in worker processes apply_to_collection( dataloaders, dtype=DataLoader, function=self._auto_add_worker_init_fn, rank=self.global_rank ) loader_num_batches = [] # determine number of batches # datasets could be none, 1 or 2+ module = model or self.lightning_module or self.datamodule if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): num_batches = ( len(dataloader) if has_len_all_ranks(dataloader, self.training_type_plugin, module) else float("inf") ) self._worker_check(dataloader, f"{mode.dataloader_prefix}_dataloader {i}") # percent or num_steps limit_eval_batches = getattr(self, f"limit_{mode.dataloader_prefix}_batches") # limit num batches either as a percent or num steps if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0: num_batches = min(num_batches, int(limit_eval_batches)) elif num_batches != float("inf"): num_batches = int(num_batches * limit_eval_batches) elif limit_eval_batches != 1.0: raise MisconfigurationException( f"When using an IterableDataset for `limit_{mode}_batches`," f" `Trainer(limit_{mode.dataloader_prefix}_batches)` must be `0.0`, `1.0` or an int. An int k" f" specifies `num_{mode.dataloader_prefix}_batches` to use." ) if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float): min_pct = 1.0 / len(dataloader) raise MisconfigurationException( f"you requested to check {limit_eval_batches} of the `{mode.dataloader_prefix}_dataloader` but" f" {limit_eval_batches}*{num_batches} < 1. Please increase the" f" `limit_{mode.dataloader_prefix}_batches` flag. Try at least" f" `limit_{mode.dataloader_prefix}_batches={min_pct}`" ) loader_num_batches.append(num_batches) return loader_num_batches, dataloaders
def run_train(self) -> None: self._pre_training_routine() if not self.is_global_zero and self.progress_bar_callback is not None: self.progress_bar_callback.disable() self.run_sanity_check(self.lightning_module) self.checkpoint_connector.has_trained = False # enable train mode model = self.lightning_module model.train() torch.set_grad_enabled(True) # reload data when needed self.train_loop.reset_train_val_dataloaders(model) # hook self.train_loop.on_train_start() try: if self.train_loop.should_skip_training(): return # run all epochs epochs = range(self.current_epoch, self.max_epochs) if self.max_epochs else count(self.current_epoch) for epoch in epochs: # hook self.train_loop.on_train_epoch_start(epoch) with self.profiler.profile("run_training_epoch"): # run train epoch self.train_loop.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: return # early stopping met_min_epochs = (epoch >= self.min_epochs - 1) if self.min_epochs else True met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if self.should_stop: if met_min_epochs and met_min_steps: return else: log.info( 'Trainer was signaled to stop but required minimum epochs' f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...' ) # hook self.train_loop.on_train_end() except KeyboardInterrupt: rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') # user could press Ctrl+c many times... only shutdown once if not self.interrupted: self.state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() except (RuntimeError, AssertionError): # if an exception is raised, the finally block is executed and can hide the actual exception # that was initially raised if `on_train_end` also raises an exception. we want to avoid that # for assertions and other runtime errors so we aren't misled while debugging print_exc() finally: # hook self.train_loop.on_train_end()
def init_optimizers( self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]: pl_module = self.lightning_module or model self._lightning_optimizers = None optim_conf = self.call_hook("configure_optimizers", pl_module=pl_module) if optim_conf is None: rank_zero_warn( "`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer", UserWarning, ) optim_conf = _MockOptimizer() optimizers, lr_schedulers, optimizer_frequencies = [], [], [] monitor = None # single output, single optimizer if isinstance(optim_conf, Optimizer): optimizers = [optim_conf] # two lists, optimizer + lr schedulers elif (isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list) and all(isinstance(opt, Optimizer) for opt in optim_conf[0])): opt, sch = optim_conf optimizers = opt lr_schedulers = sch if isinstance(sch, list) else [sch] # single dictionary elif isinstance(optim_conf, dict): optimizers = [optim_conf["optimizer"]] monitor = optim_conf.get("monitor", None) lr_schedulers = [optim_conf["lr_scheduler"] ] if "lr_scheduler" in optim_conf else [] # multiple dictionaries elif isinstance(optim_conf, (list, tuple)) and all( isinstance(d, dict) for d in optim_conf): optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] scheduler_dict = ( lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx) if isinstance(scheduler, dict) else { "scheduler": scheduler, "opt_idx": opt_idx }) lr_schedulers = [ scheduler_dict(opt_dict["lr_scheduler"], opt_idx) for opt_idx, opt_dict in enumerate(optim_conf) if "lr_scheduler" in opt_dict ] optimizer_frequencies = [ opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None ] # assert that if frequencies are present, they are given for all optimizers if optimizer_frequencies and len(optimizer_frequencies) != len( optimizers): raise ValueError( "A frequency must be given to each optimizer.") # single list or tuple, multiple optimizer elif isinstance(optim_conf, (list, tuple)) and all( isinstance(opt, Optimizer) for opt in optim_conf): optimizers = list(optim_conf) # unknown configuration else: raise MisconfigurationException( "Unknown configuration for model optimizers." " Output from `model.configure_optimizers()` should either be:\n" " * `torch.optim.Optimizer`\n" " * [`torch.optim.Optimizer`]\n" " * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n" ' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n' ' * A list of the previously described dict format, with an optional "frequency" key (int)' ) is_manual_optimization = not pl_module.automatic_optimization lr_schedulers = self.configure_schedulers(lr_schedulers, monitor, is_manual_optimization) _validate_scheduler_optimizer(optimizers, lr_schedulers) return optimizers, lr_schedulers, optimizer_frequencies
def configure_schedulers( self, schedulers: list, monitor: Optional[str], is_manual_optimization: bool) -> List[Dict[str, Any]]: """Convert each scheduler into dict structure with relevant information""" lr_schedulers = [] default_config = _get_default_scheduler_config() for scheduler in schedulers: if is_manual_optimization: if isinstance(scheduler, dict): invalid_keys = { "interval", "frequency", "reduce_on_plateau", "monitor", "strict" } keys_to_warn = [ k for k in scheduler.keys() if k in invalid_keys ] if keys_to_warn: rank_zero_warn( f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored." " You need to call `lr_scheduler.step()` manually in manual optimization.", RuntimeWarning, ) scheduler = { key: scheduler[key] for key in scheduler if key not in invalid_keys } lr_schedulers.append({**default_config, **scheduler}) else: lr_schedulers.append({ **default_config, "scheduler": scheduler }) else: if isinstance(scheduler, dict): # check provided keys extra_keys = [ k for k in scheduler.keys() if k not in default_config.keys() ] if extra_keys: rank_zero_warn( f"Found unsupported keys in the lr scheduler dict: {extra_keys}", RuntimeWarning) if "scheduler" not in scheduler: raise MisconfigurationException( 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' ) if "interval" in scheduler and scheduler[ "interval"] not in ("step", "epoch"): raise MisconfigurationException( 'The "interval" key in lr scheduler dict must be "step" or "epoch"' f' but is "{scheduler["interval"]}"') scheduler["reduce_on_plateau"] = isinstance( scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau) if scheduler["reduce_on_plateau"] and scheduler.get( "monitor", None) is None: raise MisconfigurationException( "The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used." ' For example: {"optimizer": optimizer, "lr_scheduler":' ' {"scheduler": scheduler, "monitor": "your_loss"}}' ) lr_schedulers.append({**default_config, **scheduler}) elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): if monitor is None: raise MisconfigurationException( "`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`" " scheduler is used. For example:" ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' ) lr_schedulers.append({ **default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor }) elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): lr_schedulers.append({ **default_config, "scheduler": scheduler }) else: raise ValueError( f'The provided lr scheduler "{scheduler}" is invalid') return lr_schedulers
def on_gpu(self, val: bool) -> None: rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self.accelerator_connector._device_type = DeviceType.GPU
def training_forward(self, batch, batch_idx, opt_idx, hiddens): """ Handle forward for each training case (distributed, single gpu, etc...) :param batch: :param batch_idx: :return: """ # --------------- # FORWARD # --------------- # enable not needing to add opt_idx to training_step args = [batch, batch_idx] if len(self.optimizers) > 1: if self.has_arg('training_step', 'optimizer_idx'): args.append(opt_idx) else: num_opts = len(self.optimizers) raise ValueError( f'Your LightningModule defines {num_opts} optimizers but ' f'training_step is missing the "optimizer_idx" argument.') # pass hiddens if using tbptt if self.truncated_bptt_steps is not None: args.append(hiddens) # distributed forward if self.use_ddp or self.use_ddp2 or self.use_dp: output = self.model(*args) # single GPU forward elif self.single_gpu: gpu_id = 0 if isinstance(self.data_parallel_device_ids, list): gpu_id = self.data_parallel_device_ids[0] batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id) args[0] = batch output = self.model.training_step(*args) # TPU support elif self.use_tpu: batch = self.transfer_batch_to_tpu(copy.copy(batch)) args[0] = batch output = self.model.training_step(*args) # CPU forward else: output = self.model.training_step(*args) # allow any mode to define training_step_end # do something will all the dp outputs (like softmax) if self.is_overriden('training_step_end'): model_ref = self.get_model() with self.profiler.profile('training_step_end'): output = model_ref.training_step_end(output) # allow any mode to define training_end # TODO: remove in 1.0.0 if self.is_overriden('training_end'): model_ref = self.get_model() with self.profiler.profile('training_end'): output = model_ref.training_end(output) rank_zero_warn( '`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.' ' Use training_epoch_end instead', DeprecationWarning) return output
def use_ddp2(self) -> bool: rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.accelerator_connector._distrib_type == DistributedType.DDP2
def transform( self, sample: Any, ) -> Union[Classification, Classifications, Dict[str, Any]]: pred = sample[DataKeys.PREDS] if isinstance(sample, Dict) else sample pred = torch.tensor(pred) logits = None if self.store_logits: logits = pred.tolist() if self.multi_label: one_hot = (pred.sigmoid() > self.threshold).int().tolist() classes = [] for index, value in enumerate(one_hot): if value == 1: classes.append(index) probabilities = torch.sigmoid(pred).tolist() else: classes = torch.argmax(pred, -1).tolist() probabilities = torch.softmax(pred, -1).tolist() if self._labels is not None: if self.multi_label: classifications = [] for idx in classes: fo_cls = fol.Classification( label=self._labels[idx], confidence=probabilities[idx], ) classifications.append(fo_cls) fo_predictions = fol.Classifications( classifications=classifications, logits=logits, ) else: confidence = max(probabilities) if self.threshold is not None and confidence < self.threshold: fo_predictions = None else: fo_predictions = fol.Classification( label=self._labels[classes], confidence=confidence, logits=logits, ) else: rank_zero_warn("No labels were provided, int targets will be used as label strings.", category=UserWarning) if self.multi_label: classifications = [] for idx in classes: fo_cls = fol.Classification( label=str(idx), confidence=probabilities[idx], ) classifications.append(fo_cls) fo_predictions = fol.Classifications( classifications=classifications, logits=logits, ) else: confidence = max(probabilities) if self.threshold is not None and confidence < self.threshold: fo_predictions = None else: fo_predictions = fol.Classification( label=str(classes), confidence=confidence, logits=logits, ) if self.return_filepath: filepath = sample[DataKeys.METADATA]["filepath"] return {"filepath": filepath, "predictions": fo_predictions} return fo_predictions
def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, callbacks: Optional[List[Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, gpus: Optional[Union[List[int], str, int]] = None, auto_select_gpus: bool = False, tpu_cores: Optional[Union[List[int], int]] = None, log_gpu_memory: Optional[str] = None, progress_bar_refresh_rate: int = 1, overfit_pct: float = 0.0, track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_epochs: int = 1000, min_epochs: int = 1, max_steps: Optional[int] = None, min_steps: Optional[int] = None, train_percent_check: float = 1.0, val_percent_check: float = 1.0, test_percent_check: float = 1.0, val_check_interval: float = 1.0, log_save_interval: int = 100, row_log_interval: int = 50, distributed_backend: Optional[str] = None, precision: int = 32, print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0 weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT, weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, resume_from_checkpoint: Optional[str] = None, profiler: Optional[Union[BaseProfiler, bool]] = None, benchmark: bool = False, deterministic: bool = False, reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, amp_level: str = 'O1', # backward compatible, todo: remove in v1.0.0 num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0 use_amp=None, # backward compatible, todo: remove in v0.9.0 show_progress_bar=None, # backward compatible, todo: remove in v0.9.0 ): r""" Customize every aspect of training via flags Args: logger: Logger (or iterable collection of loggers) for experiment tracking. checkpoint_callback: Callback for checkpointing. early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`): callbacks: Add a list of callbacks. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed gradient_clip_val: 0 means don't clip. gradient_clip: .. warning:: .. deprecated:: 0.7.0 Use `gradient_clip_val` instead. Will remove 0.9.0. process_position: orders the progress bar when running multiple models on same machine. num_nodes: number of GPU nodes for distributed training. nb_gpu_nodes: .. warning:: .. deprecated:: 0.7.0 Use `num_nodes` instead. Will remove 0.9.0. gpus: Which GPUs to train on. auto_select_gpus: If enabled and `gpus` is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in "exclusive mode", such that only one process at a time can access them. tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1] num_tpu_cores: How many TPU cores to train on (1 or 8) .. warning:: .. deprecated:: 0.7.6. Will remove 0.9.0. log_gpu_memory: None, 'min_max', 'all'. Might slow performance show_progress_bar: .. warning:: .. deprecated:: 0.7.2 Set `progress_bar_refresh_rate` to positive integer to enable. Will remove 0.9.0. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`. overfit_pct: How much of training-, validation-, and test dataset to check. track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. check_val_every_n_epoch: Check val every n train epochs. fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. max_epochs: Stop training once this number of epochs is reached. max_nb_epochs: .. warning:: .. deprecated:: 0.7.0 Use `max_epochs` instead. Will remove 0.9.0. min_epochs: Force training for at least these many epochs min_nb_epochs: .. warning:: .. deprecated:: 0.7.0 Use `min_epochs` instead. Will remove 0.9.0. max_steps: Stop training after this number of steps. Disabled by default (None). min_steps: Force training for at least these number of steps. Disabled by default (None). train_percent_check: How much of training dataset to check. val_percent_check: How much of validation dataset to check. test_percent_check: How much of test dataset to check. val_check_interval: How often within one training epoch to check the validation set log_save_interval: Writes logs to disk this often row_log_interval: How often to add logging rows (does not write to disk) add_row_log_interval: .. warning:: .. deprecated:: 0.7.0 Use `row_log_interval` instead. Will remove 0.9.0. distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn) use_amp: .. warning:: .. deprecated:: 0.7.0 Use `precision` instead. Will remove 0.9.0. precision: Full precision (32), half precision (16). print_nan_grads: .. warning:: .. deprecated:: 0.7.2 Has no effect. When detected, NaN grads will be printed automatically. Will remove 0.9.0. weights_summary: Prints a summary of the weights when training begins. weights_save_path: Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in `default_root_dir`. amp_level: The optimization level to use (O1, O2, etc...). num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine. truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here. This can be a URL. profiler: To profile individual steps during training and assist in reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch auto_lr_find: If set to True, will `initially` run a learning rate finder, trying to optimize initial learning for faster convergence. Sets learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key, set a string instead of True with the key name. replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this will toggled automatically ddp is used benchmark: If true enables cudnn.benchmark. deterministic: If true enables cudnn.deterministic terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. auto_scale_batch_size: If set to True, will `initially` run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either `power` that estimates the batch size through a power search or `binsearch` that estimates the batch size through a binary search. prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data """ super().__init__() self.deterministic = deterministic torch.backends.cudnn.deterministic = self.deterministic if self.deterministic: # fixing non-deterministic part of horovod # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) # Init callbacks self.prepare_data_per_node = prepare_data_per_node self.callbacks = callbacks or [] self.on_init_start() # benchmarking self.benchmark = benchmark torch.backends.cudnn.benchmark = self.benchmark # Transfer params self.num_nodes = num_nodes self.log_gpu_memory = log_gpu_memory self.gradient_clip_val = gradient_clip_val self.check_val_every_n_epoch = check_val_every_n_epoch if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': raise MisconfigurationException( "track_grad_norm can be an int, a float or 'inf' (infinity norm).") self.track_grad_norm = float(track_grad_norm) self.on_gpu = True if (gpus and torch.cuda.is_available()) else False # tpu config if num_tpu_cores is not None: rank_zero_warn("Argument `num_tpu_cores` is now set by `tpu_cores` since v0.7.6" " and this argument will be removed in v0.9.0", DeprecationWarning) if tpu_cores is None: tpu_cores = num_tpu_cores self.on_tpu = tpu_cores is not None self.tpu_cores = tpu_cores assert self.tpu_cores in (1, 8, None) or ( isinstance(self.tpu_cores, (list, tuple, set)) and len(self.tpu_cores) == 1 ), '`tpu_cores` can only be 1, 8 or [<1-8>]' self.tpu_id = tpu_cores[0] if isinstance(tpu_cores, list) else None if num_processes != 1 and distributed_backend != "ddp_cpu": rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.") self.num_processes = num_processes self.weights_summary = weights_summary self.max_epochs = max_epochs self.min_epochs = min_epochs self.max_steps = max_steps self.min_steps = min_steps self.num_sanity_val_steps = num_sanity_val_steps # Backward compatibility, TODO: remove in v0.9.0 if print_nan_grads: rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0." " NaN grads will be printed automatically when detected.", DeprecationWarning) self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.auto_lr_find = auto_lr_find self.auto_scale_batch_size = auto_scale_batch_size self._is_data_prepared = False self.replace_sampler_ddp = replace_sampler_ddp self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint self.terminate_on_nan = terminate_on_nan self.shown_warnings = set() self.fast_dev_run = fast_dev_run if self.fast_dev_run: self.num_sanity_val_steps = 0 self.max_epochs = 1 rank_zero_info('Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch') # set default save path if user didn't provide one self.default_root_dir = default_root_dir if self.default_root_dir is None: self.default_root_dir = os.getcwd() # training bookeeping self.total_batch_idx = 0 self.running_loss = TensorRunningAccum(window_length=20) self.batch_idx = 0 self.progress_bar_metrics = {} self.callback_metrics = {} self.num_val_batches = 0 self.num_training_batches = 0 self.num_test_batches = 0 self.train_dataloader = None self.test_dataloaders = None self.val_dataloaders = None # training state self.model = None self.testing = False self.disable_validation = False self.lr_schedulers = [] self.optimizers = None self.optimizer_frequencies = [] self.global_step = 0 self.current_epoch = 0 self.interrupted = False # configure logger self.configure_logger(logger) # configure profiler if profiler is True: profiler = SimpleProfiler() self.profiler = profiler or PassThroughProfiler() # configure early stop callback # creates a default one if none passed in self.configure_early_stopping(early_stop_callback) # configure checkpoint callback self.checkpoint_callback = checkpoint_callback self.weights_save_path = weights_save_path # accumulated grads self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) # for gpus allow int, string and gpu list if auto_select_gpus and isinstance(gpus, int): self.gpus = pick_multiple_gpus(gpus) else: self.gpus = gpus self.data_parallel_device_ids = parse_gpu_ids(self.gpus) self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) self.root_device = torch.device("cpu") # tpu state flags self.use_tpu = False self.tpu_local_core_rank = None self.tpu_global_core_rank = None # distributed backend choice self.distributed_backend = distributed_backend self.set_distributed_mode(distributed_backend) # override dist backend when using tpus if self.on_tpu: self.init_tpu() # init flags for SLURM+ddp to work self.world_size = 1 self.interactive_ddp_procs = [] self.configure_slurm_ddp(self.num_nodes) self.node_rank = self.determine_ddp_node_rank() self.local_rank = self.determine_local_rank() self.global_rank = 0 # nvidia setup self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids) # backward compatibility if show_progress_bar is not None: self.show_progress_bar = show_progress_bar self._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position) # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval self.row_log_interval = row_log_interval # how much of the data to use self.overfit_pct = overfit_pct self.determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) # AMP init # These are the only lines needed after v0.8.0 # we wrap the user's forward with autocast and give it back at the end of fit self.autocast_original_forward = None self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") self.precision = precision self.scaler = None self.amp_level = amp_level self.init_amp(use_amp) self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') # Callback system self.on_init_end()
def accelerator_backend(self) -> Accelerator: rank_zero_warn( "The `Trainer.accelerator_backend` attribute is deprecated in favor of `Trainer.accelerator`" " since 1.2 and will be removed in v1.4.", DeprecationWarning ) return self.accelerator
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0: rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
def get_model(self) -> LightningModule: rank_zero_warn( "The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`" " and will be removed in v1.4.", DeprecationWarning ) return self.lightning_module
def lr_find( trainer, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', early_stop_threshold: float = 4.0, datamodule: Optional[LightningDataModule] = None, ): r""" `lr_find` enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. Args: model: Model to do range testing for train_dataloader: A PyTorch `DataLoader` with training samples. If the model has a predefined train_dataloader method, this will be skipped. min_lr: minimum learning rate to investigate max_lr: maximum learning rate to investigate num_training: number of learning rates to test mode: search strategy, either 'linear' or 'exponential'. If set to 'linear' the learning rate will be searched by linearly increasing after each batch. If set to 'exponential', will increase learning rate exponentially. early_stop_threshold: threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None. datamodule: An optional `LightningDataModule` which holds the training and validation dataloader(s). Note that the `train_dataloader` and `val_dataloaders` parameters cannot be used at the same time as this parameter, or a `MisconfigurationException` will be raised. Example:: # Setup model and trainer model = MyModelClass(hparams) trainer = pl.Trainer() # Run lr finder lr_finder = trainer.tuner.lr_find(model, ...) # Inspect results fig = lr_finder.plot(); fig.show() suggested_lr = lr_finder.suggestion() # Overwrite lr and create new model hparams.lr = suggested_lr model = MyModelClass(hparams) # Ready to train with new learning rate trainer.fit(model) """ if trainer.fast_dev_run: rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning) return save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt') __lr_finder_dump_params(trainer, model) # Prevent going into infinite loop trainer.auto_lr_find = False # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) # Use special lr logger callback trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] # No logging trainer.logger = DummyLogger() # Max step set to number of iterations trainer.max_steps = num_training # Disable standard progress bar for fit if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() # Required for saving the model trainer.optimizers, trainer.schedulers = [], [], trainer.model = model # Dump model checkpoint trainer.save_checkpoint(str(save_path)) # Configure optimizer and scheduler model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers) # Fit, lr & loss logged in callback trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule) # Prompt if we stopped early if trainer.global_step != num_training: log.info('LR finder stopped early due to diverging loss.') # Transfer results from callback to lr finder object lr_finder.results.update({'lr': trainer.callbacks[0].lrs, 'loss': trainer.callbacks[0].losses}) lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose # Reset model state if trainer.is_global_zero: trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu) fs = get_filesystem(str(save_path)) if fs.exists(save_path): fs.rm(save_path) # Finish by resetting variables so trainer is ready to fit model __lr_finder_restore_params(trainer, model) if trainer.progress_bar_callback: trainer.progress_bar_callback.enable() return lr_finder
def use_tpu(self) -> bool: rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.on_tpu
def _evaluate( self, model: LightningModule, dataloaders: List[DataLoader], max_batches: Union[int, List[int]], test_mode: bool = False ): """Run evaluation code. Args: model: The model to evaluate. dataloaders: A list of PyTorch dataloaders. max_batches: An integer or list of integers with length of the number of dataloaders. Each entry is the number of batches to process in the corresponding dataloader. test_mode: """ # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.copy_trainer_model_properties(model) # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping outputs = [] # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device(self.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) # each dataloader has a max num batches dl_max_batches = max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when on fast_dev_run (sets max_batch=1) if batch_idx >= dl_max_batches: break # callbacks if test_mode: self.on_test_batch_start() else: self.on_validation_batch_start() # ----------------- # RUN EVALUATION STEP # ----------------- if self.use_amp and NATIVE_AMP_AVALAIBLE: with torch.cuda.amp.autocast(): output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) else: output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) # on dp / ddp2 might still want to do something with the batch parts if test_mode: if self.is_overridden('test_step_end'): model_ref = self.get_model() with self.profiler.profile('test_step_end'): output = model_ref.test_step_end(output) self.on_test_batch_end() else: if self.is_overridden('validation_step_end'): model_ref = self.get_model() with self.profiler.profile('validation_step_end'): output = model_ref.validation_step_end(output) self.on_validation_batch_end() # track outputs for collation dl_outputs.append(output) outputs.append(dl_outputs) eval_results = {} # with a single dataloader don't pass an array if len(dataloaders) == 1: outputs = outputs[0] # give model a chance to do something with the outputs (and method defined) if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)): model = model.module if test_mode: if self.is_overridden('test_end', model=model): # TODO: remove in v1.0.0 eval_results = model.test_end(outputs) rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed v1.0.' ' Use `test_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('test_epoch_end', model=model): eval_results = model.test_epoch_end(outputs) else: if self.is_overridden('validation_end', model=model): # TODO: remove in v1.0.0 eval_results = model.validation_end(outputs) rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed v1.0.' ' Use `validation_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('validation_epoch_end', model=model): eval_results = model.validation_epoch_end(outputs) # aggregate ddp stats across has_content = eval_results is not None and len(eval_results) > 0 if has_content and (self.use_ddp or self.use_ddp2): self.reduce_eval_ddp(eval_results) # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) return eval_results
def use_tpu(self, val: bool) -> None: rank_zero_warn("Internal: `use_tpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) self.on_tpu = val
def best(self): rank_zero_warn( "Attribute `best` has been renamed to `best_model_score` since v0.8.0" " and will be removed in v0.10.0", DeprecationWarning) return self.best_model_score
def on_gpu(self) -> bool: rank_zero_warn("Internal: `on_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self.accelerator_connector._device_type == DeviceType.GPU
def on_validation_end(self, trainer, pl_module): # only run on main process if trainer.global_rank != 0: return metrics = trainer.callback_metrics epoch = trainer.current_epoch # support structured results if metrics.get('checkpoint_on') is not None: self.monitor = 'checkpoint_on' if self.save_top_k == 0: # no models are saved return if self.epoch_last_check is not None and ( epoch - self.epoch_last_check) < self.period: # skipping in this term return self.epoch_last_check = epoch if self.save_last: filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt') self._save_model(filepath, trainer, pl_module) filepath = self.format_checkpoint_name(epoch, metrics) version_cnt = 0 while os.path.isfile(filepath): filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt) # this epoch called before version_cnt += 1 if self.save_top_k != -1: current = metrics.get(self.monitor) if not isinstance(current, torch.Tensor): rank_zero_warn( f'The metric you returned {current} must be a `torch.Tensor` instance, checkpoint not saved' f' HINT: what is the value of {self.monitor} in validation_epoch_end()?', RuntimeWarning) if current is not None: current = torch.tensor(current) if current is None: rank_zero_warn( f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning) elif self.check_monitor_top_k(current): self._do_check_save(filepath, current, epoch, trainer, pl_module) elif self.verbose > 0: log.info( f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}' ) else: if self.verbose > 0: log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' self._save_model(filepath, trainer, pl_module)
def _reset_eval_dataloader( self, model: LightningModule, mode: str) -> Tuple[Union[int, float], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: model: The current `LightningModule` mode: Either `'val'` or `'test'` Returns: Tuple (num_batches, dataloaders) """ dataloaders = self.request_dataloader( getattr(model, f'{mode}_dataloader')) if not isinstance(dataloaders, list): dataloaders = [dataloaders] # shuffling in val and test set is bad practice for loader in dataloaders: if mode in ('val', 'test') and hasattr( loader, 'sampler') and isinstance(loader.sampler, RandomSampler): rank_zero_warn( f'Your {mode}_dataloader has shuffle=True, it is best practice to turn' ' this off for validation and test dataloaders.') if any([dl is None for dl in dataloaders]): rank_zero_warn( "One of given dataloaders is None and it will be skipped.") # add samplers dataloaders = [ self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl is not None ] num_batches = 0 # determine number of batches # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): self._worker_check(dataloader, f'{mode} dataloader {i}') if not _has_len(dataloader): num_batches = float('inf') percent_check = getattr(self, f'{mode}_percent_check') if num_batches != float('inf'): self._percent_range_check(f'{mode}_percent_check') num_batches = sum( len(dataloader) for dataloader in dataloaders) num_batches = int(num_batches * percent_check) elif percent_check not in (0.0, 1.0): raise MisconfigurationException( 'When using an infinite DataLoader (e.g. with an IterableDataset' f' or when DataLoader does not implement `__len__`) for `{mode}_dataloader`,' f' `Trainer({mode}_percent_check)` must be `0.0` or `1.0`.' ) return num_batches, dataloaders