コード例 #1
0
ファイル: test_revop.py プロジェクト: mawright/memcnn
def test_invertible_module_wrapper_simple_inverse(coupling):
    """InvertibleModuleWrapper inverse test"""
    for seed in range(10):
        set_seeds(seed)
        # define some data
        X = torch.rand(2, 4, 5, 5).requires_grad_()

        # define an arbitrary reversible function
        coupling_fn = create_coupling(Fm=torch.nn.Conv2d(2, 2, 3, padding=1),
                                      coupling=coupling,
                                      implementation_fwd=-1,
                                      implementation_bwd=-1,
                                      adapter=AffineAdapterNaive)
        fn = InvertibleModuleWrapper(fn=coupling_fn,
                                     keep_input=False,
                                     keep_input_inverse=False)

        # compute output
        Y = fn.forward(X.clone())

        # compute input from output
        X2 = fn.inverse(Y)

        # check that the inverted output and the original input are approximately similar
        assert torch.allclose(X2.detach(), X.detach(), atol=1e-06)
コード例 #2
0
def test_input_output_invertible_function_share_tensor():
    fn = IdentityInverse()
    rm = InvertibleModuleWrapper(fn=fn, keep_input=True, keep_input_inverse=True)
    X = torch.rand(1, 2, 5, 5, dtype=torch.float32).requires_grad_()
    assert not is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
    with pytest.raises(RuntimeError):
        rm.forward(X)
    fn.multiply_forward = True
    rm.forward(X)
    assert not is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
    with pytest.raises(RuntimeError):
        rm.inverse(X)
    fn.multiply_inverse = True
    rm.inverse(X)
    assert is_invertible_module(fn, test_input_shape=X.shape, atol=1e-6)
コード例 #3
0
 def __init__(self, Gm, coupling='additive', depth=10, implementation_fwd=-1, implementation_bwd=-1,
              keep_input=False, adapter=None):
     super(SubModuleStack, self).__init__()
     fn = create_coupling(Fm=Gm, Gm=Gm, coupling=coupling, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd, adapter=adapter)
     self.stack = torch.nn.ModuleList(
         [InvertibleModuleWrapper(fn=fn, keep_input=keep_input, keep_input_inverse=keep_input) for _ in range(depth)]
     )
コード例 #4
0
ファイル: test_multi.py プロジェクト: silvandeleemput/memcnn
def test_multi(disable):
    split = InvertibleModuleWrapper(SplitChannels(2), disable = disable)
    concat = InvertibleModuleWrapper(ConcatenateChannels(2), disable = disable)

    assert is_invertible_module(split, test_input_shape=(1, 3, 32, 32))
    assert is_invertible_module(concat, test_input_shape=((1, 2, 32, 32), (1, 1, 32, 32)))

    conv_a = torch.nn.Conv2d(2, 2, 3)
    conv_b = torch.nn.Conv2d(1, 1, 3)

    x = torch.rand(1, 3, 32, 32)
    x.requires_grad = True

    a, b = split(x)
    a, b = conv_a(a), conv_b(b)
    y = concat(a, b)
    loss = torch.sum(y)
    loss.backward()
コード例 #5
0
 def __init__(self,
              inplanes,
              planes,
              stride=1,
              downsample=None,
              noactivation=False):
     super(RevBasicBlock, self).__init__()
     if downsample is None and stride == 1:
         gm = BasicBlockSub(inplanes // 2, planes // 2, stride,
                            noactivation)
         fm = BasicBlockSub(inplanes // 2, planes // 2, stride,
                            noactivation)
         coupling = create_coupling(Fm=fm, Gm=gm, coupling='additive')
         self.revblock = InvertibleModuleWrapper(fn=coupling,
                                                 keep_input=False)
     else:
         self.basicblock_sub = BasicBlockSub(inplanes, planes, stride,
                                             noactivation)
     self.downsample = downsample
     self.stride = stride
