Exemple #1
0
    def convolve_volume_matrix(self, x):
        """
        Convolve volume matrix with kernel
        :param x: An N-by-...-by-N (6 dimensions) volume matrix to be convolved.
        :return: The original volume matrix convolved by the kernel with the same dimensions as before.
        """
        shape = x.shape
        N = shape[0]
        kernel_f = self.kernel
        ensure(
            len(set(shape[i] for i in range(5))) == 1,
            "Volume matrix must be cubic and square",
        )

        # TODO from MATLAB code: Deal with rolled dimensions
        N_ker = kernel_f.shape[0]

        # Note from MATLAB code:
        # Order is important here.  It's about 20% faster to run from 1 through 6 compared with 6 through 1.
        # TODO: Experiment with scipy order; try overwrite_x argument
        for i in range(6):
            x = fft(x, N_ker, i, overwrite_x=True)

        x *= kernel_f

        indices = list(range(N))
        for i in range(5, -1, -1):
            x = ifft(x, None, i, overwrite_x=True)
            x = x.take(indices, axis=i)

        return np.real(x)
Exemple #2
0
 def precond_fun(S, x):
     p = np.size(S, 0)
     ensure(np.size(x) == p*p, 'The sizes of S and x are not consistent.')
     x = m_reshape(x, (p, p))
     y = S @ x @ S
     y = m_reshape(y, (p**2,))
     return y
Exemple #3
0
    def backproject(self, rot_matrices):
        """
        Backproject images along rotation
        :param im: An Image (stack) to backproject.
        :param rot_matrices: An n-by-3-by-3 array of rotation matrices \
        corresponding to viewing directions.

        :return: Volume instance corresonding to the backprojected images.
        """

        L = self.res

        ensure(
            self.n_images == rot_matrices.shape[0],
            "Number of rotation matrices must match the number of images",
        )

        # TODO: rotated_grids might as well give us correctly shaped array in the first place
        pts_rot = aspire.volume.rotated_grids(L, rot_matrices)
        pts_rot = np.moveaxis(pts_rot, 1, 2)
        pts_rot = m_reshape(pts_rot, (3, -1))

        im_f = xp.asnumpy(fft.centered_fft2(xp.asarray(self.data))) / (L**2)
        if L % 2 == 0:
            im_f[:, 0, :] = 0
            im_f[:, :, 0] = 0

        im_f = im_f.flatten()

        vol = anufft(im_f, pts_rot, (L, L, L), real=True) / L

        return aspire.volume.Volume(vol)
Exemple #4
0
    def __init__(self, size, ell_max=None):

        d = len(size)
        ensure(d == 3, 'Only three-dimensional basis functions are supported.')
        ensure(len(set(size)) == 1, 'Only cubic domains are supported.')

        super().__init__(size, ell_max)
