def cryo_estimate_mean(im, params, basis=None, mean_est_opt=None): """ 1e-5 error from matlab :param im: :param params: :param basis: :param mean_est_opt: :return: """ resolution = im.shape[1] if basis is None: basis = DiracBasis((resolution, resolution, resolution)) mean_est_opt = fill_struct(mean_est_opt, {'precision': 'float64', 'preconditioner': 'circulant'}) kernel_f = cryo_mean_kernel_f(resolution, params, mean_est_opt) precond_kernel_f = [] if mean_est_opt.preconditioner == 'circulant': precond_kernel_f = 1 / circularize_kernel_f(kernel_f) elif mean_est_opt.preconditioner != 'none': raise ValueError('Invalid preconditioner type') def identity(x): return x mean_est_opt.preconditioner = identity im_bp = cryo_mean_backproject(im, params, mean_est_opt) mean_est, cg_info = cryo_conj_grad_mean(kernel_f, im_bp, basis, precond_kernel_f, mean_est_opt) return mean_est, cg_info
def subset_params(params, ind): batch_params = fill_struct() params.rot_matrices = params.rot_matrices[:, :, ind] params.ctf_idx = params.ctf_idx[:, ind] params.ampl = params.ampl[:, ind] params.shifts = params.shifts[:, ind] return batch_params
def cryo_mean_kernel_f(resolution, params, mean_est_opt=None): """ 8e-14 error from matlab :param resolution: :param params: :param mean_est_opt: :return: """ mean_est_opt = fill_struct(mean_est_opt, {'precision': 'float64', 'half_pixel': False, 'batch_size': 0}) n = params.rot_matrices.shape[2] # TODO debug, might be a problem with the first 2 lines if mean_est_opt.batch_size != 0: batch_size = mean_est_opt.batch_size mean_est_opt.batch_size = 0 batch_ct = np.ceil(n / batch_size) mean_kernel_f = np.zeros([2 * resolution] * 3, dtype=mean_est_opt.precision) for batch in range(batch_ct): start = batch_size * batch end = min((batch_size + 1) * batch, n) batch_params = subset_params(params, np.arange(start, end)) batch_kernel_f = cryo_mean_kernel_f(resolution, batch_params, mean_est_opt) mean_kernel_f += (end - start) / n * batch_kernel_f return mean_kernel_f pts_rot = rotated_grids(resolution, params.rot_matrices, mean_est_opt.half_pixel) filt = np.einsum('ij, k -> ijk', np.square(params.ctf), np.square(params.ampl), dtype=mean_est_opt.precision) if resolution % 2 == 0 and not mean_est_opt.half_pixel: # is it necessary? pts_rot = pts_rot[:, 1:, 1:] filt = filt[1:, 1:] # Reshape inputs into appropriate sizes and apply adjoint NUFFT pts_rot = pts_rot.reshape((3, -1), order='F') filt = filt.flatten('F') mean_kernel = anufft3(filt, pts_rot, [2 * resolution] * 3) mean_kernel /= n * resolution ** 2 # Ensure symmetric kernel mean_kernel[0] = 0 mean_kernel[:, 0] = 0 mean_kernel[:, :, 0] = 0 mean_kernel = mean_kernel.copy() # Take the Fourier transform since this is what we want to use when convolving mean_kernel = np.fft.ifftshift(mean_kernel) mean_kernel = np.fft.fftn(mean_kernel) mean_kernel = np.fft.fftshift(mean_kernel) mean_kernel = np.real(mean_kernel) return mean_kernel
def mesh_2d(resolution, inclusive=False): if inclusive: cons = (resolution - 1) / 2 grid = np.arange(-cons, cons + 1) / cons else: cons = resolution / 2 grid = np.ceil(np.arange(-cons, cons)) / cons mesh = fill_struct() mesh.y, mesh.x = np.meshgrid(grid, grid) # reversed from matlab mesh.phi, mesh.r, _ = cart2pol(mesh.x, mesh.y) return mesh
def cryo_conj_grad_mean(kernel_f, im_bp, basis, precond_kernel_f=None, mean_est_opt=None): mean_est_opt = fill_struct(mean_est_opt) resolution = im_bp.shape[0] if len(im_bp.shape) != 3 or im_bp.shape[1] != resolution or im_bp.shape[2] != resolution: raise ValueError('im_bp must be as array of size LxLxL') def fun(vol_basis): return apply_mean_kernel(vol_basis, kernel_f, basis) if precond_kernel_f is not None: def precond_fun(vol_basis): return apply_mean_kernel(vol_basis, precond_kernel_f, basis) mean_est_opt.preconditioner = precond_fun im_bp_basis = basis.evaluate_t(im_bp) mean_est_basis, _, cg_info = conj_grad(fun, im_bp_basis, mean_est_opt) mean_est = basis.evaluate(mean_est_basis) return mean_est, cg_info
def cryo_mean_backproject(im, params, mean_est_opt=None): """ 1e-7 error from matlab :param im: :param params: :param mean_est_opt: :return: """ mean_est_opt = fill_struct(mean_est_opt, {'precision': 'float64', 'half_pixel': False, 'batch_size': 0}) if im.shape[0] != im.shape[1] or im.shape[0] == 1 or len(im.shape) != 3: raise ValueError('im must be 3 dimensional LxLxn where L > 1') resolution = im.shape[1] n = im.shape[2] if mean_est_opt.batch_size != 0: batch_size = mean_est_opt.batch_size mean_est_opt.batch_size = 0 batch_ct = np.ceil(n / batch_size) im_bp = np.zeros([2 * resolution] * 3, dtype=mean_est_opt.precision) for batch in range(batch_ct): start = batch_size * batch end = min((batch_size + 1) * batch, n) batch_params = subset_params(params, np.arange(start, end)) batch_im = im[:, :, start:end] batch_im_bp = cryo_mean_kernel_f(batch_im, batch_params, mean_est_opt) im_bp += (end - start) / n * batch_im_bp return im_bp if mean_est_opt.precision == 'float32' or mean_est_opt.precision == 'single': im = im.astype('float32') filter_f = np.einsum('ij, k -> ijk', params.ctf, np.ones(np.count_nonzero(params.ctf_idx))) im = im * params.ampl im = im_translate(im, -params.shifts) im = im_filter(im, filter_f) im = im_backproject(im, params.rot_matrices, mean_est_opt.half_pixel) im /= n return im
def conj_grad(a_fun, b, cg_opt=None, init=None): def identity(input_x): return input_x cg_opt = fill_struct(cg_opt, {'max_iter': 50, 'verbose': 0, 'iter_callback': [], 'preconditioner': identity, 'rel_tolerance': 1e-15, 'store_iterates': False}) init = fill_struct(init, {'x': None, 'p': None}) if init.x is None: x = np.zeros(b.shape) else: x = init.x b_norm = np.linalg.norm(b) r = b.copy() s = cg_opt.preconditioner(r) if np.any(x != 0): if cg_opt.verbose: print('[CG] Calculating initial residual') a_x = a_fun(x) r = r-a_x s = cg_opt.preconditioner(r) else: a_x = np.zeros(x.shape) obj = np.real(np.sum(x.conj() * a_x, 0) - 2 * np.real(np.sum(np.conj(b * x), 0))) if init.p is None: p = s else: p = init.p info = fill_struct(att_vals={'iter': [0], 'res': [np.linalg.norm(r)], 'obj': [obj]}) if cg_opt.store_iterates: info = fill_struct(info, att_vals={'x': [x], 'r': [r], 'p': [p]}) if cg_opt.verbose: print('[CG] Initialized. Residual: {}. Objective: {}'.format(np.linalg.norm(info.res[0]), np.sum(info.obj[0]))) if b_norm == 0: print('b_norm == 0') return for i in range(1, cg_opt.max_iter): if cg_opt.verbose: print('[CG] Applying matrix & preconditioner') a_p = a_fun(p) old_gamma = np.real(np.sum(s.conj() * r)) alpha = old_gamma / np.real(np.sum(p.conj() * a_p)) x += alpha * p a_x += alpha * a_p r -= alpha * a_p s = cg_opt.preconditioner(r) new_gamma = np.real(np.sum(r.conj() * s)) beta = new_gamma / old_gamma p *= beta p += s obj = np.real(np.sum(x.conj() * a_x, 0) - 2 * np.real(np.sum(np.conj(b * x), 0))) res = np.linalg.norm(r) info.iter.append(i) info.res.append(res) info.obj.append(obj) if cg_opt.store_iterates: info.x.append(x) info.r.append(r) info.p.append(p) if cg_opt.verbose: print('[CG] Initialized. Residual: {}. Objective: {}'.format(np.linalg.norm(info.res[0]), np.sum(info.obj[0]))) if np.all(res < b_norm * cg_opt.rel_tolerance): break # if i == cg_opt.max_iter - 1: # raise Warning('Conjugate gradient reached maximum number of iterations!') return x, obj, info