예제 #1
0
def test_is_invertible_module():
    X = torch.zeros(1, 10, 10, 10)
    assert not is_invertible_module(torch.nn.Conv2d(10, 10, kernel_size=(1, 1)),
                                    test_input_shape=X.shape)
    fn = AdditiveCoupling(SubModule(), implementation_bwd=-1, implementation_fwd=-1)
    assert is_invertible_module(fn, test_input_shape=X.shape)
    class FakeInverse(torch.nn.Module):
        def forward(self, x):
            return x * 4

        def inverse(self, y):
            return y * 8
    assert not is_invertible_module(FakeInverse(), test_input_shape=X.shape)
예제 #2
0
def test_input_output_invertible_function_share_tensor():
    fn = IdentityInverse()
    rm = InvertibleModuleWrapper(fn=fn, keep_input=True, keep_input_inverse=True)
    X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_()
    assert not is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
    with pytest.raises(RuntimeError):
        rm.forward(X)
    fn.multiply_forward = True
    rm.forward(X)
    assert not is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
    with pytest.raises(RuntimeError):
        rm.inverse(X)
    fn.multiply_inverse = True
    rm.inverse(X)
    assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
예제 #3
0
def test_multi(disable):
    split = InvertibleModuleWrapper(SplitChannels(2), disable = disable)
    concat = InvertibleModuleWrapper(ConcatenateChannels(2), disable = disable)

    assert is_invertible_module(split, test_input_shape=(1, 3, 32, 32))
    assert is_invertible_module(concat, test_input_shape=((1, 2, 32, 32), (1, 1, 32, 32)))

    conv_a = torch.nn.Conv2d(2, 2, 3)
    conv_b = torch.nn.Conv2d(1, 1, 3)

    x = torch.rand(1, 3, 32, 32)
    x.requires_grad = True

    a, b = split(x)
    a, b = conv_a(a), conv_b(b)
    y = concat(a, b)
    loss = torch.sum(y)
    loss.backward()
예제 #4
0
def test_invertible_module_wrapper_fwd_bwd(fn, bwd, keep_input, keep_input_inverse):
    """InvertibleModuleWrapper tests for the memory saving forward and backward passes

    * test inversion Y = RB(X) and X = RB.inverse(Y)
    * test training the block for a single step and compare weights for implementations: 0, 1
    * test automatic discard of input X and its retrieval after the backward pass
    * test usage of BN to identify non-contiguous memory blocks

    """
    for seed in range(10):
        set_seeds(seed)
        dims = (2, 10, 8, 8)
        data = torch.rand(*dims, dtype=torch.float32)
        target_data = torch.rand(*dims, dtype=torch.float32)

        assert is_invertible_module(fn, test_input_shape=data.shape, atol=1e-4)

        # test with zero padded convolution
        with torch.set_grad_enabled(True):
            X = data.clone().requires_grad_()

            Ytarget = target_data.clone()

            Xshape = X.shape

            rb = InvertibleModuleWrapper(fn=fn, keep_input=keep_input, keep_input_inverse=keep_input_inverse)
            s_grad = [p.detach().clone() for p in rb.parameters()]

            rb.train()
            rb.zero_grad()

            optim = torch.optim.RMSprop(rb.parameters())
            optim.zero_grad()
            if not bwd:
                Xin = X.clone().requires_grad_()
                Y = rb(Xin)
                Yrev = Y.detach().clone().requires_grad_()
                Xinv = rb.inverse(Yrev)
            else:
                Xin = X.clone().requires_grad_()
                Y = rb.inverse(Xin)
                Yrev = Y.detach().clone().requires_grad_()
                Xinv = rb(Yrev)
            loss = torch.nn.MSELoss()(Y, Ytarget)

            # has input been retained/discarded after forward (and backward) passes?

            if not bwd:
                assert is_memory_cleared(Yrev, not keep_input_inverse, Xshape)
                assert is_memory_cleared(Xin, not keep_input, Xshape)
            else:
                assert is_memory_cleared(Xin, not keep_input_inverse, Xshape)
                assert is_memory_cleared(Yrev, not keep_input, Xshape)

            optim.zero_grad()

            loss.backward()
            optim.step()

            assert Y.shape == Xshape
            assert X.detach().shape == data.shape
            assert torch.allclose(X.detach(), data, atol=1e-06)
            assert torch.allclose(X.detach(), Xinv.detach(), atol=1e-05)  # Model is now trained and will differ
            grads = [p.detach().clone() for p in rb.parameters()]

            assert not torch.allclose(grads[0], s_grad[0])