Example #1
0
    def __init__(self, ind_L=False, ng=False, alt_sum=True):
        base.BaseStep.__init__(self)

        self.g_history = FiniteRunningSum(second_order=ng)

        self.L = None

        self.alt_sum = alt_sum  # FIXME: currently ignored
        self.ind_L = ind_L  # FIXME: currently ignored

        self.precond = None
        self.ng = ng

        self.prev_max_freq = None
Example #2
0
    def set_dataset(self,cryodata):
        Objective.set_dataset(self,cryodata)
        self.kernel.set_dataset(cryodata)

        self.error_history = FiniteRunningSum(second_order=False)
        self.correlation_history = FiniteRunningSum(second_order=False)
        self.power_history = FiniteRunningSum(second_order=False)
        self.mask_history = FiniteRunningSum(second_order=False)
Example #3
0
    def __init__(self, ind_L=False, ng=False, alt_sum=True):
        base.BaseStep.__init__(self)

        self.g_history = FiniteRunningSum(second_order=ng)

        self.L = None

        self.alt_sum = alt_sum  # FIXME: currently ignored
        self.ind_L = ind_L  # FIXME: currently ignored

        self.precond = None
        self.ng = ng

        self.prev_max_freq = None
Example #4
0
class UnknownRSLikelihood(Objective):
    def __init__(self):
        Objective.__init__(self,False)

    def setup(self,params,diagout,statout,ostream):
        Objective.setup(self,params,diagout,statout,ostream)

        if params['kernel'] == 'multicpu':
            from .cpu_kernel import UnknownRSThreadedCPUKernel
            self.kernel = UnknownRSThreadedCPUKernel()
        else:
            assert False
        self.kernel.setup(params,diagout,statout,ostream)

    def get_sigma2_map(self,nu,model,rotavg=True):
        N = self.cryodata.N
        N_D = float(self.cryodata.N_D_Train)
        num_batches = float(self.cryodata.num_batches)
        base_sigma2 = self.cryodata.get_noise_std()**2

        mean_sigma2 = self.error_history.get_mean().reshape((N,N))
        mean_mask = self.mask_history.get_mean().reshape((N,N))
        mask_w = self.mask_history.get_wsum() * (N_D / num_batches)
        
        if rotavg:
            mean_sigma2 = cryoem.rotational_average(mean_sigma2,normalize=True,doexpand=True)
            mean_mask = cryoem.rotational_average(mean_mask,normalize=False,doexpand=True)

        obsw = mask_w * mean_mask
        map_sigma2 = (mean_sigma2 * obsw + nu * base_sigma2) / (obsw + nu)

        assert np.all(np.isfinite(map_sigma2))

        if model == 'coloured':
            map_sigma2 = map_sigma2
        elif model == 'white':
            map_sigma2 = np.mean(map_sigma2)
        else:
            assert False, 'model must be one of white or coloured'

        return map_sigma2

    def get_envelope_map(self,sigma2,rho,env_lb=None,env_ub=None,minFreq=None,bfactor=None,rotavg=True):
        N = self.cryodata.N
        N_D = float(self.cryodata.N_D_Train)
        num_batches = float(self.cryodata.num_batches)
        psize = self.params['pixel_size']

        mean_corr = self.correlation_history.get_mean().reshape((N,N))
        mean_power = self.power_history.get_mean().reshape((N,N))
        mean_mask = self.mask_history.get_mean().reshape((N,N))
        mask_w = self.mask_history.get_wsum() * (N_D / num_batches)
        
        if rotavg:
            mean_corr = cryoem.rotational_average(mean_corr,normalize=True,doexpand=True)
            mean_power = cryoem.rotational_average(mean_power,normalize=True,doexpand=True)
            mean_mask = cryoem.rotational_average(mean_mask,normalize=False,doexpand=True)

        if isinstance(sigma2,np.ndarray):
            sigma2 = sigma2.reshape((N,N))

        if bfactor is not None:
            coords = gencoords(N,2).reshape((N**2,2))
            freqs = np.sqrt(np.sum(coords**2,axis=1))/(psize*N)
            prior_envelope = ctf.envelope_function(freqs,bfactor).reshape((N,N))
        else:
            prior_envelope = 1.0

        obsw = (mask_w * mean_mask / sigma2)
        exp_env = (mean_corr * obsw + prior_envelope*rho) / (mean_power * obsw + rho)
        
        if minFreq is not None:
            # Only consider envelope parameters for frequencies above a threshold
            minRad = minFreq*2.0*psize
    
            _, _, minRadMask = gencoords(N, 2, minRad, True)
            
            exp_env[minRadMask.reshape((N,N))] = 1.0
        
        if env_lb is not None or env_ub is not None:
            np.clip(exp_env,env_lb,env_ub,out=exp_env)

        return exp_env


    def get_rmse(self):
        return np.sqrt(self.get_sigma2_mle().mean())

    def get_sigma2_mle(self,noise_model='coloured'):
        N = self.cryodata.N
        sigma2 = self.error_history.get_mean()
        mean_mask = self.mask_history.get_mean()
        # mse = mean_mask*sigma2 + (1-mean_mask)*self.cryodata.data['imgvar_freq']
        mse = mean_mask*sigma2 + (1-mean_mask)*self.cryodata.data_var
        if noise_model == 'coloured':
            return mse.reshape((N,N))
        elif noise_model == 'white':
            return mse.mean()
            

    def get_envelope_mle(self,rotavg=False):
        N = self.cryodata.N

        mean_corr = self.correlation_history.get_mean()
        mean_power = self.power_history.get_mean()
        mean_mask = self.mask_history.get_mean()
        
        if rotavg:
            mean_corr = cryoem.rotational_average(mean_corr.reshape((N,N)),doexpand=True)
            mean_power = cryoem.rotational_average(mean_power.reshape((N,N)),doexpand=True)
            mean_mask = cryoem.rotational_average(mean_mask.reshape((N,N)),doexpand=True)

        obs_mask = mean_mask > 0
        exp_env = np.ones_like(mean_corr)
        exp_env[obs_mask] = (mean_corr[obs_mask] / mean_power[obs_mask])

        return exp_env.reshape((N,N))
        

    def set_samplers(self,sampler_R,sampler_I,sampler_S):
        self.kernel.set_samplers(sampler_R,sampler_I,sampler_S)

    def set_dataset(self,cryodata):
        Objective.set_dataset(self,cryodata)
        self.kernel.set_dataset(cryodata)

        self.error_history = FiniteRunningSum(second_order=False)
        self.correlation_history = FiniteRunningSum(second_order=False)
        self.power_history = FiniteRunningSum(second_order=False)
        self.mask_history = FiniteRunningSum(second_order=False)
        
    def set_data(self,cparams,minibatch):
        Objective.set_data(self,cparams,minibatch)
        self.kernel.set_data(cparams,minibatch)
        
    def eval(self,M=None, compute_gradient=True, fM=None, **kwargs):
        tic_start = time.time()
        
        if self.kernel.slice_premult is not None:
            pfM = density.real_to_fspace(self.kernel.slice_premult * M)
        else:
            pfM = density.real_to_fspace(M)
        pmtime = time.time() - tic_start

        ret = self.kernel.eval(fM=pfM,M=None,compute_gradient=compute_gradient)
        
        if not self.minibatch['test_batch'] and not kwargs.get('intermediate',False):
            tic_record = time.time()
            curr_var = ret[-1]['sigma2_est']
            assert np.all(np.isfinite(curr_var))
            if self.error_history.N_sum != self.cryodata.N_batches:
                self.error_history.setup(curr_var,self.cryodata.N_batches,allow_decay=False)
            self.error_history.set_value(self.minibatch['id'],curr_var)

            curr_corr = ret[-1]['correlation']
            assert np.all(np.isfinite(curr_corr))
            if self.correlation_history.N_sum != self.cryodata.N_batches:
                self.correlation_history.setup(curr_corr,self.cryodata.N_batches,allow_decay=False)
            self.correlation_history.set_value(self.minibatch['id'],curr_corr)

            curr_power = ret[-1]['power']
            assert np.all(np.isfinite(curr_power))
            if self.power_history.N_sum != self.cryodata.N_batches:
                self.power_history.setup(curr_power,self.cryodata.N_batches,allow_decay=False)
            self.power_history.set_value(self.minibatch['id'],curr_power)

            curr_mask = self.kernel.truncmask
            if self.mask_history.N_sum != self.cryodata.N_batches:
                self.mask_history.setup(np.require(curr_mask,dtype=np.float32),self.cryodata.N_batches,allow_decay=False)
            self.mask_history.set_value(self.minibatch['id'],curr_mask)
            ret[-1]['like_timing']['record'] = time.time() - tic_record
        
        if compute_gradient and self.kernel.slice_premult is not None:
            tic_record = time.time()
            ret = (ret[0],self.kernel.slice_premult * density.fspace_to_real(ret[1]),ret[2])
            ret[-1]['like_timing']['premult'] = pmtime + time.time() - tic_record
            
        ret[-1]['like_timing']['total'] = time.time() - tic_start
        
        return ret
