def test_gradients_inv(biort, qshift, size, J): """ Gradient of forward function should be inverse function with filters swapped """ im = np.random.randn(5, 6, *size).astype('float32') imt = torch.tensor(im, dtype=torch.float32, device=dev) ifm = DTCWTInverse(biort=biort, qshift=qshift).to(dev) h0o, g0o, h1o, g1o = _biort(biort) h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift) ifm_grad = DTCWTForward(J=J, biort=(g0o[::-1], g1o[::-1]), qshift=(g0a[::-1], g0b[::-1], g1a[::-1], g1b[::-1])).to(dev) yl, yh = ifm_grad(imt) g = torch.randn(*imt.shape, device=dev) ylv = torch.randn(*yl.shape, requires_grad=True, device=dev) yhv = [torch.randn(*h.shape, requires_grad=True, device=dev) for h in yh] Y = ifm((ylv, yhv)) Y.backward(g) # Check the lowpass gradient is the same ref_lp, ref_bp = ifm_grad(g) np.testing.assert_array_almost_equal(ylv.grad.detach().cpu(), ref_lp.cpu()) # check the bandpasses are the same for y, ref in zip(yhv, ref_bp): np.testing.assert_array_almost_equal(y.grad.detach().cpu(), ref.cpu())
def test_inv(J, o_before_c): Yl = 100 * np.random.randn(3, 5, 64, 64) Yhr = [np.random.randn(3, 5, 6, 2**j, 2**j) for j in range(4 + J, 4, -1)] Yhi = [np.random.randn(3, 5, 6, 2**j, 2**j) for j in range(4 + J, 4, -1)] Yh1 = [yhr + 1j * yhi for yhr, yhi in zip(Yhr, Yhi)] if o_before_c: Yh2 = [ torch.tensor(np.stack((yhr, yhi), axis=-1), dtype=torch.float32, device=dev).transpose(1, 2) for yhr, yhi in zip(Yhr, Yhi) ] else: Yh2 = [ torch.tensor(np.stack((yhr, yhi), axis=-1), dtype=torch.float32, device=dev) for yhr, yhi in zip(Yhr, Yhi) ] ifm = DTCWTInverse(J=J, o_before_c=o_before_c).to(dev) X = ifm((torch.tensor(Yl, dtype=torch.float32, device=dev), Yh2)) f1 = Transform2d_np() x = f1.inverse(Yl, Yh1) np.testing.assert_array_almost_equal(X.cpu(), x, decimal=PRECISION_FLOAT)
def test_gradients_fwd(biort, qshift, size, J): """ Gradient of forward function should be inverse function with filters swapped """ im = np.random.randn(5, 6, *size).astype('float32') imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev) xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(dev) h0o, g0o, h1o, g1o = _biort(biort) h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift) xfm_grad = DTCWTInverse(biort=(h0o[::-1], h1o[::-1]), qshift=(h0a[::-1], h0b[::-1], h1a[::-1], h1b[::-1])).to(dev) Yl, Yh = xfm(imt) Ylg = torch.randn(*Yl.shape, device=dev) Yl.backward(Ylg, retain_graph=True) ref = xfm_grad((Ylg, [ None, ] * J)) np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu()) for j, y in enumerate(Yh): imt.grad.zero_() g = torch.randn(*y.shape, device=dev) y.backward(g, retain_graph=True) hps = [ None, ] * J hps[j] = g ref = xfm_grad((torch.zeros_like(Yl), hps)) np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu())
def test_inv_skip_hps(J, o_dim): hps = np.random.binomial(size=J, n=1,p=0.5).astype('bool') Yl = 100*np.random.randn(3, 5, 64, 64) Yhr = [[np.random.randn(3, 5, 2**j, 2**j) for l in range(6)] for j in range(4+J,4,-1)] Yhi = [[np.random.randn(3, 5, 2**j, 2**j) for l in range(6)] for j in range(4+J,4,-1)] Yh1 = [np.stack(r, axis=2) + 1j*np.stack(i, axis=2) for r, i in zip(Yhr, Yhi)] Yh2 = [np.stack((np.stack(r, axis=o_dim), np.stack(i, axis=o_dim)), axis=-1) for r, i in zip(Yhr, Yhi)] Yh2 = [torch.tensor(yh, dtype=torch.float32, device=dev) for yh in Yh2] for j in range(J): if hps[j]: Yh2[j] = torch.tensor([]) Yh1[j] = np.zeros_like(Yh1[j]) ifm = DTCWTInverse(J=J, o_dim=o_dim).to(dev) X = ifm((torch.tensor(Yl, dtype=torch.float32, requires_grad=True, device=dev), Yh2)) # Also test giving None instead of an empty tensor for j in range(J): if hps[j]: Yh2[j] = None X2 = ifm((torch.tensor(Yl, dtype=torch.float32, device=dev), Yh2)) f1 = Transform2d_np() x = f1.inverse(Yl, Yh1) np.testing.assert_array_almost_equal( X.detach().cpu(), x, decimal=PRECISION_FLOAT) np.testing.assert_array_almost_equal( X2.cpu(), x, decimal=PRECISION_FLOAT) # Test gradients are ok X.backward(torch.ones_like(X))
def end_to_end(size, no_grad, J, no_hp=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) xfm = DTCWTForward(J=J, skip_hps=no_hp).to(dev) ifm = DTCWTInverse(J=J).to(dev) Yl, Yh = xfm(x) Y = ifm((Yl, Yh)) if not no_grad: Y.backward(torch.ones_like(Y)) return Y
def inverse(size, no_grad, J, no_hp=False, dev='cuda'): yl = torch.randn(size[0], size[1], size[2] >> (J-1), size[3] >> (J-1), requires_grad=(not no_grad)).to(dev) yh = [torch.randn(size[0], size[1], 6, size[2] >> j, size[3] >> j, 2, requires_grad=(not no_grad)).to(dev) for j in range(1,J+1)] ifm = DTCWTInverse(J=J).to(dev) Y = ifm((yl, yh)) if not no_grad: Y.backward(torch.ones_like(Y)) return Y
def __init__(self, C, F, k=4, stride=1, J=1, wd=0, wd1=None, right=True): super().__init__() self.wd = wd if wd1 is None: self.wd1 = wd else: self.wd1 = wd1 self.k = k self.C = C self.F = F self.J = J self.right = right self.ifm = DTCWTInverse() self._init()
def __init__(self, C, F, lp_size=3, bp_sizes=(1, ), biort='near_sym_a', qshift='qshift_a', xfm=True, ifm=True, wd=0, wd1=None, lp_nl=None, bp_nl=None, lp_nl_kwargs={}, bp_nl_kwargs={}): super().__init__() self.C = C self.F = F # If any of the mixing for a scale is 0, don't calculate the dtcwt at # that scale skip_hps = [True if s == 0 else False for s in bp_sizes] self.J = len(bp_sizes) self.wd = wd self.wd1 = wd1 if xfm: self.XFM = DTCWTForward(biort=biort, qshift=qshift, J=self.J, skip_hps=skip_hps, o_dim=2, ri_dim=-1) else: self.XFM = lambda x: x self.GainLayer = WaveGainLayer(C, F, lp_size, bp_sizes, wd=wd, wd1=wd1) if not isinstance(bp_nl, (list, tuple)): bp_nl = [ bp_nl, ] * self.J self.NL = WaveNonLinearity(F, lp_nl, bp_nl, lp_nl_kwargs, bp_nl_kwargs) if ifm: self.IFM = DTCWTInverse(biort=biort, qshift=qshift, o_dim=2, ri_dim=-1) else: self.IFM = lambda x: x
def test_inv(J, o_dim): Yl = 100*np.random.randn(3, 5, 64, 64) Yhr = [[np.random.randn(3, 5, 2**j, 2**j) for l in range(6)] for j in range(4+J,4,-1)] Yhi = [[np.random.randn(3, 5, 2**j, 2**j) for l in range(6)] for j in range(4+J,4,-1)] Yh1 = [np.stack(r, axis=2) + 1j*np.stack(i, axis=2) for r, i in zip(Yhr, Yhi)] Yh2 = [np.stack((np.stack(r, axis=o_dim), np.stack(i, axis=o_dim)), axis=-1) for r, i in zip(Yhr, Yhi)] Yh2 = [torch.tensor(yh, dtype=torch.float32, device=dev) for yh in Yh2] ifm = DTCWTInverse(J=J, o_dim=o_dim).to(dev) X = ifm((torch.tensor(Yl, dtype=torch.float32, device=dev), Yh2)) f1 = Transform2d_np() x = f1.inverse(Yl, Yh1) np.testing.assert_array_almost_equal( X.cpu(), x, decimal=PRECISION_FLOAT)
def test_end2end(biort, qshift, size, J): im = np.random.randn(5,6,*size).astype('float32') imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev) xfm = DTCWTForward(J=J, biort=biort, qshift=qshift).to(dev) Yl, Yh = xfm(imt) ifm = DTCWTInverse(J=J, biort=biort, qshift=qshift).to(dev) y = ifm((Yl, Yh)) # Compare with numpy results f_np = Transform2d_np(biort=biort, qshift=qshift) yl, yh = f_np.forward(im, nlevels=J) y2 = f_np.inverse(yl, yh) np.testing.assert_array_almost_equal(y.detach().cpu(), y2, decimal=PRECISION_FLOAT) # Test gradients are ok y.backward(torch.ones_like(y))
def __init__(self, wt_type='DTCWT', biort='near_sym_b', qshift='qshift_b', J=5, wave='db3', mode='zero', device='cuda', requires_grad=True): super().__init__() if wt_type == 'DTCWT': self.xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(device) self.ifm = DTCWTInverse(biort=biort, qshift=qshift).to(device) elif wt_type == 'DWT': self.xfm = DWTForward(wave=wave, J=J, mode=mode).to(device) self.ifm = DWTInverse(wave=wave, mode=mode).to(device) else: raise ValueError('no such type of wavelet transform is supported') self.J = J self.wt_type = wt_type
def test_inv_ri_dim(ri_dim): Yl = 100*np.random.randn(3, 5, 64, 64) J = 3 Yhr = [np.random.randn(3, 5, 6, 2**j, 2**j) for j in range(4+J,4,-1)] Yhi = [np.random.randn(3, 5, 6, 2**j, 2**j) for j in range(4+J,4,-1)] Yh1 = [yhr + 1j*yhi for yhr, yhi in zip(Yhr, Yhi)] Yh2 = [torch.tensor(np.stack((yhr, yhi), axis=ri_dim), dtype=torch.float32, device=dev) for yhr, yhi in zip(Yhr, Yhi)] if (ri_dim % 6) <= 2: o_dim = 3 else: o_dim = 2 ifm = DTCWTInverse(J=J, o_dim=o_dim, ri_dim=ri_dim).to(dev) X = ifm((torch.tensor(Yl, dtype=torch.float32, device=dev), Yh2)) f1 = Transform2d_np() x = f1.inverse(Yl, Yh1) np.testing.assert_array_almost_equal( X.cpu(), x, decimal=PRECISION_FLOAT)
def __init__(self): super().__init__() self.xfm = DTCWTForward(J=3, C=3) self.ifm = DTCWTInverse(J=3, C=3) self.sparsify = SparsifyWaveCoeffs2(3, 3)
def main(): xfm = DTCWTForward(J=1) # xfm.h0o = xfm.h0a # xfm.h1o = xfm.h1a ifm = DTCWTInverse() # ifm.g0o = ifm.g0a # ifm.g1o = ifm.g1a b1 = (ifm.g0o.data.numpy().ravel()[::-1], ifm.g1o.data.numpy().ravel()[::-1]) xfm2 = DTCWTForward(J=1, biort=b1) b1 = (np.copy(xfm.h0o.data.numpy().ravel()[::-1]), np.copy(xfm.h1o.data.numpy().ravel()[::-1])) ifm2 = DTCWTInverse(biort=b1) # xfm2 = xfm # ifm2 = ifm wd = 1e-2 N = 8 pad = (N - 1) // 2 # U = np.exp(-1j*2*np.pi*index/N) # Us = 1/N * np.conj(U) w = np.random.randn(8, 5, N, N).astype('float32') W = torch.randn(8, 5, N, N, requires_grad=True) W_hat_lp, (W_hat_bp, ) = xfm(W.data) W_hat_lp.requires_grad = True W_hat_bp.requires_grad = True optim1 = torch.optim.SGD([ W, ], lr=0.1, momentum=0.0, weight_decay=0) optim2 = torch.optim.SGD([W_hat_lp, W_hat_bp], lr=0.1, momentum=0.0, weight_decay=0) # optim1 = torch.optim.Adam([W,], lr=0.01) # optim2 = CplxAdam([W_hat,], lr=0.01) # optim1 = torch.optim.Adagrad([W,], lr=0.1) # optim2 = torch.optim.Adagrad([W_hat,], lr=0.1) loss1 = torch.nn.MSELoss() loss2 = torch.nn.MSELoss() for i in range(10): print('Testing step {}'.format(i)) optim1.zero_grad() optim2.zero_grad() W2 = ifm((W_hat_lp, (W_hat_bp, ))) W2.retain_grad() np.testing.assert_array_almost_equal(W2.detach().numpy(), W.detach().numpy(), decimal=4) x = torch.randn(10, 5, 32, 32) y1 = F.conv2d(x, W, padding=pad) y2 = F.conv2d(x, W2, padding=pad) np.testing.assert_array_almost_equal(y1.detach().numpy(), y2.detach().numpy(), decimal=4) y_target = torch.randn(10, 8, 31, 31) output1 = loss1(y1, y_target) output2 = loss2(y2, y_target) output1.backward() output2.backward() # Check the gradients are the same before regularization np.testing.assert_array_almost_equal(W2.grad.numpy(), W.grad.numpy(), decimal=4) reg1 = wd * reg_loss(W, 'l2') reg2 = wd * reg_loss(W_hat_lp, 'l2') + wd * reg_loss(W_hat_bp, 'l2') # Do some sanity checks on gradients np.testing.assert_array_almost_equal(W.grad.data.numpy(), W2.grad.data.numpy(), decimal=4) # DTCWT = a, DTCWT_grad = b, pixel = c, pixel_grad = d a_lp = W_hat_lp.data.clone() a_bp = W_hat_bp.data.clone() da_lp = W_hat_lp.grad.data.clone() da_bp = W_hat_bp.grad.data.clone() b = W.data.clone() db = W.grad.data.clone() # Check that c -> a and d -> b b_lp, (b_bp, ) = xfm(b) db_lp, (db_bp, ) = xfm2(db) np.testing.assert_array_almost_equal(a_lp, b_lp, decimal=4) np.testing.assert_array_almost_equal(a_bp, b_bp, decimal=4) np.testing.assert_array_almost_equal(da_lp, db_lp, decimal=4) np.testing.assert_array_almost_equal(da_bp, db_bp, decimal=4) # Check that a -> c and b -> d a = ifm((a_lp, (a_bp, ))) da = ifm2((da_lp, (da_bp, ))) np.testing.assert_array_almost_equal(a, b, decimal=4) np.testing.assert_array_almost_equal(da, db, decimal=4) # Add in regularization # reg1.backward() # reg2.backward() optim1.step() optim2.step() print('Done! They matched')