Exemplo n.º 1
0
def test_serial_checkpoints(device):
    # Copied from https://github.com/pytorch/pytorch/pull/18568.
    timeline = []

    class Log(torch.autograd.Function):
        @staticmethod
        def forward(ctx, name, x):
            ctx.name = name
            timeline.append(f"{name}:forward")
            return x.detach()

        @staticmethod
        def backward(ctx, grad_output):
            name = ctx.name
            timeline.append(f"{name}:backward")
            return None, grad_output

    a = torch.rand(1, device=device, requires_grad=True)
    b = torch.rand(1, device=device, requires_grad=True)

    # Increase the next function sequence number.
    _ = a + 1 + 2 + 3 + 4 + 5

    a = checkpoint(partial(Log.apply, "a"), a)

    a, phony = fork(a)
    b = join(b, phony)

    b = checkpoint(partial(Log.apply, "b"), b)

    c = torch.cat((a, b))

    out = c.sum()

    #                        +--> {a} --Checkpoint(Log)--> {a}
    # {out} --Sum--> {c} --Cat     ^-----------------------------+
    #                        +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
    out.backward()

    assert timeline == [
        "a:forward", "b:forward", "b:forward", "b:backward", "a:forward",
        "a:backward"
    ]
Exemplo n.º 2
0
def test_non_grad_output():
    class ForkNonGrad(nn.Module):
        def forward(self, input):
            return (input * 2, torch.rand(1))

    model = ForkNonGrad()
    input = torch.rand(1, requires_grad=True)

    output = checkpoint(model, input)
    output[0].backward()
Exemplo n.º 3
0
def test_not_requires_grad_with_parameter():
    x = torch.rand(1, requires_grad=False)
    a = torch.rand(1, requires_grad=True)

    def f(x):
        return x * a

    y = checkpoint(f, x)
    y.backward()

    assert a.grad is not None
Exemplo n.º 4
0
def test_random_in_checkpoint(device):
    dropout = nn.Dropout(p=0.5)

    torch.manual_seed(0)
    x = torch.randn(3, 3, device=device, requires_grad=True)
    y = dropout(x)
    y.norm().backward()

    torch.manual_seed(0)
    chk_x = torch.randn(3, 3, device=device, requires_grad=True)
    chk_y = checkpoint(dropout, chk_x)
    chk_y.norm().backward()

    assert torch.allclose(x.grad, chk_x.grad)
Exemplo n.º 5
0
def test_detect_checkpointing_recomputing():
    logs = []

    class Detect(nn.Module):
        def forward(self, input):
            logs.append((is_checkpointing(), is_recomputing()))
            return input

    model = Detect()
    input = torch.rand(1, requires_grad=True)

    output = checkpoint(model, input)
    output.backward()

    assert logs == [(True, False), (False, True)]