class TrainLoop: def __init__(self, trainer): self.trainer = trainer self.early_stopping_accumulator = None self.checkpoint_accumulator = None self.accumulated_loss = None self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) self.automatic_optimization = True self._curr_step_result = None self._cur_grad_norm_dict = None def on_trainer_init(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, automatic_optimization): self.trainer.global_step = 0 self.trainer.current_epoch = 0 self.trainer.interrupted = False self.trainer.should_stop = False self.trainer._state = TrainerState.INITIALIZING self.trainer.total_batch_idx = 0 self.trainer.batch_idx = 0 self.trainer.num_training_batches = 0 self.trainer.train_dataloader = None self.automatic_optimization = automatic_optimization self.trainer.max_epochs = max_epochs self.trainer.min_epochs = min_epochs self.trainer.max_steps = max_steps self.trainer.min_steps = min_steps if num_sanity_val_steps == -1: self.trainer.num_sanity_val_steps = float("inf") else: self.trainer.num_sanity_val_steps = num_sanity_val_steps @property def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) return num_optimizers def should_skip_training(self): if self.trainer.current_epoch >= self.trainer.max_epochs: return True if self.trainer.limit_train_batches == 0: return True return False def on_train_start(self): # clear cache before training if self.trainer.on_gpu and self.trainer.root_gpu is not None: # use context because of: # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 with torch.cuda.device(f"cuda:{self.trainer.root_gpu}"): torch.cuda.empty_cache() # hook self.trainer.call_hook("on_train_start") def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # bind logger and other properties self.trainer.model_connector.copy_trainer_model_properties(model) # clean hparams if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) # links data to the trainer self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) # check that model is configured correctly self.trainer.config_validator.verify_loop_configurations(model) def setup_training(self, model: LightningModule): """Sanity check a few things before starting actual training. Args: model: The model to run sanity test on. """ # -------------------------- # Setup?? # -------------------------- ref_model = model if self.trainer.data_parallel: ref_model = model.module # set the ranks and devices self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank self.trainer.accelerator_backend.dist.device = ref_model.device # give model convenience properties ref_model.trainer = self.trainer # set local properties on the model self.trainer.model_connector.copy_trainer_model_properties(ref_model) # init amp. Must be done here instead of __init__ to allow ddp to work if self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16 and not self.trainer.use_tpu: self.trainer.scaler = self.trainer.precision_connector.backend.scaler # log hyper-parameters if self.trainer.logger is not None: # save exp to get started (this is where the first experiment logs are written) self.trainer.logger.log_hyperparams(ref_model.hparams_initial) self.trainer.logger.log_graph(ref_model) self.trainer.logger.save() # wait for all to join if on distributed self.trainer.accelerator_backend.barrier("setup_training") # register auto-resubmit when on SLURM self.trainer.slurm_connector.register_slurm_signal_handlers() # -------------------------- # Pre-train # -------------------------- # on pretrain routine start self.trainer.on_pretrain_routine_start(ref_model) if self.trainer.is_function_implemented("on_pretrain_routine_start"): ref_model.on_pretrain_routine_start() # print model summary if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing: if self.trainer.weights_summary in ModelSummary.MODES: ref_model.summarize(mode=self.trainer.weights_summary) else: raise MisconfigurationException( "weights_summary can be None, " + ", ".join(ModelSummary.MODES)) # track model now. # if cluster resets state, the model will update with the saved weights self.trainer.model = model # restore training and model before hpc is called self.trainer.checkpoint_connector.restore_weights(model) # on pretrain routine end self.trainer.on_pretrain_routine_end(ref_model) if self.trainer.is_function_implemented("on_pretrain_routine_end"): ref_model.on_pretrain_routine_end() def on_train_end(self): if self._teardown_already_run: return self._teardown_already_run = True # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step self.trainer.global_step -= 1 self.check_checkpoint_callback(should_save=True, is_last=True) self.trainer.global_step += 1 # hook self.trainer.call_hook("on_train_end") # kill loggers if self.trainer.logger is not None: self.trainer.logger.finalize("success") # summarize profile results if self.trainer.global_rank == 0: self.trainer.profiler.describe() # give accelerators a chance to finish self.trainer.accelerator_backend.on_train_end() # clear mem if self.trainer.on_gpu: model = self.trainer.get_model() model.cpu() torch.cuda.empty_cache() def check_checkpoint_callback(self, should_save, is_last=False): # TODO bake this logic into the checkpoint callback if should_save and self.trainer.checkpoint_connector.has_trained: checkpoint_callbacks = [ c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint) ] if is_last and any(c.save_last for c in checkpoint_callbacks): rank_zero_info("Saving latest checkpoint...") model = self.trainer.get_model() for callback in checkpoint_callbacks: callback.on_validation_end(self.trainer, model) def on_train_epoch_start(self, epoch): # update training progress in trainer self.trainer.current_epoch = epoch model = self.trainer.get_model() # reset train dataloader if self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) try: self.trainer.train_dataloader.sampler.set_epoch(epoch) except Exception: pass # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_epoch_start( self.trainer, self.trainer.get_model()) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches) # structured result accumulators for callbacks self.early_stopping_accumulator = Accumulator() self.checkpoint_accumulator = Accumulator() # hook self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): # hook self.trainer.call_hook('on_batch_end') self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx) # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs) # reset batch logger internals self.trainer.logger_connector.on_train_batch_end() def reset_train_val_dataloaders(self, model): if not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs): # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(epoch_end_outputs): # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance( opt_outputs[0], Result): opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) def get_optimizers_iterable(self): """ Generates an iterable with (idx, optimizer) for each optimizer. """ if not self.trainer.optimizer_frequencies: # call training_step once per optimizer return list(enumerate(self.trainer.optimizers)) optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) optimizers_loop_length = optimizer_freq_cumsum[-1] current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length # find optimzier index by looking for the first {item > current_place} in the cumsum list opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) return [[opt_idx, self.trainer.optimizers[opt_idx]]] def on_after_backward(self, training_step_output, batch_idx, untouched_loss): is_result_obj = isinstance(training_step_output, Result) if is_result_obj: training_step_output.detach() else: training_step_output.batch_loss = training_step_output.batch_loss.detach( ) # insert after step hook self.trainer.call_hook("on_after_backward") # when in dev debugging track the losses self.trainer.dev_debugger.track_train_loss_history( batch_idx, untouched_loss.detach()) def _check_training_step_output(self, training_step_output): if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization: if training_step_output.grad_fn is None: # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... raise MisconfigurationException( "In manual optimization, `training_step` should not return a Tensor" ) def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging model_ref = self.trainer.get_model() with self.trainer.profiler.profile("model_forward"): args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) # manually capture logged metrics model_ref._current_fx_name = 'training_step' model_ref._results = Result() training_step_output = self.trainer.accelerator_backend.training_step( args) self.trainer.logger_connector.cache_logged_metrics() self._check_training_step_output(training_step_output) training_step_output = self.trainer.call_hook( "training_step_end", training_step_output) training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( training_step_output, split_batch) is_result_obj = isinstance(training_step_output, Result) if training_step_output_for_epoch_end is None: return None # enable empty loss when using manual opt closure_loss = None untouched_loss = None if self.trainer.train_loop.automatic_optimization: # accumulate loss # (if accumulate_grad_batches = 1 no effect) if is_result_obj: closure_loss = training_step_output.minimize else: closure_loss = training_step_output.batch_loss closure_loss = closure_loss / self.trainer.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() # result result = AttributeDict( closure_loss=closure_loss, loss=untouched_loss, training_step_output=training_step_output, training_step_output_for_epoch_end= training_step_output_for_epoch_end, hiddens=training_step_output.hiddens, ) return result def _process_training_step_output(self, training_step_output, split_batch): training_step_output_for_epoch_end = training_step_output # enable validation_step return None if training_step_output_for_epoch_end is None: return None, None # ----------------------------------------- # process result return (DEPRECATE in 1.0) # ----------------------------------------- if isinstance(training_step_output, Result): training_step_output_for_epoch_end = self._process_result( training_step_output, split_batch) return training_step_output_for_epoch_end, training_step_output # ----------------------------------------- # process hybrid (1.0) # ----------------------------------------- # no need for these checks in 1.0.0 # TODO: remove checks in 1.0.0 is_tensor = isinstance(training_step_output_for_epoch_end, torch.Tensor) is_1_0_output = is_tensor or ("log" not in training_step_output and "progress_bar" not in training_step_output) if is_1_0_output: return self._process_training_step_output_1_0( training_step_output, split_batch) # ----------------------------------------- # process old dict (deprecate 1.0) # ----------------------------------------- training_step_output = self.trainer.process_dict_result( training_step_output, train=True) training_step_output = AttributeDict( batch_loss=training_step_output[0], pbar_on_batch_end=training_step_output[1], log_metrics=training_step_output[2], callback_metrics=training_step_output[3], hiddens=training_step_output[4], ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs if isinstance(training_step_output_for_epoch_end, torch.Tensor): training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach( ) else: training_step_output_for_epoch_end = recursive_detach( training_step_output_for_epoch_end) return training_step_output_for_epoch_end, training_step_output def _process_training_step_output_1_0(self, training_step_output, split_batch): result = self.trainer.get_model()._results loss = None hiddens = None # handle dict return if isinstance(training_step_output, dict): loss = training_step_output.pop("loss", None) hiddens = training_step_output.pop("hiddens", None) result["extra"] = training_step_output # handle scalar return elif isinstance(training_step_output, torch.Tensor): loss = training_step_output result["extra"] = {} # map to results under the hood result.minimize = loss result.hiddens = hiddens # track batch for manual reduction with result result.track_batch_size(len(split_batch)) # track metrics without grads for epoch reduction training_step_output_for_epoch_end = copy(result) training_step_output_for_epoch_end.detach() if self.trainer.move_metrics_to_cpu: training_step_output_for_epoch_end.cpu() # what flows back into the system training_step_output = result return training_step_output_for_epoch_end, training_step_output def _process_result(self, training_step_output, split_batch): training_step_output.track_batch_size(len(split_batch)) m = """ TrainResult and EvalResult were deprecated in 0.9.1 and support will drop in 1.0.0. Use self.log and .write from the LightningModule to log metrics and write predictions. training_step can now only return a scalar (for the loss) or a dictionary with anything you want. Option 1: return loss Option 2: return {'loss': loss, 'anything_else': ...} Option 3: return {'loss': loss, 'hiddens': hiddens, 'anything_else': ...} """ rank_zero_warn(m) # don't allow EvalResult in the training_step if isinstance(training_step_output, EvalResult): raise MisconfigurationException( "training_step cannot return EvalResult, " "use a dict or TrainResult instead") training_step_output_for_epoch_end = copy(training_step_output) training_step_output_for_epoch_end.detach() return training_step_output_for_epoch_end def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): model_ref = self.trainer.get_model() is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) using_native_amp = self.trainer.amp_backend == AMPType.NATIVE # native amp + lbfgs is a no go right now if using_native_amp and is_lbfgs: raise MisconfigurationException( 'native PyTorch amp and lbfgs are not compatible.' ' To request, please file a Github issue in PyTorch and tag @mcarilli' ) # model hook model_ref.optimizer_step( self.trainer.current_epoch, batch_idx, optimizer, opt_idx, train_step_and_backward_closure, on_tpu=self.trainer.use_tpu and TPU_AVAILABLE, using_native_amp=using_native_amp, using_lbfgs=is_lbfgs, ) def on_before_zero_grad(self, optimizer): self.trainer.call_hook('on_before_zero_grad', optimizer) def track_and_norm_grad(self, optimizer): # track gradient norms grad_norm_dic = self._track_gradient_norm() # clip gradients self.trainer.accelerator_backend.clip_gradients(optimizer) self._cur_grad_norm_dict = grad_norm_dic def _track_gradient_norm(self): grad_norm_dict = {} if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: if float(self.trainer.track_grad_norm) > 0: model = self.trainer.get_model() grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) return grad_norm_dict def process_hiddens(self, opt_closure_result): hiddens = opt_closure_result.hiddens if isinstance(opt_closure_result.training_step_output, Result): opt_closure_result.training_step_output_for_epoch_end.drop_hiddens( ) return hiddens def tbptt_split_batch(self, batch): splits = [batch] if self.trainer.truncated_bptt_steps is not None: model_ref = self.trainer.get_model() with self.trainer.profiler.profile("tbptt_split_batch"): splits = model_ref.tbptt_split_batch( batch, self.trainer.truncated_bptt_steps) return splits def run_training_epoch(self): # get model model = self.trainer.get_model() # modify dataloader if needed (ddp, etc...) train_dataloader = self.trainer.accelerator_backend.process_dataloader( self.trainer.train_dataloader) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] # enable profiling for the dataloader train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader( train_dataloader) dataloader_idx = 0 should_check_val = False for batch_idx, (batch, is_last_batch) in train_dataloader: self.trainer.batch_idx = batch_idx # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ with self.trainer.profiler.profile("run_training_batch"): batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: break # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory epoch_end_outputs = self.process_train_step_outputs( batch_output.training_step_output_for_epoch_end, self.early_stopping_accumulator, self.checkpoint_accumulator, ) # hook # TODO: add outputs to batches self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- self.trainer.logger_connector.log_train_step_metrics(batch_output) # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- should_check_val = self.should_check_val_fx( batch_idx, is_last_batch) if should_check_val: self.trainer.run_evaluation(test_mode=False) # reset stage to train self.trainer.logger_connector.set_stage("train") # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- self.save_loggers_on_train_batch_end() # update LR schedulers monitor_metrics = deepcopy( self.trainer.logger_connector.callback_metrics) self.update_train_loop_lr_schedulers( monitor_metrics=monitor_metrics) self.trainer.checkpoint_connector.has_trained = True # max steps reached, end training if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1: accumulation_done = self._accumulated_batches_reached() # Ensure accumulation across batches has completed before breaking loop if accumulation_done: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if self.trainer.should_stop: break self.trainer.total_batch_idx += 1 # stop epoch if we limited the number of training batches if (batch_idx + 1) >= self.trainer.num_training_batches: break # progress global step according to grads progress self.increment_accumulated_grad_global_step() # epoch end hook self.run_on_epoch_end_hook(epoch_output) # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers) # when no val loop is present or fast-dev-run still need to call checkpoints self.check_checkpoint_callback(not ( should_check_val or is_overridden('validation_step', model))) # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms grad_norm_dic = {} # bookkeeping using_results_obj = False self.trainer.hiddens = None # track all outputs across time and num of optimizers batch_outputs = [[] for _ in range(len(self.get_optimizers_iterable()))] if batch is None: return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # hook response = self.trainer.call_hook("on_batch_start") if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) # hook response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) # lightning module hook splits = self.tbptt_split_batch(batch) for split_idx, split_batch in enumerate(splits): # create an iterable for optimizers and loop over them for opt_idx, optimizer in self.prepare_optimizers(): # toggle model params + set info to logger_connector self.run_train_split_start(split_idx, split_batch, opt_idx, optimizer) if self.should_accumulate(): # For gradient accumulation # ------------------- # calculate loss (train step + train step end) # ------------------- # automatic_optimization=True: perform dpp sync only when performing optimizer_step # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): self.training_step_and_backward( split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) batch_outputs = self._process_closure_result( batch_outputs=batch_outputs, opt_idx=opt_idx, ) # ------------------------------ # BACKWARD PASS # ------------------------------ # gradient update with accumulated gradients else: if self.automatic_optimization: def train_step_and_backward_closure(): result = self.training_step_and_backward( split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) return None if result is None else result.loss # optimizer step self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) else: self._curr_step_result = self.training_step( split_batch, batch_idx, opt_idx, self.trainer.hiddens) if self._curr_step_result is None: # user decided to skip optimization # make sure to zero grad. continue batch_outputs = self._process_closure_result( batch_outputs=batch_outputs, opt_idx=opt_idx, ) # todo: Properly aggregate grad_norm accros opt_idx and split_idx grad_norm_dic = self._cur_grad_norm_dict self._cur_grad_norm_dict = None # update running loss + reset accumulated loss self.update_running_loss() result = AttributeDict( signal=0, grad_norm_dic=grad_norm_dic, training_step_output_for_epoch_end=batch_outputs, ) return result @contextmanager def block_ddp_sync_behaviour(self): """ automatic_optimization = True Blocks ddp sync gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead automatic_optimization = False do not block ddp gradient sync when using manual optimization as gradients are needed within the training step Returns: context manager with sync behaviour off """ if self.trainer.accelerator_backend is not None and self.automatic_optimization: yield self.trainer.accelerator_backend.block_ddp_plugin_sync_behaviour( ) else: yield None def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: opt_closure_result = self._curr_step_result if opt_closure_result is not None: # cache metrics self.trainer.logger_connector.cache_training_step_metrics( opt_closure_result) # track hiddens self.trainer.hiddens = self.process_hiddens(opt_closure_result) # check if loss or model weights are nan if self.trainer.terminate_on_nan: self.trainer.detect_nan_tensors(opt_closure_result.loss) # track all the outputs across all steps batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 batch_outputs[batch_opt_idx].append( opt_closure_result.training_step_output_for_epoch_end) if self.automatic_optimization: # track total loss for logging (avoid mem leaks) self.accumulated_loss.append(opt_closure_result.loss) self._curr_step_result = None return batch_outputs def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ with self.trainer.profiler.profile("training_step_and_backward"): # lightning module hook result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) self._curr_step_result = result if result is None: self.warning_cache.warn( "training_step returned None if it was on purpose, ignore this warning..." ) return None if self.trainer.train_loop.automatic_optimization: # backward pass with self.trainer.profiler.profile("model_backward"): self.backward(result, optimizer, opt_idx) # hook - call this hook only # when gradients have finished to accumulate if not self.should_accumulate(): self.on_after_backward(result.training_step_output, batch_idx, result.loss) # check if loss or model weights are nan if self.trainer.terminate_on_nan: self.trainer.detect_nan_tensors(result.loss) return result def backward(self, result, optimizer, opt_idx, *args, **kwargs): self.trainer.dev_debugger.track_event("backward_call") # backward can be called manually in the training loop if isinstance(result, torch.Tensor): self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, *args, **kwargs) else: result.closure_loss = self.trainer.accelerator_backend.backward( result.closure_loss, optimizer, opt_idx, *args, **kwargs) if not self.should_accumulate(): # track gradients self.track_and_norm_grad(optimizer=optimizer) def update_train_loop_lr_schedulers(self, monitor_metrics=None): num_accumulated_batches_reached = self._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() if num_accumulated_batches_reached or num_training_batches_reached: # update lr self.trainer.optimizer_connector.update_learning_rates( interval="step", monitor_metrics=monitor_metrics) def run_on_epoch_end_hook(self, epoch_output): # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() self.trainer.call_hook('on_epoch_end') self.trainer.call_hook('on_train_epoch_end', epoch_output) def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self._accumulated_batches_reached() num_training_batches_reached = self._num_training_batches_reached() # progress global step according to grads progress if num_accumulated_batches_reached or num_training_batches_reached: self.trainer.global_step += 1 def _accumulated_batches_reached(self): return (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 def _num_training_batches_reached(self): return (self.trainer.batch_idx + 1) == self.trainer.num_training_batches def should_accumulate(self): # checks if backward or backward + optimizer step (via closure) accumulation_done = self._accumulated_batches_reached() is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) def should_check_val_fx(self, batch_idx, is_last_batch): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 can_check_val = self.trainer.enable_validation and is_val_check_epoch should_check_val = is_val_check_batch or self.trainer.should_stop is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float( "inf") should_check_val = can_check_val and ( should_check_val or is_last_batch_for_infinite_dataset) return should_check_val def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step args = [batch, batch_idx] if len(self.trainer.optimizers) > 1: if self.trainer.has_arg("training_step", "optimizer_idx"): args.append(opt_idx) else: num_opts = len(self.trainer.optimizers) raise ValueError( f"Your LightningModule defines {num_opts} optimizers but " f'training_step is missing the "optimizer_idx" argument.') # pass hiddens if using tbptt if self.trainer.truncated_bptt_steps is not None: args.append(hiddens) return args def save_loggers_on_train_batch_end(self): # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs or self.trainer.fast_dev_run is True: if self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator): """ Figure out what needs to be tracked/logged at the end of the epoch """ # the training step outputs a list per optimizer. The list contains the outputs at each time step # when no TBPTT is used, then the list has 1 item per batch # when TBPTT IS used, then the list has n items (1 per time step) epoch_end_outputs = [] for optimizer_idx_outputs in all_train_step_outputs: # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer if len(optimizer_idx_outputs) == 0: continue sample_output = optimizer_idx_outputs[-1] # pull out callback info if available (ie: Results object) if isinstance(sample_output, dict) and "early_stop_on" in sample_output: early_stopping_accumulator.accumulate( sample_output["early_stop_on"]) if isinstance(sample_output, dict) and "checkpoint_on" in sample_output: checkpoint_accumulator.accumulate( sample_output["checkpoint_on"]) # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance( sample_output, Result) and sample_output.should_reduce_on_epoch_end # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if is_overridden( "training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result: epoch_end_outputs.append(optimizer_idx_outputs) return epoch_end_outputs def prepare_optimizers(self): # in manual optimization we loop over all optimizers at once optimizers = self.get_optimizers_iterable() if not self.automatic_optimization: optimizers = [optimizers[0]] return optimizers def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): # set split_idx to trainer for tracking self.trainer.split_idx = split_idx # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if self.automatic_optimization and len(self.trainer.optimizers) > 1: model = self.trainer.get_model() model.toggle_optimizer(optimizer, opt_idx) # use to track metrics internally self.trainer.logger_connector.on_train_split_start( split_idx, opt_idx, split_batch) def update_running_loss(self): accumulated_loss = self.accumulated_loss.mean() if accumulated_loss is not None: # calculate running loss for display self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) # reset for next set of accumulated grads self.accumulated_loss.reset()
class EvaluationLoop(object): def __init__(self, trainer): self.trainer = trainer self.testing = False self.outputs = [] self.step_metrics = [] self.predictions = None self.max_batches = None self.warning_cache = WarningCache() self.num_dataloaders = None def on_trainer_init(self): self.trainer.num_val_batches = [] self.trainer.num_sanity_val_batches = [] self.trainer.num_test_batches = [] self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None self.trainer.running_sanity_check = False self.trainer.testing = False # when .test() is called, it sets this self.trainer.tested_ckpt_path = None # when true, prints test results self.trainer.verbose_test = True def get_evaluation_dataloaders(self, max_batches): # select dataloaders model = self.trainer.get_model() # select dataloaders if self.testing: self.trainer.reset_test_dataloader(model) dataloaders = self.trainer.test_dataloaders new_max_batches = self.trainer.num_test_batches else: # val in_sanity_check = self.trainer.running_sanity_check should_reload_every_epoch = self.trainer.reload_dataloaders_every_epoch if (self.trainer.val_dataloaders is None or should_reload_every_epoch) and not in_sanity_check: self.trainer.reset_val_dataloader(model) dataloaders = self.trainer.val_dataloaders new_max_batches = self.trainer.num_val_batches if max_batches is None: max_batches = new_max_batches return dataloaders, max_batches def should_skip_evaluation(self, dataloaders, max_batches): # skip when dataloaders aren't defined if dataloaders is None: return True # enable disabling validation step with limit_val_batches = 0 should_skip = sum(max_batches) == 0 if should_skip: return True return False def on_evaluation_start(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_start', *args, **kwargs) def on_evaluation_model_eval(self, *args, **kwargs): model_ref = self.trainer.get_model() if self.testing: model_ref.on_test_model_eval() else: model_ref.on_validation_model_eval() def on_evaluation_model_train(self, *args, **kwargs): model_ref = self.trainer.get_model() if self.testing: model_ref.on_test_model_train() else: model_ref.on_validation_model_train() def on_evaluation_end(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_end', *args, **kwargs) # reset stage to train self.trainer.logger_connector.set_stage("train") def reload_evaluation_dataloaders(self): model = self.trainer.get_model() if self.testing: self.trainer.reset_test_dataloader(model) else: self.trainer.reset_val_dataloader(model) def is_using_eval_results(self): outputs = self.outputs using_eval_result = len(outputs) > 0 and len( outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) return using_eval_result def setup(self, model, max_batches, dataloaders): # copy properties for forward overrides self.trainer.model_connector.copy_trainer_model_properties(model) # bookkeeping self.outputs = [] self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) self.max_batches = max_batches self.num_dataloaders = self._get_num_dataloaders(dataloaders) def on_evaluation_epoch_start(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) def build_args(self, test_mode, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] multiple_val_loaders = ( not test_mode and self._get_num_dataloaders(self.trainer.val_dataloaders) > 1) multiple_test_loaders = ( test_mode and self._get_num_dataloaders(self.trainer.test_dataloaders) > 1) if multiple_test_loaders or multiple_val_loaders: args.append(dataloader_idx) return args def _get_num_dataloaders(self, dataloaders): # case where user does: # return dl1, dl2 length = len(dataloaders) if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): length = len(dataloaders[0]) return length def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): # configure args args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) # run actual test step if self.testing: output = self.trainer.accelerator_backend.test_step(args) else: output = self.trainer.accelerator_backend.validation_step(args) # track batch size for weighted average is_result_obj = isinstance(output, Result) if is_result_obj: output.track_batch_size(batch) # allow only EvalResult when using structured results (from val_step) if is_result_obj and not isinstance(output, EvalResult): m = 'only EvalResults or dicts are allowed from validation_step' raise MisconfigurationException(m) return output def evaluation_step_end(self, *args, **kwargs): if self.testing: output = self.trainer.call_hook('test_step_end', *args, **kwargs) else: output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output def evaluation_epoch_end(self, num_dataloaders): using_eval_result = self.is_using_eval_results() # call the model epoch end deprecated_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result) # 1.0 epoch_logs = self.trainer.get_model()._results # enable returning anything for i, r in enumerate(deprecated_results): if not isinstance(r, (dict, Result, torch.Tensor)): deprecated_results[i] = [] return deprecated_results, epoch_logs def log_epoch_metrics(self, deprecated_eval_results, epoch_logs, test_mode): using_eval_result = self.is_using_eval_results() eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end( deprecated_eval_results, epoch_logs, using_eval_result, test_mode) return eval_loop_results def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): model = self.trainer.get_model() # reset results model._results = Result() # with a single dataloader don't pass an array outputs = self.outputs eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] user_reduced = False if self.testing: if is_overridden('test_epoch_end', model=model): model._current_fx_name = 'test_epoch_end' if using_eval_result: eval_results = self.__gather_epoch_end_eval_results( outputs) eval_results = model.test_epoch_end(eval_results) user_reduced = True else: if is_overridden('validation_epoch_end', model=model): model._current_fx_name = 'validation_epoch_end' if using_eval_result: eval_results = self.__gather_epoch_end_eval_results( outputs) eval_results = model.validation_epoch_end(eval_results) user_reduced = True # depre warning if eval_results is not None and user_reduced: step = 'testing_epoch_end' if self.testing else 'validation_epoch_end' self.warning_cache.warn( f'The {step} should not return anything as of 9.1.' ' To log, use self.log(...) or self.write(...) directly in the LightningModule' ) if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) if not isinstance(eval_results, list): eval_results = [eval_results] return eval_results def __gather_epoch_end_eval_results(self, outputs): eval_results = [] for epoch_output in outputs: result = epoch_output[0].__class__.gather(epoch_output) if 'checkpoint_on' in result: result.checkpoint_on = result.checkpoint_on.mean() if 'early_stop_on' in result: result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) # with 1 dataloader don't pass in a list if len(eval_results) == 1: eval_results = eval_results[0] return eval_results def __auto_reduce_result_objs(self, outputs): # outputs has a list of results per dataloader eval_results = [] for dl_output in outputs: result = dl_output[0] result = result.__class__.reduce_on_epoch_end(dl_output) if 'checkpoint_on' in result: result.checkpoint_on = result.checkpoint_on.mean() if 'early_stop_on' in result: result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) return eval_results def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): # reset the result of the PL module model = self.trainer.get_model() model._results = Result() model._current_fx_name = 'evaluation_step' # set dataloader_idx and track batch_size self.trainer.logger_connector.on_evaluation_batch_start( self.testing, batch, dataloader_idx, self.num_dataloaders) if self.testing: self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) else: self.trainer.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) def on_evaluation_batch_end(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_batch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_batch_end', *args, **kwargs) def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx): # Add step predictions to prediction collection to write later if output is not None: do_write_predictions = isinstance(output, Result) and self.testing if do_write_predictions: self.predictions.add(output.pop('predictions', None)) # track debug metrics self.trainer.dev_debugger.track_eval_loss_history( self.testing, batch_idx, dataloader_idx, output) def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook if self.testing: self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) def log_evaluation_step_metrics(self, batch, batch_idx): results = self.trainer.get_model()._results if len(results) == 1: return None results.track_batch_size(batch) self.__log_result_step_metrics(results, batch_idx) return results # TODO: deprecate at 1.0 def log_evaluation_step_metrics_legacy(self, output, batch_idx): if self.trainer.running_sanity_check: return if isinstance(output, EvalResult): self.__log_result_step_metrics(output, batch_idx) def __log_result_step_metrics(self, output, batch_idx): step_log_metrics = output.get_batch_log_metrics( include_forked_originals=False) step_pbar_metrics = output.get_batch_pbar_metrics( include_forked_originals=False) cached_batch_log_metrics = \ self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics() if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph metrics_by_epoch = {} for k, v in step_log_metrics.items(): metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx) if len(step_pbar_metrics) > 0: self.trainer.logger_connector.add_progress_bar_metrics( step_pbar_metrics)
class WandbLogger(LightningLoggerBase): r""" Log using `Weights and Biases <https://www.wandb.com/>`_. Install it with pip: .. code-block:: bash pip install wandb Args: name: Display name for the run. save_dir: Path where data is saved. offline: Run offline (data can be streamed later to wandb servers). id: Sets the version, mainly used to resume a previous run. anonymous: Enables or explicitly disables anonymous logging. version: Sets the version, mainly used to resume a previous run. project: The name of the project to which this run will belong. log_model: Save checkpoints in wandb dir to upload on W&B servers. experiment: WandB experiment object. prefix: A string to put at the beginning of metric keys. \**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by :func:`wandb.init` can be passed as keyword arguments in this logger. Example:: .. code-block:: python from pytorch_lightning.loggers import WandbLogger from pytorch_lightning import Trainer wandb_logger = WandbLogger() trainer = Trainer(logger=wandb_logger) Note: When logging manually through `wandb.log` or `trainer.logger.experiment.log`, make sure to use `commit=False` so the logging step does not increase. See Also: - `Tutorial <https://app.wandb.ai/cayush/pytorchlightning/reports/ Use-Pytorch-Lightning-with-Weights-%26-Biases--Vmlldzo2NjQ1Mw>`__ on how to use W&B with Pytorch Lightning. """ LOGGER_JOIN_CHAR = '-' def __init__(self, name: Optional[str] = None, save_dir: Optional[str] = None, offline: bool = False, id: Optional[str] = None, anonymous: bool = False, version: Optional[str] = None, project: Optional[str] = None, log_model: bool = False, experiment=None, prefix: str = '', **kwargs): if wandb is None: raise ImportError( 'You want to use `wandb` logger which is not installed yet,' # pragma: no-cover ' install it with `pip install wandb`.') super().__init__() self._name = name self._save_dir = save_dir self._anonymous = 'allow' if anonymous else None self._id = version or id self._project = project self._experiment = experiment self._offline = offline self._log_model = log_model self._prefix = prefix self._kwargs = kwargs # logging multiple Trainer on a single W&B run (k-fold, resuming, etc) self._step_offset = 0 self.warning_cache = WarningCache() def __getstate__(self): state = self.__dict__.copy() # args needed to reload correct experiment state[ '_id'] = self._experiment.id if self._experiment is not None else None # cannot be pickled state['_experiment'] = None return state @property @rank_zero_experiment def experiment(self) -> Run: r""" Actual wandb object. To use wandb features in your :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. Example:: self.logger.experiment.some_wandb_function() """ if self._experiment is None: if self._offline: os.environ['WANDB_MODE'] = 'dryrun' self._experiment = wandb.init( name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous, id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run # offset logging step when resuming a run self._step_offset = self._experiment.step # save checkpoints in wandb dir to upload on W&B servers if self._log_model: self._save_dir = self._experiment.dir return self._experiment def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100): self.experiment.watch(model, log=log, log_freq=log_freq) @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = self._convert_params(params) params = self._flatten_dict(params) params = self._sanitize_callable_params(params) self.experiment.config.update(params, allow_val_change=True) @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' metrics = self._add_prefix(metrics) if step is not None and step + self._step_offset < self.experiment.step: self.warning_cache.warn( 'Trying to log at a previous step. Use `commit=False` when logging metrics manually.' ) self.experiment.log( metrics, step=(step + self._step_offset) if step is not None else None) @property def save_dir(self) -> Optional[str]: return self._save_dir @property def name(self) -> Optional[str]: # don't create an experiment if we don't have one return self._experiment.project_name( ) if self._experiment else self._name @property def version(self) -> Optional[str]: # don't create an experiment if we don't have one return self._experiment.id if self._experiment else self._id @rank_zero_only def finalize(self, status: str) -> None: # offset future training logged on same W&B run if self._experiment is not None: self._step_offset = self._experiment.step # upload all checkpoints from saving dir if self._log_model: wandb.save(os.path.join(self.save_dir, "*.ckpt"))
class TrainLoop: def __init__(self, trainer): self.trainer = trainer self.early_stopping_accumulator = None self.checkpoint_accumulator = None self.accumulated_loss = None self.warning_cache = WarningCache() self._teardown_already_run = False self.running_loss = TensorRunningAccum(window_length=20) def on_trainer_init(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps): self.trainer.global_step = 0 self.trainer.current_epoch = 0 self.trainer.interrupted = False self.trainer.should_stop = False self.trainer._state = TrainerState.INITIALIZING self.trainer.total_batch_idx = 0 self.trainer.batch_idx = 0 self.trainer.num_training_batches = 0 self.trainer.train_dataloader = None self.trainer.max_epochs = max_epochs self.trainer.min_epochs = min_epochs self.trainer.max_steps = max_steps self.trainer.min_steps = min_steps if num_sanity_val_steps == -1: self.trainer.num_sanity_val_steps = float('inf') else: self.trainer.num_sanity_val_steps = num_sanity_val_steps @property def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) return num_optimizers def on_train_start(self): # clear cache before training if self.trainer.on_gpu and self.trainer.root_gpu is not None: # use context because of: # https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 with torch.cuda.device(f'cuda:{self.trainer.root_gpu}'): torch.cuda.empty_cache() # hook self.trainer.call_hook('on_train_start') def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): # bind logger and other properties self.trainer.model_connector.copy_trainer_model_properties(model) # clean hparams if hasattr(model, 'hparams'): parsing.clean_namespace(model.hparams) # links data to the trainer self.trainer.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) # check that model is configured correctly self.trainer.config_validator.verify_loop_configurations(model) def setup_training(self, model: LightningModule): """Sanity check a few things before starting actual training. Args: model: The model to run sanity test on. """ # -------------------------- # Setup?? # -------------------------- ref_model = model if self.trainer.data_parallel: ref_model = model.module self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank self.trainer.accelerator_backend.dist.device = ref_model.device # give model convenience properties ref_model.trainer = self.trainer # set local properties on the model self.trainer.model_connector.copy_trainer_model_properties(ref_model) # init amp. Must be done here instead of __init__ to allow ddp to work if self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16 and not self.trainer.use_tpu: self.trainer.scaler = torch.cuda.amp.GradScaler() # log hyper-parameters if self.trainer.logger is not None: # save exp to get started self.trainer.logger.log_hyperparams(ref_model.hparams) self.trainer.logger.log_graph(ref_model) self.trainer.logger.save() # wait for all to join if on distributed self.trainer.accelerator_backend.barrier('setup_training') # register auto-resubmit when on SLURM self.trainer.slurm_connector.register_slurm_signal_handlers() # -------------------------- # Pre-train # -------------------------- # on pretrain routine start self.trainer.on_pretrain_routine_start(ref_model) if self.trainer.is_function_implemented('on_pretrain_routine_start'): ref_model.on_pretrain_routine_start() # print model summary if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing: if self.trainer.weights_summary in ModelSummary.MODES: ref_model.summarize(mode=self.trainer.weights_summary) else: raise MisconfigurationException( "weights_summary can be None, " + ", ".join(ModelSummary.MODES)) # track model now. # if cluster resets state, the model will update with the saved weights self.trainer.model = model # restore training and model before hpc is called self.trainer.checkpoint_connector.restore_weights(model) # on pretrain routine end self.trainer.on_pretrain_routine_end(ref_model) if self.trainer.is_function_implemented('on_pretrain_routine_end'): ref_model.on_pretrain_routine_end() def on_train_end(self): if self._teardown_already_run: return self._teardown_already_run = True # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates # when a checkpoint was saved at the last step self.trainer.global_step -= 1 self.check_checkpoint_callback(should_save=True, is_last=True) self.trainer.global_step += 1 # hook self.trainer.call_hook('on_train_end') # kill loggers if self.trainer.logger is not None: self.trainer.logger.finalize("success") # summarize profile results if self.trainer.global_rank == 0: self.trainer.profiler.describe() # give accelerators a chance to finish self.trainer.accelerator_backend.on_train_end() # clear mem if self.trainer.on_gpu: model = self.trainer.get_model() model.cpu() torch.cuda.empty_cache() def check_checkpoint_callback(self, should_save, is_last=False): # TODO bake this logic into the checkpoint callback if should_save: checkpoint_callbacks = [ c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint) ] if is_last and any(c.save_last for c in checkpoint_callbacks): rank_zero_info('Saving latest checkpoint...') model = self.trainer.get_model() [ c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks ] def on_train_epoch_start(self, epoch): model = self.trainer.get_model() # set seed for distributed sampler (enables shuffling for each epoch) try: self.trainer.train_dataloader.sampler.set_epoch(epoch) except Exception: pass # update training progress in trainer self.trainer.current_epoch = epoch # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_epoch_start( self.trainer, self.trainer.get_model()) # stores accumulated grad fractions per batch self.accumulated_loss = TensorRunningAccum( window_length=self.trainer.accumulate_grad_batches) # structured result accumulators for callbacks self.early_stopping_accumulator = Accumulator() self.checkpoint_accumulator = Accumulator() # hook self.trainer.call_hook('on_epoch_start') self.trainer.call_hook('on_train_epoch_start') def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): # figure out what to track for epoch end self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs) # hook self.trainer.call_hook('on_batch_end') self.trainer.call_hook('on_train_batch_end', batch, batch_idx, dataloader_idx) def reset_train_val_dataloaders(self, model): if not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_train_dataloader(model) if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs): # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(epoch_end_outputs): # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance( opt_outputs[0], Result): opt_outputs = opt_outputs[0] epoch_output[opt_idx].append(opt_outputs) def get_optimizers_iterable(self): """ Generates an iterable with (idx, optimizer) for each optimizer. """ if not self.trainer.optimizer_frequencies: # call training_step once per optimizer return list(enumerate(self.trainer.optimizers)) optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) optimizers_loop_length = optimizer_freq_cumsum[-1] current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length # find optimzier index by looking for the first {item > current_place} in the cumsum list opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) return [(opt_idx, self.trainer.optimizers[opt_idx])] def backward(self, result, optimizer, opt_idx): # backward pass with self.trainer.profiler.profile('model_backward'): result.closure_loss = self.trainer.accelerator_backend.backward( result.closure_loss, optimizer, opt_idx) def on_after_backward(self, training_step_output, batch_idx, untouched_loss): is_result_obj = isinstance(training_step_output, Result) if is_result_obj: training_step_output.detach() else: training_step_output.batch_loss = training_step_output.batch_loss.detach( ) # insert after step hook self.trainer.call_hook('on_after_backward') # when in dev debugging track the losses self.trainer.dev_debugger.track_train_loss_history( batch_idx, untouched_loss.detach()) def training_step(self, split_batch, batch_idx, opt_idx, hiddens): # give the PL module a result for logging model = self.trainer.get_model() model._results = Result() model._current_fx_name = 'training_step' with self.trainer.profiler.profile('model_forward'): args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) training_step_output = self.trainer.accelerator_backend.training_step( args) training_step_output = self.trainer.call_hook( 'training_step_end', training_step_output) training_step_output_for_epoch_end, training_step_output = self._process_training_step_output( training_step_output, split_batch) is_result_obj = isinstance(training_step_output, Result) if training_step_output_for_epoch_end is None: return None # accumulate loss # (if accumulate_grad_batches = 1 no effect) if is_result_obj: closure_loss = training_step_output.minimize else: closure_loss = training_step_output.batch_loss closure_loss = closure_loss / self.trainer.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() # result result = AttributeDict( closure_loss=closure_loss, loss=untouched_loss, training_step_output=training_step_output, training_step_output_for_epoch_end= training_step_output_for_epoch_end, hiddens=training_step_output.hiddens, ) return result def _process_training_step_output(self, training_step_output, split_batch): training_step_output_for_epoch_end = training_step_output # enable validation_step return None if training_step_output_for_epoch_end is None: return None, None # ----------------------------------------- # process result return (DEPRECATE in 1.0) # ----------------------------------------- if isinstance(training_step_output, Result): training_step_output_for_epoch_end = self._process_result( training_step_output, split_batch) return training_step_output_for_epoch_end, training_step_output # ----------------------------------------- # process hybrid (1.0) # ----------------------------------------- # no need for these checks in 1.0.0 # TODO: remove checks in 1.0.0 is_tensor = isinstance(training_step_output_for_epoch_end, torch.Tensor) is_1_0_output = is_tensor or ('log' not in training_step_output and 'progress_bar' not in training_step_output) if is_1_0_output: return self._process_training_step_output_1_0( training_step_output, split_batch) # ----------------------------------------- # process old dict (deprecate 1.0) # ----------------------------------------- training_step_output = self.trainer.process_dict_result( training_step_output, train=True) training_step_output = AttributeDict( batch_loss=training_step_output[0], pbar_on_batch_end=training_step_output[1], log_metrics=training_step_output[2], callback_metrics=training_step_output[3], hiddens=training_step_output[4], ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs if isinstance(training_step_output_for_epoch_end, torch.Tensor): training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach( ) else: training_step_output_for_epoch_end = recursive_detach( training_step_output_for_epoch_end) return training_step_output_for_epoch_end, training_step_output def _process_training_step_output_1_0(self, training_step_output, split_batch): result = self.trainer.get_model()._results loss = None hiddens = None # handle dict return if isinstance(training_step_output, dict): loss = training_step_output.pop('loss', None) hiddens = training_step_output.pop('hiddens', None) result['extra'] = training_step_output # handle scalar return elif isinstance(training_step_output, torch.Tensor): loss = training_step_output result['extra'] = {} # map to results under the hood result.minimize = loss result.hiddens = hiddens # track batch for manual reduction with result result.track_batch_size(len(split_batch)) # track metrics without grads for epoch reduction training_step_output_for_epoch_end = copy(result) training_step_output_for_epoch_end.detach() # what flows back into the system training_step_output = result return training_step_output_for_epoch_end, training_step_output def _process_result(self, training_step_output, split_batch): training_step_output.track_batch_size(len(split_batch)) m = """ TrainResult and EvalResult were deprecated in 0.9.1 and support will drop in 1.0.0. Use self.log and .write from the LightningModule to log metrics and write predictions. training_step can now only return a scalar (for the loss) or a dictionary with anything you want. Option 1: return loss Option 2: return {'loss': loss, 'anything_else': ...} Option 3: return {'loss': loss, 'hiddens': hiddens, 'anything_else': ...} """ rank_zero_warn(m) # don't allow EvalResult in the training_step if isinstance(training_step_output, EvalResult): raise MisconfigurationException( 'training_step cannot return EvalResult, ' 'use a dict or TrainResult instead') training_step_output_for_epoch_end = copy(training_step_output) training_step_output_for_epoch_end.detach() return training_step_output_for_epoch_end def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): with self.trainer.profiler.profile('optimizer_step'): # optimizer step lightningModule hook self.trainer.accelerator_backend.optimizer_step( optimizer, batch_idx, opt_idx, train_step_and_backward_closure) def on_before_zero_grad(self, optimizer): model = self.trainer.get_model() model.on_before_zero_grad(optimizer) def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): self.trainer.accelerator_backend.optimizer_zero_grad( batch_idx, optimizer, opt_idx) def on_before_backward(self, batch_idx, optimizer): # track gradient norms grad_norm_dic = self._track_gradient_norm() # clip gradients self.trainer.accelerator_backend.clip_gradients(optimizer) return grad_norm_dic def _track_gradient_norm(self): grad_norm_dict = {} if (self.trainer.global_step + 1) % self.trainer.row_log_interval == 0: if float(self.trainer.track_grad_norm) > 0: model = self.trainer.get_model() grad_norm_dict = model.grad_norm(self.trainer.track_grad_norm) return grad_norm_dict def log_training_step_metrics(self, opt_closure_result, batch_callback_metrics, batch_log_metrics): # track callback metrics callback_metrics = opt_closure_result.training_step_output.callback_metrics batch_callback_metrics.append(callback_metrics) # decide which metrics to log (results vs dict return) using_results_obj = isinstance(opt_closure_result.training_step_output, Result) if using_results_obj: metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics else: metrics_to_log = opt_closure_result.training_step_output.log_metrics step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end # track batch log metrics batch_log_metrics.append(metrics_to_log) # track progress bar metrics if len(step_pbar_metrics) > 0: self.trainer.logger_connector.add_progress_bar_metrics( step_pbar_metrics) self.trainer.logger_connector.callback_metrics.update( step_pbar_metrics) def process_hiddens(self, opt_closure_result): hiddens = opt_closure_result.hiddens if isinstance(opt_closure_result.training_step_output, Result): opt_closure_result.training_step_output_for_epoch_end.drop_hiddens( ) return hiddens def tbptt_split_batch(self, batch): splits = [batch] if self.trainer.truncated_bptt_steps is not None: model_ref = self.trainer.get_model() with self.trainer.profiler.profile('tbptt_split_batch'): splits = model_ref.tbptt_split_batch( batch, self.trainer.truncated_bptt_steps) return splits def run_training_epoch(self): # get model model = self.trainer.get_model() # modify dataloader if needed (ddp, etc...) train_dataloader = self.trainer.accelerator_backend.process_dataloader( self.trainer.train_dataloader) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)] # enable profiling for the dataloader train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader( train_dataloader) dataloader_idx = 0 should_check_val = False for batch_idx, (batch, is_last_batch) in train_dataloader: self.trainer.batch_idx = batch_idx # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: break # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory epoch_end_outputs = self.process_train_step_outputs( batch_output.training_step_output_for_epoch_end, self.early_stopping_accumulator, self.checkpoint_accumulator) # hook # TODO: add outputs to batches self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- self.trainer.logger_connector.log_train_step_metrics(batch_output) # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- should_check_val = self.should_check_val_fx( batch_idx, is_last_batch) if should_check_val: self.trainer.run_evaluation(test_mode=False) # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- self.save_loggers_on_train_batch_end() # update LR schedulers monitor_metrics = deepcopy( self.trainer.logger_connector.callback_metrics) monitor_metrics.update(batch_output.batch_log_metrics) self.update_train_loop_lr_schedulers( monitor_metrics=monitor_metrics) # max steps reached, end training if self.trainer.max_steps is not None and self.trainer.max_steps == self.trainer.global_step + 1: break # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches if self.trainer.should_stop: break self.trainer.total_batch_idx += 1 # stop epoch if we limited the number of training batches if batch_idx + 1 >= self.trainer.num_training_batches: break # progress global step according to grads progress self.increment_accumulated_grad_global_step() # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers) # hook self.trainer.logger_connector.on_train_epoch_end(epoch_output) # when no val loop is present or fast-dev-run still need to call checkpoints self.check_checkpoint_callback(not ( should_check_val or is_overridden('validation_step', model))) # epoch end hook self.run_on_epoch_end_hook() # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms grad_norm_dic = {} # track all metrics for callbacks batch_callback_metrics = [] # track metrics to log batch_log_metrics = [] # bookkeeping using_results_obj = False self.trainer.hiddens = None # track all outputs across time and num of optimizers batch_outputs = [[] for _ in range(len(self.get_optimizers_iterable()))] if batch is None: return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) # hook response = self.trainer.call_hook('on_batch_start') if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) # hook response = self.trainer.call_hook('on_train_batch_start', batch, batch_idx, dataloader_idx) if response == -1: return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic) # lightning module hook splits = self.tbptt_split_batch(batch) for split_idx, split_batch in enumerate(splits): self.trainer.split_idx = split_idx # loop over optimizers for opt_idx, optimizer in self.get_optimizers_iterable(): # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. if len(self.trainer.optimizers) > 1: for param in self.trainer.get_model().parameters(): param.requires_grad = False for group in optimizer.param_groups: for param in group['params']: param.requires_grad = True # ------------------- # calculate loss (train step + train step end) # ------------------- opt_closure_result = self.training_step_and_backward( split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens) if opt_closure_result is None: continue using_results_obj = isinstance( opt_closure_result.training_step_output, Result) # log metrics self.log_training_step_metrics(opt_closure_result, batch_callback_metrics, batch_log_metrics) # track hiddens self.trainer.hiddens = self.process_hiddens(opt_closure_result) # check if loss or model weights are nan if self.trainer.terminate_on_nan: self.trainer.detect_nan_tensors(opt_closure_result.loss) # track total loss for logging (avoid mem leaks) self.accumulated_loss.append(opt_closure_result.loss) # track all the outputs across all steps batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0 batch_outputs[batch_opt_idx].append( opt_closure_result.training_step_output_for_epoch_end) # ------------------------------ # BACKWARD PASS # ------------------------------ # gradient update with accumulated gradients accumulation_done = ( self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 is_final_batch = (self.trainer.batch_idx + 1) == self.trainer.num_training_batches if accumulation_done or is_final_batch: # hook grad_norm_dic = self.on_before_backward( batch_idx, optimizer) # wrap forward + backward pass in closure for 2nd order optimizers train_step_and_backward_closure = lambda: self.training_step_and_backward( split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens, ).loss # optimizer step self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) # hook self.on_before_zero_grad(optimizer) # clear gradients self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) # calculate running loss for display self.running_loss.append( self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) # reset for next set of accumulated grads self.accumulated_loss.reset() # collapse all metrics into one dict batch_log_metrics = { k: v for d in batch_log_metrics for k, v in d.items() } # track all metrics for callbacks self.trainer.logger_connector.callback_metrics.update( batch_log_metrics) self.trainer.logger_connector.callback_metrics.update({ k: v for d in batch_callback_metrics for k, v in d.items() if v is not None }) result = AttributeDict( signal=0, grad_norm_dic=grad_norm_dic, batch_log_metrics=batch_log_metrics, training_step_output_for_epoch_end=batch_outputs) return result def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ # lightning module hook result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) if result is None: self.warning_cache.warn( 'training_step returned None if it was on purpose, ignore this warning...' ) return None # backward pass self.backward(result, optimizer, opt_idx) # hook self.on_after_backward(result.training_step_output, batch_idx, result.loss) return result def update_train_loop_lr_schedulers(self, monitor_metrics=None): num_accumulated_batches_reached = ( self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 num_training_batches_reached = (self.trainer.batch_idx + 1) == self.trainer.num_training_batches if num_accumulated_batches_reached or num_training_batches_reached: # update lr self.trainer.optimizer_connector.update_learning_rates( interval='step', monitor_metrics=monitor_metrics) def run_on_epoch_end_hook(self): self.trainer.call_hook('on_epoch_end') self.trainer.call_hook('on_train_epoch_end') def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = ( self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 num_training_batches_reached = (self.trainer.batch_idx + 1) == self.trainer.num_training_batches # progress global step according to grads progress if num_accumulated_batches_reached or num_training_batches_reached: self.trainer.global_step += 1 def should_check_val_fx(self, batch_idx, is_last_batch): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 can_check_val = self.trainer.enable_validation and is_val_check_epoch should_check_val = is_val_check_batch or self.trainer.should_stop is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float( 'inf') should_check_val = can_check_val and ( should_check_val or is_last_batch_for_infinite_dataset) return should_check_val def build_train_args(self, batch, batch_idx, opt_idx, hiddens): # enable not needing to add opt_idx to training_step args = [batch, batch_idx] if len(self.trainer.optimizers) > 1: if self.trainer.has_arg('training_step', 'optimizer_idx'): args.append(opt_idx) else: num_opts = len(self.trainer.optimizers) raise ValueError( f'Your LightningModule defines {num_opts} optimizers but ' f'training_step is missing the "optimizer_idx" argument.') # pass hiddens if using tbptt if self.trainer.truncated_bptt_steps is not None: args.append(hiddens) return args def save_loggers_on_train_batch_end(self): # when loggers should save to disk should_save_log = ((self.trainer.global_step + 1) % self.trainer.log_save_interval == 0 or self.trainer.should_stop) if should_save_log or self.trainer.fast_dev_run: if self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accumulator, checkpoint_accumulator): """ Figure out what needs to be tracked/logged at the end of the epoch """ # the training step outputs a list per optimizer. The list contains the outputs at each time step # when no TBPTT is used, then the list has 1 item per batch # when TBPTT IS used, then the list has n items (1 per time step) epoch_end_outputs = [] for optimizer_idx_outputs in all_train_step_outputs: # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer if len(optimizer_idx_outputs) == 0: continue sample_output = optimizer_idx_outputs[-1] # pull out callback info if available (ie: Results object) if isinstance(sample_output, dict) and 'early_stop_on' in sample_output: early_stopping_accumulator.accumulate( sample_output['early_stop_on']) if isinstance(sample_output, dict) and 'checkpoint_on' in sample_output: checkpoint_accumulator.accumulate( sample_output['checkpoint_on']) # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance( sample_output, Result) and sample_output.should_reduce_on_epoch_end # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if is_overridden( 'training_epoch_end', model=self.trainer.get_model()) or auto_reduce_tng_result: epoch_end_outputs.append(optimizer_idx_outputs) return epoch_end_outputs
class WandbLogger(LightningLoggerBase): LOGGER_JOIN_CHAR = '-' NAME_HPARAMS_FILE = 'hparams.yaml' def __init__(self, name: Optional[str] = None, save_dir: Optional[str] = None, offline: bool = False, id: Optional[str] = None, anonymous: bool = False, version: Optional[str] = None, project: Optional[str] = None, log_model: bool = False, log_graph: bool = False, default_hp_metric: bool = True, experiment=None, prefix: str = '', **kwargs): if wandb is None: raise ImportError( 'You want to use `wandb` logger which is not installed yet,' # pragma: no-cover ' install it with `pip install wandb`.') super().__init__() self._name = name self._save_dir = save_dir self._anonymous = 'allow' if anonymous else None self._id = version or id self._project = project self._experiment = experiment self._offline = offline self._log_model = log_model self._prefix = prefix self._log_graph = log_graph self._default_hp_metric = default_hp_metric self._kwargs = kwargs # logging multiple Trainer on a single W&B run (k-fold, resuming, etc) self._step_offset = 0 self.hparams = {} self.warning_cache = WarningCache() def __getstate__(self): state = self.__dict__.copy() # args needed to reload correct experiment state[ '_id'] = self._experiment.wandb_experiment.id if self._experiment is not None else None # cannot be pickled state['_experiment'] = None return state @property @rank_zero_experiment def experiment(self) -> Run: r""" Actual wandb object. To use wandb features in your :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. Example:: self.logger.experiment.some_wandb_function() """ if self._experiment is None: if self._offline: os.environ['WANDB_MODE'] = 'dryrun' wandb_experiment = wandb.init( name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous, id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run # offset logging step when resuming a run self._step_offset = wandb_experiment.step # save checkpoints in wandb dir to upload on W&B servers if self._log_model: self._save_dir = wandb_experiment.dir self._fs = get_filesystem(self.save_dir) tensorboard_experiment = SummaryWriter( log_dir=wandb_experiment.dir, **self._kwargs) self._experiment = ExperimentTuple(wandb_experiment, tensorboard_experiment) self._experiment._wandb_offset = self._step_offset return self._experiment @rank_zero_only def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100): wandb_experiment, tensorboard_experiment = self.experiment wandb_experiment.watch(model, log=log, log_freq=log_freq) @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None) -> None: params = self._convert_params(params) wandb_experiment, tensorboard_experiment = self.experiment # store params to output if OMEGACONF_AVAILABLE and isinstance(params, Container): self.hparams = OmegaConf.merge(self.hparams, params) else: self.hparams.update(params) params = self._flatten_dict(params) params = self._sanitize_callable_params(params) if metrics is None: if self._default_hp_metric: metrics = {"hp_metric": -1} elif not isinstance(metrics, dict): metrics = {"hp_metric": metrics} # TensorBoard if metrics: metrics = self._add_prefix(metrics) for k, v in metrics.items(): if isinstance(v, torch.Tensor): v = v.item() tensorboard_experiment.add_scalar(k, v, 0) exp, ssi, sei = hparams(params, metrics) writer = tensorboard_experiment._get_file_writer() writer.add_summary(exp) writer.add_summary(ssi) writer.add_summary(sei) # Wandb wandb_experiment.config.update(params, allow_val_change=True) @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' metrics = self._add_prefix(metrics) wandb_experiment, tensorboard_experiment = self.experiment # TensorBoard for k, v in metrics.items(): if isinstance(v, torch.Tensor): v = v.item() if isinstance(v, dict): tensorboard_experiment.add_scalars(k, v, step) else: try: tensorboard_experiment.add_scalar(k, v, step) except Exception as e: m = f'\n you tried to log {v} which is not currently supported. Try a dict or a scalar/tensor.' type(e)(e.message + m) # Wandb if step is not None and step + self._step_offset < wandb_experiment.step: self.warning_cache.warn( 'Trying to log at a previous step. Use `commit=False` when logging metrics manually.' ) wandb_experiment.log( metrics, step=(step + self._step_offset) if step is not None else None) @rank_zero_only def log_graph(self, model: LightningModule, input_array=None): if self._log_graph: wandb_experiment, tensorboard_experiment = self.experiment if input_array is None: input_array = model.example_input_array if input_array is not None: input_array = model.transfer_batch_to_device( input_array, model.device) tensorboard_experiment.add_graph(model, input_array) else: rank_zero_warn( 'Could not log computational graph since the' ' `model.example_input_array` attribute is not set' ' or `input_array` was not given', UserWarning) @property def save_dir(self) -> Optional[str]: return self._save_dir @property def name(self) -> Optional[str]: # don't create an experiment if we don't have one return self._experiment.wandb_experiment.project_name( ) if self._experiment else self._name @property def version(self) -> Optional[str]: # don't create an experiment if we don't have one return self._experiment.wandb_experiment.id if self._experiment else self._id @rank_zero_only def finalize(self, status: str) -> None: self.experiment.flush() self.save() # offset future training logged on same W&B run if self._experiment is not None: self._step_offset = self._experiment.wandb_experiment.step # upload all checkpoints from saving dir if self._log_model: wandb.save(os.path.join(self.save_dir, "*.ckpt")) wandb.save(os.path.join(self.save_dir, self.NAME_HPARAMS_FILE)) @rank_zero_only def save(self) -> None: # Initialize experiment _ = self.experiment super().save() # prepare the file path hparams_file = os.path.join(self.save_dir, self.NAME_HPARAMS_FILE) # save the metatags file if it doesn't exist if not os.path.isfile(hparams_file): save_hparams_to_yaml(hparams_file, self.hparams)