Esempio n. 1
0
def correlation_noise(num_images=1000, N=128, rad=0.6, stack_noise=False):
    noise_stack = np.require(np.random.randn(num_images, N, N), dtype=density.real_t)
    real_image = noise_stack[np.random.randint(num_images)]
    corr_real_image = correlation.calc_full_ac(real_image, rad=rad)
    fourier_noise = density.real_to_fspace(real_image)
    fourier_corr_image = correlation.calc_full_ac(fourier_noise, rad=rad)

    _, _, mask = geometry.gencoords(N, 2, rad, True)
    plot_noise_histogram(corr_real_image, fourier_corr_image, mask, mask)

    if stack_noise:
        center = int(N/2)
        x_shift, y_shift = np.random.randint(-center, center, size=2)

        fourier_noise_stack = density.real_to_fspace(noise_stack, axes=(1, 2))
        corr_noise_stack = np.zeros_like(noise_stack, dtype=density.real_t)
        fourier_corr_noise_stack = np.zeros_like(fourier_noise_stack, dtype=density.complex_t)

        for i in range(num_images):
            corr_noise_stack[i] = correlation.calc_full_ac(noise_stack[i], rad=rad)
            fourier_corr_noise_stack[i] = correlation.calc_full_ac(fourier_noise_stack[i], rad=rad)

        noise_zoom = noise_stack[:, x_shift, y_shift]
        fourier_noise_zoom = fourier_noise_stack[:, x_shift, y_shift]
        plot_stack_noise(noise_zoom, fourier_noise_zoom)
Esempio n. 2
0
    def eval(self, M=None, compute_gradient=True, fM=None, **kwargs):
        tic_start = time.time()

        if self.kernel.slice_premult is not None:
            pfM = density.real_to_fspace(self.kernel.slice_premult * M)
        else:
            pfM = density.real_to_fspace(M)
        pmtime = time.time() - tic_start

        ret = self.kernel.eval(fM=pfM,
                               M=None,
                               compute_gradient=compute_gradient)

        if not self.minibatch['test_batch'] and not kwargs.get(
                'intermediate', False):
            tic_record = time.time()
            curr_var = ret[-1]['sigma2_est']
            assert n.all(n.isfinite(curr_var))
            if self.error_history.N_sum != self.cryodata.N_batches:
                self.error_history.setup(curr_var,
                                         self.cryodata.N_batches,
                                         allow_decay=False)
            self.error_history.set_value(self.minibatch['id'], curr_var)

            curr_corr = ret[-1]['correlation']
            assert n.all(n.isfinite(curr_corr))
            if self.correlation_history.N_sum != self.cryodata.N_batches:
                self.correlation_history.setup(curr_corr,
                                               self.cryodata.N_batches,
                                               allow_decay=False)
            self.correlation_history.set_value(self.minibatch['id'], curr_corr)

            curr_power = ret[-1]['power']
            assert n.all(n.isfinite(curr_power))
            if self.power_history.N_sum != self.cryodata.N_batches:
                self.power_history.setup(curr_power,
                                         self.cryodata.N_batches,
                                         allow_decay=False)
            self.power_history.set_value(self.minibatch['id'], curr_power)

            curr_mask = self.kernel.truncmask
            if self.mask_history.N_sum != self.cryodata.N_batches:
                self.mask_history.setup(n.require(curr_mask, dtype=n.float32),
                                        self.cryodata.N_batches,
                                        allow_decay=False)
            self.mask_history.set_value(self.minibatch['id'], curr_mask)
            ret[-1]['like_timing']['record'] = time.time() - tic_record

        if compute_gradient and self.kernel.slice_premult is not None:
            tic_record = time.time()
            ret = (ret[0],
                   self.kernel.slice_premult * density.fspace_to_real(ret[1]),
                   ret[2])
            ret[-1]['like_timing']['premult'] = pmtime + time.time(
            ) - tic_record

        ret[-1]['like_timing']['total'] = time.time() - tic_start

        return ret
Esempio n. 3
0
def shift_vis(model):
    N = model.shape[0]
    rad = 0.8
    kernel = 'lanczos'
    kernsize = 4

    xy, trunc_xy, truncmask = geometry.gencoords(N, 2, rad, True)
    N_T = trunc_xy.shape[0]
    premult = cryoops.compute_premultiplier(N,
                                            kernel=kernel,
                                            kernsize=kernsize)
    TtoF = sincint.gentrunctofull(N=N, rad=rad)

    fM = density.real_to_fspace(model)
    prefM = density.real_to_fspace(
        premult.reshape((1, 1, -1)) * premult.reshape(
            (1, -1, 1)) * premult.reshape((-1, 1, 1)) * model)

    pt = np.random.randn(3)
    pt /= np.linalg.norm(pt)
    psi = 2 * np.pi * np.random.rand()
    ea = geometry.genEA(pt)[0]
    ea[2] = psi
    print('project model for Euler angel: ({:.2f}, {:.2f}, {:.2f}) degree'.
          format(*np.rad2deg(ea)))

    rot_matrix = geometry.rotmat3D_EA(*ea)[:, 0:2]
    slop = cryoops.compute_projection_matrix([rot_matrix], N, kernel, kernsize,
                                             rad, 'rots')
    # trunc_slice = slop.dot(prefM.reshape((-1,)))
    trunc_slice = cryoem.getslices(prefM, slop)
    fourier_slice = TtoF.dot(trunc_slice).reshape(N, N)
    real_proj = density.fspace_to_real(fourier_slice)

    fig, axes = plt.subplots(4, 4, figsize=(12.8, 8))

    im_real = axes[0, 0].imshow(real_proj)
    im_fourier = axes[1, 0].imshow(np.log(np.abs(fourier_slice)))

    for i, ax in enumerate(axes[:, 1:].T):
        shift = np.random.randn(2) * (N / 4.0)
        S = cryoops.compute_shift_phases(shift.reshape(1, 2), N, rad)[0]

        shift_trunc_slice = S * trunc_slice
        shift_fourier_slice = TtoF.dot(shift_trunc_slice).reshape(N, N)
        shift_real_proj = density.fspace_to_real(shift_fourier_slice)

        ax[0].imshow(shift_real_proj)
        ax[1].imshow(np.log(np.abs(shift_fourier_slice)))
        ax[2].imshow(np.log(shift_fourier_slice.real))
        ax[3].imshow(np.log(shift_fourier_slice.imag))

    fig.tight_layout()
    plt.show()
Esempio n. 4
0
def premult_test(model, kernel='lanczos', kernsize=6):
    if isinstance(model, str):
        M = mrc.readMRC(model)
    elif isinstance(model, np.ndarray):
        M = model

    shape = np.asarray(M.shape)
    assert (shape - shape.mean()).sum() == 0

    N = M.shape[0]
    rad = 0.6

    premult = cryoops.compute_premultiplier(N, kernel, kernsize)
    TtoF = sincint.gentrunctofull(N=N, rad=rad)
    premulter =   premult.reshape((1, 1, -1)) \
                * premult.reshape((1, -1, 1)) \
                * premult.reshape((-1, 1, 1))

    fM = density.real_to_fspace(M)
    prefM = density.real_to_fspace(premulter * M)

    pt = np.random.randn(3)
    pt /= np.linalg.norm(pt)
    psi = 2 * np.pi * np.random.rand()
    ea = geometry.genEA(pt)[0]
    ea[2] = psi
    print('project model for Euler angel: ({:.2f}, {:.2f}, {:.2f}) degree'.
          format(*np.rad2deg(ea)))

    rot_matrix = geometry.rotmat3D_EA(*ea)[:, 0:2]
    slop = cryoops.compute_projection_matrix([rot_matrix], N, kernel, kernsize,
                                             rad, 'rots')
    trunc_slice = slop.dot(fM.reshape((-1, )))
    premult_trunc_slice = slop.dot(prefM.reshape((-1, )))
    proj = density.fspace_to_real(TtoF.dot(trunc_slice).reshape(N, N))
    premult_proj = density.fspace_to_real(
        TtoF.dot(premult_trunc_slice).reshape(N, N))

    fig, ax = plt.subplots(1, 3, figsize=(14.4, 4.8))
    im_proj = ax[0].imshow(proj, origin='lower')
    fig.colorbar(im_proj, ax=ax[0])
    ax[0].set_title('no premulter')
    im_pre = ax[1].imshow(premult_proj, origin='lower')
    fig.colorbar(im_pre, ax=ax[1])
    ax[1].set_title('with premulter')
    im_diff = ax[2].imshow(proj - premult_proj, origin='lower')
    fig.colorbar(im_diff, ax=ax[2])
    ax[2].set_title('difference of two image')
    fig.tight_layout()
    plt.show()
Esempio n. 5
0
def fourierspace_benchmark_comparison(model, ea=None, num_inplane_angles=360, modulus=False,
                                      save_animation=False, animation_name=None):
    N = model.shape[0]
    ea, euler_angles = gen_EAs_randomly(ea, num_inplane_angles)

    proj = projector.project(model, ea)
    one_slice = density.real_to_fspace(proj)
    rot_projs = projector.project(M, euler_angles)
    rot_slices = np.zeros_like(rot_projs, dtype=density.complex_t)
    for i, pj in enumerate(rot_projs):
        rot_slices[i] = density.real_to_fspace(pj)
    if modulus:
        one_slice = np.abs(one_slice)
        rot_slices = np.abs(rot_slices)

    if modulus:
        corr_slice = correlation.get_corr_img(one_slice)
        corr_rot_slices = correlation.get_corr_imgs(rot_slices)

        diff = calc_difference(one_slice, rot_slices, only_real=True)
        corr_diff = calc_difference(corr_slice, corr_rot_slices, only_real=True)
        
        vis_real_space_comparison(
            np.log(rot_slices),
            np.log(corr_rot_slices),
            diff, corr_diff,
            original_img=np.log(one_slice),
            original_corr_img=np.log(corr_slice),
            save_animation=save_animation, animation_name=animation_name
            )
    else:
        corr_slice = np.zeros((int(N/2.0), 360), dtype=density.complex_t)
        corr_slice.imag = correlation.get_corr_img(one_slice.imag)
        corr_slice.real = correlation.get_corr_img(one_slice.real)
        corr_rot_slices = np.zeros((num_inplane_angles, int(N/2.0), 360), dtype=density.complex_t)
        corr_rot_slices.real = correlation.get_corr_imgs(rot_slices.real)
        corr_rot_slices.imag = correlation.get_corr_imgs(rot_slices.imag)

        diff_real, diff_imag = calc_difference(one_slice, rot_slices)
        corr_diff_real, corr_diff_imag = calc_difference(corr_slice, corr_rot_slices)

        vis_fourier_space_comparison(
            rot_slices,
            corr_rot_slices,
            diff_real, diff_imag,
            corr_diff_real, corr_diff_imag,
            original_img=one_slice,
            original_corr_img=corr_slice,
            save_animation=save_animation, animation_name=animation_name
            )
