def create_coupling(Fm, Gm=None, coupling='additive', implementation_fwd=-1, implementation_bwd=-1, adapter=None): if coupling == 'additive': fn = AdditiveCoupling(Fm, Gm, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd) elif coupling == 'affine': fn = AffineCoupling(Fm, Gm, adapter=adapter, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd) else: raise NotImplementedError('Unknown coupling method: %s' % coupling) return fn
def test_is_invertible_module(): X = torch.zeros(1, 10, 10, 10) assert not is_invertible_module(torch.nn.Conv2d(10, 10, kernel_size=(1, 1)), test_input_shape=X.shape) fn = AdditiveCoupling(SubModule(), implementation_bwd=-1, implementation_fwd=-1) assert is_invertible_module(fn, test_input_shape=X.shape) class FakeInverse(torch.nn.Module): def forward(self, x): return x * 4 def inverse(self, y): return y * 8 assert not is_invertible_module(FakeInverse(), test_input_shape=X.shape)
X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_() assert not is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6) with pytest.raises(RuntimeError): rm.forward(X) fn.multiply_forward = True rm.forward(X) assert not is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6) with pytest.raises(RuntimeError): rm.inverse(X) fn.multiply_inverse = True rm.inverse(X) assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6) @pytest.mark.parametrize('fn', [ AdditiveCoupling(Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1), AffineCoupling(Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterNaive), AffineCoupling(Fm=SubModule(out_filters=10), implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterSigmoid), MultiplicationInverse() ]) @pytest.mark.parametrize('bwd', [False, True]) @pytest.mark.parametrize('keep_input', [False, True]) @pytest.mark.parametrize('keep_input_inverse', [False, True]) def test_invertible_module_wrapper_fwd_bwd(fn, bwd, keep_input, keep_input_inverse): """InvertibleModuleWrapper tests for the memory saving forward and backward passes * test inversion Y = RB(X) and X = RB.inverse(Y) * test training the block for a single step and compare weights for implementations: 0, 1 * test automatic discard of input X and its retrieval after the backward pass * test usage of BN to identify non-contiguous memory blocks