def wavelets():
    M = 7
    J = 3
    L = 32
    filters_set = filter_bank(M, M, J, L=L)
    fig, axs = plt.subplots(J,
                            L,
                            sharex=True,
                            sharey=True,
                            gridspec_kw={
                                'width_ratios': [1] * L,
                                'wspace': 0.5,
                                'hspace': 0.5,
                                'top': 0.95,
                                'bottom': 0.05,
                                'left': 0.1,
                                'right': 0.95
                            })
    fig.set_figheight(5)
    fig.set_figwidth(60)
    plt.rc('text', usetex=False)
    plt.rc('font', family='serif')
    i = 0
    func = lambda elems: [(yield f[0][..., 0]) for f in elems]
    plot_wavelets(J, L, filters_set['psi'], func)
    func = lambda elems: [(yield f[0][..., 1]) for f in elems]
    plot_wavelets(J, L, filters_set['psi'], func)
    func = lambda elems: [(yield f[1][..., 0]) for f in elems if len(f) > 3]
    plot_wavelets(J, L, filters_set['psi'], func)
    func = lambda elems: [(yield f[1][..., 1]) for f in elems if len(f) > 3]
    plot_wavelets(J, L, filters_set['psi'], func)
    func = lambda elems: [(yield f[2][..., 0]) for f in elems if len(f) > 4]
    plot_wavelets(J, L, filters_set['psi'], func)
    func = lambda elems: [(yield f[2][..., 1]) for f in elems if len(f) > 4]
    plot_wavelets(J, L, filters_set['psi'], func)
Ejemplo n.º 2
0
def main():
    M = 32
    J = 3
    L = 8
    filters_set = filter_bank(M, M, J, L=L)

    fig, axs = plt.subplots(J, L, sharex=True, sharey=True)
    fig.set_figheight(6)
    fig.set_figwidth(6)
    plt.rc('text', usetex=True)
    plt.rc('font', family='serif')
    i = 0

    for filter in filters_set['psi']:
        f_r = filter[0][..., 0].numpy()
        f_i = filter[0][..., 1].numpy()
        f = f_r + 1j * f_i
        filter_c = fft2(f)
        filter_c = np.fft.fftshift(filter_c)
        axs[i // L, i % L].imshow(colorize(filter_c))
        axs[i // L, i % L].axis('off')
        axs[i // L,
            i % L].set_title("$j = {}$ \n $\\theta={}$".format(i // L, i % L))
        i = i + 1

    fig.suptitle(
        "Wavelets for each scales $j$ and angles $\\theta$ used."
        "\n Color saturation and color hue respectively denote complex magnitude and complex phase.",
        fontsize=13)
    fig.show()

    plt.figure()
    plt.rc('text', usetex=True)
    plt.rc('font', family='serif')
    plt.axis('off')
    plt.set_cmap('gray_r')

    f_r = filters_set['phi'][0][..., 0].numpy()
    f_i = filters_set['phi'][0][..., 1].numpy()
    f = f_r + 1j * f_i

    filter_c = fft2(f)
    filter_c = np.fft.fftshift(filter_c)
    plt.suptitle(
        "The corresponding low-pass filter, also known as scaling function.",
        fontsize=13)
    filter_c = np.abs(filter_c)
    plt.imshow(filter_c)
    plt.show()

rot(0.0)
xlocs = np.linspace(-1, 1, num=512)
gaussianExample = f(xlocs, xlocs, .6, .3, np.pi / 6)
mutilated = np.copy(gaussianExample)
nZeros = 12
mutilated[0:nZeros, :] = np.zeros((nZeros, mutilated.shape[1]))
plt.imshow(gaussianExample)
plt.show()
plt.imshow(mutilated)
plt.show()
M = 512
J = 2
L = 7
filters_set = filter_bank(M, M, J, L=L)
transformed = np.zeros((512, 512, len(filters_set['psi'])))
transMut = np.zeros((512, 512, len(filters_set['psi'])))
for i, filter in enumerate(filters_set['psi']):
    f_r = filter[0][..., 0].numpy()
    f_i = filter[0][..., 1].numpy()
    f = f_r + 1j * f_i
    transformed[:, :, i] = np.fft.ifft2(fft2(gaussianExample) * f)
    transMut[:, :, i] = np.fft.ifft2(fft2(mutilated) * f)

filter0 = filters_set['psi'][2]
filter0 = filter0[0][..., 0].numpy() + 1j * filter0[0][..., 1].numpy()
plt.imshow(abs(filter0))
plt.show()

plt.imshow(transformed[:, :, 4])