Beispiel #1
0
    def __init__(self, J=1, wave='db1', mode='zero', separable=True):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0_col, h1_col = wave.dec_lo, wave.dec_hi
            h0_row, h1_row = h0_col, h1_col
        else:
            if len(wave) == 2:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = h0_col, h1_col
            elif len(wave) == 4:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = wave[2], wave[3]

        # Prepare the filters
        if separable:
            filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
            self.h0_col = nn.Parameter(filts[0], requires_grad=False)
            self.h1_col = nn.Parameter(filts[1], requires_grad=False)
            self.h0_row = nn.Parameter(filts[2], requires_grad=False)
            self.h1_row = nn.Parameter(filts[3], requires_grad=False)
            self.h = (self.h0_col, self.h1_col, self.h0_row, self.h1_row)
        else:
            filts = lowlevel.prep_filt_afb2d_nonsep(h0_col, h1_col, h0_row, h1_row)
            self.h = nn.Parameter(filts, requires_grad=False)
        self.J = J
        self.mode = mode
        self.separable = separable
    def __init__(self, J=1, wave='db1', mode='zero', separable=True):
        super().__init__()
        self.get_high_from_low = False
        self.n_scales = 1
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0_col, h1_col = np.expand_dims(np.array(wave.dec_lo),
                                            0), np.expand_dims(
                                                np.array(wave.dec_hi), 0)
            h0_row, h1_row = np.expand_dims(np.array(h0_col),
                                            0), np.expand_dims(
                                                np.array(h1_col), 0)
        else:
            self.n_scales = len(wave[0])
            if len(wave) == 1:
                self.get_high_from_low = True
                h0_col, h1_col = wave[0], None
                h0_row, h1_row = None, None
            if len(wave) == 2:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = None, None
            elif len(wave) == 4:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = wave[2], wave[3]

        # Prepare the filters
        if separable:
            filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
            self.h0_col = filts[0]
            self.h1_col = filts[1]
            self.h0_row = filts[2]
            self.h1_row = filts[3]
        else:
            filts, self.h0_col, self.h1_col, self.h0_row, self.h1_row = lowlevel.prep_filt_afb2d_nonsep(
                h0_col, h1_col, h0_row, h1_row)

            self.h = filts
        self.J = J
        self.mode = mode
        self.separable = separable
Beispiel #3
0
    def __init__(self, J=1, wave='db1', mode='zero'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0_col, h1_col = wave.dec_lo, wave.dec_hi
            h0_row, h1_row = h0_col, h1_col
        else:
            if len(wave) == 2:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = h0_col, h1_col
            elif len(wave) == 4:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = wave[2], wave[3]

        # Prepare the filters
        filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
        self.register_buffer('h0_col', filts[0])
        self.register_buffer('h1_col', filts[1])
        self.register_buffer('h0_row', filts[2])
        self.register_buffer('h1_row', filts[3])
        self.J = J
        self.mode = mode
Beispiel #4
0
    def __init__(self, J=1, wave='db1', mode='periodization'):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0_col, h1_col = wave.dec_lo, wave.dec_hi
            h0_row, h1_row = h0_col, h1_col
        else:
            if len(wave) == 2:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = h0_col, h1_col
            elif len(wave) == 4:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = wave[2], wave[3]

        # Prepare the filters
        filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)
        self.h0_col = nn.Parameter(filts[0], requires_grad=False)
        self.h1_col = nn.Parameter(filts[1], requires_grad=False)
        self.h0_row = nn.Parameter(filts[2], requires_grad=False)
        self.h1_row = nn.Parameter(filts[3], requires_grad=False)

        self.J = J
        self.mode = mode
Beispiel #5
0
    def __init__(self, decomp_level=1, wave="db1", padding_method="zero", device="cpu"):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            h0_col, h1_col = wave.dec_lo, wave.dec_hi
            h0_row, h1_row = h0_col, h1_col
        else:
            if len(wave) == 2:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = h0_col, h1_col
            elif len(wave) == 4:
                h0_col, h1_col = wave[0], wave[1]
                h0_row, h1_row = wave[2], wave[3]

        # Prepare the filters
        filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row, device=device)
        self.register_buffer("h0_col", filts[0])
        self.register_buffer("h1_col", filts[1])
        self.register_buffer("h0_row", filts[2])
        self.register_buffer("h1_row", filts[3])
        # todo: register the buffer for the output
        self.decomp_level = decomp_level
        self.padding_method = padding_method
