def bandpass_denoising(a, name, save_flag=False): b_time = time.time() grid = GV.grid_displacement_to_center(a.shape, GV.fft_mid_co(a.shape)) rad = GV.grid_distance_to_center(grid) rad = np.round(rad).astype(np.int) # create a mask that only center frequencies components will be left curve = np.zeros(rad.shape) # TODO: change the curve value as desired curve[int(rad.shape[0] / 8) * 3:int(rad.shape[0] / 8) * 5, int(rad.shape[1] / 8) * 3:int(rad.shape[1] / 8) * 5, int(rad.shape[2] / 8) * 3:int(rad.shape[2] / 8) * 5] = 1 #perform FFT and filter the data with the mask and then transform the filtered data back vf = ifftn(ifftshift((fftshift(fftn(a)) * curve))) vf = np.real(vf) end_time = time.time() print('Bandpass de-noise takes', end_time - b_time, 's') if save_flag: img = (vf[:, :, int(vf.shape[2] / 2)]).copy() # TODO: Change the image and tomogram saving path img_path = '/Users/apple/Desktop/Lab/Zach_Project/Denoising_Result/Bandpass/' + str( name) + '_BP.png' plt.imsave(img_path, img, cmap='gray') mrc_path = '/Users/apple/Desktop/Lab/Zach_Project/Denoising_Result/Bandpass/' + str( name) + '_BP.mrc' io_file.put_mrc_data(vf, mrc_path) return img
def sphere_mask(shape, center=None, radius=None, smooth_sigma=None): shape = N.array(shape) v = N.zeros(shape) if center is None: center = ( shape - 1 ) / 2.0 # IMPORTANT: following python convension, in index starts from 0 to size-1 !!! So (siz-1)/2 is real symmetry center of the volume center = N.array(center) if radius is None: radius = N.min(shape / 2.0) grid = gv.grid_displacement_to_center(shape, mid_co=center) dist = gv.grid_distance_to_center(grid) v[dist <= radius] = 1.0 if smooth_sigma is not None: assert smooth_sigma > 0 v_s = N.exp(-((dist - radius) / smooth_sigma)**2) v_s[v_s < N.exp( -3 )] = 0.0 # use a cutoff of -3 looks nicer, although the tom toolbox uses -2 v[dist >= radius] = v_s[dist >= radius] return v
def filter_given_curve(v, curve): grid = GV.grid_displacement_to_center(v.shape, GV.fft_mid_co(v.shape)) rad = GV.grid_distance_to_center(grid) rad = N.round(rad).astype(N.int) b = N.zeros(rad.shape) for (i, a) in enumerate(curve): b[(rad == i)] = a vf = ifftn(ifftshift((fftshift(fftn(v)) * b))) vf = N.real(vf) return vf
def ssnr__get_rad(siz): grid = GV.grid_displacement_to_center(siz, GV.fft_mid_co(siz)) rad = GV.grid_distance_to_center(grid) return rad