Example #5
0
    def __init__(self, expbase, cmdparams=None):
        """cryodata is a CryoData instance. 
        expbase is a path to the base of folder where this experiment's files
        will be stored.  The folder above expbase will also be searched
        for .params files. These will be loaded first."""
        BackgroundWorker.__init__(self)

        # Create a background thread which handles IO
        self.io_queue = Queue()
        self.io_thread = Thread(target=self.ioworker)
        self.io_thread.daemon = True
        self.io_thread.start()

        # General setup ----------------------------------------------------
        self.expbase = expbase
        self.outbase = None

        # Paramter setup ---------------------------------------------------
        # search above expbase for params files
        _,_,filenames = os.walk(opj(expbase,'../')).next()
        self.paramfiles = [opj(opj(expbase,'../'), fname) \
                           for fname in filenames if fname.endswith('.params')]
        # search expbase for params files
        _,_,filenames = os.walk(opj(expbase)).next()
        self.paramfiles += [opj(expbase,fname)  \
                            for fname in filenames if fname.endswith('.params')]
        if 'local.params' in filenames:
            self.paramfiles += [opj(expbase,'local.params')]
        # load parameter files
        self.params = Params(self.paramfiles)
        self.cparams = None
        
        if cmdparams is not None:
            # Set parameter specified on the command line
            for k,v in cmdparams.iteritems():
                self.params[k] = v
                
        # Dataset setup -------------------------------------------------------
        self.imgpath = self.params['inpath']
        psize = self.params['resolution']
        if not isinstance(self.imgpath,list):
            imgstk = MRCImageStack(self.imgpath,psize)
        else:
            imgstk = CombinedImageStack([MRCImageStack(cimgpath,psize) for cimgpath in self.imgpath])

        if self.params.get('float_images',True):
            imgstk.float_images()
        
        self.ctfpath = self.params['ctfpath']
        mscope_params = self.params['microscope_params']
         
        if not isinstance(self.ctfpath,list):
            ctfstk = CTFStack(self.ctfpath,mscope_params)
        else:
            ctfstk = CombinedCTFStack([CTFStack(cctfpath,mscope_params) for cctfpath in self.ctfpath])


        self.cryodata = CryoDataset(imgstk,ctfstk)
        self.cryodata.compute_noise_statistics()
        if self.params.get('window_images',True):
            imgstk.window_images()
        minibatch_size = self.params['minisize']
        testset_size = self.params['test_imgs']
        partition = self.params.get('partition',0)
        num_partitions = self.params.get('num_partitions',1)
        seed = self.params['random_seed']
        if isinstance(partition,str):
            partition = eval(partition)
        if isinstance(num_partitions,str):
            num_partitions = eval(num_partitions)
        if isinstance(seed,str):
            seed = eval(seed)
        self.cryodata.divide_dataset(minibatch_size,testset_size,partition,num_partitions,seed)
        
        self.cryodata.set_datasign(self.params.get('datasign','auto'))
        if self.params.get('normalize_data',True):
            self.cryodata.normalize_dataset()

        self.voxel_size = self.cryodata.pixel_size


        # Iterations setup -------------------------------------------------
        self.iteration = 0 
        self.tic_epoch = None
        self.num_data_evals = 0
        self.eval_params()

        outdir = self.cparams.get('outdir',None)
        if outdir is None:
            if self.cparams.get('num_partitions',1) > 1:
                outdir = 'partition{0}'.format(self.cparams['partition'])
            else:
                outdir = ''
        self.outbase = opj(self.expbase,outdir)
        if not os.path.isdir(self.outbase):
            os.makedirs(self.outbase) 

        # Output setup -----------------------------------------------------
        self.ostream = OutputStream(opj(self.outbase,'stdout'))

        self.ostream(80*"=")
        self.ostream("Experiment: " + expbase + \
                     "    Kernel: " + self.params['kernel'])
        self.ostream("Started on " + socket.gethostname() + \
                     "    At: " + time.strftime('%B %d %Y: %I:%M:%S %p'))
        self.ostream("Git SHA1: " + gitutil.git_get_SHA1())
        self.ostream(80*"=")
        gitutil.git_info_dump(opj(self.outbase, 'gitinfo'))
        self.startdatetime = datetime.now()


        # for diagnostics and parameters
        self.diagout = Output(opj(self.outbase, 'diag'),runningout=False)
        # for stats (per image etc)
        self.statout = Output(opj(self.outbase, 'stat'),runningout=True)
        # for likelihoods of individual images
        self.likeout = Output(opj(self.outbase, 'like'),runningout=False)

        self.img_likes = n.empty(self.cryodata.N_D)
        self.img_likes[:] = n.inf

        # optimization state vars ------------------------------------------
        init_model = self.cparams.get('init_model',None)
        if init_model is not None:
            filename = init_model
            if filename.upper().endswith('.MRC'):
                M = readMRC(filename)
            else:
                with open(filename) as fp:
                    M = cPickle.load(fp)
                    if type(M)==list:
                        M = M[-1]['M'] 
            if M.shape != 3*(self.cryodata.N,):
                M = cryoem.resize_ndarray(M,3*(self.cryodata.N,),axes=(0,1,2))
        else:
            init_seed = self.cparams.get('init_random_seed',0)  + self.cparams.get('partition',0)
            print "Randomly generating initial density (init_random_seed = {0})...".format(init_seed), ; sys.stdout.flush()
            tic = time.time()
            M = cryoem.generate_phantom_density(self.cryodata.N, 0.95*self.cryodata.N/2.0, \
                                                5*self.cryodata.N/128.0, 30, seed=init_seed)
            print "done in {0}s".format(time.time() - tic)

        tic = time.time()
        print "Windowing and aligning initial density...", ; sys.stdout.flush()
        # window the initial density
        wfunc = self.cparams.get('init_window','circle')
        cryoem.window(M,wfunc)

        # Center and orient the initial density
        cryoem.align_density(M)
        print "done in {0:.2f}s".format(time.time() - tic)

        # apply the symmetry operator
        init_sym = get_symmetryop(self.cparams.get('init_symmetry',self.cparams.get('symmetry',None)))
        if init_sym is not None:
            tic = time.time()
            print "Applying symmetry operator...", ; sys.stdout.flush()
            M = init_sym.apply(M)
            print "done in {0:.2f}s".format(time.time() - tic)

        tic = time.time()
        print "Scaling initial model...", ; sys.stdout.flush()
        modelscale = self.cparams.get('modelscale','auto')
        mleDC, _, mleDC_est_std = self.cryodata.get_dc_estimate()
        if modelscale == 'auto':
            # Err on the side of a weaker prior by using a larger value for modelscale
            modelscale = (n.abs(mleDC) + 2*mleDC_est_std)/self.cryodata.N
            print "estimated modelscale = {0:.3g}...".format(modelscale), ; sys.stdout.flush()
            self.params['modelscale'] = modelscale
            self.cparams['modelscale'] = modelscale
        M *= modelscale/M.sum()
        print "done in {0:.2f}s".format(time.time() - tic)
        if mleDC_est_std/n.abs(mleDC) > 0.05:
            print "  WARNING: the DC component estimate has a high relative variance, it may be inaccurate!"
        if ((modelscale*self.cryodata.N - n.abs(mleDC)) / mleDC_est_std) > 3:
            print "  WARNING: the selected modelscale value is more than 3 std devs different than the estimated one.  Be sure this is correct."

        self.M = n.require(M,dtype=density.real_t)
        self.fM = density.real_to_fspace(M)
        self.dM = density.zeros_like(self.M)

        self.step = eval(self.cparams['optim_algo'])
        self.step.setup(self.cparams, self.diagout, self.statout, self.ostream)

        # Objective function setup --------------------------------------------
        param_type = self.cparams.get('parameterization','real')
        cplx_param = param_type in ['complex','complex_coeff','complex_herm_coeff']
        self.like_func = eval_objective(self.cparams['likelihood'])
        self.prior_func = eval_objective(self.cparams['prior'])

        if self.cparams.get('penalty',None) is not None:
            self.penalty_func = eval_objective(self.cparams['penalty'])
            prior_func = SumObjectives(self.prior_func.fspace, \
                                       [self.penalty_func,self.prior_func], None)
        else:
            prior_func = self.prior_func

        self.obj = SumObjectives(cplx_param,
                                 [self.like_func,prior_func], [None,None])
        self.obj.setup(self.cparams, self.diagout, self.statout, self.ostream)
        self.obj.set_dataset(self.cryodata)
        self.obj_wrapper = ObjectiveWrapper(param_type)

        self.last_save = time.time()
        
        self.logpost_history = FiniteRunningSum()
        self.like_history = FiniteRunningSum()

        # Importance Samplers -------------------------------------------------
        self.is_sym = get_symmetryop(self.cparams.get('is_symmetry',self.cparams.get('symmetry',None)))
        self.sampler_R = FixedFisherImportanceSampler('_R',self.is_sym)
        self.sampler_I = FixedFisherImportanceSampler('_I')
        self.sampler_S = FixedGaussianImportanceSampler('_S')
        self.like_func.set_samplers(sampler_R=self.sampler_R,sampler_I=self.sampler_I,sampler_S=self.sampler_S)
