def setup(self,params,diagout,statout,ostream):
        Objective.setup(self,params,diagout,statout,ostream)

        if params['kernel'] == 'multicpu':
            from cpu_kernel import UnknownRSThreadedCPUKernel
            self.kernel = UnknownRSThreadedCPUKernel()
        else:
            assert False
        self.kernel.setup(params,diagout,statout,ostream)
class UnknownRSLikelihood(Objective):
    def __init__(self):
        Objective.__init__(self,False)

    def setup(self,params,diagout,statout,ostream):
        Objective.setup(self,params,diagout,statout,ostream)

        if params['kernel'] == 'multicpu':
            from cpu_kernel import UnknownRSThreadedCPUKernel
            self.kernel = UnknownRSThreadedCPUKernel()
        else:
            assert False
        self.kernel.setup(params,diagout,statout,ostream)
        
    def get_sigma2_map(self,nu,model,rotavg=True):
        N = self.cryodata.N
        N_D = float(self.cryodata.N_D_Train)
        num_batches = float(self.cryodata.num_batches)
        base_sigma2 = self.cryodata.get_noise_std()**2

        mean_sigma2 = self.error_history.get_mean().reshape((N,N))
        mean_mask = self.mask_history.get_mean().reshape((N,N))
        mask_w = self.mask_history.get_wsum() * (N_D / num_batches)
        
        if rotavg:
            mean_sigma2 = cryoem.rotational_average(mean_sigma2,normalize=True,doexpand=True)
            mean_mask = cryoem.rotational_average(mean_mask,normalize=False,doexpand=True)

        obsw = mask_w * mean_mask
        map_sigma2 = (mean_sigma2 * obsw + nu * base_sigma2) / (obsw + nu)

        assert n.all(n.isfinite(map_sigma2))

        if model == 'coloured':
            map_sigma2 = map_sigma2
        elif model == 'white':
            map_sigma2 = n.mean(map_sigma2)
        else:
            assert False, 'model must be one of white or coloured'

        return map_sigma2

    def get_envelope_map(self,sigma2,rho,env_lb=None,env_ub=None,minFreq=None,bfactor=None,rotavg=True):
        N = self.cryodata.N
        N_D = float(self.cryodata.N_D_Train)
        num_batches = float(self.cryodata.num_batches)
        psize = self.params['pixel_size']

        mean_corr = self.correlation_history.get_mean().reshape((N,N))
        mean_power = self.power_history.get_mean().reshape((N,N))
        mean_mask = self.mask_history.get_mean().reshape((N,N))
        mask_w = self.mask_history.get_wsum() * (N_D / num_batches)
        
        if rotavg:
            mean_corr = cryoem.rotational_average(mean_corr,normalize=True,doexpand=True)
            mean_power = cryoem.rotational_average(mean_power,normalize=True,doexpand=True)
            mean_mask = cryoem.rotational_average(mean_mask,normalize=False,doexpand=True)

        if isinstance(sigma2,n.ndarray):
            sigma2 = sigma2.reshape((N,N))

        if bfactor is not None:
            coords = gencoords(N,2).reshape((N**2,2))
            freqs = n.sqrt(n.sum(coords**2,axis=1))/(psize*N)
            prior_envelope = ctf.envelope_function(freqs,bfactor).reshape((N,N))
        else:
            prior_envelope = 1.0

        obsw = (mask_w * mean_mask / sigma2)
        exp_env = (mean_corr * obsw + prior_envelope*rho) / (mean_power * obsw + rho)
        
        if minFreq is not None:
            # Only consider envelope parameters for frequencies above a threshold
            minRad = minFreq*2.0*psize
    
            _, _, minRadMask = gencoords(N, 2, minRad, True)
            
            exp_env[minRadMask.reshape((N,N))] = 1.0
        
        if env_lb is not None or env_ub is not None:
            n.clip(exp_env,env_lb,env_ub,out=exp_env)

        return exp_env


    def get_rmse(self):
        return n.sqrt(self.get_sigma2_mle().mean())

    def get_sigma2_mle(self,noise_model='coloured'):
        N = self.cryodata.N
        sigma2 = self.error_history.get_mean()
        mean_mask = self.mask_history.get_mean()
