def forward( # type: ignore ctx: Any, dummy_tensor_requires_grad: torch.Tensor, run_function: Any, parent_ctx_dict: Dict[str, Any], kwarg_keys: Tuple[str, ...], *args: Any, **kwargs: Any) -> Any: torch_checkpoint.check_backward_validity(args) ctx.run_function = run_function ctx.kwarg_keys = kwarg_keys ctx.fwd_rng_state = get_rng_state() ctx.had_autocast_in_fwd = is_autocast_enabled() tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) if parent_ctx_dict["offload"]: ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) tensor_inputs = tuple( x.to("cpu", non_blocking=True) for x in tensor_inputs) else: ctx.fwd_device, ctx.grad_requirements = None, None ctx.save_for_backward(*tensor_inputs) ctx.packed_non_tensor_inputs = packed_non_tensor_inputs with torch.no_grad(), enable_checkpointing(): unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) outputs = run_function(*unpacked_args, **unpacked_kwargs) the_module = unpacked_args[0] # Because we run with torch.no_grad(), we can't actually access # outputs.requires_grad. Instead, we manually compute it by # checking if either the input or the module needs grads parameters = list(the_module.parameters()) # If the module is wrapped by FlattenParamsWrapper, then the # parameters would have been deleted. If so, we need to access # the views into the flattened parameters. if hasattr(the_module, "_unflattened_param_views"): parameters += the_module._unflattened_param_views output_requires_grad = any(param.requires_grad for param in parameters) or any( x.requires_grad for x in tensor_inputs) parent_ctx_dict["output_requires_grad"] = output_requires_grad if not isinstance(outputs, torch.Tensor): # Autograd Functions don't like non-Tensor outputs. We can split the # non-Tensor and Tensor outputs, returning the former by reference # through *parent_ctx_dict* and returning the latter directly. outputs, packed_non_tensor_outputs = split_non_tensors(outputs) parent_ctx_dict[ "packed_non_tensor_outputs"] = packed_non_tensor_outputs return outputs
def forward( # type: ignore ctx: Any, run_function: Any, parent_ctx_dict: Dict[str, Any], kwarg_keys: Tuple[str, ...], *args: Any, **kwargs: Any ) -> Any: if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation torch_checkpoint.check_backward_validity(args) ctx.run_function = run_function ctx.kwarg_keys = kwarg_keys ctx.fwd_rng_state = get_rng_state() ctx.had_autocast_in_fwd = is_autocast_enabled() tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) if parent_ctx_dict["offload"]: ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) tensor_inputs = tuple(x.cpu() for x in tensor_inputs) else: ctx.fwd_device, ctx.grad_requirements = None, None ctx.save_for_backward(*tensor_inputs) ctx.packed_non_tensor_inputs = packed_non_tensor_inputs with torch.no_grad(): unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) outputs = run_function(*unpacked_args, **unpacked_kwargs) the_module = unpacked_args[0] inc_counter(the_module) if not isinstance(outputs, torch.Tensor): # Autograd Functions don't like non-Tensor outputs. We can split the # non-Tensor and Tensor outputs, returning the former by reference # through *parent_ctx_dict* and returning the latter directly. outputs, packed_non_tensor_outputs = split_non_tensors(outputs) parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs return outputs
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: """ Cast any Tensors in *args or **kwargs to FP16. Doesn't currently support Tensors nested inside containers (e.g., dict). """ kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) tensor_inputs, packed_non_tensor_inputs = split_non_tensors(flat_args) tensor_inputs = tuple(t.half() if torch.is_floating_point(t) else t for t in tensor_inputs) flat_args = unpack_non_tensors(tensor_inputs, packed_non_tensor_inputs) args, kwargs = unpack_kwargs(kwarg_keys, flat_args) return args, kwargs
def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad(), please use .backward() if possible" ) tensor_inputs: Tuple = ctx.saved_tensors tensor_inputs = torch_checkpoint.detach_variable(tensor_inputs) if ctx.fwd_device is not None: tensor_inputs = tuple( t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs)) for i, need_grad in enumerate(ctx.grad_requirements): tensor_inputs[i].requires_grad = need_grad inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) # Store the current states. bwd_rng_state = get_rng_state() # Set the states to what it used to be before the forward pass. set_rng_state(ctx.fwd_rng_state) with torch.enable_grad(), enable_recomputing(), autocast( ctx.had_autocast_in_fwd): unpacked_args, unpacked_kwargs = unpack_kwargs( ctx.kwarg_keys, inputs) outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) tensor_outputs, _ = split_non_tensors(outputs) # Set the states back to what it was at the start of this function. set_rng_state(bwd_rng_state) # Run backward() with only Tensors that require grad outputs_with_grad = [] args_with_grad = [] for i in range(len(tensor_outputs)): if tensor_outputs[i].requires_grad: outputs_with_grad.append(tensor_outputs[i]) args_with_grad.append(args[i]) if len(outputs_with_grad) == 0: raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary") torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) return (None, None, None, None) + grads
def test_split_unpack(): """Test split_non_tensors and unpack_non_tensors.""" x = torch.Tensor([1]) y = torch.Tensor([2]) tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) assert tensors == (x, y) assert packed_non_tensors == { "is_tensor": [True, True, False, False], "objects": [None, 3], } recon = unpack_non_tensors(tensors, packed_non_tensors) assert recon == (x, y, None, 3) tensors, packed_non_tensors = split_non_tensors((None, 3, x, y)) recon = unpack_non_tensors(tensors, packed_non_tensors) assert recon == (None, 3, x, y) tensors, packed_non_tensors = split_non_tensors((None, 3)) recon = unpack_non_tensors(tensors, packed_non_tensors) assert recon == (None, 3) tensors, packed_non_tensors = split_non_tensors((x, y)) recon = unpack_non_tensors(tensors, packed_non_tensors) assert recon == (x, y) recon = unpack_non_tensors(tensors, None) assert recon == (x, y) with pytest.raises(AssertionError): # assert the second arg should be a dict. recon = unpack_non_tensors(tensors, set()) with pytest.raises(AssertionError): # assert the content of the second arg should be sane. recon = unpack_non_tensors(tensors, {"is_tensor": [], "objects": []})