def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(batch_end_outputs): sample_output = opt_outputs[-1] # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance( sample_output, Result) and sample_output.should_reduce_on_epoch_end hook_overridden = (is_overridden("training_epoch_end", model=self.trainer.get_model()) or is_overridden("on_train_epoch_end", model=self.trainer.get_model())) # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if not (hook_overridden or auto_reduce_tng_result): continue # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance( opt_outputs[0], Result): opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs)
def __verify_train_loop_configuration(self, model): # ----------------------------------- # verify model has a training step # ----------------------------------- has_training_step = is_overridden('training_step', model) if not has_training_step: raise MisconfigurationException( 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' ) # ----------------------------------- # verify model has a train dataloader # ----------------------------------- has_train_dataloader = is_overridden('train_dataloader', model) if not has_train_dataloader: raise MisconfigurationException( 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' ) # ----------------------------------- # verify model has optimizer # ----------------------------------- has_optimizers = is_overridden('configure_optimizers', model) if not has_optimizers: raise MisconfigurationException( 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' )
def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): model = self.trainer.get_model() # with a single dataloader don't pass an array outputs = self.outputs eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] user_reduced = False if self.testing: if is_overridden('test_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) eval_results = model.test_epoch_end(eval_results) user_reduced = True else: if is_overridden('validation_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) eval_results = model.validation_epoch_end(eval_results) user_reduced = True if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) if not isinstance(eval_results, list): eval_results = [eval_results] return eval_results
def reset_val_dataloader(self, model: LightningModule) -> None: """Resets the validation dataloader and determines the number of batches. Args: model: The current `LightningModule` """ has_loader = is_overridden('val_dataloader', model) has_step = is_overridden('validation_step', model) if has_loader and has_step: self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): model = self.trainer.get_model() # with a single dataloader don't pass an array outputs = self.outputs eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] user_reduced = False if self.testing: if is_overridden('test_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results( outputs) model._current_fx_name = 'test_epoch_end' eval_results = model.test_epoch_end(eval_results) user_reduced = True else: if is_overridden('validation_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results( outputs) model._current_fx_name = 'validation_epoch_end' eval_results = model.validation_epoch_end(eval_results) user_reduced = True # capture logging self.trainer.logger_connector.cache_logged_metrics() # depre warning if eval_results is not None and user_reduced: step = 'testing_epoch_end' if self.testing else 'validation_epoch_end' self.warning_cache.warn( f'The {step} should not return anything as of 9.1.' ' To log, use self.log(...) or self.write(...) directly in the LightningModule' ) if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) result = model._results if len(result) > 0 and eval_results is None: eval_results = result.get_epoch_log_metrics() if not isinstance(eval_results, list): eval_results = [eval_results] # track depreceated metrics self.trainer.logger_connector.track_metrics_deprecated( eval_results, using_eval_result, self.testing) return eval_results
def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): model = self.trainer.get_model() # reset results model._results = Result() # with a single dataloader don't pass an array outputs = self.outputs eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] user_reduced = False if self.testing: if is_overridden('test_epoch_end', model=model): model._current_fx_name = 'test_epoch_end' if using_eval_result: eval_results = self.__gather_epoch_end_eval_results( outputs) eval_results = model.test_epoch_end(eval_results) user_reduced = True else: if is_overridden('validation_epoch_end', model=model): model._current_fx_name = 'validation_epoch_end' if using_eval_result: eval_results = self.__gather_epoch_end_eval_results( outputs) eval_results = model.validation_epoch_end(eval_results) user_reduced = True # depre warning if eval_results is not None and user_reduced: step = 'testing_epoch_end' if self.testing else 'validation_epoch_end' self.warning_cache.warn( f'The {step} should not return anything as of 9.1.' ' To log, use self.log(...) or self.write(...) directly in the LightningModule' ) if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) if not isinstance(eval_results, list): eval_results = [eval_results] return eval_results
def __verify_train_loop_configuration(self, model): # ----------------------------------- # verify model has a training step # ----------------------------------- has_training_step = is_overridden('training_step', model) if not has_training_step: raise MisconfigurationException( 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' ) # ----------------------------------- # verify model has a train dataloader # ----------------------------------- has_train_dataloader = is_overridden('train_dataloader', model) if not has_train_dataloader: raise MisconfigurationException( 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' ) # ----------------------------------- # verify model has optimizer # ----------------------------------- has_optimizers = is_overridden('configure_optimizers', model) if not has_optimizers: raise MisconfigurationException( 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' ) trainer = self.trainer trainer.overriden_optimizer_step = is_overridden( 'optimizer_step', model) trainer.overriden_optimizer_zero_grad = is_overridden( 'optimizer_zero_grad', model) automatic_optimization = trainer.train_loop.automatic_optimization going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches( ) has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad if (has_overriden_optimization_functions ) and going_to_accumulate_grad_batches and automatic_optimization: raise MisconfigurationException( 'When overriding `LightningModule` optimizer_step or optimizer_zero_grad' ' , `accumulate_grad_batches` in `Trainer` should to be 1.' ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' )
def training_epoch_end(self, model, epoch_output, num_optimizers): if not is_overridden('training_epoch_end', model=model): return Result() # run training_epoch_end # refresh the result for custom logging at the epoch level model._current_fx_name = 'training_epoch_end' model._results = Result() epoch_output = self.__prepare_epoch_end_inputs(epoch_output) if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization: epoch_output = epoch_output[0] # lightningmodule hook epoch_output = model.training_epoch_end(epoch_output) model._current_fx_name = '' if epoch_output is not None: raise MisconfigurationException( 'training_epoch_end expects a return of None. ' 'HINT: remove the return statement in training_epoch_end') # user can ALSO log at the end of an epoch new_epoch_end_logs = model._results return new_epoch_end_logs
def call_hook(self, hook_name, *args, **kwargs): # temporary. Don't modify evaluation behaviour if self.logger_connector._current_stage == "train": # set hook_name to model + reset Result obj self._reset_result_and_set_hook_fx_name(hook_name) # always profile hooks with self.profiler.profile(hook_name): # first call trainer hook if hasattr(self, hook_name): trainer_hook = getattr(self, hook_name) trainer_hook(*args, **kwargs) # next call hook in lightningModule output = None model_ref = self.get_model() if is_overridden(hook_name, model_ref): hook_fx = getattr(model_ref, hook_name) output = hook_fx(*args, **kwargs) # if the PL module doesn't have the hook then call the accelator # used to auto-reduce things for the user with Results obj elif hasattr(self.accelerator_backend, hook_name): accelerator_hook = getattr(self.accelerator_backend, hook_name) output = accelerator_hook(*args, **kwargs) # temporary. Don't modify evaluation behaviour if self.logger_connector._current_stage == "train": # capture logging self._cache_logged_metrics() return output
def test_dm_transfer_batch_to_device(tmpdir): class CustomBatch: def __init__(self, data): self.samples = data[0] self.targets = data[1] class CurrentTestDM(LightningDataModule): hook_called = False def transfer_batch_to_device(self, data, device): self.hook_called = True if isinstance(data, CustomBatch): data.samples = data.samples.to(device) data.targets = data.targets.to(device) else: data = super().transfer_batch_to_device(data, device) return data model = EvalModelTemplate() dm = CurrentTestDM() batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long))) trainer = Trainer(gpus=1) # running .fit() would require us to implement custom data loaders, we mock the model reference instead trainer.get_model = MagicMock(return_value=model) if is_overridden('transfer_batch_to_device', dm): model.transfer_batch_to_device = dm.transfer_batch_to_device trainer.accelerator_backend = GPUBackend(trainer) batch_gpu = trainer.accelerator_backend.batch_to_device(batch, torch.device('cuda:0')) expected = torch.device('cuda', 0) assert dm.hook_called assert batch_gpu.samples.device == batch_gpu.targets.device == expected
def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator): """ Figure out what needs to be tracked/logged at the end of the epoch """ # the training step outputs a list per optimizer. The list contains the outputs at each time step # when no TBPTT is used, then the list has 1 item per batch # when TBPTT IS used, then the list has n items (1 per time step) epoch_end_outputs = [] for optimizer_idx_outputs in all_train_step_outputs: # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer sample_output = optimizer_idx_outputs[-1] # pull out callback info if available (ie: Results object) if isinstance(sample_output, dict) and 'early_stop_on' in sample_output: early_stopping_accumulator.accumulate(sample_output['early_stop_on']) if isinstance(sample_output, dict) and 'checkpoint_on' in sample_output: checkpoint_accumulator.accumulate(sample_output['checkpoint_on']) # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if is_overridden('training_epoch_end', model=self.get_model()) or auto_reduce_tng_result: epoch_end_outputs.append(optimizer_idx_outputs) return epoch_end_outputs
def check_checkpoint_callback(self, should_check_val): # when no val loop is present or fast-dev-run still need to call checkpoints # TODO bake this logic into the checkpoint callback should_activate = not is_overridden('validation_step', self.get_model()) and not should_check_val if should_activate: checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] [c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks]
def call_hook(self, hook_name, *args, capture=False, **kwargs): # set hook_name to model + reset Result obj if capture: self._reset_result_and_set_hook_fx_name(hook_name) # always profile hooks with self.profiler.profile(hook_name): # first call trainer hook if hasattr(self, hook_name): trainer_hook = getattr(self, hook_name) trainer_hook(*args, **kwargs) # next call hook in lightningModule output = None model_ref = self.get_model() if is_overridden(hook_name, model_ref): hook_fx = getattr(model_ref, hook_name) output = hook_fx(*args, **kwargs) # if the PL module doesn't have the hook then call the accelator # used to auto-reduce things for the user with Results obj elif hasattr(self.accelerator_backend, hook_name): accelerator_hook = getattr(self.accelerator_backend, hook_name) output = accelerator_hook(*args, **kwargs) if capture: self._cache_logged_metrics() return output
def run_sanity_check(self, ref_model): using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 # run tiny validation (if validation defined) # to make sure program won't crash during val if should_sanity_check: self.reset_val_dataloader(ref_model) self.num_sanity_val_batches = [ min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches ] # hook and callback self.running_sanity_check = True self.on_sanity_check_start() # run eval step _, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches) # allow no returns from eval if eval_results is not None and len(eval_results) > 0: # when we get a list back, used only the last item if isinstance(eval_results, list): eval_results = eval_results[-1] if isinstance(eval_results, EvalResult): callback_metrics = eval_results.callback_metrics else: _, _, _, callback_metrics, _ = self.process_dict_result(eval_results) self.logger_connector.callback_metrics = callback_metrics self.on_sanity_check_end() self.running_sanity_check = False
def validate(self, val_iterator, info): self.model.zero_grad() self.model.eval() torch.set_grad_enabled(False) model = self.get_model() if self.is_function_implemented("on_validation_epoch_start", model): model.on_validation_epoch_start() val_outputs = [] for batch_idx, batch in enumerate(val_iterator): batch_info = {"batch_idx": batch_idx} batch_info.update(info) batch_output = self.validate_batch(batch, batch_info) if batch_output is not None: val_outputs.append(batch_output) processed_outputs = None if is_overridden("validation_epoch_end", model): raw_outputs = [vo["raw_output"] for vo in val_outputs] processed_outputs = model.training_epoch_end(raw_outputs) if processed_outputs is not None: if isinstance(processed_outputs, torch.Tensor): return_output = {"val_loss": processed_outputs} elif isinstance(processed_outputs, Result): raise ValueError("Result objects are not supported. Please " "return a dictionary instead.") elif isinstance(processed_outputs, dict): return_output = processed_outputs else: raise TypeError("validation_epoch_end returned an invalid " "type. It must return a Tensor, Result, " "or dict.") else: # User did not override training_epoch_end assert isinstance(val_outputs, list) # Use AverageMeterCollection util to reduce results. meter_collection = AverageMeterCollection() for v in val_outputs: num_samples = v.pop(NUM_SAMPLES, 1) raw_output = v["raw_output"] if isinstance(raw_output, dict): meter_collection.update(raw_output, num_samples) elif isinstance(raw_output, torch.Tensor): meter_collection.update({"val_loss": raw_output.item()}, num_samples) return_output = meter_collection.summary() if self.is_function_implemented("on_validation_epoch_end", model): model.on_validation_epoch_end() # Set back to True so training will work. torch.set_grad_enabled(True) return return_output
def can_prepare_data(self): should_call_dm_prepare_data = True if self.trainer.datamodule is not None and is_overridden( 'prepare_data', self.trainer.datamodule): should_call_dm_prepare_data = not self.trainer.datamodule.has_prepared_data if self.trainer.prepare_data_per_node: return self.trainer.local_rank == 0 and should_call_dm_prepare_data else: return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data
def __verify_eval_loop_configuration(self, model, eval_loop_name): step_name = f'{eval_loop_name}_step' # map the dataloader name loader_name = f'{eval_loop_name}_dataloader' if eval_loop_name == 'validation': loader_name = 'val_dataloader' has_loader = is_overridden(loader_name, model) has_step = is_overridden(step_name, model) if has_loader and not has_step: rank_zero_warn( f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop' ) if has_step and not has_loader: rank_zero_warn( f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop' )
def attach_datamodule(self, model, datamodule, stage): # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) # If we have a datamodule, attach necessary hooks + dataloaders if datamodule: # Override loader hooks if is_overridden('train_dataloader', datamodule): model.train_dataloader = datamodule.train_dataloader if is_overridden('val_dataloader', datamodule): model.val_dataloader = datamodule.val_dataloader if is_overridden('test_dataloader', datamodule): model.test_dataloader = datamodule.test_dataloader # Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule if is_overridden('transfer_batch_to_device', datamodule): model.transfer_batch_to_device = datamodule.transfer_batch_to_device self.trainer.datamodule = datamodule
def validate_batch(self, batch, batch_info): model = self.get_model() batch_idx = batch_info["batch_idx"] if is_overridden("on_validation_batch_start", model): model.on_validation_batch_start(batch=batch, batch_idx=batch_idx, dataloader_idx=0) args = [batch, batch_idx] with self.timers.record("eval_fwd"): if self._is_distributed: # Use the DDP wrapped model (self.model). output = self.model(*args) elif self.use_gpu: # Using single GPU. device = self.device batch = model.transfer_batch_to_device(batch, device=device) args[0] = batch output = model.validation_step(*args) else: # Using CPU. output = model.validation_step(*args) if isinstance(output, Result): raise ValueError("EvalResult objects are not supported. Please " "return a dictionary instead.") if is_overridden("on_validation_step_end", model): output = model.validation_step_end(output) if self.is_function_implemented("on_validation_batch_end", model): model.on_validation_batch_end(outputs=output, batch=batch, batch_idx=batch_idx, dataloader_idx=0) return { "raw_output": output, # NUM_SAMPLES: len(batch) }
def configure_checkpoint_callback(self, checkpoint_callback): if checkpoint_callback is True: # when no val step is defined, use 'loss' otherwise 'val_loss' train_step_only = not is_overridden('validation_step', self.get_model()) monitor_key = 'loss' if train_step_only else 'val_loss' checkpoint_callback = ModelCheckpoint(filepath=None, monitor=monitor_key) elif checkpoint_callback is False: checkpoint_callback = None if checkpoint_callback: checkpoint_callback.save_function = self.save_checkpoint return checkpoint_callback
def __run_legacy_training_epoch_end( self, num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics ): epoch_log_metrics = {} epoch_progress_bar_metrics = {} # -------------------------- # EPOCH END STEP IF DEFINED # -------------------------- if is_overridden('training_epoch_end', model=model): if is_result_obj: # with result object gather across time and training steps so each opt idx has a single result obj epoch_output = self.__gather_result_across_time_and_optimizers(epoch_output) if num_optimizers == 1: epoch_output = epoch_output[0] # run training_epoch_end # a list with a result per optimizer index model._current_fx_name = 'training_epoch_end' epoch_output = model.training_epoch_end(epoch_output) # capture logging self.trainer.logger_connector.cache_logged_metrics() if isinstance(epoch_output, Result): epoch_log_metrics = epoch_output.epoch_log_metrics epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics else: _processed_outputs = self.trainer.process_dict_result(epoch_output) epoch_progress_bar_metrics = _processed_outputs[1] epoch_log_metrics = _processed_outputs[2] epoch_callback_metrics = _processed_outputs[3] # -------------------------- # Structured Result (auto epoch end) # -------------------------- elif is_result_obj: epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output) return epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics
def training_epoch_end(self, model, epoch_output, num_optimizers): if not is_overridden('training_epoch_end', model=model): return # run training_epoch_end # refresh the result for custom logging at the epoch level model._current_fx_name = 'training_epoch_end' epoch_output = self.__prepare_epoch_end_inputs(epoch_output) if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization: epoch_output = epoch_output[0] # lightningmodule hook epoch_output = model.training_epoch_end(epoch_output) if epoch_output is not None: raise MisconfigurationException('training_epoch_end expects a return of None. ' 'HINT: remove the return statement in training_epoch_end') # capture logging self.trainer.logger_connector.cache_logged_metrics()
def run_training_epoch(self): # get model model = self.trainer.get_model() # modify dataloader if needed (ddp, etc...) train_dataloader = self.trainer.accelerator_backend.process_dataloader( self.trainer.train_dataloader) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] # enable profiling for the dataloader train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader( train_dataloader) dataloader_idx = 0 should_check_val = False for batch_idx, (batch, is_last_batch) in train_dataloader: self.trainer.batch_idx = batch_idx # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ with self.trainer.profiler.profile("run_training_batch"): batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: break # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory epoch_end_outputs = self.process_train_step_outputs( batch_output.training_step_output_for_epoch_end, self.early_stopping_accumulator, self.checkpoint_accumulator, ) # hook # TODO: add outputs to batches self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- self.trainer.logger_connector.log_train_step_metrics(batch_output) # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- should_check_val = self.should_check_val_fx( batch_idx, is_last_batch) if should_check_val: self.trainer.run_evaluation(test_mode=False) # reset stage to train self.trainer.logger_connector.set_stage("train") # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- self.save_loggers_on_train_batch_end() # update LR schedulers monitor_metrics = deepcopy( self.trainer.logger_connector.callback_metrics) self.update_train_loop_lr_schedulers( monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True # max steps reached, end training if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1: accumulation_done = self._accumulated_batches_reached() # Ensure accumulation across batches has completed before breaking loop if accumulation_done: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if self.trainer.should_stop: break self.trainer.total_batch_idx += 1 # stop epoch if we limited the number of training batches if (batch_idx + 1) >= self.trainer.num_training_batches: break # progress global step according to grads progress self.increment_accumulated_grad_global_step() # epoch end hook self.run_on_epoch_end_hook(epoch_output) # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers) # when no val loop is present or fast-dev-run still need to call checkpoints self.check_checkpoint_callback(not ( should_check_val or is_overridden('validation_step', model))) # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step()
def train_batch(self, batch, batch_info): # Get the original PTL module. model = self.get_model() optimizer = self.optimizers[0] batch_idx = batch_info["batch_idx"] epoch_idx = batch_info["epoch_idx"] if self.is_function_implemented("on_train_batch_start", model): response = model.on_train_batch_start(batch=batch, batch_idx=batch_idx, dataloader_idx=0) # Skip remainder of epoch if response is -1. if response == -1: return {"signal": -1} args = [batch, batch_idx] if len(self.optimizers) > 1: if self.has_arg("training_step", "optimizer_idx"): args.append(0) with self.timers.record("fwd"): if self._is_distributed: # Use the DDP wrapped model (self.model). output = self.model(*args) elif self.use_gpu: # Using single GPU. # Don't copy the batch since there is a single gpu that # the batch could be referenced from and if there are # multiple optimizers the batch will wind up copying it to # the same device repeatedly. device = self.device batch = model.transfer_batch_to_device(batch, device=device) args[0] = batch output = model.training_step(*args) else: # Using CPU. output = model.training_step(*args) if isinstance(output, Result): raise ValueError("TrainResult objects are not supported. Please " "return a dictionary instead.") # allow any mode to define training_step_end # do something will all the dp outputs (like softmax) if is_overridden("training_step_end", model): output = model.training_step_end(output) # Extract loss from output if dictionary. try: loss = output["loss"] except Exception: if isinstance(output, torch.Tensor): loss = output else: raise RuntimeError( "No `loss` value in the dictionary returned from " "`model.training_step()`.") # If output contains tensors, detach them all. if isinstance(output, torch.Tensor): output = output.detach() elif isinstance(output, dict): output = recursive_detach(output) else: raise TypeError("training_step returned invalid type. It must " "return either a Tensor, Result, or dict.") untouched_loss = loss.detach().clone() with self.timers.record("grad"): if self.use_fp16: with self._amp.scale_loss(loss, optimizer) as scaled_loss: model.backward(scaled_loss, optimizer, optimizer_idx=0) else: model.backward(loss, optimizer, optimizer_idx=0) if self.is_function_implemented("on_after_backward", model): model.on_after_backward() with self.timers.record("apply"): optimizer.step() model.on_before_zero_grad(optimizer) model.optimizer_zero_grad(epoch=epoch_idx, batch_idx=batch_idx, optimizer=optimizer, optimizer_idx=0) if self.is_function_implemented("on_train_batch_end", model): model.on_train_batch_end(outputs=output, batch=batch, batch_idx=batch_idx, dataloader_idx=0) return { "signal": 0, "training_loss": untouched_loss.item(), "raw_output": output, # NUM_SAMPLES: len(batch) }
def train_epoch(self, iterator, info): model = self.get_model() # Enable train mode. self.model.train() # Enable gradients. torch.set_grad_enabled(True) if self.is_function_implemented("on_train_epoch_start", model): model.on_train_epoch_start() if self.use_tqdm and self.world_rank == 0: desc = "" if info is not None and "epoch_idx" in info: if "num_epochs" in info: desc = f"{info['epoch_idx'] + 1}/{info['num_epochs']}e" else: desc = f"{info['epoch_idx'] + 1}e" # TODO: Implement len for Dataset? total = info[NUM_STEPS] if total is None: if hasattr(iterator, "__len__"): total = len(iterator) _progress_bar = tqdm(total=total, desc=desc, unit="batch", leave=False) # Output for each batch. epoch_outputs = [] for batch_idx, batch in enumerate(iterator): batch_info = { "batch_idx": batch_idx, "global_step": self.global_step } batch_info.update(info) batch_output = self.train_batch(batch, batch_info=batch_info) # batch output for each optimizer. epoch_outputs.append(batch_output) should_stop = batch_output["signal"] == -1 if self.use_tqdm and self.world_rank == 0: _progress_bar.n = batch_idx + 1 postfix = {} if "training_loss" in batch_output: postfix.update(loss=batch_output["training_loss"]) _progress_bar.set_postfix(postfix) for s_dict, scheduler in zip(self.scheduler_dicts, self.schedulers): if s_dict["interval"] == SCHEDULER_STEP_BATCH: scheduler.step() self.global_step += 1 if should_stop: break processed_outputs = None if is_overridden("training_epoch_end", model): raw_outputs = [eo["raw_output"] for eo in epoch_outputs] processed_outputs = model.training_epoch_end(raw_outputs) if processed_outputs is not None: if isinstance(processed_outputs, torch.Tensor): return_output = {"train_loss": processed_outputs} elif isinstance(processed_outputs, Result): raise ValueError("Result objects are not supported. Please " "return a dictionary instead.") elif isinstance(processed_outputs, dict): return_output = processed_outputs else: raise TypeError("training_epoch_end returned an invalid " "type. It must return a Tensor, Result, " "or dict.") else: # User did not override training_epoch_end assert isinstance(epoch_outputs, list) # Use AverageMeterCollection util to reduce results. meter_collection = AverageMeterCollection() for o in epoch_outputs: num_samples = o.pop(NUM_SAMPLES, 1) raw_output = o["raw_output"] if isinstance(raw_output, dict): meter_collection.update(raw_output, num_samples) elif isinstance(raw_output, torch.Tensor): meter_collection.update({"train_loss": o["training_loss"]}, num_samples) return_output = meter_collection.summary() if self.is_function_implemented("on_train_epoch_end", model): model.on_train_epoch_end( [eo.get("raw_output") for eo in epoch_outputs]) for s_dict, scheduler in zip(self.scheduler_dicts, self.schedulers): if s_dict["interval"] == SCHEDULER_STEP_EPOCH: scheduler.step() return return_output
def setup(self, config): # Pass in config if ptl_module accepts it. ptl_class = self.__class__._lightning_module_cls if not issubclass(ptl_class, ptl.LightningModule): raise TypeError("Argument must be subclass of " "pytorch_lightning.LightningModule. Got class {} " "instead.".format(ptl_class)) if "config" in inspect.signature(ptl_class.__init__).parameters: ptl_module = ptl_class(config=config) else: ptl_module = ptl_class() # This is needed for LightningDistributedDataParallel. ptl_module.testing = False # Call on_fit_start on instantiation. if self.is_function_implemented("on_fit_start", ptl_module): ptl_module.on_fit_start() # Only run data preparation once per node. if self.local_rank == 0 and self.is_function_implemented( "prepare_data", ptl_module): ptl_module.prepare_data() # Call model.setup. ptl_module.setup("fit") if not is_overridden("configure_optimizers", ptl_module): raise MisconfigurationException( "No `configure_optimizers()` method defined.") optimizers, self._scheduler_dicts, optimizer_frequencies = \ self.init_optimizers(model=ptl_module) if len(optimizer_frequencies) > 0: logger.warning("Optimizer frequencies will be ignored. When " "passing in multiple optimizers, you should " "implement your own custom training loop.") lr_schedulers = [] for scheduler in self.scheduler_dicts: if isinstance(scheduler, dict): # A scheduler dictionary is passed in. if "reduce_on_plateau" in scheduler and "monitor" in \ scheduler and scheduler["reduce_on_plateau"] is True: logger.info( "reduce_on_plateau and monitor will be " "ignored " "from the scheduler dict {}. To update a " "ReduceLROnPlateau scheduler, you should use " "TorchTrainer.update_schedulers.".format(scheduler)) if "frequency" in scheduler and scheduler["frequency"] > 1: logger.info("frequency will be ignored from the " "scheduler dict {}.".format(scheduler)) lr_schedulers.append(scheduler["scheduler"]) else: lr_schedulers.append(scheduler) # Set this so register doesn't complain. self._scheduler_step_freq = "ptl" ddp_model, self._optimizers, self._schedulers = self.register( models=[ptl_module], optimizers=optimizers, schedulers=lr_schedulers) assert len(ddp_model) == 1 self._model = ddp_model[0] model = self.get_model() if self.is_function_implemented("on_pretrain_routine_start", model): model.on_pretrain_routine_start() train_data_loader = None if self.__class__._train_dataloader: train_data_loader = self.__class__._train_dataloader elif self.is_function_implemented("train_dataloader", model): train_data_loader = model.train_dataloader() val_data_loader = None if self.__class__._val_dataloader: val_data_loader = self.__class__._val_dataloader elif self.is_function_implemented("val_dataloader", model): val_data_loader = model.val_dataloader() self.register_data(train_loader=train_data_loader, validation_loader=val_data_loader)
def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers): # epoch output is a list. Each item in that list has all the outputs per optimizer # epoch_output[optimizer_idx][training_step_idx][tbptt_index] # remember that not using truncated backprop is equivalent with truncated back prop of len(1) model = self.get_model() epoch_log_metrics = {} epoch_callback_metrics = {} epoch_progress_bar_metrics = {} # ----------------------- # Calculate epoch callback values if given # ----------------------- if checkpoint_accumulator.num_values > 0: epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean() if early_stopping_accumulator.num_values > 0: epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean() # ------------------------ # determine if using a result obj # ------------------------ # [optimizer_idx][training_step_idx][tbptt_index] opt_idx_outputs = epoch_output[0] try: sample_obj = opt_idx_outputs[0][0] if isinstance(opt_idx_outputs[0], list) else opt_idx_outputs[0] is_result_obj = len(epoch_output) > 0 and isinstance(sample_obj, Result) except IndexError as e: is_result_obj = False # -------------------------- # EPOCH END STEP IF DEFINED # -------------------------- if is_overridden('training_epoch_end', model=model): self.global_step += 1 if is_result_obj: # with result object gather across time and training steps so each opt idx has a single result obj epoch_output = self.__gather_result_across_time_and_optimizers(epoch_output) if num_optimizers == 1: epoch_output = epoch_output[0] # run training_epoch_end # a list with a result per optimizer index epoch_output = model.training_epoch_end(epoch_output) if isinstance(epoch_output, Result): epoch_log_metrics = epoch_output.epoch_log_metrics epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics else: _processed_outputs = self.process_output(epoch_output) epoch_progress_bar_metrics = _processed_outputs[1] epoch_log_metrics = _processed_outputs[2] epoch_callback_metrics = _processed_outputs[3] # -------------------------- # Structured Result (auto epoch end) # -------------------------- elif is_result_obj: epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output) # -------------------------- # track results # -------------------------- # add the metrics to the loggers if epoch_log_metrics and len(epoch_log_metrics) > 0: self.log_metrics(epoch_log_metrics, {}) # add metrics to callbacks self.callback_metrics.update(epoch_callback_metrics) # add metrics to progress_bar if len(epoch_progress_bar_metrics) > 0: self.add_progress_bar_metrics(epoch_progress_bar_metrics)
def __verify_train_loop_configuration(self, model): # ----------------------------------- # verify model has a training step # ----------------------------------- has_training_step = is_overridden('training_step', model) if not has_training_step: raise MisconfigurationException( 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' ) # ----------------------------------- # verify model has a train dataloader # ----------------------------------- has_train_dataloader = is_overridden('train_dataloader', model) if not has_train_dataloader: raise MisconfigurationException( 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' ) # ----------------------------------- # verify model has optimizer # ----------------------------------- has_optimizers = is_overridden('configure_optimizers', model) if not has_optimizers: raise MisconfigurationException( 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' ) trainer = self.trainer trainer.overriden_optimizer_step = is_overridden('optimizer_step', model) trainer.overriden_optimizer_zero_grad = is_overridden('optimizer_zero_grad', model) enable_pl_optimizer = trainer._enable_pl_optimizer automatic_optimization = trainer.train_loop.automatic_optimization if trainer.overriden_optimizer_step and not enable_pl_optimizer and automatic_optimization: rank_zero_warn( "When overriding `LightningModule` optimizer_step with" " `Trainer(..., enable_pl_optimizer=False, automatic_optimization=True, ...)`," " we won't be calling `.zero_grad` we can't assume when you call your `optimizer.step()`." " For Lightning to take care of it, please use `Trainer(enable_pl_optimizer=True)`." ) going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches() has_overriden_optimization_functions = trainer.overriden_optimizer_step or trainer.overriden_optimizer_zero_grad if (has_overriden_optimization_functions) and going_to_accumulate_grad_batches and automatic_optimization: raise MisconfigurationException( 'When overriding `LightningModule` optimizer_step or optimizer_zero_grad with ' '`Trainer(automatic_optimization=True, ...)`, `accumulate_grad_batches` should to be 1.' ' It ensures optimizer_step or optimizer_zero_grad are called on every batch.' ) if (enable_pl_optimizer) and trainer.overriden_optimizer_zero_grad and not automatic_optimization: raise MisconfigurationException( 'When overriding `LightningModule` optimizer_zero_grad with ' '`Trainer(automatic_optimization=False, enable_pl_optimizer=True, ...) is not supported' )
def enable_validation(self) -> bool: """ Check if we should run validation during training. """ model_ref = self.model_connector.get_model() val_loop_enabled = is_overridden( 'validation_step', model_ref) and self.limit_val_batches > 0 return val_loop_enabled or self.fast_dev_run
def attach_step_and_epoch_functions(model, datamodule): datamodule.forward = model.forward for attr in dir(datamodule): if sum([token in attr for token in ["_step", "_epoch_end"]]) > 0: if not is_overridden(attr, model): setattr(model, attr, getattr(datamodule, attr))