def __init__(self, biort='near_sym_a', qshift='qshift_a', mode='symmetric', magbias=1e-2, combine_colour=False): super().__init__() # Have to convert the string to an int as the grad checks don't work # with string inputs self.combine_colour = combine_colour self.MagFn1 = MagFn(b=magbias, c=1) self.MagFn2 = MagFn(b=magbias, c=1) self.MagFn1c = MagFn(b=magbias, c=1) self.MagFn2c = MagFn(b=magbias, c=1) self.MagFn3 = MagFn(b=magbias, c=1) self.xfm1 = DTCWTForward(J=2, biort=biort, qshift=qshift, o_dim=2, ri_dim=-1, mode=mode) self.xfm2 = DTCWTForward(J=1, biort=biort, qshift=qshift, o_dim=2, ri_dim=-1, mode=mode) Hr, Hi = filters_rotated() self.Hr = nn.Parameter(Hr, requires_grad=False) self.Hi = nn.Parameter(Hi, requires_grad=False)
def test_fwd(J, o_before_c): X = 100 * np.random.randn(3, 5, 100, 100) Xt = torch.tensor(X, dtype=torch.get_default_dtype(), device=dev) xfm = DTCWTForward(J=J, o_before_c=o_before_c).to(dev) Yl, Yh = xfm(Xt) f1 = Transform2d_np() yl, yh = f1.forward(X, nlevels=J) np.testing.assert_array_almost_equal(Yl.cpu(), yl, decimal=PRECISION_FLOAT) for i in range(len(yh)): if o_before_c: np.testing.assert_array_almost_equal(Yh[i][..., 0].cpu().transpose( 2, 1), yh[i].real, decimal=PRECISION_FLOAT) np.testing.assert_array_almost_equal(Yh[i][..., 1].cpu().transpose( 2, 1), yh[i].imag, decimal=PRECISION_FLOAT) else: np.testing.assert_array_almost_equal(Yh[i][..., 0].cpu(), yh[i].real, decimal=PRECISION_FLOAT) np.testing.assert_array_almost_equal(Yh[i][..., 1].cpu(), yh[i].imag, decimal=PRECISION_FLOAT)
def test_fwd_ri_dim(o_dim, ri_dim): J = 3 X = 100 * np.random.randn(3, 5, 100, 100) Xt = torch.tensor(X, dtype=torch.get_default_dtype(), device=dev) xfm = DTCWTForward(J=J, o_dim=o_dim, ri_dim=ri_dim).to(dev) Yl, Yh = xfm(Xt) f1 = Transform2d_np() yl, yh = f1.forward(X, nlevels=J) np.testing.assert_array_almost_equal(Yl.cpu(), yl, decimal=PRECISION_FLOAT) if (ri_dim % 6) < o_dim: o_dim -= 1 for i in range(len(yh)): ours_r = np.take(Yh[i].cpu().numpy(), 0, ri_dim) ours_i = np.take(Yh[i].cpu().numpy(), 1, ri_dim) for l in range(6): ours = np.take(ours_r, l, o_dim) np.testing.assert_array_almost_equal(ours, yh[i][:, :, l].real, decimal=PRECISION_FLOAT) ours = np.take(ours_i, l, o_dim) np.testing.assert_array_almost_equal(ours, yh[i][:, :, l].imag, decimal=PRECISION_FLOAT)
def test_gradients_fwd(biort, qshift, size, J): """ Gradient of forward function should be inverse function with filters swapped """ im = np.random.randn(5, 6, *size).astype('float32') imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev) xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(dev) h0o, g0o, h1o, g1o = _biort(biort) h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift) xfm_grad = DTCWTInverse(biort=(h0o[::-1], h1o[::-1]), qshift=(h0a[::-1], h0b[::-1], h1a[::-1], h1b[::-1])).to(dev) Yl, Yh = xfm(imt) Ylg = torch.randn(*Yl.shape, device=dev) Yl.backward(Ylg, retain_graph=True) ref = xfm_grad((Ylg, [ None, ] * J)) np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu()) for j, y in enumerate(Yh): imt.grad.zero_() g = torch.randn(*y.shape, device=dev) y.backward(g, retain_graph=True) hps = [ None, ] * J hps[j] = g ref = xfm_grad((torch.zeros_like(Yl), hps)) np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu())
def test_gradients_inv(biort, qshift, size, J): """ Gradient of forward function should be inverse function with filters swapped """ im = np.random.randn(5, 6, *size).astype('float32') imt = torch.tensor(im, dtype=torch.float32, device=dev) ifm = DTCWTInverse(biort=biort, qshift=qshift).to(dev) h0o, g0o, h1o, g1o = _biort(biort) h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift) ifm_grad = DTCWTForward(J=J, biort=(g0o[::-1], g1o[::-1]), qshift=(g0a[::-1], g0b[::-1], g1a[::-1], g1b[::-1])).to(dev) yl, yh = ifm_grad(imt) g = torch.randn(*imt.shape, device=dev) ylv = torch.randn(*yl.shape, requires_grad=True, device=dev) yhv = [torch.randn(*h.shape, requires_grad=True, device=dev) for h in yh] Y = ifm((ylv, yhv)) Y.backward(g) # Check the lowpass gradient is the same ref_lp, ref_bp = ifm_grad(g) np.testing.assert_array_almost_equal(ylv.grad.detach().cpu(), ref_lp.cpu()) # check the bandpasses are the same for y, ref in zip(yhv, ref_bp): np.testing.assert_array_almost_equal(y.grad.detach().cpu(), ref.cpu())
def test_fwd_skip_hps(J, o_before_c): X = 100 * np.random.randn(3, 5, 100, 100) # Randomly turn on/off the highpass outputs hps = np.random.binomial(size=J, n=1, p=0.5).astype('bool') xfm = DTCWTForward(J=J, skip_hps=hps, o_before_c=o_before_c).to(dev) Yl, Yh = xfm(torch.tensor(X, dtype=torch.float32, device=dev)) f1 = Transform2d_np() yl, yh = f1.forward(X, nlevels=J) np.testing.assert_array_almost_equal(Yl.cpu(), yl, decimal=PRECISION_FLOAT) for j in range(J): if hps[j]: assert Yh[j].shape == torch.Size([0]) else: if o_before_c: np.testing.assert_array_almost_equal(Yh[j][..., 0].cpu().transpose( 2, 1), yh[j].real, decimal=PRECISION_FLOAT) np.testing.assert_array_almost_equal(Yh[j][..., 1].cpu().transpose( 2, 1), yh[j].imag, decimal=PRECISION_FLOAT) else: np.testing.assert_array_almost_equal(Yh[j][..., 0].cpu(), yh[j].real, decimal=PRECISION_FLOAT) np.testing.assert_array_almost_equal(Yh[j][..., 1].cpu(), yh[j].imag, decimal=PRECISION_FLOAT)
def forward(size, no_grad, J, no_hp=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) xfm = DTCWTForward(J=J, skip_hps=no_hp, o_dim=1).to(dev) Yl, Yh = xfm(x) if not no_grad: Yl.backward(torch.ones_like(Yl)) return Yl.mean(), [y.mean() for y in Yh]
def forward(size, no_grad, J, no_hp=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) xfm = DTCWTForward(J=J, skip_hps=no_hp, o_dim=1, mode='symmetric').to(dev) for _ in range(5): Yl, Yh = xfm(x) if not no_grad: Yl.backward(torch.ones_like(Yl)) return Yl, Yh
def end_to_end(size, no_grad, J, no_hp=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) xfm = DTCWTForward(J=J, skip_hps=no_hp).to(dev) ifm = DTCWTInverse(J=J).to(dev) Yl, Yh = xfm(x) Y = ifm((Yl, Yh)) if not no_grad: Y.backward(torch.ones_like(Y)) return Y
def test_bwd_include_scale(scales): X = 100 * np.random.randn(3, 5, 100, 100) # Randomly turn on/off the highpass outputs J = len(scales) xfm = DTCWTForward(J=J, include_scale=scales).to(dev) Ys, Yh = xfm( torch.tensor(X, dtype=torch.float32, requires_grad=True, device=dev)) f1 = Transform2d_np() yl, yh, ys = f1.forward(X, nlevels=J, include_scale=True) for ys in Ys: ys.backward(torch.ones_like(ys), retain_graph=True)
def __init__(self, stride=2): super().__init__() self.xfm = DTCWTForward(J=1, o_dim=1, ri_dim=2) self.mag = MagReshape(b=1e-5, o_dim=1, ri_dim=2) if stride == 2: self.lp = nn.AvgPool2d(2) self.bp = lambda x: x self.avg = lambda x: 2 * func.avg_pool2d(x, 2) elif stride == 1: self.lp = lambda x: x self.bp = lambda x: 0.5 * func.interpolate( x, scale_factor=2, mode='bilinear', align_corners=False) else: raise ValueError("Can only do 1 or 2 stride")
def __init__(self, C, F, lp_size=3, bp_sizes=(1, ), biort='near_sym_a', qshift='qshift_a', xfm=True, ifm=True, wd=0, wd1=None, lp_nl=None, bp_nl=None, lp_nl_kwargs={}, bp_nl_kwargs={}): super().__init__() self.C = C self.F = F # If any of the mixing for a scale is 0, don't calculate the dtcwt at # that scale skip_hps = [True if s == 0 else False for s in bp_sizes] self.J = len(bp_sizes) self.wd = wd self.wd1 = wd1 if xfm: self.XFM = DTCWTForward(biort=biort, qshift=qshift, J=self.J, skip_hps=skip_hps, o_dim=2, ri_dim=-1) else: self.XFM = lambda x: x self.GainLayer = WaveGainLayer(C, F, lp_size, bp_sizes, wd=wd, wd1=wd1) if not isinstance(bp_nl, (list, tuple)): bp_nl = [ bp_nl, ] * self.J self.NL = WaveNonLinearity(F, lp_nl, bp_nl, lp_nl_kwargs, bp_nl_kwargs) if ifm: self.IFM = DTCWTInverse(biort=biort, qshift=qshift, o_dim=2, ri_dim=-1) else: self.IFM = lambda x: x
def test_fwd_include_scale(scales): X = 100*np.random.randn(3, 5, 100, 100) # Randomly turn on/off the highpass outputs J = len(scales) xfm = DTCWTForward(J=J, include_scale=scales).to(dev) Ys, Yh = xfm(torch.tensor(X, dtype=torch.float32, device=dev)) f1 = Transform2d_np() yl, yh, ys = f1.forward(X, nlevels=J, include_scale=True) for j in range(J): if not scales[j]: assert Ys[j].shape == torch.Size([0]) else: np.testing.assert_array_almost_equal( Ys[j].cpu(), ys[j], decimal=PRECISION_FLOAT)
def test_end2end(biort, qshift, size, J): im = np.random.randn(5,6,*size).astype('float32') imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev) xfm = DTCWTForward(J=J, biort=biort, qshift=qshift).to(dev) Yl, Yh = xfm(imt) ifm = DTCWTInverse(J=J, biort=biort, qshift=qshift).to(dev) y = ifm((Yl, Yh)) # Compare with numpy results f_np = Transform2d_np(biort=biort, qshift=qshift) yl, yh = f_np.forward(im, nlevels=J) y2 = f_np.inverse(yl, yh) np.testing.assert_array_almost_equal(y.detach().cpu(), y2, decimal=PRECISION_FLOAT) # Test gradients are ok y.backward(torch.ones_like(y))
def __init__(self, wt_type='DTCWT', biort='near_sym_b', qshift='qshift_b', J=5, wave='db3', mode='zero', device='cuda', requires_grad=True): super().__init__() if wt_type == 'DTCWT': self.xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(device) self.ifm = DTCWTInverse(biort=biort, qshift=qshift).to(device) elif wt_type == 'DWT': self.xfm = DWTForward(wave=wave, J=J, mode=mode).to(device) self.ifm = DWTInverse(wave=wave, mode=mode).to(device) else: raise ValueError('no such type of wavelet transform is supported') self.J = J self.wt_type = wt_type
def test_fwd_double(J, o_dim): with set_double_precision(): X = 100*np.random.randn(3, 5, 100, 100) Xt = torch.tensor(X, dtype=torch.get_default_dtype(), device=dev) xfm = DTCWTForward(J=J, o_dim=o_dim).to(dev) Yl, Yh = xfm(Xt) assert Yl.dtype == torch.float64 f1 = Transform2d_np() yl, yh = f1.forward(X, nlevels=J) np.testing.assert_array_almost_equal( Yl.cpu(), yl, decimal=PRECISION_DOUBLE) for i in range(len(yh)): for l in range(6): ours_r = np.take(Yh[i][...,0].cpu().numpy(), l, o_dim) ours_i = np.take(Yh[i][...,1].cpu().numpy(), l, o_dim) np.testing.assert_array_almost_equal( ours_r, yh[i][:,:,l].real, decimal=PRECISION_DOUBLE) np.testing.assert_array_almost_equal( ours_i, yh[i][:,:,l].imag, decimal=PRECISION_DOUBLE)
def _init(self): # To get the right coeff sizes, perform a forward dtcwt on a kernel size # you ultimately want after reconstruction. x = torch.zeros(self.F, self.C, self.k, self.k) torch.nn.init.xavier_uniform_(x) xfm = DTCWTForward(J=self.J) yl, yh = xfm(x) self.downsample = False if self.k == 4: if self.J == 1: self.downsample = True yl = func.avg_pool2d(yl, 2) self.u_lp = nn.Parameter(torch.zeros_like(yl)) self.uj = nn.Parameter(torch.zeros_like(yh[-1])) self.u_lp.data = yl.data self.uj.data = yh[-1].data if self.right: self.pad = (1, 2, 1, 2) else: self.pad = (2, 1, 2, 1) elif self.k == 8: if self.J == 1: self.downsample = True yl = func.avg_pool2d(yl, 2) self.u_lp = nn.Parameter(torch.zeros_like(yl)) self.uj = nn.Parameter(torch.zeros_like(yh[-1])) self.u_lp.data = yl.data self.uj.data = yh[-1].data if self.right: self.pad = (3, 4, 3, 4) else: self.pad = (4, 3, 4, 3) else: raise NotImplementedError return yl, yh
def test_odd_rows_and_cols(): xfm = DTCWTForward(J=3).to(dev) Yl, Yh = xfm(barbara_t[:, :, :509, :509])
def test_specific_wavelet(): xfm = DTCWTForward(J=3, biort='antonini', qshift='qshift_06').to(dev) Yl, Yh = xfm(barbara_t) assert len(Yl.shape) == 4 assert len(Yh) == 3 assert Yh[0].shape[-1] == 2
def test_simple(): xfm = DTCWTForward(J=3).to(dev) Yl, Yh = xfm(barbara_t) assert len(Yl.shape) == 4 assert len(Yh) == 3 assert Yh[0].shape[-1] == 2
def main(): xfm = DTCWTForward(J=1) # xfm.h0o = xfm.h0a # xfm.h1o = xfm.h1a ifm = DTCWTInverse() # ifm.g0o = ifm.g0a # ifm.g1o = ifm.g1a b1 = (ifm.g0o.data.numpy().ravel()[::-1], ifm.g1o.data.numpy().ravel()[::-1]) xfm2 = DTCWTForward(J=1, biort=b1) b1 = (np.copy(xfm.h0o.data.numpy().ravel()[::-1]), np.copy(xfm.h1o.data.numpy().ravel()[::-1])) ifm2 = DTCWTInverse(biort=b1) # xfm2 = xfm # ifm2 = ifm wd = 1e-2 N = 8 pad = (N - 1) // 2 # U = np.exp(-1j*2*np.pi*index/N) # Us = 1/N * np.conj(U) w = np.random.randn(8, 5, N, N).astype('float32') W = torch.randn(8, 5, N, N, requires_grad=True) W_hat_lp, (W_hat_bp, ) = xfm(W.data) W_hat_lp.requires_grad = True W_hat_bp.requires_grad = True optim1 = torch.optim.SGD([ W, ], lr=0.1, momentum=0.0, weight_decay=0) optim2 = torch.optim.SGD([W_hat_lp, W_hat_bp], lr=0.1, momentum=0.0, weight_decay=0) # optim1 = torch.optim.Adam([W,], lr=0.01) # optim2 = CplxAdam([W_hat,], lr=0.01) # optim1 = torch.optim.Adagrad([W,], lr=0.1) # optim2 = torch.optim.Adagrad([W_hat,], lr=0.1) loss1 = torch.nn.MSELoss() loss2 = torch.nn.MSELoss() for i in range(10): print('Testing step {}'.format(i)) optim1.zero_grad() optim2.zero_grad() W2 = ifm((W_hat_lp, (W_hat_bp, ))) W2.retain_grad() np.testing.assert_array_almost_equal(W2.detach().numpy(), W.detach().numpy(), decimal=4) x = torch.randn(10, 5, 32, 32) y1 = F.conv2d(x, W, padding=pad) y2 = F.conv2d(x, W2, padding=pad) np.testing.assert_array_almost_equal(y1.detach().numpy(), y2.detach().numpy(), decimal=4) y_target = torch.randn(10, 8, 31, 31) output1 = loss1(y1, y_target) output2 = loss2(y2, y_target) output1.backward() output2.backward() # Check the gradients are the same before regularization np.testing.assert_array_almost_equal(W2.grad.numpy(), W.grad.numpy(), decimal=4) reg1 = wd * reg_loss(W, 'l2') reg2 = wd * reg_loss(W_hat_lp, 'l2') + wd * reg_loss(W_hat_bp, 'l2') # Do some sanity checks on gradients np.testing.assert_array_almost_equal(W.grad.data.numpy(), W2.grad.data.numpy(), decimal=4) # DTCWT = a, DTCWT_grad = b, pixel = c, pixel_grad = d a_lp = W_hat_lp.data.clone() a_bp = W_hat_bp.data.clone() da_lp = W_hat_lp.grad.data.clone() da_bp = W_hat_bp.grad.data.clone() b = W.data.clone() db = W.grad.data.clone() # Check that c -> a and d -> b b_lp, (b_bp, ) = xfm(b) db_lp, (db_bp, ) = xfm2(db) np.testing.assert_array_almost_equal(a_lp, b_lp, decimal=4) np.testing.assert_array_almost_equal(a_bp, b_bp, decimal=4) np.testing.assert_array_almost_equal(da_lp, db_lp, decimal=4) np.testing.assert_array_almost_equal(da_bp, db_bp, decimal=4) # Check that a -> c and b -> d a = ifm((a_lp, (a_bp, ))) da = ifm2((da_lp, (da_bp, ))) np.testing.assert_array_almost_equal(a, b, decimal=4) np.testing.assert_array_almost_equal(da, db, decimal=4) # Add in regularization # reg1.backward() # reg2.backward() optim1.step() optim2.step() print('Done! They matched')
def __init__(self): super().__init__() self.xfm = DTCWTForward(J=3, C=3) self.ifm = DTCWTInverse(J=3, C=3) self.sparsify = SparsifyWaveCoeffs2(3, 3)
def forward(size, no_grad, J, no_hp=False, dev='cuda'): x = torch.randn(*size, requires_grad=(not no_grad)).to(dev) xfm = DTCWTForward(J=J, skip_hps=no_hp).to(dev) Yl, Yh = xfm(x) if not no_grad: Yl.backward(torch.ones_like(Yl))