Example #6
0
class CryoOptimizer(BackgroundWorker):
    def outputbatchinfo(self,batch,res,logP,prefix,name):
        diag = {}
        stat = {}
        like = {}
        
        N_M = batch['N_M']
        cepoch = self.cryodata.get_epoch(frac=True)
        epoch = self.cryodata.get_epoch()
        num_data = self.cryodata.N_D_Train
        sigma = n.sqrt(n.mean(res['Evar_like']))
        sigma_prior = n.sqrt(n.mean(res['Evar_prior']))
        
        self.ostream('  {0} Batch:'.format(name))

        for suff in ['R','I','S']:
            diag[prefix+'_CV2_'+suff] = res['CV2_'+suff]

        diag[prefix+'_idxs'] = batch['img_idxs']
        diag[prefix+'_sigma2_est'] = res['sigma2_est']
        diag[prefix+'_correlation'] = res['correlation']
        diag[prefix+'_power'] = res['power']

#         self.ostream("    RMS Error: %g" % (sigma/n.sqrt(self.cryodata.noise_var)))
        self.ostream("    RMS Error: %g, Signal: %g" % (sigma/n.sqrt(self.cryodata.noise_var), \
                                                        sigma_prior/n.sqrt(self.cryodata.noise_var)))
        self.ostream("    Effective # of R / I / S:     %.2f / %.2f / %.2f " %\
                      (n.mean(res['CV2_R']), n.mean(res['CV2_I']),n.mean(res['CV2_S'])))

        # Importance Sampling Statistics
        is_speedups = []
        for suff in ['R','I','S','Total']:
            if self.cparams.get('is_on_'+suff,False) or (suff == 'Total' and len(is_speedups) > 0):
                spdup = N_M/res['N_' + suff + '_sampled_total']
                is_speedups.append((suff,spdup,n.mean(res['N_'+suff+'_sampled']),res['N_'+suff]))
                stat[prefix+'_is_speedup_'+suff] = [spdup]
            else:
                stat[prefix+'_is_speedup_'+suff] = [1.0]

        if len(is_speedups) > 0:
            lblstr = is_speedups[0][0]
            numstr = '%.2f (%d of %d)' % is_speedups[0][1:]
            for i in range(1,len(is_speedups)):
                lblstr += ' / ' + is_speedups[i][0]
                numstr += ' / %.2f (%d of %d)' % is_speedups[i][1:]
            
            self.ostream("    IS Speedup {0}: {1}".format(lblstr,numstr))

        stat[prefix+'_sigma'] = [sigma]
        stat[prefix+'_logp'] = [logP]
        stat[prefix+'_like'] = [res['L']]
        stat[prefix+'_num_data'] = [num_data]
        stat[prefix+'_num_data_evals'] = [self.num_data_evals]
        stat[prefix+'_iteration'] = [self.iteration]
        stat[prefix+'_epoch'] = [epoch]
        stat[prefix+'_cepoch'] = [cepoch],
        stat[prefix+'_time'] = [time.time()]

        for k,v in res['like_timing'].iteritems():
            stat[prefix+'_like_timing_'+k] = [v]
        
        Idxs = batch['img_idxs']
        self.img_likes[Idxs] = res['like']
        like['img_likes'] = self.img_likes
        like['train_idxs'] = self.cryodata.train_idxs
        like['test_idxs'] = self.cryodata.test_idxs
        keepidxs = self.cryodata.train_idxs if prefix == 'train' else self.cryodata.test_idxs
        keeplikes = self.img_likes[keepidxs]
        keeplikes = keeplikes[n.isfinite(keeplikes)]
        quants = n.percentile(keeplikes, range(0,101))
        stat[prefix+'_full_like_quantiles'] = [quants]
        quants = n.percentile(res['like'], range(0,101))
        stat[prefix+'_mini_like_quantiles'] = [quants]
        stat[prefix+'_num_like_quantiles'] = [len(keeplikes)]

        self.diagout.output(**diag)
        self.statout.output(**stat)
        self.likeout.output(**like)

    def ioworker(self):
        while True:
            iotype,fname,data = self.io_queue.get()
            
            try:
                if iotype == 'mrc':
                    writeMRC(fname,*data)
                elif iotype == 'pkl':
                    with open(fname, 'wb') as f:
                        cPickle.dump(data, f, protocol=-1)
                elif iotype == 'cp':
                    copyfile(fname,data)
            except:
                print "ERROR DUMPING {0}: {1}".format(fname, sys.exc_info()[0])
                
            self.io_queue.task_done()
        
    def __init__(self, expbase, cmdparams=None):
        """cryodata is a CryoData instance. 
        expbase is a path to the base of folder where this experiment's files
        will be stored.  The folder above expbase will also be searched
        for .params files. These will be loaded first."""
        BackgroundWorker.__init__(self)

        # Create a background thread which handles IO
        self.io_queue = Queue()
        self.io_thread = Thread(target=self.ioworker)
        self.io_thread.daemon = True
        self.io_thread.start()

        # General setup ----------------------------------------------------
        self.expbase = expbase
        self.outbase = None

        # Paramter setup ---------------------------------------------------
        # search above expbase for params files
        _,_,filenames = os.walk(opj(expbase,'../')).next()
        self.paramfiles = [opj(opj(expbase,'../'), fname) \
                           for fname in filenames if fname.endswith('.params')]
        # search expbase for params files
        _,_,filenames = os.walk(opj(expbase)).next()
        self.paramfiles += [opj(expbase,fname)  \
                            for fname in filenames if fname.endswith('.params')]
        if 'local.params' in filenames:
            self.paramfiles += [opj(expbase,'local.params')]
        # load parameter files
        self.params = Params(self.paramfiles)
        self.cparams = None
        
        if cmdparams is not None:
            # Set parameter specified on the command line
            for k,v in cmdparams.iteritems():
                self.params[k] = v
                
        # Dataset setup -------------------------------------------------------
        self.imgpath = self.params['inpath']
        psize = self.params['resolution']
        if not isinstance(self.imgpath,list):
            imgstk = MRCImageStack(self.imgpath,psize)
        else:
            imgstk = CombinedImageStack([MRCImageStack(cimgpath,psize) for cimgpath in self.imgpath])

        if self.params.get('float_images',True):
            imgstk.float_images()
        
        self.ctfpath = self.params['ctfpath']
        mscope_params = self.params['microscope_params']
         
        if not isinstance(self.ctfpath,list):
            ctfstk = CTFStack(self.ctfpath,mscope_params)
        else:
            ctfstk = CombinedCTFStack([CTFStack(cctfpath,mscope_params) for cctfpath in self.ctfpath])


        self.cryodata = CryoDataset(imgstk,ctfstk)
        self.cryodata.compute_noise_statistics()
        if self.params.get('window_images',True):
            imgstk.window_images()
        minibatch_size = self.params['minisize']
        testset_size = self.params['test_imgs']
        partition = self.params.get('partition',0)
        num_partitions = self.params.get('num_partitions',1)
        seed = self.params['random_seed']
        if isinstance(partition,str):
            partition = eval(partition)
        if isinstance(num_partitions,str):
            num_partitions = eval(num_partitions)
        if isinstance(seed,str):
            seed = eval(seed)
        self.cryodata.divide_dataset(minibatch_size,testset_size,partition,num_partitions,seed)
        
        self.cryodata.set_datasign(self.params.get('datasign','auto'))
        if self.params.get('normalize_data',True):
            self.cryodata.normalize_dataset()

        self.voxel_size = self.cryodata.pixel_size


        # Iterations setup -------------------------------------------------
        self.iteration = 0 
        self.tic_epoch = None
        self.num_data_evals = 0
        self.eval_params()

        outdir = self.cparams.get('outdir',None)
        if outdir is None:
            if self.cparams.get('num_partitions',1) > 1:
                outdir = 'partition{0}'.format(self.cparams['partition'])
            else:
                outdir = ''
        self.outbase = opj(self.expbase,outdir)
        if not os.path.isdir(self.outbase):
            os.makedirs(self.outbase) 

        # Output setup -----------------------------------------------------
        self.ostream = OutputStream(opj(self.outbase,'stdout'))

        self.ostream(80*"=")
        self.ostream("Experiment: " + expbase + \
                     "    Kernel: " + self.params['kernel'])
        self.ostream("Started on " + socket.gethostname() + \
                     "    At: " + time.strftime('%B %d %Y: %I:%M:%S %p'))
        self.ostream("Git SHA1: " + gitutil.git_get_SHA1())
        self.ostream(80*"=")
        gitutil.git_info_dump(opj(self.outbase, 'gitinfo'))
        self.startdatetime = datetime.now()


        # for diagnostics and parameters
        self.diagout = Output(opj(self.outbase, 'diag'),runningout=False)
        # for stats (per image etc)
        self.statout = Output(opj(self.outbase, 'stat'),runningout=True)
        # for likelihoods of individual images
        self.likeout = Output(opj(self.outbase, 'like'),runningout=False)

        self.img_likes = n.empty(self.cryodata.N_D)
        self.img_likes[:] = n.inf

        # optimization state vars ------------------------------------------
        init_model = self.cparams.get('init_model',None)
        if init_model is not None:
            filename = init_model
            if filename.upper().endswith('.MRC'):
                M = readMRC(filename)
            else:
                with open(filename) as fp:
                    M = cPickle.load(fp)
                    if type(M)==list:
                        M = M[-1]['M'] 
            if M.shape != 3*(self.cryodata.N,):
                M = cryoem.resize_ndarray(M,3*(self.cryodata.N,),axes=(0,1,2))
        else:
            init_seed = self.cparams.get('init_random_seed',0)  + self.cparams.get('partition',0)
            print "Randomly generating initial density (init_random_seed = {0})...".format(init_seed), ; sys.stdout.flush()
            tic = time.time()
            M = cryoem.generate_phantom_density(self.cryodata.N, 0.95*self.cryodata.N/2.0, \
                                                5*self.cryodata.N/128.0, 30, seed=init_seed)
            print "done in {0}s".format(time.time() - tic)

        tic = time.time()
        print "Windowing and aligning initial density...", ; sys.stdout.flush()
        # window the initial density
        wfunc = self.cparams.get('init_window','circle')
        cryoem.window(M,wfunc)

        # Center and orient the initial density
        cryoem.align_density(M)
        print "done in {0:.2f}s".format(time.time() - tic)

        # apply the symmetry operator
        init_sym = get_symmetryop(self.cparams.get('init_symmetry',self.cparams.get('symmetry',None)))
        if init_sym is not None:
            tic = time.time()
            print "Applying symmetry operator...", ; sys.stdout.flush()
            M = init_sym.apply(M)
            print "done in {0:.2f}s".format(time.time() - tic)

        tic = time.time()
        print "Scaling initial model...", ; sys.stdout.flush()
        modelscale = self.cparams.get('modelscale','auto')
        mleDC, _, mleDC_est_std = self.cryodata.get_dc_estimate()
        if modelscale == 'auto':
            # Err on the side of a weaker prior by using a larger value for modelscale
            modelscale = (n.abs(mleDC) + 2*mleDC_est_std)/self.cryodata.N
            print "estimated modelscale = {0:.3g}...".format(modelscale), ; sys.stdout.flush()
            self.params['modelscale'] = modelscale
            self.cparams['modelscale'] = modelscale
        M *= modelscale/M.sum()
        print "done in {0:.2f}s".format(time.time() - tic)
        if mleDC_est_std/n.abs(mleDC) > 0.05:
            print "  WARNING: the DC component estimate has a high relative variance, it may be inaccurate!"
        if ((modelscale*self.cryodata.N - n.abs(mleDC)) / mleDC_est_std) > 3:
            print "  WARNING: the selected modelscale value is more than 3 std devs different than the estimated one.  Be sure this is correct."

        self.M = n.require(M,dtype=density.real_t)
        self.fM = density.real_to_fspace(M)
        self.dM = density.zeros_like(self.M)

        self.step = eval(self.cparams['optim_algo'])
        self.step.setup(self.cparams, self.diagout, self.statout, self.ostream)

        # Objective function setup --------------------------------------------
        param_type = self.cparams.get('parameterization','real')
        cplx_param = param_type in ['complex','complex_coeff','complex_herm_coeff']
        self.like_func = eval_objective(self.cparams['likelihood'])
        self.prior_func = eval_objective(self.cparams['prior'])

        if self.cparams.get('penalty',None) is not None:
            self.penalty_func = eval_objective(self.cparams['penalty'])
            prior_func = SumObjectives(self.prior_func.fspace, \
                                       [self.penalty_func,self.prior_func], None)
        else:
            prior_func = self.prior_func

        self.obj = SumObjectives(cplx_param,
                                 [self.like_func,prior_func], [None,None])
        self.obj.setup(self.cparams, self.diagout, self.statout, self.ostream)
        self.obj.set_dataset(self.cryodata)
        self.obj_wrapper = ObjectiveWrapper(param_type)

        self.last_save = time.time()
        
        self.logpost_history = FiniteRunningSum()
        self.like_history = FiniteRunningSum()

        # Importance Samplers -------------------------------------------------
        self.is_sym = get_symmetryop(self.cparams.get('is_symmetry',self.cparams.get('symmetry',None)))
        self.sampler_R = FixedFisherImportanceSampler('_R',self.is_sym)
        self.sampler_I = FixedFisherImportanceSampler('_I')
        self.sampler_S = FixedGaussianImportanceSampler('_S')
        self.like_func.set_samplers(sampler_R=self.sampler_R,sampler_I=self.sampler_I,sampler_S=self.sampler_S)

    def eval_params(self):
        # cvars are state variables that can be used in parameter expressions
        cvars = {}
        cvars['cepoch'] = self.cryodata.get_epoch(frac=True)
        cvars['epoch'] = self.cryodata.get_epoch()
        cvars['iteration'] = self.iteration
        cvars['num_data'] = self.cryodata.N_D_Train
        cvars['num_batches'] = self.cryodata.N_batches
        cvars['noise_std'] = n.sqrt(self.cryodata.noise_var)
        cvars['data_std'] = n.sqrt(self.cryodata.data_var)
        cvars['voxel_size'] = self.voxel_size
        cvars['pixel_size'] = self.cryodata.pixel_size
        cvars['prev_max_frequency'] = self.cparams['max_frequency'] if self.cparams is not None else None

        # prelist fields are parameters that can be used in evaluating other parameter
        # expressions, they can only depend on values defined in cvars
        prelist = ['max_frequency']
        
        skipfields = set(['inpath','ctfpath'])

        cvars = self.params.partial_evaluate(prelist,**cvars)
        if self.cparams is None:
            self.max_frequency_changes = 0
        else:
            self.max_frequency_changes += cvars['max_frequency'] != cvars['prev_max_frequency']
                
        cvars['num_max_frequency_changes'] =  self.max_frequency_changes
        cvars['max_frequency_changed'] = cvars['max_frequency'] != cvars['prev_max_frequency']
        self.cparams = self.params.evaluate(skipfields,**cvars)

        self.cparams['exp_path'] = self.expbase
        self.cparams['out_path'] = self.outbase

        if 'name' not in self.cparams:
            self.cparams['name'] = '{0} - {1} - {2} ({3})'.format(self.cparams['dataset_name'], self.cparams['prior_name'], self.cparams['optimizer_name'], self.cparams['kernel'])

    def run(self):
        while self.dowork(): pass
        print "Waiting for IO queue to clear...",  ; sys.stdout.flush()
        self.io_queue.join()
        print "done."  ; sys.stdout.flush()

    def begin(self):
        BackgroundWorker.begin(self)

    def end(self):
        BackgroundWorker.end(self)

    def dowork(self):
        """Do one atom of work. I.E. Execute one minibatch"""

        timing = {}
        # Time each minibatch
        tic_mini = time.time()

        self.iteration += 1

        # Fetch the current batches
        trainbatch = self.cryodata.get_next_minibatch(self.cparams.get('shuffle_minibatches',True))

        # Get the current epoch
        cepoch = self.cryodata.get_epoch(frac=True)
        epoch = self.cryodata.get_epoch()
        num_data = self.cryodata.N_D_Train

        # Evaluate the parameters
        self.eval_params()
        timing['setup'] = time.time() - tic_mini

        # Do hyperparameter learning
        if self.cparams.get('learn_params',False):
            tic_learn = time.time()
            if self.cparams.get('learn_prior_params',True):
                tic_learn_prior = time.time()
                self.prior_func.learn_params(self.params, self.cparams, M=self.M, fM=self.fM)
                timing['learn_prior'] = time.time() - tic_learn_prior 

            if self.cparams.get('learn_likelihood_params',True):
                tic_learn_like = time.time()
                self.like_func.learn_params(self.params, self.cparams, M=self.M, fM=self.fM)
                timing['learn_like'] = time.time() - tic_learn_like
                
            if self.cparams.get('learn_prior_params',True) or self.cparams.get('learn_likelihood_params',True):
                timing['learn_total'] = time.time() - tic_learn   

        # Time each epoch
        if self.tic_epoch == None:
            self.ostream("Epoch: %d" % epoch)
            self.tic_epoch = (tic_mini,epoch)
        elif self.tic_epoch[1] != epoch:
            self.ostream("Epoch Total - %.6f seconds " % \
                         (tic_mini - self.tic_epoch[0]))
            self.tic_epoch = (tic_mini,epoch)

        sym = get_symmetryop(self.cparams.get('symmetry',None))
        if sym is not None:
            self.obj.ws[1] = 1.0/sym.get_order()

        tic_mstats = time.time()
        self.ostream(self.cparams['name']," Iteration:", self.iteration,\
                     " Epoch:", epoch, " Host:", socket.gethostname())

        # Compute density statistics
        N = self.cryodata.N
        M_sum = self.M.sum(dtype=n.float64)
        M_zeros = (self.M == 0).sum()
        M_mean = M_sum/N**3
        M_max = self.M.max()
        M_min = self.M.min()
