Beispiel #1
0
    def __init__(self, cryodata, use_angular_correlation=False):
        self.cryodata = cryodata
        try:
            self.N = self.cryodata.get_num_pixels()
            self.psize = self.cryodata.get_pixel_size()
        except AttributeError:
            self.N = self.cryodata.imgstack.get_num_pixels()
            self.psize = self.cryodata.imgstack.get_pixel_size()

        symmetry = self.cryodata.dataset_params.get('symmetry', None)
        self.is_sym = get_symmetryop(symmetry)
        self.sampler_R = FixedFisherImportanceSampler('_R', self.is_sym)
        self.sampler_I = FixedFisherImportanceSampler('_I')

        self.use_cached_slicing = True
        self.use_angular_correlation = use_angular_correlation
        self.ac_slices = None

        self.G_datatype = np.float32

        self.executor = ThreadPoolExecutor(max_workers=NUM_CORES)

        self.cached_workspace = dict()
        self.cached_cphi = dict()
Beispiel #2
0
    def __init__(self, model, dataset_params, ctf_params,
                 interp_params={'kern': 'lanczos', 'kernsize': 4.0, 'zeropad': 0, 'dopremult': True},
                 load_cache=True):

        self.dataset_params = dataset_params

        if model is not None:
            # assert False
            assert isinstance(model, np.ndarray), "Unexpected data type for input model"

            self.num_pixels = model.shape[0]
            N = self.num_pixels
            self.num_images = dataset_params['num_images']
            assert self.num_images > 1, "it's better to make num_images larger than 1."
            self.pixel_size = float(dataset_params['pixel_size'])
            euler_angles = dataset_params['euler_angles']
            self.is_sym = get_symmetryop(dataset_params.get('symmetry', None))

            if euler_angles is None and self.is_sym is None:
                pt = np.random.randn(self.num_images, 3)
                pt /= np.linalg.norm(pt, axis=1, keepdims=True)
                euler_angles = geometry.genEA(pt)
                euler_angles[:, 2] = 2 * np.pi * np.random.rand(self.num_images)
            elif euler_angles is None and self.is_sym is not None:
                euler_angles = np.zeros((self.num_images, 3))
                for i, ea in enumerate(euler_angles):
                    while True:
                        pt = np.random.randn(3)
                        pt /= np.linalg.norm(pt)
                        if self.is_sym.in_asymunit(pt.reshape(-1, 3)):
                            break
                    ea[0:2] = geometry.genEA(pt)[0][0:2]
                    ea[2] = 2 * np.pi * np.random.rand()
            self.euler_angles = euler_angles.reshape((-1, 3))

            if ctf_params is not None:
                self.use_ctf = True
                ctf_map = ctf.compute_full_ctf(None, N, ctf_params['psize'],
                    ctf_params['akv'], ctf_params['cs'], ctf_params['wgh'],
                    ctf_params['df1'], ctf_params['df2'], ctf_params['angast'],
                    ctf_params['dscale'], ctf_params.get('bfactor', 500))
                self.ctf_params = copy(ctf_params)
                if 'bfactor' in self.ctf_params.keys():
                    self.ctf_params.pop('bfactor')
            else:
                self.use_ctf = False
                ctf_map = np.ones((N**2,), dtype=density.real_t)

            kernel = 'lanczos'
            ksize = 6
            rad = 0.95
            # premult = cryoops.compute_premultiplier(N, kernel, ksize)
            TtoF = sincint.gentrunctofull(N=N, rad=rad)
            base_coords = geometry.gencoords(N, 2, rad)
            # premulter =   premult.reshape((1, 1, -1)) \
            #             * premult.reshape((1, -1, 1)) \
            #             * premult.reshape((-1, 1, 1))
            # fM = density.real_to_fspace(premulter * model)
            fM = model

            # if load_cache:
            #     try:
            print("Generating Dataset ... :")
            tic = time.time()
            imgdata = np.empty((self.num_images, N, N), dtype=density.real_t)
            for i, ea in zip(range(self.num_images), self.euler_angles):
                R = geometry.rotmat3D_EA(*ea)[:, 0:2]
                slop = cryoops.compute_projection_matrix(
                    [R], N, kernel, ksize, rad, 'rots')
                # D = slop.dot(fM.reshape((-1,)))
                rotated_coords = R.dot(base_coords.T).T + int(N/2)
                D = interpn((np.arange(N),) * 3, fM, rotated_coords)
                np.maximum(D, 0.0, out=D)
                
                intensity = ctf_map.reshape((N, N)) * TtoF.dot(D).reshape((N, N))
                np.maximum(1e-8, intensity, out=intensity)
                intensity = np.float_( np.random.poisson(intensity) )
                imgdata[i] = np.require(intensity, dtype=density.real_t)
            self.imgdata = imgdata
            print("  cost {} seconds.".format(time.time()-tic))

            self.set_transform(interp_params)
            # self.prep_processing()

        else:
            euler_angles = []    
            with open(self.dataset_params['gtpath']) as par:
                par.readline()
                # 'C                 PHI      THETA        PSI        SHX        SHY       FILM        DF1        DF2     ANGAST'
                while True:
                    try:
                        line = par.readline().split()
                        euler_angles.append([float(line[1]), float(line[2]), float(line[3])])
                    except Exception:
                        break
            self.euler_angles = np.deg2rad(np.asarray(euler_angles))
            num_images = self.dataset_params.get('num_images', 200)
            imgdata = mrc.readMRCimgs(self.dataset_params['inpath'], 0, num_images)

            self.imgdata = np.transpose(imgdata, axes=(2, 0, 1))

            self.num_images = self.imgdata.shape[0]
            self.num_pixels = self.imgdata.shape[1]
            N = self.num_pixels

            self.pixel_size = self.dataset_params['resolution']
            self.is_sym = self.dataset_params.get('symmetry', None)

            self.use_ctf = False
            ctf_map = np.ones((N**2,), dtype=density.real_t)
            self.set_transform(interp_params)
