def test_is_invertible_module(): X = torch.zeros(1, 10, 10, 10) assert not is_invertible_module(torch.nn.Conv2d(10, 10, kernel_size=(1, 1)), test_input_shape=X.shape) fn = AdditiveCoupling(SubModule(), implementation_bwd=-1, implementation_fwd=-1) assert is_invertible_module(fn, test_input_shape=X.shape) class FakeInverse(torch.nn.Module): def forward(self, x): return x * 4 def inverse(self, y): return y * 8 assert not is_invertible_module(FakeInverse(), test_input_shape=X.shape)
def test_input_output_invertible_function_share_tensor(): fn = IdentityInverse() rm = InvertibleModuleWrapper(fn=fn, keep_input=True, keep_input_inverse=True) X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_() assert not is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6) with pytest.raises(RuntimeError): rm.forward(X) fn.multiply_forward = True rm.forward(X) assert not is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6) with pytest.raises(RuntimeError): rm.inverse(X) fn.multiply_inverse = True rm.inverse(X) assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
def test_multi(disable): split = InvertibleModuleWrapper(SplitChannels(2), disable = disable) concat = InvertibleModuleWrapper(ConcatenateChannels(2), disable = disable) assert is_invertible_module(split, test_input_shape=(1, 3, 32, 32)) assert is_invertible_module(concat, test_input_shape=((1, 2, 32, 32), (1, 1, 32, 32))) conv_a = torch.nn.Conv2d(2, 2, 3) conv_b = torch.nn.Conv2d(1, 1, 3) x = torch.rand(1, 3, 32, 32) x.requires_grad = True a, b = split(x) a, b = conv_a(a), conv_b(b) y = concat(a, b) loss = torch.sum(y) loss.backward()
def test_invertible_module_wrapper_fwd_bwd(fn, bwd, keep_input, keep_input_inverse): """InvertibleModuleWrapper tests for the memory saving forward and backward passes * test inversion Y = RB(X) and X = RB.inverse(Y) * test training the block for a single step and compare weights for implementations: 0, 1 * test automatic discard of input X and its retrieval after the backward pass * test usage of BN to identify non-contiguous memory blocks """ for seed in range(10): set_seeds(seed) dims = (2, 10, 8, 8) data = torch.rand(*dims, dtype=torch.float32) target_data = torch.rand(*dims, dtype=torch.float32) assert is_invertible_module(fn, test_input_shape=data.shape, atol=1e-4) # test with zero padded convolution with torch.set_grad_enabled(True): X = data.clone().requires_grad_() Ytarget = target_data.clone() Xshape = X.shape rb = InvertibleModuleWrapper(fn=fn, keep_input=keep_input, keep_input_inverse=keep_input_inverse) s_grad = [p.detach().clone() for p in rb.parameters()] rb.train() rb.zero_grad() optim = torch.optim.RMSprop(rb.parameters()) optim.zero_grad() if not bwd: Xin = X.clone().requires_grad_() Y = rb(Xin) Yrev = Y.detach().clone().requires_grad_() Xinv = rb.inverse(Yrev) else: Xin = X.clone().requires_grad_() Y = rb.inverse(Xin) Yrev = Y.detach().clone().requires_grad_() Xinv = rb(Yrev) loss = torch.nn.MSELoss()(Y, Ytarget) # has input been retained/discarded after forward (and backward) passes? if not bwd: assert is_memory_cleared(Yrev, not keep_input_inverse, Xshape) assert is_memory_cleared(Xin, not keep_input, Xshape) else: assert is_memory_cleared(Xin, not keep_input_inverse, Xshape) assert is_memory_cleared(Yrev, not keep_input, Xshape) optim.zero_grad() loss.backward() optim.step() assert Y.shape == Xshape assert X.detach().shape == data.shape assert torch.allclose(X.detach(), data, atol=1e-06) assert torch.allclose(X.detach(), Xinv.detach(), atol=1e-05) # Model is now trained and will differ grads = [p.detach().clone() for p in rb.parameters()] assert not torch.allclose(grads[0], s_grad[0])