Esempio n. 6
0
def no_correlation(num_images=1000, N=128, rad=0.8):
    center = int(N/2)
    x_shift, y_shift = np.random.randint(-center, center, size=2)

    noise_stack = np.require(np.random.randn(num_images, N, N), dtype=density.real_t)
    real_image = noise_stack[np.random.randint(num_images)]
    fourier_noise_stack = density.real_to_fspace(noise_stack, axes=(1, 2))
    fourier_noise = density.real_to_fspace(real_image)

    noise_zoom = noise_stack[:, x_shift, y_shift]
    fourier_noise_zoom = fourier_noise_stack[:, x_shift, y_shift]

    _, _, mask = geometry.gencoords(N, 2, rad, True)
    plot_noise_histogram(real_image, fourier_noise, rmask=mask, fmask=mask)
    plot_stack_noise(noise_zoom, fourier_noise_zoom)
Esempio n. 7
0
    def apply(self,M,normalize=True,kernel='sinc',kernelsize=3,rad=2):
        if self.symclass == '':
            return M

        if n.iscomplexobj(M):
            symRs = n.array(self.get_rotations(exclude_dsym=True), dtype=n.float32)

            N = M.shape[0]
            if rad*self.get_order()*N < 500: # VERY heuristic 
                symM = sincint.symmetrize_fspace_volume(M,rad,kernel,kernelsize,symRs=symRs)
                if self.symclass == 'd':
                    symR = n.array([ [ [ -1.0,   0,    0 ],
                                       [    0, 1.0,    0 ],
                                       [    0,   0, -1.0 ] ] ], dtype=n.float32)
                    symM = sincint.symmetrize_fspace_volume(symM,rad,kernel,kernelsize,symRs=symR)
            else:
                return density.real_to_fspace(self.apply(density.fspace_to_real(M),normalize=normalize))
        else:
            symRs = n.array(self.get_rotations(), dtype=n.float32)

            symM = sincint.symmetrize_volume(n.require(M,dtype=n.float32),symRs)
#             symRs = n.array(self.get_rotations(exclude_dsym=True), dtype=n.float32)
# 
#             symM = sincint.symmetrize_volume_z(M,symRs)
# 
#             if self.symclass == 'd':
#                 symR = n.array([ [ [ -1.0,     0,   0 ],
#                                    [    0,  -1.0,   0 ],
#                                    [    0,     0, 1.0 ] ] ], dtype=n.float32)
#                 symM = n.swapaxes(sincint.symmetrize_volume_z(n.swapaxes(symM,1,2),symR),1,2)
                
        if normalize:
            symM /= self.get_order()
             
        return symM
Esempio n. 8
0
    def get_image(self, idx):
        if not self.caching:
            self.transformed = {}
        if idx not in self.transformed:
            self.fft_lock.acquire()
            if self.zeropad:
                N = self.stack.get_num_pixels()
                img = self.zpimg
                img[self.zeropad:(
                    N + self.zeropad), self.zeropad:(N + self.zeropad)] = self.stack.get_image(idx)
            else:
                img = self.stack.get_image(idx)

            if self.premult is not None:
                img = self.premult * img

            self.transformed[idx] = density.real_to_fspace(img)
            self.fft_lock.release()

            self.fspacesum += self.transformed[idx]
            self.powersum += self.transformed[idx].real**2 + \
                self.transformed[idx].imag**2
            self.nsum += 1

        return self.transformed[idx]
Esempio n. 9
0
    def get_image(self, idx):
        if not self.caching:
            self.transformed = {}
        if idx not in self.transformed:
            self.fft_lock.acquire()
            if self.zeropad:
                N = self.stack.get_num_pixels()
                img = self.zpimg
                img[self.zeropad:(N + self.zeropad),
                    self.zeropad:(N +
                                  self.zeropad)] = self.stack.get_image(idx)
            else:
                img = self.stack.get_image(idx)

            if self.premult is not None:
                img = self.premult * img

            self.transformed[idx] = density.real_to_fspace(img)
            self.fft_lock.release()

            self.fspacesum += self.transformed[idx]
            self.powersum += self.transformed[idx].real**2 + self.transformed[
                idx].imag**2
            self.nsum += 1

        return self.transformed[idx]
    def eval(self, M=None, compute_gradient=True, fM=None, **kwargs):
        tic_start = time.time()

        if self.kernel.slice_premult is not None:
            pfM = density.real_to_fspace(self.kernel.slice_premult * M)
        else:
            pfM = density.real_to_fspace(M)
        pmtime = time.time() - tic_start

        ret = self.kernel.eval(fM=pfM, M=None, compute_gradient=compute_gradient)

        if not self.minibatch['test_batch'] and not kwargs.get('intermediate', False):
            tic_record = time.time()
            curr_var = ret[-1]['sigma2_est']
            assert n.all(n.isfinite(curr_var))
            if self.error_history.N_sum != self.cryodata.N_batches:
                self.error_history.setup(curr_var, self.cryodata.N_batches, allow_decay=False)
            self.error_history.set_value(self.minibatch['id'], curr_var)

            curr_corr = ret[-1]['correlation']
            assert n.all(n.isfinite(curr_corr))
            if self.correlation_history.N_sum != self.cryodata.N_batches:
                self.correlation_history.setup(curr_corr, self.cryodata.N_batches, allow_decay=False)
            self.correlation_history.set_value(self.minibatch['id'], curr_corr)

            curr_power = ret[-1]['power']
            assert n.all(n.isfinite(curr_power))
            if self.power_history.N_sum != self.cryodata.N_batches:
                self.power_history.setup(curr_power, self.cryodata.N_batches, allow_decay=False)
            self.power_history.set_value(self.minibatch['id'], curr_power)

            curr_mask = self.kernel.truncmask
            if self.mask_history.N_sum != self.cryodata.N_batches:
                self.mask_history.setup(n.require(curr_mask, dtype=n.float32), self.cryodata.N_batches,
                                        allow_decay=False)
            self.mask_history.set_value(self.minibatch['id'], curr_mask)
            ret[-1]['like_timing']['record'] = time.time() - tic_record

        if compute_gradient and self.kernel.slice_premult is not None:
            tic_record = time.time()
            ret = (ret[0], self.kernel.slice_premult * density.fspace_to_real(ret[1]), ret[2])
            ret[-1]['like_timing']['premult'] = pmtime + time.time() - tic_record

        ret[-1]['like_timing']['total'] = time.time() - tic_start

        return ret
Esempio n. 11
0
    def convert_parameter(self,x,comp_real=False,comp_fspace=False):
        is_x0 = x is self.x0
        if is_x0:
            M, fM = self.M0, self.fM0
        else:
            M, fM = param2density(x, self.xtype, self.M0.shape, \
                                  precond=self.precond)

            if comp_real and M is None:
                M = density.fspace_to_real(fM)

            if comp_fspace and fM is None:
                fM = density.real_to_fspace(M)
                
        return M, fM
Esempio n. 12
0
    def apply(self, M, normalize=True, kernel='sinc', kernelsize=3, rad=2):
        if self.symclass == '':
            return M

        if np.iscomplexobj(M):
            symRs = np.array(self.get_rotations(exclude_dsym=True),
                             dtype=np.float32)

            N = M.shape[0]
            if rad * self.get_order() * N < 500:  # VERY heuristic
                symM = sincint.symmetrize_fspace_volume(M,
                                                        rad,
                                                        kernel,
                                                        kernelsize,
                                                        symRs=symRs)
                if self.symclass == 'd':
                    symR = np.array(
                        [[[-1.0, 0, 0], [0, 1.0, 0], [0, 0, -1.0]]],
                        dtype=np.float32)
                    symM = sincint.symmetrize_fspace_volume(symM,
                                                            rad,
                                                            kernel,
                                                            kernelsize,
                                                            symRs=symR)
            else:
                return density.real_to_fspace(
                    self.apply(density.fspace_to_real(M), normalize=normalize))
        else:
            symRs = np.array(self.get_rotations(), dtype=np.float32)

            symM = sincint.symmetrize_volume(np.require(M, dtype=np.float32),
                                             symRs)
            # symRs = n.array(self.get_rotations(exclude_dsym=True), dtype=n.float32)

            # symM = sincint.symmetrize_volume_z(M,symRs)

            # if self.symclass == 'd':
            #     symR = n.array([ [ [ -1.0,     0,   0 ],
            #                        [    0,  -1.0,   0 ],
            #                        [    0,     0, 1.0 ] ] ], dtype=n.float32)
            #     symM = n.swapaxes(sincint.symmetrize_volume_z(n.swapaxes(symM,1,2),symR),1,2)

        if normalize:
            symM /= self.get_order()

        return symM