Beispiel #3
0
    def set_slice_quad(self,rad):
        # Get (and generate if needed) the quadrature scheme for slicing
        params = self.params

        tic = time.time()

        N = self.N
        degree_R = params.get('quad_degree_R','auto')
        quad_scheme_R = params.get('quad_type_R','sk97')
        sym = get_symmetryop(params.get('symmetry',None)) if params.get('perfect_symmetry',True) else None
        usFactor_R = params.get('quad_undersample_R',params.get('quad_undersample',1.0))

        kern_R = params.get('interp_kernel_R',params.get('interp_kernel',None))
        kernsize_R = params.get('interp_kernel_size_R',params.get('interp_kernel_size',None))
        zeropad_R = params.get('interp_zeropad_R',params.get('interp_zeropad',0))
        dopremult_R = params.get('interp_premult_R',params.get('interp_premult',True))

        quad_R = quadrature.quad_schemes[('dir',quad_scheme_R)]

        if degree_R == 'auto':
            degree_R,resolution_R = quad_R.compute_degree(N,rad,usFactor_R)
        resolution_R = max(0.5*quadrature.compute_max_angle(self.N,rad), resolution_R)

        slice_params = { 'quad_type':quad_scheme_R, 'degree':degree_R, 
                         'sym':sym }
        interp_params_R = { 'N':self.N, 'kern':kern_R, 'kernsize':kernsize_R, 'rad':rad, 'zeropad':zeropad_R, 'dopremult':dopremult_R }
        
        domain_change_R = slice_params != self.slice_params
        interp_change_R = self.slice_interp != interp_params_R
        transform_change = self.slice_interp is None or \
                        self.slice_interp['kern'] != interp_params_R['kern'] or \
                        self.slice_interp['kernsize'] != interp_params_R['kernsize'] or \
                        self.slice_interp['zeropad'] != interp_params_R['zeropad']

        if domain_change_R:
            slice_quad = {}

            slice_quad['resolution'] = resolution_R
            slice_quad['degree'] = degree_R
            slice_quad['symop'] = sym

            slice_quad['dir'],slice_quad['W'] = quad_R.get_quad_points(degree_R,slice_quad['symop'])
            slice_quad['W'] = np.require(slice_quad['W'], dtype=np.float32)

            self.quad_domain_R = quadrature.FixedSphereDomain(slice_quad['dir'],
                                                              slice_quad['resolution'],\
                                                              sym=sym)
            self.N_R = len(self.quad_domain_R)
            self.sampler_R.setup(params, self.N_D, self.N_D_Train, self.quad_domain_R)
            
            self.slice_quad = slice_quad
            self.slice_params = slice_params

        if domain_change_R or interp_change_R:
            symorder = 1 if self.slice_quad['symop'] is None else self.slice_quad['symop'].get_order()
            print("  Slice Ops: %d, " % self.N_R); sys.stdout.flush()
            if self.N_R*symorder < self.otf_thresh_R:
                self.using_precomp_slicing = True
                print("generated in", end=''); sys.stdout.flush()
                self.slice_ops = self.quad_domain_R.compute_operator(interp_params_R)
                print(" {0} secs.".format(time.time() - tic))

                Gsz = (self.N_R,self.N_T)
                self.G = np.empty(Gsz, dtype=self.G_datatype)
                self.slices = np.empty(np.prod(Gsz), dtype=np.complex64)
            else:
                self.using_precomp_slicing = False
                print("generating OTF.")
                self.slice_ops = None
                self.G = np.empty((self.N,self.N,self.N),dtype=self.G_datatype)
                self.slices = None
            self.slice_interp = interp_params_R

        if transform_change:
            if dopremult_R:
                premult = cryoops.compute_premultiplier(self.N + 2*int(interp_params_R['zeropad']*(self.N/2)), 
                                                        interp_params_R['kern'],interp_params_R['kernsize'])
                premult = premult.reshape((-1,1,1)) * premult.reshape((1,-1,1)) * premult.reshape((1,1,-1))
            else:
                premult = None
            self.slice_premult = premult
            self.slice_zeropad = interp_params_R['zeropad']
            assert interp_params_R['zeropad'] == 0,'Zero padding for slicing not yet implemented'
Beispiel #4
0
    def set_proj_quad(self,rad):
        # Get (and generate if needed) the quadrature scheme for slicing
        params = self.params

        tic = time.time()

        N = self.N

        quad_scheme_R = params.get('quad_type_R','sk97')
        quad_R = quadrature.quad_schemes[('dir',quad_scheme_R)]

        degree_R = params.get('quad_degree_R','auto')
        degree_I = params.get('quad_degree_I','auto')

        usFactor_R = params.get('quad_undersample_R',params.get('quad_undersample',1.0))
        usFactor_I = params.get('quad_undersample_I',params.get('quad_undersample',1.0))

        kern_R = params.get('interp_kernel_R',params.get('interp_kernel',None))
        kernsize_R = params.get('interp_kernel_size_R',params.get('interp_kernel_size',None))
        zeropad_R = params.get('interp_zeropad_R',params.get('interp_zeropad',0))
        dopremult_R = params.get('interp_premult_R',params.get('interp_premult',True))

        sym = get_symmetryop(params.get('symmetry',None)) if params.get('perfect_symmetry',True) else None

        maxAngle = quadrature.compute_max_angle(self.N,rad,usFactor_I)
        if degree_I == 'auto':
            degree_I = np.uint32(np.ceil(2.0 * np.pi / maxAngle))

        if degree_R == 'auto':
            degree_R,resolution_R = quad_R.compute_degree(N,rad,usFactor_R)

        resolution_R = max(0.5*quadrature.compute_max_angle(self.N,rad), resolution_R)
        resolution_I = max(0.5*quadrature.compute_max_angle(self.N,rad), 2.0*np.pi / degree_I)

        slice_params = { 'quad_type':quad_scheme_R, 'degree':degree_R, 
                         'sym':sym }
        inplane_params = { 'degree':degree_I }
        proj_params = { 'quad_type_R':quad_scheme_R, 'degree_R':degree_R, 
                         'sym':sym, 'degree_I':degree_I }
        interp_params_RI = { 'N':self.N, 'kern':kern_R, 'kernsize':kernsize_R, 'rad':rad, 'zeropad':zeropad_R, 'dopremult':dopremult_R }
        interp_change_RI = self.proj_interp != interp_params_RI
        
        transform_change = self.slice_interp is None or \
                        self.slice_interp['kern'] != interp_params_RI['kern'] or \
                        self.slice_interp['kernsize'] != interp_params_RI['kernsize'] or \
                        self.slice_interp['zeropad'] != interp_params_RI['zeropad']

        domain_change_R = self.slice_params != slice_params
        domain_change_I = self.inplane_params != inplane_params
        domain_change_RI = self.proj_params != proj_params  
        
        if domain_change_RI:
            proj_quad = {}

            proj_quad['resolution'] = max(resolution_R,resolution_I)
            proj_quad['degree_R'] = degree_R
            proj_quad['degree_I'] = degree_I
            proj_quad['symop'] = sym

            proj_quad['dir'],proj_quad['W_R'] = quad_R.get_quad_points(degree_R,proj_quad['symop'])
            proj_quad['W_R'] = np.require(proj_quad['W_R'], dtype=np.float32)

            proj_quad['thetas'] = np.linspace(0, 2.0*np.pi, degree_I, endpoint=False)
            proj_quad['thetas'] += proj_quad['thetas'][1]/2.0
            proj_quad['W_I'] = np.require((2.0*np.pi/float(degree_I))*np.ones((degree_I,)),dtype=np.float32)

            self.quad_domain_RI = quadrature.FixedSO3Domain( proj_quad['dir'],
                                                            -proj_quad['thetas'],
                                                             proj_quad['resolution'],
                                                             sym=sym)
            self.N_RI = len(self.quad_domain_RI)
            self.proj_quad = proj_quad
            self.proj_params = proj_params


            if domain_change_R:
                self.quad_domain_R = quadrature.FixedSphereDomain(proj_quad['dir'],
                                                                  proj_quad['resolution'],
                                                                  sym=sym)
                self.N_R = len(self.quad_domain_R)
                self.sampler_R.setup(params, self.N_D, self.N_D_Train, self.quad_domain_R)
                self.slice_params = slice_params

            if domain_change_I:
                self.quad_domain_I = quadrature.FixedCircleDomain(proj_quad['thetas'],
                                                                  proj_quad['resolution'])
                self.N_I = len(self.quad_domain_I)
                self.sampler_I.setup(params, self.N_D, self.N_D_Train, self.quad_domain_I)
                self.inplane_params = inplane_params

        if domain_change_RI or interp_change_RI:
            symorder = 1 if self.proj_quad['symop'] is None else self.proj_quad['symop'].get_order()
            print("  Projection Ops: %d (%d slice, %d inplane), " % (self.N_RI, self.N_R, self.N_I)); sys.stdout.flush()
            if self.N_RI*symorder < self.otf_thresh_RI:
                self.using_precomp_slicing = True
                print("generated in", end=''); sys.stdout.flush()
                self.slice_ops = self.quad_domain_RI.compute_operator(interp_params_RI)
                print(" {0} secs.".format(time.time() - tic))

                Gsz = (self.N_RI,self.N_T)
                self.G = np.empty(Gsz, dtype=self.G_datatype)
                self.slices = np.empty(np.prod(Gsz), dtype=np.complex64)
            else:
                self.using_precomp_slicing = False
                print("generating OTF.")
                self.slice_ops = None
                self.G = np.empty((N,N,N),dtype=self.G_datatype)
                self.slices = None
            self.using_precomp_inplane = False
            self.inplane_ops = None
            
            self.proj_interp = interp_params_RI

        if transform_change:
            if dopremult_R:
                premult = cryoops.compute_premultiplier(self.N + 2*int(interp_params_RI['zeropad']*(self.N/2)), 
                                                        interp_params_RI['kern'],interp_params_RI['kernsize'])
                premult = premult.reshape((-1,1,1)) * premult.reshape((1,-1,1)) * premult.reshape((1,1,-1))
            else:
                premult = None
            self.slice_premult = premult
            self.slice_zeropad = interp_params_RI['zeropad']
            assert interp_params_RI['zeropad'] == 0,'Zero padding for slicing not yet implemented'
