예제 #1
0
    def moments(cls, img, query_size):
        """Calculates the mean and standard deviation for each window of size (query_size x query_size)
        in the micrograph.

        Args:
            img: Micrograph image.
            query_size: Size of windows for which to compute mean and std.

        Returns:
            A matrix of mean intensity and a matrix of variance, each containing a single
            entry for each possible (query_size x query_size) window in the micrograph.
        """

        filt = xp.ones((query_size, query_size)) / (query_size * query_size)
        filt = xp.pad(filt, (0, img.shape[0] - 1),
                      "constant",
                      constant_values=(0, 0))
        filt_freq = fft.fft2(filt, axes=(0, 1))

        pad_img = xp.pad(img, (0, query_size - 1),
                         "constant",
                         constant_values=(0, 0))
        img_freq = fft.fft2(pad_img, axes=(0, 1))

        mean_freq = xp.multiply(img_freq, filt_freq)
        mean_all = fft.ifft2(mean_freq, axes=(0, 1)).real

        pad_img_square = np.square(pad_img)
        img_var_freq = fft.fft2(pad_img_square, axes=(0, 1))
        var_freq = xp.multiply(img_var_freq, filt_freq)
        var_all = fft.ifft2(var_freq, axes=(0, 1))
        var_all = var_all.real - xp.power(mean_all, 2)
        std_all = xp.sqrt(xp.maximum(0, var_all))

        return mean_all, std_all
예제 #2
0
def _im_translate2(im, shifts):
    """
    Translate image by shifts
    :param im: An Image instance to be translated.
    :param shifts: An array of size n-by-2 specifying the shifts in pixels.
        Alternatively, it can be a row vector of length 2, in which case the same shifts is applied to each image.
    :return: An Image instance translated by the shifts.

    TODO: This implementation has been moved here from aspire.aspire.abinitio and is faster than _im_translate.
    """

    if not isinstance(im, Image):
        logger.warning(
            "_im_translate2 expects an Image, attempting to convert array."
            "Expects array of size n-by-L-by-L.")
        im = Image(im)

    if shifts.ndim == 1:
        shifts = shifts[np.newaxis, :]

    n_shifts = shifts.shape[0]

    if shifts.shape[1] != 2:
        raise ValueError("Input `shifts` must be of size n-by-2")

    if n_shifts != 1 and n_shifts != im.n_images:
        raise ValueError(
            "The number of shifts must be 1 or match the number of images")

    resolution = im.res
    grid = xp.asnumpy(
        fft.ifftshift(
            xp.asarray(np.ceil(np.arange(-resolution / 2, resolution / 2)))))
    om_y, om_x = np.meshgrid(grid, grid)
    phase_shifts = np.einsum("ij, k -> ijk", om_x, shifts[:, 0]) + np.einsum(
        "ij, k -> ijk", om_y, shifts[:, 1])
    # TODO: figure out how why the result of einsum requires reshape
    phase_shifts = phase_shifts.reshape(n_shifts, resolution, resolution)
    phase_shifts /= resolution

    mult_f = np.exp(-2 * np.pi * 1j * phase_shifts)
    im_f = xp.asnumpy(fft.fft2(xp.asarray(im.asnumpy())))
    im_translated_f = im_f * mult_f
    im_translated = np.real(xp.asnumpy(fft.ifft2(xp.asarray(im_translated_f))))

    return Image(im_translated)
예제 #3
0
    def _im_translate(self, shifts):
        """
        Translate image by shifts
        :param im: An array of size n-by-L-by-L containing images to be translated.
        :param shifts: An array of size n-by-2 specifying the shifts in pixels.
            Alternatively, it can be a row vector of length 2, in which case the same shifts is applied to each image.
        :return: The images translated by the shifts, with periodic boundaries.

        TODO: This implementation is slower than _im_translate2
        """
        im = self.data

        if shifts.ndim == 1:
            shifts = shifts[np.newaxis, :]
        n_shifts = shifts.shape[0]

        ensure(shifts.shape[-1] == 2, "shifts must be nx2")

        ensure(
            n_shifts == 1 or n_shifts == self.n_images,
            "number of shifts must be 1 or match the number of images",
        )
        # Cast shifts to this instance's internal dtype
        shifts = shifts.astype(self.dtype)

        L = self.res
        im_f = xp.asnumpy(fft.fft2(xp.asarray(im)))
        grid_shifted = fft.ifftshift(
            xp.asarray(np.ceil(np.arange(-L / 2, L / 2, dtype=self.dtype))))
        grid_1d = xp.asnumpy(grid_shifted) * 2 * np.pi / L
        om_x, om_y = np.meshgrid(grid_1d, grid_1d, indexing="ij")

        phase_shifts_x = -shifts[:, 0].reshape((n_shifts, 1, 1))
        phase_shifts_y = -shifts[:, 1].reshape((n_shifts, 1, 1))

        phase_shifts = (om_x[np.newaxis, :, :] * phase_shifts_x +
                        om_y[np.newaxis, :, :] * phase_shifts_y)
        mult_f = np.exp(-1j * phase_shifts)
        im_translated_f = im_f * mult_f
        im_translated = xp.asnumpy(fft.ifft2(xp.asarray(im_translated_f)))
        im_translated = np.real(im_translated)

        return Image(im_translated)