コード例 #6
0
def test_normal_vs_invertible_module_wrapper(coupling):
    """InvertibleModuleWrapper test if similar gradients and weights results are obtained after similar training"""
    for seed in range(10):
        set_seeds(seed)

        X = torch.rand(2, 4, 5, 5)

        # define models and their copies
        c1 = torch.nn.Conv2d(2, 2, 3, padding=1)
        c2 = torch.nn.Conv2d(2, 2, 3, padding=1)
        c1_2 = copy.deepcopy(c1)
        c2_2 = copy.deepcopy(c2)

        # are weights between models the same, but do they differ between convolutions?
        assert torch.equal(c1.weight, c1_2.weight)
        assert torch.equal(c2.weight, c2_2.weight)
        assert torch.equal(c1.bias, c1_2.bias)
        assert torch.equal(c2.bias, c2_2.bias)
        assert not torch.equal(c1.weight, c2.weight)

        # define optimizers
        optim1 = torch.optim.SGD([e for e in c1.parameters()] + [e for e in c2.parameters()], 0.1)
        optim2 = torch.optim.SGD([e for e in c1_2.parameters()] + [e for e in c2_2.parameters()], 0.1)
        for e in [c1, c2, c1_2, c2_2]:
            e.train()

        # define an arbitrary reversible function and define graph for model 1
        Xin = X.clone().requires_grad_()
        coupling_fn = create_coupling(Fm=c1_2, Gm=c2_2, coupling=coupling, implementation_fwd=-1,
                                      implementation_bwd=-1, adapter=AffineAdapterNaive)
        fn = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False)

        Y = fn.forward(Xin)
        loss2 = torch.mean(Y)

        # define the reversible function without custom backprop and define graph for model 2
        XX = X.clone().detach().requires_grad_()
        x1, x2 = torch.chunk(XX, 2, dim=1)
        if coupling == 'additive':
            y1 = x1 + c1.forward(x2)
            y2 = x2 + c2.forward(y1)
        elif coupling == 'affine':
            fmr2 = c1.forward(x2)
            fmr1 = torch.exp(fmr2)
            y1 = (x1 * fmr1) + fmr2
            gmr2 = c2.forward(y1)
            gmr1 = torch.exp(gmr2)
            y2 = (x2 * gmr1) + gmr2
        else:
            raise NotImplementedError()
        YY = torch.cat([y1, y2], dim=1)

        loss = torch.mean(YY)

        # compute gradients manually
        grads = torch.autograd.grad(loss, (XX, c1.weight, c2.weight, c1.bias, c2.bias), None, retain_graph=True)

        # compute gradients and perform optimization model 2
        loss.backward()
        optim1.step()

        # gradients computed manually match those of the .backward() pass
        assert torch.equal(c1.weight.grad, grads[1])
        assert torch.equal(c2.weight.grad, grads[2])
        assert torch.equal(c1.bias.grad, grads[3])
        assert torch.equal(c2.bias.grad, grads[4])

        # weights differ after training a single model?
        assert not torch.equal(c1.weight, c1_2.weight)
        assert not torch.equal(c2.weight, c2_2.weight)
        assert not torch.equal(c1.bias, c1_2.bias)
        assert not torch.equal(c2.bias, c2_2.bias)

        # compute gradients and perform optimization model 1
        loss2.backward()
        optim2.step()

        # input is contiguous tests
        assert Xin.is_contiguous()
        assert Y.is_contiguous()

        # weights are approximately the same after training both models?
        assert torch.allclose(c1.weight.detach(), c1_2.weight.detach())
        assert torch.allclose(c2.weight.detach(), c2_2.weight.detach())
        assert torch.allclose(c1.bias.detach(), c1_2.bias.detach())
        assert torch.allclose(c2.bias.detach(), c2_2.bias.detach())

        # gradients are approximately the same after training both models?
        assert torch.allclose(c1.weight.grad.detach(), c1_2.weight.grad.detach())
        assert torch.allclose(c2.weight.grad.detach(), c2_2.weight.grad.detach())
        assert torch.allclose(c1.bias.grad.detach(), c1_2.bias.grad.detach())
        assert torch.allclose(c2.bias.grad.detach(), c2_2.bias.grad.detach())
