Beispiel #1
0
def test_invertible_module_wrapper_disabled_versus_enabled(inverted):
    set_seeds(42)
    Gm = SubModule(in_filters=5, out_filters=5)

    coupling_fn = create_coupling(Fm=Gm, Gm=Gm, coupling='additive', implementation_fwd=-1,
                                  implementation_bwd=-1)
    rb = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False)
    rb2 = InvertibleModuleWrapper(fn=copy.deepcopy(coupling_fn), keep_input=False, keep_input_inverse=False)
    rb.eval()
    rb2.eval()
    rb2.disable = True
    with torch.no_grad():
        dims = (2, 10, 8, 8)
        data = torch.rand(*dims, dtype=torch.float32)
        X, X2 = data.clone().detach().requires_grad_(), data.clone().detach().requires_grad_()
        if not inverted:
            Y = rb(X)
            Y2 = rb2(X2)
        else:
            Y = rb.inverse(X)
            Y2 = rb2.inverse(X2)

        assert torch.allclose(Y, Y2)

        assert is_memory_cleared(X, True, dims)
        assert is_memory_cleared(X2, False, dims)
Beispiel #2
0
def test_invertible_module_wrapper_simple_inverse(coupling):
    """InvertibleModuleWrapper inverse test"""
    for seed in range(10):
        set_seeds(seed)
        # define some data
        X = torch.rand(2, 4, 5, 5).requires_grad_()

        # define an arbitrary reversible function
        coupling_fn = create_coupling(Fm=torch.nn.Conv2d(2, 2, 3, padding=1),
                                      coupling=coupling,
                                      implementation_fwd=-1,
                                      implementation_bwd=-1,
                                      adapter=AffineAdapterNaive)
        fn = InvertibleModuleWrapper(fn=coupling_fn,
                                     keep_input=False,
                                     keep_input_inverse=False)

        # compute output
        Y = fn.forward(X.clone())

        # compute input from output
        X2 = fn.inverse(Y)

        # check that the inverted output and the original input are approximately similar
        assert torch.allclose(X2.detach(), X.detach(), atol=1e-06)
Beispiel #3
0
 def __init__(self, Gm, coupling='additive', depth=10, implementation_fwd=-1, implementation_bwd=-1,
              keep_input=False, adapter=None):
     super(SubModuleStack, self).__init__()
     fn = create_coupling(Fm=Gm, Gm=Gm, coupling=coupling, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd, adapter=adapter)
     self.stack = torch.nn.ModuleList(
         [InvertibleModuleWrapper(fn=fn, keep_input=keep_input, keep_input_inverse=keep_input) for _ in range(depth)]
     )
Beispiel #4
0
 def __init__(self,
              inplanes,
              planes,
              stride=1,
              downsample=None,
              noactivation=False):
     super(RevBasicBlock, self).__init__()
     if downsample is None and stride == 1:
         gm = BasicBlockSub(inplanes // 2, planes // 2, stride,
                            noactivation)
         fm = BasicBlockSub(inplanes // 2, planes // 2, stride,
                            noactivation)
         coupling = create_coupling(Fm=fm, Gm=gm, coupling='additive')
         self.revblock = InvertibleModuleWrapper(fn=coupling,
                                                 keep_input=False)
     else:
         self.basicblock_sub = BasicBlockSub(inplanes, planes, stride,
                                             noactivation)
     self.downsample = downsample
     self.stride = stride
Beispiel #5
0
def test_normal_vs_invertible_module_wrapper(coupling):
    """InvertibleModuleWrapper test if similar gradients and weights results are obtained after similar training"""
    for seed in range(10):
        set_seeds(seed)

        X = torch.rand(2, 4, 5, 5)

        # define models and their copies
        c1 = torch.nn.Conv2d(2, 2, 3, padding=1)
        c2 = torch.nn.Conv2d(2, 2, 3, padding=1)
        c1_2 = copy.deepcopy(c1)
        c2_2 = copy.deepcopy(c2)

        # are weights between models the same, but do they differ between convolutions?
        assert torch.equal(c1.weight, c1_2.weight)
        assert torch.equal(c2.weight, c2_2.weight)
        assert torch.equal(c1.bias, c1_2.bias)
        assert torch.equal(c2.bias, c2_2.bias)
        assert not torch.equal(c1.weight, c2.weight)

        # define optimizers
        optim1 = torch.optim.SGD([e for e in c1.parameters()] + [e for e in c2.parameters()], 0.1)
        optim2 = torch.optim.SGD([e for e in c1_2.parameters()] + [e for e in c2_2.parameters()], 0.1)
        for e in [c1, c2, c1_2, c2_2]:
            e.train()

        # define an arbitrary reversible function and define graph for model 1
        Xin = X.clone().requires_grad_()
        coupling_fn = create_coupling(Fm=c1_2, Gm=c2_2, coupling=coupling, implementation_fwd=-1,
                                      implementation_bwd=-1, adapter=AffineAdapterNaive)
        fn = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False)

        Y = fn.forward(Xin)
        loss2 = torch.mean(Y)

        # define the reversible function without custom backprop and define graph for model 2
        XX = X.clone().detach().requires_grad_()
        x1, x2 = torch.chunk(XX, 2, dim=1)
        if coupling == 'additive':
            y1 = x1 + c1.forward(x2)
            y2 = x2 + c2.forward(y1)
        elif coupling == 'affine':
            fmr2 = c1.forward(x2)
            fmr1 = torch.exp(fmr2)
            y1 = (x1 * fmr1) + fmr2
            gmr2 = c2.forward(y1)
            gmr1 = torch.exp(gmr2)
            y2 = (x2 * gmr1) + gmr2
        else:
            raise NotImplementedError()
        YY = torch.cat([y1, y2], dim=1)

        loss = torch.mean(YY)

        # compute gradients manually
        grads = torch.autograd.grad(loss, (XX, c1.weight, c2.weight, c1.bias, c2.bias), None, retain_graph=True)

        # compute gradients and perform optimization model 2
        loss.backward()
        optim1.step()

        # gradients computed manually match those of the .backward() pass
        assert torch.equal(c1.weight.grad, grads[1])
        assert torch.equal(c2.weight.grad, grads[2])
        assert torch.equal(c1.bias.grad, grads[3])
        assert torch.equal(c2.bias.grad, grads[4])

        # weights differ after training a single model?
        assert not torch.equal(c1.weight, c1_2.weight)
        assert not torch.equal(c2.weight, c2_2.weight)
        assert not torch.equal(c1.bias, c1_2.bias)
        assert not torch.equal(c2.bias, c2_2.bias)

        # compute gradients and perform optimization model 1
        loss2.backward()
        optim2.step()

        # input is contiguous tests
        assert Xin.is_contiguous()
        assert Y.is_contiguous()

        # weights are approximately the same after training both models?
        assert torch.allclose(c1.weight.detach(), c1_2.weight.detach())
        assert torch.allclose(c2.weight.detach(), c2_2.weight.detach())
        assert torch.allclose(c1.bias.detach(), c1_2.bias.detach())
        assert torch.allclose(c2.bias.detach(), c2_2.bias.detach())

        # gradients are approximately the same after training both models?
        assert torch.allclose(c1.weight.grad.detach(), c1_2.weight.grad.detach())
        assert torch.allclose(c2.weight.grad.detach(), c2_2.weight.grad.detach())
        assert torch.allclose(c1.bias.grad.detach(), c1_2.bias.grad.detach())
        assert torch.allclose(c2.bias.grad.detach(), c2_2.bias.grad.detach())