def main(): M = 32 J = 3 L = 8 filters_set = filter_bank(M, M, J, L=L) fig, axs = plt.subplots(J, L, sharex=True, sharey=True) fig.set_figheight(6) fig.set_figwidth(6) plt.rc('text', usetex=True) plt.rc('font', family='serif') i = 0 for filter in filters_set['psi']: f_r = filter[0][..., 0].numpy() f_i = filter[0][..., 1].numpy() f = f_r + 1j * f_i filter_c = fft2(f) filter_c = np.fft.fftshift(filter_c) axs[i // L, i % L].imshow(colorize(filter_c)) axs[i // L, i % L].axis('off') axs[i // L, i % L].set_title("$j = {}$ \n $\\theta={}$".format(i // L, i % L)) i = i + 1 fig.suptitle( "Wavelets for each scales $j$ and angles $\\theta$ used." "\n Color saturation and color hue respectively denote complex magnitude and complex phase.", fontsize=13) fig.show() plt.figure() plt.rc('text', usetex=True) plt.rc('font', family='serif') plt.axis('off') plt.set_cmap('gray_r') f_r = filters_set['phi'][0][..., 0].numpy() f_i = filters_set['phi'][0][..., 1].numpy() f = f_r + 1j * f_i filter_c = fft2(f) filter_c = np.fft.fftshift(filter_c) plt.suptitle( "The corresponding low-pass filter, also known as scaling function.", fontsize=13) filter_c = np.abs(filter_c) plt.imshow(filter_c) plt.show()
def plot_wavelets(J, L, elems, it): fig, axs = plt.subplots(J, L, sharex=True, sharey=True, gridspec_kw={ 'width_ratios': [1] * L, 'wspace': 0.5, 'hspace': 0.5, 'top': 0.95, 'bottom': 0.05, 'left': 0.1, 'right': 0.95 }) fig.set_figheight(5) fig.set_figwidth(60) plt.rc('text', usetex=False) plt.rc('font', family='serif') i = 0 for f in it(elems): filter_c = fft2(f.numpy()) filter_c = np.fft.fftshift(filter_c) if 0 not in filter_c.shape and not np.isnan(filter_c[0]).any(): axs[i // L, i % L].imshow(normalize(filter_c.astype(float)), cmap='gray') # axs[i // L, i % L].imshow(colorize(filter_c)) axs[i // L, i % L].axis('off') axs[i // L, i % L].set_title("j = {} \n theta={}".format(i // L, i % L)) axs[i // L, i % L].title.set_fontsize(12) i = i + 1 plt.tight_layout() plt.margins(0, 0) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) fig.suptitle((r""), fontsize=13) fig.show()
mutilated[0:nZeros, :] = np.zeros((nZeros, mutilated.shape[1])) plt.imshow(gaussianExample) plt.show() plt.imshow(mutilated) plt.show() M = 512 J = 2 L = 7 filters_set = filter_bank(M, M, J, L=L) transformed = np.zeros((512, 512, len(filters_set['psi']))) transMut = np.zeros((512, 512, len(filters_set['psi']))) for i, filter in enumerate(filters_set['psi']): f_r = filter[0][..., 0].numpy() f_i = filter[0][..., 1].numpy() f = f_r + 1j * f_i transformed[:, :, i] = np.fft.ifft2(fft2(gaussianExample) * f) transMut[:, :, i] = np.fft.ifft2(fft2(mutilated) * f) filter0 = filters_set['psi'][2] filter0 = filter0[0][..., 0].numpy() + 1j * filter0[0][..., 1].numpy() plt.imshow(abs(filter0)) plt.show() plt.imshow(transformed[:, :, 4]) plt.show() transMut.shape fig, axs = plt.subplots(J, L, sharex=True, sharey=True) fig.set_figheight(6) fig.set_figwidth(6) plt.rc('text', usetex=True) plt.rc('font', family='serif')
fig, axs = plt.subplots(J + 1, L, sharex=True, sharey=True) plt.rc('text', usetex=True) plt.rc('font', family='serif') ############################################################################### # Bandpass filters # ---------------- # First, we display each wavelets according to each scale and orientation. i = 0 for filter in filters_set['psi']: f_r = filter[0][..., 0].numpy() f_i = filter[0][..., 1].numpy() f = f_r + 1j * f_i filter_c = fft2(f) filter_c = np.fft.fftshift(filter_c) axs[i // L, i % L].imshow(colorize(filter_c)) axs[i // L, i % L].axis('off') axs[i // L, i % L].set_title("$j = {}$ \n $\\theta={}".format(i // L, i % L)) i = i + 1 # Add blanks for pretty display for z in range(L): axs[i // L, i % L].axis('off') i = i + 1 ############################################################################### # Lowpass filter # ----------------