Exemple #5
0
def unique_coords_nd(N, ndim, shifted=False, normalized=True):
    """
    Generate unique polar coordinates from 2D or 3D rectangular coordinates.
    :param N: length size of a square or cube.
    :param ndim: number of dimension, 2 or 3.
    :param shifted: shifted half pixel or not for odd N.
    :param normalized: normalize the grid or not.
    :return: The unique polar coordinates in 2D or 3D
    """
    ensure(ndim in (2, 3),
           'Only two- or three-dimensional basis functions are supported.')
    ensure(N > 0, 'Number of grid points should be greater than 0.')

    if ndim == 2:
        grid = grid_2d(N, shifted=shifted, normalized=normalized)
        mask = grid['r'] <= 1

        # Minor differences in r/theta/phi values are unimportant for the purpose
        # of this function, so round off before proceeding

        # TODO: numpy boolean indexing will return a 1d array (like MATLAB)
        # However, it always searches in row-major order, unlike MATLAB (column-major),
        # with no options to change the search order. The results we'll be getting back are thus not comparable.
        # We transpose the appropriate ndarrays before applying the mask to obtain the same behavior as MATLAB.
        r = grid['r'].T[mask].round(5)
        phi = grid['phi'].T[mask].round(5)

        r_unique, r_idx = np.unique(r, return_inverse=True)
        ang_unique, ang_idx = np.unique(phi, return_inverse=True)

    else:
        grid = grid_3d(N, shifted=shifted, normalized=normalized)
        mask = grid['r'] <= 1

        # In Numpy, elements in the indexed array are always iterated and returned in row-major (C-style) order.
        # To emulate a behavior where iteration happens in Fortran order, we swap axes 0 and 2 of both the array
        # being indexed (r/theta/phi), as well as the mask itself.
        # TODO: This is only for the purpose of getting the same behavior as MATLAB while porting the code, and is
        # likely not needed in the final version.

        # Minor differences in r/theta/phi values are unimportant for the purpose of this function,
        # so we round off before proceeding.

        mask_ = np.swapaxes(mask, 0, 2)
        r = np.swapaxes(grid['r'], 0, 2)[mask_].round(5)
        theta = np.swapaxes(grid['theta'], 0, 2)[mask_].round(5)
        phi = np.swapaxes(grid['phi'], 0, 2)[mask_].round(5)

        r_unique, r_idx = np.unique(r, return_inverse=True)
        ang_unique, ang_idx = np.unique(np.vstack([theta, phi]),
                                        axis=1,
                                        return_inverse=True)

    return {
        'r_unique': r_unique,
        'ang_unique': ang_unique,
        'r_idx': r_idx,
        'ang_idx': ang_idx,
        'mask': mask
    }
Exemple #6
0
    def __init__(self, data, dtype=None):
        """
        A stack of one or more images.

        This is a wrapper of numpy.ndarray which provides methods
        for common processing tasks.

        :param data: Numpy array containing image data with shape `(n_images, res, res)`.
        :param dtype: Optionally cast `data` to this dtype. Defaults to `data.dtype`.
        :return: Image instance storing `data`.
        """

        assert isinstance(
            data, np.ndarray), "Image should be instantiated with an ndarray"

        if data.ndim == 2:
            data = data[np.newaxis, :, :]

        if dtype is None:
            self.dtype = data.dtype
        else:
            self.dtype = np.dtype(dtype)

        self.data = data.astype(self.dtype, copy=False)
        self.ndim = self.data.ndim
        self.shape = self.data.shape
        self.n_images = self.shape[0]
        self.res = self.shape[1]

        ensure(data.shape[1] == data.shape[2],
               "Only square ndarrays are supported.")
Exemple #7
0
def vec_to_symmat(vec):
    """
    Convert packed lower triangular vector to symmetric matrix
    :param vec: A vector of size N*(N+1)/2-by-... describing a symmetric (or Hermitian) matrix.
    :return: An array of size N-by-N-by-... which indexes symmetric/Hermitian matrices that occupy the first two
        dimensions. The lower triangular parts of these matrices consists of the corresponding vectors in vec.
    """
    # TODO: Handle complex values in vec
    if np.iscomplex(vec).any():
        raise NotImplementedError("Coming soon")

    # M represents N(N+1)/2
    M = vec.shape[0]
    N = int(round(np.sqrt(2 * M + 0.25) - 0.5))
    ensure(
        (M == 0.5 * N * (N + 1)) and N != 0,
        "Vector must be of size N*(N+1)/2 for some N>0.",
    )

    vec, sz_roll = unroll_dim(vec, 2)
    index_matrix = np.empty((N, N))
    i_upper = np.triu_indices_from(index_matrix)
    index_matrix[i_upper] = np.arange(
        M
    )  # Incrementally populate upper triangle in row major order
    index_matrix.T[i_upper] = index_matrix[i_upper]  # Copy to lower triangle

    mat = vec[index_matrix.flatten("F").astype("int")]
    mat = m_reshape(mat, (N, N) + mat.shape[1:])
    mat = roll_dim(mat, sz_roll)

    return mat
