Esempio n. 1
0
def resolution_curves_subplot(ax, data_to_plot, x_idx=0, disable_ax_labels=False, line_style='-'):
    """
    Does the actual owrk fo the plot_resolution_curves function, but requires an axis as an input. It is useful eg.
    when one desires to have several subplots.

    :param ax:           plot axis (as in matplotlib subplot instance)
    :param data_to_plot: a FourierCorrelationDataCollection object with the FRC results
    :param x_idx:        if the data contains FRC results with different dynamic range (pixel size), indicate
                         the index of the dataset here with the maximum range
    :param size:        size of the plot
    :param disable_ax_labels: disable y- and x-axis labels (Correlation, Frequency), in case you want to add them
                        later for instance
    :return:            returns the matplotlib.pyplot.Figure object that you can use to make further modificaitons
                        to the plot
        """
    assert isinstance(data_to_plot, FourierCorrelationDataCollection)

    angles = list()
    datasets = list()

    # Sort datasets by angle.
    for dataset in data_to_plot:
        angles.append((int(dataset[0])))
        datasets.append(dataset[1])

    angles, datasets = list(zip(*sorted(zip(angles, datasets))))

    # plot threshold
    dataset = datasets[int(x_idx)]

    y = dataset.resolution["threshold"]
    x = dataset.correlation["frequency"]
    if x[-1] < 1.0:
        x = np.append(x, 1.0)
        y = np.append(y, y[-1])

    x_axis = arrayops.safe_divide(x, 2 * dataset.resolution["spacing"])

    ax.plot(x_axis, y, linestyle='--', color='#b5b5b3')

    if not disable_ax_labels:
        xlabel = r'Frequency ($\mathrm{\mu m}^{-1}$)'
        ylabel = 'Correlation'
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)

    for idx, dataset in enumerate(datasets):
        ax.set_ylim([0, 1.2])

        # Plot calculated FRC values as xy scatter.
        y = dataset.correlation["curve-fit"]
        x = dataset.correlation["frequency"]
        x_axis = arrayops.safe_divide(x, 2 * dataset.resolution["spacing"])

        ax.plot(x_axis, y, linestyle=line_style)

    return ax
Esempio n. 2
0
def calculate_snr_threshold_value(points_x_bin, snr):
    """
    A function to calculate a SNR based resolution threshold, as described
    in ...

    :param points_x_bin: a 1D Array containing the numbers of points at each
    FRC/FSC ring/shell
    :param snr: the expected SNR value
    :return:
    """
    nominator = snr + arrayutils.safe_divide(2.0 * np.sqrt(snr) + 1,
                                             np.sqrt(points_x_bin))
    denominator = snr + 1 + arrayutils.safe_divide(2.0 * np.sqrt(snr),
                                                   np.sqrt(points_x_bin))
    return arrayutils.safe_divide(nominator, denominator)
Esempio n. 3
0
def calculate_fourier_plane_correlation(image1, image2, args, z_correction=1):
    steps = np.arange(0, 360, args.d_angle)
    data = containers.FourierCorrelationDataCollection()

    for idx, step in enumerate(steps):
        im1_rot = np.fft.fftshift(
            np.fft.fftn(rotate(image1, step, reshape=False)))
        im2_rot = np.fft.fftshift(
            np.fft.fftn(rotate(image2, step, reshape=False)))

        numerator = np.sum(im1_rot * np.conjugate(im2_rot), axis=(0, 2))
        denominator = np.sum(np.sqrt(np.abs(im1_rot)**2 * np.abs(im2_rot)**2),
                             axis=(0, 2))

        correlation = ndarray.safe_divide(numerator, denominator)

        zero = correlation.size / 2
        correlation = correlation[zero:]

        result = containers.FourierCorrelationData()
        result.correlation["correlation"] = correlation
        result.correlation["frequency"] = np.linspace(0,
                                                      1.0,
                                                      num=correlation.size)
        result.correlation["points-x-bin"] = np.ones(
            correlation.size) * (im2_rot.shape[2] * im2_rot.shape[0])

        data[int(step)] = result

    analyzer = fsc_analysis.FourierCorrelationAnalysis(data, image1.spacing[0],
                                                       args)
    return analyzer.execute(z_correction=z_correction)
