def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad(), please use .backward() if possible" ) inputs = detach_variable(ctx.saved_tensors) # Store the current states. bwd_rng_state = utils.get_rng_state() # Set the states to what it used to be before the forward pass. utils.set_rng_state(ctx.fwd_rng_state) with torch.enable_grad(): outputs = ctx.run_function(*inputs) # Set the states back to what it was at the start of this function. utils.set_rng_state(bwd_rng_state) if isinstance(outputs, torch.Tensor): outputs = (outputs, ) torch.autograd.backward(outputs, args) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs) return (None, ) + grads
def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad(), please use .backward() if possible" ) tensor_inputs: Tuple = ctx.saved_tensors tensor_inputs = checkpoint.detach_variable(tensor_inputs) if ctx.fwd_device is not None: tensor_inputs = [ t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs) ] for i, need_grad in enumerate(ctx.grad_requirements): tensor_inputs[i].requires_grad = need_grad inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) # Store the current states. bwd_rng_state = utils.get_rng_state() # Set the states to what it used to be before the forward pass. utils.set_rng_state(ctx.fwd_rng_state) with torch.enable_grad(): unpacked_args, unpacked_kwargs = unpack_kwargs( ctx.kwarg_keys, inputs) outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) tensor_outputs, _ = split_non_tensors(outputs) # Set the states back to what it was at the start of this function. utils.set_rng_state(bwd_rng_state) # Run backward() with only Tensors that require grad outputs_with_grad = [] args_with_grad = [] for i in range(len(tensor_outputs)): if tensor_outputs[i].requires_grad: outputs_with_grad.append(tensor_outputs[i]) args_with_grad.append(args[i]) if len(outputs_with_grad) == 0: raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary") torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) return (None, None, None) + grads