Exemple #8
0
def im_backproject(im, rot_matrices):
    """
    Backproject images along rotation
    :param im: An L-by-L-by-n array of images to backproject.
    :param rot_matrices: An 3-by-3-by-n array of rotation matrices corresponding to viewing directions.
    :return: An L-by-L-by-L volumes corresponding to the sum of the backprojected images.
    """
    L, _, n = im.shape
    ensure(L == im.shape[1], "im must be LxLxK")
    ensure(n == rot_matrices.shape[2],
           "No. of rotation matrices must match the number of images")

    pts_rot = rotated_grids(L, rot_matrices)
    pts_rot = m_reshape(pts_rot, (3, -1))

    im_f = centered_fft2(im) / (L**2)
    if L % 2 == 0:
        im_f[0, :, :] = 0
        im_f[:, 0, :] = 0
    im_f = m_flatten(im_f)

    plan = Plan(sz=(L, L, L), fourier_pts=pts_rot)
    vol = np.real(plan.adjoint(im_f)) / L

    return vol
Exemple #9
0
    def __init__(self, xfer_fn_array):
        """
        A Filter corresponding to the filter with the specified transfer function.
        :param xfer_fn_array: The transfer function of the filter in the form of an array of one or two dimensions.
        """
        dim = xfer_fn_array.ndim
        ensure(dim in (1, 2), "Only dimensions 1 and 2 supported.")

        super().__init__(dim=dim, radial=False)

        # sz is assigned before we do anything with xfer_fn_array
        self.sz = xfer_fn_array.shape

        # The following code, though superficially different from the MATLAB code its copied from,
        # results in the same behavior.
        # TODO: This could use documentation - very unintuitive!
        if dim == 1:
            # If we have a vector of even length, then append the first element to the last
            if xfer_fn_array.shape[0] % 2 == 0:
                xfer_fn_array = np.concatenate(
                    (xfer_fn_array, np.array([xfer_fn_array[0]])))
        elif dim == 2:
            # If we have a 2d array with an even number of rows, append the first row reversed at the bottom
            if xfer_fn_array.shape[0] % 2 == 0:
                xfer_fn_array = np.vstack(
                    (xfer_fn_array, xfer_fn_array[0, ::-1]))
            # If we have a 2d array with an even number of columns, append the first column reversed at the right
            if xfer_fn_array.shape[1] % 2 == 0:
                xfer_fn_array = np.hstack(
                    (xfer_fn_array, xfer_fn_array[::-1, 0][:, np.newaxis]))

        self.xfer_fn_array = xfer_fn_array
Exemple #10
0
    def __init__(self, data):
        """
        Create a volume initialized with data.

        Volumes should be N x L x L x L,
        or L x L x L which implies N=1.

        :param data: Volume data

        :return: A volume instance.
        """

        if data.ndim == 3:
            data = data[np.newaxis, :, :, :]

        ensure(
            data.ndim == 4,
            "Volume data should be ndarray with shape NxLxLxL"
            " or LxLxL.",
        )

        ensure(
            data.shape[1] == data.shape[2] == data.shape[3],
            "Only cubed ndarrays are supported.",
        )

        self._data = data
        self.n_vols = self._data.shape[0]
        self.dtype = self._data.dtype
        self.resolution = self._data.shape[1]
        self.shape = self._data.shape
        self.volume_shape = self._data.shape[1:]
Exemple #11
0
    def evaluate(self, omega):
        """
        Evaluate the filter at specified frequencies.
        :param omega: A vector of size n (for 1d filters), or an array of size 2-by-n, representing the spatial
            frequencies at which the filter is to be evaluated. These are normalized so that pi is equal to the Nyquist
            frequency.
        :return: The value of the filter at the specified frequencies.
        """
        if omega.ndim == 1:
            ensure(self.radial,
                   "Cannot evaluate a non-radial filter on 1D input array.")
        elif omega.ndim == 2 and self.dim:
            ensure(omega.shape[0] == self.dim,
                   f"Omega must be of size {self.dim} x n")

        if self.radial:
            if omega.ndim > 1:
                omega = np.sqrt(np.sum(omega**2, axis=0))
            omega, idx = np.unique(omega, return_inverse=True)
            omega = np.vstack((omega, np.zeros_like(omega)))

        h = self._evaluate(omega)

        if self.radial:
            h = np.take(h, idx)

        return h