예제 #4
0
def downsample(insamples, szout, mask=None):
    """
    Blur and downsample 1D to 3D objects such as, curves, images or volumes

    The function handles odd and even-sized arrays correctly. The center of
    an odd array is taken to be at (n+1)/2, and an even array is n/2+1.
    :param insamples: Set of objects to be downsampled in the form of an array.\
    the first dimension is the number of objects.
    :param szout: The desired resolution of for output objects.
    :return: An array consists of the blurred and downsampled objects.
    """

    ensure(
        insamples.ndim - 1 == np.size(szout),
        "The number of downsampling dimensions is not the same as that of objects.",
    )

    L_in = insamples.shape[1]
    L_out = szout[0]
    ndata = insamples.shape[0]
    outdims = np.r_[ndata, szout]

    outsamples = np.zeros(outdims, dtype=insamples.dtype)

    if mask is None:
        mask = 1.0

    if insamples.ndim == 2:
        # stack of one dimension objects

        for idata in range(ndata):
            insamples_shifted = fft.fftshift(fft.fft(xp.asarray(insamples[idata])))
            insamples_fft = crop_pad(insamples_shifted, L_out) * mask

            outsamples_shifted = fft.ifft(fft.ifftshift(xp.asarray(insamples_fft)))
            outsamples[idata] = np.real(xp.asnumpy(outsamples_shifted) * (L_out / L_in))

    elif insamples.ndim == 3:
        # stack of two dimension objects
        for idata in range(ndata):
            insamples_shifted = fft.fftshift(fft.fft2(xp.asarray(insamples[idata])))
            insamples_fft = crop_pad(insamples_shifted, L_out) * mask

            outsamples_shifted = fft.ifft2(fft.ifftshift(xp.asarray(insamples_fft)))
            outsamples[idata] = np.real(
                xp.asnumpy(outsamples_shifted) * (L_out ** 2 / L_in ** 2)
            )

    elif insamples.ndim == 4:
        # stack of three dimension objects
        for idata in range(ndata):
            insamples_shifted = fft.fftshift(
                fft.fftn(xp.asarray(insamples[idata]), axes=(0, 1, 2))
            )
            insamples_fft = crop_pad(insamples_shifted, L_out) * mask

            outsamples_shifted = fft.ifftn(
                fft.ifftshift(xp.asarray(insamples_fft)), axes=(0, 1, 2)
            )
            outsamples[idata] = np.real(
                xp.asnumpy(outsamples_shifted) * (L_out ** 3 / L_in ** 3)
            )

    else:
        raise RuntimeError("Number of dimensions > 3 for input objects.")

    return outsamples
예제 #5
0
 def _work(index):
     reference_box_i = fft.fft2(reference_box[index], axes=(0, 1))
     window_t = xp.multiply(reference_box_i, query_box)
     cc = fft.ifft2(window_t, axes=(2, 3))
     return index, cc.real.max((2, 3)) - cc.real.mean((2, 3))
예제 #6
0
    def query_score(self, show_progress=True):
        """Calculates score for each query image.

        Extracts query images and reference windows. Computes the cross-correlation between these
        windows, and applies a threshold to compute a score for each query image.

        Args:
            show_progress: Whether to show a progress bar

        Returns:
            Matrix containing a score for each query image.
        """

        micro_img = xp.asarray(self.im)
        logger.info("Extracting query images")
        query_box = PickerHelper.extract_query(micro_img, self.query_size // 2)
        logger.info("Extracting query images complete")

        query_box = xp.conj(fft.fft2(query_box, axes=(2, 3)))

        reference_box = PickerHelper.extract_references(
            micro_img, self.query_size, self.container_size)

        reference_size = PickerHelper.reference_size(micro_img,
                                                     self.container_size)
        conv_map = xp.zeros(
            (reference_size, query_box.shape[0], query_box.shape[1]))

        def _work(index):
            reference_box_i = fft.fft2(reference_box[index], axes=(0, 1))
            window_t = xp.multiply(reference_box_i, query_box)
            cc = fft.ifft2(window_t, axes=(2, 3))
            return index, cc.real.max((2, 3)) - cc.real.mean((2, 3))

        n_works = reference_size
        n_threads = config.apple.conv_map_nthreads
        pbar = tqdm(total=reference_size, disable=not show_progress)

        # Ideally we'd like something like 'SerialExecutor' to enable easy debugging
        # but for now do an if-else
        if n_threads > 1:
            with futures.ThreadPoolExecutor(n_threads) as executor:
                to_do = [executor.submit(_work, i) for i in range(n_works)]

                for future in futures.as_completed(to_do):
                    i, res = future.result()
                    conv_map[i, :, :] = res
                    pbar.update(1)
        else:
            for i in range(n_works):
                _, conv_map[i, :, :] = _work(i)
                pbar.update(1)

        pbar.close()

        conv_map = xp.transpose(conv_map, (1, 2, 0))

        min_val = xp.min(conv_map)
        max_val = xp.max(conv_map)
        thresh = (
            min_val +
            (max_val - min_val) / config.apple.response_thresh_norm_factor)
        return xp.asnumpy(xp.sum(conv_map >= thresh, axis=2))