예제 #1
0
    def evaluate(self, v):
        """
        Evaluate coefficients in standard 2D coordinate basis from those in polar Fourier basis

        :param v: A coefficient vector (or an array of coefficient vectors)
            in polar Fourier basis to be evaluated. The last dimension must equal to
            `self.count`.
        :return x: Image instance in standard 2D coordinate basis with
            resolution of `self.sz`.
        """
        if self.dtype != real_type(v.dtype):
            msg = (f"Input data type, {v.dtype}, is not consistent with"
                   f" type defined in the class {self.dtype}.")
            logger.error(msg)
            raise TypeError(msg)

        v = v.reshape(-1, self.ntheta, self.nrad)

        nimgs = v.shape[0]

        half_size = self.ntheta // 2

        v = v[:, :half_size, :] + v[:, half_size:, :].conj()

        v = v.reshape(nimgs, self.nrad * half_size)

        x = anufft(v, self.freqs, self.sz, real=True)

        return Image(x)
예제 #2
0
    def backproject(self, rot_matrices):
        """
        Backproject images along rotation
        :param im: An Image (stack) to backproject.
        :param rot_matrices: An n-by-3-by-3 array of rotation matrices \
        corresponding to viewing directions.

        :return: Volume instance corresonding to the backprojected images.
        """

        L = self.res

        ensure(
            self.n_images == rot_matrices.shape[0],
            "Number of rotation matrices must match the number of images",
        )

        # TODO: rotated_grids might as well give us correctly shaped array in the first place
        pts_rot = aspire.volume.rotated_grids(L, rot_matrices)
        pts_rot = np.moveaxis(pts_rot, 1, 2)
        pts_rot = m_reshape(pts_rot, (3, -1))

        im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(self.data))) / (L**2)
        if L % 2 == 0:
            im_f[:, 0, :] = 0
            im_f[:, :, 0] = 0

        im_f = im_f.flatten()

        vol = anufft(im_f, pts_rot, (L, L, L), real=True) / L

        return aspire.volume.Volume(vol)
예제 #3
0
    def compute_kernel(self):
        _2L = 2 * self.L
        kernel = np.zeros((_2L, _2L, _2L), dtype=self.dtype)
        sq_filters_f = self.src.eval_filter_grid(self.L, power=2)

        for i in range(0, self.n, self.batch_size):
            _range = np.arange(i, min(self.n, i + self.batch_size), dtype=int)
            pts_rot = rotated_grids(self.L, self.src.rots[_range, :, :])
            weights = sq_filters_f[:, :, _range]
            weights *= self.src.amplitudes[_range]**2

            if self.L % 2 == 0:
                weights[0, :, :] = 0
                weights[:, 0, :] = 0

            pts_rot = m_reshape(pts_rot, (3, -1))
            weights = m_flatten(weights)

            kernel += (1 / (self.n * self.L**4) *
                       anufft(weights, pts_rot, (_2L, _2L, _2L), real=True))

        # Ensure symmetric kernel
        kernel[0, :, :] = 0
        kernel[:, 0, :] = 0
        kernel[:, :, 0] = 0

        logger.info("Computing non-centered Fourier Transform")
        kernel = mdim_ifftshift(kernel, range(0, 3))
        kernel_f = fft2(kernel, axes=(0, 1, 2))
        kernel_f = np.real(kernel_f)

        return FourierKernel(kernel_f, centered=False)
예제 #4
0
    def compute_kernel(self):
        # TODO: Most of this stuff is duplicated in MeanEstimator - move up the hierarchy?
        n = self.n
        L = self.L
        _2L = 2 * self.L

        kernel = np.zeros((_2L, _2L, _2L, _2L, _2L, _2L), dtype=self.dtype)
        sq_filters_f = self.src.eval_filter_grid(self.L, power=2)

        for i in tqdm(range(0, n, self.batch_size)):
            _range = np.arange(i, min(n, i + self.batch_size))
            pts_rot = rotated_grids(L, self.src.rots[_range, :, :])
            weights = sq_filters_f[:, :, _range]
            weights *= self.src.amplitudes[_range]**2

            if L % 2 == 0:
                weights[0, :, :] = 0
                weights[:, 0, :] = 0

            # TODO: This is where this differs from MeanEstimator
            pts_rot = np.moveaxis(pts_rot, -1, 0).reshape(-1, 3, L**2)
            weights = weights.T.reshape((-1, L**2))

            batch_n = weights.shape[0]
            factors = np.zeros((batch_n, _2L, _2L, _2L), dtype=self.dtype)

            for j in range(batch_n):
                factors[j] = anufft(weights[j],
                                    pts_rot[j], (_2L, _2L, _2L),
                                    real=True)

            factors = Volume(factors).to_vec()
            kernel += vecmat_to_volmat(factors.T @ factors) / (n * L**8)

        # Ensure symmetric kernel
        kernel[0, :, :, :, :, :] = 0
        kernel[:, 0, :, :, :, :] = 0
        kernel[:, :, 0, :, :, :] = 0
        kernel[:, :, :, 0, :, :] = 0
        kernel[:, :, :, :, 0, :] = 0
        kernel[:, :, :, :, :, 0] = 0

        logger.info("Computing non-centered Fourier Transform")
        kernel = mdim_ifftshift(kernel, range(0, 6))
        kernel_f = fftn(kernel)
        # Kernel is always symmetric in spatial domain and therefore real in Fourier
        kernel_f = np.real(kernel_f)

        return FourierKernel(kernel_f, centered=False)
