def backward( self, model: LightningModule, closure_loss: Tensor, optimizer: Optimizer, opt_idx: int, should_accumulate: bool, *args: Any, **kwargs: Any, ) -> Tensor: """performs the actual backpropagation Args: model: the model to be optimized closure_loss: the loss value obtained from the closure optimizer: the optimizer to perform the step lateron opt_idx: the optimizer index should_accumulate: whether to accumulate gradients or not """ opt = model.trainer.optimizers if optimizer is None else optimizer scaled_loss: ContextManager[Tensor] = amp.scale_loss(closure_loss, opt) # enter apex context closure_loss = scaled_loss.__enter__() # do backward pass # TODO: not entirely sure, why we need this if model is not None and isinstance(model, LightningModule): model.backward(closure_loss, optimizer, opt_idx, **kwargs) # TODO: avoid dev_debugger and track these calls with mock model.trainer.dev_debugger.track_event('AMP', str(AMPType.APEX)) else: closure_loss.backward(*args, **kwargs) # exit amp context error = scaled_loss.__exit__(None, None, None) if error: raise Exception("apex unscale error") # once backward has been applied, release graph closure_loss = closure_loss.detach() return closure_loss
def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx)