Пример #1
0
def create_coupling(Fm, Gm=None, coupling='additive', implementation_fwd=-1, implementation_bwd=-1, adapter=None):
    if coupling == 'additive':
        fn = AdditiveCoupling(Fm, Gm,
                              implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)
    elif coupling == 'affine':
        fn = AffineCoupling(Fm, Gm, adapter=adapter,
                            implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd)
    else:
        raise NotImplementedError('Unknown coupling method: %s' % coupling)
    return fn
Пример #2
0
def test_is_invertible_module():
    X = torch.zeros(1, 10, 10, 10)
    assert not is_invertible_module(torch.nn.Conv2d(10, 10, kernel_size=(1, 1)),
                                    test_input_shape=X.shape)
    fn = AdditiveCoupling(SubModule(), implementation_bwd=-1, implementation_fwd=-1)
    assert is_invertible_module(fn, test_input_shape=X.shape)
    class FakeInverse(torch.nn.Module):
        def forward(self, x):
            return x * 4

        def inverse(self, y):
            return y * 8
    assert not is_invertible_module(FakeInverse(), test_input_shape=X.shape)
Пример #3
0
    X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_()
    assert not is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
    with pytest.raises(RuntimeError):
        rm.forward(X)
    fn.multiply_forward = True
    rm.forward(X)
    assert not is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
    with pytest.raises(RuntimeError):
        rm.inverse(X)
    fn.multiply_inverse = True
    rm.inverse(X)
    assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)


@pytest.mark.parametrize('fn', [
    AdditiveCoupling(Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1),
    AffineCoupling(Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterNaive),
    AffineCoupling(Fm=SubModule(out_filters=10), implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterSigmoid),
    MultiplicationInverse()
])
@pytest.mark.parametrize('bwd', [False, True])
@pytest.mark.parametrize('keep_input', [False, True])
@pytest.mark.parametrize('keep_input_inverse', [False, True])
def test_invertible_module_wrapper_fwd_bwd(fn, bwd, keep_input, keep_input_inverse):
    """InvertibleModuleWrapper tests for the memory saving forward and backward passes

    * test inversion Y = RB(X) and X = RB.inverse(Y)
    * test training the block for a single step and compare weights for implementations: 0, 1
    * test automatic discard of input X and its retrieval after the backward pass
    * test usage of BN to identify non-contiguous memory blocks