예제 #5
0
    def evaluate(self, v):
        """
        Evaluate coefficients in standard 2D coordinate basis from those in FB basis

        :param v: A coefficient vector (or an array of coefficient vectors)
            in FB basis to be evaluated. The last dimension must equal `self.count`.
        :return x: The evaluation of the coefficient vector(s) `x` in standard 2D
            coordinate basis. This is Image instance with resolution of `self.sz`
            and the first dimension correspond to remaining dimension of `v`.
        """

        if v.dtype != self.dtype:
            logger.debug(
                f"{self.__class__.__name__}::evaluate"
                f" Inconsistent dtypes v: {v.dtype} self: {self.dtype}")

        sz_roll = v.shape[:-1]
        v = v.reshape(-1, self.count)

        # number of 2D image samples
        n_data = v.shape[0]

        # get information on polar grids from precomputed data
        n_theta = np.size(self._precomp["freqs"], 2)
        n_r = np.size(self._precomp["freqs"], 1)

        # go through  each basis function and find corresponding coefficient
        pf = np.zeros((n_data, 2 * n_theta, n_r),
                      dtype=complex_type(self.dtype))
        mask = self._indices["ells"] == 0

        ind = 0

        idx = ind + np.arange(self.k_max[0], dtype=int)

        # include the normalization factor of angular part into radial part
        radial_norm = self._precomp["radial"] / np.expand_dims(
            self.angular_norms, 1)
        pf[:, 0, :] = v[:, mask] @ radial_norm[idx]
        ind = ind + np.size(idx)

        ind_pos = ind

        for ell in range(1, self.ell_max + 1):
            idx = ind + np.arange(self.k_max[ell], dtype=int)
            idx_pos = ind_pos + np.arange(self.k_max[ell], dtype=int)
            idx_neg = idx_pos + self.k_max[ell]

            v_ell = (v[:, idx_pos] - 1j * v[:, idx_neg]) / 2.0

            if np.mod(ell, 2) == 1:
                v_ell = 1j * v_ell

            pf_ell = v_ell @ radial_norm[idx]
            pf[:, ell, :] = pf_ell

            if np.mod(ell, 2) == 0:
                pf[:, 2 * n_theta - ell, :] = pf_ell.conjugate()
            else:
                pf[:, 2 * n_theta - ell, :] = -pf_ell.conjugate()

            ind = ind + np.size(idx)
            ind_pos = ind_pos + 2 * self.k_max[ell]

        # 1D inverse FFT in the degree of polar angle
        pf = 2 * pi * xp.asnumpy(fft.ifft(xp.asarray(pf), axis=1))

        # Only need "positive" frequencies.
        hsize = int(np.size(pf, 1) / 2)
        pf = pf[:, 0:hsize, :]

        for i_r in range(0, n_r):
            pf[..., i_r] = pf[..., i_r] * (self._precomp["gl_weights"][i_r] *
                                           self._precomp["gl_nodes"][i_r])

        pf = np.reshape(pf, (n_data, n_r * n_theta))

        # perform inverse non-uniformly FFT transform back to 2D coordinate basis
        freqs = m_reshape(self._precomp["freqs"], (2, n_r * n_theta))

        x = 2 * anufft(pf, 2 * pi * freqs, self.sz, real=True)

        # Return X as Image instance with the last two dimensions as *self.sz
        x = x.reshape((*sz_roll, *self.sz))

        return Image(x)
