Exemple #1
0
def cryo_estimate_mean(im, params, basis=None, mean_est_opt=None):
    """
    1e-5 error from matlab
    :param im:
    :param params:
    :param basis:
    :param mean_est_opt:
    :return:
    """
    resolution = im.shape[1]

    if basis is None:
        basis = DiracBasis((resolution, resolution, resolution))

    mean_est_opt = fill_struct(mean_est_opt, {'precision': 'float64', 'preconditioner': 'circulant'})

    kernel_f = cryo_mean_kernel_f(resolution, params, mean_est_opt)

    precond_kernel_f = []
    if mean_est_opt.preconditioner == 'circulant':
        precond_kernel_f = 1 / circularize_kernel_f(kernel_f)
    elif mean_est_opt.preconditioner != 'none':
        raise ValueError('Invalid preconditioner type')

    def identity(x):
        return x

    mean_est_opt.preconditioner = identity
    im_bp = cryo_mean_backproject(im, params, mean_est_opt)

    mean_est, cg_info = cryo_conj_grad_mean(kernel_f, im_bp, basis, precond_kernel_f, mean_est_opt)
    return mean_est, cg_info
Exemple #2
0
def subset_params(params, ind):
    batch_params = fill_struct()

    params.rot_matrices = params.rot_matrices[:, :, ind]
    params.ctf_idx = params.ctf_idx[:, ind]
    params.ampl = params.ampl[:, ind]
    params.shifts = params.shifts[:, ind]

    return batch_params
Exemple #3
0
def cryo_mean_kernel_f(resolution, params, mean_est_opt=None):
    """
    8e-14 error from matlab
    :param resolution:
    :param params:
    :param mean_est_opt:
    :return:
    """
    mean_est_opt = fill_struct(mean_est_opt, {'precision': 'float64', 'half_pixel': False, 'batch_size': 0})
    n = params.rot_matrices.shape[2]

    # TODO debug, might be a problem with the first 2 lines
    if mean_est_opt.batch_size != 0:
        batch_size = mean_est_opt.batch_size
        mean_est_opt.batch_size = 0

        batch_ct = np.ceil(n / batch_size)
        mean_kernel_f = np.zeros([2 * resolution] * 3, dtype=mean_est_opt.precision)

        for batch in range(batch_ct):
            start = batch_size * batch
            end = min((batch_size + 1) * batch, n)

            batch_params = subset_params(params, np.arange(start, end))
            batch_kernel_f = cryo_mean_kernel_f(resolution, batch_params, mean_est_opt)
            mean_kernel_f += (end - start) / n * batch_kernel_f

        return mean_kernel_f

    pts_rot = rotated_grids(resolution, params.rot_matrices, mean_est_opt.half_pixel)
    filt = np.einsum('ij, k -> ijk', np.square(params.ctf), np.square(params.ampl), dtype=mean_est_opt.precision)

    if resolution % 2 == 0 and not mean_est_opt.half_pixel:
        # is it necessary?
        pts_rot = pts_rot[:, 1:, 1:]
        filt = filt[1:, 1:]

    # Reshape inputs into appropriate sizes and apply adjoint NUFFT
    pts_rot = pts_rot.reshape((3, -1), order='F')
    filt = filt.flatten('F')
    mean_kernel = anufft3(filt, pts_rot, [2 * resolution] * 3)
    mean_kernel /= n * resolution ** 2

    # Ensure symmetric kernel
    mean_kernel[0] = 0
    mean_kernel[:, 0] = 0
    mean_kernel[:, :, 0] = 0

    mean_kernel = mean_kernel.copy()
    # Take the Fourier transform since this is what we want to use when convolving
    mean_kernel = np.fft.ifftshift(mean_kernel)
    mean_kernel = np.fft.fftn(mean_kernel)
    mean_kernel = np.fft.fftshift(mean_kernel)
    mean_kernel = np.real(mean_kernel)
    return mean_kernel
Exemple #4
0
def mesh_2d(resolution, inclusive=False):
    if inclusive:
        cons = (resolution - 1) / 2
        grid = np.arange(-cons, cons + 1) / cons
    else:
        cons = resolution / 2
        grid = np.ceil(np.arange(-cons, cons)) / cons

    mesh = fill_struct()
    mesh.y, mesh.x = np.meshgrid(grid, grid)  # reversed from matlab
    mesh.phi, mesh.r, _ = cart2pol(mesh.x, mesh.y)
    return mesh
Exemple #5
0
def cryo_conj_grad_mean(kernel_f, im_bp, basis, precond_kernel_f=None, mean_est_opt=None):
    mean_est_opt = fill_struct(mean_est_opt)
    resolution = im_bp.shape[0]
    if len(im_bp.shape) != 3 or im_bp.shape[1] != resolution or im_bp.shape[2] != resolution:
        raise ValueError('im_bp must be as array of size LxLxL')

    def fun(vol_basis):
        return apply_mean_kernel(vol_basis, kernel_f, basis)

    if precond_kernel_f is not None:
        def precond_fun(vol_basis):
            return apply_mean_kernel(vol_basis, precond_kernel_f, basis)

        mean_est_opt.preconditioner = precond_fun

    im_bp_basis = basis.evaluate_t(im_bp)
    mean_est_basis, _, cg_info = conj_grad(fun, im_bp_basis, mean_est_opt)
    mean_est = basis.evaluate(mean_est_basis)
    return mean_est, cg_info