Exemple #12
0
    def adjoint(self, signal):
        """
        Compute the NUFFT adjoint using this plan instance.

        :param signal: Signal to be transformed. For a single transform,
        this should be a a 1D array of len `num_pts`.
        For a batch, signal should have shape `(ntransforms, num_pts)`.

        :returns: Transformed signal `(sz)` or `(ntransforms, sz)`.
        """

        # Note, there is a corner case for ntransforms == 1.
        if self.ntransforms > 1 or (self.ntransforms == 1
                                    and len(signal.shape) == 2):
            ensure(
                len(signal.shape) == 2,  # Stack and num_pts
                f"For multiple {self.dim}D adjoints, signal should be"
                f" a {self.ntransforms} element stack of {self.num_pts}.",
            )
            ensure(
                signal.shape[0] == self.ntransforms,
                "For multiple transforms, signal stack length"
                f" should match ntransforms {self.ntransforms}.",
            )

            # finufft is expecting flat array for 1D case.
            if self.ntransforms == 1:
                signal = signal.reshape(self.num_pts)

        result = self._adjoint_plan.execute(signal)

        return result
    def syncmatrix_vote(self):
        """
        Construct the synchronization matrix using voting method

        A pre-computed common line matrix is required as input.
        """
        if self.clmatrix is None:
            self.build_clmatrix()

        clmatrix = self.clmatrix

        sz = clmatrix.shape
        n_theta = self.n_theta

        ensure(sz[0] == sz[1], "clmatrix must be a square matrix.")

        n_img = sz[0]
        S = np.eye(2 * n_img, dtype=self.dtype).reshape(n_img, 2, n_img, 2)

        # Build Synchronization matrix from the rotation blocks in X and Y
        for i in range(n_img - 1):
            for j in range(i + 1, n_img):
                rot_block = self._syncmatrix_ij_vote(
                    clmatrix, i, j, np.arange(n_img), n_theta
                )
                S[i, :, j, :] = rot_block
                S[j, :, i, :] = rot_block.T

        self.syncmatrix = S.reshape(2 * n_img, 2 * n_img)
Exemple #14
0
    def find_registration(self, rots_ref):
        """
        Register estimated orientations to reference ones.

        Finds the orthogonal transformation that best aligns the estimated rotations
        to the reference rotations.

        :param rots_ref: The reference Rotation object to which we would like to align
            with data matrices in the form of a n-by-3-by-3 array.
        :return: o_mat, optimal orthogonal 3x3 matrix to align the two sets;
                flag, flag==1 then J conjugacy is required and 0 is not.
        """
        rots = self._matrices
        rots_ref = rots_ref.matrices.astype(self.dtype)
        ensure(
            rots.shape == rots_ref.shape,
            "Two sets of rotations must have same dimensions.",
        )
        K = rots.shape[0]

        # Reflection matrix
        J = np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]])

        Q1 = np.zeros((3, 3), dtype=rots.dtype)
        Q2 = np.zeros((3, 3), dtype=rots.dtype)

        for k in range(K):
            R = rots[k, :, :]
            Rref = rots_ref[k, :, :]
            Q1 = Q1 + R @ Rref.T
            Q2 = Q2 + (J @ R @ J) @ Rref.T

        # Compute the two possible orthogonal matrices which register the
        # estimated rotations to the true ones.
        Q1 = Q1 / K
        Q2 = Q2 / K

        # We are registering one set of rotations (the estimated ones) to
        # another set of rotations (the true ones). Thus, the transformation
        # matrix between the two sets of rotations should be orthogonal. This
        # matrix is either Q1 if we recover the non-reflected solution, or Q2,
        # if we got the reflected one. In any case, one of them should be
        # orthogonal.

        err1 = norm(Q1 @ Q1.T - np.eye(3), ord="fro")
        err2 = norm(Q2 @ Q2.T - np.eye(3), ord="fro")

        # In any case, enforce the registering matrix O to be a rotation.
        if err1 < err2:
            # Use Q1 as the registering matrix
            U, _, V = svd(Q1)
            flag = 0
        else:
            # Use Q2 as the registering matrix
            U, _, V = svd(Q2)
            flag = 1

        Q_mat = U @ V

        return Q_mat, flag
