def test_serial_checkpoints(device): # Copied from https://github.com/pytorch/pytorch/pull/18568. timeline = [] class Log(torch.autograd.Function): @staticmethod def forward(ctx, name, x): ctx.name = name timeline.append(f"{name}:forward") return x.detach() @staticmethod def backward(ctx, grad_output): name = ctx.name timeline.append(f"{name}:backward") return None, grad_output a = torch.rand(1, device=device, requires_grad=True) b = torch.rand(1, device=device, requires_grad=True) # Increase the next function sequence number. _ = a + 1 + 2 + 3 + 4 + 5 a = checkpoint(partial(Log.apply, "a"), a) a, phony = fork(a) b = join(b, phony) b = checkpoint(partial(Log.apply, "b"), b) c = torch.cat((a, b)) out = c.sum() # +--> {a} --Checkpoint(Log)--> {a} # {out} --Sum--> {c} --Cat ^-----------------------------+ # +--> {b} --Checkpoint(Log)--> {b} --First--> {b} out.backward() assert timeline == [ "a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward" ]
def test_non_grad_output(): class ForkNonGrad(nn.Module): def forward(self, input): return (input * 2, torch.rand(1)) model = ForkNonGrad() input = torch.rand(1, requires_grad=True) output = checkpoint(model, input) output[0].backward()
def test_not_requires_grad_with_parameter(): x = torch.rand(1, requires_grad=False) a = torch.rand(1, requires_grad=True) def f(x): return x * a y = checkpoint(f, x) y.backward() assert a.grad is not None
def test_random_in_checkpoint(device): dropout = nn.Dropout(p=0.5) torch.manual_seed(0) x = torch.randn(3, 3, device=device, requires_grad=True) y = dropout(x) y.norm().backward() torch.manual_seed(0) chk_x = torch.randn(3, 3, device=device, requires_grad=True) chk_y = checkpoint(dropout, chk_x) chk_y.norm().backward() assert torch.allclose(x.grad, chk_x.grad)
def test_detect_checkpointing_recomputing(): logs = [] class Detect(nn.Module): def forward(self, input): logs.append((is_checkpointing(), is_recomputing())) return input model = Detect() input = torch.rand(1, requires_grad=True) output = checkpoint(model, input) output.backward() assert logs == [(True, False), (False, True)]