def _checkpointed_forward(original_forward: Any, weak_self: Any,
                          offload_to_cpu: bool, *args: Any,
                          **kwargs: Any) -> Any:
    module = weak_self()

    # If gradients are disabled, just use original `.forward()` method directly.
    # Doing so also ensures the internal fwd counter is not incremented in the forward pass,
    # which would be an issue during eval since there wouldn't be a corresponding backward pass
    # to decrement the fwd counter.
    # See https://github.com/facebookresearch/fairscale/pull/709.
    if not torch.is_grad_enabled():
        return original_forward(module, *args, **kwargs)

    # Autograd Functions in PyTorch work best with positional args, since
    # the backward must return gradients (or None) for every input argument.
    # We can flatten keyword arguments to make this easier.
    args = (module, ) + args
    kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
    parent_ctx_dict: Dict[str, Any] = {
        "offload": offload_to_cpu,
    }
    # Dummy tensor with grad is used to ensure the backward pass is called. This is needed
    # when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor
    # avoids requiring users to set their input tensors's requires_grad flag. In the case
    # of tuple type inputs, setting the flag won't even trigger the backward pass.
    output = CheckpointFunction.apply(torch.tensor([], requires_grad=True),
                                      original_forward, parent_ctx_dict,
                                      kwarg_keys, *flat_args)
    if not isinstance(output, torch.Tensor):
        packed_non_tensor_outputs = parent_ctx_dict[
            "packed_non_tensor_outputs"]
        if packed_non_tensor_outputs:
            output = unpack_non_tensors(output, packed_non_tensor_outputs)
    return output
def _checkpointed_forward(original_forward: Any, weak_self: Any,
                          offload_to_cpu: bool, *args: Any,
                          **kwargs: Any) -> Any:
    module = weak_self()

    # If gradients are disabled, just use original `.forward()` method directly.
    if not torch.is_grad_enabled() or thread_local.is_checkpointing_disabled:
        return original_forward(module, *args, **kwargs)

    # Autograd Functions in PyTorch work best with positional args, since
    # the backward must return gradients (or None) for every input argument.
    # We can flatten keyword arguments to make this easier.
    args = (module, ) + args
    kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
    parent_ctx_dict: Dict[str, Any] = {
        "offload": offload_to_cpu,
    }
    # Dummy tensor with grad is used to ensure the backward pass is called. This is needed
    # when original_forward's input are non-tensor (i.e. a tuple). Using this dummy tensor
    # avoids requiring users to set their input tensors's requires_grad flag. In the case
    # of tuple type inputs, setting the flag won't even trigger the backward pass.
    #
    # One implication of this is that since we always feed in a dummy tensor
    # needing grad, then the output will always require grad, even if it originally
    # wouldn't, such as if the module and original input both do not require grad.
    # We get around this by saving the desired requires_grad value in output and
    # detaching the output if needed.
    output = CheckpointFunction.apply(torch.tensor([], requires_grad=True),
                                      original_forward, parent_ctx_dict,
                                      kwarg_keys, *flat_args)
    output_requires_grad = parent_ctx_dict["output_requires_grad"]
    if not isinstance(output, torch.Tensor):
        # If output should not require grad, then detach it, since otherwise it will
        # always have requires_grad = True due to our dummy tensor input above that
        # requires_grad
        output = [
            x.detach() if not output_requires_grad else x for x in output
        ]

        packed_non_tensor_outputs = parent_ctx_dict[
            "packed_non_tensor_outputs"]
        if packed_non_tensor_outputs:
            output = unpack_non_tensors(output, packed_non_tensor_outputs)

    else:
        # If output should not require grad, then detach it, since otherwise it will
        # always have requires_grad = True due to our dummy tensor input above that
        # requires_grad
        if not output_requires_grad:
            output = output.detach()

    return output
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
    """
    Cast any Tensors in *args or **kwargs to FP16.

    Doesn't currently support Tensors nested inside containers (e.g., dict).
    """
    kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
    tensor_inputs, packed_non_tensor_inputs = split_non_tensors(flat_args)
    tensor_inputs = tuple(t.half() if torch.is_floating_point(t) else t
                          for t in tensor_inputs)
    flat_args = unpack_non_tensors(tensor_inputs, packed_non_tensor_inputs)
    args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
    return args, kwargs
    def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
            )

        tensor_inputs: Tuple = ctx.saved_tensors
        tensor_inputs = torch_checkpoint.detach_variable(tensor_inputs)
        if ctx.fwd_device is not None:
            tensor_inputs = tuple(
                t.to(ctx.fwd_device[i], non_blocking=True)
                for i, t in enumerate(tensor_inputs))
            for i, need_grad in enumerate(ctx.grad_requirements):
                tensor_inputs[i].requires_grad = need_grad
        inputs = unpack_non_tensors(tensor_inputs,
                                    ctx.packed_non_tensor_inputs)

        # Store the current states.
        bwd_rng_state = get_rng_state()

        # Set the states to what it used to be before the forward pass.
        set_rng_state(ctx.fwd_rng_state)

        with torch.enable_grad(), enable_recomputing(), autocast(
                ctx.had_autocast_in_fwd):
            unpacked_args, unpacked_kwargs = unpack_kwargs(
                ctx.kwarg_keys, inputs)
            outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
            tensor_outputs, _ = split_non_tensors(outputs)

        # Set the states back to what it was at the start of this function.
        set_rng_state(bwd_rng_state)

        # Run backward() with only Tensors that require grad
        outputs_with_grad = []
        args_with_grad = []
        for i in range(len(tensor_outputs)):
            if tensor_outputs[i].requires_grad:
                outputs_with_grad.append(tensor_outputs[i])
                args_with_grad.append(args[i])

        if len(outputs_with_grad) == 0:
            raise RuntimeError("None of the outputs have requires_grad=True, "
                               "this checkpoint() is not necessary")

        torch.autograd.backward(outputs_with_grad, args_with_grad)

        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
                      for inp in inputs)

        return (None, None, None, None) + grads
