def make_checkpoint(function: Function, input: TensorOrTensors, index: int) -> TensorOrTensors: """Makes a checkpoint with a simple interface like :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug :class:`Checkpoint` and :class:`Recompute` without boilerplate. """ batch = Batch(input, index) chk = Checkpointing(function, batch) batch = chk.checkpoint() chk.recompute(batch) return batch.tensor_or_tensors
def test_not_requires_grad(): x = Batch(torch.rand(1, requires_grad=False)) assert not x[0].requires_grad def f(x): return x * 2 chk = Checkpointing(f, x) x = chk.checkpoint() assert x[0].requires_grad chk.recompute(x) assert x[0].requires_grad x.tensor.backward()