Exemple #1
0
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
            )
        inputs = detach_variable(ctx.saved_tensors)

        # Store the current states.
        bwd_rng_state = utils.get_rng_state()

        # Set the states to what it used to be before the forward pass.
        utils.set_rng_state(ctx.fwd_rng_state)

        with torch.enable_grad():
            outputs = ctx.run_function(*inputs)

        # Set the states back to what it was at the start of this function.
        utils.set_rng_state(bwd_rng_state)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs, )

        torch.autograd.backward(outputs, args)

        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
                      for inp in inputs)
        return (None, ) + grads
    def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args):
        if torch.is_grad_enabled(
        ):  # grad may be disabled, e.g., during validation
            checkpoint.check_backward_validity(args)

        ctx.run_function = run_function
        ctx.kwarg_keys = kwarg_keys
        ctx.fwd_rng_state = utils.get_rng_state()

        tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args)
        ctx.save_for_backward(*tensor_inputs)
        ctx.packed_non_tensor_inputs = packed_non_tensor_inputs

        with torch.no_grad():
            unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
            outputs = run_function(*unpacked_args, **unpacked_kwargs)

        if isinstance(outputs, torch.Tensor):
            return outputs
        else:
            # Autograd Functions don't like non-Tensor outputs. We can split the
            # non-Tensor and Tensor outputs, returning the former by reference
            # through *parent_ctx_dict* and returning the latter directly.
            outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
            parent_ctx_dict[
                "packed_non_tensor_outputs"] = packed_non_tensor_outputs
            return outputs
Exemple #3
0
 def forward(ctx, run_function, *args):
     check_backward_validity(args)
     ctx.run_function = run_function
     ctx.fwd_rng_state = utils.get_rng_state()
     ctx.save_for_backward(*args)
     with torch.no_grad():
         outputs = run_function(*args)
     return outputs
Exemple #4
0
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
            )

        tensor_inputs: Tuple = ctx.saved_tensors
        tensor_inputs = checkpoint.detach_variable(tensor_inputs)
        if ctx.fwd_device is not None:
            tensor_inputs = [
                t.to(ctx.fwd_device[i], non_blocking=True)
                for i, t in enumerate(tensor_inputs)
            ]
            for i, need_grad in enumerate(ctx.grad_requirements):
                tensor_inputs[i].requires_grad = need_grad
        inputs = unpack_non_tensors(tensor_inputs,
                                    ctx.packed_non_tensor_inputs)

        # Store the current states.
        bwd_rng_state = utils.get_rng_state()

        # Set the states to what it used to be before the forward pass.
        utils.set_rng_state(ctx.fwd_rng_state)

        with torch.enable_grad():
            unpacked_args, unpacked_kwargs = unpack_kwargs(
                ctx.kwarg_keys, inputs)
            outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
            tensor_outputs, _ = split_non_tensors(outputs)
        # Set the states back to what it was at the start of this function.
        utils.set_rng_state(bwd_rng_state)

        # Run backward() with only Tensors that require grad
        outputs_with_grad = []
        args_with_grad = []
        for i in range(len(tensor_outputs)):
            if tensor_outputs[i].requires_grad:
                outputs_with_grad.append(tensor_outputs[i])
                args_with_grad.append(args[i])
        if len(outputs_with_grad) == 0:
            raise RuntimeError("None of the outputs have requires_grad=True, "
                               "this checkpoint() is not necessary")

        torch.autograd.backward(outputs_with_grad, args_with_grad)

        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
                      for inp in inputs)
        return (None, None, None) + grads