Esempio n. 13
0
def project(model, euler_angles, rad=0.95, truncate=False):
    if isinstance(model, str):
        M = mrc.readMRC(model)
    elif isinstance(model, np.ndarray):
        M = model
    
    N = M.shape[0]
    kernel = 'lanczos'
    ksize = 6

    premult = cryoops.compute_premultiplier(N, kernel, ksize)
    TtoF = sincint.gentrunctofull(N=N, rad=rad)
    premulter =   premult.reshape((1, 1, -1)) \
                * premult.reshape((1, -1, 1)) \
                * premult.reshape((-1, 1, 1))
    # premulter = 1
    fM = density.real_to_fspace(premulter * M)

    euler_angles = euler_angles.reshape((-1, 3))
    num_projs = euler_angles.shape[0]
    if truncate:
        projs = np.zeros((num_projs, TtoF.shape[1]), dtype=fM.dtype)
    else:
        projs = np.zeros((num_projs, N, N), dtype=M.dtype)
    for i, ea in enumerate(euler_angles):
        rot_matrix = geometry.rotmat3D_EA(*ea)[:, 0:2]
        slop = cryoops.compute_projection_matrix([rot_matrix], N, kernel, ksize, rad, 'rots')
        trunc_slice = slop.dot(fM.reshape((-1,)))
        if truncate:
            projs[i, :] = trunc_slice
        else:
            projs[i, :, :] = density.fspace_to_real(TtoF.dot(trunc_slice).reshape(N, N))

    if num_projs == 1 and not truncate:
        projs = projs.reshape((N, N))

    return projs
Esempio n. 14
0
    def learn_params(self, params, cparams, M=None, fM=None):
        anyfspace = any([obj.fspace for obj in self.objs])
        anyrspace = any([not obj.fspace for obj in self.objs])

        N = None
        if fM is None and anyfspace:
            assert M is not None, 'M or fM must be set!'
            N = M.shape[0]
            fM = density.real_to_fspace(M)
        elif fM is not None:
            N = fM.shape[0]

        if M is None and anyrspace:
            assert fM is not None, 'M or fM must be set!'
            N = fM.shape[0]
            M = density.fspace_to_real(fM)
        elif M is not None:
            assert N is None or N == M.shape[0]
            N = M.shape[0]

        assert N is not None

        for obj in self.objs:
            obj.learn_params(params, cparams, M=M, fM=fM)
Esempio n. 15
0
    def learn_params(self, params, cparams, M=None, fM=None):
        anyfspace = any([obj.fspace for obj in self.objs])
        anyrspace = any([not obj.fspace for obj in self.objs])
            
        N = None
        if fM is None and anyfspace:
            assert M is not None, 'M or fM must be set!'
            N = M.shape[0]
            fM = density.real_to_fspace(M)
        elif fM is not None:
            N = fM.shape[0]

        if M is None and anyrspace:
            assert fM is not None, 'M or fM must be set!'
            N = fM.shape[0]
            M = density.fspace_to_real(fM)
        elif M is not None:
            assert N is None or N == M.shape[0]
            N = M.shape[0]

        assert N is not None

        for obj in self.objs:
            obj.learn_params(params,cparams,M=M,fM=fM)
Esempio n. 16
0
M = M[:124, :124, :124]

mrc.writeMRC('./particle/EMD-6044-cropped.mrc', M, psz=3.0)

N = M.shape[0]
print(M.shape)
rad = 1
kernel = 'lanczos'
ksize = 4

xy, trunc_xy, truncmask = geometry.gencoords(N, 2, rad, True)
# premult = cryoops.compute_premultiplier(N, kernel='lanczos', kernsize=6)
premult = cryoops.compute_premultiplier(N, kernel, ksize)
TtoF = sincint.gentrunctofull(N=N, rad=rad)

fM = density.real_to_fspace(M)
prefM = density.real_to_fspace(premult.reshape(
        (1, 1, -1)) * premult.reshape((1, -1, 1)) * premult.reshape((-1, 1, 1)) * M)

EAs_grid = healpix.gen_EAs_grid(nside=2, psi_step=360)
Rs = [geometry.rotmat3D_EA(*EA)[:, 0:2] for EA in EAs_grid]
slice_ops = cryoops.compute_projection_matrix(Rs, N, kern='lanczos', kernsize=ksize, rad=rad, projdirtype='rots')

slices_sampled = cryoem.getslices(fM, slice_ops).reshape((EAs_grid.shape[0], trunc_xy.shape[0]))

premult_slices_sampled = cryoem.getslices(prefM, slice_ops).reshape((EAs_grid.shape[0], trunc_xy.shape[0]))

S = cryoops.compute_shift_phases(np.asarray([100, -20]).reshape((1,2)), N, rad)[0]

trunc_slice = slices_sampled[0]
premult_trunc_slice = premult_slices_sampled[0]
Esempio n. 17
0
    def dowork(self):
        """Do one atom of work. I.E. Execute one minibatch"""

        timing = {}
        # Time each minibatch
        tic_mini = time.time()

        self.iteration += 1

        # Fetch the current batches
        trainbatch = self.cryodata.get_next_minibatch(self.cparams.get('shuffle_minibatches',True))

        # Get the current epoch
        cepoch = self.cryodata.get_epoch(frac=True)
        epoch = self.cryodata.get_epoch()
        num_data = self.cryodata.N_D_Train

        # Evaluate the parameters
        self.eval_params()
        timing['setup'] = time.time() - tic_mini

        # Do hyperparameter learning
        if self.cparams.get('learn_params',False):
            tic_learn = time.time()
            if self.cparams.get('learn_prior_params',True):
                tic_learn_prior = time.time()
                self.prior_func.learn_params(self.params, self.cparams, M=self.M, fM=self.fM)
                timing['learn_prior'] = time.time() - tic_learn_prior 

            if self.cparams.get('learn_likelihood_params',True):
                tic_learn_like = time.time()
                self.like_func.learn_params(self.params, self.cparams, M=self.M, fM=self.fM)
                timing['learn_like'] = time.time() - tic_learn_like
                
            if self.cparams.get('learn_prior_params',True) or self.cparams.get('learn_likelihood_params',True):
                timing['learn_total'] = time.time() - tic_learn   

        # Time each epoch
        if self.tic_epoch == None:
            self.ostream("Epoch: %d" % epoch)
            self.tic_epoch = (tic_mini,epoch)
        elif self.tic_epoch[1] != epoch:
            self.ostream("Epoch Total - %.6f seconds " % \
                         (tic_mini - self.tic_epoch[0]))
            self.tic_epoch = (tic_mini,epoch)

        sym = get_symmetryop(self.cparams.get('symmetry',None))
        if sym is not None:
            self.obj.ws[1] = 1.0/sym.get_order()

        tic_mstats = time.time()
        self.ostream(self.cparams['name']," Iteration:", self.iteration,\
                     " Epoch:", epoch, " Host:", socket.gethostname())

        # Compute density statistics
        N = self.cryodata.N
        M_sum = self.M.sum(dtype=n.float64)
        M_zeros = (self.M == 0).sum()
        M_mean = M_sum/N**3
        M_max = self.M.max()
        M_min = self.M.min()
