def test_pack_unpack(): """Test pack_kwargs and unpack_kwargs.""" kwarg_keys, flat_args = pack_kwargs(1, 2, 3, 4) assert kwarg_keys == tuple() assert flat_args == (1, 2, 3, 4) kwarg_keys, flat_args = pack_kwargs(a=1, b={2: "2"}, c={3}, d=[4], e=(5, )) assert kwarg_keys == ("a", "b", "c", "d", "e") assert flat_args == (1, {2: "2"}, {3}, [4], (5, )) kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) assert kwarg_keys == ("a", "b") assert flat_args == (1, 2, 3, 4) args, kwargs = unpack_kwargs(kwarg_keys, flat_args) assert args == (1, 2) assert kwargs == {"a": 3, "b": 4} args, kwargs = unpack_kwargs([], flat_args) assert kwargs == {} assert args == (1, 2, 3, 4) args, kwargs = unpack_kwargs(["a", "b", "c", "d"], flat_args) assert kwargs == {"a": 1, "b": 2, "c": 3, "d": 4} assert args == tuple() with pytest.raises(AssertionError): # too many keys should assert. args, kwargs = unpack_kwargs(["a", "b", "c", "d", "e"], flat_args)
def _checkpointed_forward(original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any: module = weak_self() # If gradients are disabled, just use original `.forward()` method directly. # Doing so also ensures the internal fwd counter is not incremented in the forward pass, # which would be an issue during eval since there wouldn't be a corresponding backward pass # to decrement the fwd counter. # See https://github.com/facebookresearch/fairscale/pull/709. if not torch.is_grad_enabled(): return original_forward(module, *args, **kwargs) # Autograd Functions in PyTorch work best with positional args, since # the backward must return gradients (or None) for every input argument. # We can flatten keyword arguments to make this easier. args = (module, ) + args kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) parent_ctx_dict: Dict[str, Any] = { "offload": offload_to_cpu, } # Dummy tensor with grad is used to ensure the backward pass is called. This is needed # when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor # avoids requiring users to set their input tensors's requires_grad flag. In the case # of tuple type inputs, setting the flag won't even trigger the backward pass. output = CheckpointFunction.apply(torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args) if not isinstance(output, torch.Tensor): packed_non_tensor_outputs = parent_ctx_dict[ "packed_non_tensor_outputs"] if packed_non_tensor_outputs: output = unpack_non_tensors(output, packed_non_tensor_outputs) return output
def _checkpointed_forward(original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any: module = weak_self() # If gradients are disabled, just use original `.forward()` method directly. if not torch.is_grad_enabled() or thread_local.is_checkpointing_disabled: return original_forward(module, *args, **kwargs) # Autograd Functions in PyTorch work best with positional args, since # the backward must return gradients (or None) for every input argument. # We can flatten keyword arguments to make this easier. args = (module, ) + args kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) parent_ctx_dict: Dict[str, Any] = { "offload": offload_to_cpu, } # Dummy tensor with grad is used to ensure the backward pass is called. This is needed # when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor # avoids requiring users to set their input tensors's requires_grad flag. In the case # of tuple type inputs, setting the flag won't even trigger the backward pass. # # One implication of this is that since we always feed in a dummy tensor # needing grad, then the output will always require grad, even if it originally # wouldn't, such as if the module and original input both do not require grad. # We get around this by saving the desired requires_grad value in output and # detaching the output if needed. output = CheckpointFunction.apply(torch.tensor([], requires_grad=True), original_forward, parent_ctx_dict, kwarg_keys, *flat_args) output_requires_grad = parent_ctx_dict["output_requires_grad"] if not isinstance(output, torch.Tensor): # If output should not require grad, then detach it, since otherwise it will # always have requires_grad = True due to our dummy tensor input above that # requires_grad output = [ x.detach() if not output_requires_grad else x for x in output ] packed_non_tensor_outputs = parent_ctx_dict[ "packed_non_tensor_outputs"] if packed_non_tensor_outputs: output = unpack_non_tensors(output, packed_non_tensor_outputs) else: # If output should not require grad, then detach it, since otherwise it will # always have requires_grad = True due to our dummy tensor input above that # requires_grad if not output_requires_grad: output = output.detach() return output
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: """ Cast any Tensors in *args or **kwargs to FP16. Doesn't currently support Tensors nested inside containers (e.g., dict). """ kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) tensor_inputs, packed_non_tensor_inputs = split_non_tensors(flat_args) tensor_inputs = tuple(t.half() if torch.is_floating_point(t) else t for t in tensor_inputs) flat_args = unpack_non_tensors(tensor_inputs, packed_non_tensor_inputs) args, kwargs = unpack_kwargs(kwarg_keys, flat_args) return args, kwargs
def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any: # Autograd Functions in PyTorch work best with positional args, since # the backward must return gradients (or None) for every input argument. # We can flatten keyword arguments to make this easier. kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu} output = CheckpointFunction.apply(original_forward, parent_ctx_dict, kwarg_keys, *flat_args) if not isinstance(output, torch.Tensor): packed_non_tensor_outputs = parent_ctx_dict[ "packed_non_tensor_outputs"] if packed_non_tensor_outputs: output = unpack_non_tensors(output, packed_non_tensor_outputs) return output