Esempio n. 4
0
def wiener_deconvolution(image, psf, snr=30, add_pad=0):
    assert isinstance(image, Image)
    assert isinstance(psf, Image)

    image_s = Image(image.copy(), image.spacing)
    orig_shape = image.shape

    if image.ndim != psf.ndim:
        raise ValueError("Image and psf dimensions do not match")

    if psf.spacing != image.spacing:
        psf = imops.zoom_to_spacing(psf, image.spacing)

    if add_pad != 0:
        new_shape = list(i + 2 * add_pad for i in image_s.shape)
        image_s = imops.zero_pad_to_shape(image_s, new_shape)

    if psf.shape != image_s.shape:
        psf = imops.zero_pad_to_shape(psf, image_s.shape)

    psf /= psf.max()

    psf_f = fftn(fftshift(psf))

    wiener = arrayops.safe_divide(
        np.abs(psf_f)**2 / (np.abs(psf_f)**2 + snr), psf_f)

    image_s = fftn(image_s)

    image_s = Image(np.abs(ifftn(image_s * wiener).real), image.spacing)

    return imops.remove_zero_padding(image_s, orig_shape)
Esempio n. 5
0
def calculate_resolution_threshold_curve(data_set, criterion, threshold, snr):
    """
    Calculate the two sigma curve. The FRC should be run first, as the results of the two sigma
    depend on the number of points on the fourier rings.

    :return:  Adds the
    """
    assert isinstance(data_set, FourierCorrelationData)

    points_x_bin = data_set.correlation["points-x-bin"]

    if points_x_bin[-1] == 0:
        points_x_bin[-1] = points_x_bin[-2]

    if criterion == 'one-bit':
        nominator = 0.5 + arrayutils.safe_divide(2.4142, np.sqrt(points_x_bin))
        denominator = 1.5 + arrayutils.safe_divide(1.4142,
                                                   np.sqrt(points_x_bin))
        points = arrayutils.safe_divide(nominator, denominator)

    elif criterion == 'half-bit':
        nominator = 0.2071 + arrayutils.safe_divide(1.9102,
                                                    np.sqrt(points_x_bin))
        denominator = 1.2071 + arrayutils.safe_divide(0.9102,
                                                      np.sqrt(points_x_bin))
        points = arrayutils.safe_divide(nominator, denominator)

    elif criterion == 'three-sigma':
        points = arrayutils.safe_divide(np.full(points_x_bin.shape, 3.0),
                                        (np.sqrt(points_x_bin) + 3.0 - 1))

    elif criterion == 'fixed':
        points = threshold * np.ones(len(data_set.correlation["points-x-bin"]))
    elif criterion == 'snr':
        points = calculate_snr_threshold_value(points_x_bin, snr)

    else:
        raise AttributeError()

    if criterion != 'fixed':
        #coeff = np.polyfit(data_set.correlation["frequency"], points, 3)
        #equation = np.poly1d(coeff)
        equation = interp1d(data_set.correlation["frequency"],
                            points,
                            kind='slinear')
        curve = equation(data_set.correlation["frequency"])
    else:
        curve = points
        equation = None

    data_set.resolution["threshold"] = curve
    return equation
Esempio n. 6
0
def wiener_deconvolution(image, psf, snr=30, add_pad=0):
    """ A GPU accelerated implementation of a linear Wiener filter. Some effort is made
    to allow processing even relatively large images, but some kind of block-based processing
     (as in the RL implementation) may be required in some cases."""
    assert isinstance(image, Image)
    assert isinstance(psf, Image)

    image_s = Image(image.copy(), image.spacing)
    orig_shape = image.shape

    if image.ndim != psf.ndim:
        raise ValueError("Image and psf dimensions do not match")

    if psf.spacing != image.spacing:
        psf = imops.zoom_to_spacing(psf, image.spacing)

    if add_pad != 0:
        new_shape = list(i + 2 * add_pad for i in image_s.shape)
        image_s = imops.zero_pad_to_shape(image_s, new_shape)

    if psf.shape != image_s.shape:
        psf = imops.zero_pad_to_shape(psf, image_s.shape)

    psf /= psf.max()
    psf = fftshift(psf)

    psf_dev = cp.asarray(psf.astype(np.complex64))
    with get_fft_plan(psf_dev):
        psf_dev = fftn(psf_dev, overwrite_x=True)

    below = cp.asnumpy(psf_dev)
    psf_abs = cp.abs(psf_dev)**2
    psf_abs /= (psf_abs + snr)
    above = cp.asnumpy(psf_abs)
    psf_abs = None
    psf_dev = None

    image_dev = cp.asarray(image_s.astype(np.complex64))
    with get_fft_plan(image_dev):
        image_dev = fftn(image_dev, overwrite_x=True)

    wiener_dev = cp.asarray(arrayops.safe_divide(above, below))

    image_dev *= wiener_dev

    result = cp.asnumpy(cp.abs(ifftn(image_dev, overwrite_x=True)).real)
    result = Image(result, image.spacing)

    return imops.remove_zero_padding(result, orig_shape)
