def resolution_curves_subplot(ax, data_to_plot, x_idx=0, disable_ax_labels=False, line_style='-'): """ Does the actual owrk fo the plot_resolution_curves function, but requires an axis as an input. It is useful eg. when one desires to have several subplots. :param ax: plot axis (as in matplotlib subplot instance) :param data_to_plot: a FourierCorrelationDataCollection object with the FRC results :param x_idx: if the data contains FRC results with different dynamic range (pixel size), indicate the index of the dataset here with the maximum range :param size: size of the plot :param disable_ax_labels: disable y- and x-axis labels (Correlation, Frequency), in case you want to add them later for instance :return: returns the matplotlib.pyplot.Figure object that you can use to make further modificaitons to the plot """ assert isinstance(data_to_plot, FourierCorrelationDataCollection) angles = list() datasets = list() # Sort datasets by angle. for dataset in data_to_plot: angles.append((int(dataset[0]))) datasets.append(dataset[1]) angles, datasets = list(zip(*sorted(zip(angles, datasets)))) # plot threshold dataset = datasets[int(x_idx)] y = dataset.resolution["threshold"] x = dataset.correlation["frequency"] if x[-1] < 1.0: x = np.append(x, 1.0) y = np.append(y, y[-1]) x_axis = arrayops.safe_divide(x, 2 * dataset.resolution["spacing"]) ax.plot(x_axis, y, linestyle='--', color='#b5b5b3') if not disable_ax_labels: xlabel = r'Frequency ($\mathrm{\mu m}^{-1}$)' ylabel = 'Correlation' ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) for idx, dataset in enumerate(datasets): ax.set_ylim([0, 1.2]) # Plot calculated FRC values as xy scatter. y = dataset.correlation["curve-fit"] x = dataset.correlation["frequency"] x_axis = arrayops.safe_divide(x, 2 * dataset.resolution["spacing"]) ax.plot(x_axis, y, linestyle=line_style) return ax
def calculate_snr_threshold_value(points_x_bin, snr): """ A function to calculate a SNR based resolution threshold, as described in ... :param points_x_bin: a 1D Array containing the numbers of points at each FRC/FSC ring/shell :param snr: the expected SNR value :return: """ nominator = snr + arrayutils.safe_divide(2.0 * np.sqrt(snr) + 1, np.sqrt(points_x_bin)) denominator = snr + 1 + arrayutils.safe_divide(2.0 * np.sqrt(snr), np.sqrt(points_x_bin)) return arrayutils.safe_divide(nominator, denominator)
def calculate_fourier_plane_correlation(image1, image2, args, z_correction=1): steps = np.arange(0, 360, args.d_angle) data = containers.FourierCorrelationDataCollection() for idx, step in enumerate(steps): im1_rot = np.fft.fftshift( np.fft.fftn(rotate(image1, step, reshape=False))) im2_rot = np.fft.fftshift( np.fft.fftn(rotate(image2, step, reshape=False))) numerator = np.sum(im1_rot * np.conjugate(im2_rot), axis=(0, 2)) denominator = np.sum(np.sqrt(np.abs(im1_rot)**2 * np.abs(im2_rot)**2), axis=(0, 2)) correlation = ndarray.safe_divide(numerator, denominator) zero = correlation.size / 2 correlation = correlation[zero:] result = containers.FourierCorrelationData() result.correlation["correlation"] = correlation result.correlation["frequency"] = np.linspace(0, 1.0, num=correlation.size) result.correlation["points-x-bin"] = np.ones( correlation.size) * (im2_rot.shape[2] * im2_rot.shape[0]) data[int(step)] = result analyzer = fsc_analysis.FourierCorrelationAnalysis(data, image1.spacing[0], args) return analyzer.execute(z_correction=z_correction)
def wiener_deconvolution(image, psf, snr=30, add_pad=0): assert isinstance(image, Image) assert isinstance(psf, Image) image_s = Image(image.copy(), image.spacing) orig_shape = image.shape if image.ndim != psf.ndim: raise ValueError("Image and psf dimensions do not match") if psf.spacing != image.spacing: psf = imops.zoom_to_spacing(psf, image.spacing) if add_pad != 0: new_shape = list(i + 2 * add_pad for i in image_s.shape) image_s = imops.zero_pad_to_shape(image_s, new_shape) if psf.shape != image_s.shape: psf = imops.zero_pad_to_shape(psf, image_s.shape) psf /= psf.max() psf_f = fftn(fftshift(psf)) wiener = arrayops.safe_divide( np.abs(psf_f)**2 / (np.abs(psf_f)**2 + snr), psf_f) image_s = fftn(image_s) image_s = Image(np.abs(ifftn(image_s * wiener).real), image.spacing) return imops.remove_zero_padding(image_s, orig_shape)
def calculate_resolution_threshold_curve(data_set, criterion, threshold, snr): """ Calculate the two sigma curve. The FRC should be run first, as the results of the two sigma depend on the number of points on the fourier rings. :return: Adds the """ assert isinstance(data_set, FourierCorrelationData) points_x_bin = data_set.correlation["points-x-bin"] if points_x_bin[-1] == 0: points_x_bin[-1] = points_x_bin[-2] if criterion == 'one-bit': nominator = 0.5 + arrayutils.safe_divide(2.4142, np.sqrt(points_x_bin)) denominator = 1.5 + arrayutils.safe_divide(1.4142, np.sqrt(points_x_bin)) points = arrayutils.safe_divide(nominator, denominator) elif criterion == 'half-bit': nominator = 0.2071 + arrayutils.safe_divide(1.9102, np.sqrt(points_x_bin)) denominator = 1.2071 + arrayutils.safe_divide(0.9102, np.sqrt(points_x_bin)) points = arrayutils.safe_divide(nominator, denominator) elif criterion == 'three-sigma': points = arrayutils.safe_divide(np.full(points_x_bin.shape, 3.0), (np.sqrt(points_x_bin) + 3.0 - 1)) elif criterion == 'fixed': points = threshold * np.ones(len(data_set.correlation["points-x-bin"])) elif criterion == 'snr': points = calculate_snr_threshold_value(points_x_bin, snr) else: raise AttributeError() if criterion != 'fixed': #coeff = np.polyfit(data_set.correlation["frequency"], points, 3) #equation = np.poly1d(coeff) equation = interp1d(data_set.correlation["frequency"], points, kind='slinear') curve = equation(data_set.correlation["frequency"]) else: curve = points equation = None data_set.resolution["threshold"] = curve return equation
def wiener_deconvolution(image, psf, snr=30, add_pad=0): """ A GPU accelerated implementation of a linear Wiener filter. Some effort is made to allow processing even relatively large images, but some kind of block-based processing (as in the RL implementation) may be required in some cases.""" assert isinstance(image, Image) assert isinstance(psf, Image) image_s = Image(image.copy(), image.spacing) orig_shape = image.shape if image.ndim != psf.ndim: raise ValueError("Image and psf dimensions do not match") if psf.spacing != image.spacing: psf = imops.zoom_to_spacing(psf, image.spacing) if add_pad != 0: new_shape = list(i + 2 * add_pad for i in image_s.shape) image_s = imops.zero_pad_to_shape(image_s, new_shape) if psf.shape != image_s.shape: psf = imops.zero_pad_to_shape(psf, image_s.shape) psf /= psf.max() psf = fftshift(psf) psf_dev = cp.asarray(psf.astype(np.complex64)) with get_fft_plan(psf_dev): psf_dev = fftn(psf_dev, overwrite_x=True) below = cp.asnumpy(psf_dev) psf_abs = cp.abs(psf_dev)**2 psf_abs /= (psf_abs + snr) above = cp.asnumpy(psf_abs) psf_abs = None psf_dev = None image_dev = cp.asarray(image_s.astype(np.complex64)) with get_fft_plan(image_dev): image_dev = fftn(image_dev, overwrite_x=True) wiener_dev = cp.asarray(arrayops.safe_divide(above, below)) image_dev *= wiener_dev result = cp.asnumpy(cp.abs(ifftn(image_dev, overwrite_x=True)).real) result = Image(result, image.spacing) return imops.remove_zero_padding(result, orig_shape)
def execute(self): """ Calculate the FRC :return: Returns the FRC results. They are also saved inside the class. The return value is just for convenience. """ data_structure = containers.FourierCorrelationDataCollection() radii, angles = self.iterator.steps freq_nyq = self.iterator.nyquist shape = (angles.shape[0], radii.shape[0]) c1 = np.zeros(shape, dtype=np.float32) c2 = np.zeros(shape, dtype=np.float32) c3 = np.zeros(shape, dtype=np.float32) points = np.zeros(shape, dtype=np.float32) # Iterate through the sphere and calculate initial values for ind_ring, shell_idx, rotation_idx in self.iterator: subset1 = self.fft_image1[ind_ring] subset2 = self.fft_image2[ind_ring] c1[rotation_idx, shell_idx] = np.sum(subset1 * np.conjugate(subset2)).real c2[rotation_idx, shell_idx] = np.sum(np.abs(subset1)**2) c3[rotation_idx, shell_idx] = np.sum(np.abs(subset2)**2) points[rotation_idx, shell_idx] = len(subset1) # Finish up FRC calculation for every rotation angle and sav # results to the data structure. for i in range(angles.size): # Calculate FRC for every orientation spatial_freq = radii.astype(np.float32) / freq_nyq n_points = np.array(points[i]) frc = ndarray.safe_divide(c1[i], np.sqrt(c2[i] * c3[i])) result = containers.FourierCorrelationData() result.correlation["correlation"] = frc result.correlation["frequency"] = spatial_freq result.correlation["points-x-bin"] = n_points # Save result to the structure and move to next # angle data_structure[angles[i]] = result return data_structure
def compute_estimate(self): """ Calculates a single RL fusion estimate. There is no reason to call this function -- it is used internally by the class during fusion process. """ self.estimate_new[:] = np.zeros(self.image_size, dtype=np.float32) # Iterate over blocks iterables = (range(0, m, n) for m, n in zip(self.image_size, self.block_size)) pad = self.options.block_pad block_idx = tuple(slice(pad, pad + block) for block in self.block_size) for pos in itertools.product(*iterables): estimate_idx = tuple(slice(j, j + k) for j, k in zip(pos, self.block_size)) index = np.array(pos, dtype=int) if self.options.block_pad > 0: h_estimate_block = self.get_padded_block(self.estimate, index.copy()).astype(np.complex64) else: h_estimate_block = self.estimate[estimate_idx].astype(np.complex64) # # Execute: cache = convolve(PSF, estimate), non-normalized h_estimate_block_new = self._fft_convolve(h_estimate_block, self.psf_fft) # Execute: cache = data/cache. Add background bias if requested. h_image_block = self.get_padded_block(self.image, index.copy()).astype(np.float32) if self.options.rl_background != 0: h_image_block += self.options.rl_background ops_ext.inverse_division_inplace(h_estimate_block_new, h_image_block) # Execute correlation with PSF h_estimate_block_new = self._fft_convolve(h_estimate_block_new, self.adj_psf_fft).real # Get new weights self.estimate_new[estimate_idx] = h_estimate_block_new[block_idx] # TV Regularization (doesn't seem to do anything miraculous). if self.options.tv_lambda > 0 and self.iteration_count > 0: dv_est = ops_ext.div_unit_grad(self.estimate, self.image_spacing) self.estimate_new = ops_array.safe_divide(self.estimate_new, (1.0 - self.options.rltv_lambda * dv_est)) # Update estimate inplace. Get convergence statistics. return ops_ext.update_estimate_poisson(self.estimate, self.estimate_new, self.options.convergence_epsilon)
def compute_estimate(self): """ Calculates a single RL fusion estimate. There is no reason to call this function -- it is used internally by the class during fusion process. """ print( f'Beginning the computation of the {self.iteration_count + 1}. estimate' ) if "multiplicative" in self.options.fusion_method: self.estimate_new[:] = numpy.ones(self.image_size, dtype=numpy.float32) else: self.estimate_new[:] = numpy.zeros(self.image_size, dtype=numpy.float32) # Iterate over views for idx, view in enumerate(self.views): psf_fft = self.psfs_fft[idx] adj_psf_fft = self.adj_psfs_fft[idx] self.data.set_active_image(view, self.options.channel, self.options.scale, "registered") weighting = self.weights[idx] background = self.background[idx] iterables = (range(0, m, n) for m, n in zip(self.image_size, self.block_size)) pad = self.options.block_pad block_idx = tuple( slice(pad, pad + block) for block in self.block_size) for pos in itertools.product(*iterables): estimate_idx = tuple( slice(j, j + k) for j, k in zip(pos, self.block_size)) index = numpy.array(pos, dtype=int) if self.options.block_pad > 0: h_estimate_block = self.get_padded_block( self.estimate, index.copy()).astype(numpy.complex64) else: h_estimate_block = self.estimate[estimate_idx].astype( numpy.complex64) # Convolve estimate block with the PSF h_estimate_block_new = self._fft_convolve( h_estimate_block, psf_fft) # Apply weighting h_estimate_block_new *= weighting # Apply background h_estimate_block_new += background # Divide image block with the convolution result h_image_block = self.data.get_registered_block( self.block_size, self.options.block_pad, index.copy()).astype(numpy.float32) #h_estimate_block_new = ops_array.safe_divide(h_image_block, h_estimate_block_new) ops_ext.inverse_division_inplace(h_estimate_block_new, h_image_block) # Correlate with adj PSF h_estimate_block_new = self._fft_convolve( h_estimate_block_new, adj_psf_fft).real # Update the contribution from a single view to the new estimate self._write_estimate_block(h_estimate_block_new, estimate_idx, block_idx) # Divide with the number of projections if "summative" in self.options.fusion_method: # self.estimate_new[:] = self.float_vmult(self.estimate_new, # self.scaler) self.estimate_new *= (1.0 / self.n_views) else: self.estimate_new[self.estimate_new < 0] = 0 self.estimate_new[:] = ops_array.nroot(self.estimate_new, self.n_views) # TV Regularization (doesn't seem to do anything miraculous). if self.options.tv_lambda > 0 and self.iteration_count > 0: dv_est = ops_ext.div_unit_grad(self.estimate, self.voxel_size) self.estimate_new = ops_array.safe_divide( self.estimate, (1.0 - self.options.rltv_lambda * dv_est)) # Update estimate inplace. Get convergence statistics. return ops_ext.update_estimate_poisson( self.estimate, self.estimate_new, self.options.convergence_epsilon)
def __make_printable_frc_subplot(ax, frc, title=None, coerce_ticks=True): """ Creates a plot of the FRC curves in the curve_list. Single or multiple vurves can be plotted. """ assert isinstance(frc, FourierCorrelationData) # # Font setting # font0 = FontProperties() # font1 = font0.copy() # font1.set_size('medium') # font = font1.copy() # font.set_family('sans') # rc('text', usetex=True) # Enable grid gridLineWidth = 0.2 # ax.yaxis.grid(True, linewidth=gridLineWidth, linestyle='-', color='0.05') # Marker setup colorArray = [ 'blue', 'green', 'red', 'orange', 'brown', 'black', 'violet', 'pink' ] marker_array = ['^', 's', 'o', 'd', '1', 'v', '*', 'p'] # Axis labelling xlabel = 'Frequency' ylabel = 'Correlation' # ax.set_xlabel(xlabel, fontsize=12, position=(0.5, -0.2)) # ax.set_ylabel(ylabel, fontsize=12, position=(0.5, 0.5)) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_ylim([0, 1.2]) # Title if title is not None: ax.set_title(title) # Plot calculated FRC values as xy scatter. y = frc.correlation["correlation"] x_raw = frc.correlation["frequency"] x = arrayops.safe_divide(x_raw, 2 * frc.resolution["spacing"]) ax.plot(x, y, marker_array[0], color='#b5b5b3', label='FRC') # Plot polynomial fit as a line plot over the FRC scatter y = frc.correlation["curve-fit"] ax.plot(x, y, color='#61a2da', label='Least-squares fit') # Plot the resolution threshold curve y = frc.resolution["threshold"] res_crit = frc.resolution["criterion"] if res_crit == 'one-bit': label = 'One-bit curve' elif res_crit == 'half-bit': label = 'Half-bit curve' elif res_crit == 'fixed': label = 'y = %f' % y[0] else: label = "Threshold" if x_raw[-1] < 1.0: x_th = arrayops.safe_divide(np.append(x_raw, 1.0), 2 * frc.resolution["spacing"]) y = np.append(y, y[-1]) else: x_th = x ax.plot(x_th, y, color='#d77186', label=label, linestyle='--') # Plot resolution point y0 = frc.resolution["resolution-point"][0] x0 = frc.resolution["resolution-point"][1] / ( 2 * frc.resolution["spacing"]) ax.plot(x0, y0, 'ro', label='Resolution point', color='#D75725') verts = [(x0, 0), (x0, y0)] xs, ys = list(zip(*verts)) ax.plot(xs, ys, 'x--', color='#D75725', ms=10)
def __make_frc_subplot(ax, frc, title): """ Creates a plot of the FRC curves in the curve_list. Single or multiple vurves can be plotted. """ assert isinstance(frc, FourierCorrelationData) # # Font setting # font0 = FontProperties() # font1 = font0.copy() # font1.set_size('medium') # font = font1.copy() # font.set_family('sans') # rc('text', usetex=True) # Enable grid gridLineWidth = 0.2 # ax.yaxis.grid(True, linewidth=gridLineWidth, linestyle='-', color='0.05') # Axis labelling xlabel = 'Frequency (1/um)' ylabel = 'Correlation' # ax.set_xlabel(xlabel, fontsize=12, position=(0.5, -0.2)) # ax.set_ylabel(ylabel, fontsize=12, position=(0.5, 0.5)) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_ylim([0, 1.2]) # Title ax.set_title(title) # Plot calculated FRC values as xy scatter. y = frc.correlation["correlation"] x = frc.correlation["frequency"] x_axis = arrayops.safe_divide(x, 2 * frc.resolution["spacing"]) ax.plot(x_axis, y, '^', markersize=6, color='#b5b5b3', label='FRC') # Plot polynomial fit as a line plot over the FRC scatter y = frc.correlation["curve-fit"] ax.plot(x_axis, y, linewidth=3, color='#61a2da', label='Least-squares fit') # Plot the resolution threshold curve y = frc.resolution["threshold"] res_crit = frc.resolution["criterion"] if res_crit == 'one-bit': label = 'One-bit curve' elif res_crit == 'half-bit': label = 'Half-bit curve' elif res_crit == 'fixed': label = 'y = %f' % y[0] else: label = "Threshold" if x[-1] < 1.0: x = np.append(x, 1.0) y = np.append(y, y[-1]) x_axis = arrayops.safe_divide(x, 2 * frc.resolution["spacing"]) ax.plot(x_axis, y, color='#d77186', label=label, lw=2, linestyle='--') # Plot resolution point y0 = frc.resolution["resolution-point"][0] x0 = frc.resolution["resolution-point"][1] / ( 2 * frc.resolution["spacing"]) ax.plot(x0, y0, 'ro', markersize=8, label='Resolution point', color='#D75725') verts = [(x0, 0), (x0, y0)] xs, ys = list(zip(*verts)) ax.plot(xs, ys, 'x--', lw=3, color='#D75725', ms=10) # ax.text(x0, y0 + 0.10, 'RESOL-FREQ', fontsize=12) resolution = "The resolution is {} um.".format( frc.resolution["resolution"]) ax.text(0.5, -0.3, resolution, ha="center", fontsize=12)