#         self.ostream("  Density (min/max/avg/sum/zeros): " +
#                      "%.2e / %.2e / %.2e / %.2e / %g " %
#                      (M_min, M_max, M_mean, M_sum, M_zeros))
        self.statout.output(total_density=[M_sum],
                            avg_density=[M_mean],
                            nonzero_density=[M_zeros],
                            max_density=[M_max],
                            min_density=[M_min])
        timing['density_stats'] = time.time() - tic_mstats

        # evaluate test batch if requested
        if self.iteration <= 1 or self.cparams.get('evaluate_test_set',self.iteration%5):
            tic_test = time.time()
            testbatch = self.cryodata.get_testbatch()

            self.obj.set_data(self.cparams,testbatch)
            testLogP, res_test = self.obj.eval(M=self.M, fM=self.fM,
                                               compute_gradient=False)

            self.outputbatchinfo(testbatch, res_test, testLogP, 'test', 'Test')
            timing['test_batch'] = time.time() - tic_test
        else:
            testLogP, res_test = None, None

        # setup the wrapper for the objective function 
        tic_objsetup = time.time()
        self.obj.set_data(self.cparams,trainbatch)
        self.obj_wrapper.set_objective(self.obj)
        x0 = self.obj_wrapper.set_density(self.M,self.fM)
        evalobj = self.obj_wrapper.eval_obj
        timing['obj_setup'] = time.time() - tic_objsetup

        # Get step size
        self.num_data_evals += trainbatch['N_M']  # at least one gradient
        tic_objstep = time.time()
        trainLogP, dlogP, v, res_train, extra_num_data = self.step.do_step(x0,
                                                         self.cparams,
                                                         self.cryodata,
                                                         evalobj,
                                                         batch=trainbatch)

        # Apply the step
        x = x0 + v
        timing['step'] = time.time() - tic_objstep

        # Convert from parameters to value
        tic_stepfinalize = time.time()
        prevM = n.copy(self.M)
        self.M, self.fM = self.obj_wrapper.convert_parameter(x,comp_real=True)
 
        apply_sym = sym is not None and self.cparams.get('perfect_symmetry',True) and self.cparams.get('apply_symmetry',True)
        if apply_sym:
            self.M = sym.apply(self.M)

        # Truncate the density to bounds if they exist
        if self.cparams['density_lb'] is not None:
            n.maximum(self.M,self.cparams['density_lb']*self.cparams['modelscale'],out=self.M)
        if self.cparams['density_ub'] is not None:
            n.minimum(self.M,self.cparams['density_ub']*self.cparams['modelscale'],out=self.M)

        # Compute net change
        self.dM = prevM - self.M

        # Convert to fourier space (may not be required)
        if self.fM is None or apply_sym \
           or self.cparams['density_lb'] != None \
           or self.cparams['density_ub'] != None:
            self.fM = density.real_to_fspace(self.M)
        timing['step_finalize'] = time.time() - tic_stepfinalize

        # Compute step statistics
        tic_stepstats = time.time()
        step_size = n.linalg.norm(self.dM)
        grad_size = n.linalg.norm(dlogP)
        M_norm = n.linalg.norm(self.M)

        self.num_data_evals += extra_num_data
        inc_ratio = step_size / M_norm
        self.statout.output(step_size=[step_size],
                            inc_ratio=[inc_ratio],
                            grad_size=[grad_size],
                            norm_density=[M_norm])
        timing['step_stats'] = time.time() - tic_stepstats


        # Update import sampling distributions
        tic_isupdate = time.time()
        self.sampler_R.perform_update()
        self.sampler_I.perform_update()
        self.sampler_S.perform_update()

        self.diagout.output(global_phi_R=self.sampler_R.get_global_dist())
        self.diagout.output(global_phi_I=self.sampler_I.get_global_dist())
        self.diagout.output(global_phi_S=self.sampler_S.get_global_dist())
        timing['is_update'] = time.time() - tic_isupdate
        
        # Output basic diagnostics
        tic_diagnostics = time.time()
        self.diagout.output(iteration=self.iteration, epoch=epoch, cepoch=cepoch)

        if self.logpost_history.N_sum != self.cryodata.N_batches:
            self.logpost_history.setup(trainLogP,self.cryodata.N_batches)
        self.logpost_history.set_value(trainbatch['id'],trainLogP)

        if self.like_history.N_sum != self.cryodata.N_batches:
            self.like_history.setup(res_train['L'],self.cryodata.N_batches)
        self.like_history.set_value(trainbatch['id'],res_train['L'])

        self.outputbatchinfo(trainbatch, res_train, trainLogP, 'train', 'Train')

        # Dump parameters here to catch the defaults used in evaluation
        self.diagout.output(params=self.cparams,
                            envelope_mle=self.like_func.get_envelope_mle(),
                            sigma2_mle=self.like_func.get_sigma2_mle(),
                            hostname=socket.gethostname())
        self.statout.output(num_data=[num_data],
                            num_data_evals=[self.num_data_evals],
                            iteration=[self.iteration],
                            epoch=[epoch],
                            cepoch=[cepoch],
                            logp=[self.logpost_history.get_mean()],
                            like=[self.like_history.get_mean()],
                            sigma=[self.like_func.get_rmse()],
                            time=[time.time()])
        timing['diagnostics'] = time.time() - tic_diagnostics

        checkpoint_it = self.iteration % self.cparams.get('checkpoint_frequency',50) == 0 
        save_it = checkpoint_it or self.cparams['save_iteration'] or \
                  time.time() - self.last_save > self.cparams.get('save_time',n.inf)
                  
        if save_it:
            tic_save = time.time()
            self.last_save = tic_save
            if self.io_queue.qsize():
                print "Warning: IO queue has become backlogged with {0} remaining, waiting for it to clear".format(self.io_queue.qsize())
                self.io_queue.join()
            self.io_queue.put(( 'pkl', self.statout.fname, copy(self.statout.outdict) ))
            self.io_queue.put(( 'pkl', self.diagout.fname, deepcopy(self.diagout.outdict) ))
            self.io_queue.put(( 'pkl', self.likeout.fname, deepcopy(self.likeout.outdict) ))
            self.io_queue.put(( 'mrc', opj(self.outbase,'model.mrc'), \
                                (n.require(self.M,dtype=density.real_t),self.voxel_size) ))
            self.io_queue.put(( 'mrc', opj(self.outbase,'dmodel.mrc'), \
                                (n.require(self.dM,dtype=density.real_t),self.voxel_size) ))

            if checkpoint_it:
                self.io_queue.put(( 'cp', self.diagout.fname, self.diagout.fname+'-{0:06}'.format(self.iteration) ))
                self.io_queue.put(( 'cp', self.likeout.fname, self.likeout.fname+'-{0:06}'.format(self.iteration) ))
                self.io_queue.put(( 'cp', opj(self.outbase,'model.mrc'), opj(self.outbase,'model-{0:06}.mrc'.format(self.iteration)) ))
            timing['save'] = time.time() - tic_save
                
            
        time_total = time.time() - tic_mini
        self.ostream("  Minibatch Total - %.2f seconds                         Total Runtime - %s" %
                     (time_total, format_timedelta(datetime.now() - self.startdatetime) ))

        
        return self.iteration < self.cparams.get('max_iterations',n.inf) and \
               cepoch < self.cparams.get('max_epochs',n.inf)
Esempio n. 18
0
    def __init__(self, expbase, cmdparams=None):
        """cryodata is a CryoData instance. 
        expbase is a path to the base of folder where this experiment's files
        will be stored.  The folder above expbase will also be searched
        for .params files. These will be loaded first."""
        BackgroundWorker.__init__(self)

        # Create a background thread which handles IO
        self.io_queue = Queue()
        self.io_thread = Thread(target=self.ioworker)
        self.io_thread.daemon = True
        self.io_thread.start()

        # General setup ----------------------------------------------------
        self.expbase = expbase
        self.outbase = None

        # Paramter setup ---------------------------------------------------
        # search above expbase for params files
        _,_,filenames = os.walk(opj(expbase,'../')).next()
        self.paramfiles = [opj(opj(expbase,'../'), fname) \
                           for fname in filenames if fname.endswith('.params')]
        # search expbase for params files
        _,_,filenames = os.walk(opj(expbase)).next()
        self.paramfiles += [opj(expbase,fname)  \
                            for fname in filenames if fname.endswith('.params')]
        if 'local.params' in filenames:
            self.paramfiles += [opj(expbase,'local.params')]
        # load parameter files
        self.params = Params(self.paramfiles)
        self.cparams = None
        
        if cmdparams is not None:
            # Set parameter specified on the command line
            for k,v in cmdparams.iteritems():
                self.params[k] = v
                
        # Dataset setup -------------------------------------------------------
        self.imgpath = self.params['inpath']
        psize = self.params['resolution']
        if not isinstance(self.imgpath,list):
            imgstk = MRCImageStack(self.imgpath,psize)
        else:
            imgstk = CombinedImageStack([MRCImageStack(cimgpath,psize) for cimgpath in self.imgpath])

        if self.params.get('float_images',True):
            imgstk.float_images()
        
        self.ctfpath = self.params['ctfpath']
        mscope_params = self.params['microscope_params']
         
        if not isinstance(self.ctfpath,list):
            ctfstk = CTFStack(self.ctfpath,mscope_params)
        else:
            ctfstk = CombinedCTFStack([CTFStack(cctfpath,mscope_params) for cctfpath in self.ctfpath])


        self.cryodata = CryoDataset(imgstk,ctfstk)
        self.cryodata.compute_noise_statistics()
        if self.params.get('window_images',True):
            imgstk.window_images()
        minibatch_size = self.params['minisize']
        testset_size = self.params['test_imgs']
        partition = self.params.get('partition',0)
        num_partitions = self.params.get('num_partitions',1)
        seed = self.params['random_seed']
        if isinstance(partition,str):
            partition = eval(partition)
        if isinstance(num_partitions,str):
            num_partitions = eval(num_partitions)
        if isinstance(seed,str):
            seed = eval(seed)
        self.cryodata.divide_dataset(minibatch_size,testset_size,partition,num_partitions,seed)
        
        self.cryodata.set_datasign(self.params.get('datasign','auto'))
        if self.params.get('normalize_data',True):
            self.cryodata.normalize_dataset()

        self.voxel_size = self.cryodata.pixel_size


        # Iterations setup -------------------------------------------------
        self.iteration = 0 
        self.tic_epoch = None
        self.num_data_evals = 0
        self.eval_params()

        outdir = self.cparams.get('outdir',None)
        if outdir is None:
            if self.cparams.get('num_partitions',1) > 1:
                outdir = 'partition{0}'.format(self.cparams['partition'])
            else:
                outdir = ''
        self.outbase = opj(self.expbase,outdir)
        if not os.path.isdir(self.outbase):
            os.makedirs(self.outbase) 

        # Output setup -----------------------------------------------------
        self.ostream = OutputStream(opj(self.outbase,'stdout'))

        self.ostream(80*"=")
        self.ostream("Experiment: " + expbase + \
                     "    Kernel: " + self.params['kernel'])
        self.ostream("Started on " + socket.gethostname() + \
                     "    At: " + time.strftime('%B %d %Y: %I:%M:%S %p'))
        self.ostream("Git SHA1: " + gitutil.git_get_SHA1())
        self.ostream(80*"=")
        gitutil.git_info_dump(opj(self.outbase, 'gitinfo'))
        self.startdatetime = datetime.now()


        # for diagnostics and parameters
        self.diagout = Output(opj(self.outbase, 'diag'),runningout=False)
        # for stats (per image etc)
        self.statout = Output(opj(self.outbase, 'stat'),runningout=True)
        # for likelihoods of individual images
        self.likeout = Output(opj(self.outbase, 'like'),runningout=False)

        self.img_likes = n.empty(self.cryodata.N_D)
        self.img_likes[:] = n.inf

        # optimization state vars ------------------------------------------
        init_model = self.cparams.get('init_model',None)
        if init_model is not None:
            filename = init_model
            if filename.upper().endswith('.MRC'):
                M = readMRC(filename)
            else:
                with open(filename) as fp:
                    M = cPickle.load(fp)
                    if type(M)==list:
                        M = M[-1]['M'] 
            if M.shape != 3*(self.cryodata.N,):
                M = cryoem.resize_ndarray(M,3*(self.cryodata.N,),axes=(0,1,2))
        else:
            init_seed = self.cparams.get('init_random_seed',0)  + self.cparams.get('partition',0)
            print "Randomly generating initial density (init_random_seed = {0})...".format(init_seed), ; sys.stdout.flush()
            tic = time.time()
            M = cryoem.generate_phantom_density(self.cryodata.N, 0.95*self.cryodata.N/2.0, \
                                                5*self.cryodata.N/128.0, 30, seed=init_seed)
            print "done in {0}s".format(time.time() - tic)

        tic = time.time()
        print "Windowing and aligning initial density...", ; sys.stdout.flush()
        # window the initial density
        wfunc = self.cparams.get('init_window','circle')
        cryoem.window(M,wfunc)

        # Center and orient the initial density
        cryoem.align_density(M)
        print "done in {0:.2f}s".format(time.time() - tic)

        # apply the symmetry operator
        init_sym = get_symmetryop(self.cparams.get('init_symmetry',self.cparams.get('symmetry',None)))
        if init_sym is not None:
            tic = time.time()
            print "Applying symmetry operator...", ; sys.stdout.flush()
            M = init_sym.apply(M)
            print "done in {0:.2f}s".format(time.time() - tic)

        tic = time.time()
        print "Scaling initial model...", ; sys.stdout.flush()
        modelscale = self.cparams.get('modelscale','auto')
        mleDC, _, mleDC_est_std = self.cryodata.get_dc_estimate()
        if modelscale == 'auto':
            # Err on the side of a weaker prior by using a larger value for modelscale
            modelscale = (n.abs(mleDC) + 2*mleDC_est_std)/self.cryodata.N
            print "estimated modelscale = {0:.3g}...".format(modelscale), ; sys.stdout.flush()
            self.params['modelscale'] = modelscale
            self.cparams['modelscale'] = modelscale
        M *= modelscale/M.sum()
        print "done in {0:.2f}s".format(time.time() - tic)
        if mleDC_est_std/n.abs(mleDC) > 0.05:
            print "  WARNING: the DC component estimate has a high relative variance, it may be inaccurate!"
        if ((modelscale*self.cryodata.N - n.abs(mleDC)) / mleDC_est_std) > 3:
            print "  WARNING: the selected modelscale value is more than 3 std devs different than the estimated one.  Be sure this is correct."

        self.M = n.require(M,dtype=density.real_t)
        self.fM = density.real_to_fspace(M)
        self.dM = density.zeros_like(self.M)

        self.step = eval(self.cparams['optim_algo'])
        self.step.setup(self.cparams, self.diagout, self.statout, self.ostream)

        # Objective function setup --------------------------------------------
        param_type = self.cparams.get('parameterization','real')
        cplx_param = param_type in ['complex','complex_coeff','complex_herm_coeff']
        self.like_func = eval_objective(self.cparams['likelihood'])
        self.prior_func = eval_objective(self.cparams['prior'])

        if self.cparams.get('penalty',None) is not None:
            self.penalty_func = eval_objective(self.cparams['penalty'])
            prior_func = SumObjectives(self.prior_func.fspace, \
                                       [self.penalty_func,self.prior_func], None)
        else:
            prior_func = self.prior_func

        self.obj = SumObjectives(cplx_param,
                                 [self.like_func,prior_func], [None,None])
        self.obj.setup(self.cparams, self.diagout, self.statout, self.ostream)
        self.obj.set_dataset(self.cryodata)
        self.obj_wrapper = ObjectiveWrapper(param_type)

        self.last_save = time.time()
        
        self.logpost_history = FiniteRunningSum()
        self.like_history = FiniteRunningSum()

        # Importance Samplers -------------------------------------------------
        self.is_sym = get_symmetryop(self.cparams.get('is_symmetry',self.cparams.get('symmetry',None)))
        self.sampler_R = FixedFisherImportanceSampler('_R',self.is_sym)
        self.sampler_I = FixedFisherImportanceSampler('_I')
        self.sampler_S = FixedGaussianImportanceSampler('_S')
        self.like_func.set_samplers(sampler_R=self.sampler_R,sampler_I=self.sampler_I,sampler_S=self.sampler_S)
