def compute_kernel(self): _2L = 2 * self.L kernel = np.zeros((_2L, _2L, _2L), dtype=self.as_type) filters_f = self.src.filters.evaluate_grid(self.L) sq_filters_f = np.array(filters_f ** 2, dtype=self.as_type) for i in range(0, self.n, self.batch_size): pts_rot = rotated_grids(self.L, self.src.rots[:, :, i:i+self.batch_size]) weights = sq_filters_f[:, :, self.src.filters.indices[i:i+self.batch_size]] weights *= self.src.amplitudes[i:i+self.batch_size] ** 2 if self.L % 2 == 0: weights[0, :, :] = 0 weights[:, 0, :] = 0 pts_rot = m_reshape(pts_rot, (3, -1)) weights = m_flatten(weights) kernel += 1 / (self.n * self.L ** 4) * anufft3(weights, pts_rot, (_2L, _2L, _2L), real=True) # Ensure symmetric kernel kernel[0, :, :] = 0 kernel[:, 0, :] = 0 kernel[:, :, 0] = 0 logger.info('Computing non-centered Fourier Transform') kernel = mdim_ifftshift(kernel, range(0, 3)) kernel_f = fft2(kernel, axes=(0, 1, 2)) kernel_f = np.real(kernel_f) return FourierKernel(kernel_f, centered=False)
def im_backproject(im, rot_matrices): """ Backproject images along rotation :param im: An L-by-L-by-n array of images to backproject. :param rot_matrices: An 3-by-3-by-n array of rotation matrices corresponding to viewing directions. :return: An L-by-L-by-L volumes corresponding to the sum of the backprojected images. """ L, _, n = im.shape ensure(L == im.shape[1], "im must be LxLxK") ensure(n == rot_matrices.shape[2], "No. of rotation matrices must match the number of images") pts_rot = rotated_grids(L, rot_matrices) pts_rot = m_reshape(pts_rot, (3, -1)) im_f = centered_fft2(im) / (L**2) if L % 2 == 0: im_f[0, :, :] = 0 im_f[:, 0, :] = 0 im_f = m_flatten(im_f) plan = Plan(sz=(L, L, L), fourier_pts=pts_rot) vol = np.real(plan.adjoint(im_f)) / L return vol
def evaluate_t(self, v): """ Evaluate coefficient in dual basis :param v: The coefficient array to be evaluated. The first dimensions must equal `self.sz`. :return: The evaluation of the coefficient array `v` in the dual basis of `basis`. This is an array of vectors whose first dimension equals `self.basis_count` and whose remaining dimensions correspond to higher dimensions of `v`. """ x, sz_roll = unroll_dim(v, self.d + 1) x = m_reshape(x, new_shape=tuple([np.prod(self.sz)] + list(x.shape[self.d:]))) r_idx = self.basis_coords['r_idx'] ang_idx = self.basis_coords['ang_idx'] mask = m_flatten(self.basis_coords['mask']) ind = 0 ind_radial = 0 ind_ang = 0 v = np.zeros(shape=tuple([self.basis_count] + list(x.shape[1:]))) for ell in range(0, self.ell_max + 1): k_max = self.k_max[ell] idx_radial = ind_radial + np.arange(0, k_max) nrms = self._norms[idx_radial] radial = self._precomp['radial'][:, idx_radial] radial = radial / nrms sgns = (1, ) if ell == 0 else (1, -1) for _ in sgns: ang = self._precomp['ang'][:, ind_ang] ang_radial = np.expand_dims(ang[ang_idx], axis=1) * radial[r_idx] idx = ind + np.arange(0, k_max) v[idx] = ang_radial.T @ x[mask] ind += len(idx) ind_ang += 1 ind_radial += len(idx_radial) v = roll_dim(v, sz_roll) return v
def evaluate(self, v): """ Evaluate coefficient vector in basis :param v: A coefficient vector (or an array of coefficient vectors) to be evaluated. The first dimension must equal `self.basis_count`. :return: The evaluation of the coefficient vector(s) `v` for this basis. This is an array whose first dimensions equal `self.z` and the remaining dimensions correspond to dimensions two and higher of `v`. """ v, sz_roll = unroll_dim(v, 2) r_idx = self.basis_coords['r_idx'] ang_idx = self.basis_coords['ang_idx'] mask = m_flatten(self.basis_coords['mask']) ind = 0 ind_radial = 0 ind_ang = 0 x = np.zeros(shape=tuple([np.prod(self.sz)] + list(v.shape[1:]))) for ell in range(0, self.ell_max + 1): k_max = self.k_max[ell] idx_radial = ind_radial + np.arange(0, k_max) nrms = self._norms[idx_radial] radial = self._precomp['radial'][:, idx_radial] radial = radial / nrms sgns = (1, ) if ell == 0 else (1, -1) for _ in sgns: ang = self._precomp['ang'][:, ind_ang] ang_radial = np.expand_dims(ang[ang_idx], axis=1) * radial[r_idx] idx = ind + np.arange(0, k_max) x[mask] += ang_radial @ v[idx] ind += len(idx) ind_ang += 1 ind_radial += len(idx_radial) x = m_reshape(x, self.sz + x.shape[1:]) x = roll_dim(x, sz_roll) return x
def precomp(self): """ Precomute the basis functions on a polar Fourier 3D grid. Gaussian quadrature points and weights are also generated in radical and phi dimensions. """ n_r = int(self.ell_max + 1) n_theta = int(2 * self.sz[0]) n_phi = int(self.ell_max + 1) r, wt_r = lgwt(n_r, 0.0, self.c) z, wt_z = lgwt(n_phi, -1, 1) r = m_reshape(r, (n_r, 1)) wt_r = m_reshape(wt_r, (n_r, 1)) z = m_reshape(z, (n_phi, 1)) wt_z = m_reshape(wt_z, (n_phi, 1)) phi = np.arccos(z) wt_phi = wt_z theta = 2 * pi * np.arange(n_theta).T / (2 * n_theta) theta = m_reshape(theta, (n_theta, 1)) # evaluate basis function in the radial dimension radial_wtd = np.zeros(shape=(n_r, np.max(self.k_max), self.ell_max + 1)) for ell in range(0, self.ell_max + 1): k_max_ell = self.k_max[ell] rmat = r * self.r0[0:k_max_ell, ell].T / self.c radial_ell = np.zeros_like(rmat) for ik in range(0, k_max_ell): radial_ell[:, ik] = sph_bessel(ell, rmat[:, ik]) nrm = np.abs(sph_bessel(ell + 1, self.r0[0:k_max_ell, ell].T) / 4) radial_ell = radial_ell / nrm radial_ell_wtd = r**2 * wt_r * radial_ell radial_wtd[:, 0:k_max_ell, ell] = radial_ell_wtd # evaluate basis function in the phi dimension ang_phi_wtd_even = [] ang_phi_wtd_odd = [] for m in range(0, self.ell_max + 1): n_even_ell = int( np.floor((self.ell_max - m) / 2) + 1 - np.mod(self.ell_max, 2) * np.mod(m, 2)) n_odd_ell = int(self.ell_max - m + 1 - n_even_ell) phi_wtd_m_even = np.zeros((n_phi, n_even_ell), dtype=phi.dtype) phi_wtd_m_odd = np.zeros((n_phi, n_odd_ell), dtype=phi.dtype) ind_even = 0 ind_odd = 0 for ell in range(m, self.ell_max + 1): phi_m_ell = norm_assoc_legendre(ell, m, z) nrm_inv = np.sqrt(0.5 / pi) phi_m_ell = nrm_inv * phi_m_ell phi_wtd_m_ell = wt_phi * phi_m_ell if np.mod(ell, 2) == 0: phi_wtd_m_even[:, ind_even] = phi_wtd_m_ell[:, 0] ind_even = ind_even + 1 else: phi_wtd_m_odd[:, ind_odd] = phi_wtd_m_ell[:, 0] ind_odd = ind_odd + 1 ang_phi_wtd_even.append(phi_wtd_m_even) ang_phi_wtd_odd.append(phi_wtd_m_odd) # evaluate basis function in the theta dimension ang_theta = np.zeros((n_theta, 2 * self.ell_max + 1), dtype=theta.dtype) ang_theta[:, 0:self.ell_max] = np.sqrt(2) * np.sin( theta @ m_reshape(np.arange(self.ell_max, 0, -1), (1, self.ell_max))) ang_theta[:, self.ell_max] = np.ones(n_theta, dtype=theta.dtype) ang_theta[:, self.ell_max + 1:2 * self.ell_max + 1] = np.sqrt(2) * np.cos( theta @ m_reshape(np.arange(1, self.ell_max + 1), (1, self.ell_max))) ang_theta_wtd = (2 * pi / n_theta) * ang_theta theta_grid, phi_grid, r_grid = np.meshgrid(theta, phi, r, sparse=False, indexing='ij') fourier_x = m_flatten(r_grid * np.cos(theta_grid) * np.sin(phi_grid)) fourier_y = m_flatten(r_grid * np.sin(theta_grid) * np.sin(phi_grid)) fourier_z = m_flatten(r_grid * np.cos(phi_grid)) fourier_pts = 2 * pi * np.vstack( (fourier_x[np.newaxis, ...], fourier_y[np.newaxis, ...], fourier_z[np.newaxis, ...])) return { 'radial_wtd': radial_wtd, 'ang_phi_wtd_even': ang_phi_wtd_even, 'ang_phi_wtd_odd': ang_phi_wtd_odd, 'ang_theta_wtd': ang_theta_wtd, 'fourier_pts': fourier_pts }