#         self.ostream("  Density (min/max/avg/sum/zeros): " +
#                      "%.2e / %.2e / %.2e / %.2e / %g " %
#                      (M_min, M_max, M_mean, M_sum, M_zeros))
        self.statout.output(total_density=[M_sum],
                            avg_density=[M_mean],
                            nonzero_density=[M_zeros],
                            max_density=[M_max],
                            min_density=[M_min])
        timing['density_stats'] = time.time() - tic_mstats

        # evaluate test batch if requested
        if self.iteration <= 1 or self.cparams.get('evaluate_test_set',self.iteration%5):
            tic_test = time.time()
            testbatch = self.cryodata.get_testbatch()

            self.obj.set_data(self.cparams,testbatch)
            testLogP, res_test = self.obj.eval(M=self.M, fM=self.fM,
                                               compute_gradient=False)

            self.outputbatchinfo(testbatch, res_test, testLogP, 'test', 'Test')
            timing['test_batch'] = time.time() - tic_test
        else:
            testLogP, res_test = None, None

        # setup the wrapper for the objective function 
        tic_objsetup = time.time()
        self.obj.set_data(self.cparams,trainbatch)
        self.obj_wrapper.set_objective(self.obj)
        x0 = self.obj_wrapper.set_density(self.M,self.fM)
        evalobj = self.obj_wrapper.eval_obj
        timing['obj_setup'] = time.time() - tic_objsetup

        # Get step size
        self.num_data_evals += trainbatch['N_M']  # at least one gradient
        tic_objstep = time.time()
        trainLogP, dlogP, v, res_train, extra_num_data = self.step.do_step(x0,
                                                         self.cparams,
                                                         self.cryodata,
                                                         evalobj,
                                                         batch=trainbatch)

        # Apply the step
        x = x0 + v
        timing['step'] = time.time() - tic_objstep

        # Convert from parameters to value
        tic_stepfinalize = time.time()
        prevM = n.copy(self.M)
        self.M, self.fM = self.obj_wrapper.convert_parameter(x,comp_real=True)
 
        apply_sym = sym is not None and self.cparams.get('perfect_symmetry',True) and self.cparams.get('apply_symmetry',True)
        if apply_sym:
            self.M = sym.apply(self.M)

        # Truncate the density to bounds if they exist
        if self.cparams['density_lb'] is not None:
            n.maximum(self.M,self.cparams['density_lb']*self.cparams['modelscale'],out=self.M)
        if self.cparams['density_ub'] is not None:
            n.minimum(self.M,self.cparams['density_ub']*self.cparams['modelscale'],out=self.M)

        # Compute net change
        self.dM = prevM - self.M

        # Convert to fourier space (may not be required)
        if self.fM is None or apply_sym \
           or self.cparams['density_lb'] != None \
           or self.cparams['density_ub'] != None:
            self.fM = density.real_to_fspace(self.M)
        timing['step_finalize'] = time.time() - tic_stepfinalize

        # Compute step statistics
        tic_stepstats = time.time()
        step_size = n.linalg.norm(self.dM)
        grad_size = n.linalg.norm(dlogP)
        M_norm = n.linalg.norm(self.M)

        self.num_data_evals += extra_num_data
        inc_ratio = step_size / M_norm
        self.statout.output(step_size=[step_size],
                            inc_ratio=[inc_ratio],
                            grad_size=[grad_size],
                            norm_density=[M_norm])
        timing['step_stats'] = time.time() - tic_stepstats


        # Update import sampling distributions
        tic_isupdate = time.time()
        self.sampler_R.perform_update()
        self.sampler_I.perform_update()
        self.sampler_S.perform_update()

        self.diagout.output(global_phi_R=self.sampler_R.get_global_dist())
        self.diagout.output(global_phi_I=self.sampler_I.get_global_dist())
        self.diagout.output(global_phi_S=self.sampler_S.get_global_dist())
        timing['is_update'] = time.time() - tic_isupdate
        
        # Output basic diagnostics
        tic_diagnostics = time.time()
        self.diagout.output(iteration=self.iteration, epoch=epoch, cepoch=cepoch)

        if self.logpost_history.N_sum != self.cryodata.N_batches:
            self.logpost_history.setup(trainLogP,self.cryodata.N_batches)
        self.logpost_history.set_value(trainbatch['id'],trainLogP)

        if self.like_history.N_sum != self.cryodata.N_batches:
            self.like_history.setup(res_train['L'],self.cryodata.N_batches)
        self.like_history.set_value(trainbatch['id'],res_train['L'])

        self.outputbatchinfo(trainbatch, res_train, trainLogP, 'train', 'Train')

        # Dump parameters here to catch the defaults used in evaluation
        self.diagout.output(params=self.cparams,
                            envelope_mle=self.like_func.get_envelope_mle(),
                            sigma2_mle=self.like_func.get_sigma2_mle(),
                            hostname=socket.gethostname())
        self.statout.output(num_data=[num_data],
                            num_data_evals=[self.num_data_evals],
                            iteration=[self.iteration],
                            epoch=[epoch],
                            cepoch=[cepoch],
                            logp=[self.logpost_history.get_mean()],
                            like=[self.like_history.get_mean()],
                            sigma=[self.like_func.get_rmse()],
                            time=[time.time()])
        timing['diagnostics'] = time.time() - tic_diagnostics

        checkpoint_it = self.iteration % self.cparams.get('checkpoint_frequency',50) == 0 
        save_it = checkpoint_it or self.cparams['save_iteration'] or \
                  time.time() - self.last_save > self.cparams.get('save_time',n.inf)
                  
        if save_it:
            tic_save = time.time()
            self.last_save = tic_save
            if self.io_queue.qsize():
                print "Warning: IO queue has become backlogged with {0} remaining, waiting for it to clear".format(self.io_queue.qsize())
                self.io_queue.join()
            self.io_queue.put(( 'pkl', self.statout.fname, copy(self.statout.outdict) ))
            self.io_queue.put(( 'pkl', self.diagout.fname, deepcopy(self.diagout.outdict) ))
            self.io_queue.put(( 'pkl', self.likeout.fname, deepcopy(self.likeout.outdict) ))
            self.io_queue.put(( 'mrc', opj(self.outbase,'model.mrc'), \
                                (n.require(self.M,dtype=density.real_t),self.voxel_size) ))
            self.io_queue.put(( 'mrc', opj(self.outbase,'dmodel.mrc'), \
                                (n.require(self.dM,dtype=density.real_t),self.voxel_size) ))

            if checkpoint_it:
                self.io_queue.put(( 'cp', self.diagout.fname, self.diagout.fname+'-{0:06}'.format(self.iteration) ))
                self.io_queue.put(( 'cp', self.likeout.fname, self.likeout.fname+'-{0:06}'.format(self.iteration) ))
                self.io_queue.put(( 'cp', opj(self.outbase,'model.mrc'), opj(self.outbase,'model-{0:06}.mrc'.format(self.iteration)) ))
            timing['save'] = time.time() - tic_save
                
            
        time_total = time.time() - tic_mini
        self.ostream("  Minibatch Total - %.2f seconds                         Total Runtime - %s" %
                     (time_total, format_timedelta(datetime.now() - self.startdatetime) ))

        
        return self.iteration < self.cparams.get('max_iterations',n.inf) and \
               cepoch < self.cparams.get('max_epochs',n.inf)
