def __init__(self, cryodata, use_angular_correlation=False): self.cryodata = cryodata try: self.N = self.cryodata.get_num_pixels() self.psize = self.cryodata.get_pixel_size() except AttributeError: self.N = self.cryodata.imgstack.get_num_pixels() self.psize = self.cryodata.imgstack.get_pixel_size() symmetry = self.cryodata.dataset_params.get('symmetry', None) self.is_sym = get_symmetryop(symmetry) self.sampler_R = FixedFisherImportanceSampler('_R', self.is_sym) self.sampler_I = FixedFisherImportanceSampler('_I') self.use_cached_slicing = True self.use_angular_correlation = use_angular_correlation self.ac_slices = None self.G_datatype = np.float32 self.executor = ThreadPoolExecutor(max_workers=NUM_CORES) self.cached_workspace = dict() self.cached_cphi = dict()
def __init__(self, model, dataset_params, ctf_params, interp_params={'kern': 'lanczos', 'kernsize': 4.0, 'zeropad': 0, 'dopremult': True}, load_cache=True): self.dataset_params = dataset_params if model is not None: # assert False assert isinstance(model, np.ndarray), "Unexpected data type for input model" self.num_pixels = model.shape[0] N = self.num_pixels self.num_images = dataset_params['num_images'] assert self.num_images > 1, "it's better to make num_images larger than 1." self.pixel_size = float(dataset_params['pixel_size']) euler_angles = dataset_params['euler_angles'] self.is_sym = get_symmetryop(dataset_params.get('symmetry', None)) if euler_angles is None and self.is_sym is None: pt = np.random.randn(self.num_images, 3) pt /= np.linalg.norm(pt, axis=1, keepdims=True) euler_angles = geometry.genEA(pt) euler_angles[:, 2] = 2 * np.pi * np.random.rand(self.num_images) elif euler_angles is None and self.is_sym is not None: euler_angles = np.zeros((self.num_images, 3)) for i, ea in enumerate(euler_angles): while True: pt = np.random.randn(3) pt /= np.linalg.norm(pt) if self.is_sym.in_asymunit(pt.reshape(-1, 3)): break ea[0:2] = geometry.genEA(pt)[0][0:2] ea[2] = 2 * np.pi * np.random.rand() self.euler_angles = euler_angles.reshape((-1, 3)) if ctf_params is not None: self.use_ctf = True ctf_map = ctf.compute_full_ctf(None, N, ctf_params['psize'], ctf_params['akv'], ctf_params['cs'], ctf_params['wgh'], ctf_params['df1'], ctf_params['df2'], ctf_params['angast'], ctf_params['dscale'], ctf_params.get('bfactor', 500)) self.ctf_params = copy(ctf_params) if 'bfactor' in self.ctf_params.keys(): self.ctf_params.pop('bfactor') else: self.use_ctf = False ctf_map = np.ones((N**2,), dtype=density.real_t) kernel = 'lanczos' ksize = 6 rad = 0.95 # premult = cryoops.compute_premultiplier(N, kernel, ksize) TtoF = sincint.gentrunctofull(N=N, rad=rad) base_coords = geometry.gencoords(N, 2, rad) # premulter = premult.reshape((1, 1, -1)) \ # * premult.reshape((1, -1, 1)) \ # * premult.reshape((-1, 1, 1)) # fM = density.real_to_fspace(premulter * model) fM = model # if load_cache: # try: print("Generating Dataset ... :") tic = time.time() imgdata = np.empty((self.num_images, N, N), dtype=density.real_t) for i, ea in zip(range(self.num_images), self.euler_angles): R = geometry.rotmat3D_EA(*ea)[:, 0:2] slop = cryoops.compute_projection_matrix( [R], N, kernel, ksize, rad, 'rots') # D = slop.dot(fM.reshape((-1,))) rotated_coords = R.dot(base_coords.T).T + int(N/2) D = interpn((np.arange(N),) * 3, fM, rotated_coords) np.maximum(D, 0.0, out=D) intensity = ctf_map.reshape((N, N)) * TtoF.dot(D).reshape((N, N)) np.maximum(1e-8, intensity, out=intensity) intensity = np.float_( np.random.poisson(intensity) ) imgdata[i] = np.require(intensity, dtype=density.real_t) self.imgdata = imgdata print(" cost {} seconds.".format(time.time()-tic)) self.set_transform(interp_params) # self.prep_processing() else: euler_angles = [] with open(self.dataset_params['gtpath']) as par: par.readline() # 'C PHI THETA PSI SHX SHY FILM DF1 DF2 ANGAST' while True: try: line = par.readline().split() euler_angles.append([float(line[1]), float(line[2]), float(line[3])]) except Exception: break self.euler_angles = np.deg2rad(np.asarray(euler_angles)) num_images = self.dataset_params.get('num_images', 200) imgdata = mrc.readMRCimgs(self.dataset_params['inpath'], 0, num_images) self.imgdata = np.transpose(imgdata, axes=(2, 0, 1)) self.num_images = self.imgdata.shape[0] self.num_pixels = self.imgdata.shape[1] N = self.num_pixels self.pixel_size = self.dataset_params['resolution'] self.is_sym = self.dataset_params.get('symmetry', None) self.use_ctf = False ctf_map = np.ones((N**2,), dtype=density.real_t) self.set_transform(interp_params)
def set_slice_quad(self,rad): # Get (and generate if needed) the quadrature scheme for slicing params = self.params tic = time.time() N = self.N degree_R = params.get('quad_degree_R','auto') quad_scheme_R = params.get('quad_type_R','sk97') sym = get_symmetryop(params.get('symmetry',None)) if params.get('perfect_symmetry',True) else None usFactor_R = params.get('quad_undersample_R',params.get('quad_undersample',1.0)) kern_R = params.get('interp_kernel_R',params.get('interp_kernel',None)) kernsize_R = params.get('interp_kernel_size_R',params.get('interp_kernel_size',None)) zeropad_R = params.get('interp_zeropad_R',params.get('interp_zeropad',0)) dopremult_R = params.get('interp_premult_R',params.get('interp_premult',True)) quad_R = quadrature.quad_schemes[('dir',quad_scheme_R)] if degree_R == 'auto': degree_R,resolution_R = quad_R.compute_degree(N,rad,usFactor_R) resolution_R = max(0.5*quadrature.compute_max_angle(self.N,rad), resolution_R) slice_params = { 'quad_type':quad_scheme_R, 'degree':degree_R, 'sym':sym } interp_params_R = { 'N':self.N, 'kern':kern_R, 'kernsize':kernsize_R, 'rad':rad, 'zeropad':zeropad_R, 'dopremult':dopremult_R } domain_change_R = slice_params != self.slice_params interp_change_R = self.slice_interp != interp_params_R transform_change = self.slice_interp is None or \ self.slice_interp['kern'] != interp_params_R['kern'] or \ self.slice_interp['kernsize'] != interp_params_R['kernsize'] or \ self.slice_interp['zeropad'] != interp_params_R['zeropad'] if domain_change_R: slice_quad = {} slice_quad['resolution'] = resolution_R slice_quad['degree'] = degree_R slice_quad['symop'] = sym slice_quad['dir'],slice_quad['W'] = quad_R.get_quad_points(degree_R,slice_quad['symop']) slice_quad['W'] = np.require(slice_quad['W'], dtype=np.float32) self.quad_domain_R = quadrature.FixedSphereDomain(slice_quad['dir'], slice_quad['resolution'],\ sym=sym) self.N_R = len(self.quad_domain_R) self.sampler_R.setup(params, self.N_D, self.N_D_Train, self.quad_domain_R) self.slice_quad = slice_quad self.slice_params = slice_params if domain_change_R or interp_change_R: symorder = 1 if self.slice_quad['symop'] is None else self.slice_quad['symop'].get_order() print(" Slice Ops: %d, " % self.N_R); sys.stdout.flush() if self.N_R*symorder < self.otf_thresh_R: self.using_precomp_slicing = True print("generated in", end=''); sys.stdout.flush() self.slice_ops = self.quad_domain_R.compute_operator(interp_params_R) print(" {0} secs.".format(time.time() - tic)) Gsz = (self.N_R,self.N_T) self.G = np.empty(Gsz, dtype=self.G_datatype) self.slices = np.empty(np.prod(Gsz), dtype=np.complex64) else: self.using_precomp_slicing = False print("generating OTF.") self.slice_ops = None self.G = np.empty((self.N,self.N,self.N),dtype=self.G_datatype) self.slices = None self.slice_interp = interp_params_R if transform_change: if dopremult_R: premult = cryoops.compute_premultiplier(self.N + 2*int(interp_params_R['zeropad']*(self.N/2)), interp_params_R['kern'],interp_params_R['kernsize']) premult = premult.reshape((-1,1,1)) * premult.reshape((1,-1,1)) * premult.reshape((1,1,-1)) else: premult = None self.slice_premult = premult self.slice_zeropad = interp_params_R['zeropad'] assert interp_params_R['zeropad'] == 0,'Zero padding for slicing not yet implemented'
def set_proj_quad(self,rad): # Get (and generate if needed) the quadrature scheme for slicing params = self.params tic = time.time() N = self.N quad_scheme_R = params.get('quad_type_R','sk97') quad_R = quadrature.quad_schemes[('dir',quad_scheme_R)] degree_R = params.get('quad_degree_R','auto') degree_I = params.get('quad_degree_I','auto') usFactor_R = params.get('quad_undersample_R',params.get('quad_undersample',1.0)) usFactor_I = params.get('quad_undersample_I',params.get('quad_undersample',1.0)) kern_R = params.get('interp_kernel_R',params.get('interp_kernel',None)) kernsize_R = params.get('interp_kernel_size_R',params.get('interp_kernel_size',None)) zeropad_R = params.get('interp_zeropad_R',params.get('interp_zeropad',0)) dopremult_R = params.get('interp_premult_R',params.get('interp_premult',True)) sym = get_symmetryop(params.get('symmetry',None)) if params.get('perfect_symmetry',True) else None maxAngle = quadrature.compute_max_angle(self.N,rad,usFactor_I) if degree_I == 'auto': degree_I = np.uint32(np.ceil(2.0 * np.pi / maxAngle)) if degree_R == 'auto': degree_R,resolution_R = quad_R.compute_degree(N,rad,usFactor_R) resolution_R = max(0.5*quadrature.compute_max_angle(self.N,rad), resolution_R) resolution_I = max(0.5*quadrature.compute_max_angle(self.N,rad), 2.0*np.pi / degree_I) slice_params = { 'quad_type':quad_scheme_R, 'degree':degree_R, 'sym':sym } inplane_params = { 'degree':degree_I } proj_params = { 'quad_type_R':quad_scheme_R, 'degree_R':degree_R, 'sym':sym, 'degree_I':degree_I } interp_params_RI = { 'N':self.N, 'kern':kern_R, 'kernsize':kernsize_R, 'rad':rad, 'zeropad':zeropad_R, 'dopremult':dopremult_R } interp_change_RI = self.proj_interp != interp_params_RI transform_change = self.slice_interp is None or \ self.slice_interp['kern'] != interp_params_RI['kern'] or \ self.slice_interp['kernsize'] != interp_params_RI['kernsize'] or \ self.slice_interp['zeropad'] != interp_params_RI['zeropad'] domain_change_R = self.slice_params != slice_params domain_change_I = self.inplane_params != inplane_params domain_change_RI = self.proj_params != proj_params if domain_change_RI: proj_quad = {} proj_quad['resolution'] = max(resolution_R,resolution_I) proj_quad['degree_R'] = degree_R proj_quad['degree_I'] = degree_I proj_quad['symop'] = sym proj_quad['dir'],proj_quad['W_R'] = quad_R.get_quad_points(degree_R,proj_quad['symop']) proj_quad['W_R'] = np.require(proj_quad['W_R'], dtype=np.float32) proj_quad['thetas'] = np.linspace(0, 2.0*np.pi, degree_I, endpoint=False) proj_quad['thetas'] += proj_quad['thetas'][1]/2.0 proj_quad['W_I'] = np.require((2.0*np.pi/float(degree_I))*np.ones((degree_I,)),dtype=np.float32) self.quad_domain_RI = quadrature.FixedSO3Domain( proj_quad['dir'], -proj_quad['thetas'], proj_quad['resolution'], sym=sym) self.N_RI = len(self.quad_domain_RI) self.proj_quad = proj_quad self.proj_params = proj_params if domain_change_R: self.quad_domain_R = quadrature.FixedSphereDomain(proj_quad['dir'], proj_quad['resolution'], sym=sym) self.N_R = len(self.quad_domain_R) self.sampler_R.setup(params, self.N_D, self.N_D_Train, self.quad_domain_R) self.slice_params = slice_params if domain_change_I: self.quad_domain_I = quadrature.FixedCircleDomain(proj_quad['thetas'], proj_quad['resolution']) self.N_I = len(self.quad_domain_I) self.sampler_I.setup(params, self.N_D, self.N_D_Train, self.quad_domain_I) self.inplane_params = inplane_params if domain_change_RI or interp_change_RI: symorder = 1 if self.proj_quad['symop'] is None else self.proj_quad['symop'].get_order() print(" Projection Ops: %d (%d slice, %d inplane), " % (self.N_RI, self.N_R, self.N_I)); sys.stdout.flush() if self.N_RI*symorder < self.otf_thresh_RI: self.using_precomp_slicing = True print("generated in", end=''); sys.stdout.flush() self.slice_ops = self.quad_domain_RI.compute_operator(interp_params_RI) print(" {0} secs.".format(time.time() - tic)) Gsz = (self.N_RI,self.N_T) self.G = np.empty(Gsz, dtype=self.G_datatype) self.slices = np.empty(np.prod(Gsz), dtype=np.complex64) else: self.using_precomp_slicing = False print("generating OTF.") self.slice_ops = None self.G = np.empty((N,N,N),dtype=self.G_datatype) self.slices = None self.using_precomp_inplane = False self.inplane_ops = None self.proj_interp = interp_params_RI if transform_change: if dopremult_R: premult = cryoops.compute_premultiplier(self.N + 2*int(interp_params_RI['zeropad']*(self.N/2)), interp_params_RI['kern'],interp_params_RI['kernsize']) premult = premult.reshape((-1,1,1)) * premult.reshape((1,-1,1)) * premult.reshape((1,1,-1)) else: premult = None self.slice_premult = premult self.slice_zeropad = interp_params_RI['zeropad'] assert interp_params_RI['zeropad'] == 0,'Zero padding for slicing not yet implemented'
def dowork(self): """Do one atom of work. I.E. Execute one minibatch""" timing = {} # Time each minibatch tic_mini = time.time() self.iteration += 1 # Fetch the current batches trainbatch = self.cryodata.get_next_minibatch(self.cparams.get('shuffle_minibatches',True)) # Get the current epoch cepoch = self.cryodata.get_epoch(frac=True) epoch = self.cryodata.get_epoch() num_data = self.cryodata.N_D_Train # Evaluate the parameters self.eval_params() timing['setup'] = time.time() - tic_mini # Do hyperparameter learning if self.cparams.get('learn_params',False): tic_learn = time.time() if self.cparams.get('learn_prior_params',True): tic_learn_prior = time.time() self.prior_func.learn_params(self.params, self.cparams, M=self.M, fM=self.fM) timing['learn_prior'] = time.time() - tic_learn_prior if self.cparams.get('learn_likelihood_params',True): tic_learn_like = time.time() self.like_func.learn_params(self.params, self.cparams, M=self.M, fM=self.fM) timing['learn_like'] = time.time() - tic_learn_like if self.cparams.get('learn_prior_params',True) or self.cparams.get('learn_likelihood_params',True): timing['learn_total'] = time.time() - tic_learn # Time each epoch if self.tic_epoch == None: self.ostream("Epoch: %d" % epoch) self.tic_epoch = (tic_mini,epoch) elif self.tic_epoch[1] != epoch: self.ostream("Epoch Total - %.6f seconds " % \ (tic_mini - self.tic_epoch[0])) self.tic_epoch = (tic_mini,epoch) sym = get_symmetryop(self.cparams.get('symmetry',None)) if sym is not None: self.obj.ws[1] = 1.0/sym.get_order() tic_mstats = time.time() self.ostream(self.cparams['name']," Iteration:", self.iteration,\ " Epoch:", epoch, " Host:", socket.gethostname()) # Compute density statistics N = self.cryodata.N M_sum = self.M.sum(dtype=n.float64) M_zeros = (self.M == 0).sum() M_mean = M_sum/N**3 M_max = self.M.max() M_min = self.M.min() # self.ostream(" Density (min/max/avg/sum/zeros): " + # "%.2e / %.2e / %.2e / %.2e / %g " % # (M_min, M_max, M_mean, M_sum, M_zeros)) self.statout.output(total_density=[M_sum], avg_density=[M_mean], nonzero_density=[M_zeros], max_density=[M_max], min_density=[M_min]) timing['density_stats'] = time.time() - tic_mstats # evaluate test batch if requested if self.iteration <= 1 or self.cparams.get('evaluate_test_set',self.iteration%5): tic_test = time.time() testbatch = self.cryodata.get_testbatch() self.obj.set_data(self.cparams,testbatch) testLogP, res_test = self.obj.eval(M=self.M, fM=self.fM, compute_gradient=False) self.outputbatchinfo(testbatch, res_test, testLogP, 'test', 'Test') timing['test_batch'] = time.time() - tic_test else: testLogP, res_test = None, None # setup the wrapper for the objective function tic_objsetup = time.time() self.obj.set_data(self.cparams,trainbatch) self.obj_wrapper.set_objective(self.obj) x0 = self.obj_wrapper.set_density(self.M,self.fM) evalobj = self.obj_wrapper.eval_obj timing['obj_setup'] = time.time() - tic_objsetup # Get step size self.num_data_evals += trainbatch['N_M'] # at least one gradient tic_objstep = time.time() trainLogP, dlogP, v, res_train, extra_num_data = self.step.do_step(x0, self.cparams, self.cryodata, evalobj, batch=trainbatch) # Apply the step x = x0 + v timing['step'] = time.time() - tic_objstep # Convert from parameters to value tic_stepfinalize = time.time() prevM = n.copy(self.M) self.M, self.fM = self.obj_wrapper.convert_parameter(x,comp_real=True) apply_sym = sym is not None and self.cparams.get('perfect_symmetry',True) and self.cparams.get('apply_symmetry',True) if apply_sym: self.M = sym.apply(self.M) # Truncate the density to bounds if they exist if self.cparams['density_lb'] is not None: n.maximum(self.M,self.cparams['density_lb']*self.cparams['modelscale'],out=self.M) if self.cparams['density_ub'] is not None: n.minimum(self.M,self.cparams['density_ub']*self.cparams['modelscale'],out=self.M) # Compute net change self.dM = prevM - self.M # Convert to fourier space (may not be required) if self.fM is None or apply_sym \ or self.cparams['density_lb'] != None \ or self.cparams['density_ub'] != None: self.fM = density.real_to_fspace(self.M) timing['step_finalize'] = time.time() - tic_stepfinalize # Compute step statistics tic_stepstats = time.time() step_size = n.linalg.norm(self.dM) grad_size = n.linalg.norm(dlogP) M_norm = n.linalg.norm(self.M) self.num_data_evals += extra_num_data inc_ratio = step_size / M_norm self.statout.output(step_size=[step_size], inc_ratio=[inc_ratio], grad_size=[grad_size], norm_density=[M_norm]) timing['step_stats'] = time.time() - tic_stepstats # Update import sampling distributions tic_isupdate = time.time() self.sampler_R.perform_update() self.sampler_I.perform_update() self.sampler_S.perform_update() self.diagout.output(global_phi_R=self.sampler_R.get_global_dist()) self.diagout.output(global_phi_I=self.sampler_I.get_global_dist()) self.diagout.output(global_phi_S=self.sampler_S.get_global_dist()) timing['is_update'] = time.time() - tic_isupdate # Output basic diagnostics tic_diagnostics = time.time() self.diagout.output(iteration=self.iteration, epoch=epoch, cepoch=cepoch) if self.logpost_history.N_sum != self.cryodata.N_batches: self.logpost_history.setup(trainLogP,self.cryodata.N_batches) self.logpost_history.set_value(trainbatch['id'],trainLogP) if self.like_history.N_sum != self.cryodata.N_batches: self.like_history.setup(res_train['L'],self.cryodata.N_batches) self.like_history.set_value(trainbatch['id'],res_train['L']) self.outputbatchinfo(trainbatch, res_train, trainLogP, 'train', 'Train') # Dump parameters here to catch the defaults used in evaluation self.diagout.output(params=self.cparams, envelope_mle=self.like_func.get_envelope_mle(), sigma2_mle=self.like_func.get_sigma2_mle(), hostname=socket.gethostname()) self.statout.output(num_data=[num_data], num_data_evals=[self.num_data_evals], iteration=[self.iteration], epoch=[epoch], cepoch=[cepoch], logp=[self.logpost_history.get_mean()], like=[self.like_history.get_mean()], sigma=[self.like_func.get_rmse()], time=[time.time()]) timing['diagnostics'] = time.time() - tic_diagnostics checkpoint_it = self.iteration % self.cparams.get('checkpoint_frequency',50) == 0 save_it = checkpoint_it or self.cparams['save_iteration'] or \ time.time() - self.last_save > self.cparams.get('save_time',n.inf) if save_it: tic_save = time.time() self.last_save = tic_save if self.io_queue.qsize(): print "Warning: IO queue has become backlogged with {0} remaining, waiting for it to clear".format(self.io_queue.qsize()) self.io_queue.join() self.io_queue.put(( 'pkl', self.statout.fname, copy(self.statout.outdict) )) self.io_queue.put(( 'pkl', self.diagout.fname, deepcopy(self.diagout.outdict) )) self.io_queue.put(( 'pkl', self.likeout.fname, deepcopy(self.likeout.outdict) )) self.io_queue.put(( 'mrc', opj(self.outbase,'model.mrc'), \ (n.require(self.M,dtype=density.real_t),self.voxel_size) )) self.io_queue.put(( 'mrc', opj(self.outbase,'dmodel.mrc'), \ (n.require(self.dM,dtype=density.real_t),self.voxel_size) )) if checkpoint_it: self.io_queue.put(( 'cp', self.diagout.fname, self.diagout.fname+'-{0:06}'.format(self.iteration) )) self.io_queue.put(( 'cp', self.likeout.fname, self.likeout.fname+'-{0:06}'.format(self.iteration) )) self.io_queue.put(( 'cp', opj(self.outbase,'model.mrc'), opj(self.outbase,'model-{0:06}.mrc'.format(self.iteration)) )) timing['save'] = time.time() - tic_save time_total = time.time() - tic_mini self.ostream(" Minibatch Total - %.2f seconds Total Runtime - %s" % (time_total, format_timedelta(datetime.now() - self.startdatetime) )) return self.iteration < self.cparams.get('max_iterations',n.inf) and \ cepoch < self.cparams.get('max_epochs',n.inf)
def __init__(self, expbase, cmdparams=None): """cryodata is a CryoData instance. expbase is a path to the base of folder where this experiment's files will be stored. The folder above expbase will also be searched for .params files. These will be loaded first.""" BackgroundWorker.__init__(self) # Create a background thread which handles IO self.io_queue = Queue() self.io_thread = Thread(target=self.ioworker) self.io_thread.daemon = True self.io_thread.start() # General setup ---------------------------------------------------- self.expbase = expbase self.outbase = None # Paramter setup --------------------------------------------------- # search above expbase for params files _,_,filenames = os.walk(opj(expbase,'../')).next() self.paramfiles = [opj(opj(expbase,'../'), fname) \ for fname in filenames if fname.endswith('.params')] # search expbase for params files _,_,filenames = os.walk(opj(expbase)).next() self.paramfiles += [opj(expbase,fname) \ for fname in filenames if fname.endswith('.params')] if 'local.params' in filenames: self.paramfiles += [opj(expbase,'local.params')] # load parameter files self.params = Params(self.paramfiles) self.cparams = None if cmdparams is not None: # Set parameter specified on the command line for k,v in cmdparams.iteritems(): self.params[k] = v # Dataset setup ------------------------------------------------------- self.imgpath = self.params['inpath'] psize = self.params['resolution'] if not isinstance(self.imgpath,list): imgstk = MRCImageStack(self.imgpath,psize) else: imgstk = CombinedImageStack([MRCImageStack(cimgpath,psize) for cimgpath in self.imgpath]) if self.params.get('float_images',True): imgstk.float_images() self.ctfpath = self.params['ctfpath'] mscope_params = self.params['microscope_params'] if not isinstance(self.ctfpath,list): ctfstk = CTFStack(self.ctfpath,mscope_params) else: ctfstk = CombinedCTFStack([CTFStack(cctfpath,mscope_params) for cctfpath in self.ctfpath]) self.cryodata = CryoDataset(imgstk,ctfstk) self.cryodata.compute_noise_statistics() if self.params.get('window_images',True): imgstk.window_images() minibatch_size = self.params['minisize'] testset_size = self.params['test_imgs'] partition = self.params.get('partition',0) num_partitions = self.params.get('num_partitions',1) seed = self.params['random_seed'] if isinstance(partition,str): partition = eval(partition) if isinstance(num_partitions,str): num_partitions = eval(num_partitions) if isinstance(seed,str): seed = eval(seed) self.cryodata.divide_dataset(minibatch_size,testset_size,partition,num_partitions,seed) self.cryodata.set_datasign(self.params.get('datasign','auto')) if self.params.get('normalize_data',True): self.cryodata.normalize_dataset() self.voxel_size = self.cryodata.pixel_size # Iterations setup ------------------------------------------------- self.iteration = 0 self.tic_epoch = None self.num_data_evals = 0 self.eval_params() outdir = self.cparams.get('outdir',None) if outdir is None: if self.cparams.get('num_partitions',1) > 1: outdir = 'partition{0}'.format(self.cparams['partition']) else: outdir = '' self.outbase = opj(self.expbase,outdir) if not os.path.isdir(self.outbase): os.makedirs(self.outbase) # Output setup ----------------------------------------------------- self.ostream = OutputStream(opj(self.outbase,'stdout')) self.ostream(80*"=") self.ostream("Experiment: " + expbase + \ " Kernel: " + self.params['kernel']) self.ostream("Started on " + socket.gethostname() + \ " At: " + time.strftime('%B %d %Y: %I:%M:%S %p')) self.ostream("Git SHA1: " + gitutil.git_get_SHA1()) self.ostream(80*"=") gitutil.git_info_dump(opj(self.outbase, 'gitinfo')) self.startdatetime = datetime.now() # for diagnostics and parameters self.diagout = Output(opj(self.outbase, 'diag'),runningout=False) # for stats (per image etc) self.statout = Output(opj(self.outbase, 'stat'),runningout=True) # for likelihoods of individual images self.likeout = Output(opj(self.outbase, 'like'),runningout=False) self.img_likes = n.empty(self.cryodata.N_D) self.img_likes[:] = n.inf # optimization state vars ------------------------------------------ init_model = self.cparams.get('init_model',None) if init_model is not None: filename = init_model if filename.upper().endswith('.MRC'): M = readMRC(filename) else: with open(filename) as fp: M = cPickle.load(fp) if type(M)==list: M = M[-1]['M'] if M.shape != 3*(self.cryodata.N,): M = cryoem.resize_ndarray(M,3*(self.cryodata.N,),axes=(0,1,2)) else: init_seed = self.cparams.get('init_random_seed',0) + self.cparams.get('partition',0) print "Randomly generating initial density (init_random_seed = {0})...".format(init_seed), ; sys.stdout.flush() tic = time.time() M = cryoem.generate_phantom_density(self.cryodata.N, 0.95*self.cryodata.N/2.0, \ 5*self.cryodata.N/128.0, 30, seed=init_seed) print "done in {0}s".format(time.time() - tic) tic = time.time() print "Windowing and aligning initial density...", ; sys.stdout.flush() # window the initial density wfunc = self.cparams.get('init_window','circle') cryoem.window(M,wfunc) # Center and orient the initial density cryoem.align_density(M) print "done in {0:.2f}s".format(time.time() - tic) # apply the symmetry operator init_sym = get_symmetryop(self.cparams.get('init_symmetry',self.cparams.get('symmetry',None))) if init_sym is not None: tic = time.time() print "Applying symmetry operator...", ; sys.stdout.flush() M = init_sym.apply(M) print "done in {0:.2f}s".format(time.time() - tic) tic = time.time() print "Scaling initial model...", ; sys.stdout.flush() modelscale = self.cparams.get('modelscale','auto') mleDC, _, mleDC_est_std = self.cryodata.get_dc_estimate() if modelscale == 'auto': # Err on the side of a weaker prior by using a larger value for modelscale modelscale = (n.abs(mleDC) + 2*mleDC_est_std)/self.cryodata.N print "estimated modelscale = {0:.3g}...".format(modelscale), ; sys.stdout.flush() self.params['modelscale'] = modelscale self.cparams['modelscale'] = modelscale M *= modelscale/M.sum() print "done in {0:.2f}s".format(time.time() - tic) if mleDC_est_std/n.abs(mleDC) > 0.05: print " WARNING: the DC component estimate has a high relative variance, it may be inaccurate!" if ((modelscale*self.cryodata.N - n.abs(mleDC)) / mleDC_est_std) > 3: print " WARNING: the selected modelscale value is more than 3 std devs different than the estimated one. Be sure this is correct." self.M = n.require(M,dtype=density.real_t) self.fM = density.real_to_fspace(M) self.dM = density.zeros_like(self.M) self.step = eval(self.cparams['optim_algo']) self.step.setup(self.cparams, self.diagout, self.statout, self.ostream) # Objective function setup -------------------------------------------- param_type = self.cparams.get('parameterization','real') cplx_param = param_type in ['complex','complex_coeff','complex_herm_coeff'] self.like_func = eval_objective(self.cparams['likelihood']) self.prior_func = eval_objective(self.cparams['prior']) if self.cparams.get('penalty',None) is not None: self.penalty_func = eval_objective(self.cparams['penalty']) prior_func = SumObjectives(self.prior_func.fspace, \ [self.penalty_func,self.prior_func], None) else: prior_func = self.prior_func self.obj = SumObjectives(cplx_param, [self.like_func,prior_func], [None,None]) self.obj.setup(self.cparams, self.diagout, self.statout, self.ostream) self.obj.set_dataset(self.cryodata) self.obj_wrapper = ObjectiveWrapper(param_type) self.last_save = time.time() self.logpost_history = FiniteRunningSum() self.like_history = FiniteRunningSum() # Importance Samplers ------------------------------------------------- self.is_sym = get_symmetryop(self.cparams.get('is_symmetry',self.cparams.get('symmetry',None))) self.sampler_R = FixedFisherImportanceSampler('_R',self.is_sym) self.sampler_I = FixedFisherImportanceSampler('_I') self.sampler_S = FixedGaussianImportanceSampler('_S') self.like_func.set_samplers(sampler_R=self.sampler_R,sampler_I=self.sampler_I,sampler_S=self.sampler_S)
def sagd_dostep(data_dir, model_file, use_angular_correlation=True): cryodata, (M, fM), cparams = sagd_init(data_dir, model_file, use_angular_correlation) iteration = cparams['iteration'] sagd_params = { 'iteration': iteration, 'exp_path': 'exp/', 'num_batches': cryodata.N_batches, 'sagd_linesearch': 'max_frequency_changed or ((iteration%5 == 0) if iteration < 2500 else \ (iteration%3 == 0) if iteration < 5000 else \ True)' , 'sagd_linesearch_accuracy': 1.01 if iteration < 10 else \ 1.10 if iteration < 2500 else \ 1.25 if iteration < 5000 else \ 1.50, 'sagd_linesearch_maxits': 5 if iteration < 2500 else 3, 'sagd_incL': 1.0, 'sagd_momentum': 1 - 1.0/(1.0 + 0.1 * iteration), # 'sagd_learnrate': '1.0/min(16.0,2**(num_max_frequency_changes))', 'shuffle_minibatches': 'iteration >= 1000', } # initial logging if os.path.exists('exp/sagd_L0.pkl'): os.remove('exp/sagd_L0.pkl') # for diagnostics and parameters, # for stats (per image etc), # for likelihoods of individual images diagout = Output(os.path.join('exp', 'diag'), runningout=False) statout = Output(os.path.join('exp', 'stat'), runningout=True) likeout = Output(os.path.join('exp', 'like'), runningout=False) ostream = OutputStream(None) # Setup SAGD optimizer step = SAGDStep() step.setup(cparams, diagout, statout, ostream) # Objective function setup param_type = cparams.get('parameterization', 'real') cplx_param = param_type in [ 'complex', 'complex_coeff', 'complex_herm_coeff' ] like_func = eval_objective(cparams['likelihood']) prior_func = eval_objective(cparams['prior']) obj = SumObjectives(cplx_param, [like_func, prior_func], [None, None]) obj.setup(cparams, diagout, statout, ostream) obj.set_dataset(cryodata) obj_wrapper = ObjectiveWrapper(param_type) is_sym = get_symmetryop( cparams.get('is_symmetry', cparams.get('symmetry', None))) sampler_R = FixedFisherImportanceSampler('_R', is_sym) sampler_I = FixedFisherImportanceSampler('_I') # sampler_S = FixedGaussianImportanceSampler('_S') sampler_S = None like_func.set_samplers(sampler_R=sampler_R, sampler_I=sampler_I, sampler_S=sampler_S) # Start iteration num_data_evals = 0 num_iterations = 10 for i in range(num_iterations): cparams['iteration'] = i sagd_params['iteration'] = i print('Iteration #:', i) minibatch = cryodata.get_next_minibatch(True) num_data_evals += minibatch['N_M'] # setup the wrapper for the objective function obj.set_data(cparams, minibatch) obj_wrapper.set_objective(obj) x0 = obj_wrapper.set_density(M, fM) evalobj = obj_wrapper.eval_obj # Get step size trainLogP, dlogP, v, res_train, extra_num_data = \ step.do_step(x0, sagd_params, cryodata, evalobj, batch=minibatch) # print('trainLogP:', trainLogP) # print('dlogP:', dlogP.shape) # print('v:', v.shape) # print('res_train:', res_train.keys()) # dict_keys(['CV2_R', 'CV2_I', 'Evar_like', 'Evar_prior', 'sigma2_est', 'correlation', 'power', 'like', 'N_R_sampled', 'N_I_sampled', 'N_Total_sampled', 'totallike_logscale', 'kern_timing', 'angular_correlation_timing', 'like_timing', 'N_R', 'N_I', 'N_Total', 'N_R_sampled_total', 'N_I_sampled_total', 'N_Total_sampled_total', 'L', 'all_logPs', 'all_dlogPs']) # print('res_train L:', res_train['L']) # print('res_train like:', res_train['like']) # print('extra_num_data:', extra_num_data) # Apply the step x = x0 + v # Convert from parameters to value prevM = np.copy(M) M, fM = obj_wrapper.convert_parameter(x, comp_real=True) # Compute net change dM = prevM - M # Compute step statistics step_size = np.linalg.norm(dM) grad_size = np.linalg.norm(dlogP) M_norm = np.linalg.norm(M) num_data_evals += extra_num_data inc_ratio = step_size / M_norm # Update import sampling distributions sampler_R.perform_update() sampler_I.perform_update()
def updatevis(self, levels=[0.2,0.5,0.8]): if self.M is None or self.diag is None or self.stat is None: return cdiag = self.diag cparams = cdiag['params'] sym = get_symmetryop(cparams.get('symmetry',None)) quad_sym = sym if cparams.get('perfect_symmetry',True) else None resolution = cparams['voxel_size'] name = cparams['name'] maxfreq = cparams['max_frequency'] N = self.M.shape[0] rad_cutoff = cparams.get('rad_cutoff', 1.0) rad = min(rad_cutoff,maxfreq*2.0*resolution) # Show objective function self.show_objective_plot(self.get_figure('stats')) # Show information about noise and error self.show_error_plot(self.get_figure('error')) self.show_noise_plot(self.get_figure('noise')) # Plot the envelope function if we have the info if 'envelope_mle' in cdiag: self.show_envelope_plot(self.get_figure('envelope')) else: self.close_figure('envelope') if sym is None: assert quad_sym is None alignedM,R = cryoem.align_density(self.M) if self.show_grad: aligneddM = cryoem.rotate_density(self.dM,R) else: aligneddM = None else: alignedM, aligneddM = self.M, self.dM R = np.identity(3) self.alignedM,self.aligneddM,self.alignedR = alignedM,aligneddM,R self.fM = density.real_to_fspace(self.M) self.figMslices.set_data(alignedM) glbl_phi_R = np.array([cdiag['global_phi_R']]).ravel() if len(glbl_phi_R) == 1: glbl_phi_R = None glbl_phi_I = cdiag['global_phi_I'] if 'global_phi_S' in cdiag: glbl_phi_S = cdiag['global_phi_S'] else: glbl_phi_S = None # Get direction quadrature quad_R = quadrature.quad_schemes[('dir',cparams.get('quad_type_R','sk97'))] quad_degree_R = cparams.get('quad_degree_R','auto') if quad_degree_R == 'auto': usFactor_R = cparams.get('quad_undersample_R', cparams.get('quad_undersample',1.0)) quad_degree_R,_ = quad_R.compute_degree(N,rad,usFactor_R) origlebDirs,_ = quad_R.get_quad_points(quad_degree_R,quad_sym) lebDirs = np.dot(origlebDirs,R) vmax_R = 5.0/len(glbl_phi_R) # Get shift quadrature if 'global_phi_S' in cdiag: quad_S = quadrature.quad_schemes[('shift',cparams.get('quad_type_S','hermite'))] quad_degree_S = cparams.get('quad_degree_S','auto') if quad_degree_S == 'auto': usFactor_S = cparams.get('quad_undersample_S', cparams.get('quad_undersample',1.0)) quad_degree_S = quad_S.get_degree(N,rad, cparams['quad_shiftsigma']/resolution, cparams['quad_shiftextent']/resolution, usFactor_S) pts_S,_ = quad_S.get_quad_points(quad_degree_S, cparams['quad_shiftsigma']/resolution, cparams['quad_shiftextent']/resolution, cparams.get('quad_shifttrunc','circ')) vmax_S = 5.0/len(glbl_phi_S) else: pts_S = np.zeros_like([0]) vmax_S = None # Density visualization mlab.figure(self.fig1) mlab.clf() self.curr_contours = plot_density(alignedM, self.contours, levels) # dispPhiR = glbl_phi_R # dispDirs = lebDirs # plot_directions(alignedM.shape[0]*dispDirs + alignedM.shape[0]/2.0, # dispPhiR, # 0, vmax_R) mlab.view(focalpoint=[alignedM.shape[0]/2.0,alignedM.shape[0]/2.0,alignedM.shape[0]/2.0],distance=1.5*alignedM.shape[0]) if glbl_phi_R is not None: plt.figure(self.get_figure('global_is_dists').number) plt.clf() plot_importance_dists(name,lebDirs,pts_S*resolution,glbl_phi_R,glbl_phi_I,glbl_phi_S,vmax_R,vmax_S) if self.show_grad: # Statistics of dM self.figdMslices.set_data(aligneddM) plt.figure(self.get_figure('step_stats').number) plt.clf() plt.suptitle(name + ' Step Statistics') plt.subplot(1,2,1) plt.hist(self.dM.reshape((-1,)),bins=0.5*self.dM.shape[0],log=True) plt.title('Voxel Histogram') (fs,raps) = rot_power_spectra(self.dM,resolution=resolution) plt.subplot(1,2,2) plt.plot(fs/(N/2.0)/(2.0*resolution),raps,label='RAPS') plt.plot((rad/(2.0*resolution))*np.ones((2,)), np.array([raps[raps > 0].min(),raps.max()])) plt.yscale('log') plt.title('RAPS Step') if not self.extra_plots: self.close_figure('density_stats') return # Statistics of M self.show_density_plot(self.get_figure('density_stats'))
def updatevis(self, levels=[0.2,0.5,0.8]): if self.M is None or self.diag is None or self.stat is None: return cdiag = self.diag cparams = cdiag['params'] sym = get_symmetryop(cparams.get('symmetry',None)) quad_sym = sym if cparams.get('perfect_symmetry',True) else None resolution = cparams['voxel_size'] name = cparams['name'] maxfreq = cparams['max_frequency'] N = self.M.shape[0] rad_cutoff = cparams.get('rad_cutoff', 1.0) rad = min(rad_cutoff,maxfreq*2.0*resolution) # Show objective function self.show_objective_plot(self.get_figure('stats')) # Show information about noise and error self.show_error_plot(self.get_figure('error')) self.show_noise_plot(self.get_figure('noise')) # Plot the envelope function if we have the info if 'envelope_mle' in cdiag: self.show_envelope_plot(self.get_figure('envelope')) else: self.close_figure('envelope') if sym is None: assert quad_sym is None alignedM,R = c.align_density(self.M) if self.show_grad: aligneddM = c.rotate_density(self.dM,R) else: aligneddM = None else: alignedM, aligneddM = self.M, self.dM R = n.identity(3) self.alignedM,self.aligneddM,self.alignedR = alignedM,aligneddM,R self.fM = density.real_to_fspace(self.M) self.figMslices.set_data(alignedM) glbl_phi_R = n.array([cdiag['global_phi_R']]).ravel() if len(glbl_phi_R) == 1: glbl_phi_R = None glbl_phi_I = cdiag['global_phi_I'] glbl_phi_S = cdiag['global_phi_S'] # Get direction quadrature quad_R = quadrature.quad_schemes[('dir',cparams.get('quad_type_R','sk97'))] quad_degree_R = cparams.get('quad_degree_R','auto') if quad_degree_R == 'auto': usFactor_R = cparams.get('quad_undersample_R', cparams.get('quad_undersample',1.0)) quad_degree_R,_ = quad_R.compute_degree(N,rad,usFactor_R) origlebDirs,_ = quad_R.get_quad_points(quad_degree_R,quad_sym) lebDirs = n.dot(origlebDirs,R) # Get shift quadrature quad_S = quadrature.quad_schemes[('shift',cparams.get('quad_type_S','hermite'))] quad_degree_S = cparams.get('quad_degree_S','auto') if quad_degree_S == 'auto': usFactor_S = cparams.get('quad_undersample_S', cparams.get('quad_undersample',1.0)) quad_degree_S = quad_S.get_degree(N,rad, cparams['quad_shiftsigma']/resolution, cparams['quad_shiftextent']/resolution, usFactor_S) pts_S,_ = quad_S.get_quad_points(quad_degree_S, cparams['quad_shiftsigma']/resolution, cparams['quad_shiftextent']/resolution, cparams.get('quad_shifttrunc','circ')) vmax_R = 5.0/len(glbl_phi_R) vmax_S = 5.0/len(glbl_phi_S) # Density visualization mlab.figure(self.fig1) mlab.clf() self.curr_contours = plot_density(alignedM, self.contours, levels) # dispPhiR = glbl_phi_R # dispDirs = lebDirs # plot_directions(alignedM.shape[0]*dispDirs + alignedM.shape[0]/2.0, # dispPhiR, # 0, vmax_R) mlab.view(focalpoint=[alignedM.shape[0]/2.0,alignedM.shape[0]/2.0,alignedM.shape[0]/2.0],distance=1.5*alignedM.shape[0]) if glbl_phi_R is not None: plt.figure(self.get_figure('global_is_dists').number) plt.clf() plot_importance_dists(name,lebDirs,pts_S*resolution,glbl_phi_R,glbl_phi_I,glbl_phi_S,vmax_R,vmax_S) if self.show_grad: # Statistics of dM self.figdMslices.set_data(aligneddM) plt.figure(self.get_figure('step_stats').number) plt.clf() plt.suptitle(name + ' Step Statistics') plt.subplot(1,2,1) plt.hist(self.dM.reshape((-1,)),bins=0.5*self.dM.shape[0],log=True) plt.title('Voxel Histogram') (fs,raps) = rot_power_spectra(self.dM,resolution=resolution) plt.subplot(1,2,2) plt.plot(fs/(N/2.0)/(2.0*resolution),raps,label='RAPS') plt.plot((rad/(2.0*resolution))*n.ones((2,)), n.array([raps[raps > 0].min(),raps.max()])) plt.yscale('log') plt.title('RAPS Step') if not self.extra_plots: self.close_figure('density_stats') return # Statistics of M self.show_density_plot(self.get_figure('density_stats'))
if it<=3: if subd_R: r_thresh = min(n.percentile(errs, 100*r_factor),best_thresh) rs = rs[errs <= r_thresh] rs, dr = geometry.subdivde(rs, dr) gc.collect() if radwn == max_radwn and ((not subd_R) or \ (delta_R < delta_R_thresh)): break if dr < radwn_ang: radwn = n.minimum(radwn + 5, 10)#max_radwn) """ print(time.time()-tic)/5000. return best_R, n.zeros(3) import os import symmetry os.environ["CRYOSPARC_ROOT_DIR"] = "../runtest" symop = symmetry.get_symmetryop('C1') data_path_abs='../runtest/vol.mrc' N = 256//8 map_r = readMRC(data_path_abs) input_structure_datas = [] input_structure_datas.append(map_r) initmodel = input_structure_datas[0] initmodel = fourier.resample_real(initmodel, N) align_symmetryEEE(initmodel,symop, 10, verbose=1, cuda_dev=0)
def load_kernel(data_dir, model_file, use_angular_correlation=False, sample_shifts=False): data_params = { 'dataset_name': "1AON", 'inpath': os.path.join(data_dir, 'imgdata.mrc'), 'ctfpath': os.path.join(data_dir, 'defocus.txt'), 'microscope_params': {'akv': 200, 'wgh': 0.07, 'cs': 2.0}, 'resolution': 2.8, 'sigma': 'noise_std', 'sigma_out': 'data_std', 'minisize': 20, 'test_imgs': 20, 'partition': 0, 'num_partitions': 0, 'random_seed': 1, # 'symmetry': 'C7' } print("Loading dataset %s" % data_dir) cryodata, _ = dataset_loading_test(data_params) # mleDC, _, mleDC_est_std = cryodata.get_dc_estimate() # modelscale = (np.abs(mleDC) + 2*mleDC_est_std)/cryodata.N modelscale = 1.0 if model_file is not None: print("Loading density map %s" % model_file) M = mrc.readMRC(model_file) else: print("Generating random initial density map ...") M = cryoem.generate_phantom_density(cryodata.N, 0.95 * cryodata.N / 2.0, \ 5 * cryodata.N / 128.0, 30, seed=0) M *= modelscale/M.sum() slice_interp = {'kern': 'lanczos', 'kernsize': 4, 'zeropad': 0, 'dopremult': True} fM = M minibatch = cryodata.get_next_minibatch(shuffle_minibatches=False) is_sym = get_symmetryop(data_params.get('symmetry',None)) sampler_R = FixedFisherImportanceSampler('_R', is_sym) sampler_I = FixedFisherImportanceSampler('_I') sampler_S = None cparams = { 'use_angular_correlation': use_angular_correlation, 'iteration': 0, 'pixel_size': cryodata.pixel_size, 'max_frequency': 0.02, 'interp_kernel_R': 'lanczos', 'interp_kernel_size_R': 4, 'interp_zeropad_R': 0, 'interp_premult_R': True, 'interp_kernel_I': 'lanczos', 'interp_kernel_size_I': 4, 'interp_zeropad_I': 0.0, 'interp_premult_I': True, 'quad_shiftsigma': 10, 'quad_shiftextent': 60, 'sigma': cryodata.noise_var, # 'symmetry': 'C7' } kernel = UnknownRSThreadedCPUKernel() kernel.setup(cparams, None, None, None) kernel.set_samplers(sampler_R, sampler_I, sampler_S) kernel.set_dataset(cryodata) kernel.precomp_slices = None kernel.set_data(cparams, minibatch) kernel.using_precomp_slicing = False kernel.using_precomp_inplane = False kernel.M = M kernel.fM = fM return kernel