Exemple #15
0
    def mse(self, rots_ref):
        """
        Calculate MSE between the estimated orientations to reference ones.

        :param rots_reg: The estimated Rotation object after alignment
             with data matrices in the form of a n-by-3-by-3 array.
        :param rots_ref: The reference Rotation object.
        :return: The MSE value between two sets of rotations.
        """
        aligned_rots = self.register(rots_ref)
        rots_reg = aligned_rots.matrices
        rots_ref = rots_ref.matrices
        ensure(
            rots_reg.shape == rots_ref.shape,
            "Two sets of rotations must have same dimensions.",
        )
        K = rots_reg.shape[0]

        diff = np.zeros(K)
        mse = 0
        for k in range(K):
            diff[k] = norm(rots_reg[k, :, :] - rots_ref[k, :, :], ord="fro")
            mse += diff[k]**2
        mse = mse / K
        return mse
Exemple #16
0
 def __init__(self, basis):
     """
     constructor of an object for 2D covariance analysis
     """
     self.basis = basis
     self.dtype = self.basis.dtype
     ensure(basis.ndim == 2,
            "Only two-dimensional basis functions are needed.")
Exemple #17
0
    def process_micrograph(self,
                           filepath,
                           return_centers=True,
                           return_img=False,
                           show_progress=True,
                           create_jpg=False):
        ensure(not all([return_centers, return_img]),
               "Cannot specify both return_centers and return_img")

        picker = Picker(self.particle_size, self.max_particle_size,
                        self.min_particle_size, self.query_image_size,
                        self.tau1, self.tau2, self.minimum_overlap_amount,
                        self.container_size, filepath, self.output_dir)

        logger.info('Computing scores for query images')
        score = picker.query_score(
            show_progress=show_progress
        )  # compute score using normalized cross-correlations

        while True:
            logger.info(
                f'Running svm with tau1={picker.tau1}, tau2={picker.tau2}')
            # train SVM classifier and classify all windows in micrograph
            segmentation = picker.run_svm(score)

            # If all windows are classified identically, update tau_1 or tau_2
            if np.all(segmentation):
                picker.tau2 += 500
            elif not np.any(segmentation):
                picker.tau1 += 500
            else:
                break

        logger.info('Discarding suspected artifacts')
        segmentation = picker.morphology_ops(segmentation)

        logger.info('Getting particle centers')
        centers = picker.extract_particles(segmentation)

        particle_image = None
        if create_jpg and self.output_dir is not None:
            particle_image = self.particle_image(picker.original_im,
                                                 picker.particle_size, centers)
            misc.imsave(
                os.path.join(
                    self.output_dir,
                    os.path.splitext(os.path.basename(picker.filename))[0] +
                    '_result.jpg'), particle_image)

        if return_centers:
            return centers
        elif return_img:
            if particle_image is not None:
                return particle_image
            else:
                return self.particle_image(picker.original_im,
                                           picker.particle_size, centers)
Exemple #18
0
 def _init_margins(self, margin):
     if margin is None:
         t = r = b = l = None
     elif isinstance(margin, (tuple, list)):
         ensure(len(margin)==4, 'If specifying margins a a tuple/list, specify the top/right/bottom/left margins.')
         t, r, b, l = margin
     else:  # assume scalar
         t = r = b = l = int(margin)
     self.margin_top, self.margin_right, self.margin_bottom, self.margin_left = t, r, b, l
