Ejemplo n.º 1
0
def test_not_requires_grad_with_parameter():
    x = torch.rand(1, requires_grad=False)
    a = torch.rand(1, requires_grad=True)

    def f(x):
        return x * a

    y = checkpoint(f, x)
    y.backward()

    assert a.grad is not None
Ejemplo n.º 2
0
def test_serial_checkpoints(device):
    # Copied from https://github.com/pytorch/pytorch/pull/18568.
    timeline = []

    class Log(torch.autograd.Function):
        @staticmethod
        def forward(ctx, name, x):
            ctx.name = name
            timeline.append(f'{name}:forward')
            return x.detach()

        @staticmethod
        def backward(ctx, grad_output):
            name = ctx.name
            timeline.append(f'{name}:backward')
            return None, grad_output

    a = torch.rand(1, device=device, requires_grad=True)
    b = torch.rand(1, device=device, requires_grad=True)

    # Increase the next function sequence number.
    _ = a + 1 + 2 + 3 + 4 + 5

    a = checkpoint(partial(Log.apply, 'a'), a)

    a, phony = fork(a)
    b = join(b, phony)

    b = checkpoint(partial(Log.apply, 'b'), b)

    c = torch.cat((a, b))

    out = c.sum()

    #                        +--> {a} --Checkpoint(Log)--> {a}
    # {out} --Sum--> {c} --Cat     ^-----------------------------+
    #                        +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
    out.backward()

    assert timeline == \
        ['a:forward', 'b:forward', 'b:forward', 'b:backward', 'a:forward', 'a:backward']
Ejemplo n.º 3
0
def test_random_in_checkpoint(device):
    dropout = nn.Dropout(p=0.5)

    torch.manual_seed(0)
    x = torch.randn(3, 3, device=device, requires_grad=True)
    y = dropout(x)
    y.norm().backward()

    torch.manual_seed(0)
    chk_x = torch.randn(3, 3, device=device, requires_grad=True)
    chk_y = checkpoint(dropout, chk_x)
    chk_y.norm().backward()

    assert torch.allclose(x.grad, chk_x.grad)
Ejemplo n.º 4
0
def test_linear_checkpoints(device):
    # Copied from https://github.com/pytorch/pytorch/pull/18568.
    timeline = []

    class Log(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x, name):
            ctx.name = name
            timeline.append('%s:forward' % name)
            return x.detach()

        @staticmethod
        def backward(ctx, grad_output):
            name = ctx.name
            timeline.append('%s:backward' % name)
            return grad_output, None

    a = torch.rand(1, requires_grad=True, device=device)
    b = torch.rand(1, requires_grad=True, device=device)

    # Increase the next function sequence number.
    _ = a + 1 + 2 + 3 + 4 + 5

    a, recompute_a = checkpoint(lambda a: Log.apply(a, 'a'), a)
    b = first(b, recompute_a)  # This line fixes this case.
    b, _ = checkpoint(lambda b: Log.apply(b, 'b'), b)
    c = torch.cat((a, b))

    out = c.sum()

    #                        +--> {a} --Checkpoint(Log)--> {a}
    # {out} --Sum--> {c} --Cat     ^-----------------------------+
    #                        +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
    out.backward()

    assert timeline == \
        ['a:forward', 'b:forward', 'b:forward', 'b:backward', 'a:forward', 'a:backward']
Ejemplo n.º 5
0
def test_detect_checkpointing_recomputing():
    logs = []

    class Detect(nn.Module):
        def forward(self, input):
            logs.append((is_checkpointing(), is_recomputing()))
            return input

    model = Detect()
    input = torch.rand(1, requires_grad=True)

    output = checkpoint(model, input)
    output.backward()

    assert logs == [(True, False), (False, True)]