Beispiel #5
0
    def dowork(self):
        """Do one atom of work. I.E. Execute one minibatch"""

        timing = {}
        # Time each minibatch
        tic_mini = time.time()

        self.iteration += 1

        # Fetch the current batches
        trainbatch = self.cryodata.get_next_minibatch(self.cparams.get('shuffle_minibatches',True))

        # Get the current epoch
        cepoch = self.cryodata.get_epoch(frac=True)
        epoch = self.cryodata.get_epoch()
        num_data = self.cryodata.N_D_Train

        # Evaluate the parameters
        self.eval_params()
        timing['setup'] = time.time() - tic_mini

        # Do hyperparameter learning
        if self.cparams.get('learn_params',False):
            tic_learn = time.time()
            if self.cparams.get('learn_prior_params',True):
                tic_learn_prior = time.time()
                self.prior_func.learn_params(self.params, self.cparams, M=self.M, fM=self.fM)
                timing['learn_prior'] = time.time() - tic_learn_prior 

            if self.cparams.get('learn_likelihood_params',True):
                tic_learn_like = time.time()
                self.like_func.learn_params(self.params, self.cparams, M=self.M, fM=self.fM)
                timing['learn_like'] = time.time() - tic_learn_like
                
            if self.cparams.get('learn_prior_params',True) or self.cparams.get('learn_likelihood_params',True):
                timing['learn_total'] = time.time() - tic_learn   

        # Time each epoch
        if self.tic_epoch == None:
            self.ostream("Epoch: %d" % epoch)
            self.tic_epoch = (tic_mini,epoch)
        elif self.tic_epoch[1] != epoch:
            self.ostream("Epoch Total - %.6f seconds " % \
                         (tic_mini - self.tic_epoch[0]))
            self.tic_epoch = (tic_mini,epoch)

        sym = get_symmetryop(self.cparams.get('symmetry',None))
        if sym is not None:
            self.obj.ws[1] = 1.0/sym.get_order()

        tic_mstats = time.time()
        self.ostream(self.cparams['name']," Iteration:", self.iteration,\
                     " Epoch:", epoch, " Host:", socket.gethostname())

        # Compute density statistics
        N = self.cryodata.N
        M_sum = self.M.sum(dtype=n.float64)
        M_zeros = (self.M == 0).sum()
        M_mean = M_sum/N**3
        M_max = self.M.max()
        M_min = self.M.min()