Esempio n. 7
0
    def execute(self):
        """
        Calculate the FRC
        :return: Returns the FRC results. They are also saved inside the class.
                 The return value is just for convenience.
        """

        data_structure = containers.FourierCorrelationDataCollection()
        radii, angles = self.iterator.steps
        freq_nyq = self.iterator.nyquist
        shape = (angles.shape[0], radii.shape[0])
        c1 = np.zeros(shape, dtype=np.float32)
        c2 = np.zeros(shape, dtype=np.float32)
        c3 = np.zeros(shape, dtype=np.float32)
        points = np.zeros(shape, dtype=np.float32)

        # Iterate through the sphere and calculate initial values
        for ind_ring, shell_idx, rotation_idx in self.iterator:
            subset1 = self.fft_image1[ind_ring]
            subset2 = self.fft_image2[ind_ring]

            c1[rotation_idx,
               shell_idx] = np.sum(subset1 * np.conjugate(subset2)).real
            c2[rotation_idx, shell_idx] = np.sum(np.abs(subset1)**2)
            c3[rotation_idx, shell_idx] = np.sum(np.abs(subset2)**2)

            points[rotation_idx, shell_idx] = len(subset1)

        # Finish up FRC calculation for every rotation angle and sav
        # results to the data structure.
        for i in range(angles.size):

            # Calculate FRC for every orientation
            spatial_freq = radii.astype(np.float32) / freq_nyq
            n_points = np.array(points[i])
            frc = ndarray.safe_divide(c1[i], np.sqrt(c2[i] * c3[i]))

            result = containers.FourierCorrelationData()
            result.correlation["correlation"] = frc
            result.correlation["frequency"] = spatial_freq
            result.correlation["points-x-bin"] = n_points

            # Save result to the structure and move to next
            # angle
            data_structure[angles[i]] = result

        return data_structure
Esempio n. 8
0
    def compute_estimate(self):
        """
            Calculates a single RL fusion estimate. There is no reason to call this
            function -- it is used internally by the class during fusion process.
        """
        self.estimate_new[:] = np.zeros(self.image_size, dtype=np.float32)

        # Iterate over blocks
        iterables = (range(0, m, n) for m, n in zip(self.image_size, self.block_size))
        pad = self.options.block_pad
        block_idx = tuple(slice(pad, pad + block) for block in self.block_size)

        for pos in itertools.product(*iterables):

            estimate_idx = tuple(slice(j, j + k) for j, k in zip(pos, self.block_size))
            index = np.array(pos, dtype=int)

            if self.options.block_pad > 0:
                h_estimate_block = self.get_padded_block(self.estimate, index.copy()).astype(np.complex64)
            else:
                h_estimate_block = self.estimate[estimate_idx].astype(np.complex64)

            # # Execute: cache = convolve(PSF, estimate), non-normalized
            h_estimate_block_new = self._fft_convolve(h_estimate_block, self.psf_fft)

            # Execute: cache = data/cache. Add background bias if requested.
            h_image_block = self.get_padded_block(self.image, index.copy()).astype(np.float32)
            if self.options.rl_background != 0:
                h_image_block += self.options.rl_background
            ops_ext.inverse_division_inplace(h_estimate_block_new, h_image_block)

            # Execute correlation with PSF
            h_estimate_block_new = self._fft_convolve(h_estimate_block_new, self.adj_psf_fft).real

            # Get new weights
            self.estimate_new[estimate_idx] = h_estimate_block_new[block_idx]

        # TV Regularization (doesn't seem to do anything miraculous).
        if self.options.tv_lambda > 0 and self.iteration_count > 0:
            dv_est = ops_ext.div_unit_grad(self.estimate, self.image_spacing)
            self.estimate_new = ops_array.safe_divide(self.estimate_new,
                                                      (1.0 - self.options.rltv_lambda * dv_est))

        # Update estimate inplace. Get convergence statistics.
        return ops_ext.update_estimate_poisson(self.estimate,
                                               self.estimate_new,
                                               self.options.convergence_epsilon)
