class UnknownRSLikelihood(Objective): def __init__(self): Objective.__init__(self,False) def setup(self,params,diagout,statout,ostream): Objective.setup(self,params,diagout,statout,ostream) if params['kernel'] == 'multicpu': from .cpu_kernel import UnknownRSThreadedCPUKernel self.kernel = UnknownRSThreadedCPUKernel() else: assert False self.kernel.setup(params,diagout,statout,ostream) def get_sigma2_map(self,nu,model,rotavg=True): N = self.cryodata.N N_D = float(self.cryodata.N_D_Train) num_batches = float(self.cryodata.num_batches) base_sigma2 = self.cryodata.get_noise_std()**2 mean_sigma2 = self.error_history.get_mean().reshape((N,N)) mean_mask = self.mask_history.get_mean().reshape((N,N)) mask_w = self.mask_history.get_wsum() * (N_D / num_batches) if rotavg: mean_sigma2 = cryoem.rotational_average(mean_sigma2,normalize=True,doexpand=True) mean_mask = cryoem.rotational_average(mean_mask,normalize=False,doexpand=True) obsw = mask_w * mean_mask map_sigma2 = (mean_sigma2 * obsw + nu * base_sigma2) / (obsw + nu) assert np.all(np.isfinite(map_sigma2)) if model == 'coloured': map_sigma2 = map_sigma2 elif model == 'white': map_sigma2 = np.mean(map_sigma2) else: assert False, 'model must be one of white or coloured' return map_sigma2 def get_envelope_map(self,sigma2,rho,env_lb=None,env_ub=None,minFreq=None,bfactor=None,rotavg=True): N = self.cryodata.N N_D = float(self.cryodata.N_D_Train) num_batches = float(self.cryodata.num_batches) psize = self.params['pixel_size'] mean_corr = self.correlation_history.get_mean().reshape((N,N)) mean_power = self.power_history.get_mean().reshape((N,N)) mean_mask = self.mask_history.get_mean().reshape((N,N)) mask_w = self.mask_history.get_wsum() * (N_D / num_batches) if rotavg: mean_corr = cryoem.rotational_average(mean_corr,normalize=True,doexpand=True) mean_power = cryoem.rotational_average(mean_power,normalize=True,doexpand=True) mean_mask = cryoem.rotational_average(mean_mask,normalize=False,doexpand=True) if isinstance(sigma2,np.ndarray): sigma2 = sigma2.reshape((N,N)) if bfactor is not None: coords = gencoords(N,2).reshape((N**2,2)) freqs = np.sqrt(np.sum(coords**2,axis=1))/(psize*N) prior_envelope = ctf.envelope_function(freqs,bfactor).reshape((N,N)) else: prior_envelope = 1.0 obsw = (mask_w * mean_mask / sigma2) exp_env = (mean_corr * obsw + prior_envelope*rho) / (mean_power * obsw + rho) if minFreq is not None: # Only consider envelope parameters for frequencies above a threshold minRad = minFreq*2.0*psize _, _, minRadMask = gencoords(N, 2, minRad, True) exp_env[minRadMask.reshape((N,N))] = 1.0 if env_lb is not None or env_ub is not None: np.clip(exp_env,env_lb,env_ub,out=exp_env) return exp_env def get_rmse(self): return np.sqrt(self.get_sigma2_mle().mean()) def get_sigma2_mle(self,noise_model='coloured'): N = self.cryodata.N sigma2 = self.error_history.get_mean() mean_mask = self.mask_history.get_mean() # mse = mean_mask*sigma2 + (1-mean_mask)*self.cryodata.data['imgvar_freq'] mse = mean_mask*sigma2 + (1-mean_mask)*self.cryodata.data_var if noise_model == 'coloured': return mse.reshape((N,N)) elif noise_model == 'white': return mse.mean() def get_envelope_mle(self,rotavg=False): N = self.cryodata.N mean_corr = self.correlation_history.get_mean() mean_power = self.power_history.get_mean() mean_mask = self.mask_history.get_mean() if rotavg: mean_corr = cryoem.rotational_average(mean_corr.reshape((N,N)),doexpand=True) mean_power = cryoem.rotational_average(mean_power.reshape((N,N)),doexpand=True) mean_mask = cryoem.rotational_average(mean_mask.reshape((N,N)),doexpand=True) obs_mask = mean_mask > 0 exp_env = np.ones_like(mean_corr) exp_env[obs_mask] = (mean_corr[obs_mask] / mean_power[obs_mask]) return exp_env.reshape((N,N)) def set_samplers(self,sampler_R,sampler_I,sampler_S): self.kernel.set_samplers(sampler_R,sampler_I,sampler_S) def set_dataset(self,cryodata): Objective.set_dataset(self,cryodata) self.kernel.set_dataset(cryodata) self.error_history = FiniteRunningSum(second_order=False) self.correlation_history = FiniteRunningSum(second_order=False) self.power_history = FiniteRunningSum(second_order=False) self.mask_history = FiniteRunningSum(second_order=False) def set_data(self,cparams,minibatch): Objective.set_data(self,cparams,minibatch) self.kernel.set_data(cparams,minibatch) def eval(self,M=None, compute_gradient=True, fM=None, **kwargs): tic_start = time.time() if self.kernel.slice_premult is not None: pfM = density.real_to_fspace(self.kernel.slice_premult * M) else: pfM = density.real_to_fspace(M) pmtime = time.time() - tic_start ret = self.kernel.eval(fM=pfM,M=None,compute_gradient=compute_gradient) if not self.minibatch['test_batch'] and not kwargs.get('intermediate',False): tic_record = time.time() curr_var = ret[-1]['sigma2_est'] assert np.all(np.isfinite(curr_var)) if self.error_history.N_sum != self.cryodata.N_batches: self.error_history.setup(curr_var,self.cryodata.N_batches,allow_decay=False) self.error_history.set_value(self.minibatch['id'],curr_var) curr_corr = ret[-1]['correlation'] assert np.all(np.isfinite(curr_corr)) if self.correlation_history.N_sum != self.cryodata.N_batches: self.correlation_history.setup(curr_corr,self.cryodata.N_batches,allow_decay=False) self.correlation_history.set_value(self.minibatch['id'],curr_corr) curr_power = ret[-1]['power'] assert np.all(np.isfinite(curr_power)) if self.power_history.N_sum != self.cryodata.N_batches: self.power_history.setup(curr_power,self.cryodata.N_batches,allow_decay=False) self.power_history.set_value(self.minibatch['id'],curr_power) curr_mask = self.kernel.truncmask if self.mask_history.N_sum != self.cryodata.N_batches: self.mask_history.setup(np.require(curr_mask,dtype=np.float32),self.cryodata.N_batches,allow_decay=False) self.mask_history.set_value(self.minibatch['id'],curr_mask) ret[-1]['like_timing']['record'] = time.time() - tic_record if compute_gradient and self.kernel.slice_premult is not None: tic_record = time.time() ret = (ret[0],self.kernel.slice_premult * density.fspace_to_real(ret[1]),ret[2]) ret[-1]['like_timing']['premult'] = pmtime + time.time() - tic_record ret[-1]['like_timing']['total'] = time.time() - tic_start return ret
class SAGDStep(base.BaseStep): def __init__(self, ind_L=False, ng=False, alt_sum=True): base.BaseStep.__init__(self) self.g_history = FiniteRunningSum(second_order=ng) self.L = None self.alt_sum = alt_sum # FIXME: currently ignored self.ind_L = ind_L # FIXME: currently ignored self.precond = None self.ng = ng self.prev_max_freq = None def add_batch_gradient(self, batch, curr_g, params): mu = params.get('sagd_momentum', None) if self.g_history.N_sum != params['num_batches']: # Reset gradient g_history if minibatch size changes. self.g_history.setup(curr_g, batch['num_batches']) return self.g_history.set_value(batch['id'], curr_g, mu) def save_L0(self, params): with open(os.path.join(params['exp_path'], 'sagd_L0.pkl'), 'wb') as fp: pickle.dump(self.L, fp, protocol=2) def load_L0(self, params): try: with open(os.path.join(params['exp_path'], 'sagd_L0.pkl'), 'rb') as fp: L0 = pickle.load(fp) except: L0 = params.get('sagd_L0', 1.0) return L0 def do_step(self, x, params, cryodata, evalobj, batch, **kwargs): # ostream = self.ostream # ostream = None ostream = self.ostream if self.L is None else None inc_L = params.get('sagd_incL', False) do_ls = self.L is None or params.get('sagd_linesearch', False) max_ls_its = params.get('sagd_linesearch_maxits', 5) if self.L is not None else None grad_check = params.get('sagd_gradcheck', False) minStep = 1.05 if self.L is not None else 2 maxStep = 100.0 if self.L is not None else 10 optThresh = params.get('sagd_linesearch_accuracy', 1.01) # weight of the prior which biases the covariance estimate towards # lambda g2_lambdaw = params.get('sagd_lambdaw', 10) eps0 = params.get('sagd_learnrate', 1.0 / 16.0) reset_precond = params.get('sagd_reset_precond', False) F_max_range = params.get('sagd_precond_maxrange', 1000.0) use_saga = params.get('sagd_sagastep', False) if self.g_history is not None and reset_precond: self.g_history.reset_meansq() do_ls = True # Evaluate the gradient f, g, res_train = evalobj(x, all_grads=True) assert len(res_train['all_dlogPs']) == 2 g_like = res_train['all_dlogPs'][0] g_prior = res_train['all_dlogPs'][1] if use_saga: prev_g_hat = self.g_history.get_mean() # Add the gradient to the g_history prev_g_like, _ = self.add_batch_gradient(batch, g_like, params) # Get the current average g if not use_saga: curr_g_hat = self.g_history.get_mean() totghat = curr_g_hat + g_prior else: totghat = prev_g_hat + g_prior + (g_like - prev_g_like) if self.ng: # Update the preconditioner when we're going to do a linesearch if do_ls: curr_g_var = np.maximum(0, self.g_history.get_meansq()) curr_w2sum = self.g_history.get_w2sum() mean_var = np.mean(curr_g_var) if curr_w2sum > 1: F = mean_var * (g2_lambdaw / (curr_w2sum + g2_lambdaw)) + \ curr_g_var * (curr_w2sum / (curr_w2sum + g2_lambdaw)) else: F = mean_var minF = F.min() maxF = F.max() minF_trunc = np.maximum(minF, maxF / F_max_range) self.precond = np.sqrt(minF_trunc) / \ np.sqrt(np.maximum(minF_trunc, F)) if ostream is not None: ostream(" F min/max = {0} / {1}, var min/max/mean = {2} / {3} / {4}".format(np.min(F), np.max(F), np.min(curr_g_var), np.max(curr_g_var), mean_var)) # ostream(" precond range = ({0},{1}), Lscale = {2}".format(precond.min(),precond.max(),Lscale)) precond = self.precond else: precond = None init_L = self.L is None # Gradually increase the current L if requested if inc_L != 1.0 and not init_L: self.L *= inc_L if not init_L: L0 = self.L else: L0 = self.load_L0(params) if do_ls: # Perform line search if we haven't found a value of L yet # and/or check that the current L satisfies the conditions self.L = find_L(x, f, g, evalobj, L0, max_ls_its, gradientCheck=grad_check, ostream=ostream, precond=precond, minStep=minStep, maxStep=maxStep, optThresh=optThresh) currL = self.L if init_L: self.save_L0(params) eps = eps0 / currL dx = -eps * totghat if self.ng: self.statout.output(sagd_precond_min=[precond.min()], sagd_precond_max=[precond.max()]) dx *= precond**2 # if ostream is not None: # ostream(" step size range = ({0},{1})".format(precond.min()**2/currL,precond.max()**2/currL)) # else: # if ostream is not None: # ostream(" step size = {0}".format(1.0/currL)) # dMnorm = n.linalg.norm(dM) # print "||dM||/eps = {0}, ||dM|| / (||M|| * eps) = # {1}".format(dMnorm/eps,dMnorm/(Mnorm*eps)) ghatnorm = np.linalg.norm(totghat) self.statout.output(sagd_L=[self.L], sagd_gnorm=[ghatnorm], sagd_eps=[eps]) # 0 extra operations on data return f, g.reshape((-1, 1)), dx, res_train, 0
class CryoOptimizer(BackgroundWorker): def outputbatchinfo(self,batch,res,logP,prefix,name): diag = {} stat = {} like = {} N_M = batch['N_M'] cepoch = self.cryodata.get_epoch(frac=True) epoch = self.cryodata.get_epoch() num_data = self.cryodata.N_D_Train sigma = n.sqrt(n.mean(res['Evar_like'])) sigma_prior = n.sqrt(n.mean(res['Evar_prior'])) self.ostream(' {0} Batch:'.format(name)) for suff in ['R','I','S']: diag[prefix+'_CV2_'+suff] = res['CV2_'+suff] diag[prefix+'_idxs'] = batch['img_idxs'] diag[prefix+'_sigma2_est'] = res['sigma2_est'] diag[prefix+'_correlation'] = res['correlation'] diag[prefix+'_power'] = res['power'] # self.ostream(" RMS Error: %g" % (sigma/n.sqrt(self.cryodata.noise_var))) self.ostream(" RMS Error: %g, Signal: %g" % (sigma/n.sqrt(self.cryodata.noise_var), \ sigma_prior/n.sqrt(self.cryodata.noise_var))) self.ostream(" Effective # of R / I / S: %.2f / %.2f / %.2f " %\ (n.mean(res['CV2_R']), n.mean(res['CV2_I']),n.mean(res['CV2_S']))) # Importance Sampling Statistics is_speedups = [] for suff in ['R','I','S','Total']: if self.cparams.get('is_on_'+suff,False) or (suff == 'Total' and len(is_speedups) > 0): spdup = N_M/res['N_' + suff + '_sampled_total'] is_speedups.append((suff,spdup,n.mean(res['N_'+suff+'_sampled']),res['N_'+suff])) stat[prefix+'_is_speedup_'+suff] = [spdup] else: stat[prefix+'_is_speedup_'+suff] = [1.0] if len(is_speedups) > 0: lblstr = is_speedups[0][0] numstr = '%.2f (%d of %d)' % is_speedups[0][1:] for i in range(1,len(is_speedups)): lblstr += ' / ' + is_speedups[i][0] numstr += ' / %.2f (%d of %d)' % is_speedups[i][1:] self.ostream(" IS Speedup {0}: {1}".format(lblstr,numstr)) stat[prefix+'_sigma'] = [sigma] stat[prefix+'_logp'] = [logP] stat[prefix+'_like'] = [res['L']] stat[prefix+'_num_data'] = [num_data] stat[prefix+'_num_data_evals'] = [self.num_data_evals] stat[prefix+'_iteration'] = [self.iteration] stat[prefix+'_epoch'] = [epoch] stat[prefix+'_cepoch'] = [cepoch], stat[prefix+'_time'] = [time.time()] for k,v in res['like_timing'].iteritems(): stat[prefix+'_like_timing_'+k] = [v] Idxs = batch['img_idxs'] self.img_likes[Idxs] = res['like'] like['img_likes'] = self.img_likes like['train_idxs'] = self.cryodata.train_idxs like['test_idxs'] = self.cryodata.test_idxs keepidxs = self.cryodata.train_idxs if prefix == 'train' else self.cryodata.test_idxs keeplikes = self.img_likes[keepidxs] keeplikes = keeplikes[n.isfinite(keeplikes)] quants = n.percentile(keeplikes, range(0,101)) stat[prefix+'_full_like_quantiles'] = [quants] quants = n.percentile(res['like'], range(0,101)) stat[prefix+'_mini_like_quantiles'] = [quants] stat[prefix+'_num_like_quantiles'] = [len(keeplikes)] self.diagout.output(**diag) self.statout.output(**stat) self.likeout.output(**like) def ioworker(self): while True: iotype,fname,data = self.io_queue.get() try: if iotype == 'mrc': writeMRC(fname,*data) elif iotype == 'pkl': with open(fname, 'wb') as f: cPickle.dump(data, f, protocol=-1) elif iotype == 'cp': copyfile(fname,data) except: print "ERROR DUMPING {0}: {1}".format(fname, sys.exc_info()[0]) self.io_queue.task_done() def __init__(self, expbase, cmdparams=None): """cryodata is a CryoData instance. expbase is a path to the base of folder where this experiment's files will be stored. The folder above expbase will also be searched for .params files. These will be loaded first.""" BackgroundWorker.__init__(self) # Create a background thread which handles IO self.io_queue = Queue() self.io_thread = Thread(target=self.ioworker) self.io_thread.daemon = True self.io_thread.start() # General setup ---------------------------------------------------- self.expbase = expbase self.outbase = None # Paramter setup --------------------------------------------------- # search above expbase for params files _,_,filenames = os.walk(opj(expbase,'../')).next() self.paramfiles = [opj(opj(expbase,'../'), fname) \ for fname in filenames if fname.endswith('.params')] # search expbase for params files _,_,filenames = os.walk(opj(expbase)).next() self.paramfiles += [opj(expbase,fname) \ for fname in filenames if fname.endswith('.params')] if 'local.params' in filenames: self.paramfiles += [opj(expbase,'local.params')] # load parameter files self.params = Params(self.paramfiles) self.cparams = None if cmdparams is not None: # Set parameter specified on the command line for k,v in cmdparams.iteritems(): self.params[k] = v # Dataset setup ------------------------------------------------------- self.imgpath = self.params['inpath'] psize = self.params['resolution'] if not isinstance(self.imgpath,list): imgstk = MRCImageStack(self.imgpath,psize) else: imgstk = CombinedImageStack([MRCImageStack(cimgpath,psize) for cimgpath in self.imgpath]) if self.params.get('float_images',True): imgstk.float_images() self.ctfpath = self.params['ctfpath'] mscope_params = self.params['microscope_params'] if not isinstance(self.ctfpath,list): ctfstk = CTFStack(self.ctfpath,mscope_params) else: ctfstk = CombinedCTFStack([CTFStack(cctfpath,mscope_params) for cctfpath in self.ctfpath]) self.cryodata = CryoDataset(imgstk,ctfstk) self.cryodata.compute_noise_statistics() if self.params.get('window_images',True): imgstk.window_images() minibatch_size = self.params['minisize'] testset_size = self.params['test_imgs'] partition = self.params.get('partition',0) num_partitions = self.params.get('num_partitions',1) seed = self.params['random_seed'] if isinstance(partition,str): partition = eval(partition) if isinstance(num_partitions,str): num_partitions = eval(num_partitions) if isinstance(seed,str): seed = eval(seed) self.cryodata.divide_dataset(minibatch_size,testset_size,partition,num_partitions,seed) self.cryodata.set_datasign(self.params.get('datasign','auto')) if self.params.get('normalize_data',True): self.cryodata.normalize_dataset() self.voxel_size = self.cryodata.pixel_size # Iterations setup ------------------------------------------------- self.iteration = 0 self.tic_epoch = None self.num_data_evals = 0 self.eval_params() outdir = self.cparams.get('outdir',None) if outdir is None: if self.cparams.get('num_partitions',1) > 1: outdir = 'partition{0}'.format(self.cparams['partition']) else: outdir = '' self.outbase = opj(self.expbase,outdir) if not os.path.isdir(self.outbase): os.makedirs(self.outbase) # Output setup ----------------------------------------------------- self.ostream = OutputStream(opj(self.outbase,'stdout')) self.ostream(80*"=") self.ostream("Experiment: " + expbase + \ " Kernel: " + self.params['kernel']) self.ostream("Started on " + socket.gethostname() + \ " At: " + time.strftime('%B %d %Y: %I:%M:%S %p')) self.ostream("Git SHA1: " + gitutil.git_get_SHA1()) self.ostream(80*"=") gitutil.git_info_dump(opj(self.outbase, 'gitinfo')) self.startdatetime = datetime.now() # for diagnostics and parameters self.diagout = Output(opj(self.outbase, 'diag'),runningout=False) # for stats (per image etc) self.statout = Output(opj(self.outbase, 'stat'),runningout=True) # for likelihoods of individual images self.likeout = Output(opj(self.outbase, 'like'),runningout=False) self.img_likes = n.empty(self.cryodata.N_D) self.img_likes[:] = n.inf # optimization state vars ------------------------------------------ init_model = self.cparams.get('init_model',None) if init_model is not None: filename = init_model if filename.upper().endswith('.MRC'): M = readMRC(filename) else: with open(filename) as fp: M = cPickle.load(fp) if type(M)==list: M = M[-1]['M'] if M.shape != 3*(self.cryodata.N,): M = cryoem.resize_ndarray(M,3*(self.cryodata.N,),axes=(0,1,2)) else: init_seed = self.cparams.get('init_random_seed',0) + self.cparams.get('partition',0) print "Randomly generating initial density (init_random_seed = {0})...".format(init_seed), ; sys.stdout.flush() tic = time.time() M = cryoem.generate_phantom_density(self.cryodata.N, 0.95*self.cryodata.N/2.0, \ 5*self.cryodata.N/128.0, 30, seed=init_seed) print "done in {0}s".format(time.time() - tic) tic = time.time() print "Windowing and aligning initial density...", ; sys.stdout.flush() # window the initial density wfunc = self.cparams.get('init_window','circle') cryoem.window(M,wfunc) # Center and orient the initial density cryoem.align_density(M) print "done in {0:.2f}s".format(time.time() - tic) # apply the symmetry operator init_sym = get_symmetryop(self.cparams.get('init_symmetry',self.cparams.get('symmetry',None))) if init_sym is not None: tic = time.time() print "Applying symmetry operator...", ; sys.stdout.flush() M = init_sym.apply(M) print "done in {0:.2f}s".format(time.time() - tic) tic = time.time() print "Scaling initial model...", ; sys.stdout.flush() modelscale = self.cparams.get('modelscale','auto') mleDC, _, mleDC_est_std = self.cryodata.get_dc_estimate() if modelscale == 'auto': # Err on the side of a weaker prior by using a larger value for modelscale modelscale = (n.abs(mleDC) + 2*mleDC_est_std)/self.cryodata.N print "estimated modelscale = {0:.3g}...".format(modelscale), ; sys.stdout.flush() self.params['modelscale'] = modelscale self.cparams['modelscale'] = modelscale M *= modelscale/M.sum() print "done in {0:.2f}s".format(time.time() - tic) if mleDC_est_std/n.abs(mleDC) > 0.05: print " WARNING: the DC component estimate has a high relative variance, it may be inaccurate!" if ((modelscale*self.cryodata.N - n.abs(mleDC)) / mleDC_est_std) > 3: print " WARNING: the selected modelscale value is more than 3 std devs different than the estimated one. Be sure this is correct." self.M = n.require(M,dtype=density.real_t) self.fM = density.real_to_fspace(M) self.dM = density.zeros_like(self.M) self.step = eval(self.cparams['optim_algo']) self.step.setup(self.cparams, self.diagout, self.statout, self.ostream) # Objective function setup -------------------------------------------- param_type = self.cparams.get('parameterization','real') cplx_param = param_type in ['complex','complex_coeff','complex_herm_coeff'] self.like_func = eval_objective(self.cparams['likelihood']) self.prior_func = eval_objective(self.cparams['prior']) if self.cparams.get('penalty',None) is not None: self.penalty_func = eval_objective(self.cparams['penalty']) prior_func = SumObjectives(self.prior_func.fspace, \ [self.penalty_func,self.prior_func], None) else: prior_func = self.prior_func self.obj = SumObjectives(cplx_param, [self.like_func,prior_func], [None,None]) self.obj.setup(self.cparams, self.diagout, self.statout, self.ostream) self.obj.set_dataset(self.cryodata) self.obj_wrapper = ObjectiveWrapper(param_type) self.last_save = time.time() self.logpost_history = FiniteRunningSum() self.like_history = FiniteRunningSum() # Importance Samplers ------------------------------------------------- self.is_sym = get_symmetryop(self.cparams.get('is_symmetry',self.cparams.get('symmetry',None))) self.sampler_R = FixedFisherImportanceSampler('_R',self.is_sym) self.sampler_I = FixedFisherImportanceSampler('_I') self.sampler_S = FixedGaussianImportanceSampler('_S') self.like_func.set_samplers(sampler_R=self.sampler_R,sampler_I=self.sampler_I,sampler_S=self.sampler_S) def eval_params(self): # cvars are state variables that can be used in parameter expressions cvars = {} cvars['cepoch'] = self.cryodata.get_epoch(frac=True) cvars['epoch'] = self.cryodata.get_epoch() cvars['iteration'] = self.iteration cvars['num_data'] = self.cryodata.N_D_Train cvars['num_batches'] = self.cryodata.N_batches cvars['noise_std'] = n.sqrt(self.cryodata.noise_var) cvars['data_std'] = n.sqrt(self.cryodata.data_var) cvars['voxel_size'] = self.voxel_size cvars['pixel_size'] = self.cryodata.pixel_size cvars['prev_max_frequency'] = self.cparams['max_frequency'] if self.cparams is not None else None # prelist fields are parameters that can be used in evaluating other parameter # expressions, they can only depend on values defined in cvars prelist = ['max_frequency'] skipfields = set(['inpath','ctfpath']) cvars = self.params.partial_evaluate(prelist,**cvars) if self.cparams is None: self.max_frequency_changes = 0 else: self.max_frequency_changes += cvars['max_frequency'] != cvars['prev_max_frequency'] cvars['num_max_frequency_changes'] = self.max_frequency_changes cvars['max_frequency_changed'] = cvars['max_frequency'] != cvars['prev_max_frequency'] self.cparams = self.params.evaluate(skipfields,**cvars) self.cparams['exp_path'] = self.expbase self.cparams['out_path'] = self.outbase if 'name' not in self.cparams: self.cparams['name'] = '{0} - {1} - {2} ({3})'.format(self.cparams['dataset_name'], self.cparams['prior_name'], self.cparams['optimizer_name'], self.cparams['kernel']) def run(self): while self.dowork(): pass print "Waiting for IO queue to clear...", ; sys.stdout.flush() self.io_queue.join() print "done." ; sys.stdout.flush() def begin(self): BackgroundWorker.begin(self) def end(self): BackgroundWorker.end(self) def dowork(self): """Do one atom of work. I.E. Execute one minibatch""" timing = {} # Time each minibatch tic_mini = time.time() self.iteration += 1 # Fetch the current batches trainbatch = self.cryodata.get_next_minibatch(self.cparams.get('shuffle_minibatches',True)) # Get the current epoch cepoch = self.cryodata.get_epoch(frac=True) epoch = self.cryodata.get_epoch() num_data = self.cryodata.N_D_Train # Evaluate the parameters self.eval_params() timing['setup'] = time.time() - tic_mini # Do hyperparameter learning if self.cparams.get('learn_params',False): tic_learn = time.time() if self.cparams.get('learn_prior_params',True): tic_learn_prior = time.time() self.prior_func.learn_params(self.params, self.cparams, M=self.M, fM=self.fM) timing['learn_prior'] = time.time() - tic_learn_prior if self.cparams.get('learn_likelihood_params',True): tic_learn_like = time.time() self.like_func.learn_params(self.params, self.cparams, M=self.M, fM=self.fM) timing['learn_like'] = time.time() - tic_learn_like if self.cparams.get('learn_prior_params',True) or self.cparams.get('learn_likelihood_params',True): timing['learn_total'] = time.time() - tic_learn # Time each epoch if self.tic_epoch == None: self.ostream("Epoch: %d" % epoch) self.tic_epoch = (tic_mini,epoch) elif self.tic_epoch[1] != epoch: self.ostream("Epoch Total - %.6f seconds " % \ (tic_mini - self.tic_epoch[0])) self.tic_epoch = (tic_mini,epoch) sym = get_symmetryop(self.cparams.get('symmetry',None)) if sym is not None: self.obj.ws[1] = 1.0/sym.get_order() tic_mstats = time.time() self.ostream(self.cparams['name']," Iteration:", self.iteration,\ " Epoch:", epoch, " Host:", socket.gethostname()) # Compute density statistics N = self.cryodata.N M_sum = self.M.sum(dtype=n.float64) M_zeros = (self.M == 0).sum() M_mean = M_sum/N**3 M_max = self.M.max() M_min = self.M.min() # self.ostream(" Density (min/max/avg/sum/zeros): " + # "%.2e / %.2e / %.2e / %.2e / %g " % # (M_min, M_max, M_mean, M_sum, M_zeros)) self.statout.output(total_density=[M_sum], avg_density=[M_mean], nonzero_density=[M_zeros], max_density=[M_max], min_density=[M_min]) timing['density_stats'] = time.time() - tic_mstats # evaluate test batch if requested if self.iteration <= 1 or self.cparams.get('evaluate_test_set',self.iteration%5): tic_test = time.time() testbatch = self.cryodata.get_testbatch() self.obj.set_data(self.cparams,testbatch) testLogP, res_test = self.obj.eval(M=self.M, fM=self.fM, compute_gradient=False) self.outputbatchinfo(testbatch, res_test, testLogP, 'test', 'Test') timing['test_batch'] = time.time() - tic_test else: testLogP, res_test = None, None # setup the wrapper for the objective function tic_objsetup = time.time() self.obj.set_data(self.cparams,trainbatch) self.obj_wrapper.set_objective(self.obj) x0 = self.obj_wrapper.set_density(self.M,self.fM) evalobj = self.obj_wrapper.eval_obj timing['obj_setup'] = time.time() - tic_objsetup # Get step size self.num_data_evals += trainbatch['N_M'] # at least one gradient tic_objstep = time.time() trainLogP, dlogP, v, res_train, extra_num_data = self.step.do_step(x0, self.cparams, self.cryodata, evalobj, batch=trainbatch) # Apply the step x = x0 + v timing['step'] = time.time() - tic_objstep # Convert from parameters to value tic_stepfinalize = time.time() prevM = n.copy(self.M) self.M, self.fM = self.obj_wrapper.convert_parameter(x,comp_real=True) apply_sym = sym is not None and self.cparams.get('perfect_symmetry',True) and self.cparams.get('apply_symmetry',True) if apply_sym: self.M = sym.apply(self.M) # Truncate the density to bounds if they exist if self.cparams['density_lb'] is not None: n.maximum(self.M,self.cparams['density_lb']*self.cparams['modelscale'],out=self.M) if self.cparams['density_ub'] is not None: n.minimum(self.M,self.cparams['density_ub']*self.cparams['modelscale'],out=self.M) # Compute net change self.dM = prevM - self.M # Convert to fourier space (may not be required) if self.fM is None or apply_sym \ or self.cparams['density_lb'] != None \ or self.cparams['density_ub'] != None: self.fM = density.real_to_fspace(self.M) timing['step_finalize'] = time.time() - tic_stepfinalize # Compute step statistics tic_stepstats = time.time() step_size = n.linalg.norm(self.dM) grad_size = n.linalg.norm(dlogP) M_norm = n.linalg.norm(self.M) self.num_data_evals += extra_num_data inc_ratio = step_size / M_norm self.statout.output(step_size=[step_size], inc_ratio=[inc_ratio], grad_size=[grad_size], norm_density=[M_norm]) timing['step_stats'] = time.time() - tic_stepstats # Update import sampling distributions tic_isupdate = time.time() self.sampler_R.perform_update() self.sampler_I.perform_update() self.sampler_S.perform_update() self.diagout.output(global_phi_R=self.sampler_R.get_global_dist()) self.diagout.output(global_phi_I=self.sampler_I.get_global_dist()) self.diagout.output(global_phi_S=self.sampler_S.get_global_dist()) timing['is_update'] = time.time() - tic_isupdate # Output basic diagnostics tic_diagnostics = time.time() self.diagout.output(iteration=self.iteration, epoch=epoch, cepoch=cepoch) if self.logpost_history.N_sum != self.cryodata.N_batches: self.logpost_history.setup(trainLogP,self.cryodata.N_batches) self.logpost_history.set_value(trainbatch['id'],trainLogP) if self.like_history.N_sum != self.cryodata.N_batches: self.like_history.setup(res_train['L'],self.cryodata.N_batches) self.like_history.set_value(trainbatch['id'],res_train['L']) self.outputbatchinfo(trainbatch, res_train, trainLogP, 'train', 'Train') # Dump parameters here to catch the defaults used in evaluation self.diagout.output(params=self.cparams, envelope_mle=self.like_func.get_envelope_mle(), sigma2_mle=self.like_func.get_sigma2_mle(), hostname=socket.gethostname()) self.statout.output(num_data=[num_data], num_data_evals=[self.num_data_evals], iteration=[self.iteration], epoch=[epoch], cepoch=[cepoch], logp=[self.logpost_history.get_mean()], like=[self.like_history.get_mean()], sigma=[self.like_func.get_rmse()], time=[time.time()]) timing['diagnostics'] = time.time() - tic_diagnostics checkpoint_it = self.iteration % self.cparams.get('checkpoint_frequency',50) == 0 save_it = checkpoint_it or self.cparams['save_iteration'] or \ time.time() - self.last_save > self.cparams.get('save_time',n.inf) if save_it: tic_save = time.time() self.last_save = tic_save if self.io_queue.qsize(): print "Warning: IO queue has become backlogged with {0} remaining, waiting for it to clear".format(self.io_queue.qsize()) self.io_queue.join() self.io_queue.put(( 'pkl', self.statout.fname, copy(self.statout.outdict) )) self.io_queue.put(( 'pkl', self.diagout.fname, deepcopy(self.diagout.outdict) )) self.io_queue.put(( 'pkl', self.likeout.fname, deepcopy(self.likeout.outdict) )) self.io_queue.put(( 'mrc', opj(self.outbase,'model.mrc'), \ (n.require(self.M,dtype=density.real_t),self.voxel_size) )) self.io_queue.put(( 'mrc', opj(self.outbase,'dmodel.mrc'), \ (n.require(self.dM,dtype=density.real_t),self.voxel_size) )) if checkpoint_it: self.io_queue.put(( 'cp', self.diagout.fname, self.diagout.fname+'-{0:06}'.format(self.iteration) )) self.io_queue.put(( 'cp', self.likeout.fname, self.likeout.fname+'-{0:06}'.format(self.iteration) )) self.io_queue.put(( 'cp', opj(self.outbase,'model.mrc'), opj(self.outbase,'model-{0:06}.mrc'.format(self.iteration)) )) timing['save'] = time.time() - tic_save time_total = time.time() - tic_mini self.ostream(" Minibatch Total - %.2f seconds Total Runtime - %s" % (time_total, format_timedelta(datetime.now() - self.startdatetime) )) return self.iteration < self.cparams.get('max_iterations',n.inf) and \ cepoch < self.cparams.get('max_epochs',n.inf)
class SAGDStep(base.BaseStep): def __init__(self, ind_L=False, ng=False, alt_sum=True): base.BaseStep.__init__(self) self.g_history = FiniteRunningSum(second_order=ng) self.L = None self.alt_sum = alt_sum # FIXME: currently ignored self.ind_L = ind_L # FIXME: currently ignored self.precond = None self.ng = ng self.prev_max_freq = None def add_batch_gradient(self, batch, curr_g, params): mu = params.get('sagd_momentum', None) if self.g_history.N_sum != params['num_batches']: # Reset gradient g_history if minibatch size changes. self.g_history.setup(curr_g, batch['num_batches']) return self.g_history.set_value(batch['id'], curr_g, mu) def save_L0(self, params): with open(os.path.join(params['exp_path'], 'sagd_L0.pkl'), 'wb') as fp: pickle.dump(self.L, fp, protocol=2) def load_L0(self, params): try: with open(os.path.join(params['exp_path'], 'sagd_L0.pkl'), 'rb') as fp: L0 = pickle.load(fp) except: L0 = params.get('sagd_L0', 1.0) return L0 def do_step(self, x, params, cryodata, evalobj, batch, **kwargs): # ostream = self.ostream # ostream = None ostream = self.ostream if self.L is None else None inc_L = params.get('sagd_incL', False) do_ls = self.L is None or params.get('sagd_linesearch', False) max_ls_its = params.get('sagd_linesearch_maxits', 5) if self.L is not None else None grad_check = params.get('sagd_gradcheck', False) minStep = 1.05 if self.L is not None else 2 maxStep = 100.0 if self.L is not None else 10 optThresh = params.get('sagd_linesearch_accuracy', 1.01) # weight of the prior which biases the covariance estimate towards # lambda g2_lambdaw = params.get('sagd_lambdaw', 10) eps0 = params.get('sagd_learnrate', 1.0 / 16.0) reset_precond = params.get('sagd_reset_precond', False) F_max_range = params.get('sagd_precond_maxrange', 1000.0) use_saga = params.get('sagd_sagastep', False) if self.g_history is not None and reset_precond: self.g_history.reset_meansq() do_ls = True # Evaluate the gradient f, g, res_train = evalobj(x, all_grads=True) assert len(res_train['all_dlogPs']) == 2 g_like = res_train['all_dlogPs'][0] g_prior = res_train['all_dlogPs'][1] if use_saga: prev_g_hat = self.g_history.get_mean() # Add the gradient to the g_history prev_g_like, _ = self.add_batch_gradient(batch, g_like, params) # Get the current average g if not use_saga: curr_g_hat = self.g_history.get_mean() totghat = curr_g_hat + g_prior else: totghat = prev_g_hat + g_prior + (g_like - prev_g_like) if self.ng: # Update the preconditioner when we're going to do a linesearch if do_ls: curr_g_var = np.maximum(0, self.g_history.get_meansq()) curr_w2sum = self.g_history.get_w2sum() mean_var = np.mean(curr_g_var) if curr_w2sum > 1: F = mean_var * (g2_lambdaw / (curr_w2sum + g2_lambdaw)) + \ curr_g_var * (curr_w2sum / (curr_w2sum + g2_lambdaw)) else: F = mean_var minF = F.min() maxF = F.max() minF_trunc = np.maximum(minF, maxF / F_max_range) self.precond = np.sqrt(minF_trunc) / \ np.sqrt(np.maximum(minF_trunc, F)) if ostream is not None: ostream(" F min/max = {0} / {1}, var min/max/mean = {2} / {3} / {4}".format(np.min(F), np.max(F), np.min(curr_g_var), np.max(curr_g_var), mean_var)) # ostream(" precond range = ({0},{1}), Lscale = {2}".format(precond.min(),precond.max(),Lscale)) precond = self.precond else: precond = None init_L = self.L is None # Gradually increase the current L if requested if inc_L != 1.0 and not init_L: self.L *= inc_L if not init_L: L0 = self.L else: L0 = self.load_L0(params) if do_ls: # Perform line search if we haven't found a value of L yet # and/or check that the current L satisfies the conditions self.L = find_L(x, f, g, evalobj, L0, max_ls_its, gradientCheck=grad_check, ostream=ostream, precond=precond, minStep=minStep, maxStep=maxStep, optThresh=optThresh) # self.L = 1.0 currL = self.L if init_L: self.save_L0(params) eps = eps0 / currL dx = -eps * totghat if self.ng: self.statout.output(sagd_precond_min=[precond.min()], sagd_precond_max=[precond.max()]) dx *= precond**2 # if ostream is not None: # ostream(" step size range = ({0},{1})".format(precond.min()**2/currL,precond.max()**2/currL)) # else: # if ostream is not None: # ostream(" step size = {0}".format(1.0/currL)) # dMnorm = n.linalg.norm(dM) # print "||dM||/eps = {0}, ||dM|| / (||M|| * eps) = # {1}".format(dMnorm/eps,dMnorm/(Mnorm*eps)) ghatnorm = np.linalg.norm(totghat) self.statout.output(sagd_L=[self.L], sagd_gnorm=[ghatnorm], sagd_eps=[eps]) # 0 extra operations on data return f, g.reshape((-1, 1)), dx, res_train, 0