def test_is_invertible_module_shared_outputs(): fnb = MultiSharedOutputs() X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_() with pytest.warns(UserWarning): assert is_invertible_module(fnb, test_input_shape=(X.shape, ), atol=1e-6)
def test_is_invertible_module(): X = torch.zeros(1, 10, 10, 10) assert not is_invertible_module( torch.nn.Conv2d(10, 10, kernel_size=(1, 1)), test_input_shape=X.shape) fn = AdditiveCoupling(SubModule(), implementation_bwd=-1, implementation_fwd=-1) assert is_invertible_module(fn, test_input_shape=X.shape) class FakeInverse(torch.nn.Module): def forward(self, x): return x * 4 def inverse(self, y): return y * 8 assert not is_invertible_module(FakeInverse(), test_input_shape=X.shape)
def test_is_invertible_module_shared_tensors(): fn = IdentityInverse() rm = InvertibleModuleWrapper(fn=fn, keep_input=True, keep_input_inverse=True) X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_() with pytest.warns(UserWarning): assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6) rm.forward(X) fn.multiply_forward = True rm.forward(X) with pytest.warns(UserWarning): assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6) rm.inverse(X) fn.multiply_inverse = True rm.inverse(X) assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
def test_is_invertible_module_type_check_input_shapes(input_shape): with pytest.raises(ValueError): is_invertible_module(module_in=IdentityInverse(multiply_forward=True, multiply_inverse=True), test_input_shape=input_shape)
def test_is_invertible_module_with_invalid_inverse(): fn = IdentityInverse(multiply_inverse=True) with torch.no_grad(): fn.factor.zero_() assert not is_invertible_module(fn, test_input_shape=(12, 12))
def test_is_invertible_module_random_seeds(random_seed): fn = IdentityInverse(multiply_forward=True, multiply_inverse=True) assert is_invertible_module(fn, test_input_shape=(1, ), random_seed=random_seed)
# generate some random input data (batch_size, num_channels, y_elements, x_elements) X = torch.rand(2, 10, 8, 8) # application of the operation(s) the normal way model_normal = ExampleOperation(channels=10) model_normal.eval() Y = model_normal(X) # turn the ExampleOperation invertible using an additive coupling invertible_module = memcnn.AdditiveCoupling( Fm=ExampleOperation(channels=10 // 2), Gm=ExampleOperation(channels=10 // 2)) # test that it is actually a valid invertible module (has a valid inverse method) assert memcnn.is_invertible_module(invertible_module, test_input_shape=X.shape) # wrap our invertible_module using the InvertibleModuleWrapper and benefit from memory savings during training invertible_module_wrapper = memcnn.InvertibleModuleWrapper( fn=invertible_module, keep_input=True, keep_input_inverse=True) # by default the module is set to training, the following sets this to evaluation # note that this is required to pass input tensors to the model with requires_grad=False (inference only) invertible_module_wrapper.eval() # test that the wrapped module is also a valid invertible module assert memcnn.is_invertible_module(invertible_module_wrapper, test_input_shape=X.shape) # compute the forward pass using the wrapper Y2 = invertible_module_wrapper.forward(X)