def backward( self, model: LightningModule, closure_loss: torch.Tensor, optimizer: torch.optim.Optimizer, opt_idx: int, should_accumulate: bool, *args: Any, **kwargs: Any, ) -> torch.Tensor: """performs the actual backpropagation Args: model: the model to be optimized closure_loss: the loss value obtained from the closure optimizer: the optimizer to perform the step lateron opt_idx: the optimizer's index should_accumulate: whether to accumulate gradients or not """ automatic_optimization = model.automatic_optimization # do backward pass if automatic_optimization: model.backward(closure_loss, optimizer, opt_idx) else: closure_loss.backward(*args, **kwargs) # once backward has been applied, release graph closure_loss = closure_loss.detach() return closure_loss
def backward( self, model: LightningModule, closure_loss: torch.Tensor, optimizer: Optimizer, opt_idx: int, should_accumulate: bool, *args: Any, **kwargs: Any, ) -> torch.Tensor: """performs the actual backpropagation Args: model: the model to be optimized closure_loss: the loss value obtained from the closure optimizer: the optimizer to perform the step lateron opt_idx: the optimizer's index should_accumulate: whether to accumulate gradients or not """ closure_loss = amp.scale_loss( closure_loss, model.trainer.optimizers if optimizer is None else optimizer) # enter apex context context = closure_loss closure_loss = closure_loss.__enter__() # do backward pass # TODO: not entirely sure, why we need this if model is not None and isinstance(model, LightningModule): model.backward(closure_loss, optimizer, opt_idx, **kwargs) # TODO: avoid dev_debugger and track these calls with mock model.trainer.dev_debugger.track_event('AMP', str(AMPType.APEX)) else: closure_loss.backward(*args, **kwargs) # exit amp context a, b, c = None, None, None error = context.__exit__(a, b, c) if error: rank_zero_warn(a, b, c) raise Exception("apex unscale error") # once backward has been applied, release graph closure_loss = closure_loss.detach() return closure_loss