예제 #6
0
    def evaluate(self, v):
        """
        Evaluate coefficients in standard 3D coordinate basis from those in 3D FB basis

        :param v: A coefficient vector (or an array of coefficient vectors) in FB basis
            to be evaluated. The last dimension must equal `self.count`.
        :return x: The evaluation of the coefficient vector(s) `x` in standard 3D
            coordinate basis. This is an array whose last three dimensions equal
            `self.sz` and the remaining dimensions correspond to `v`.
        """
        # roll dimensions of v
        sz_roll = v.shape[:-1]
        v = v.reshape((-1, self.count))

        # get information on polar grids from precomputed data
        n_theta = np.size(self._precomp["ang_theta_wtd"], 0)
        n_phi = np.size(self._precomp["ang_phi_wtd_even"][0], 0)
        n_r = np.size(self._precomp["radial_wtd"], 0)

        # number of 3D image samples
        n_data = v.shape[0]

        u_even = np.zeros(
            (
                n_r,
                int(2 * self.ell_max + 1),
                n_data,
                int(np.floor(self.ell_max / 2) + 1),
            ),
            dtype=v.dtype,
        )
        u_odd = np.zeros(
            (n_r, int(2 * self.ell_max + 1), n_data,
             int(np.ceil(self.ell_max / 2))),
            dtype=v.dtype,
        )

        # go through each basis function and find corresponding coefficient
        # evaluate the radial parts
        for ell in range(0, self.ell_max + 1):
            k_max_ell = self.k_max[ell]
            radial_wtd = self._precomp["radial_wtd"][:, 0:k_max_ell, ell]

            ind = self._indices["ells"] == ell

            v_ell = m_reshape(v[:, ind].T, (k_max_ell, (2 * ell + 1) * n_data))
            v_ell = radial_wtd @ v_ell
            v_ell = m_reshape(v_ell, (n_r, 2 * ell + 1, n_data))

            if np.mod(ell, 2) == 0:
                u_even[:,
                       int(self.ell_max - ell):int(self.ell_max + ell + 1), :,
                       int(ell / 2), ] = v_ell
            else:
                u_odd[:,
                      int(self.ell_max - ell):int(self.ell_max + ell + 1), :,
                      int((ell - 1) / 2), ] = v_ell

        u_even = np.transpose(u_even, (3, 0, 1, 2))
        u_odd = np.transpose(u_odd, (3, 0, 1, 2))
        w_even = np.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1),
                          dtype=v.dtype)
        w_odd = np.zeros((n_phi, n_r, n_data, 2 * self.ell_max + 1),
                         dtype=v.dtype)

        # evaluate the phi parts
        for m in range(0, self.ell_max + 1):
            ang_phi_wtd_m_even = self._precomp["ang_phi_wtd_even"][m]
            ang_phi_wtd_m_odd = self._precomp["ang_phi_wtd_odd"][m]

            n_even_ell = np.size(ang_phi_wtd_m_even, 1)
            n_odd_ell = np.size(ang_phi_wtd_m_odd, 1)

            if m == 0:
                sgns = (1, )
            else:
                sgns = (1, -1)

            for sgn in sgns:

                end = np.size(u_even, 0)
                u_m_even = u_even[end - n_even_ell:end, :,
                                  self.ell_max + sgn * m, :]
                end = np.size(u_odd, 0)
                u_m_odd = u_odd[end - n_odd_ell:end, :,
                                self.ell_max + sgn * m, :]

                u_m_even = m_reshape(u_m_even, (n_even_ell, n_r * n_data))
                u_m_odd = m_reshape(u_m_odd, (n_odd_ell, n_r * n_data))

                w_m_even = ang_phi_wtd_m_even @ u_m_even
                w_m_odd = ang_phi_wtd_m_odd @ u_m_odd

                w_m_even = m_reshape(w_m_even, (n_phi, n_r, n_data))
                w_m_odd = m_reshape(w_m_odd, (n_phi, n_r, n_data))

                w_even[:, :, :, self.ell_max + sgn * m] = w_m_even
                w_odd[:, :, :, self.ell_max + sgn * m] = w_m_odd

        w_even = np.transpose(w_even, (3, 0, 1, 2))
        w_odd = np.transpose(w_odd, (3, 0, 1, 2))
        u_even = w_even
        u_odd = w_odd

        u_even = m_reshape(u_even,
                           (2 * self.ell_max + 1, n_phi * n_r * n_data))
        u_odd = m_reshape(u_odd, (2 * self.ell_max + 1, n_phi * n_r * n_data))

        # evaluate the theta parts
        w_even = self._precomp["ang_theta_wtd"] @ u_even
        w_odd = self._precomp["ang_theta_wtd"] @ u_odd

        pf = w_even + 1j * w_odd
        pf = m_reshape(pf, (n_theta * n_phi * n_r, n_data))
        pf = np.moveaxis(pf, 0, -1)

        # perform inverse non-uniformly FFT transformation back to 3D rectangular coordinates
        freqs = m_reshape(self._precomp["fourier_pts"],
                          (3, n_r * n_theta * n_phi))
        x = anufft(pf, freqs, self.sz, real=True)

        # Roll, return the x with the last three dimensions as self.sz
        # Higher dimensions should be like v.
        x = x.reshape((*sz_roll, *self.sz))
        return x