Exemplo n.º 1
0
def test_pack_unpack():
    """Test pack_kwargs and unpack_kwargs."""
    kwarg_keys, flat_args = pack_kwargs(1, 2, 3, 4)
    assert kwarg_keys == tuple()
    assert flat_args == (1, 2, 3, 4)

    kwarg_keys, flat_args = pack_kwargs(a=1, b={2: "2"}, c={3}, d=[4], e=(5, ))
    assert kwarg_keys == ("a", "b", "c", "d", "e")
    assert flat_args == (1, {2: "2"}, {3}, [4], (5, ))

    kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
    assert kwarg_keys == ("a", "b")
    assert flat_args == (1, 2, 3, 4)

    args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
    assert args == (1, 2)
    assert kwargs == {"a": 3, "b": 4}

    args, kwargs = unpack_kwargs([], flat_args)
    assert kwargs == {}
    assert args == (1, 2, 3, 4)

    args, kwargs = unpack_kwargs(["a", "b", "c", "d"], flat_args)
    assert kwargs == {"a": 1, "b": 2, "c": 3, "d": 4}
    assert args == tuple()

    with pytest.raises(AssertionError):
        # too many keys should assert.
        args, kwargs = unpack_kwargs(["a", "b", "c", "d", "e"], flat_args)
def _checkpointed_forward(original_forward: Any, weak_self: Any,
                          offload_to_cpu: bool, *args: Any,
                          **kwargs: Any) -> Any:
    module = weak_self()

    # If gradients are disabled, just use original `.forward()` method directly.
    # Doing so also ensures the internal fwd counter is not incremented in the forward pass,
    # which would be an issue during eval since there wouldn't be a corresponding backward pass
    # to decrement the fwd counter.
    # See https://github.com/facebookresearch/fairscale/pull/709.
    if not torch.is_grad_enabled():
        return original_forward(module, *args, **kwargs)

    # Autograd Functions in PyTorch work best with positional args, since
    # the backward must return gradients (or None) for every input argument.
    # We can flatten keyword arguments to make this easier.
    args = (module, ) + args
    kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
    parent_ctx_dict: Dict[str, Any] = {
        "offload": offload_to_cpu,
    }
    # Dummy tensor with grad is used to ensure the backward pass is called. This is needed
    # when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor
    # avoids requiring users to set their input tensors's requires_grad flag. In the case
    # of tuple type inputs, setting the flag won't even trigger the backward pass.
    output = CheckpointFunction.apply(torch.tensor([], requires_grad=True),
                                      original_forward, parent_ctx_dict,
                                      kwarg_keys, *flat_args)
    if not isinstance(output, torch.Tensor):
        packed_non_tensor_outputs = parent_ctx_dict[
            "packed_non_tensor_outputs"]
        if packed_non_tensor_outputs:
            output = unpack_non_tensors(output, packed_non_tensor_outputs)
    return output
Exemplo n.º 3
0
def _checkpointed_forward(original_forward: Any, weak_self: Any,
                          offload_to_cpu: bool, *args: Any,
                          **kwargs: Any) -> Any:
    module = weak_self()

    # If gradients are disabled, just use original `.forward()` method directly.
    if not torch.is_grad_enabled() or thread_local.is_checkpointing_disabled:
        return original_forward(module, *args, **kwargs)

    # Autograd Functions in PyTorch work best with positional args, since
    # the backward must return gradients (or None) for every input argument.
    # We can flatten keyword arguments to make this easier.
    args = (module, ) + args
    kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
    parent_ctx_dict: Dict[str, Any] = {
        "offload": offload_to_cpu,
    }
    # Dummy tensor with grad is used to ensure the backward pass is called. This is needed
    # when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor
    # avoids requiring users to set their input tensors's requires_grad flag. In the case
    # of tuple type inputs, setting the flag won't even trigger the backward pass.
    #
    # One implication of this is that since we always feed in a dummy tensor
    # needing grad, then the output will always require grad, even if it originally
    # wouldn't, such as if the module and original input both do not require grad.
    # We get around this by saving the desired requires_grad value in output and
    # detaching the output if needed.
    output = CheckpointFunction.apply(torch.tensor([], requires_grad=True),
                                      original_forward, parent_ctx_dict,
                                      kwarg_keys, *flat_args)
    output_requires_grad = parent_ctx_dict["output_requires_grad"]
    if not isinstance(output, torch.Tensor):
        # If output should not require grad, then detach it, since otherwise it will
        # always have requires_grad = True due to our dummy tensor input above that
        # requires_grad
        output = [
            x.detach() if not output_requires_grad else x for x in output
        ]

        packed_non_tensor_outputs = parent_ctx_dict[
            "packed_non_tensor_outputs"]
        if packed_non_tensor_outputs:
            output = unpack_non_tensors(output, packed_non_tensor_outputs)

    else:
        # If output should not require grad, then detach it, since otherwise it will
        # always have requires_grad = True due to our dummy tensor input above that
        # requires_grad
        if not output_requires_grad:
            output = output.detach()

    return output
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
    """
    Cast any Tensors in *args or **kwargs to FP16.

    Doesn't currently support Tensors nested inside containers (e.g., dict).
    """
    kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
    tensor_inputs, packed_non_tensor_inputs = split_non_tensors(flat_args)
    tensor_inputs = tuple(t.half() if torch.is_floating_point(t) else t
                          for t in tensor_inputs)
    flat_args = unpack_non_tensors(tensor_inputs, packed_non_tensor_inputs)
    args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
    return args, kwargs
Exemplo n.º 5
0
def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool,
                          *args: Any, **kwargs: Any) -> Any:
    # Autograd Functions in PyTorch work best with positional args, since
    # the backward must return gradients (or None) for every input argument.
    # We can flatten keyword arguments to make this easier.
    kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
    parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu}
    output = CheckpointFunction.apply(original_forward, parent_ctx_dict,
                                      kwarg_keys, *flat_args)
    if not isinstance(output, torch.Tensor):
        packed_non_tensor_outputs = parent_ctx_dict[
            "packed_non_tensor_outputs"]
        if packed_non_tensor_outputs:
            output = unpack_non_tensors(output, packed_non_tensor_outputs)
    return output