def moments(cls, img, query_size): """Calculates the mean and standard deviation for each window of size (query_size x query_size) in the micrograph. Args: img: Micrograph image. query_size: Size of windows for which to compute mean and std. Returns: A matrix of mean intensity and a matrix of variance, each containing a single entry for each possible (query_size x query_size) window in the micrograph. """ filt = xp.ones((query_size, query_size)) / (query_size * query_size) filt = xp.pad(filt, (0, img.shape[0] - 1), "constant", constant_values=(0, 0)) filt_freq = fft.fft2(filt, axes=(0, 1)) pad_img = xp.pad(img, (0, query_size - 1), "constant", constant_values=(0, 0)) img_freq = fft.fft2(pad_img, axes=(0, 1)) mean_freq = xp.multiply(img_freq, filt_freq) mean_all = fft.ifft2(mean_freq, axes=(0, 1)).real pad_img_square = np.square(pad_img) img_var_freq = fft.fft2(pad_img_square, axes=(0, 1)) var_freq = xp.multiply(img_var_freq, filt_freq) var_all = fft.ifft2(var_freq, axes=(0, 1)) var_all = var_all.real - xp.power(mean_all, 2) std_all = xp.sqrt(xp.maximum(0, var_all)) return mean_all, std_all
def _im_translate2(im, shifts): """ Translate image by shifts :param im: An Image instance to be translated. :param shifts: An array of size n-by-2 specifying the shifts in pixels. Alternatively, it can be a row vector of length 2, in which case the same shifts is applied to each image. :return: An Image instance translated by the shifts. TODO: This implementation has been moved here from aspire.aspire.abinitio and is faster than _im_translate. """ if not isinstance(im, Image): logger.warning( "_im_translate2 expects an Image, attempting to convert array." "Expects array of size n-by-L-by-L.") im = Image(im) if shifts.ndim == 1: shifts = shifts[np.newaxis, :] n_shifts = shifts.shape[0] if shifts.shape[1] != 2: raise ValueError("Input `shifts` must be of size n-by-2") if n_shifts != 1 and n_shifts != im.n_images: raise ValueError( "The number of shifts must be 1 or match the number of images") resolution = im.res grid = xp.asnumpy( fft.ifftshift( xp.asarray(np.ceil(np.arange(-resolution / 2, resolution / 2))))) om_y, om_x = np.meshgrid(grid, grid) phase_shifts = np.einsum("ij, k -> ijk", om_x, shifts[:, 0]) + np.einsum( "ij, k -> ijk", om_y, shifts[:, 1]) # TODO: figure out how why the result of einsum requires reshape phase_shifts = phase_shifts.reshape(n_shifts, resolution, resolution) phase_shifts /= resolution mult_f = np.exp(-2 * np.pi * 1j * phase_shifts) im_f = xp.asnumpy(fft.fft2(xp.asarray(im.asnumpy()))) im_translated_f = im_f * mult_f im_translated = np.real(xp.asnumpy(fft.ifft2(xp.asarray(im_translated_f)))) return Image(im_translated)
def _im_translate(self, shifts): """ Translate image by shifts :param im: An array of size n-by-L-by-L containing images to be translated. :param shifts: An array of size n-by-2 specifying the shifts in pixels. Alternatively, it can be a row vector of length 2, in which case the same shifts is applied to each image. :return: The images translated by the shifts, with periodic boundaries. TODO: This implementation is slower than _im_translate2 """ im = self.data if shifts.ndim == 1: shifts = shifts[np.newaxis, :] n_shifts = shifts.shape[0] ensure(shifts.shape[-1] == 2, "shifts must be nx2") ensure( n_shifts == 1 or n_shifts == self.n_images, "number of shifts must be 1 or match the number of images", ) # Cast shifts to this instance's internal dtype shifts = shifts.astype(self.dtype) L = self.res im_f = xp.asnumpy(fft.fft2(xp.asarray(im))) grid_shifted = fft.ifftshift( xp.asarray(np.ceil(np.arange(-L / 2, L / 2, dtype=self.dtype)))) grid_1d = xp.asnumpy(grid_shifted) * 2 * np.pi / L om_x, om_y = np.meshgrid(grid_1d, grid_1d, indexing="ij") phase_shifts_x = -shifts[:, 0].reshape((n_shifts, 1, 1)) phase_shifts_y = -shifts[:, 1].reshape((n_shifts, 1, 1)) phase_shifts = (om_x[np.newaxis, :, :] * phase_shifts_x + om_y[np.newaxis, :, :] * phase_shifts_y) mult_f = np.exp(-1j * phase_shifts) im_translated_f = im_f * mult_f im_translated = xp.asnumpy(fft.ifft2(xp.asarray(im_translated_f))) im_translated = np.real(im_translated) return Image(im_translated)
def downsample(insamples, szout, mask=None): """ Blur and downsample 1D to 3D objects such as, curves, images or volumes The function handles odd and even-sized arrays correctly. The center of an odd array is taken to be at (n+1)/2, and an even array is n/2+1. :param insamples: Set of objects to be downsampled in the form of an array.\ the first dimension is the number of objects. :param szout: The desired resolution of for output objects. :return: An array consists of the blurred and downsampled objects. """ ensure( insamples.ndim - 1 == np.size(szout), "The number of downsampling dimensions is not the same as that of objects.", ) L_in = insamples.shape[1] L_out = szout[0] ndata = insamples.shape[0] outdims = np.r_[ndata, szout] outsamples = np.zeros(outdims, dtype=insamples.dtype) if mask is None: mask = 1.0 if insamples.ndim == 2: # stack of one dimension objects for idata in range(ndata): insamples_shifted = fft.fftshift(fft.fft(xp.asarray(insamples[idata]))) insamples_fft = crop_pad(insamples_shifted, L_out) * mask outsamples_shifted = fft.ifft(fft.ifftshift(xp.asarray(insamples_fft))) outsamples[idata] = np.real(xp.asnumpy(outsamples_shifted) * (L_out / L_in)) elif insamples.ndim == 3: # stack of two dimension objects for idata in range(ndata): insamples_shifted = fft.fftshift(fft.fft2(xp.asarray(insamples[idata]))) insamples_fft = crop_pad(insamples_shifted, L_out) * mask outsamples_shifted = fft.ifft2(fft.ifftshift(xp.asarray(insamples_fft))) outsamples[idata] = np.real( xp.asnumpy(outsamples_shifted) * (L_out ** 2 / L_in ** 2) ) elif insamples.ndim == 4: # stack of three dimension objects for idata in range(ndata): insamples_shifted = fft.fftshift( fft.fftn(xp.asarray(insamples[idata]), axes=(0, 1, 2)) ) insamples_fft = crop_pad(insamples_shifted, L_out) * mask outsamples_shifted = fft.ifftn( fft.ifftshift(xp.asarray(insamples_fft)), axes=(0, 1, 2) ) outsamples[idata] = np.real( xp.asnumpy(outsamples_shifted) * (L_out ** 3 / L_in ** 3) ) else: raise RuntimeError("Number of dimensions > 3 for input objects.") return outsamples
def _work(index): reference_box_i = fft.fft2(reference_box[index], axes=(0, 1)) window_t = xp.multiply(reference_box_i, query_box) cc = fft.ifft2(window_t, axes=(2, 3)) return index, cc.real.max((2, 3)) - cc.real.mean((2, 3))
def query_score(self, show_progress=True): """Calculates score for each query image. Extracts query images and reference windows. Computes the cross-correlation between these windows, and applies a threshold to compute a score for each query image. Args: show_progress: Whether to show a progress bar Returns: Matrix containing a score for each query image. """ micro_img = xp.asarray(self.im) logger.info("Extracting query images") query_box = PickerHelper.extract_query(micro_img, self.query_size // 2) logger.info("Extracting query images complete") query_box = xp.conj(fft.fft2(query_box, axes=(2, 3))) reference_box = PickerHelper.extract_references( micro_img, self.query_size, self.container_size) reference_size = PickerHelper.reference_size(micro_img, self.container_size) conv_map = xp.zeros( (reference_size, query_box.shape[0], query_box.shape[1])) def _work(index): reference_box_i = fft.fft2(reference_box[index], axes=(0, 1)) window_t = xp.multiply(reference_box_i, query_box) cc = fft.ifft2(window_t, axes=(2, 3)) return index, cc.real.max((2, 3)) - cc.real.mean((2, 3)) n_works = reference_size n_threads = config.apple.conv_map_nthreads pbar = tqdm(total=reference_size, disable=not show_progress) # Ideally we'd like something like 'SerialExecutor' to enable easy debugging # but for now do an if-else if n_threads > 1: with futures.ThreadPoolExecutor(n_threads) as executor: to_do = [executor.submit(_work, i) for i in range(n_works)] for future in futures.as_completed(to_do): i, res = future.result() conv_map[i, :, :] = res pbar.update(1) else: for i in range(n_works): _, conv_map[i, :, :] = _work(i) pbar.update(1) pbar.close() conv_map = xp.transpose(conv_map, (1, 2, 0)) min_val = xp.min(conv_map) max_val = xp.max(conv_map) thresh = ( min_val + (max_val - min_val) / config.apple.response_thresh_norm_factor) return xp.asnumpy(xp.sum(conv_map >= thresh, axis=2))