#         self.ostream("  Density (min/max/avg/sum/zeros): " +
#                      "%.2e / %.2e / %.2e / %.2e / %g " %
#                      (M_min, M_max, M_mean, M_sum, M_zeros))
        self.statout.output(total_density=[M_sum],
                            avg_density=[M_mean],
                            nonzero_density=[M_zeros],
                            max_density=[M_max],
                            min_density=[M_min])
        timing['density_stats'] = time.time() - tic_mstats

        # evaluate test batch if requested
        if self.iteration <= 1 or self.cparams.get('evaluate_test_set',self.iteration%5):
            tic_test = time.time()
            testbatch = self.cryodata.get_testbatch()

            self.obj.set_data(self.cparams,testbatch)
            testLogP, res_test = self.obj.eval(M=self.M, fM=self.fM,
                                               compute_gradient=False)

            self.outputbatchinfo(testbatch, res_test, testLogP, 'test', 'Test')
            timing['test_batch'] = time.time() - tic_test
        else:
            testLogP, res_test = None, None

        # setup the wrapper for the objective function 
        tic_objsetup = time.time()
        self.obj.set_data(self.cparams,trainbatch)
        self.obj_wrapper.set_objective(self.obj)
        x0 = self.obj_wrapper.set_density(self.M,self.fM)
        evalobj = self.obj_wrapper.eval_obj
        timing['obj_setup'] = time.time() - tic_objsetup

        # Get step size
        self.num_data_evals += trainbatch['N_M']  # at least one gradient
        tic_objstep = time.time()
        trainLogP, dlogP, v, res_train, extra_num_data = self.step.do_step(x0,
                                                         self.cparams,
                                                         self.cryodata,
                                                         evalobj,
                                                         batch=trainbatch)

        # Apply the step
        x = x0 + v
        timing['step'] = time.time() - tic_objstep

        # Convert from parameters to value
        tic_stepfinalize = time.time()
        prevM = n.copy(self.M)
        self.M, self.fM = self.obj_wrapper.convert_parameter(x,comp_real=True)
 
        apply_sym = sym is not None and self.cparams.get('perfect_symmetry',True) and self.cparams.get('apply_symmetry',True)
        if apply_sym:
            self.M = sym.apply(self.M)

        # Truncate the density to bounds if they exist
        if self.cparams['density_lb'] is not None:
            n.maximum(self.M,self.cparams['density_lb']*self.cparams['modelscale'],out=self.M)
        if self.cparams['density_ub'] is not None:
            n.minimum(self.M,self.cparams['density_ub']*self.cparams['modelscale'],out=self.M)

        # Compute net change
        self.dM = prevM - self.M

        # Convert to fourier space (may not be required)
        if self.fM is None or apply_sym \
           or self.cparams['density_lb'] != None \
           or self.cparams['density_ub'] != None:
            self.fM = density.real_to_fspace(self.M)
        timing['step_finalize'] = time.time() - tic_stepfinalize

        # Compute step statistics
        tic_stepstats = time.time()
        step_size = n.linalg.norm(self.dM)
        grad_size = n.linalg.norm(dlogP)
        M_norm = n.linalg.norm(self.M)

        self.num_data_evals += extra_num_data
        inc_ratio = step_size / M_norm
        self.statout.output(step_size=[step_size],
                            inc_ratio=[inc_ratio],
                            grad_size=[grad_size],
                            norm_density=[M_norm])
        timing['step_stats'] = time.time() - tic_stepstats


        # Update import sampling distributions
        tic_isupdate = time.time()
        self.sampler_R.perform_update()
        self.sampler_I.perform_update()
        self.sampler_S.perform_update()

        self.diagout.output(global_phi_R=self.sampler_R.get_global_dist())
        self.diagout.output(global_phi_I=self.sampler_I.get_global_dist())
        self.diagout.output(global_phi_S=self.sampler_S.get_global_dist())
        timing['is_update'] = time.time() - tic_isupdate
        
        # Output basic diagnostics
        tic_diagnostics = time.time()
        self.diagout.output(iteration=self.iteration, epoch=epoch, cepoch=cepoch)

        if self.logpost_history.N_sum != self.cryodata.N_batches:
            self.logpost_history.setup(trainLogP,self.cryodata.N_batches)
        self.logpost_history.set_value(trainbatch['id'],trainLogP)

        if self.like_history.N_sum != self.cryodata.N_batches:
            self.like_history.setup(res_train['L'],self.cryodata.N_batches)
        self.like_history.set_value(trainbatch['id'],res_train['L'])

        self.outputbatchinfo(trainbatch, res_train, trainLogP, 'train', 'Train')

        # Dump parameters here to catch the defaults used in evaluation
        self.diagout.output(params=self.cparams,
                            envelope_mle=self.like_func.get_envelope_mle(),
                            sigma2_mle=self.like_func.get_sigma2_mle(),
                            hostname=socket.gethostname())
        self.statout.output(num_data=[num_data],
                            num_data_evals=[self.num_data_evals],
                            iteration=[self.iteration],
                            epoch=[epoch],
                            cepoch=[cepoch],
                            logp=[self.logpost_history.get_mean()],
                            like=[self.like_history.get_mean()],
                            sigma=[self.like_func.get_rmse()],
                            time=[time.time()])
        timing['diagnostics'] = time.time() - tic_diagnostics

        checkpoint_it = self.iteration % self.cparams.get('checkpoint_frequency',50) == 0 
        save_it = checkpoint_it or self.cparams['save_iteration'] or \
                  time.time() - self.last_save > self.cparams.get('save_time',n.inf)
                  
        if save_it:
            tic_save = time.time()
            self.last_save = tic_save
            if self.io_queue.qsize():
                print "Warning: IO queue has become backlogged with {0} remaining, waiting for it to clear".format(self.io_queue.qsize())
                self.io_queue.join()
            self.io_queue.put(( 'pkl', self.statout.fname, copy(self.statout.outdict) ))
            self.io_queue.put(( 'pkl', self.diagout.fname, deepcopy(self.diagout.outdict) ))
            self.io_queue.put(( 'pkl', self.likeout.fname, deepcopy(self.likeout.outdict) ))
            self.io_queue.put(( 'mrc', opj(self.outbase,'model.mrc'), \
                                (n.require(self.M,dtype=density.real_t),self.voxel_size) ))
            self.io_queue.put(( 'mrc', opj(self.outbase,'dmodel.mrc'), \
                                (n.require(self.dM,dtype=density.real_t),self.voxel_size) ))

            if checkpoint_it:
                self.io_queue.put(( 'cp', self.diagout.fname, self.diagout.fname+'-{0:06}'.format(self.iteration) ))
                self.io_queue.put(( 'cp', self.likeout.fname, self.likeout.fname+'-{0:06}'.format(self.iteration) ))
                self.io_queue.put(( 'cp', opj(self.outbase,'model.mrc'), opj(self.outbase,'model-{0:06}.mrc'.format(self.iteration)) ))
            timing['save'] = time.time() - tic_save
                
            
        time_total = time.time() - tic_mini
        self.ostream("  Minibatch Total - %.2f seconds                         Total Runtime - %s" %
                     (time_total, format_timedelta(datetime.now() - self.startdatetime) ))

        
        return self.iteration < self.cparams.get('max_iterations',n.inf) and \
               cepoch < self.cparams.get('max_epochs',n.inf)