def cplxdual2D(x,
               J,
               level1='farras',
               qshift='qshift_a',
               mode='periodization',
               mag=False):
    """ Do a complex dtcwt

    Returns:
        lows: lowpass outputs from each of the 4 trees. Is a 2x2 list of lists
        w: bandpass outputs from each of the 4 trees. Is a list of lists, with
        shape [J][2][2][3]. Initially the 3 outputs are the lh, hl and hh from
        each of the 4 trees. After doing sums and differences though, they
        become the real and imaginary parts for the 6 orientations. In
        particular:
            first index - indexes over scales
            second index - 0 = real, 1 = imaginary
            third and fourth indices:
            0,1 = 15 degrees
            1,2 = 45 degrees
            0,0 = 75 degrees
            1,0 = 105 degrees
            0,2 = 135 degrees
            1,1 = 165 degrees
    """
    x = x / 2
    # Get the filters
    h0a1, h0b1, _, _, h1a1, h1b1, _, _ = _level1(level1)
    h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift)

    Faf = ((prep_filt_afb2d(h0a1, h1a1, h0a1, h1a1, device=x.device),
            prep_filt_afb2d(h0a1, h1a1, h0b1, h1b1, device=x.device)),
           (prep_filt_afb2d(h0b1, h1b1, h0a1, h1a1, device=x.device),
            prep_filt_afb2d(h0b1, h1b1, h0b1, h1b1, device=x.device)))
    af = ((prep_filt_afb2d(h0a, h1a, h0a, h1a, device=x.device),
           prep_filt_afb2d(h0a, h1a, h0b, h1b, device=x.device)),
          (prep_filt_afb2d(h0b, h1b, h0a, h1a, device=x.device),
           prep_filt_afb2d(h0b, h1b, h0b, h1b, device=x.device)))

    # Do 4 fully decimated dwts
    w = [[[None for _ in range(2)] for _ in range(2)] for j in range(J)]
    lows = [[None for _ in range(2)] for _ in range(2)]
    for m in range(2):
        for n in range(2):
            # Do the first level transform with the first level filters
            #  ll, bands = afb2d(x, (Faf[m][0], Faf[m][1], Faf[n][0], Faf[n][1]), mode=mode)
            bands = afb2d(x, Faf[m][n], mode=mode)
            # Separate the low and bandpasses
            s = bands.shape
            bands = bands.reshape(s[0], -1, 4, s[-2], s[-1])
            ll = bands[:, :, 0].contiguous()
            w[0][m][n] = [bands[:, :, 2], bands[:, :, 1], bands[:, :, 3]]

            # Do the second+ level transform with the second level filters
            for j in range(1, J):
                #  ll, bands = afb2d(ll, (af[m][0], af[m][1], af[n][0], af[n][1]), mode=mode)
                bands = afb2d(ll, af[m][n], mode=mode)
                # Separate the low and bandpasses
                s = bands.shape
                bands = bands.reshape(s[0], -1, 4, s[-2], s[-1])
                ll = bands[:, :, 0].contiguous()
                w[j][m][n] = [bands[:, :, 2], bands[:, :, 1], bands[:, :, 3]]
            lows[m][n] = ll

    # Convert the quads into real and imaginary parts
    yh = [
        None,
    ] * J
    for j in range(J):
        deg75r, deg105i = pm(w[j][0][0][0], w[j][1][1][0])
        deg105r, deg75i = pm(w[j][0][1][0], w[j][1][0][0])
        deg15r, deg165i = pm(w[j][0][0][1], w[j][1][1][1])
        deg165r, deg15i = pm(w[j][0][1][1], w[j][1][0][1])
        deg135r, deg45i = pm(w[j][0][0][2], w[j][1][1][2])
        deg45r, deg135i = pm(w[j][0][1][2], w[j][1][0][2])
        yhr = torch.stack((deg15r, deg45r, deg75r, deg105r, deg135r, deg165r),
                          dim=1)
        yhi = torch.stack((deg15i, deg45i, deg75i, deg105i, deg135i, deg165i),
                          dim=1)
        if mag:
            yh[j] = torch.sqrt(yhr**2 + yhi**2 + 0.01) - np.sqrt(0.01)
        else:
            yh[j] = torch.stack((yhr, yhi), dim=-1)

    return lows, yh