Esempio n. 1
def quad_afb2d_nonsep(x, filts, mode='zero'):
    """ Does a 1 level 2d wavelet decomposition of an input. Doesn't do separate
    row and column filtering.

        x (torch.Tensor): Input to decompose
        filts (list or torch.Tensor): If a list is given, should be the low and
            highpass filter banks. If a tensor is given, it should be of the
            form created by
        mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which
            padding to use. If periodization, the output size will be half the
            input size.  Otherwise, the output size will be slightly larger than
    C = x.shape[1]
    Ny = x.shape[2]
    Nx = x.shape[3]

    # Check the filter inputs
    f =[filts] * C, dim=0)
    Ly = f.shape[2]
    Lx = f.shape[3]

    if mode == 'periodization' or mode == 'per':
        if x.shape[2] % 2 == 1:
            x =, x[:, :, -1:]), dim=2)
            Ny += 1
        if x.shape[3] % 2 == 1:
            x =, x[:, :, :, -1:]), dim=3)
            Nx += 1
        pad = (Ly - 1, Lx - 1)
        stride = (2, 2)
        x = roll(roll(x, -Ly // 2, dim=2), -Lx // 2, dim=3)
        y = F.conv2d(x, f, padding=pad, stride=stride, groups=C)
        y[:, :, :Ly // 2] += y[:, :, Ny // 2:Ny // 2 + Ly // 2]
        y[:, :, :, :Lx // 2] += y[:, :, :, Nx // 2:Nx // 2 + Lx // 2]
        y = y[:, :, :Ny // 2, :Nx // 2]
    elif mode == 'zero' or mode == 'symmetric' or mode == 'reflect':
        # Calculate the pad size
        out1 = pywt.dwt_coeff_len(Ny, Ly, mode=mode)
        out2 = pywt.dwt_coeff_len(Nx, Lx, mode=mode)
        p1 = 2 * (out1 - 1) - Ny + Ly
        p2 = 2 * (out2 - 1) - Nx + Lx
        if mode == 'zero':
            # Sadly, pytorch only allows for same padding before and after, if
            # we need to do more padding after for odd length signals, have to
            # prepad
            if p1 % 2 == 1 and p2 % 2 == 1:
                x = F.pad(x, (0, 1, 0, 1))
            elif p1 % 2 == 1:
                x = F.pad(x, (0, 0, 0, 1))
            elif p2 % 2 == 1:
                x = F.pad(x, (0, 1, 0, 0))
            # Calculate the high and lowpass
            y = F.conv2d(x, f, padding=(p1 // 2, p2 // 2), stride=2, groups=C)
        elif mode == 'symmetric' or mode == 'reflect':
            pad = (p2 // 2, (p2 + 1) // 2, p1 // 2, (p1 + 1) // 2)
            x = mypad(x, pad=pad, mode=mode)
            y = F.conv2d(x, f, stride=2, groups=C)
        raise ValueError("Unkown pad type: {}".format(mode))

    y = y.reshape((y.shape[0], C, 4, y.shape[-2], y.shape[-1]))
    yl = y[:, :, 0].contiguous()
    yh = y[:, :, 1:].contiguous()
    return yl, yh
Esempio n. 2
def quad_afb2d(x, cols, rows, mode='zero', split=True, stride=2):
    """ Does a single level 2d wavelet decomposition of an input. Does separate
    row and column filtering by two calls to

        x (torch.Tensor): Input to decompose
        filts (list of ndarray or torch.Tensor): If a list of tensors has been
            given, this function assumes they are in the right form (the form
            returned by
            Otherwise, this function will prepare the filters to be of the right
            form by calling
        mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. Which
            padding to use. If periodization, the output size will be half the
            input size.  Otherwise, the output size will be slightly larger than
    x = x / 2
    C = x.shape[1]
    cols =[cols] * C, dim=0)
    rows =[rows] * C, dim=0)

    if mode == 'per' or mode == 'periodization':
        # Do column filtering
        L = cols.shape[2]
        L2 = L // 2
        if x.shape[2] % 2 == 1:
            x =, x[:, :, -1:]), dim=2)
        N2 = x.shape[2] // 2
        x = roll(x, -L2, dim=2)
        pad = (L - 1, 0)
        lohi = F.conv2d(x, cols, padding=pad, stride=(stride, 1), groups=C)
        lohi[:, :, :L2] = lohi[:, :, :L2] + lohi[:, :, N2:N2 + L2]
        lohi = lohi[:, :, :N2]

        # Do row filtering
        L = rows.shape[3]
        L2 = L // 2
        if lohi.shape[3] % 2 == 1:
            lohi =, lohi[:, :, :, -1:]), dim=3)
        N2 = x.shape[3] // 2
        lohi = roll(lohi, -L2, dim=3)
        pad = (0, L - 1)
        w = F.conv2d(lohi, rows, padding=pad, stride=(1, stride), groups=8 * C)
        w[:, :, :, :L2] = w[:, :, :, :L2] + w[:, :, :, N2:N2 + L2]
        w = w[:, :, :, :N2]
    elif mode == 'zero':
        # Do column filtering
        N = x.shape[2]
        L = cols.shape[2]
        outsize = pywt.dwt_coeff_len(N, L, mode='zero')
        p = 2 * (outsize - 1) - N + L

        # Sadly, pytorch only allows for same padding before and after, if
        # we need to do more padding after for odd length signals, have to
        # prepad
        if p % 2 == 1:
            x = F.pad(x, (0, 0, 0, 1))
        pad = (p // 2, 0)
        # Calculate the high and lowpass
        lohi = F.conv2d(x, cols, padding=pad, stride=(stride, 1), groups=C)

        # Do row filtering
        N = lohi.shape[3]
        L = rows.shape[3]
        outsize = pywt.dwt_coeff_len(N, L, mode='zero')
        p = 2 * (outsize - 1) - N + L
        if p % 2 == 1:
            lohi = F.pad(lohi, (0, 1, 0, 0))
        pad = (0, p // 2)
        w = F.conv2d(lohi, rows, padding=pad, stride=(1, stride), groups=8 * C)
    elif mode == 'symmetric' or mode == 'reflect':
        # Do column filtering
        N = x.shape[2]
        L = cols.shape[2]
        outsize = pywt.dwt_coeff_len(N, L, mode=mode)
        p = 2 * (outsize - 1) - N + L
        x = mypad(x, pad=(0, 0, p // 2, (p + 1) // 2), mode=mode)
        lohi = F.conv2d(x, cols, stride=(stride, 1), groups=C)

        # Do row filtering
        N = lohi.shape[3]
        L = rows.shape[3]
        outsize = pywt.dwt_coeff_len(N, L, mode=mode)
        p = 2 * (outsize - 1) - N + L
        lohi = mypad(lohi, pad=(p // 2, (p + 1) // 2, 0, 0), mode=mode)
        w = F.conv2d(lohi, rows, stride=(1, stride), groups=8 * C)
        raise ValueError("Unkown pad type: {}".format(mode))

    y = w.view((w.shape[0], C, 4, 4, w.shape[-2], w.shape[-1]))
    yl = y[:, :, :, 0]
    yh = y[:, :, :, 1:]
    deg75r, deg105i = pm(yh[:, :, 0, 0], yh[:, :, 3, 0])
    deg105r, deg75i = pm(yh[:, :, 1, 0], yh[:, :, 2, 0])
    deg15r, deg165i = pm(yh[:, :, 0, 1], yh[:, :, 3, 1])
    deg165r, deg15i = pm(yh[:, :, 1, 1], yh[:, :, 2, 1])
    deg135r, deg45i = pm(yh[:, :, 0, 2], yh[:, :, 3, 2])
    deg45r, deg135i = pm(yh[:, :, 1, 2], yh[:, :, 2, 2])
    yhr = torch.stack((deg15r, deg45r, deg75r, deg105r, deg135r, deg165r),
    yhi = torch.stack((deg15i, deg45i, deg75i, deg105i, deg135i, deg165i),
    yh = torch.stack((yhr, yhi), dim=-1)

    yl_rowa = torch.stack((yl[:, :, 1], yl[:, :, 0]), dim=-1)
    yl_rowb = torch.stack((yl[:, :, 3], yl[:, :, 2]), dim=-1)
    yl_rowa = yl_rowa.view(yl.shape[0], C, yl.shape[-2], yl.shape[-1] * 2)
    yl_rowb = yl_rowb.view(yl.shape[0], C, yl.shape[-2], yl.shape[-1] * 2)
    z = torch.stack((yl_rowb, yl_rowa), dim=-2)
    yl = z.view(yl.shape[0], C, yl.shape[-2] * 2, yl.shape[-1] * 2)

    return yl.contiguous(), yh