Beispiel #6
0
    def __init__(self, expbase, cmdparams=None):
        """cryodata is a CryoData instance. 
        expbase is a path to the base of folder where this experiment's files
        will be stored.  The folder above expbase will also be searched
        for .params files. These will be loaded first."""
        BackgroundWorker.__init__(self)

        # Create a background thread which handles IO
        self.io_queue = Queue()
        self.io_thread = Thread(target=self.ioworker)
        self.io_thread.daemon = True
        self.io_thread.start()

        # General setup ----------------------------------------------------
        self.expbase = expbase
        self.outbase = None

        # Paramter setup ---------------------------------------------------
        # search above expbase for params files
        _,_,filenames = os.walk(opj(expbase,'../')).next()
        self.paramfiles = [opj(opj(expbase,'../'), fname) \
                           for fname in filenames if fname.endswith('.params')]
        # search expbase for params files
        _,_,filenames = os.walk(opj(expbase)).next()
        self.paramfiles += [opj(expbase,fname)  \
                            for fname in filenames if fname.endswith('.params')]
        if 'local.params' in filenames:
            self.paramfiles += [opj(expbase,'local.params')]
        # load parameter files
        self.params = Params(self.paramfiles)
        self.cparams = None
        
        if cmdparams is not None:
            # Set parameter specified on the command line
            for k,v in cmdparams.iteritems():
                self.params[k] = v
                
        # Dataset setup -------------------------------------------------------
        self.imgpath = self.params['inpath']
        psize = self.params['resolution']
        if not isinstance(self.imgpath,list):
            imgstk = MRCImageStack(self.imgpath,psize)
        else:
            imgstk = CombinedImageStack([MRCImageStack(cimgpath,psize) for cimgpath in self.imgpath])

        if self.params.get('float_images',True):
            imgstk.float_images()
        
        self.ctfpath = self.params['ctfpath']
        mscope_params = self.params['microscope_params']
         
        if not isinstance(self.ctfpath,list):
            ctfstk = CTFStack(self.ctfpath,mscope_params)
        else:
            ctfstk = CombinedCTFStack([CTFStack(cctfpath,mscope_params) for cctfpath in self.ctfpath])


        self.cryodata = CryoDataset(imgstk,ctfstk)
        self.cryodata.compute_noise_statistics()
        if self.params.get('window_images',True):
            imgstk.window_images()
        minibatch_size = self.params['minisize']
        testset_size = self.params['test_imgs']
        partition = self.params.get('partition',0)
        num_partitions = self.params.get('num_partitions',1)
        seed = self.params['random_seed']
        if isinstance(partition,str):
            partition = eval(partition)
        if isinstance(num_partitions,str):
            num_partitions = eval(num_partitions)
        if isinstance(seed,str):
            seed = eval(seed)
        self.cryodata.divide_dataset(minibatch_size,testset_size,partition,num_partitions,seed)
        
        self.cryodata.set_datasign(self.params.get('datasign','auto'))
        if self.params.get('normalize_data',True):
            self.cryodata.normalize_dataset()

        self.voxel_size = self.cryodata.pixel_size


        # Iterations setup -------------------------------------------------
        self.iteration = 0 
        self.tic_epoch = None
        self.num_data_evals = 0
        self.eval_params()

        outdir = self.cparams.get('outdir',None)
        if outdir is None:
            if self.cparams.get('num_partitions',1) > 1:
                outdir = 'partition{0}'.format(self.cparams['partition'])
            else:
                outdir = ''
        self.outbase = opj(self.expbase,outdir)
        if not os.path.isdir(self.outbase):
            os.makedirs(self.outbase) 

        # Output setup -----------------------------------------------------
        self.ostream = OutputStream(opj(self.outbase,'stdout'))

        self.ostream(80*"=")
        self.ostream("Experiment: " + expbase + \
                     "    Kernel: " + self.params['kernel'])
        self.ostream("Started on " + socket.gethostname() + \
                     "    At: " + time.strftime('%B %d %Y: %I:%M:%S %p'))
        self.ostream("Git SHA1: " + gitutil.git_get_SHA1())
        self.ostream(80*"=")
        gitutil.git_info_dump(opj(self.outbase, 'gitinfo'))
        self.startdatetime = datetime.now()


        # for diagnostics and parameters
        self.diagout = Output(opj(self.outbase, 'diag'),runningout=False)
        # for stats (per image etc)
        self.statout = Output(opj(self.outbase, 'stat'),runningout=True)
        # for likelihoods of individual images
        self.likeout = Output(opj(self.outbase, 'like'),runningout=False)

        self.img_likes = n.empty(self.cryodata.N_D)
        self.img_likes[:] = n.inf

        # optimization state vars ------------------------------------------
        init_model = self.cparams.get('init_model',None)
        if init_model is not None:
            filename = init_model
            if filename.upper().endswith('.MRC'):
                M = readMRC(filename)
            else:
                with open(filename) as fp:
                    M = cPickle.load(fp)
                    if type(M)==list:
                        M = M[-1]['M'] 
            if M.shape != 3*(self.cryodata.N,):
                M = cryoem.resize_ndarray(M,3*(self.cryodata.N,),axes=(0,1,2))
        else:
            init_seed = self.cparams.get('init_random_seed',0)  + self.cparams.get('partition',0)
            print "Randomly generating initial density (init_random_seed = {0})...".format(init_seed), ; sys.stdout.flush()
            tic = time.time()
            M = cryoem.generate_phantom_density(self.cryodata.N, 0.95*self.cryodata.N/2.0, \
                                                5*self.cryodata.N/128.0, 30, seed=init_seed)
            print "done in {0}s".format(time.time() - tic)

        tic = time.time()
        print "Windowing and aligning initial density...", ; sys.stdout.flush()
        # window the initial density
        wfunc = self.cparams.get('init_window','circle')
        cryoem.window(M,wfunc)

        # Center and orient the initial density
        cryoem.align_density(M)
        print "done in {0:.2f}s".format(time.time() - tic)

        # apply the symmetry operator
        init_sym = get_symmetryop(self.cparams.get('init_symmetry',self.cparams.get('symmetry',None)))
        if init_sym is not None:
            tic = time.time()
            print "Applying symmetry operator...", ; sys.stdout.flush()
            M = init_sym.apply(M)
            print "done in {0:.2f}s".format(time.time() - tic)

        tic = time.time()
        print "Scaling initial model...", ; sys.stdout.flush()
        modelscale = self.cparams.get('modelscale','auto')
        mleDC, _, mleDC_est_std = self.cryodata.get_dc_estimate()
        if modelscale == 'auto':
            # Err on the side of a weaker prior by using a larger value for modelscale
            modelscale = (n.abs(mleDC) + 2*mleDC_est_std)/self.cryodata.N
            print "estimated modelscale = {0:.3g}...".format(modelscale), ; sys.stdout.flush()
            self.params['modelscale'] = modelscale
            self.cparams['modelscale'] = modelscale
        M *= modelscale/M.sum()
        print "done in {0:.2f}s".format(time.time() - tic)
        if mleDC_est_std/n.abs(mleDC) > 0.05:
            print "  WARNING: the DC component estimate has a high relative variance, it may be inaccurate!"
        if ((modelscale*self.cryodata.N - n.abs(mleDC)) / mleDC_est_std) > 3:
            print "  WARNING: the selected modelscale value is more than 3 std devs different than the estimated one.  Be sure this is correct."

        self.M = n.require(M,dtype=density.real_t)
        self.fM = density.real_to_fspace(M)
        self.dM = density.zeros_like(self.M)

        self.step = eval(self.cparams['optim_algo'])
        self.step.setup(self.cparams, self.diagout, self.statout, self.ostream)

        # Objective function setup --------------------------------------------
        param_type = self.cparams.get('parameterization','real')
        cplx_param = param_type in ['complex','complex_coeff','complex_herm_coeff']
        self.like_func = eval_objective(self.cparams['likelihood'])
        self.prior_func = eval_objective(self.cparams['prior'])

        if self.cparams.get('penalty',None) is not None:
            self.penalty_func = eval_objective(self.cparams['penalty'])
            prior_func = SumObjectives(self.prior_func.fspace, \
                                       [self.penalty_func,self.prior_func], None)
        else:
            prior_func = self.prior_func

        self.obj = SumObjectives(cplx_param,
                                 [self.like_func,prior_func], [None,None])
        self.obj.setup(self.cparams, self.diagout, self.statout, self.ostream)
        self.obj.set_dataset(self.cryodata)
        self.obj_wrapper = ObjectiveWrapper(param_type)

        self.last_save = time.time()
        
        self.logpost_history = FiniteRunningSum()
        self.like_history = FiniteRunningSum()

        # Importance Samplers -------------------------------------------------
        self.is_sym = get_symmetryop(self.cparams.get('is_symmetry',self.cparams.get('symmetry',None)))
        self.sampler_R = FixedFisherImportanceSampler('_R',self.is_sym)
        self.sampler_I = FixedFisherImportanceSampler('_I')
        self.sampler_S = FixedGaussianImportanceSampler('_S')
        self.like_func.set_samplers(sampler_R=self.sampler_R,sampler_I=self.sampler_I,sampler_S=self.sampler_S)
