def test_pack_unpack(): """Test pack_kwargs and unpack_kwargs.""" kwarg_keys, flat_args = pack_kwargs(1, 2, 3, 4) assert kwarg_keys == tuple() assert flat_args == (1, 2, 3, 4) kwarg_keys, flat_args = pack_kwargs(a=1, b={2: "2"}, c={3}, d=[4], e=(5, )) assert kwarg_keys == ("a", "b", "c", "d", "e") assert flat_args == (1, {2: "2"}, {3}, [4], (5, )) kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) assert kwarg_keys == ("a", "b") assert flat_args == (1, 2, 3, 4) args, kwargs = unpack_kwargs(kwarg_keys, flat_args) assert args == (1, 2) assert kwargs == {"a": 3, "b": 4} args, kwargs = unpack_kwargs([], flat_args) assert kwargs == {} assert args == (1, 2, 3, 4) args, kwargs = unpack_kwargs(["a", "b", "c", "d"], flat_args) assert kwargs == {"a": 1, "b": 2, "c": 3, "d": 4} assert args == tuple() with pytest.raises(AssertionError): # too many keys should assert. args, kwargs = unpack_kwargs(["a", "b", "c", "d", "e"], flat_args)
def forward( # type: ignore ctx: Any, dummy_tensor_requires_grad: torch.Tensor, run_function: Any, parent_ctx_dict: Dict[str, Any], kwarg_keys: Tuple[str, ...], *args: Any, **kwargs: Any) -> Any: torch_checkpoint.check_backward_validity(args) ctx.run_function = run_function ctx.kwarg_keys = kwarg_keys ctx.fwd_rng_state = get_rng_state() ctx.had_autocast_in_fwd = is_autocast_enabled() tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) if parent_ctx_dict["offload"]: ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) tensor_inputs = tuple( x.to("cpu", non_blocking=True) for x in tensor_inputs) else: ctx.fwd_device, ctx.grad_requirements = None, None ctx.save_for_backward(*tensor_inputs) ctx.packed_non_tensor_inputs = packed_non_tensor_inputs with torch.no_grad(), enable_checkpointing(): unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) outputs = run_function(*unpacked_args, **unpacked_kwargs) the_module = unpacked_args[0] # Because we run with torch.no_grad(), we can't actually access # outputs.requires_grad. Instead, we manually compute it by # checking if either the input or the module needs grads parameters = list(the_module.parameters()) # If the module is wrapped by FlattenParamsWrapper, then the # parameters would have been deleted. If so, we need to access # the views into the flattened parameters. if hasattr(the_module, "_unflattened_param_views"): parameters += the_module._unflattened_param_views output_requires_grad = any(param.requires_grad for param in parameters) or any( x.requires_grad for x in tensor_inputs) parent_ctx_dict["output_requires_grad"] = output_requires_grad if not isinstance(outputs, torch.Tensor): # Autograd Functions don't like non-Tensor outputs. We can split the # non-Tensor and Tensor outputs, returning the former by reference # through *parent_ctx_dict* and returning the latter directly. outputs, packed_non_tensor_outputs = split_non_tensors(outputs) parent_ctx_dict[ "packed_non_tensor_outputs"] = packed_non_tensor_outputs return outputs
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: """ Cast any Tensors in *args or **kwargs to FP16. Doesn't currently support Tensors nested inside containers (e.g., dict). """ kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) tensor_inputs, packed_non_tensor_inputs = split_non_tensors(flat_args) tensor_inputs = tuple(t.half() if torch.is_floating_point(t) else t for t in tensor_inputs) flat_args = unpack_non_tensors(tensor_inputs, packed_non_tensor_inputs) args, kwargs = unpack_kwargs(kwarg_keys, flat_args) return args, kwargs
def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad(), please use .backward() if possible" ) tensor_inputs: Tuple = ctx.saved_tensors tensor_inputs = torch_checkpoint.detach_variable(tensor_inputs) if ctx.fwd_device is not None: tensor_inputs = tuple( t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs)) for i, need_grad in enumerate(ctx.grad_requirements): tensor_inputs[i].requires_grad = need_grad inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) # Store the current states. bwd_rng_state = get_rng_state() # Set the states to what it used to be before the forward pass. set_rng_state(ctx.fwd_rng_state) with torch.enable_grad(), enable_recomputing(), autocast( ctx.had_autocast_in_fwd): unpacked_args, unpacked_kwargs = unpack_kwargs( ctx.kwarg_keys, inputs) outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) tensor_outputs, _ = split_non_tensors(outputs) # Set the states back to what it was at the start of this function. set_rng_state(bwd_rng_state) # Run backward() with only Tensors that require grad outputs_with_grad = [] args_with_grad = [] for i in range(len(tensor_outputs)): if tensor_outputs[i].requires_grad: outputs_with_grad.append(tensor_outputs[i]) args_with_grad.append(args[i]) if len(outputs_with_grad) == 0: raise RuntimeError("None of the outputs have requires_grad=True, " "this checkpoint() is not necessary") torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) return (None, None, None, None) + grads
def forward( # type: ignore ctx: Any, run_function: Any, parent_ctx_dict: Dict[str, Any], kwarg_keys: Tuple[str, ...], *args: Any, **kwargs: Any ) -> Any: if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation torch_checkpoint.check_backward_validity(args) ctx.run_function = run_function ctx.kwarg_keys = kwarg_keys ctx.fwd_rng_state = get_rng_state() ctx.had_autocast_in_fwd = is_autocast_enabled() tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) if parent_ctx_dict["offload"]: ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) tensor_inputs = tuple(x.cpu() for x in tensor_inputs) else: ctx.fwd_device, ctx.grad_requirements = None, None ctx.save_for_backward(*tensor_inputs) ctx.packed_non_tensor_inputs = packed_non_tensor_inputs with torch.no_grad(): unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) outputs = run_function(*unpacked_args, **unpacked_kwargs) the_module = unpacked_args[0] inc_counter(the_module) if not isinstance(outputs, torch.Tensor): # Autograd Functions don't like non-Tensor outputs. We can split the # non-Tensor and Tensor outputs, returning the former by reference # through *parent_ctx_dict* and returning the latter directly. outputs, packed_non_tensor_outputs = split_non_tensors(outputs) parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs return outputs