def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ # --------------------------- # FORWARD (TRAINING STEP + TRAIN STEP END) # --------------------------- with self.profiler.profile('model_forward'): args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens) training_step_output = self.accelerator_backend.training_step(args) training_step_output = self.call_hook('training_step_end', training_step_output) # ---------------------------- # PROCESS THE RESULT # ---------------------------- # format and reduce outputs accordingly training_step_output_for_epoch_end = training_step_output is_result_obj = isinstance(training_step_output, Result) # track batch size for weighted average if is_result_obj: training_step_output.track_batch_size(len(split_batch)) # don't allow EvalResult in the training_step if isinstance(training_step_output, EvalResult): raise MisconfigurationException('training_step cannot return EvalResult, ' 'use a dict or TrainResult instead') # handle regular dicts if not is_result_obj: training_step_output = self.process_output(training_step_output, train=True) training_step_output = AttributeDict( batch_loss=training_step_output[0], pbar_on_batch_end=training_step_output[1], log_metrics=training_step_output[2], callback_metrics=training_step_output[3], hiddens=training_step_output[4], ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs if isinstance(training_step_output_for_epoch_end, torch.Tensor): training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() elif is_result_obj: training_step_output_for_epoch_end = copy(training_step_output) training_step_output_for_epoch_end.detach() else: training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) # accumulate loss # (if accumulate_grad_batches = 1 no effect) closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss closure_loss = closure_loss / self.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() # backward pass model_ref = self.get_model() with self.profiler.profile('model_backward'): # scale loss for 16 bit if self.precision == 16 and not self.on_tpu: closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx, amp_backend=self.amp_backend) # enter amp context if self.amp_backend == AMPType.APEX: self.dev_debugger.track_event('AMP', str(AMPType.APEX)) context = closure_loss closure_loss = closure_loss.__enter__() # do backward pass model_ref.backward(self, closure_loss, optimizer, opt_idx) # exit amp context if self.precision == 16 and self.amp_backend == AMPType.APEX and not self.on_tpu: a, b, c = None, None, None error = context.__exit__(a, b, c) if error: rank_zero_warn(a, b, c) raise Exception('apex unscale error') # once backward has been applied, release graph closure_loss = closure_loss.detach() if is_result_obj: training_step_output.detach() else: training_step_output.batch_loss = training_step_output.batch_loss.detach() if self.use_horovod: # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid optimizer.synchronize() # insert after step hook if self.is_function_implemented('on_after_backward'): model_ref = self.get_model() with self.profiler.profile('on_after_backward'): model_ref.on_after_backward() # when in dev debugging track the losses self.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) result = AttributeDict( loss=untouched_loss, training_step_output=training_step_output, training_step_output_for_epoch_end=training_step_output_for_epoch_end, hiddens=training_step_output.hiddens, ) return result
def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ # --------------------------- # FORWARD # --------------------------- with self.profiler.profile('model_forward'): if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: with torch.cuda.amp.autocast(): training_step_output = self.training_forward( split_batch, batch_idx, opt_idx, hiddens) else: training_step_output = self.training_forward( split_batch, batch_idx, opt_idx, hiddens) # ---------------------------- # PROCESS THE RESULT # ---------------------------- # format and reduce outputs accordingly training_step_output_for_epoch_end = training_step_output training_step_output = self.process_output(training_step_output, train=True) # TODO: temporary part of structured results PR training_step_output = AttributeDict( batch_loss=training_step_output[0], pbar_on_batch_end=training_step_output[1], log_metrics=training_step_output[2], callback_metrics=training_step_output[3], hiddens=training_step_output[4], ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs if isinstance(training_step_output_for_epoch_end, torch.Tensor): training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach( ) else: training_step_output_for_epoch_end = recursive_detach( training_step_output_for_epoch_end) # accumulate loss # (if accumulate_grad_batches = 1 no effect) closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() # backward pass model_ref = self.get_model() with self.profiler.profile('model_backward'): # scale loss for 16 bit if self.precision == 16 and not self.on_tpu: closure_loss = model_ref.amp_scale_loss( closure_loss, optimizer, opt_idx) # enter amp context if not NATIVE_AMP_AVALAIBLE: context = closure_loss closure_loss = closure_loss.__enter__() # do backward pass model_ref.backward(self, closure_loss, optimizer, opt_idx) # exit amp context if self.precision == 16 and not NATIVE_AMP_AVALAIBLE and not self.on_tpu: a, b, c = None, None, None error = context.__exit__(a, b, c) if error: rank_zero_warn(a, b, c) raise Exception('apex unscale error') # once backward has been applied, release graph closure_loss = closure_loss.detach() training_step_output.batch_loss = training_step_output.batch_loss.detach( ) if self.use_horovod: # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid optimizer.synchronize() # insert after step hook if self.is_function_implemented('on_after_backward'): model_ref = self.get_model() with self.profiler.profile('on_after_backward'): model_ref.on_after_backward() result = AttributeDict( loss=untouched_loss, training_step_output=training_step_output, training_step_output_for_epoch_end= training_step_output_for_epoch_end, hiddens=training_step_output.hiddens, ) return result
def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): """ wrap the forward step in a closure so second order methods work """ # --------------------------- # FORWARD # --------------------------- with self.profiler.profile('model_forward'): if self.use_amp and self.use_native_amp: with torch.cuda.amp.autocast(): training_step_output = self.training_forward( split_batch, batch_idx, opt_idx, hiddens) else: training_step_output = self.training_forward( split_batch, batch_idx, opt_idx, hiddens) # ---------------------------- # PROCESS THE RESULT # ---------------------------- # format and reduce outputs accordingly training_step_output = self.process_output(training_step_output, train=True) # TODO: temporary part of structured results PR training_step_output = AttributeDict( batch_loss=training_step_output[0], pbar_on_batch_end=training_step_output[1], log_metrics=training_step_output[2], callback_metrics=training_step_output[3], hiddens=training_step_output[4], ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs training_step_output_for_epoch_end = recursive_detach( training_step_output) # accumulate loss # (if accumulate_grad_batches = 1 no effect) closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches # backward pass model_ref = self.get_model() with self.profiler.profile('model_backward'): # scale loss for 16 bit if self.precision == 16 and not self.on_tpu: closure_loss = model_ref.amp_scale_loss( closure_loss, optimizer, opt_idx) # do backward pass model_ref.backward(self, closure_loss, optimizer, opt_idx) # once backward has been applied, release graph closure_loss = closure_loss.detach() training_step_output.batch_loss = training_step_output.batch_loss.detach( ) if self.use_horovod: # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid optimizer.synchronize() # insert after step hook if self.is_function_implemented('on_after_backward'): model_ref = self.get_model() with self.profiler.profile('on_after_backward'): model_ref.on_after_backward() result = AttributeDict( loss=closure_loss, training_step_output=training_step_output, training_step_output_for_epoch_end= training_step_output_for_epoch_end, hiddens=training_step_output.hiddens, ) return result