Beispiel #7
0
def sagd_dostep(data_dir, model_file, use_angular_correlation=True):
    cryodata, (M, fM), cparams = sagd_init(data_dir, model_file,
                                           use_angular_correlation)

    iteration = cparams['iteration']

    sagd_params = {
        'iteration': iteration,
        'exp_path': 'exp/',

        'num_batches': cryodata.N_batches,

        'sagd_linesearch':          'max_frequency_changed or ((iteration%5 == 0) if iteration < 2500 else \
                                                            (iteration%3 == 0) if iteration < 5000 else \
                                                            True)'                                                                  ,
        'sagd_linesearch_accuracy': 1.01 if iteration < 10 else \
                                    1.10 if iteration < 2500 else \
                                    1.25 if iteration < 5000 else \
                                    1.50,
        'sagd_linesearch_maxits':   5 if iteration < 2500 else 3,
        'sagd_incL':                1.0,

        'sagd_momentum':      1 - 1.0/(1.0 + 0.1 * iteration),
        # 'sagd_learnrate':     '1.0/min(16.0,2**(num_max_frequency_changes))',

        'shuffle_minibatches': 'iteration >= 1000',
    }

    # initial logging
    if os.path.exists('exp/sagd_L0.pkl'):
        os.remove('exp/sagd_L0.pkl')
    # for diagnostics and parameters, # for stats (per image etc), # for likelihoods of individual images
    diagout = Output(os.path.join('exp', 'diag'), runningout=False)
    statout = Output(os.path.join('exp', 'stat'), runningout=True)
    likeout = Output(os.path.join('exp', 'like'), runningout=False)
    ostream = OutputStream(None)

    # Setup SAGD optimizer
    step = SAGDStep()
    step.setup(cparams, diagout, statout, ostream)

    # Objective function setup
    param_type = cparams.get('parameterization', 'real')
    cplx_param = param_type in [
        'complex', 'complex_coeff', 'complex_herm_coeff'
    ]
    like_func = eval_objective(cparams['likelihood'])
    prior_func = eval_objective(cparams['prior'])

    obj = SumObjectives(cplx_param, [like_func, prior_func], [None, None])
    obj.setup(cparams, diagout, statout, ostream)
    obj.set_dataset(cryodata)
    obj_wrapper = ObjectiveWrapper(param_type)

    is_sym = get_symmetryop(
        cparams.get('is_symmetry', cparams.get('symmetry', None)))
    sampler_R = FixedFisherImportanceSampler('_R', is_sym)
    sampler_I = FixedFisherImportanceSampler('_I')
    # sampler_S = FixedGaussianImportanceSampler('_S')
    sampler_S = None
    like_func.set_samplers(sampler_R=sampler_R,
                           sampler_I=sampler_I,
                           sampler_S=sampler_S)

    # Start iteration
    num_data_evals = 0
    num_iterations = 10
    for i in range(num_iterations):
        cparams['iteration'] = i
        sagd_params['iteration'] = i
        print('Iteration #:', i)
        minibatch = cryodata.get_next_minibatch(True)
        num_data_evals += minibatch['N_M']

        # setup the wrapper for the objective function
        obj.set_data(cparams, minibatch)
        obj_wrapper.set_objective(obj)
        x0 = obj_wrapper.set_density(M, fM)
        evalobj = obj_wrapper.eval_obj

        # Get step size
        trainLogP, dlogP, v, res_train, extra_num_data = \
            step.do_step(x0, sagd_params, cryodata, evalobj, batch=minibatch)

        # print('trainLogP:', trainLogP)
        # print('dlogP:', dlogP.shape)
        # print('v:', v.shape)
        # print('res_train:', res_train.keys())  # dict_keys(['CV2_R', 'CV2_I', 'Evar_like', 'Evar_prior', 'sigma2_est', 'correlation', 'power', 'like', 'N_R_sampled', 'N_I_sampled', 'N_Total_sampled', 'totallike_logscale', 'kern_timing', 'angular_correlation_timing', 'like_timing', 'N_R', 'N_I', 'N_Total', 'N_R_sampled_total', 'N_I_sampled_total', 'N_Total_sampled_total', 'L', 'all_logPs', 'all_dlogPs'])
        # print('res_train L:', res_train['L'])
        # print('res_train like:', res_train['like'])
        # print('extra_num_data:', extra_num_data)

        # Apply the step
        x = x0 + v

        # Convert from parameters to value
        prevM = np.copy(M)
        M, fM = obj_wrapper.convert_parameter(x, comp_real=True)

        # Compute net change
        dM = prevM - M

        # Compute step statistics
        step_size = np.linalg.norm(dM)
        grad_size = np.linalg.norm(dlogP)
        M_norm = np.linalg.norm(M)

        num_data_evals += extra_num_data
        inc_ratio = step_size / M_norm

        # Update import sampling distributions
        sampler_R.perform_update()
        sampler_I.perform_update()
