def test_invertible_module_wrapper_disabled_versus_enabled(inverted): set_seeds(42) Gm = SubModule(in_filters=5, out_filters=5) coupling_fn = create_coupling(Fm=Gm, Gm=Gm, coupling='additive', implementation_fwd=-1, implementation_bwd=-1) rb = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False) rb2 = InvertibleModuleWrapper(fn=copy.deepcopy(coupling_fn), keep_input=False, keep_input_inverse=False) rb.eval() rb2.eval() rb2.disable = True with torch.no_grad(): dims = (2, 10, 8, 8) data = torch.rand(*dims, dtype=torch.float32) X, X2 = data.clone().detach().requires_grad_(), data.clone().detach().requires_grad_() if not inverted: Y = rb(X) Y2 = rb2(X2) else: Y = rb.inverse(X) Y2 = rb2.inverse(X2) assert torch.allclose(Y, Y2) assert is_memory_cleared(X, True, dims) assert is_memory_cleared(X2, False, dims)
def test_invertible_module_wrapper_simple_inverse(coupling): """InvertibleModuleWrapper inverse test""" for seed in range(10): set_seeds(seed) # define some data X = torch.rand(2, 4, 5, 5).requires_grad_() # define an arbitrary reversible function coupling_fn = create_coupling(Fm=torch.nn.Conv2d(2, 2, 3, padding=1), coupling=coupling, implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterNaive) fn = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False) # compute output Y = fn.forward(X.clone()) # compute input from output X2 = fn.inverse(Y) # check that the inverted output and the original input are approximately similar assert torch.allclose(X2.detach(), X.detach(), atol=1e-06)
def __init__(self, Gm, coupling='additive', depth=10, implementation_fwd=-1, implementation_bwd=-1, keep_input=False, adapter=None): super(SubModuleStack, self).__init__() fn = create_coupling(Fm=Gm, Gm=Gm, coupling=coupling, implementation_fwd=implementation_fwd, implementation_bwd=implementation_bwd, adapter=adapter) self.stack = torch.nn.ModuleList( [InvertibleModuleWrapper(fn=fn, keep_input=keep_input, keep_input_inverse=keep_input) for _ in range(depth)] )
def __init__(self, inplanes, planes, stride=1, downsample=None, noactivation=False): super(RevBasicBlock, self).__init__() if downsample is None and stride == 1: gm = BasicBlockSub(inplanes // 2, planes // 2, stride, noactivation) fm = BasicBlockSub(inplanes // 2, planes // 2, stride, noactivation) coupling = create_coupling(Fm=fm, Gm=gm, coupling='additive') self.revblock = InvertibleModuleWrapper(fn=coupling, keep_input=False) else: self.basicblock_sub = BasicBlockSub(inplanes, planes, stride, noactivation) self.downsample = downsample self.stride = stride
def test_normal_vs_invertible_module_wrapper(coupling): """InvertibleModuleWrapper test if similar gradients and weights results are obtained after similar training""" for seed in range(10): set_seeds(seed) X = torch.rand(2, 4, 5, 5) # define models and their copies c1 = torch.nn.Conv2d(2, 2, 3, padding=1) c2 = torch.nn.Conv2d(2, 2, 3, padding=1) c1_2 = copy.deepcopy(c1) c2_2 = copy.deepcopy(c2) # are weights between models the same, but do they differ between convolutions? assert torch.equal(c1.weight, c1_2.weight) assert torch.equal(c2.weight, c2_2.weight) assert torch.equal(c1.bias, c1_2.bias) assert torch.equal(c2.bias, c2_2.bias) assert not torch.equal(c1.weight, c2.weight) # define optimizers optim1 = torch.optim.SGD([e for e in c1.parameters()] + [e for e in c2.parameters()], 0.1) optim2 = torch.optim.SGD([e for e in c1_2.parameters()] + [e for e in c2_2.parameters()], 0.1) for e in [c1, c2, c1_2, c2_2]: e.train() # define an arbitrary reversible function and define graph for model 1 Xin = X.clone().requires_grad_() coupling_fn = create_coupling(Fm=c1_2, Gm=c2_2, coupling=coupling, implementation_fwd=-1, implementation_bwd=-1, adapter=AffineAdapterNaive) fn = InvertibleModuleWrapper(fn=coupling_fn, keep_input=False, keep_input_inverse=False) Y = fn.forward(Xin) loss2 = torch.mean(Y) # define the reversible function without custom backprop and define graph for model 2 XX = X.clone().detach().requires_grad_() x1, x2 = torch.chunk(XX, 2, dim=1) if coupling == 'additive': y1 = x1 + c1.forward(x2) y2 = x2 + c2.forward(y1) elif coupling == 'affine': fmr2 = c1.forward(x2) fmr1 = torch.exp(fmr2) y1 = (x1 * fmr1) + fmr2 gmr2 = c2.forward(y1) gmr1 = torch.exp(gmr2) y2 = (x2 * gmr1) + gmr2 else: raise NotImplementedError() YY = torch.cat([y1, y2], dim=1) loss = torch.mean(YY) # compute gradients manually grads = torch.autograd.grad(loss, (XX, c1.weight, c2.weight, c1.bias, c2.bias), None, retain_graph=True) # compute gradients and perform optimization model 2 loss.backward() optim1.step() # gradients computed manually match those of the .backward() pass assert torch.equal(c1.weight.grad, grads[1]) assert torch.equal(c2.weight.grad, grads[2]) assert torch.equal(c1.bias.grad, grads[3]) assert torch.equal(c2.bias.grad, grads[4]) # weights differ after training a single model? assert not torch.equal(c1.weight, c1_2.weight) assert not torch.equal(c2.weight, c2_2.weight) assert not torch.equal(c1.bias, c1_2.bias) assert not torch.equal(c2.bias, c2_2.bias) # compute gradients and perform optimization model 1 loss2.backward() optim2.step() # input is contiguous tests assert Xin.is_contiguous() assert Y.is_contiguous() # weights are approximately the same after training both models? assert torch.allclose(c1.weight.detach(), c1_2.weight.detach()) assert torch.allclose(c2.weight.detach(), c2_2.weight.detach()) assert torch.allclose(c1.bias.detach(), c1_2.bias.detach()) assert torch.allclose(c2.bias.detach(), c2_2.bias.detach()) # gradients are approximately the same after training both models? assert torch.allclose(c1.weight.grad.detach(), c1_2.weight.grad.detach()) assert torch.allclose(c2.weight.grad.detach(), c2_2.weight.grad.detach()) assert torch.allclose(c1.bias.grad.detach(), c1_2.bias.grad.detach()) assert torch.allclose(c2.bias.grad.detach(), c2_2.bias.grad.detach())