Esempio n. 19
0
    def updatevis(self, levels=[0.2,0.5,0.8]):
        if self.M is None or self.diag is None or self.stat is None:
            return

        cdiag = self.diag
        cparams = cdiag['params']
        sym = get_symmetryop(cparams.get('symmetry',None))
        quad_sym = sym if cparams.get('perfect_symmetry',True) else None

        resolution = cparams['voxel_size']

        name = cparams['name']
        maxfreq = cparams['max_frequency']
        N = self.M.shape[0]
        rad_cutoff = cparams.get('rad_cutoff', 1.0)
        rad = min(rad_cutoff,maxfreq*2.0*resolution)

        # Show objective function
        self.show_objective_plot(self.get_figure('stats'))
        
        # Show information about noise and error
        self.show_error_plot(self.get_figure('error'))
        self.show_noise_plot(self.get_figure('noise'))
        
        # Plot the envelope function if we have the info
        if 'envelope_mle' in cdiag:
            self.show_envelope_plot(self.get_figure('envelope'))
        else:
            self.close_figure('envelope')

        if sym is None:
            assert quad_sym is None
            alignedM,R = c.align_density(self.M)
            if self.show_grad:
                aligneddM = c.rotate_density(self.dM,R)
            else:
                aligneddM = None
        else:
            alignedM, aligneddM = self.M, self.dM
            R = n.identity(3)

        self.alignedM,self.aligneddM,self.alignedR = alignedM,aligneddM,R
        self.fM = density.real_to_fspace(self.M)

        self.figMslices.set_data(alignedM)

        glbl_phi_R = n.array([cdiag['global_phi_R']]).ravel()
        if len(glbl_phi_R) == 1:
            glbl_phi_R = None
        glbl_phi_I = cdiag['global_phi_I']
        glbl_phi_S = cdiag['global_phi_S']

        # Get direction quadrature
        quad_R = quadrature.quad_schemes[('dir',cparams.get('quad_type_R','sk97'))]
        quad_degree_R = cparams.get('quad_degree_R','auto')
        if quad_degree_R == 'auto':
            usFactor_R = cparams.get('quad_undersample_R',
                                     cparams.get('quad_undersample',1.0))
            quad_degree_R,_ = quad_R.compute_degree(N,rad,usFactor_R)
        origlebDirs,_ = quad_R.get_quad_points(quad_degree_R,quad_sym)
        lebDirs = n.dot(origlebDirs,R)

        # Get shift quadrature
        quad_S = quadrature.quad_schemes[('shift',cparams.get('quad_type_S','hermite'))]
        quad_degree_S = cparams.get('quad_degree_S','auto')
        if quad_degree_S == 'auto':
            usFactor_S = cparams.get('quad_undersample_S',
                                     cparams.get('quad_undersample',1.0))
            quad_degree_S = quad_S.get_degree(N,rad,
                                              cparams['quad_shiftsigma']/resolution,
                                              cparams['quad_shiftextent']/resolution,
                                              usFactor_S)
        pts_S,_ = quad_S.get_quad_points(quad_degree_S,
                                         cparams['quad_shiftsigma']/resolution,
                                         cparams['quad_shiftextent']/resolution,
                                         cparams.get('quad_shifttrunc','circ'))
        vmax_R = 5.0/len(glbl_phi_R)
        vmax_S = 5.0/len(glbl_phi_S)

        # Density visualization
        mlab.figure(self.fig1)
        mlab.clf()
        self.curr_contours = plot_density(alignedM, self.contours, levels)
#         dispPhiR = glbl_phi_R
#         dispDirs = lebDirs
#         plot_directions(alignedM.shape[0]*dispDirs + alignedM.shape[0]/2.0,
#                         dispPhiR,
#                         0, vmax_R)
        mlab.view(focalpoint=[alignedM.shape[0]/2.0,alignedM.shape[0]/2.0,alignedM.shape[0]/2.0],distance=1.5*alignedM.shape[0])


        if glbl_phi_R is not None:
            plt.figure(self.get_figure('global_is_dists').number)
            plt.clf()
            plot_importance_dists(name,lebDirs,pts_S*resolution,glbl_phi_R,glbl_phi_I,glbl_phi_S,vmax_R,vmax_S)

        if self.show_grad:
            # Statistics of dM
            self.figdMslices.set_data(aligneddM)

            plt.figure(self.get_figure('step_stats').number)
            plt.clf()
            plt.suptitle(name + ' Step Statistics')

            plt.subplot(1,2,1)
            plt.hist(self.dM.reshape((-1,)),bins=0.5*self.dM.shape[0],log=True)
            plt.title('Voxel Histogram')

            (fs,raps) = rot_power_spectra(self.dM,resolution=resolution)
            plt.subplot(1,2,2)
            plt.plot(fs/(N/2.0)/(2.0*resolution),raps,label='RAPS')
            plt.plot((rad/(2.0*resolution))*n.ones((2,)), 
                     n.array([raps[raps > 0].min(),raps.max()]))
            plt.yscale('log')
            plt.title('RAPS Step')


        if not self.extra_plots:
            self.close_figure('density_stats')
            return

        # Statistics of M
        self.show_density_plot(self.get_figure('density_stats'))