Beispiel #8
0
    def updatevis(self, levels=[0.2,0.5,0.8]):
        if self.M is None or self.diag is None or self.stat is None:
            return

        cdiag = self.diag
        cparams = cdiag['params']
        sym = get_symmetryop(cparams.get('symmetry',None))
        quad_sym = sym if cparams.get('perfect_symmetry',True) else None

        resolution = cparams['voxel_size']

        name = cparams['name']
        maxfreq = cparams['max_frequency']
        N = self.M.shape[0]
        rad_cutoff = cparams.get('rad_cutoff', 1.0)
        rad = min(rad_cutoff,maxfreq*2.0*resolution)

        # Show objective function
        self.show_objective_plot(self.get_figure('stats'))
        
        # Show information about noise and error
        self.show_error_plot(self.get_figure('error'))
        self.show_noise_plot(self.get_figure('noise'))
        
        # Plot the envelope function if we have the info
        if 'envelope_mle' in cdiag:
            self.show_envelope_plot(self.get_figure('envelope'))
        else:
            self.close_figure('envelope')

        if sym is None:
            assert quad_sym is None
            alignedM,R = cryoem.align_density(self.M)
            if self.show_grad:
                aligneddM = cryoem.rotate_density(self.dM,R)
            else:
                aligneddM = None
        else:
            alignedM, aligneddM = self.M, self.dM
            R = np.identity(3)

        self.alignedM,self.aligneddM,self.alignedR = alignedM,aligneddM,R
        self.fM = density.real_to_fspace(self.M)

        self.figMslices.set_data(alignedM)

        glbl_phi_R = np.array([cdiag['global_phi_R']]).ravel()
        if len(glbl_phi_R) == 1:
            glbl_phi_R = None
        glbl_phi_I = cdiag['global_phi_I']
        if 'global_phi_S' in cdiag:
            glbl_phi_S = cdiag['global_phi_S']
        else:
            glbl_phi_S = None

        # Get direction quadrature
        quad_R = quadrature.quad_schemes[('dir',cparams.get('quad_type_R','sk97'))]
        quad_degree_R = cparams.get('quad_degree_R','auto')
        if quad_degree_R == 'auto':
            usFactor_R = cparams.get('quad_undersample_R',
                                     cparams.get('quad_undersample',1.0))
            quad_degree_R,_ = quad_R.compute_degree(N,rad,usFactor_R)
        origlebDirs,_ = quad_R.get_quad_points(quad_degree_R,quad_sym)
        lebDirs = np.dot(origlebDirs,R)
        vmax_R = 5.0/len(glbl_phi_R)

        # Get shift quadrature
        if 'global_phi_S' in cdiag:
            quad_S = quadrature.quad_schemes[('shift',cparams.get('quad_type_S','hermite'))]
            quad_degree_S = cparams.get('quad_degree_S','auto')
            if quad_degree_S == 'auto':
                usFactor_S = cparams.get('quad_undersample_S',
                                        cparams.get('quad_undersample',1.0))
                quad_degree_S = quad_S.get_degree(N,rad,
                                                cparams['quad_shiftsigma']/resolution,
                                                cparams['quad_shiftextent']/resolution,
                                                usFactor_S)
            pts_S,_ = quad_S.get_quad_points(quad_degree_S,
                                            cparams['quad_shiftsigma']/resolution,
                                            cparams['quad_shiftextent']/resolution,
                                            cparams.get('quad_shifttrunc','circ'))
            vmax_S = 5.0/len(glbl_phi_S)
        else:
            pts_S = np.zeros_like([0])
            vmax_S = None

        # Density visualization
        mlab.figure(self.fig1)
        mlab.clf()
        self.curr_contours = plot_density(alignedM, self.contours, levels)
        # dispPhiR = glbl_phi_R
        # dispDirs = lebDirs
        # plot_directions(alignedM.shape[0]*dispDirs + alignedM.shape[0]/2.0,
        #                 dispPhiR,
        #                 0, vmax_R)
        mlab.view(focalpoint=[alignedM.shape[0]/2.0,alignedM.shape[0]/2.0,alignedM.shape[0]/2.0],distance=1.5*alignedM.shape[0])


        if glbl_phi_R is not None:
            plt.figure(self.get_figure('global_is_dists').number)
            plt.clf()
            plot_importance_dists(name,lebDirs,pts_S*resolution,glbl_phi_R,glbl_phi_I,glbl_phi_S,vmax_R,vmax_S)

        if self.show_grad:
            # Statistics of dM
            self.figdMslices.set_data(aligneddM)

            plt.figure(self.get_figure('step_stats').number)
            plt.clf()
            plt.suptitle(name + ' Step Statistics')

            plt.subplot(1,2,1)
            plt.hist(self.dM.reshape((-1,)),bins=0.5*self.dM.shape[0],log=True)
            plt.title('Voxel Histogram')

            (fs,raps) = rot_power_spectra(self.dM,resolution=resolution)
            plt.subplot(1,2,2)
            plt.plot(fs/(N/2.0)/(2.0*resolution),raps,label='RAPS')
            plt.plot((rad/(2.0*resolution))*np.ones((2,)), 
                     np.array([raps[raps > 0].min(),raps.max()]))
            plt.yscale('log')
            plt.title('RAPS Step')


        if not self.extra_plots:
            self.close_figure('density_stats')
            return

        # Statistics of M
        self.show_density_plot(self.get_figure('density_stats'))
Beispiel #9
0
    def updatevis(self, levels=[0.2,0.5,0.8]):
        if self.M is None or self.diag is None or self.stat is None:
            return

        cdiag = self.diag
        cparams = cdiag['params']
        sym = get_symmetryop(cparams.get('symmetry',None))
        quad_sym = sym if cparams.get('perfect_symmetry',True) else None

        resolution = cparams['voxel_size']

        name = cparams['name']
        maxfreq = cparams['max_frequency']
        N = self.M.shape[0]
        rad_cutoff = cparams.get('rad_cutoff', 1.0)
        rad = min(rad_cutoff,maxfreq*2.0*resolution)

        # Show objective function
        self.show_objective_plot(self.get_figure('stats'))
        
        # Show information about noise and error
        self.show_error_plot(self.get_figure('error'))
        self.show_noise_plot(self.get_figure('noise'))
        
        # Plot the envelope function if we have the info
        if 'envelope_mle' in cdiag:
            self.show_envelope_plot(self.get_figure('envelope'))
        else:
            self.close_figure('envelope')

        if sym is None:
            assert quad_sym is None
            alignedM,R = c.align_density(self.M)
            if self.show_grad:
                aligneddM = c.rotate_density(self.dM,R)
            else:
                aligneddM = None
        else:
            alignedM, aligneddM = self.M, self.dM
            R = n.identity(3)

        self.alignedM,self.aligneddM,self.alignedR = alignedM,aligneddM,R
        self.fM = density.real_to_fspace(self.M)

        self.figMslices.set_data(alignedM)

        glbl_phi_R = n.array([cdiag['global_phi_R']]).ravel()
        if len(glbl_phi_R) == 1:
            glbl_phi_R = None
        glbl_phi_I = cdiag['global_phi_I']
        glbl_phi_S = cdiag['global_phi_S']

        # Get direction quadrature
        quad_R = quadrature.quad_schemes[('dir',cparams.get('quad_type_R','sk97'))]
        quad_degree_R = cparams.get('quad_degree_R','auto')
        if quad_degree_R == 'auto':
            usFactor_R = cparams.get('quad_undersample_R',
                                     cparams.get('quad_undersample',1.0))
            quad_degree_R,_ = quad_R.compute_degree(N,rad,usFactor_R)
        origlebDirs,_ = quad_R.get_quad_points(quad_degree_R,quad_sym)
        lebDirs = n.dot(origlebDirs,R)

        # Get shift quadrature
        quad_S = quadrature.quad_schemes[('shift',cparams.get('quad_type_S','hermite'))]
        quad_degree_S = cparams.get('quad_degree_S','auto')
        if quad_degree_S == 'auto':
            usFactor_S = cparams.get('quad_undersample_S',
                                     cparams.get('quad_undersample',1.0))
            quad_degree_S = quad_S.get_degree(N,rad,
                                              cparams['quad_shiftsigma']/resolution,
                                              cparams['quad_shiftextent']/resolution,
                                              usFactor_S)
        pts_S,_ = quad_S.get_quad_points(quad_degree_S,
                                         cparams['quad_shiftsigma']/resolution,
                                         cparams['quad_shiftextent']/resolution,
                                         cparams.get('quad_shifttrunc','circ'))
        vmax_R = 5.0/len(glbl_phi_R)
        vmax_S = 5.0/len(glbl_phi_S)

        # Density visualization
        mlab.figure(self.fig1)
        mlab.clf()
        self.curr_contours = plot_density(alignedM, self.contours, levels)
