def __init__(self, wave='db1', mode='zero', separable=True): super().__init__() if isinstance(wave, str): wave = pywt.Wavelet(wave) if isinstance(wave, pywt.Wavelet): g0_col, g1_col = wave.rec_lo, wave.rec_hi g0_row, g1_row = g0_col, g1_col else: if len(wave) == 2: g0_col, g1_col = wave[0], wave[1] g0_row, g1_row = g0_col, g1_col elif len(wave) == 4: g0_col, g1_col = wave[0], wave[1] g0_row, g1_row = wave[2], wave[3] # Prepare the filters if separable: filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) self.g0_col = nn.Parameter(filts[0], requires_grad=False) self.g1_col = nn.Parameter(filts[1], requires_grad=False) self.g0_row = nn.Parameter(filts[2], requires_grad=False) self.g1_row = nn.Parameter(filts[3], requires_grad=False) self.h = (self.g0_col, self.g1_col, self.g0_row, self.g1_row) else: filts = lowlevel.prep_filt_sfb2d_nonsep(g0_col, g1_col, g0_row, g1_row) self.h = nn.Parameter(filts, requires_grad=False) self.mode = mode self.separable = separable
def icplxdual2D(yl, yh, level1='farras', qshift='qshift_a', mode='periodization'): # Get the filters _, _, g0a1, g0b1, _, _, g1a1, g1b1 = _level1(level1) _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift) dev = yl[0][0].device Faf = ((prep_filt_sfb2d(g0a1, g1a1, g0a1, g1a1, device=dev), prep_filt_sfb2d(g0a1, g1a1, g0b1, g1b1, device=dev)), (prep_filt_sfb2d(g0b1, g1b1, g0a1, g1a1, device=dev), prep_filt_sfb2d(g0b1, g1b1, g0b1, g1b1, device=dev))) af = ((prep_filt_sfb2d(g0a, g1a, g0a, g1a, device=dev), prep_filt_sfb2d(g0a, g1a, g0b, g1b, device=dev)), (prep_filt_sfb2d(g0b, g1b, g0a, g1a, device=dev), prep_filt_sfb2d(g0b, g1b, g0b, g1b, device=dev))) # Convert the highs back to subbands J = len(yh) w = [[[[None for i in range(3)] for j in range(2)] for k in range(2)] for l in range(J)] for j in range(J): w[j][0][0][0], w[j][1][1][0] = pm(yh[j][:,2,:,:,:,0], yh[j][:,3,:,:,:,1]) w[j][0][1][0], w[j][1][0][0] = pm(yh[j][:,3,:,:,:,0], yh[j][:,2,:,:,:,1]) w[j][0][0][1], w[j][1][1][1] = pm(yh[j][:,0,:,:,:,0], yh[j][:,5,:,:,:,1]) w[j][0][1][1], w[j][1][0][1] = pm(yh[j][:,5,:,:,:,0], yh[j][:,0,:,:,:,1]) w[j][0][0][2], w[j][1][1][2] = pm(yh[j][:,1,:,:,:,0], yh[j][:,4,:,:,:,1]) w[j][0][1][2], w[j][1][0][2] = pm(yh[j][:,4,:,:,:,0], yh[j][:,1,:,:,:,1]) w[j][0][0] = torch.stack(w[j][0][0], dim=2) w[j][0][1] = torch.stack(w[j][0][1], dim=2) w[j][1][0] = torch.stack(w[j][1][0], dim=2) w[j][1][1] = torch.stack(w[j][1][1], dim=2) y = None for m in range(2): for n in range(2): lo = yl[m][n] for j in range(J-1, 0, -1): lo = sfb2d(lo, w[j][m][n], af[m][n], mode=mode) lo = sfb2d(lo, w[0][m][n], Faf[m][n], mode=mode) # Add to the output if y is None: y = lo else: y = y + lo # Normalize y = y/2 return y
def __init__(self, wave='db1', mode='zero'): super().__init__() if isinstance(wave, str): wave = pywt.Wavelet(wave) if isinstance(wave, pywt.Wavelet): g0_col, g1_col = wave.rec_lo, wave.rec_hi g0_row, g1_row = g0_col, g1_col else: if len(wave) == 2: g0_col, g1_col = wave[0], wave[1] g0_row, g1_row = g0_col, g1_col elif len(wave) == 4: g0_col, g1_col = wave[0], wave[1] g0_row, g1_row = wave[2], wave[3] # Prepare the filters filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) self.register_buffer('g0_col', filts[0]) self.register_buffer('g1_col', filts[1]) self.register_buffer('g0_row', filts[2]) self.register_buffer('g1_row', filts[3]) self.mode = mode
def __init__(self, wave="db1", padding_method="zero", device="cpu"): super().__init__() if isinstance(wave, str): wave = pywt.Wavelet(wave) if isinstance(wave, pywt.Wavelet): g0_col, g1_col = wave.rec_lo, wave.rec_hi g0_row, g1_row = g0_col, g1_col else: if len(wave) == 2: g0_col, g1_col = wave[0], wave[1] g0_row, g1_row = g0_col, g1_col elif len(wave) == 4: g0_col, g1_col = wave[0], wave[1] g0_row, g1_row = wave[2], wave[3] # Prepare the filters filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row, device=device) self.register_buffer("g0_col", filts[0]) self.register_buffer("g1_col", filts[1]) self.register_buffer("g0_row", filts[2]) self.register_buffer("g1_row", filts[3]) self.padding_method = padding_method