Exemple #19
0
def downsample(insamples, szout, mask=None):
    """
    Blur and downsample 1D to 3D objects such as, curves, images or volumes

    The function handles odd and even-sized arrays correctly. The center of
    an odd array is taken to be at (n+1)/2, and an even array is n/2+1.
    :param insamples: Set of objects to be downsampled in the form of an array, the last dimension
                    is the number of objects.
    :param szout: The desired resolution of for output objects.
    :return: An array consists of the blurred and downsampled objects.
    """

    ensure(
        insamples.ndim - 1 == np.size(szout),
        'The number of downsampling dimensions is not the same as that of objects.'
    )

    L_in = insamples.shape[0]
    L_out = szout[0]
    ndata = insamples.shape[-1]
    outdims = np.r_[szout, ndata]

    outsamples = np.zeros(outdims, dtype=insamples.dtype)

    if mask is None:
        mask = 1.0

    if insamples.ndim == 2:
        # stack of one dimension objects

        for idata in range(ndata):
            insamples_fft = crop_pad(fftshift(fft(insamples[:, idata])),
                                     L_out) * mask
            outsamples[:, idata] = np.real(
                ifft(ifftshift(insamples_fft)) * (L_out / L_in))

    elif insamples.ndim == 3:
        # stack of two dimension objects
        for idata in range(ndata):
            insamples_fft = crop_pad(fftshift(fft2(insamples[:, :, idata])),
                                     L_out) * mask
            outsamples[:, :, idata] = np.real(
                ifft2(ifftshift(insamples_fft)) * (L_out**2 / L_in**2))

    elif insamples.ndim == 4:
        # stack of three dimension objects
        for idata in range(ndata):
            insamples_fft = crop_pad(fftshift(fftn(insamples[:, :, :, idata])),
                                     L_out) * mask
            outsamples[:, :, :, idata] = np.real(
                ifftn(ifftshift(insamples_fft)) * (L_out**3 / L_in**3))

    else:
        raise RuntimeError('Number of dimensions > 3 for input objects.')

    return outsamples
Exemple #20
0
def vol_to_vec(X):
    """
    Roll up volumes into vectors
    :param X: N-by-N-by-N-by-... array.
    :return: An N^3-by-... array.
    """
    shape = X.shape
    ensure(X.ndim >= 3, "Array should have at least 3 dimensions")
    ensure(shape[0] == shape[1] == shape[2], "Array should have first 3 dimensions identical")

    return m_reshape(X, (shape[0]**3,) + (shape[3:]))
Exemple #21
0
    def transform(self, signal):
        ensure(signal.shape == self.sz,
               f'Signal to be transformed must have shape {self.sz}')

        self._plan.f_hat = signal.astype('complex64')
        f = self._plan.trafo()

        if signal.dtype == np.float32:
            f = f.astype('complex64')

        return f
Exemple #22
0
def vec_to_vol(X):
    """
    Unroll vectors to volumes
    :param X: N^3-by-... array.
    :return: An N-by-N-by-N-by-... array.
    """
    shape = X.shape
    N = round(shape[0] ** (1 / 3))
    ensure(N ** 3 == shape[0], "First dimension of X must be cubic")

    return m_reshape(X, (N, N, N) + (shape[1:]))
Exemple #23
0
def vec_to_im(X):
    """
    Unroll vectors to images
    :param X: N^2-by-... array.
    :return: An N-by-N-by-... array.
    """
    shape = X.shape
    N = round(shape[0] ** (1 / 2))
    ensure(N ** 2 == shape[0], "First dimension of X must be square")

    return m_reshape(X, (N, N) + (shape[1:]))
Exemple #24
0
def im_to_vec(im):
    """
    Roll up images into vectors
    :param im: An N-by-N-by-... array.
    :return: An N^2-by-... array.
    """
    shape = im.shape
    ensure(im.ndim >= 2, "Array should have at least 2 dimensions")
    ensure(shape[0] == shape[1], "Array should have first 2 dimensions identical")

    return m_reshape(im, (shape[0] ** 2,) + (shape[2:]))
Exemple #25
0
    def __init__(self, data):
        ensure(data.shape[0] == data.shape[1],
               'Only square ndarrays are supported.')
        if data.ndim == 2:
            data = data[:, :, np.newaxis]

        self.data = data
        self.dtype = self.data.dtype
        self.shape = self.data.shape
        self.n_images = self.shape[-1]
        self.res = self.shape[0]