#         dispPhiR = glbl_phi_R
#         dispDirs = lebDirs
#         plot_directions(alignedM.shape[0]*dispDirs + alignedM.shape[0]/2.0,
#                         dispPhiR,
#                         0, vmax_R)
        mlab.view(focalpoint=[alignedM.shape[0]/2.0,alignedM.shape[0]/2.0,alignedM.shape[0]/2.0],distance=1.5*alignedM.shape[0])


        if glbl_phi_R is not None:
            plt.figure(self.get_figure('global_is_dists').number)
            plt.clf()
            plot_importance_dists(name,lebDirs,pts_S*resolution,glbl_phi_R,glbl_phi_I,glbl_phi_S,vmax_R,vmax_S)

        if self.show_grad:
            # Statistics of dM
            self.figdMslices.set_data(aligneddM)

            plt.figure(self.get_figure('step_stats').number)
            plt.clf()
            plt.suptitle(name + ' Step Statistics')

            plt.subplot(1,2,1)
            plt.hist(self.dM.reshape((-1,)),bins=0.5*self.dM.shape[0],log=True)
            plt.title('Voxel Histogram')

            (fs,raps) = rot_power_spectra(self.dM,resolution=resolution)
            plt.subplot(1,2,2)
            plt.plot(fs/(N/2.0)/(2.0*resolution),raps,label='RAPS')
            plt.plot((rad/(2.0*resolution))*n.ones((2,)), 
                     n.array([raps[raps > 0].min(),raps.max()]))
            plt.yscale('log')
            plt.title('RAPS Step')


        if not self.extra_plots:
            self.close_figure('density_stats')
            return

        # Statistics of M
        self.show_density_plot(self.get_figure('density_stats'))
Beispiel #10
0
			if it<=3:
				if subd_R:
					r_thresh = min(n.percentile(errs, 100*r_factor),best_thresh)
					rs = rs[errs <= r_thresh]
					rs, dr = geometry.subdivde(rs, dr)
				gc.collect()
				if radwn == max_radwn and ((not subd_R) or \
										   (delta_R < delta_R_thresh)):
					break
				if dr < radwn_ang:
					radwn = n.minimum(radwn + 5, 10)#max_radwn)
		"""
	print(time.time()-tic)/5000.
	return best_R, n.zeros(3)

    
    
import os
import symmetry
os.environ["CRYOSPARC_ROOT_DIR"] = "../runtest"
symop = symmetry.get_symmetryop('C1')
data_path_abs='../runtest/vol.mrc'
N = 256//8
map_r = readMRC(data_path_abs)
input_structure_datas = []
input_structure_datas.append(map_r)

initmodel = input_structure_datas[0]
initmodel = fourier.resample_real(initmodel, N)
align_symmetryEEE(initmodel,symop, 10, verbose=1, cuda_dev=0)
Beispiel #11
0
def load_kernel(data_dir, model_file, use_angular_correlation=False, sample_shifts=False):
    data_params = {
        'dataset_name': "1AON",
        'inpath': os.path.join(data_dir, 'imgdata.mrc'),
        'ctfpath': os.path.join(data_dir, 'defocus.txt'),
        'microscope_params': {'akv': 200, 'wgh': 0.07, 'cs': 2.0},
        'resolution': 2.8,
        'sigma': 'noise_std',
        'sigma_out': 'data_std',
        'minisize': 20,
        'test_imgs': 20,
        'partition': 0,
        'num_partitions': 0,
        'random_seed': 1,
        # 'symmetry': 'C7'
    }
    print("Loading dataset %s" % data_dir)
    cryodata, _ = dataset_loading_test(data_params)
    # mleDC, _, mleDC_est_std = cryodata.get_dc_estimate()
    # modelscale = (np.abs(mleDC) + 2*mleDC_est_std)/cryodata.N
    modelscale = 1.0

    if model_file is not None:
        print("Loading density map %s" % model_file)
        M = mrc.readMRC(model_file)
    else:
        print("Generating random initial density map ...")
        M = cryoem.generate_phantom_density(cryodata.N, 0.95 * cryodata.N / 2.0, \
                                            5 * cryodata.N / 128.0, 30, seed=0)
        M *= modelscale/M.sum()
    slice_interp = {'kern': 'lanczos', 'kernsize': 4, 'zeropad': 0, 'dopremult': True}
    fM = M

    minibatch = cryodata.get_next_minibatch(shuffle_minibatches=False)

    is_sym = get_symmetryop(data_params.get('symmetry',None))
    sampler_R = FixedFisherImportanceSampler('_R', is_sym)
    sampler_I = FixedFisherImportanceSampler('_I')
    sampler_S = None

    cparams = {
        'use_angular_correlation': use_angular_correlation,

        'iteration': 0,
        'pixel_size': cryodata.pixel_size,
        'max_frequency': 0.02,

        'interp_kernel_R': 'lanczos',
        'interp_kernel_size_R':	4,
        'interp_zeropad_R':	0,
        'interp_premult_R':	True,

        'interp_kernel_I': 'lanczos',
        'interp_kernel_size_I':	4,
        'interp_zeropad_I':	0.0,
        'interp_premult_I':	True,

        'quad_shiftsigma': 10,
        'quad_shiftextent': 60,

        'sigma': cryodata.noise_var,
        # 'symmetry': 'C7'
    }
    kernel = UnknownRSThreadedCPUKernel()
    kernel.setup(cparams, None, None, None)
    kernel.set_samplers(sampler_R, sampler_I, sampler_S)
    kernel.set_dataset(cryodata)
    kernel.precomp_slices = None
    kernel.set_data(cparams, minibatch)
    kernel.using_precomp_slicing = False
    kernel.using_precomp_inplane = False
    kernel.M = M
    kernel.fM = fM
    return kernel