#         mse = mean_mask*sigma2 + (1-mean_mask)*self.cryodata.data['imgvar_freq']
        mse = mean_mask*sigma2 + (1-mean_mask)*self.cryodata.data_var
        if noise_model == 'coloured':
            return mse.reshape((N,N))
        elif noise_model == 'white':
            return mse.mean()
            

    def get_envelope_mle(self,rotavg=False):
        N = self.cryodata.N

        mean_corr = self.correlation_history.get_mean()
        mean_power = self.power_history.get_mean()
        mean_mask = self.mask_history.get_mean()
        
        if rotavg:
            mean_corr = cryoem.rotational_average(mean_corr.reshape((N,N)),doexpand=True)
            mean_power = cryoem.rotational_average(mean_power.reshape((N,N)),doexpand=True)
            mean_mask = cryoem.rotational_average(mean_mask.reshape((N,N)),doexpand=True)

        obs_mask = mean_mask > 0
        exp_env = n.ones_like(mean_corr)
        exp_env[obs_mask] = (mean_corr[obs_mask] / mean_power[obs_mask])

        return exp_env.reshape((N,N))
        

    def set_samplers(self,sampler_R,sampler_I,sampler_S):
        self.kernel.set_samplers(sampler_R,sampler_I,sampler_S)

    def set_dataset(self,cryodata):
        Objective.set_dataset(self,cryodata)
        self.kernel.set_dataset(cryodata)

        self.error_history = FiniteRunningSum(second_order=False)
        self.correlation_history = FiniteRunningSum(second_order=False)
        self.power_history = FiniteRunningSum(second_order=False)
        self.mask_history = FiniteRunningSum(second_order=False)
        
    def set_data(self,cparams,minibatch):
        Objective.set_data(self,cparams,minibatch)
        self.kernel.set_data(cparams,minibatch)
        
    def eval(self,M=None, compute_gradient=True, fM=None, **kwargs):
        tic_start = time.time()
        
        if self.kernel.slice_premult is not None:
            pfM = density.real_to_fspace(self.kernel.slice_premult * M)
        else:
            pfM = density.real_to_fspace(M)
        pmtime = time.time() - tic_start

        ret = self.kernel.eval(fM=pfM,M=None,compute_gradient=compute_gradient)
        
        if not self.minibatch['test_batch'] and not kwargs.get('intermediate',False):
            tic_record = time.time()
            curr_var = ret[-1]['sigma2_est']
            assert n.all(n.isfinite(curr_var))
            if self.error_history.N_sum != self.cryodata.N_batches:
                self.error_history.setup(curr_var,self.cryodata.N_batches,allow_decay=False)
            self.error_history.set_value(self.minibatch['id'],curr_var)

            curr_corr = ret[-1]['correlation']
            assert n.all(n.isfinite(curr_corr))
            if self.correlation_history.N_sum != self.cryodata.N_batches:
                self.correlation_history.setup(curr_corr,self.cryodata.N_batches,allow_decay=False)
            self.correlation_history.set_value(self.minibatch['id'],curr_corr)

            curr_power = ret[-1]['power']
            assert n.all(n.isfinite(curr_power))
            if self.power_history.N_sum != self.cryodata.N_batches:
                self.power_history.setup(curr_power,self.cryodata.N_batches,allow_decay=False)
            self.power_history.set_value(self.minibatch['id'],curr_power)

            curr_mask = self.kernel.truncmask
            if self.mask_history.N_sum != self.cryodata.N_batches:
                self.mask_history.setup(n.require(curr_mask,dtype=n.float32),self.cryodata.N_batches,allow_decay=False)
            self.mask_history.set_value(self.minibatch['id'],curr_mask)
            ret[-1]['like_timing']['record'] = time.time() - tic_record
        
        if compute_gradient and self.kernel.slice_premult is not None:
            tic_record = time.time()
            ret = (ret[0],self.kernel.slice_premult * density.fspace_to_real(ret[1]),ret[2])
            ret[-1]['like_timing']['premult'] = pmtime + time.time() - tic_record
            
        ret[-1]['like_timing']['total'] = time.time() - tic_start
        
        return ret