Exemple #26
0
    def eval_clustering(self, vol_idx):
        """
        Evaluate clustering estimation
        :param vol_idx: Indexes of the volumes determined (0-indexed)
        :return: Accuracy [0-1] in terms of proportion of correctly assigned labels
        """
        ensure(
            len(vol_idx) == self.n,
            f'Need {self.n} vol indexes to evaluate clustering')
        # Remember that `states` is 1-indexed while vol_idx is 0-indexed
        correctly_classified = np.sum(self.states - 1 == vol_idx)

        return correctly_classified / self.n
Exemple #27
0
    def set_max_resolution(self, max_L):
        ensure(
            max_L <= self.L,
            "Max desired resolution should be less than the current resolution"
        )
        self.L = max_L

        ds_factor = self._L / max_L
        self.filters.scale(ds_factor)
        self.offsets /= ds_factor

        # Invalidate images
        self._im = None
Exemple #28
0
def acorr(x, y, axes=None):
    """
    Calculate array correlation along given axes
    :param x: An array of arbitrary shape
    :param y: An array of same shape as x
    :param axes: The axis along which to compute the correlation. If None, the correlation is calculated along all axes.
    :return: The correlation of x along specified axes.
    """
    ensure(x.shape == y.shape, "The shapes of the inputs have to match")

    if axes is None:
        axes = range(x.ndim)
    return ainner(x, y, axes) / (anorm(x, axes) * anorm(y, axes))
Exemple #29
0
def ainner(x, y, axes=None):
    """
    Calculate array inner product along given axes
    :param x: An array of arbitrary shape
    :param y: An array of same shape as x
    :param axes: The axis along which to compute the inner product. If None, the product is calculated along all axes.
    :return:
    """
    ensure(x.shape == y.shape, "The shapes of the inputs have to match")

    if axes is None:
        axes = range(x.ndim)
    return np.tensordot(x, y, axes=(axes, axes))
Exemple #30
0
def shrink_covar(covar_in, noise_var, gamma, shrinker=None):
    """
    Shrink the covariance matrix
    :param covar_in: An input covariance matrix
    :param noise_var: The estimated variance of noise
    :param gamma: An input parameter to specify the maximum values of eigen values to be neglected.
    :param shrinker: An input parameter to select different shrinking methods.
    :return: The shrinked covariance matrix
    """

    if shrinker is None:
        shrinker = "frobenius_norm"
    ensure(
        shrinker in ("frobenius_norm", "operator_norm", "soft_threshold"),
        "Unsupported shrink method",
    )

    covar = covar_in / noise_var

    lambs, eig_vec = eig(make_symmat(covar))

    lambda_max = (1 + np.sqrt(gamma))**2

    lambs[lambs < lambda_max] = 0

    if shrinker == "operator_norm":
        lambdas = lambs[lambs > lambda_max]
        lambdas = (1 / 2 * (lambdas - gamma + 1 + np.sqrt(
            (lambdas - gamma + 1)**2 - 4 * lambdas)) - 1)
        lambs[lambs > lambda_max] = lambdas
    elif shrinker == "frobenius_norm":
        lambdas = lambs[lambs > lambda_max]
        lambdas = (1 / 2 * (lambdas - gamma + 1 + np.sqrt(
            (lambdas - gamma + 1)**2 - 4 * lambdas)) - 1)
        c = np.divide((1 - np.divide(gamma, lambdas**2)),
                      (1 + np.divide(gamma, lambdas)))
        lambdas = lambdas * c
        lambs[lambs > lambda_max] = lambdas
    else:
        # for the case of shrinker == 'soft_threshold'
        lambdas = lambs[lambs > lambda_max]
        lambs[lambs > lambda_max] = lambdas - lambda_max

    diag_lambs = np.zeros_like(covar)
    np.fill_diagonal(diag_lambs, lambs)

    shrinked_covar = eig_vec @ diag_lambs @ eig_vec.conj().T
    shrinked_covar *= noise_var

    return shrinked_covar