Example #7
0
class SAGDStep(base.BaseStep):
    def __init__(self, ind_L=False, ng=False, alt_sum=True):
        base.BaseStep.__init__(self)

        self.g_history = FiniteRunningSum(second_order=ng)

        self.L = None

        self.alt_sum = alt_sum  # FIXME: currently ignored
        self.ind_L = ind_L  # FIXME: currently ignored

        self.precond = None
        self.ng = ng

        self.prev_max_freq = None

    def add_batch_gradient(self, batch, curr_g, params):
        mu = params.get('sagd_momentum', None)

        if self.g_history.N_sum != params['num_batches']:
            # Reset gradient g_history if minibatch size changes.
            self.g_history.setup(curr_g, batch['num_batches'])

        return self.g_history.set_value(batch['id'], curr_g, mu)

    def save_L0(self, params):
        with open(os.path.join(params['exp_path'], 'sagd_L0.pkl'), 'wb') as fp:
            pickle.dump(self.L, fp, protocol=2)

    def load_L0(self, params):
        try:
            with open(os.path.join(params['exp_path'], 'sagd_L0.pkl'), 'rb') as fp:
                L0 = pickle.load(fp)
        except:
            L0 = params.get('sagd_L0', 1.0)
        return L0

    def do_step(self, x, params, cryodata, evalobj, batch, **kwargs):
        # ostream = self.ostream
        # ostream = None
        ostream = self.ostream if self.L is None else None
        inc_L = params.get('sagd_incL', False)
        do_ls = self.L is None or params.get('sagd_linesearch', False)
        max_ls_its = params.get('sagd_linesearch_maxits',
                                5) if self.L is not None else None
        grad_check = params.get('sagd_gradcheck', False)
        minStep = 1.05 if self.L is not None else 2
        maxStep = 100.0 if self.L is not None else 10
        optThresh = params.get('sagd_linesearch_accuracy', 1.01)
        # weight of the prior which biases the covariance estimate towards
        # lambda
        g2_lambdaw = params.get('sagd_lambdaw', 10)
        eps0 = params.get('sagd_learnrate', 1.0 / 16.0)
        reset_precond = params.get('sagd_reset_precond', False)
        F_max_range = params.get('sagd_precond_maxrange', 1000.0)
        use_saga = params.get('sagd_sagastep', False)

        if self.g_history is not None and reset_precond:
            self.g_history.reset_meansq()
            do_ls = True

        # Evaluate the gradient
        f, g, res_train = evalobj(x, all_grads=True)

        assert len(res_train['all_dlogPs']) == 2
        g_like = res_train['all_dlogPs'][0]
        g_prior = res_train['all_dlogPs'][1]

        if use_saga:
            prev_g_hat = self.g_history.get_mean()

        # Add the gradient to the g_history
        prev_g_like, _ = self.add_batch_gradient(batch, g_like, params)

        # Get the current average g
        if not use_saga:
            curr_g_hat = self.g_history.get_mean()
            totghat = curr_g_hat + g_prior
        else:
            totghat = prev_g_hat + g_prior + (g_like - prev_g_like)

        if self.ng:
            # Update the preconditioner when we're going to do a linesearch
            if do_ls:
                curr_g_var = np.maximum(0, self.g_history.get_meansq())
                curr_w2sum = self.g_history.get_w2sum()
                mean_var = np.mean(curr_g_var)
                if curr_w2sum > 1:
                    F = mean_var * (g2_lambdaw / (curr_w2sum + g2_lambdaw)) + \
                        curr_g_var * (curr_w2sum / (curr_w2sum + g2_lambdaw))
                else:
                    F = mean_var

                minF = F.min()
                maxF = F.max()
                minF_trunc = np.maximum(minF, maxF / F_max_range)

                self.precond = np.sqrt(minF_trunc) / \
                    np.sqrt(np.maximum(minF_trunc, F))
                if ostream is not None:
                    ostream("  F min/max = {0} / {1}, var min/max/mean = {2} / {3} / {4}".format(np.min(F), np.max(F),
                                                                                                 np.min(curr_g_var), np.max(curr_g_var), mean_var))
                    # ostream("  precond range = ({0},{1}), Lscale = {2}".format(precond.min(),precond.max(),Lscale))
            precond = self.precond
        else:
            precond = None

        init_L = self.L is None

        # Gradually increase the current L if requested
        if inc_L != 1.0 and not init_L:
            self.L *= inc_L

        if not init_L:
            L0 = self.L
        else:
            L0 = self.load_L0(params)

        if do_ls:
            # Perform line search if we haven't found a value of L yet
            # and/or check that the current L satisfies the conditions
            self.L = find_L(x, f, g, evalobj, L0, max_ls_its,
                            gradientCheck=grad_check,
                            ostream=ostream, precond=precond,
                            minStep=minStep, maxStep=maxStep,
                            optThresh=optThresh)

        currL = self.L

        if init_L:
            self.save_L0(params)

        eps = eps0 / currL

        dx = -eps * totghat
        if self.ng:
            self.statout.output(sagd_precond_min=[precond.min()],
                                sagd_precond_max=[precond.max()])
            dx *= precond**2
        #     if ostream is not None:
        #         ostream("  step size range = ({0},{1})".format(precond.min()**2/currL,precond.max()**2/currL))
        # else:
        #     if ostream is not None:
        #         ostream("  step size = {0}".format(1.0/currL))

        # dMnorm = n.linalg.norm(dM)
        # print "||dM||/eps = {0}, ||dM|| / (||M|| * eps) =
        # {1}".format(dMnorm/eps,dMnorm/(Mnorm*eps))
        ghatnorm = np.linalg.norm(totghat)
        self.statout.output(sagd_L=[self.L],
                            sagd_gnorm=[ghatnorm],
                            sagd_eps=[eps])

        # 0 extra operations on data
        return f, g.reshape((-1, 1)), dx, res_train, 0
