예제 #1
0
def quad_afb2d_nonsep(x, filts, mode='zero'):
    """ Does a 1 level 2d wavelet decomposition of an input. Doesn't do separate
    row and column filtering.

    Inputs:
        x (torch.Tensor): Input to decompose
        filts (list or torch.Tensor): If a list is given, should be the low and
            highpass filter banks. If a tensor is given, it should be of the
            form created by
            :py:func:`pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d_nonsep`
        mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which
            padding to use. If periodization, the output size will be half the
            input size.  Otherwise, the output size will be slightly larger than
            half.
    """
    C = x.shape[1]
    Ny = x.shape[2]
    Nx = x.shape[3]

    # Check the filter inputs
    f = torch.cat([filts] * C, dim=0)
    Ly = f.shape[2]
    Lx = f.shape[3]

    if mode == 'periodization' or mode == 'per':
        if x.shape[2] % 2 == 1:
            x = torch.cat((x, x[:, :, -1:]), dim=2)
            Ny += 1
        if x.shape[3] % 2 == 1:
            x = torch.cat((x, x[:, :, :, -1:]), dim=3)
            Nx += 1
        pad = (Ly - 1, Lx - 1)
        stride = (2, 2)
        x = roll(roll(x, -Ly // 2, dim=2), -Lx // 2, dim=3)
        y = F.conv2d(x, f, padding=pad, stride=stride, groups=C)
        y[:, :, :Ly // 2] += y[:, :, Ny // 2:Ny // 2 + Ly // 2]
        y[:, :, :, :Lx // 2] += y[:, :, :, Nx // 2:Nx // 2 + Lx // 2]
        y = y[:, :, :Ny // 2, :Nx // 2]
    elif mode == 'zero' or mode == 'symmetric' or mode == 'reflect':
        # Calculate the pad size
        out1 = pywt.dwt_coeff_len(Ny, Ly, mode=mode)
        out2 = pywt.dwt_coeff_len(Nx, Lx, mode=mode)
        p1 = 2 * (out1 - 1) - Ny + Ly
        p2 = 2 * (out2 - 1) - Nx + Lx
        if mode == 'zero':
            # Sadly, pytorch only allows for same padding before and after, if
            # we need to do more padding after for odd length signals, have to
            # prepad
            if p1 % 2 == 1 and p2 % 2 == 1:
                x = F.pad(x, (0, 1, 0, 1))
            elif p1 % 2 == 1:
                x = F.pad(x, (0, 0, 0, 1))
            elif p2 % 2 == 1:
                x = F.pad(x, (0, 1, 0, 0))
            # Calculate the high and lowpass
            y = F.conv2d(x, f, padding=(p1 // 2, p2 // 2), stride=2, groups=C)
        elif mode == 'symmetric' or mode == 'reflect':
            pad = (p2 // 2, (p2 + 1) // 2, p1 // 2, (p1 + 1) // 2)
            x = mypad(x, pad=pad, mode=mode)
            y = F.conv2d(x, f, stride=2, groups=C)
    else:
        raise ValueError("Unkown pad type: {}".format(mode))

    y = y.reshape((y.shape[0], C, 4, y.shape[-2], y.shape[-1]))
    yl = y[:, :, 0].contiguous()
    yh = y[:, :, 1:].contiguous()
    return yl, yh
예제 #2
0
def quad_afb2d(x, cols, rows, mode='zero', split=True, stride=2):
    """ Does a single level 2d wavelet decomposition of an input. Does separate
    row and column filtering by two calls to
    :py:func:`pytorch_wavelets.dwt.lowlevel.afb1d`

    Inputs:
        x (torch.Tensor): Input to decompose
        filts (list of ndarray or torch.Tensor): If a list of tensors has been
            given, this function assumes they are in the right form (the form
            returned by
            :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`).
            Otherwise, this function will prepare the filters to be of the right
            form by calling
            :py:func:`~pytorch_wavelets.dwt.lowlevel.prep_filt_afb2d`.
        mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which
            padding to use. If periodization, the output size will be half the
            input size.  Otherwise, the output size will be slightly larger than
            half.
    """
    x = x / 2
    C = x.shape[1]
    cols = torch.cat([cols] * C, dim=0)
    rows = torch.cat([rows] * C, dim=0)

    if mode == 'per' or mode == 'periodization':
        # Do column filtering
        L = cols.shape[2]
        L2 = L // 2
        if x.shape[2] % 2 == 1:
            x = torch.cat((x, x[:, :, -1:]), dim=2)
        N2 = x.shape[2] // 2
        x = roll(x, -L2, dim=2)
        pad = (L - 1, 0)
        lohi = F.conv2d(x, cols, padding=pad, stride=(stride, 1), groups=C)
        lohi[:, :, :L2] = lohi[:, :, :L2] + lohi[:, :, N2:N2 + L2]
        lohi = lohi[:, :, :N2]

        # Do row filtering
        L = rows.shape[3]
        L2 = L // 2
        if lohi.shape[3] % 2 == 1:
            lohi = torch.cat((lohi, lohi[:, :, :, -1:]), dim=3)
        N2 = x.shape[3] // 2
        lohi = roll(lohi, -L2, dim=3)
        pad = (0, L - 1)
        w = F.conv2d(lohi, rows, padding=pad, stride=(1, stride), groups=8 * C)
        w[:, :, :, :L2] = w[:, :, :, :L2] + w[:, :, :, N2:N2 + L2]
        w = w[:, :, :, :N2]
    elif mode == 'zero':
        # Do column filtering
        N = x.shape[2]
        L = cols.shape[2]
        outsize = pywt.dwt_coeff_len(N, L, mode='zero')
        p = 2 * (outsize - 1) - N + L

        # Sadly, pytorch only allows for same padding before and after, if
        # we need to do more padding after for odd length signals, have to
        # prepad
        if p % 2 == 1:
            x = F.pad(x, (0, 0, 0, 1))
        pad = (p // 2, 0)
        # Calculate the high and lowpass
        lohi = F.conv2d(x, cols, padding=pad, stride=(stride, 1), groups=C)

        # Do row filtering
        N = lohi.shape[3]
        L = rows.shape[3]
        outsize = pywt.dwt_coeff_len(N, L, mode='zero')
        p = 2 * (outsize - 1) - N + L
        if p % 2 == 1:
            lohi = F.pad(lohi, (0, 1, 0, 0))
        pad = (0, p // 2)
        w = F.conv2d(lohi, rows, padding=pad, stride=(1, stride), groups=8 * C)
    elif mode == 'symmetric' or mode == 'reflect':
        # Do column filtering
        N = x.shape[2]
        L = cols.shape[2]
        outsize = pywt.dwt_coeff_len(N, L, mode=mode)
        p = 2 * (outsize - 1) - N + L
        x = mypad(x, pad=(0, 0, p // 2, (p + 1) // 2), mode=mode)
        lohi = F.conv2d(x, cols, stride=(stride, 1), groups=C)

        # Do row filtering
        N = lohi.shape[3]
        L = rows.shape[3]
        outsize = pywt.dwt_coeff_len(N, L, mode=mode)
        p = 2 * (outsize - 1) - N + L
        lohi = mypad(lohi, pad=(p // 2, (p + 1) // 2, 0, 0), mode=mode)
        w = F.conv2d(lohi, rows, stride=(1, stride), groups=8 * C)
    else:
        raise ValueError("Unkown pad type: {}".format(mode))

    y = w.view((w.shape[0], C, 4, 4, w.shape[-2], w.shape[-1]))
    yl = y[:, :, :, 0]
    yh = y[:, :, :, 1:]
    deg75r, deg105i = pm(yh[:, :, 0, 0], yh[:, :, 3, 0])
    deg105r, deg75i = pm(yh[:, :, 1, 0], yh[:, :, 2, 0])
    deg15r, deg165i = pm(yh[:, :, 0, 1], yh[:, :, 3, 1])
    deg165r, deg15i = pm(yh[:, :, 1, 1], yh[:, :, 2, 1])
    deg135r, deg45i = pm(yh[:, :, 0, 2], yh[:, :, 3, 2])
    deg45r, deg135i = pm(yh[:, :, 1, 2], yh[:, :, 2, 2])
    yhr = torch.stack((deg15r, deg45r, deg75r, deg105r, deg135r, deg165r),
                      dim=1)
    yhi = torch.stack((deg15i, deg45i, deg75i, deg105i, deg135i, deg165i),
                      dim=1)
    yh = torch.stack((yhr, yhi), dim=-1)

    yl_rowa = torch.stack((yl[:, :, 1], yl[:, :, 0]), dim=-1)
    yl_rowb = torch.stack((yl[:, :, 3], yl[:, :, 2]), dim=-1)
    yl_rowa = yl_rowa.view(yl.shape[0], C, yl.shape[-2], yl.shape[-1] * 2)
    yl_rowb = yl_rowb.view(yl.shape[0], C, yl.shape[-2], yl.shape[-1] * 2)
    z = torch.stack((yl_rowb, yl_rowa), dim=-2)
    yl = z.view(yl.shape[0], C, yl.shape[-2] * 2, yl.shape[-1] * 2)

    return yl.contiguous(), yh