Esempio n. 20
0
def genphantomdata(N_D, phantompath, ctfparfile):
    mscope_params = {
        'akv': 200,
        'wgh': 0.07,
        'cs': 2.0,
        'psize': 2.8,
        'bfactor': 500.0
    }
    N = 128
    rad = 0.95
    shift_sigma = 3.0
    sigma_noise = 25.0
    M_totalmass = 80000
    kernel = 'lanczos'
    ksize = 6

    premult = cryoops.compute_premultiplier(N, kernel, ksize)

    tic = time.time()

    N_D = int(N_D)
    N = int(N)
    rad = float(rad)
    psize = mscope_params['psize']
    bfactor = mscope_params['bfactor']
    shift_sigma = float(shift_sigma)
    sigma_noise = float(sigma_noise)
    M_totalmass = float(M_totalmass)

    srcctf_stack = CTFStack(ctfparfile, mscope_params)
    genctf_stack = GeneratedCTFStack(
        mscope_params, parfields=['PHI', 'THETA', 'PSI', 'SHX', 'SHY'])

    TtoF = sincint.gentrunctofull(N=N, rad=rad)
    Cmap = n.sort(
        n.random.random_integers(0,
                                 srcctf_stack.get_num_ctfs() - 1, N_D))

    M = mrc.readMRC(phantompath)
    cryoem.window(M, 'circle')
    M[M < 0] = 0
    if M_totalmass is not None:
        M *= M_totalmass / M.sum()

    V = density.real_to_fspace(
        premult.reshape((1, 1, -1)) * premult.reshape(
            (1, -1, 1)) * premult.reshape((-1, 1, 1)) * M)

    print "Generating data..."
    sys.stdout.flush()
    imgdata = n.empty((N_D, N, N), dtype=density.real_t)

    pardata = {'R': [], 't': []}

    prevctfI = None
    for i, srcctfI in enumerate(Cmap):
        ellapse_time = time.time() - tic
        remain_time = float(N_D - i) * ellapse_time / max(i, 1)
        print "\r%.2f Percent.. (Elapsed: %s, Remaining: %s)      " % (
            i / float(N_D) * 100.0, format_timedelta(ellapse_time),
            format_timedelta(remain_time)),
        sys.stdout.flush()

        # Get the CTF for this image
        cCTF = srcctf_stack.get_ctf(srcctfI)
        if prevctfI != srcctfI:
            genctfI = genctf_stack.add_ctf(cCTF)
            C = cCTF.dense_ctf(N, psize, bfactor).reshape((N**2, ))
            prevctfI = srcctfI

        # Randomly generate the viewing direction/shift
        pt = n.random.randn(3)
        pt /= n.linalg.norm(pt)
        psi = 2 * n.pi * n.random.rand()
        EA = geom.genEA(pt)[0]
        EA[2] = psi
        shift = n.random.randn(2) * shift_sigma

        R = geom.rotmat3D_EA(*EA)[:, 0:2]
        slop = cryoops.compute_projection_matrix([R], N, kernel, ksize, rad,
                                                 'rots')
        S = cryoops.compute_shift_phases(shift.reshape((1, 2)), N, rad)[0]

        D = slop.dot(V.reshape((-1, )))
        D *= S

        imgdata[i] = density.fspace_to_real((C * TtoF.dot(D)).reshape(
            (N, N))) + n.require(n.random.randn(N, N) * sigma_noise,
                                 dtype=density.real_t)

        genctf_stack.add_img(genctfI,
                             PHI=EA[0] * 180.0 / n.pi,
                             THETA=EA[1] * 180.0 / n.pi,
                             PSI=EA[2] * 180.0 / n.pi,
                             SHX=shift[0],
                             SHY=shift[1])

        pardata['R'].append(R)
        pardata['t'].append(shift)

    print "\rDone in ", time.time() - tic, " seconds."
    return imgdata, genctf_stack, pardata, mscope_params
Esempio n. 21
0
M = M[:124, :124, :124]

mrc.writeMRC('./particle/EMD-6044-cropped.mrc', M, psz=3.0)

N = M.shape[0]
print(M.shape)
rad = 1
kernel = 'lanczos'
ksize = 4

xy, trunc_xy, truncmask = geometry.gencoords(N, 2, rad, True)
# premult = cryoops.compute_premultiplier(N, kernel='lanczos', kernsize=6)
premult = cryoops.compute_premultiplier(N, kernel, ksize)
TtoF = sincint.gentrunctofull(N=N, rad=rad)

fM = density.real_to_fspace(M)
prefM = density.real_to_fspace(
    premult.reshape((1, 1, -1)) * premult.reshape(
        (1, -1, 1)) * premult.reshape((-1, 1, 1)) * M)

EAs_grid = healpix.gen_EAs_grid(nside=2, psi_step=360)
Rs = [geometry.rotmat3D_EA(*EA)[:, 0:2] for EA in EAs_grid]
slice_ops = cryoops.compute_projection_matrix(Rs,
                                              N,
                                              kern='lanczos',
                                              kernsize=ksize,
                                              rad=rad,
                                              projdirtype='rots')

slices_sampled = cryoem.getslices(fM, slice_ops).reshape(
    (EAs_grid.shape[0], trunc_xy.shape[0]))
Esempio n. 22
0
def sagd_init(data_dir, model_file, use_angular_correlation=False):
    data_params = {
        'dataset_name': "1AON",
        'inpath': os.path.join(data_dir, 'imgdata.mrc'),
        'ctfpath': os.path.join(data_dir, 'defocus.txt'),
        'microscope_params': {
            'akv': 200,
            'wgh': 0.07,
            'cs': 2.0
        },
        'resolution': 2.8,
        'sigma': 'noise_std',
        'sigma_out': 'data_std',
        'minisize': 20,
        'test_imgs': 20,
        'partition': 0,
        'num_partitions': 0,
        'random_seed': 1,
        # 'symmetry': 'C7'
    }
    # Setup dataset
    print("Loading dataset %s" % data_dir)
    cryodata, _ = dataset_loading_test(data_params)
    # mleDC, _, mleDC_est_std = cryodata.get_dc_estimate()
    # modelscale = (np.abs(mleDC) + 2*mleDC_est_std)/cryodata.N
    modelscale = 1.0

    if model_file is not None:
        print("Loading density map %s" % model_file)
        M = readMRC(model_file)
    else:
        print("Generating random initial density map ...")
        M = cryoem.generate_phantom_density(cryodata.N, 0.95 * cryodata.N / 2.0, \
                                            5 * cryodata.N / 128.0, 30, seed=0)
        M *= modelscale / M.sum()
    slice_interp = {
        'kern': 'lanczos',
        'kernsize': 4,
        'zeropad': 0,
        'dopremult': True
    }
    # fM = SimpleKernel.get_fft(M, slice_interp)

    M_totalmass = 5000
    M *= M_totalmass / M.sum()
    N = M.shape[0]
    kernel = 'lanczos'
    ksize = 6
    premult = cryoops.compute_premultiplier(N, kernel, ksize)
    V = density.real_to_fspace(
        premult.reshape((1, 1, -1)) * premult.reshape(
            (1, -1, 1)) * premult.reshape((-1, 1, 1)) * M)
    M = V.real**2 + V.imag**2

    freqs_3d = geometry.gencoords_base(N, 3) / (N * data_params['resolution'])
    freq_radius_3d = np.sqrt((freqs_3d**2).sum(axis=1))
    mask_3d_outlier = np.require(np.float_(freq_radius_3d > 0.015).reshape(
        (N, N, N)),
                                 dtype=density.real_t)
    fM = M * mask_3d_outlier

    cparams = {
        'use_angular_correlation': use_angular_correlation,
        'likelihood': 'UnknownRSLikelihood()',
        'kernel': 'multicpu',
        'prior_name': "'Null'",
        'sparsity_lambda': 0.9,
        'prior': 'NullPrior()',

        # 'prior_name': "'CAR'",
        # 'prior': 'CARPrior()',
        # 'car_type': 'gauss0.5',
        # 'car_tau': 75.0,
        'iteration': 0,
        'pixel_size': cryodata.pixel_size,
        'max_frequency': 0.02,
        'num_batches': cryodata.N_batches,
        'interp_kernel_R': 'lanczos',
        'interp_kernel_size_R': 4,
        'interp_zeropad_R': 0.0,
        'interp_premult_R': True,
        'interp_kernel_I': 'lanczos',
        'interp_kernel_size_I': 8,
        'interp_zeropad_I': 0.0,  # 1.0,
        'interp_premult_I': True,
        'sigma': cryodata.noise_var,
        'modelscale': modelscale,
        # 'symmetry': 'C7'
    }

    is_params = {
        # importance sampling
        # Ignore the first 50 iterations entirely
        'is_prior_prob': max(0.05,
                             2**(-0.005 * max(0, cparams['iteration'] - 50))),
        'is_temperature': max(1.0,
                              2**(750.0 / max(1, cparams['iteration'] - 50))),
        'is_ess_scale': 10,
        'is_fisher_chirality_flip': cparams['iteration'] < 2500,
        'is_on_R': True,
        'is_global_prob_R': 0.9,
        'is_on_I': True,
        'is_global_prob_I': 1e-10,
        'is_on_S': True,
        'is_global_prob_S': 0.9,
        'is_gaussian_sigmascale_S': 0.67,
    }

    cparams.update(is_params)
    return cryodata, (M, fM), cparams
Esempio n. 23
0
zeropad_size = int(zeropad * (N / 2))
zp_N = zeropad_size * 2 + N
zp_M_shape = (zp_N,) * 3
ZP_M = np.zeros(zp_M_shape, dtype=density.real_t)
zp_M_slicer = (slice( zeropad_size, (N + zeropad_size) ),) * 3

M_totalmass = 5000
M *= M_totalmass / M.sum()

ZP_M[zp_M_slicer] = M

N = M.shape[0]
kernel = 'lanczos'
ksize = 6
premult = cryoops.compute_premultiplier(zp_N, kernel, ksize)
V = density.real_to_fspace(premult.reshape((1, 1, -1)) * premult.reshape((1, -1, 1)) * premult.reshape((-1, 1, 1)) * ZP_M)
# V = density.real_to_fspace(ZP_M)
ZP_fM = V.real ** 2 + V.imag ** 2

fM = ZP_fM[zp_M_slicer]

# mask_3d_outlier = geometry.gen_dense_beamstop_mask(N, 3, 0.015, psize=2.8)
# fM *= mask_3d_outlier

# fM = mrc.readMRC('particle/1AON_fM_totalmass_5000.mrc') * mask_3d_outlier

