예제 #1
0
 def __init__(self, wave='db1', mode='zero', separable=True):
     super().__init__()
     if isinstance(wave, str):
         wave = pywt.Wavelet(wave)
     if isinstance(wave, pywt.Wavelet):
         g0_col, g1_col = wave.rec_lo, wave.rec_hi
         g0_row, g1_row = g0_col, g1_col
     else:
         if len(wave) == 2:
             g0_col, g1_col = wave[0], wave[1]
             g0_row, g1_row = g0_col, g1_col
         elif len(wave) == 4:
             g0_col, g1_col = wave[0], wave[1]
             g0_row, g1_row = wave[2], wave[3]
     # Prepare the filters
     if separable:
         filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
         self.g0_col = nn.Parameter(filts[0], requires_grad=False)
         self.g1_col = nn.Parameter(filts[1], requires_grad=False)
         self.g0_row = nn.Parameter(filts[2], requires_grad=False)
         self.g1_row = nn.Parameter(filts[3], requires_grad=False)
         self.h = (self.g0_col, self.g1_col, self.g0_row, self.g1_row)
     else:
         filts = lowlevel.prep_filt_sfb2d_nonsep(g0_col, g1_col, g0_row, g1_row)
         self.h = nn.Parameter(filts, requires_grad=False)
     self.mode = mode
     self.separable = separable
예제 #2
0
def icplxdual2D(yl, yh, level1='farras', qshift='qshift_a', mode='periodization'):
    # Get the filters
    _, _, g0a1, g0b1, _, _, g1a1, g1b1 = _level1(level1)
    _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift)

    dev = yl[0][0].device
    Faf = ((prep_filt_sfb2d(g0a1, g1a1, g0a1, g1a1, device=dev),
            prep_filt_sfb2d(g0a1, g1a1, g0b1, g1b1, device=dev)),
           (prep_filt_sfb2d(g0b1, g1b1, g0a1, g1a1, device=dev),
            prep_filt_sfb2d(g0b1, g1b1, g0b1, g1b1, device=dev)))
    af = ((prep_filt_sfb2d(g0a, g1a, g0a, g1a, device=dev),
           prep_filt_sfb2d(g0a, g1a, g0b, g1b, device=dev)),
          (prep_filt_sfb2d(g0b, g1b, g0a, g1a, device=dev),
           prep_filt_sfb2d(g0b, g1b, g0b, g1b, device=dev)))

    # Convert the highs back to subbands
    J = len(yh)
    w = [[[[None for i in range(3)] for j in range(2)] for k in range(2)] for l in range(J)]
    for j in range(J):
        w[j][0][0][0], w[j][1][1][0] = pm(yh[j][:,2,:,:,:,0],
                                          yh[j][:,3,:,:,:,1])
        w[j][0][1][0], w[j][1][0][0] = pm(yh[j][:,3,:,:,:,0],
                                          yh[j][:,2,:,:,:,1])
        w[j][0][0][1], w[j][1][1][1] = pm(yh[j][:,0,:,:,:,0],
                                          yh[j][:,5,:,:,:,1])
        w[j][0][1][1], w[j][1][0][1] = pm(yh[j][:,5,:,:,:,0],
                                          yh[j][:,0,:,:,:,1])
        w[j][0][0][2], w[j][1][1][2] = pm(yh[j][:,1,:,:,:,0],
                                          yh[j][:,4,:,:,:,1])
        w[j][0][1][2], w[j][1][0][2] = pm(yh[j][:,4,:,:,:,0],
                                          yh[j][:,1,:,:,:,1])
        w[j][0][0] = torch.stack(w[j][0][0], dim=2)
        w[j][0][1] = torch.stack(w[j][0][1], dim=2)
        w[j][1][0] = torch.stack(w[j][1][0], dim=2)
        w[j][1][1] = torch.stack(w[j][1][1], dim=2)

    y = None
    for m in range(2):
        for n in range(2):
            lo = yl[m][n]
            for j in range(J-1, 0, -1):
                lo = sfb2d(lo, w[j][m][n], af[m][n], mode=mode)
            lo = sfb2d(lo, w[0][m][n], Faf[m][n], mode=mode)

            # Add to the output
            if y is None:
                y = lo
            else:
                y = y + lo

    # Normalize
    y = y/2
    return y
예제 #3
0
 def __init__(self, wave='db1', mode='zero'):
     super().__init__()
     if isinstance(wave, str):
         wave = pywt.Wavelet(wave)
     if isinstance(wave, pywt.Wavelet):
         g0_col, g1_col = wave.rec_lo, wave.rec_hi
         g0_row, g1_row = g0_col, g1_col
     else:
         if len(wave) == 2:
             g0_col, g1_col = wave[0], wave[1]
             g0_row, g1_row = g0_col, g1_col
         elif len(wave) == 4:
             g0_col, g1_col = wave[0], wave[1]
             g0_row, g1_row = wave[2], wave[3]
     # Prepare the filters
     filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
     self.register_buffer('g0_col', filts[0])
     self.register_buffer('g1_col', filts[1])
     self.register_buffer('g0_row', filts[2])
     self.register_buffer('g1_row', filts[3])
     self.mode = mode
예제 #4
0
    def __init__(self, wave="db1", padding_method="zero", device="cpu"):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            g0_col, g1_col = wave.rec_lo, wave.rec_hi
            g0_row, g1_row = g0_col, g1_col
        else:
            if len(wave) == 2:
                g0_col, g1_col = wave[0], wave[1]
                g0_row, g1_row = g0_col, g1_col
            elif len(wave) == 4:
                g0_col, g1_col = wave[0], wave[1]
                g0_row, g1_row = wave[2], wave[3]

        # Prepare the filters
        filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row, device=device)
        self.register_buffer("g0_col", filts[0])
        self.register_buffer("g1_col", filts[1])
        self.register_buffer("g0_row", filts[2])
        self.register_buffer("g1_row", filts[3])
        self.padding_method = padding_method