def _checkpointed_forward(original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any: module = weak_self() # If gradients are disabled, just use original `.forward()` method directly. # Doing so also ensures the internal fwd counter is not incremented in the forward pass, # which would be an issue during eval since there wouldn't be a corresponding backward pass # to decrement the fwd counter. # See https://github.com/facebookresearch/fairscale/pull/709. if not torch.is_grad_enabled(): return original_forward(module, *args, **kwargs) # Autograd Functions in PyTorch work best with positional args, since # the backward must return gradients (or None) for every input argument. # We can flatten keyword arguments to make this easier. args = (module, ) + args kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) parent_ctx_dict: Dict[str, Any] = { "offload": offload_to_cpu, } # Dummy tensor with grad is used to ensure the backward pass is called. This is needed # when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor # avoids requiring users to set their input tensors's requires_grad flag. In the case # of tuple type inputs, setting the flag won't even trigger the backward pass. output = CheckpointFunction.apply(torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args) if not isinstance(output, torch.Tensor): packed_non_tensor_outputs = parent_ctx_dict[ "packed_non_tensor_outputs"] if packed_non_tensor_outputs: output = unpack_non_tensors(output, packed_non_tensor_outputs) return output
def _checkpointed_forward(original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any: module = weak_self() # If gradients are disabled, just use original `.forward()` method directly. if not torch.is_grad_enabled() or thread_local.is_checkpointing_disabled: return original_forward(module, *args, **kwargs) # Autograd Functions in PyTorch work best with positional args, since # the backward must return gradients (or None) for every input argument. # We can flatten keyword arguments to make this easier. args = (module, ) + args kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) parent_ctx_dict: Dict[str, Any] = { "offload": offload_to_cpu, } # Dummy tensor with grad is used to ensure the backward pass is called. This is needed # when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor # avoids requiring users to set their input tensors's requires_grad flag. In the case # of tuple type inputs, setting the flag won't even trigger the backward pass. # # One implication of this is that since we always feed in a dummy tensor # needing grad, then the output will always require grad, even if it originally # wouldn't, such as if the module and original input both do not require grad. # We get around this by saving the desired requires_grad value in output and # detaching the output if needed. output = CheckpointFunction.apply(torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args) output_requires_grad = parent_ctx_dict["output_requires_grad"] if not isinstance(output, torch.Tensor): # If output should not require grad, then detach it, since otherwise it will # always have requires_grad = True due to our dummy tensor input above that # requires_grad output = [ x.detach() if not output_requires_grad else x for x in output ] packed_non_tensor_outputs = parent_ctx_dict[ "packed_non_tensor_outputs"] if packed_non_tensor_outputs: output = unpack_non_tensors(output, packed_non_tensor_outputs) else: # If output should not require grad, then detach it, since otherwise it will # always have requires_grad = True due to our dummy tensor input above that # requires_grad if not output_requires_grad: output = output.detach() return output
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: """ Cast any Tensors in *args or **kwargs to FP16. Doesn't currently support Tensors nested inside containers (e.g., dict). """ kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) tensor_inputs, packed_non_tensor_inputs = split_non_tensors(flat_args) tensor_inputs = tuple(t.half() if torch.is_floating_point(t) else t for t in tensor_inputs) flat_args = unpack_non_tensors(tensor_inputs, packed_non_tensor_inputs) args, kwargs = unpack_kwargs(kwarg_keys, flat_args) return args, kwargs
def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad(), please use .backward() if possible" ) tensor_inputs: Tuple = ctx.saved_tensors tensor_inputs = torch_checkpoint.detach_variable(tensor_inputs) if ctx.fwd_device is not None: tensor_inputs = tuple( t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs)) for i, need_grad in enumerate(ctx.grad_requirements): tensor_inputs[i].requires_grad = need_grad inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) # Store the current states. bwd_rng_state = get_rng_state() # Set the states to what it used to be before the forward pass. set_rng_state(ctx.fwd_rng_state) with torch.enable_grad(), enable_recomputing(), autocast( ctx.had_autocast_in_fwd): unpacked_args, unpacked_kwargs = unpack_kwargs( ctx.kwarg_keys, inputs) outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) tensor_outputs, _ = split_non_tensors(outputs) # Set the states back to what it was at the start of this function. set_rng_state(bwd_rng_state) # Run backward() with only Tensors that require grad outputs_with_grad = [] args_with_grad = [] for i in range(len(tensor_outputs)): if tensor_outputs[i].requires_grad: outputs_with_grad.append(tensor_outputs[i]) args_with_grad.append(args[i]) if len(outputs_with_grad) == 0: raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary") torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) return (None, None, None, None) + grads
def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any: # Autograd Functions in PyTorch work best with positional args, since # the backward must return gradients (or None) for every input argument. # We can flatten keyword arguments to make this easier. kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu} output = CheckpointFunction.apply(original_forward, parent_ctx_dict, kwarg_keys, *flat_args) if not isinstance(output, torch.Tensor): packed_non_tensor_outputs = parent_ctx_dict[ "packed_non_tensor_outputs"] if packed_non_tensor_outputs: output = unpack_non_tensors(output, packed_non_tensor_outputs) return output
def test_split_unpack(): """Test split_non_tensors and unpack_non_tensors.""" x = torch.Tensor([1]) y = torch.Tensor([2]) tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) assert tensors == (x, y) assert packed_non_tensors == { "is_tensor": [True, True, False, False], "objects": [None, 3], } recon = unpack_non_tensors(tensors, packed_non_tensors) assert recon == (x, y, None, 3) tensors, packed_non_tensors = split_non_tensors((None, 3, x, y)) recon = unpack_non_tensors(tensors, packed_non_tensors) assert recon == (None, 3, x, y) tensors, packed_non_tensors = split_non_tensors((None, 3)) recon = unpack_non_tensors(tensors, packed_non_tensors) assert recon == (None, 3) tensors, packed_non_tensors = split_non_tensors((x, y)) recon = unpack_non_tensors(tensors, packed_non_tensors) assert recon == (x, y) recon = unpack_non_tensors(tensors, None) assert recon == (x, y) with pytest.raises(AssertionError): # assert the second arg should be a dict. recon = unpack_non_tensors(tensors, set()) with pytest.raises(AssertionError): # assert the content of the second arg should be sane. recon = unpack_non_tensors(tensors, {"is_tensor": [], "objects": []})