Esempio n. 1
0
def test_chained_invertible_module_wrapper(coupling, adapter):
    set_seeds(42)
    dims = (2, 10, 8, 8)
    data = torch.rand(*dims, dtype=torch.float32)
    target_data = torch.rand(*dims, dtype=torch.float32)
    with torch.set_grad_enabled(True):
        X = data.clone().requires_grad_()
        Ytarget = target_data.clone()

        Gm = SubModule(in_filters=5,
                       out_filters=5 if coupling == 'additive'
                       or adapter is AffineAdapterNaive else 10)
        rb = SubModuleStack(Gm,
                            coupling=coupling,
                            depth=2,
                            keep_input=False,
                            adapter=adapter,
                            implementation_bwd=-1,
                            implementation_fwd=-1)
        rb.train()
        optim = torch.optim.RMSprop(rb.parameters())

        rb.zero_grad()

        optim.zero_grad()

        Xin = X.clone()
        Y = rb(Xin)

        loss = torch.nn.MSELoss()(Y, Ytarget)

        loss.backward()
        optim.step()

    assert not torch.isnan(loss)
Esempio n. 2
0
def test_invertible_module_wrapper_disabled_versus_enabled(inverted):
    set_seeds(42)
    Gm = SubModule(in_filters=5, out_filters=5)

    coupling_fn = create_coupling(Fm=Gm,
                                  Gm=Gm,
                                  coupling='additive',
                                  implementation_fwd=-1,
                                  implementation_bwd=-1)
    rb = InvertibleModuleWrapper(fn=coupling_fn,
                                 keep_input=False,
                                 keep_input_inverse=False)
    rb2 = InvertibleModuleWrapper(fn=copy.deepcopy(coupling_fn),
                                  keep_input=False,
                                  keep_input_inverse=False)
    rb.eval()
    rb2.eval()
    rb2.disable = True
    with torch.no_grad():
        dims = (2, 10, 8, 8)
        data = torch.rand(*dims, dtype=torch.float32)
        X, X2 = data.clone().detach().requires_grad_(), data.clone().detach(
        ).requires_grad_()
        if not inverted:
            Y = rb(X)
            Y2 = rb2(X2)
        else:
            Y = rb.inverse(X)
            Y2 = rb2.inverse(X2)

        assert torch.allclose(Y, Y2)

        assert is_memory_cleared(X, True, dims)
        assert is_memory_cleared(X2, False, dims)
def test_memory_saving_invertible_model_wrapper(device, coupling, keep_input):
    """Test memory saving of the invertible model wrapper

    * tests fitting a large number of images by creating a deep network requiring large
      intermediate feature maps for training

    * keep_input = False should use less memory than keep_input = True on both GPU and CPU RAM

    * input size in bytes:            np.prod((2, 10, 10, 10)) * 4 / 1024.0 =  7.8125 kB
      for a depth=5 this yields                                  7.8125 * 5 = 39.0625 kB

    """

    if device == 'cpu':
        pytest.skip('Unreliable metrics, should be fixed.')

    if device == 'cuda' and not torch.cuda.is_available():
        pytest.skip('This test requires a GPU to be available')

    gc.disable()
    gc.collect()

    with torch.set_grad_enabled(True):
        dims = [2, 10, 10, 10]
        depth = 5

        xx = torch.rand(*dims, device=device, dtype=torch.float32).requires_grad_()
        ytarget = torch.rand(*dims, device=device, dtype=torch.float32)

        # same convolution test
        network = SubModuleStack(SubModule(in_filters=5, out_filters=5), depth=depth, keep_input=keep_input, coupling=coupling,
                                 implementation_fwd=-1, implementation_bwd=-1)
        network.to(device)
        network.train()
        network.zero_grad()
        optim = torch.optim.RMSprop(network.parameters())
        optim.zero_grad()
        mem_start = 0 if not device == 'cuda' else \
            torch.cuda.memory_allocated() / float(1024 ** 2)

        y = network(xx)
        gc.collect()
        mem_after_forward = torch.cuda.memory_allocated() / float(1024 ** 2)
        loss = torch.nn.MSELoss()(y, ytarget)
        optim.zero_grad()
        loss.backward()
        optim.step()
        gc.collect()
        # mem_after_backward = torch.cuda.memory_allocated() / float(1024 ** 2)
        gc.enable()

        memuse = float(np.prod(dims + [depth, 4, ])) / float(1024 ** 2)

        measured_memuse = mem_after_forward - mem_start
        if keep_input:
            assert measured_memuse >= memuse
        else:
            assert measured_memuse < 1
def test_is_invertible_module():
    X = torch.zeros(1, 10, 10, 10)
    assert not is_invertible_module(
        torch.nn.Conv2d(10, 10, kernel_size=(1, 1)), test_input_shape=X.shape)
    fn = AdditiveCoupling(SubModule(),
                          implementation_bwd=-1,
                          implementation_fwd=-1)
    assert is_invertible_module(fn, test_input_shape=X.shape)

    class FakeInverse(torch.nn.Module):
        def forward(self, x):
            return x * 4

        def inverse(self, y):
            return y * 8

    assert not is_invertible_module(FakeInverse(), test_input_shape=X.shape)