Exemple #6
0
def cryo_mean_backproject(im, params, mean_est_opt=None):
    """
    1e-7 error from matlab
    :param im:
    :param params:
    :param mean_est_opt:
    :return:
    """
    mean_est_opt = fill_struct(mean_est_opt, {'precision': 'float64', 'half_pixel': False, 'batch_size': 0})
    if im.shape[0] != im.shape[1] or im.shape[0] == 1 or len(im.shape) != 3:
        raise ValueError('im must be 3 dimensional LxLxn where L > 1')

    resolution = im.shape[1]
    n = im.shape[2]

    if mean_est_opt.batch_size != 0:
        batch_size = mean_est_opt.batch_size
        mean_est_opt.batch_size = 0

        batch_ct = np.ceil(n / batch_size)
        im_bp = np.zeros([2 * resolution] * 3, dtype=mean_est_opt.precision)

        for batch in range(batch_ct):
            start = batch_size * batch
            end = min((batch_size + 1) * batch, n)

            batch_params = subset_params(params, np.arange(start, end))
            batch_im = im[:, :, start:end]
            batch_im_bp = cryo_mean_kernel_f(batch_im, batch_params, mean_est_opt)
            im_bp += (end - start) / n * batch_im_bp

        return im_bp

    if mean_est_opt.precision == 'float32' or mean_est_opt.precision == 'single':
        im = im.astype('float32')

    filter_f = np.einsum('ij, k -> ijk', params.ctf, np.ones(np.count_nonzero(params.ctf_idx)))
    im = im * params.ampl
    im = im_translate(im, -params.shifts)
    im = im_filter(im, filter_f)
    im = im_backproject(im, params.rot_matrices, mean_est_opt.half_pixel)
    im /= n
    return im
Exemple #7
0
def conj_grad(a_fun, b, cg_opt=None, init=None):

    def identity(input_x):
        return input_x

    cg_opt = fill_struct(cg_opt, {'max_iter': 50, 'verbose': 0, 'iter_callback': [], 'preconditioner': identity,
                                  'rel_tolerance': 1e-15, 'store_iterates': False})
    init = fill_struct(init, {'x': None, 'p': None})
    if init.x is None:
        x = np.zeros(b.shape)
    else:
        x = init.x

    b_norm = np.linalg.norm(b)
    r = b.copy()
    s = cg_opt.preconditioner(r)

    if np.any(x != 0):
        if cg_opt.verbose:
            print('[CG] Calculating initial residual')
        a_x = a_fun(x)
        r = r-a_x
        s = cg_opt.preconditioner(r)
    else:
        a_x = np.zeros(x.shape)

    obj = np.real(np.sum(x.conj() * a_x, 0) - 2 * np.real(np.sum(np.conj(b * x), 0)))

    if init.p is None:
        p = s
    else:
        p = init.p

    info = fill_struct(att_vals={'iter': [0], 'res': [np.linalg.norm(r)], 'obj': [obj]})
    if cg_opt.store_iterates:
        info = fill_struct(info, att_vals={'x': [x], 'r': [r], 'p': [p]})

    if cg_opt.verbose:
        print('[CG] Initialized. Residual: {}. Objective: {}'.format(np.linalg.norm(info.res[0]), np.sum(info.obj[0])))

    if b_norm == 0:
        print('b_norm == 0')
        return

    for i in range(1, cg_opt.max_iter):
        if cg_opt.verbose:
            print('[CG] Applying matrix & preconditioner')

        a_p = a_fun(p)
        old_gamma = np.real(np.sum(s.conj() * r))

        alpha = old_gamma / np.real(np.sum(p.conj() * a_p))
        x += alpha * p
        a_x += alpha * a_p

        r -= alpha * a_p
        s = cg_opt.preconditioner(r)
        new_gamma = np.real(np.sum(r.conj() * s))
        beta = new_gamma / old_gamma
        p *= beta
        p += s

        obj = np.real(np.sum(x.conj() * a_x, 0) - 2 * np.real(np.sum(np.conj(b * x), 0)))
        res = np.linalg.norm(r)
        info.iter.append(i)
        info.res.append(res)
        info.obj.append(obj)
        if cg_opt.store_iterates:
            info.x.append(x)
            info.r.append(r)
            info.p.append(p)

        if cg_opt.verbose:
            print('[CG] Initialized. Residual: {}. Objective: {}'.format(np.linalg.norm(info.res[0]), np.sum(info.obj[0])))

        if np.all(res < b_norm * cg_opt.rel_tolerance):
            break

    # if i == cg_opt.max_iter - 1:
    #     raise Warning('Conjugate gradient reached maximum number of iterations!')
    return x, obj, info