Ejemplo n.º 1
0
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
            )
        inputs = detach_variable(ctx.saved_tensors)

        # Store the current states.
        bwd_rng_state = utils.get_rng_state()

        # Set the states to what it used to be before the forward pass.
        utils.set_rng_state(ctx.fwd_rng_state)

        with torch.enable_grad():
            outputs = ctx.run_function(*inputs)

        # Set the states back to what it was at the start of this function.
        utils.set_rng_state(bwd_rng_state)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs, )

        torch.autograd.backward(outputs, args)

        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
                      for inp in inputs)
        return (None, ) + grads
Ejemplo n.º 2
0
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad(), please use .backward() if possible"
            )

        tensor_inputs: Tuple = ctx.saved_tensors
        tensor_inputs = checkpoint.detach_variable(tensor_inputs)
        if ctx.fwd_device is not None:
            tensor_inputs = [
                t.to(ctx.fwd_device[i], non_blocking=True)
                for i, t in enumerate(tensor_inputs)
            ]
            for i, need_grad in enumerate(ctx.grad_requirements):
                tensor_inputs[i].requires_grad = need_grad
        inputs = unpack_non_tensors(tensor_inputs,
                                    ctx.packed_non_tensor_inputs)

        # Store the current states.
        bwd_rng_state = utils.get_rng_state()

        # Set the states to what it used to be before the forward pass.
        utils.set_rng_state(ctx.fwd_rng_state)

        with torch.enable_grad():
            unpacked_args, unpacked_kwargs = unpack_kwargs(
                ctx.kwarg_keys, inputs)
            outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
            tensor_outputs, _ = split_non_tensors(outputs)
        # Set the states back to what it was at the start of this function.
        utils.set_rng_state(bwd_rng_state)

        # Run backward() with only Tensors that require grad
        outputs_with_grad = []
        args_with_grad = []
        for i in range(len(tensor_outputs)):
            if tensor_outputs[i].requires_grad:
                outputs_with_grad.append(tensor_outputs[i])
                args_with_grad.append(args[i])
        if len(outputs_with_grad) == 0:
            raise RuntimeError("None of the outputs have requires_grad=True, "
                               "this checkpoint() is not necessary")

        torch.autograd.backward(outputs_with_grad, args_with_grad)

        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
                      for inp in inputs)
        return (None, None, None) + grads