Esempio n. 5
0
def test_chained_invertible_module_wrapper_shared_fwd_and_bwd_train_passes():
    set_seeds(42)
    Gm = SubModule(in_filters=5, out_filters=5)
    rb_temp = SubModuleStack(Gm=Gm,
                             coupling='additive',
                             depth=5,
                             keep_input=True,
                             adapter=None,
                             implementation_bwd=-1,
                             implementation_fwd=-1)
    optim = torch.optim.SGD(rb_temp.parameters(), lr=0.01)

    initial_params = [p.detach().clone() for p in rb_temp.parameters()]
    initial_state = copy.deepcopy(rb_temp.state_dict())
    initial_optim_state = copy.deepcopy(optim.state_dict())

    dims = (2, 10, 8, 8)
    data = torch.rand(*dims, dtype=torch.float32)
    target_data = torch.rand(*dims, dtype=torch.float32)

    forward_outputs = []
    inverse_outputs = []
    for i in range(10):

        is_forward_pass = i % 2 == 0
        set_seeds(42)
        rb = SubModuleStack(Gm=Gm,
                            coupling='additive',
                            depth=5,
                            keep_input=True,
                            adapter=None,
                            implementation_bwd=-1,
                            implementation_fwd=-1,
                            num_bwd_passes=2)
        rb.train()
        with torch.no_grad():
            for (name, p), p_initial in zip(rb.named_parameters(),
                                            initial_params):
                p.set_(p_initial)

        rb.load_state_dict(initial_state)
        optim = torch.optim.SGD(rb_temp.parameters(), lr=0.01)
        optim.load_state_dict(initial_optim_state)

        with torch.set_grad_enabled(True):
            X = data.detach().clone().requires_grad_()
            Ytarget = target_data.detach().clone()

            optim.zero_grad()

            if is_forward_pass:
                Y = rb(X)
                Xinv = rb.inverse(Y)
                Xinv2 = rb.inverse(Y)
                Xinv3 = rb.inverse(Y)
            else:
                Y = rb.inverse(X)
                Xinv = rb(Y)
                Xinv2 = rb(Y)
                Xinv3 = rb(Y)

            for item in [Xinv, Xinv2, Xinv3]:
                assert torch.allclose(X, item, atol=1e-04)

            loss = torch.nn.MSELoss()(Xinv, Ytarget)
            assert not torch.isnan(loss)

            assert Xinv2.grad is None
            assert Xinv3.grad is None

            loss.backward()

            assert Y.grad is not None
            assert Xinv.grad is not None
            assert Xinv2.grad is None
            assert Xinv3.grad is None

            loss2 = torch.nn.MSELoss()(Xinv2, Ytarget)
            assert not torch.isnan(loss2)

            loss2.backward()

            assert Xinv2.grad is not None

            optim.step()

            if is_forward_pass:
                forward_outputs.append(Y.detach().clone())
            else:
                inverse_outputs.append(Y.detach().clone())

    for i in range(4):
        assert torch.allclose(forward_outputs[-1],
                              forward_outputs[i],
                              atol=1e-06)
        assert torch.allclose(inverse_outputs[-1],
                              inverse_outputs[i],
                              atol=1e-06)
Esempio n. 6
0
                                adapter=AffineAdapterNaive)
            assert isinstance(f, InvertibleModuleWrapper)
            f.inverse(X)
    with pytest.raises(NotImplementedError):
        with warnings.catch_warnings():
            warnings.simplefilter(action='ignore', category=DeprecationWarning)
            ReversibleBlock(fm,
                            coupling='unknown',
                            implementation_bwd=-2,
                            implementation_fwd=0,
                            adapter=AffineAdapterNaive)


@pytest.mark.parametrize('fn', [
    AdditiveCoupling(
        Fm=SubModule(), implementation_fwd=-1, implementation_bwd=-1),
    AffineCoupling(Fm=SubModule(),
                   implementation_fwd=-1,
                   implementation_bwd=-1,
                   adapter=AffineAdapterNaive),
    AffineCoupling(Fm=SubModule(out_filters=10),
                   implementation_fwd=-1,
                   implementation_bwd=-1,
                   adapter=AffineAdapterSigmoid),
    MultiplicationInverse()
])
@pytest.mark.parametrize('bwd', [False, True])
@pytest.mark.parametrize('keep_input', [False, True])
@pytest.mark.parametrize('keep_input_inverse', [False, True])
def test_invertible_module_wrapper_fwd_bwd(fn, bwd, keep_input,
                                           keep_input_inverse):
Esempio n. 7
0
def test_legacy_affine_coupling():
    with warnings.catch_warnings():
        warnings.simplefilter(action='ignore', category=DeprecationWarning)
        AffineBlock(Fm=SubModule())