示例#5
0
def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool,
                          *args: Any, **kwargs: Any) -> Any:
    # Autograd Functions in PyTorch work best with positional args, since
    # the backward must return gradients (or None) for every input argument.
    # We can flatten keyword arguments to make this easier.
    kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
    parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu}
    output = CheckpointFunction.apply(original_forward, parent_ctx_dict,
                                      kwarg_keys, *flat_args)
    if not isinstance(output, torch.Tensor):
        packed_non_tensor_outputs = parent_ctx_dict[
            "packed_non_tensor_outputs"]
        if packed_non_tensor_outputs:
            output = unpack_non_tensors(output, packed_non_tensor_outputs)
    return output
示例#6
0
def test_split_unpack():
    """Test split_non_tensors and unpack_non_tensors."""
    x = torch.Tensor([1])
    y = torch.Tensor([2])

    tensors, packed_non_tensors = split_non_tensors((x, y, None, 3))
    assert tensors == (x, y)
    assert packed_non_tensors == {
        "is_tensor": [True, True, False, False],
        "objects": [None, 3],
    }
    recon = unpack_non_tensors(tensors, packed_non_tensors)
    assert recon == (x, y, None, 3)

    tensors, packed_non_tensors = split_non_tensors((None, 3, x, y))
    recon = unpack_non_tensors(tensors, packed_non_tensors)
    assert recon == (None, 3, x, y)

    tensors, packed_non_tensors = split_non_tensors((None, 3))
    recon = unpack_non_tensors(tensors, packed_non_tensors)
    assert recon == (None, 3)

    tensors, packed_non_tensors = split_non_tensors((x, y))
    recon = unpack_non_tensors(tensors, packed_non_tensors)
    assert recon == (x, y)

    recon = unpack_non_tensors(tensors, None)
    assert recon == (x, y)

    with pytest.raises(AssertionError):
        # assert the second arg should be a dict.
        recon = unpack_non_tensors(tensors, set())

    with pytest.raises(AssertionError):
        # assert the content of the second arg should be sane.
        recon = unpack_non_tensors(tensors, {"is_tensor": [], "objects": []})