imgdata = mrc.readMRCimgs('data/1AON_xfel_5000_totalmass_05000/imgdata.mrc', 420, 1)
curr_img = imgdata[:, :, 0]
zp_img = np.zeros((zp_N,)*2, dtype=density.real_t)
zp_img[zp_M_slicer[0:2]] = curr_img
# curr_img = zp_img
Esempio n. 24
0
    def eval(self, M=None, fM=None, compute_gradient=True, all_grads=False,**kwargs):
        anyfspace = any([obj.fspace for obj in self.objs])
        anyrspace = any([not obj.fspace for obj in self.objs])

        N = None
        if fM is None and anyfspace:
            assert M is not None, 'M or fM must be set!'
            N = M.shape[0]
            fM = density.real_to_fspace(M)
        elif fM is not None:
            N = fM.shape[0]

        if M is None and anyrspace:
            assert fM is not None, 'M or fM must be set!'
            N = fM.shape[0]
            M = density.fspace_to_real(fM)
        elif M is not None:
            assert N is None or N == M.shape[0]
            N = M.shape[0]

        assert N is not None

        logP = 0
        logPs = []
        if compute_gradient:
            if all_grads:
                dlogP = density.zeros_like(fM) if self.fspace else density.zeros_like(M)
                dlogPs = []
            else:
                if (not self.fspace) or anyrspace:
                    dlogPdM = density.zeros_like(M)
                if self.fspace or anyfspace:
                    dlogPdfM = density.zeros_like(fM)
        outputs = {}
        for w,obj in zip(self.ws,self.objs):
            if compute_gradient:
                clogP, cdlogP, coutputs = obj.eval(M = M, fM = fM, 
                                                   compute_gradient = compute_gradient,
                                                   **kwargs)
                if w is not None and w != 1:
                    clogP *= w
                    cdlogP *= w

                if all_grads:
                    if obj.fspace == self.fspace:
                        dlogPs.append(cdlogP)
                    elif self.fspace:
                        dlogPs.append(density.real_to_fspace(cdlogP))
                    else:
                        dlogPs.append(density.fspace_to_real(cdlogP))
                    dlogP += dlogPs[-1]
                else:
                    if obj.fspace:
                        dlogPdfM += cdlogP
                    else:
                        dlogPdM += cdlogP

            else:
                clogP, coutputs = obj.eval(M = M, fM = fM,
                                           compute_gradient = compute_gradient,
                                           **kwargs)
                if w is not None and w != 1:
                    clogP *= w

            logP += clogP
            logPs.append(clogP)
            outputs.update(coutputs)

        if compute_gradient and not all_grads:
            if self.fspace:
                dlogP = dlogPdfM
                if anyrspace:
                    dlogP += density.real_to_fspace(dlogPdM)
            else:
                dlogP = dlogPdM
                if anyfspace:
                    dlogP += density.fspace_to_real(dlogPdfM)
        
        outputs['all_logPs'] = logPs
        if compute_gradient and all_grads:
            outputs['all_dlogPs'] = dlogPs

        if compute_gradient:
            return logP, dlogP, outputs 
        else:
            return logP, outputs 
Esempio n. 25
0
def genphantomdata(N_D, phantompath, ctfparfile):
    # mscope_params = {'akv': 200, 'wgh': 0.07,
    #                  'cs': 2.0, 'psize': 2.8, 'bfactor': 500.0}
    mscope_params = {'akv': 200, 'wgh': 0.07,
                     'cs': 2.0, 'psize': 3.0, 'bfactor': 500.0}
    
    M = mrc.readMRC(phantompath)

    N = M.shape[0]
    rad = 0.95
    shift_sigma = 3.0
    sigma_noise = 25.0
    M_totalmass = 80000
    kernel = 'lanczos'
    ksize = 6

    premult = cryoops.compute_premultiplier(N, kernel, ksize)

    tic = time.time()

    N_D = int(N_D)
    N = int(N)
    rad = float(rad)
    psize = mscope_params['psize']
    bfactor = mscope_params['bfactor']
    shift_sigma = float(shift_sigma)
    sigma_noise = float(sigma_noise)
    M_totalmass = float(M_totalmass)

    srcctf_stack = CTFStack(ctfparfile, mscope_params)
    genctf_stack = GeneratedCTFStack(mscope_params, parfields=[
                                     'PHI', 'THETA', 'PSI', 'SHX', 'SHY'])

    TtoF = sincint.gentrunctofull(N=N, rad=rad)
    Cmap = np.sort(np.random.random_integers(
        0, srcctf_stack.get_num_ctfs() - 1, N_D))

    cryoem.window(M, 'circle')
    M[M < 0] = 0
    if M_totalmass is not None:
        M *= M_totalmass / M.sum()

    V = density.real_to_fspace(
        premult.reshape((1, 1, -1)) * premult.reshape((1, -1, 1)) * premult.reshape((-1, 1, 1)) * M)

    print("Generating data...")
    sys.stdout.flush()
    imgdata = np.empty((N_D, N, N), dtype=density.real_t)

    pardata = {'R': [], 't': []}

    prevctfI = None
    for i, srcctfI in enumerate(Cmap):
        ellapse_time = time.time() - tic
        remain_time = float(N_D - i) * ellapse_time / max(i, 1)
        print("\r%.2f Percent.. (Elapsed: %s, Remaining: %s)" % (i / float(N_D)
                                                                 * 100.0, format_timedelta(ellapse_time), format_timedelta(remain_time)))
        sys.stdout.flush()

        # Get the CTF for this image
        cCTF = srcctf_stack.get_ctf(srcctfI)
        if prevctfI != srcctfI:
            genctfI = genctf_stack.add_ctf(cCTF)
            C = cCTF.dense_ctf(N, psize, bfactor).reshape((N**2,))
            prevctfI = srcctfI

        # Randomly generate the viewing direction/shift
        pt = np.random.randn(3)
        pt /= np.linalg.norm(pt)
        psi = 2 * np.pi * np.random.rand()
        EA = geometry.genEA(pt)[0]
        EA[2] = psi
        shift = np.random.randn(2) * shift_sigma

        R = geometry.rotmat3D_EA(*EA)[:, 0:2]
        slop = cryoops.compute_projection_matrix(
            [R], N, kernel, ksize, rad, 'rots')
        S = cryoops.compute_shift_phases(shift.reshape((1, 2)), N, rad)[0]

        D = slop.dot(V.reshape((-1,)))
        D *= S

        imgdata[i] = density.fspace_to_real((C * TtoF.dot(D)).reshape((N, N))) + np.require(
            np.random.randn(N, N) * sigma_noise, dtype=density.real_t)

        genctf_stack.add_img(genctfI,
                             PHI=EA[0] * 180.0 / np.pi, THETA=EA[1] * 180.0 / np.pi, PSI=EA[2] * 180.0 / np.pi,
                             SHX=shift[0], SHY=shift[1])

        pardata['R'].append(R)
        pardata['t'].append(shift)

    print("\rDone in ", time.time() - tic, " seconds.")
    return imgdata, genctf_stack, pardata, mscope_params
Esempio n. 26
0
    def updatevis(self, levels=[0.2,0.5,0.8]):
        if self.M is None or self.diag is None or self.stat is None:
            return

        cdiag = self.diag
        cparams = cdiag['params']
        sym = get_symmetryop(cparams.get('symmetry',None))
        quad_sym = sym if cparams.get('perfect_symmetry',True) else None

        resolution = cparams['voxel_size']

        name = cparams['name']
        maxfreq = cparams['max_frequency']
        N = self.M.shape[0]
        rad_cutoff = cparams.get('rad_cutoff', 1.0)
        rad = min(rad_cutoff,maxfreq*2.0*resolution)

        # Show objective function
        self.show_objective_plot(self.get_figure('stats'))
        
        # Show information about noise and error
        self.show_error_plot(self.get_figure('error'))
        self.show_noise_plot(self.get_figure('noise'))
        
        # Plot the envelope function if we have the info
        if 'envelope_mle' in cdiag:
            self.show_envelope_plot(self.get_figure('envelope'))
        else:
            self.close_figure('envelope')

        if sym is None:
            assert quad_sym is None
            alignedM,R = cryoem.align_density(self.M)
            if self.show_grad:
                aligneddM = cryoem.rotate_density(self.dM,R)
            else:
                aligneddM = None
        else:
            alignedM, aligneddM = self.M, self.dM
            R = np.identity(3)

        self.alignedM,self.aligneddM,self.alignedR = alignedM,aligneddM,R
        self.fM = density.real_to_fspace(self.M)

        self.figMslices.set_data(alignedM)

        glbl_phi_R = np.array([cdiag['global_phi_R']]).ravel()
        if len(glbl_phi_R) == 1:
            glbl_phi_R = None
        glbl_phi_I = cdiag['global_phi_I']
        if 'global_phi_S' in cdiag:
            glbl_phi_S = cdiag['global_phi_S']
        else:
            glbl_phi_S = None

        # Get direction quadrature
        quad_R = quadrature.quad_schemes[('dir',cparams.get('quad_type_R','sk97'))]
        quad_degree_R = cparams.get('quad_degree_R','auto')
        if quad_degree_R == 'auto':
            usFactor_R = cparams.get('quad_undersample_R',
                                     cparams.get('quad_undersample',1.0))
            quad_degree_R,_ = quad_R.compute_degree(N,rad,usFactor_R)
        origlebDirs,_ = quad_R.get_quad_points(quad_degree_R,quad_sym)
        lebDirs = np.dot(origlebDirs,R)
        vmax_R = 5.0/len(glbl_phi_R)

        # Get shift quadrature
        if 'global_phi_S' in cdiag:
            quad_S = quadrature.quad_schemes[('shift',cparams.get('quad_type_S','hermite'))]
            quad_degree_S = cparams.get('quad_degree_S','auto')
            if quad_degree_S == 'auto':
                usFactor_S = cparams.get('quad_undersample_S',
                                        cparams.get('quad_undersample',1.0))
                quad_degree_S = quad_S.get_degree(N,rad,
                                                cparams['quad_shiftsigma']/resolution,
                                                cparams['quad_shiftextent']/resolution,
                                                usFactor_S)
            pts_S,_ = quad_S.get_quad_points(quad_degree_S,
                                            cparams['quad_shiftsigma']/resolution,
                                            cparams['quad_shiftextent']/resolution,
                                            cparams.get('quad_shifttrunc','circ'))
            vmax_S = 5.0/len(glbl_phi_S)
        else:
            pts_S = np.zeros_like([0])
            vmax_S = None

        # Density visualization
        mlab.figure(self.fig1)
        mlab.clf()
        self.curr_contours = plot_density(alignedM, self.contours, levels)
        # dispPhiR = glbl_phi_R
        # dispDirs = lebDirs
        # plot_directions(alignedM.shape[0]*dispDirs + alignedM.shape[0]/2.0,
        #                 dispPhiR,
        #                 0, vmax_R)
        mlab.view(focalpoint=[alignedM.shape[0]/2.0,alignedM.shape[0]/2.0,alignedM.shape[0]/2.0],distance=1.5*alignedM.shape[0])


        if glbl_phi_R is not None:
            plt.figure(self.get_figure('global_is_dists').number)
            plt.clf()
            plot_importance_dists(name,lebDirs,pts_S*resolution,glbl_phi_R,glbl_phi_I,glbl_phi_S,vmax_R,vmax_S)

        if self.show_grad:
            # Statistics of dM
            self.figdMslices.set_data(aligneddM)

            plt.figure(self.get_figure('step_stats').number)
            plt.clf()
            plt.suptitle(name + ' Step Statistics')

            plt.subplot(1,2,1)
            plt.hist(self.dM.reshape((-1,)),bins=0.5*self.dM.shape[0],log=True)
            plt.title('Voxel Histogram')

            (fs,raps) = rot_power_spectra(self.dM,resolution=resolution)
            plt.subplot(1,2,2)
            plt.plot(fs/(N/2.0)/(2.0*resolution),raps,label='RAPS')
            plt.plot((rad/(2.0*resolution))*np.ones((2,)), 
                     np.array([raps[raps > 0].min(),raps.max()]))
            plt.yscale('log')
            plt.title('RAPS Step')


        if not self.extra_plots:
            self.close_figure('density_stats')
            return

        # Statistics of M
        self.show_density_plot(self.get_figure('density_stats'))