コード例 #7
0
def test_invertible_module_wrapper_disabled_versus_enabled(inverted):
    set_seeds(42)
    Gm = SubModule(in_filters=5, out_filters=5)

    coupling_fn = create_coupling(Fm=Gm, Gm=Gm, coupling='additive', implementation_fwd=-1,
                                  implementation_bwd=-1)
    rb = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False)
    rb2 = InvertibleModuleWrapper(fn=copy.deepcopy(coupling_fn), keep_input=False, keep_input_inverse=False)
    rb.eval()
    rb2.eval()
    rb2.disable = True
    with torch.no_grad():
        dims = (2, 10, 8, 8)
        data = torch.rand(*dims, dtype=torch.float32)
        X, X2 = data.clone().detach().requires_grad_(), data.clone().detach().requires_grad_()
        if not inverted:
            Y = rb(X)
            Y2 = rb2(X2)
        else:
            Y = rb.inverse(X)
            Y2 = rb2.inverse(X2)

        assert torch.allclose(Y, Y2)

        assert is_memory_cleared(X, True, dims)
        assert is_memory_cleared(X2, False, dims)
コード例 #8
0
def test_invertible_module_wrapper_fwd_bwd(fn, bwd, keep_input, keep_input_inverse):
    """InvertibleModuleWrapper tests for the memory saving forward and backward passes

    * test inversion Y = RB(X) and X = RB.inverse(Y)
    * test training the block for a single step and compare weights for implementations: 0, 1
    * test automatic discard of input X and its retrieval after the backward pass
    * test usage of BN to identify non-contiguous memory blocks

    """
    for seed in range(10):
        set_seeds(seed)
        dims = (2, 10, 8, 8)
        data = torch.rand(*dims, dtype=torch.float32)
        target_data = torch.rand(*dims, dtype=torch.float32)

        assert is_invertible_module(fn, test_input_shape=data.shape, atol=1e-4)

        # test with zero padded convolution
        with torch.set_grad_enabled(True):
            X = data.clone().requires_grad_()

            Ytarget = target_data.clone()

            Xshape = X.shape

            rb = InvertibleModuleWrapper(fn=fn, keep_input=keep_input, keep_input_inverse=keep_input_inverse)
            s_grad = [p.detach().clone() for p in rb.parameters()]

            rb.train()
            rb.zero_grad()

            optim = torch.optim.RMSprop(rb.parameters())
            optim.zero_grad()
            if not bwd:
                Xin = X.clone().requires_grad_()
                Y = rb(Xin)
                Yrev = Y.detach().clone().requires_grad_()
                Xinv = rb.inverse(Yrev)
            else:
                Xin = X.clone().requires_grad_()
                Y = rb.inverse(Xin)
                Yrev = Y.detach().clone().requires_grad_()
                Xinv = rb(Yrev)
            loss = torch.nn.MSELoss()(Y, Ytarget)

            # has input been retained/discarded after forward (and backward) passes?

            if not bwd:
                assert is_memory_cleared(Yrev, not keep_input_inverse, Xshape)
                assert is_memory_cleared(Xin, not keep_input, Xshape)
            else:
                assert is_memory_cleared(Xin, not keep_input_inverse, Xshape)
                assert is_memory_cleared(Yrev, not keep_input, Xshape)

            optim.zero_grad()

            loss.backward()
            optim.step()

            assert Y.shape == Xshape
            assert X.detach().shape == data.shape
            assert torch.allclose(X.detach(), data, atol=1e-06)
            assert torch.allclose(X.detach(), Xinv.detach(), atol=1e-05)  # Model is now trained and will differ
            grads = [p.detach().clone() for p in rb.parameters()]

            assert not torch.allclose(grads[0], s_grad[0])