Esempio n. 9
0
    def compute_estimate(self):
        """
            Calculates a single RL fusion estimate. There is no reason to call this
            function -- it is used internally by the class during fusion process.
        """
        print(
            f'Beginning the computation of the {self.iteration_count + 1}. estimate'
        )

        if "multiplicative" in self.options.fusion_method:
            self.estimate_new[:] = numpy.ones(self.image_size,
                                              dtype=numpy.float32)
        else:
            self.estimate_new[:] = numpy.zeros(self.image_size,
                                               dtype=numpy.float32)

        # Iterate over views
        for idx, view in enumerate(self.views):

            psf_fft = self.psfs_fft[idx]
            adj_psf_fft = self.adj_psfs_fft[idx]

            self.data.set_active_image(view, self.options.channel,
                                       self.options.scale, "registered")

            weighting = self.weights[idx]
            background = self.background[idx]

            iterables = (range(0, m, n)
                         for m, n in zip(self.image_size, self.block_size))
            pad = self.options.block_pad
            block_idx = tuple(
                slice(pad, pad + block) for block in self.block_size)

            for pos in itertools.product(*iterables):

                estimate_idx = tuple(
                    slice(j, j + k) for j, k in zip(pos, self.block_size))
                index = numpy.array(pos, dtype=int)

                if self.options.block_pad > 0:
                    h_estimate_block = self.get_padded_block(
                        self.estimate, index.copy()).astype(numpy.complex64)
                else:
                    h_estimate_block = self.estimate[estimate_idx].astype(
                        numpy.complex64)

                # Convolve estimate block with the PSF
                h_estimate_block_new = self._fft_convolve(
                    h_estimate_block, psf_fft)

                # Apply weighting
                h_estimate_block_new *= weighting

                # Apply background
                h_estimate_block_new += background

                # Divide image block with the convolution result
                h_image_block = self.data.get_registered_block(
                    self.block_size, self.options.block_pad,
                    index.copy()).astype(numpy.float32)

                #h_estimate_block_new = ops_array.safe_divide(h_image_block, h_estimate_block_new)
                ops_ext.inverse_division_inplace(h_estimate_block_new,
                                                 h_image_block)

                # Correlate with adj PSF
                h_estimate_block_new = self._fft_convolve(
                    h_estimate_block_new, adj_psf_fft).real

                # Update the contribution from a single view to the new estimate
                self._write_estimate_block(h_estimate_block_new, estimate_idx,
                                           block_idx)

        # Divide with the number of projections
        if "summative" in self.options.fusion_method:
            # self.estimate_new[:] = self.float_vmult(self.estimate_new,
            #                                         self.scaler)
            self.estimate_new *= (1.0 / self.n_views)
        else:
            self.estimate_new[self.estimate_new < 0] = 0
            self.estimate_new[:] = ops_array.nroot(self.estimate_new,
                                                   self.n_views)

        # TV Regularization (doesn't seem to do anything miraculous).
        if self.options.tv_lambda > 0 and self.iteration_count > 0:
            dv_est = ops_ext.div_unit_grad(self.estimate, self.voxel_size)
            self.estimate_new = ops_array.safe_divide(
                self.estimate, (1.0 - self.options.rltv_lambda * dv_est))

        # Update estimate inplace. Get convergence statistics.
        return ops_ext.update_estimate_poisson(
            self.estimate, self.estimate_new, self.options.convergence_epsilon)
