Пример #1
0
    def forward(  # type: ignore
            ctx: Any, dummy_tensor_requires_grad: torch.Tensor,
            run_function: Any, parent_ctx_dict: Dict[str, Any],
            kwarg_keys: Tuple[str, ...], *args: Any, **kwargs: Any) -> Any:
        torch_checkpoint.check_backward_validity(args)

        ctx.run_function = run_function
        ctx.kwarg_keys = kwarg_keys
        ctx.fwd_rng_state = get_rng_state()
        ctx.had_autocast_in_fwd = is_autocast_enabled()

        tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args)
        if parent_ctx_dict["offload"]:
            ctx.fwd_device = tuple(x.device for x in tensor_inputs)
            ctx.grad_requirements = tuple(x.requires_grad
                                          for x in tensor_inputs)
            tensor_inputs = tuple(
                x.to("cpu", non_blocking=True) for x in tensor_inputs)
        else:
            ctx.fwd_device, ctx.grad_requirements = None, None

        ctx.save_for_backward(*tensor_inputs)
        ctx.packed_non_tensor_inputs = packed_non_tensor_inputs

        with torch.no_grad(), enable_checkpointing():
            unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
            outputs = run_function(*unpacked_args, **unpacked_kwargs)
            the_module = unpacked_args[0]

        # Because we run with torch.no_grad(), we can't actually access
        # outputs.requires_grad. Instead, we manually compute it by
        # checking if either the input or the module needs grads
        parameters = list(the_module.parameters())

        # If the module is wrapped by FlattenParamsWrapper, then the
        # parameters would have been deleted. If so, we need to access
        # the views into the flattened parameters.
        if hasattr(the_module, "_unflattened_param_views"):
            parameters += the_module._unflattened_param_views

        output_requires_grad = any(param.requires_grad
                                   for param in parameters) or any(
                                       x.requires_grad for x in tensor_inputs)
        parent_ctx_dict["output_requires_grad"] = output_requires_grad

        if not isinstance(outputs, torch.Tensor):
            # Autograd Functions don't like non-Tensor outputs. We can split the
            # non-Tensor and Tensor outputs, returning the former by reference
            # through *parent_ctx_dict* and returning the latter directly.
            outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
            parent_ctx_dict[
                "packed_non_tensor_outputs"] = packed_non_tensor_outputs

        return outputs
Пример #2
0
    def forward(  # type: ignore
        ctx: Any,
        run_function: Any,
        parent_ctx_dict: Dict[str, Any],
        kwarg_keys: Tuple[str, ...],
        *args: Any,
        **kwargs: Any
    ) -> Any:
        if torch.is_grad_enabled():  # grad may be disabled, e.g., during validation
            torch_checkpoint.check_backward_validity(args)

        ctx.run_function = run_function
        ctx.kwarg_keys = kwarg_keys
        ctx.fwd_rng_state = get_rng_state()
        ctx.had_autocast_in_fwd = is_autocast_enabled()

        tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args)
        if parent_ctx_dict["offload"]:
            ctx.fwd_device = tuple(x.device for x in tensor_inputs)
            ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
            tensor_inputs = tuple(x.cpu() for x in tensor_inputs)
        else:
            ctx.fwd_device, ctx.grad_requirements = None, None

        ctx.save_for_backward(*tensor_inputs)
        ctx.packed_non_tensor_inputs = packed_non_tensor_inputs

        with torch.no_grad():
            unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
            outputs = run_function(*unpacked_args, **unpacked_kwargs)
            the_module = unpacked_args[0]
            inc_counter(the_module)

        if not isinstance(outputs, torch.Tensor):
            # Autograd Functions don't like non-Tensor outputs. We can split the
            # non-Tensor and Tensor outputs, returning the former by reference
            # through *parent_ctx_dict* and returning the latter directly.
            outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
            parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs
        return outputs
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
    """
    Cast any Tensors in *args or **kwargs to FP16.

    Doesn't currently support Tensors nested inside containers (e.g., dict).
    """
    kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
    tensor_inputs, packed_non_tensor_inputs = split_non_tensors(flat_args)
    tensor_inputs = tuple(t.half() if torch.is_floating_point(t) else t
                          for t in tensor_inputs)
    flat_args = unpack_non_tensors(tensor_inputs, packed_non_tensor_inputs)
    args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
    return args, kwargs
Пример #4
0
    def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
            )

        tensor_inputs: Tuple = ctx.saved_tensors
        tensor_inputs = torch_checkpoint.detach_variable(tensor_inputs)
        if ctx.fwd_device is not None:
            tensor_inputs = tuple(
                t.to(ctx.fwd_device[i], non_blocking=True)
                for i, t in enumerate(tensor_inputs))
            for i, need_grad in enumerate(ctx.grad_requirements):
                tensor_inputs[i].requires_grad = need_grad
        inputs = unpack_non_tensors(tensor_inputs,
                                    ctx.packed_non_tensor_inputs)

        # Store the current states.
        bwd_rng_state = get_rng_state()

        # Set the states to what it used to be before the forward pass.
        set_rng_state(ctx.fwd_rng_state)

        with torch.enable_grad(), enable_recomputing(), autocast(
                ctx.had_autocast_in_fwd):
            unpacked_args, unpacked_kwargs = unpack_kwargs(
                ctx.kwarg_keys, inputs)
            outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
            tensor_outputs, _ = split_non_tensors(outputs)

        # Set the states back to what it was at the start of this function.
        set_rng_state(bwd_rng_state)

        # Run backward() with only Tensors that require grad
        outputs_with_grad = []
        args_with_grad = []
        for i in range(len(tensor_outputs)):
            if tensor_outputs[i].requires_grad:
                outputs_with_grad.append(tensor_outputs[i])
                args_with_grad.append(args[i])

        if len(outputs_with_grad) == 0:
            raise RuntimeError("None of the outputs have requires_grad=True, "
                               "this checkpoint() is not necessary")

        torch.autograd.backward(outputs_with_grad, args_with_grad)

        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
                      for inp in inputs)

        return (None, None, None, None) + grads
Пример #5
0
def test_split_unpack():
    """Test split_non_tensors and unpack_non_tensors."""
    x = torch.Tensor([1])
    y = torch.Tensor([2])

    tensors, packed_non_tensors = split_non_tensors((x, y, None, 3))
    assert tensors == (x, y)
    assert packed_non_tensors == {
        "is_tensor": [True, True, False, False],
        "objects": [None, 3],
    }
    recon = unpack_non_tensors(tensors, packed_non_tensors)
    assert recon == (x, y, None, 3)

    tensors, packed_non_tensors = split_non_tensors((None, 3, x, y))
    recon = unpack_non_tensors(tensors, packed_non_tensors)
    assert recon == (None, 3, x, y)

    tensors, packed_non_tensors = split_non_tensors((None, 3))
    recon = unpack_non_tensors(tensors, packed_non_tensors)
    assert recon == (None, 3)

    tensors, packed_non_tensors = split_non_tensors((x, y))
    recon = unpack_non_tensors(tensors, packed_non_tensors)
    assert recon == (x, y)

    recon = unpack_non_tensors(tensors, None)
    assert recon == (x, y)

    with pytest.raises(AssertionError):
        # assert the second arg should be a dict.
        recon = unpack_non_tensors(tensors, set())

    with pytest.raises(AssertionError):
        # assert the content of the second arg should be sane.
        recon = unpack_non_tensors(tensors, {"is_tensor": [], "objects": []})