def __init__(self, biort='near_sym_a', qshift='qshift_a', mode='symmetric', magbias=1e-2, combine_colour=False): super().__init__() self.biort = biort self.qshift = biort # Have to convert the string to an int as the grad checks don't work # with string inputs self.mode_str = mode self.mode = mode_to_int(mode) self.magbias = magbias self.combine_colour = combine_colour if biort == 'near_sym_b_bp': assert qshift == 'qshift_b_bp' self.bandpass_diag = True h0o, _, h1o, _, h2o, _ = _biort(biort) self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) self.h2o = torch.nn.Parameter(prep_filt(h2o, 1), False) h0a, h0b, _, _, h1a, h1b, _, _, h2a, h2b, _, _ = _qshift('qshift_b_bp') self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False) self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False) self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False) self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False) self.h2a = torch.nn.Parameter(prep_filt(h2a, 1), False) self.h2b = torch.nn.Parameter(prep_filt(h2b, 1), False) else: self.bandpass_diag = False h0o, _, h1o, _ = _biort(biort) self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift) self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False) self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False) self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False) self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False)
def test_gradients_fwd(biort, qshift, size, J): """ Gradient of forward function should be inverse function with filters swapped """ im = np.random.randn(5, 6, *size).astype('float32') imt = torch.tensor(im, dtype=torch.float32, requires_grad=True, device=dev) xfm = DTCWTForward(biort=biort, qshift=qshift, J=J).to(dev) h0o, g0o, h1o, g1o = _biort(biort) h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift) xfm_grad = DTCWTInverse(biort=(h0o[::-1], h1o[::-1]), qshift=(h0a[::-1], h0b[::-1], h1a[::-1], h1b[::-1])).to(dev) Yl, Yh = xfm(imt) Ylg = torch.randn(*Yl.shape, device=dev) Yl.backward(Ylg, retain_graph=True) ref = xfm_grad((Ylg, [ None, ] * J)) np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu()) for j, y in enumerate(Yh): imt.grad.zero_() g = torch.randn(*y.shape, device=dev) y.backward(g, retain_graph=True) hps = [ None, ] * J hps[j] = g ref = xfm_grad((torch.zeros_like(Yl), hps)) np.testing.assert_array_almost_equal(imt.grad.detach().cpu(), ref.cpu())
def test_gradients_inv(biort, qshift, size, J): """ Gradient of forward function should be inverse function with filters swapped """ im = np.random.randn(5, 6, *size).astype('float32') imt = torch.tensor(im, dtype=torch.float32, device=dev) ifm = DTCWTInverse(biort=biort, qshift=qshift).to(dev) h0o, g0o, h1o, g1o = _biort(biort) h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = _qshift(qshift) ifm_grad = DTCWTForward(J=J, biort=(g0o[::-1], g1o[::-1]), qshift=(g0a[::-1], g0b[::-1], g1a[::-1], g1b[::-1])).to(dev) yl, yh = ifm_grad(imt) g = torch.randn(*imt.shape, device=dev) ylv = torch.randn(*yl.shape, requires_grad=True, device=dev) yhv = [torch.randn(*h.shape, requires_grad=True, device=dev) for h in yh] Y = ifm((ylv, yhv)) Y.backward(g) # Check the lowpass gradient is the same ref_lp, ref_bp = ifm_grad(g) np.testing.assert_array_almost_equal(ylv.grad.detach().cpu(), ref_lp.cpu()) # check the bandpasses are the same for y, ref in zip(yhv, ref_bp): np.testing.assert_array_almost_equal(y.grad.detach().cpu(), ref.cpu())
def __init__(self, biort='farras', qshift='qshift_a', mode='symmetric'): super().__init__() self.biort = biort self.qshift = qshift if isinstance(biort, str): biort = level1(biort) assert len(biort) == 8 _, _, g0a1, g0b1, _, _, g1a1, g1b1 = biort IWTaa1 = DWTInverse(wave=(g0a1, g1a1, g0a1, g1a1), mode=mode) IWTab1 = DWTInverse(wave=(g0a1, g1a1, g0b1, g1b1), mode=mode) IWTba1 = DWTInverse(wave=(g0b1, g1b1, g0a1, g1a1), mode=mode) IWTbb1 = DWTInverse(wave=(g0b1, g1b1, g0b1, g1b1), mode=mode) self.level1 = nn.ModuleList([IWTaa1, IWTab1, IWTba1, IWTbb1]) if isinstance(qshift, str): qshift = _qshift(qshift) assert len(qshift) == 8 _, _, g0a, g0b, _, _, g1a, g1b = qshift IWTaa = DWTInverse(wave=(g0a, g1a, g0a, g1a), mode=mode) IWTab = DWTInverse(wave=(g0a, g1a, g0b, g1b), mode=mode) IWTba = DWTInverse(wave=(g0b, g1b, g0a, g1a), mode=mode) IWTbb = DWTInverse(wave=(g0b, g1b, g0b, g1b), mode=mode) self.level2 = nn.ModuleList([IWTaa, IWTab, IWTba, IWTbb])
def __init__(self, biort='near_sym_a', qshift='qshift_a', o_dim=2, ri_dim=-1, mode='symmetric'): super().__init__() self.biort = biort self.qshift = qshift self.o_dim = o_dim self.ri_dim = ri_dim self.mode = mode if isinstance(biort, str): _, g0o, _, g1o = _biort(biort) self.g0o = torch.nn.Parameter(prep_filt(g0o, 1), False) self.g1o = torch.nn.Parameter(prep_filt(g1o, 1), False) else: self.g0o = torch.nn.Parameter(prep_filt(biort[0], 1), False) self.g1o = torch.nn.Parameter(prep_filt(biort[1], 1), False) if isinstance(qshift, str): _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift) self.g0a = torch.nn.Parameter(prep_filt(g0a, 1), False) self.g0b = torch.nn.Parameter(prep_filt(g0b, 1), False) self.g1a = torch.nn.Parameter(prep_filt(g1a, 1), False) self.g1b = torch.nn.Parameter(prep_filt(g1b, 1), False) else: self.g0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False) self.g0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False) self.g1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False) self.g1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False)
def __init__(self, biort='near_sym_a', qshift='qshift_a', J=3, o_dim=2, ri_dim=-1): super().__init__() self.biort = biort self.qshift = qshift self.o_dim = o_dim self.ri_dim = ri_dim self.J = J if isinstance(biort, str): _, g0o, _, g1o = _biort(biort) self.g0o = torch.nn.Parameter(prep_filt(g0o, 1), False) self.g1o = torch.nn.Parameter(prep_filt(g1o, 1), False) else: self.g0o = torch.nn.Parameter(prep_filt(biort[0], 1), False) self.g1o = torch.nn.Parameter(prep_filt(biort[1], 1), False) if isinstance(qshift, str): _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift) self.g0a = torch.nn.Parameter(prep_filt(g0a, 1), False) self.g0b = torch.nn.Parameter(prep_filt(g0b, 1), False) self.g1a = torch.nn.Parameter(prep_filt(g1a, 1), False) self.g1b = torch.nn.Parameter(prep_filt(g1b, 1), False) else: self.g0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False) self.g0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False) self.g1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False) self.g1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False) # Create the function to do the DTCWT self.dtcwt_func = getattr(tf, 'ifm{J}'.format(J=J))
def __init__(self, biort='farras', qshift='qshift_a', J=3, mode='symmetric'): super().__init__() self.biort = biort self.qshift = qshift self.J = J if isinstance(biort, str): biort = _level1(biort) assert len(biort) == 8 h0a1, h0b1, _, _, h1a1, h1b1, _, _ = biort DWTaa1 = DWTForward(J=1, wave=(h0a1, h1a1, h0a1, h1a1), mode=mode) DWTab1 = DWTForward(J=1, wave=(h0a1, h1a1, h0b1, h1b1), mode=mode) DWTba1 = DWTForward(J=1, wave=(h0b1, h1b1, h0a1, h1a1), mode=mode) DWTbb1 = DWTForward(J=1, wave=(h0b1, h1b1, h0b1, h1b1), mode=mode) self.level1 = nn.ModuleList([DWTaa1, DWTab1, DWTba1, DWTbb1]) if J > 1: if isinstance(qshift, str): qshift = _qshift(qshift) assert len(qshift) == 8 h0a, h0b, _, _, h1a, h1b, _, _ = qshift DWTaa = DWTForward(J - 1, (h0a, h1a, h0a, h1a), mode=mode) DWTab = DWTForward(J - 1, (h0a, h1a, h0b, h1b), mode=mode) DWTba = DWTForward(J - 1, (h0b, h1b, h0a, h1a), mode=mode) DWTbb = DWTForward(J - 1, (h0b, h1b, h0b, h1b), mode=mode) self.level2 = nn.ModuleList([DWTaa, DWTab, DWTba, DWTbb])
def __init__(self, C=None, biort='near_sym_a', qshift='qshift_a', J=3, o_before_c=False): super().__init__() if C is not None: warnings.warn('C parameter is deprecated. do not need to pass it') self.biort = biort self.qshift = qshift self.o_before_c = o_before_c self.J = J if isinstance(biort, str): _, g0o, _, g1o = _biort(biort) self.g0o = torch.nn.Parameter(prep_filt(g0o, 1), False) self.g1o = torch.nn.Parameter(prep_filt(g1o, 1), False) else: self.g0o = torch.nn.Parameter(prep_filt(biort[0], 1), False) self.g1o = torch.nn.Parameter(prep_filt(biort[1], 1), False) if isinstance(qshift, str): _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift) self.g0a = torch.nn.Parameter(prep_filt(g0a, 1), False) self.g0b = torch.nn.Parameter(prep_filt(g0b, 1), False) self.g1a = torch.nn.Parameter(prep_filt(g1a, 1), False) self.g1b = torch.nn.Parameter(prep_filt(g1b, 1), False) else: self.g0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False) self.g0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False) self.g1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False) self.g1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False) # Create the function to do the DTCWT self.dtcwt_func = getattr(tf, 'ifm{J}'.format(J=J))
def __init__(self, biort='near_sym_a', qshift='qshift_a', o_dim=2, ri_dim=-1, mode='symmetric'): super().__init__() self.biort = biort self.qshift = qshift self.o_dim = o_dim self.ri_dim = ri_dim self.mode = mode if isinstance(biort, str): _, g0o, _, g1o = _biort(biort) self.register_buffer('g0o', prep_filt(g0o, 1)) self.register_buffer('g1o', prep_filt(g1o, 1)) else: self.register_buffer('g0o', prep_filt(biort[0], 1)) self.register_buffer('g1o', prep_filt(biort[1], 1)) if isinstance(qshift, str): _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift) self.register_buffer('g0a', prep_filt(g0a, 1)) self.register_buffer('g0b', prep_filt(g0b, 1)) self.register_buffer('g1a', prep_filt(g1a, 1)) self.register_buffer('g1b', prep_filt(g1b, 1)) else: self.register_buffer('g0a', prep_filt(qshift[0], 1)) self.register_buffer('g0b', prep_filt(qshift[1], 1)) self.register_buffer('g1a', prep_filt(qshift[2], 1)) self.register_buffer('g1b', prep_filt(qshift[3], 1))
def __init__(self, biort='near_sym_a', qshift='qshift_a', J=3, skip_hps=False, include_scale=False, o_dim=2, ri_dim=-1, mode='symmetric'): super().__init__() if o_dim == ri_dim: raise ValueError("Orientations and real/imaginary parts must be " "in different dimensions.") self.biort = biort self.qshift = qshift self.J = J self.o_dim = o_dim self.ri_dim = ri_dim self.mode = mode if isinstance(biort, str): h0o, _, h1o, _ = _biort(biort) self.h0o = prep_filt(h0o, 1) self.h1o = prep_filt(h1o, 1) else: self.h0o = prep_filt(biort[0], 1) self.h1o = prep_filt(biort[1], 1) if isinstance(qshift, str): h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift) self.h0a = prep_filt(h0a, 1) self.h0b = prep_filt(h0b, 1) self.h1a = prep_filt(h1a, 1) self.h1b = prep_filt(h1b, 1) else: self.h0a = prep_filt(qshift[0], 1) self.h0b = prep_filt(qshift[1], 1) self.h1a = prep_filt(qshift[2], 1) self.h1b = prep_filt(qshift[3], 1) self.h0o = nn.Parameter(self.h0o, requires_grad=True) self.h1o = nn.Parameter(self.h1o, requires_grad=True) self.h0a = nn.Parameter(self.h0a, requires_grad=True) self.h0b = nn.Parameter(self.h0b, requires_grad=True) self.h1a = nn.Parameter(self.h1a, requires_grad=True) self.h1b = nn.Parameter(self.h1b, requires_grad=True) # Get the function to do the DTCWT if isinstance(skip_hps, (list, tuple, ndarray)): self.skip_hps = skip_hps else: self.skip_hps = [ skip_hps, ] * self.J if isinstance(include_scale, (list, tuple, ndarray)): self.include_scale = include_scale else: self.include_scale = [ include_scale, ] * self.J
def test_equal_small_in(): h = _qshift('qshift_b')[0] im = barbara[:, 0:4, 0:4] im_t = torch.unsqueeze(torch.tensor(im, dtype=torch.float32), dim=0).to(dev) ref = ref_rowfilter(im, h) y = rowfilter(im_t, prep_filt(h, 1).to(dev)) np.testing.assert_array_almost_equal(y[0].cpu(), ref, decimal=4)
def test_equal_numpy_qshift2(): h = _qshift('qshift_c')[0] im = barbara[:, 52:407, 30:401] im_t = torch.unsqueeze(torch.tensor(im, dtype=torch.float32), dim=0).to(dev) ref = ref_rowfilter(im, h) y = rowfilter(im_t, prep_filt(h, 1).to(dev)) np.testing.assert_array_almost_equal(y[0], ref, decimal=4)
def __init__(self, biort='near_sym_a', qshift='qshift_a', J=3, skip_hps=False, include_scale=False, downsample=False, o_dim=2, ri_dim=-1): super().__init__() if o_dim == ri_dim: raise ValueError("Orientations and real/imaginary parts must be " "in different dimensions.") self.biort = biort self.qshift = qshift self.J = J self.downsample = downsample self.o_dim = o_dim self.ri_dim = ri_dim if isinstance(biort, str): h0o, _, h1o, _ = _biort(biort) self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) else: self.h0o = torch.nn.Parameter(prep_filt(biort[0], 1), False) self.h1o = torch.nn.Parameter(prep_filt(biort[1], 1), False) if isinstance(qshift, str): h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift) self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False) self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False) self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False) self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False) else: self.h0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False) self.h0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False) self.h1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False) self.h1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False) # Get the function to do the DTCWT if isinstance(skip_hps, (list, tuple, ndarray)): self.skip_hps = skip_hps else: self.skip_hps = [ skip_hps, ] * self.J if isinstance(include_scale, (list, tuple, ndarray)): self.include_scale = include_scale else: self.include_scale = [ include_scale, ] * self.J if True in self.include_scale: self.dtcwt_func = getattr(tf, 'xfm{J}scale'.format(J=J)) else: self.dtcwt_func = getattr(tf, 'xfm{J}'.format(J=J))
def icplxdual2D(yl, yh, level1='farras', qshift='qshift_a', mode='periodization'): # Get the filters _, _, g0a1, g0b1, _, _, g1a1, g1b1 = _level1(level1) _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift) dev = yl[0][0].device Faf = ((prep_filt_sfb2d(g0a1, g1a1, g0a1, g1a1, device=dev), prep_filt_sfb2d(g0a1, g1a1, g0b1, g1b1, device=dev)), (prep_filt_sfb2d(g0b1, g1b1, g0a1, g1a1, device=dev), prep_filt_sfb2d(g0b1, g1b1, g0b1, g1b1, device=dev))) af = ((prep_filt_sfb2d(g0a, g1a, g0a, g1a, device=dev), prep_filt_sfb2d(g0a, g1a, g0b, g1b, device=dev)), (prep_filt_sfb2d(g0b, g1b, g0a, g1a, device=dev), prep_filt_sfb2d(g0b, g1b, g0b, g1b, device=dev))) # Convert the highs back to subbands J = len(yh) w = [[[[None for i in range(3)] for j in range(2)] for k in range(2)] for l in range(J)] for j in range(J): w[j][0][0][0], w[j][1][1][0] = pm(yh[j][:,2,:,:,:,0], yh[j][:,3,:,:,:,1]) w[j][0][1][0], w[j][1][0][0] = pm(yh[j][:,3,:,:,:,0], yh[j][:,2,:,:,:,1]) w[j][0][0][1], w[j][1][1][1] = pm(yh[j][:,0,:,:,:,0], yh[j][:,5,:,:,:,1]) w[j][0][1][1], w[j][1][0][1] = pm(yh[j][:,5,:,:,:,0], yh[j][:,0,:,:,:,1]) w[j][0][0][2], w[j][1][1][2] = pm(yh[j][:,1,:,:,:,0], yh[j][:,4,:,:,:,1]) w[j][0][1][2], w[j][1][0][2] = pm(yh[j][:,4,:,:,:,0], yh[j][:,1,:,:,:,1]) w[j][0][0] = torch.stack(w[j][0][0], dim=2) w[j][0][1] = torch.stack(w[j][0][1], dim=2) w[j][1][0] = torch.stack(w[j][1][0], dim=2) w[j][1][1] = torch.stack(w[j][1][1], dim=2) y = None for m in range(2): for n in range(2): lo = yl[m][n] for j in range(J-1, 0, -1): lo = sfb2d(lo, w[j][m][n], af[m][n], mode=mode) lo = sfb2d(lo, w[0][m][n], Faf[m][n], mode=mode) # Add to the output if y is None: y = lo else: y = y + lo # Normalize y = y/2 return y
def __init__(self, C=None, biort='near_sym_a', qshift='qshift_a', J=3, skip_hps=False, o_before_c=False, include_scale=False, downsample=False): super().__init__() if C is not None: warnings.warn('C parameter is deprecated. do not need to pass it ' 'anymore.') self.biort = biort self.qshift = qshift self.o_before_c = o_before_c self.J = J self.downsample = downsample if isinstance(biort, str): h0o, _, h1o, _ = _biort(biort) self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) else: self.h0o = torch.nn.Parameter(prep_filt(biort[0], 1), False) self.h1o = torch.nn.Parameter(prep_filt(biort[1], 1), False) if isinstance(qshift, str): h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift) self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False) self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False) self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False) self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False) else: self.h0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False) self.h0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False) self.h1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False) self.h1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False) # Get the function to do the DTCWT if isinstance(skip_hps, (list, tuple, ndarray)): self.skip_hps = skip_hps else: self.skip_hps = [skip_hps,] * self.J if isinstance(include_scale, (list, tuple, ndarray)): self.include_scale = include_scale else: self.include_scale = [include_scale,] * self.J if True in self.include_scale: self.dtcwt_func = getattr(tf, 'xfm{J}scale'.format(J=J)) else: self.dtcwt_func = getattr(tf, 'xfm{J}'.format(J=J))
def cplxdual2D(x, J, level1='farras', qshift='qshift_a', mode='periodization', mag=False): """ Do a complex dtcwt Returns: lows: lowpass outputs from each of the 4 trees. Is a 2x2 list of lists w: bandpass outputs from each of the 4 trees. Is a list of lists, with shape [J][2][2][3]. Initially the 3 outputs are the lh, hl and hh from each of the 4 trees. After doing sums and differences though, they become the real and imaginary parts for the 6 orientations. In particular: first index - indexes over scales second index - 0 = real, 1 = imaginary third and fourth indices: 0,1 = 15 degrees 1,2 = 45 degrees 0,0 = 75 degrees 1,0 = 105 degrees 0,2 = 135 degrees 1,1 = 165 degrees """ x = x / 2 # Get the filters h0a1, h0b1, _, _, h1a1, h1b1, _, _ = _level1(level1) h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift) Faf = ((prep_filt_afb2d(h0a1, h1a1, h0a1, h1a1, device=x.device), prep_filt_afb2d(h0a1, h1a1, h0b1, h1b1, device=x.device)), (prep_filt_afb2d(h0b1, h1b1, h0a1, h1a1, device=x.device), prep_filt_afb2d(h0b1, h1b1, h0b1, h1b1, device=x.device))) af = ((prep_filt_afb2d(h0a, h1a, h0a, h1a, device=x.device), prep_filt_afb2d(h0a, h1a, h0b, h1b, device=x.device)), (prep_filt_afb2d(h0b, h1b, h0a, h1a, device=x.device), prep_filt_afb2d(h0b, h1b, h0b, h1b, device=x.device))) # Do 4 fully decimated dwts w = [[[None for _ in range(2)] for _ in range(2)] for j in range(J)] lows = [[None for _ in range(2)] for _ in range(2)] for m in range(2): for n in range(2): # Do the first level transform with the first level filters # ll, bands = afb2d(x, (Faf[m][0], Faf[m][1], Faf[n][0], Faf[n][1]), mode=mode) bands = afb2d(x, Faf[m][n], mode=mode) # Separate the low and bandpasses s = bands.shape bands = bands.reshape(s[0], -1, 4, s[-2], s[-1]) ll = bands[:, :, 0].contiguous() w[0][m][n] = [bands[:, :, 2], bands[:, :, 1], bands[:, :, 3]] # Do the second+ level transform with the second level filters for j in range(1, J): # ll, bands = afb2d(ll, (af[m][0], af[m][1], af[n][0], af[n][1]), mode=mode) bands = afb2d(ll, af[m][n], mode=mode) # Separate the low and bandpasses s = bands.shape bands = bands.reshape(s[0], -1, 4, s[-2], s[-1]) ll = bands[:, :, 0].contiguous() w[j][m][n] = [bands[:, :, 2], bands[:, :, 1], bands[:, :, 3]] lows[m][n] = ll # Convert the quads into real and imaginary parts yh = [ None, ] * J for j in range(J): deg75r, deg105i = pm(w[j][0][0][0], w[j][1][1][0]) deg105r, deg75i = pm(w[j][0][1][0], w[j][1][0][0]) deg15r, deg165i = pm(w[j][0][0][1], w[j][1][1][1]) deg165r, deg15i = pm(w[j][0][1][1], w[j][1][0][1]) deg135r, deg45i = pm(w[j][0][0][2], w[j][1][1][2]) deg45r, deg135i = pm(w[j][0][1][2], w[j][1][0][2]) yhr = torch.stack((deg15r, deg45r, deg75r, deg105r, deg135r, deg165r), dim=1) yhi = torch.stack((deg15i, deg45i, deg75i, deg105i, deg135i, deg165i), dim=1) if mag: yh[j] = torch.sqrt(yhr**2 + yhi**2 + 0.01) - np.sqrt(0.01) else: yh[j] = torch.stack((yhr, yhi), dim=-1) return lows, yh
def test_equal_numpy_qshift1(): h = _qshift('qshift_c')[0] ref = ref_rowfilter(barbara, h) y = rowfilter(barbara_t, prep_filt(h, 1).to(dev)) np.testing.assert_array_almost_equal(y[0].cpu(), ref, decimal=4)
def test_qshift(): h = _qshift('qshift_a')[0] x = barbara_t y_op = rowfilter(x, prep_filt(h, 1).to(dev)) assert list(y_op.shape)[1:] == bshape_extracol