Esempio n. 27
0
    def eval(self,
             M=None,
             fM=None,
             compute_gradient=True,
             all_grads=False,
             **kwargs):
        anyfspace = any([obj.fspace for obj in self.objs])
        anyrspace = any([not obj.fspace for obj in self.objs])

        N = None
        if fM is None and anyfspace:
            assert M is not None, 'M or fM must be set!'
            N = M.shape[0]
            fM = density.real_to_fspace(M)
        elif fM is not None:
            N = fM.shape[0]

        if M is None and anyrspace:
            assert fM is not None, 'M or fM must be set!'
            N = fM.shape[0]
            M = density.fspace_to_real(fM)
        elif M is not None:
            assert N is None or N == M.shape[0]
            N = M.shape[0]

        assert N is not None

        logP = 0
        logPs = []
        if compute_gradient:
            if all_grads:
                dlogP = density.zeros_like(
                    fM) if self.fspace else density.zeros_like(M)
                dlogPs = []
            else:
                if (not self.fspace) or anyrspace:
                    dlogPdM = density.zeros_like(M)
                if self.fspace or anyfspace:
                    dlogPdfM = density.zeros_like(fM)
        outputs = {}
        for w, obj in zip(self.ws, self.objs):
            if compute_gradient:
                clogP, cdlogP, coutputs = obj.eval(
                    M=M, fM=fM, compute_gradient=compute_gradient, **kwargs)
                if w is not None and w != 1:
                    clogP *= w
                    cdlogP *= w

                if all_grads:
                    if obj.fspace == self.fspace:
                        dlogPs.append(cdlogP)
                    elif self.fspace:
                        dlogPs.append(density.real_to_fspace(cdlogP))
                    else:
                        dlogPs.append(density.fspace_to_real(cdlogP))
                    dlogP += dlogPs[-1]
                else:
                    if obj.fspace:
                        dlogPdfM += cdlogP
                    else:
                        dlogPdM += cdlogP

            else:
                clogP, coutputs = obj.eval(M=M,
                                           fM=fM,
                                           compute_gradient=compute_gradient,
                                           **kwargs)
                if w is not None and w != 1:
                    clogP *= w

            logP += clogP
            logPs.append(clogP)
            outputs.update(coutputs)

        if compute_gradient and not all_grads:
            if self.fspace:
                dlogP = dlogPdfM
                if anyrspace:
                    dlogP += density.real_to_fspace(dlogPdM)
            else:
                dlogP = dlogPdM
                if anyfspace:
                    dlogP += density.fspace_to_real(dlogPdfM)

        outputs['all_logPs'] = logPs
        if compute_gradient and all_grads:
            outputs['all_dlogPs'] = dlogPs

        if compute_gradient:
            return logP, dlogP, outputs
        else:
            return logP, outputs
Esempio n. 28
0
def calc_angular_correlation(
    trunc_slices,
    N,
    rad,
    beamstop_rad=None,
    pixel_size=1.0,
    interpolation='nearest',
    sort_theta=True,
    clip=True,
    outside=False,
):
    """compute angular correlation for input array
    outside: True or False (default: False)
        calculate angular correlation in radius or outside of radius
    sort_theta: True or False (default: True)
        sort theta when slicing the same rho in trunc array
    """
    # 1. get a input (single: N_T or multi: N_R x N_T) with normal sequence.
    # 2. sort truncation array by rho value of polar coordinates
    # 3. apply angular correlation function to sorted slice for both real part and imaginary part
    # 4. deal with outlier beyond 3 sigma (no enough points to do sampling via fft)
    #    (oversampling is unavailable, hence dropout points beyond 3 sigma)
    # 5. return angluar correlation slice with normal sequence.

    # 1.
    iscomplex = np.iscomplexobj(trunc_slices)
    if outside:
        trunc_xy = gencoords_outside(N, 2, rad)
    else:
        if beamstop_rad is None:
            trunc_xy = geometry.gencoords(N, 2, rad)
        else:
            trunc_xy = geometry.gencoords_centermask(N, 2, rad, beamstop_rad)
    if trunc_slices.ndim < 2:
        assert trunc_xy.shape[0] == trunc_slices.shape[
            0], "wrong length of trunc slice or wrong radius"
    else:
        assert trunc_xy.shape[0] == trunc_slices.shape[
            1], "wrong length of trunc slice or wrong radius"

    # 2.
    pol_trunc_xy = cart2pol(trunc_xy)
    if sort_theta:
        # lexsort; first, sort rho; second, sort theta
        sorted_idx = np.lexsort((pol_trunc_xy[:, 1], pol_trunc_xy[:, 0]))
    else:
        sorted_idx = np.argsort(pol_trunc_xy[:, 0])
    axis = trunc_slices.ndim - 1
    sorted_rho = np.take(pol_trunc_xy[:, 0], sorted_idx)
    sorted_slice = np.take(trunc_slices, sorted_idx, axis=axis)

    # 3.
    if 'none' in interpolation:
        pass
    elif 'nearest' in interpolation:
        sorted_rho = np.round(sorted_rho)
    elif 'linear' in interpolation:
        raise NotImplementedError()
    else:
        raise ValueError('unsupported method for interpolation')
    # sorted_rho_freqs = sorted_rho / (N * pixel_size)
    resolution = 1.0 / (N * pixel_size)

    _, unique_idx, unique_counts = np.unique(sorted_rho,
                                             return_index=True,
                                             return_counts=True)
    indices = [slice(None)] * trunc_slices.ndim
    angular_correlation = np.zeros_like(trunc_slices, dtype=trunc_slices.dtype)
    for i, count in enumerate(unique_counts):
        indices[axis] = slice(unique_idx[i], unique_idx[i] + count)
        # minimum points to do fft (2 or 4 times than Nyquist frequency)
        # minimum_sample_points = (4 / count) / resolution
        minimum_sample_points = 2000
        same_rho = np.copy(sorted_slice[indices])
        if count < minimum_sample_points:
            for shift in range(count):
                curr_delta_phi = same_rho * np.roll(same_rho, shift, axis=axis)
                indices[axis] = unique_idx[i] + shift
                angular_correlation[indices] = np.mean(curr_delta_phi,
                                                       axis=axis)
        else:
            # use view (slicing) or copy (fancy indexing, np.take(), np.put())?
            fpcimg_real = density.real_to_fspace(
                same_rho.real, axes=(axis, ))  # polar image in fourier sapce
            angular_correlation[indices].real = density.fspace_to_real(
                fpcimg_real * fpcimg_real.conjugate(), axes=(axis, )).real
            if iscomplex:  # FIXME: stupid way. optimize this
                fpcimg_fourier = density.real_to_fspace(
                    same_rho.imag,
                    axes=(axis, ))  # polar image in fourier sapce
                angular_correlation[indices].imag = density.fspace_to_real(
                    fpcimg_fourier * fpcimg_fourier.conjugate(),
                    axes=(axis, )).real

    # check inf and nan
    if np.any(np.isinf(angular_correlation)):
        warnings.warn(
            "Some values in angular correlation occur inf. These values have been set to zeros."
        )
        angular_correlation.real[np.isinf(angular_correlation.real)] = 0
        if iscomplex:
            angular_correlation.imag[np.isinf(angular_correlation.imag)] = 0
    if np.any(np.isnan(angular_correlation)):
        warnings.warn(
            "Some values in angular correlation occur inf. These values have been set to zeros."
        )
        angular_correlation.real[np.isnan(angular_correlation.real)] = 0
        if iscomplex:
            angular_correlation.imag[np.isnan(angular_correlation.imag)] = 0

    # 4.
    if clip:
        factor = 3.0
        for i, count in enumerate(unique_counts):
            # minimum_sample_points = (4 / count) / resolution
            indices[axis] = slice(unique_idx[i], unique_idx[i] + count)
            mean = np.tile(angular_correlation[indices].mean(axis),
                           (count, 1)).T
            std = np.tile(angular_correlation[indices].std(axis), (count, 1)).T
            # print(mean)
            # print(std)

            vmin = mean.mean(axis) - factor * std.mean(axis)
            vmax = mean.mean(axis) + factor * std.mean(axis)

            # if np.all(std < 1e-16):
            #     # why ???
            #     warnings.warn("Standard deviation all equal to zero")
            #     vmin = mean.mean(axis) - factor * std.mean(axis)
            #     vmax = mean.mean(axis) + factor * std.mean(axis)
            # else:
            #     # Normalize to N(0, 1)
            #     angular_correlation[indices] = (angular_correlation[indices] - mean) / std
            #     vmin = -factor
            #     vmax = +factor

            angular_correlation[indices] = np.clip(
                angular_correlation[indices].T, vmin,
                vmax).T  # set outlier to nearby boundary

    # 5.
    corr_trunc_slices = np.take(angular_correlation,
                                sorted_idx.argsort(),
                                axis=axis)
    return corr_trunc_slices