Esempio n. 10
0
    def __make_printable_frc_subplot(ax, frc, title=None, coerce_ticks=True):
        """
        Creates a plot of the FRC curves in the curve_list. Single or multiple vurves can
        be plotted.
        """
        assert isinstance(frc, FourierCorrelationData)

        # # Font setting
        # font0 = FontProperties()
        # font1 = font0.copy()
        # font1.set_size('medium')
        # font = font1.copy()
        # font.set_family('sans')
        # rc('text', usetex=True)

        # Enable grid
        gridLineWidth = 0.2
        # ax.yaxis.grid(True, linewidth=gridLineWidth, linestyle='-', color='0.05')

        # Marker setup
        colorArray = [
            'blue', 'green', 'red', 'orange', 'brown', 'black', 'violet',
            'pink'
        ]
        marker_array = ['^', 's', 'o', 'd', '1', 'v', '*', 'p']

        # Axis labelling
        xlabel = 'Frequency'
        ylabel = 'Correlation'
        # ax.set_xlabel(xlabel, fontsize=12, position=(0.5, -0.2))
        # ax.set_ylabel(ylabel, fontsize=12, position=(0.5, 0.5))
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)

        ax.set_ylim([0, 1.2])

        # Title
        if title is not None:
            ax.set_title(title)

        # Plot calculated FRC values as xy scatter.
        y = frc.correlation["correlation"]
        x_raw = frc.correlation["frequency"]
        x = arrayops.safe_divide(x_raw, 2 * frc.resolution["spacing"])
        ax.plot(x, y, marker_array[0], color='#b5b5b3', label='FRC')

        # Plot polynomial fit as a line plot over the FRC scatter
        y = frc.correlation["curve-fit"]
        ax.plot(x, y, color='#61a2da', label='Least-squares fit')

        # Plot the resolution threshold curve
        y = frc.resolution["threshold"]
        res_crit = frc.resolution["criterion"]
        if res_crit == 'one-bit':
            label = 'One-bit curve'
        elif res_crit == 'half-bit':
            label = 'Half-bit curve'
        elif res_crit == 'fixed':
            label = 'y = %f' % y[0]
        else:
            label = "Threshold"

        if x_raw[-1] < 1.0:
            x_th = arrayops.safe_divide(np.append(x_raw, 1.0),
                                        2 * frc.resolution["spacing"])
            y = np.append(y, y[-1])
        else:
            x_th = x

        ax.plot(x_th, y, color='#d77186', label=label, linestyle='--')

        # Plot resolution point
        y0 = frc.resolution["resolution-point"][0]
        x0 = frc.resolution["resolution-point"][1] / (
            2 * frc.resolution["spacing"])

        ax.plot(x0, y0, 'ro', label='Resolution point', color='#D75725')

        verts = [(x0, 0), (x0, y0)]
        xs, ys = list(zip(*verts))

        ax.plot(xs, ys, 'x--', color='#D75725', ms=10)
Esempio n. 11
0
    def __make_frc_subplot(ax, frc, title):
        """
        Creates a plot of the FRC curves in the curve_list. Single or multiple vurves can
        be plotted.
        """
        assert isinstance(frc, FourierCorrelationData)

        # # Font setting
        # font0 = FontProperties()
        # font1 = font0.copy()
        # font1.set_size('medium')
        # font = font1.copy()
        # font.set_family('sans')
        # rc('text', usetex=True)

        # Enable grid
        gridLineWidth = 0.2
        # ax.yaxis.grid(True, linewidth=gridLineWidth, linestyle='-', color='0.05')

        # Axis labelling
        xlabel = 'Frequency (1/um)'
        ylabel = 'Correlation'
        # ax.set_xlabel(xlabel, fontsize=12, position=(0.5, -0.2))
        # ax.set_ylabel(ylabel, fontsize=12, position=(0.5, 0.5))
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)

        ax.set_ylim([0, 1.2])

        # Title
        ax.set_title(title)

        # Plot calculated FRC values as xy scatter.
        y = frc.correlation["correlation"]
        x = frc.correlation["frequency"]
        x_axis = arrayops.safe_divide(x, 2 * frc.resolution["spacing"])

        ax.plot(x_axis, y, '^', markersize=6, color='#b5b5b3', label='FRC')

        # Plot polynomial fit as a line plot over the FRC scatter
        y = frc.correlation["curve-fit"]
        ax.plot(x_axis,
                y,
                linewidth=3,
                color='#61a2da',
                label='Least-squares fit')

        # Plot the resolution threshold curve
        y = frc.resolution["threshold"]
        res_crit = frc.resolution["criterion"]
        if res_crit == 'one-bit':
            label = 'One-bit curve'
        elif res_crit == 'half-bit':
            label = 'Half-bit curve'
        elif res_crit == 'fixed':
            label = 'y = %f' % y[0]
        else:
            label = "Threshold"

        if x[-1] < 1.0:
            x = np.append(x, 1.0)
            y = np.append(y, y[-1])

        x_axis = arrayops.safe_divide(x, 2 * frc.resolution["spacing"])

        ax.plot(x_axis, y, color='#d77186', label=label, lw=2, linestyle='--')

        # Plot resolution point
        y0 = frc.resolution["resolution-point"][0]
        x0 = frc.resolution["resolution-point"][1] / (
            2 * frc.resolution["spacing"])

        ax.plot(x0,
                y0,
                'ro',
                markersize=8,
                label='Resolution point',
                color='#D75725')

        verts = [(x0, 0), (x0, y0)]
        xs, ys = list(zip(*verts))

        ax.plot(xs, ys, 'x--', lw=3, color='#D75725', ms=10)
        # ax.text(x0, y0 + 0.10, 'RESOL-FREQ', fontsize=12)

        resolution = "The resolution is {} um.".format(
            frc.resolution["resolution"])
        ax.text(0.5, -0.3, resolution, ha="center", fontsize=12)