Example #8
0
class SAGDStep(base.BaseStep):
    def __init__(self, ind_L=False, ng=False, alt_sum=True):
        base.BaseStep.__init__(self)

        self.g_history = FiniteRunningSum(second_order=ng)

        self.L = None

        self.alt_sum = alt_sum  # FIXME: currently ignored
        self.ind_L = ind_L  # FIXME: currently ignored

        self.precond = None
        self.ng = ng

        self.prev_max_freq = None

    def add_batch_gradient(self, batch, curr_g, params):
        mu = params.get('sagd_momentum', None)

        if self.g_history.N_sum != params['num_batches']:
            # Reset gradient g_history if minibatch size changes.
            self.g_history.setup(curr_g, batch['num_batches'])

        return self.g_history.set_value(batch['id'], curr_g, mu)

    def save_L0(self, params):
        with open(os.path.join(params['exp_path'], 'sagd_L0.pkl'), 'wb') as fp:
            pickle.dump(self.L, fp, protocol=2)

    def load_L0(self, params):
        try:
            with open(os.path.join(params['exp_path'], 'sagd_L0.pkl'), 'rb') as fp:
                L0 = pickle.load(fp)
        except:
            L0 = params.get('sagd_L0', 1.0)
        return L0

    def do_step(self, x, params, cryodata, evalobj, batch, **kwargs):
        # ostream = self.ostream
        # ostream = None
        ostream = self.ostream if self.L is None else None
        inc_L = params.get('sagd_incL', False)
        do_ls = self.L is None or params.get('sagd_linesearch', False)
        max_ls_its = params.get('sagd_linesearch_maxits',
                                5) if self.L is not None else None
        grad_check = params.get('sagd_gradcheck', False)
        minStep = 1.05 if self.L is not None else 2
        maxStep = 100.0 if self.L is not None else 10
        optThresh = params.get('sagd_linesearch_accuracy', 1.01)
        # weight of the prior which biases the covariance estimate towards
        # lambda
        g2_lambdaw = params.get('sagd_lambdaw', 10)
        eps0 = params.get('sagd_learnrate', 1.0 / 16.0)
        reset_precond = params.get('sagd_reset_precond', False)
        F_max_range = params.get('sagd_precond_maxrange', 1000.0)
        use_saga = params.get('sagd_sagastep', False)

        if self.g_history is not None and reset_precond:
            self.g_history.reset_meansq()
            do_ls = True

        # Evaluate the gradient
        f, g, res_train = evalobj(x, all_grads=True)

        assert len(res_train['all_dlogPs']) == 2
        g_like = res_train['all_dlogPs'][0]
        g_prior = res_train['all_dlogPs'][1]

        if use_saga:
            prev_g_hat = self.g_history.get_mean()

        # Add the gradient to the g_history
        prev_g_like, _ = self.add_batch_gradient(batch, g_like, params)

        # Get the current average g
        if not use_saga:
            curr_g_hat = self.g_history.get_mean()
            totghat = curr_g_hat + g_prior
        else:
            totghat = prev_g_hat + g_prior + (g_like - prev_g_like)

        if self.ng:
            # Update the preconditioner when we're going to do a linesearch
            if do_ls:
                curr_g_var = np.maximum(0, self.g_history.get_meansq())
                curr_w2sum = self.g_history.get_w2sum()
                mean_var = np.mean(curr_g_var)
                if curr_w2sum > 1:
                    F = mean_var * (g2_lambdaw / (curr_w2sum + g2_lambdaw)) + \
                        curr_g_var * (curr_w2sum / (curr_w2sum + g2_lambdaw))
                else:
                    F = mean_var

                minF = F.min()
                maxF = F.max()
                minF_trunc = np.maximum(minF, maxF / F_max_range)

                self.precond = np.sqrt(minF_trunc) / \
                    np.sqrt(np.maximum(minF_trunc, F))
                if ostream is not None:
                    ostream("  F min/max = {0} / {1}, var min/max/mean = {2} / {3} / {4}".format(np.min(F), np.max(F),
                                                                                                 np.min(curr_g_var), np.max(curr_g_var), mean_var))
                    # ostream("  precond range = ({0},{1}), Lscale = {2}".format(precond.min(),precond.max(),Lscale))
            precond = self.precond
        else:
            precond = None

        init_L = self.L is None

        # Gradually increase the current L if requested
        if inc_L != 1.0 and not init_L:
            self.L *= inc_L

        if not init_L:
            L0 = self.L
        else:
            L0 = self.load_L0(params)

        if do_ls:
            # Perform line search if we haven't found a value of L yet
            # and/or check that the current L satisfies the conditions
            self.L = find_L(x, f, g, evalobj, L0, max_ls_its,
                            gradientCheck=grad_check,
                            ostream=ostream, precond=precond,
                            minStep=minStep, maxStep=maxStep,
                            optThresh=optThresh)
            # self.L = 1.0

        currL = self.L

        if init_L:
            self.save_L0(params)

        eps = eps0 / currL

        dx = -eps * totghat
        if self.ng:
            self.statout.output(sagd_precond_min=[precond.min()],
                                sagd_precond_max=[precond.max()])
            dx *= precond**2
        #     if ostream is not None:
        #         ostream("  step size range = ({0},{1})".format(precond.min()**2/currL,precond.max()**2/currL))
        # else:
        #     if ostream is not None:
        #         ostream("  step size = {0}".format(1.0/currL))

        # dMnorm = n.linalg.norm(dM)
        # print "||dM||/eps = {0}, ||dM|| / (||M|| * eps) =
        # {1}".format(dMnorm/eps,dMnorm/(Mnorm*eps))
        ghatnorm = np.linalg.norm(totghat)
        self.statout.output(sagd_L=[self.L],
                            sagd_gnorm=[ghatnorm],
                            sagd_eps=[eps])

        # 0 extra operations on data
        return f, g.reshape((-1, 1)), dx, res_train, 0