def test_is_invertible_module_shared_outputs():
    fnb = MultiSharedOutputs()
    X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_()
    with pytest.warns(UserWarning):
        assert is_invertible_module(fnb,
                                    test_input_shape=(X.shape, ),
                                    atol=1e-6)
def test_is_invertible_module():
    X = torch.zeros(1, 10, 10, 10)
    assert not is_invertible_module(
        torch.nn.Conv2d(10, 10, kernel_size=(1, 1)), test_input_shape=X.shape)
    fn = AdditiveCoupling(SubModule(),
                          implementation_bwd=-1,
                          implementation_fwd=-1)
    assert is_invertible_module(fn, test_input_shape=X.shape)

    class FakeInverse(torch.nn.Module):
        def forward(self, x):
            return x * 4

        def inverse(self, y):
            return y * 8

    assert not is_invertible_module(FakeInverse(), test_input_shape=X.shape)
def test_is_invertible_module_shared_tensors():
    fn = IdentityInverse()
    rm = InvertibleModuleWrapper(fn=fn,
                                 keep_input=True,
                                 keep_input_inverse=True)
    X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_()
    with pytest.warns(UserWarning):
        assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
    rm.forward(X)
    fn.multiply_forward = True
    rm.forward(X)
    with pytest.warns(UserWarning):
        assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
    rm.inverse(X)
    fn.multiply_inverse = True
    rm.inverse(X)
    assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
def test_is_invertible_module_type_check_input_shapes(input_shape):
    with pytest.raises(ValueError):
        is_invertible_module(module_in=IdentityInverse(multiply_forward=True,
                                                       multiply_inverse=True),
                             test_input_shape=input_shape)
def test_is_invertible_module_with_invalid_inverse():
    fn = IdentityInverse(multiply_inverse=True)
    with torch.no_grad():
        fn.factor.zero_()
    assert not is_invertible_module(fn, test_input_shape=(12, 12))
def test_is_invertible_module_random_seeds(random_seed):
    fn = IdentityInverse(multiply_forward=True, multiply_inverse=True)
    assert is_invertible_module(fn,
                                test_input_shape=(1, ),
                                random_seed=random_seed)
Example #7
0
# generate some random input data (batch_size, num_channels, y_elements, x_elements)
X = torch.rand(2, 10, 8, 8)

# application of the operation(s) the normal way
model_normal = ExampleOperation(channels=10)
model_normal.eval()

Y = model_normal(X)

# turn the ExampleOperation invertible using an additive coupling
invertible_module = memcnn.AdditiveCoupling(
    Fm=ExampleOperation(channels=10 // 2),
    Gm=ExampleOperation(channels=10 // 2))

# test that it is actually a valid invertible module (has a valid inverse method)
assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape)

# wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training
invertible_module_wrapper = memcnn.InvertibleModuleWrapper(
    fn=invertible_module, keep_input=True, keep_input_inverse=True)

# by default the module is set to training, the following sets this to evaluation
# note that this is required to pass input tensors to the model with requires_grad=False (inference only)
invertible_module_wrapper.eval()

# test that the wrapped module is also a valid invertible module
assert memcnn.is_invertible_module(invertible_module_wrapper,
                                   test_input_shape=X.shape)

# compute the forward pass using the wrapper
Y2 = invertible_module_wrapper.forward(X)