def moments(cls, img, query_size): """Calculates the mean and standard deviation for each window of size (query_size x query_size) in the micrograph. Args: img: Micrograph image. query_size: Size of windows for which to compute mean and std. Returns: A matrix of mean intensity and a matrix of variance, each containing a single entry for each possible (query_size x query_size) window in the micrograph. """ filt = xp.ones((query_size, query_size)) / (query_size * query_size) filt = xp.pad(filt, (0, img.shape[0] - 1), 'constant', constant_values=(0, 0)) filt_freq = xp.fft2(filt, axes=(0, 1)) pad_img = xp.pad(img, (0, query_size - 1), 'constant', constant_values=(0, 0)) img_freq = xp.fft2(pad_img, axes=(0, 1)) mean_freq = xp.multiply(img_freq, filt_freq) mean_all = xp.ifft2(mean_freq, axes=(0, 1)).real pad_img_square = np.square(pad_img) img_var_freq = xp.fft2(pad_img_square, axes=(0, 1)) var_freq = xp.multiply(img_var_freq, filt_freq) var_all = xp.ifft2(var_freq, axes=(0, 1)) var_all = var_all.real - xp.power(mean_all, 2) std_all = xp.sqrt(xp.maximum(0, var_all)) return mean_all, std_all
def query_score(self, show_progress=True): """Calculates score for each query image. Extracts query images and reference windows. Computes the cross-correlation between these windows, and applies a threshold to compute a score for each query image. Args: show_progress: Whether to show a progress bar Returns: Matrix containing a score for each query image. """ micro_img = xp.asarray(self.im) logger.info('Extracting query images') query_box = PickerHelper.extract_query(micro_img, self.query_size // 2) logger.info('Extracting query images complete') query_box = xp.conj(xp.fft2(query_box, axes=(2, 3))) reference_box = PickerHelper.extract_references(micro_img, self.query_size, self.container_size) reference_size = PickerHelper.reference_size(micro_img, self.container_size) conv_map = xp.zeros((reference_size, query_box.shape[0], query_box.shape[1])) def _work(index): reference_box_i = xp.fft2(reference_box[index], axes=(0, 1)) window_t = xp.multiply(reference_box_i, query_box) cc = xp.ifft2(window_t, axes=(2, 3)) return index, cc.real.max((2, 3)) - cc.real.mean((2, 3)) n_works = reference_size n_threads = config.apple.conv_map_nthreads pbar = tqdm(total=reference_size, disable=not show_progress) # Ideally we'd like something like 'SerialExecutor' to enable easy debugging # but for now do an if-else if n_threads > 1: with futures.ThreadPoolExecutor(n_threads) as executor: to_do = [executor.submit(_work, i) for i in range(n_works)] for future in futures.as_completed(to_do): i, res = future.result() conv_map[i, :, :] = res pbar.update(1) else: for i in range(n_works): _, conv_map[i, :, :] = _work(i) pbar.update(1) pbar.close() conv_map = xp.transpose(conv_map, (1, 2, 0)) min_val = xp.min(conv_map) max_val = xp.max(conv_map) thresh = min_val + (max_val - min_val) / config.apple.response_thresh_norm_factor return xp.asnumpy(xp.sum(conv_map >= thresh, axis=2))
def _work(index): reference_box_i = xp.fft2(reference_box[index], axes=(0, 1)) window_t = xp.multiply(reference_box_i, query_box) cc = xp.ifft2(window_t, axes=(2, 3)) return index, cc.real.max((2, 3)) - cc.real.mean((2, 3))
def testFft2(self): a = xp.random.random((100, 100)) b = xp.fft2(a) c = xp.ifft2(b) self.assertTrue(np.allclose(a, c))