Esempio n. 1
0
    def set_data(self, model, cparams):

        assert isinstance(model, np.ndarray), "Unexpected input numpy array."
        self.model = model

        self.slices_sampled = None
        self.rotd_sampled = None
        self.rad = None

        cparams['iteration'] = 0
        max_freq = cparams['max_frequency']
        rad_cutoff = cparams.get('rad_cutoff', 1.0)
        fake_oversampling_factor = 1.0
        rad = min(rad_cutoff, max_freq * 2 * self.psize)
        self.beamstop_rad = cparams.get('beamstop_freq', 0.003) * 2 * self.psize

        self.xy, self.trunc_xy, self.truncmask = geometry.gencoords_centermask(self.N, 2, rad, self.beamstop_rad, True)
        self.trunc_freq = np.require(self.trunc_xy / (self.N * self.psize), dtype=np.float32) 
        self.N_T = self.trunc_xy.shape[0]

        # set CTF and envelope
        if self.cryodata.use_ctf:
            radius_freqs = np.sqrt(np.sum(self.trunc_xy**2,axis=1))/(self.psize*self.N)
            self.envelope = ctf.envelope_function(radius_freqs, cparams.get('learn_like_envelope_bfactor', None))
        else:
            self.envelope = np.ones(self.N_T, dtype=np.float32)
        
        # print("Iteration {0}: freq = {3}, rad = {1:.4f}, beamstop_rad={4:.4f}, N_T = {2}".format(cparams['iteration'], rad, self.N_T, max_freq, self.beamstop_rad))
        self.set_quad(rad)

        # Setup inlier model
        self.inlier_sigma2 = 1.0  # cparams['sigma']**2
        base_sigma2 = 1.0  # self.cryodata.noise_var
        self.inlier_sigma2_trunc = self.inlier_sigma2 
Esempio n. 2
0
    def get_envelope_map(self,sigma2,rho,env_lb=None,env_ub=None,minFreq=None,bfactor=None,rotavg=True):
        N = self.cryodata.N
        N_D = float(self.cryodata.N_D_Train)
        num_batches = float(self.cryodata.num_batches)
        psize = self.params['pixel_size']
        beamstop_freq = self.params.get('beamstop_freq', None)

        mean_corr = self.correlation_history.get_mean().reshape((N,N))
        mean_power = self.power_history.get_mean().reshape((N,N))
        mean_mask = self.mask_history.get_mean().reshape((N,N))
        mask_w = self.mask_history.get_wsum() * (N_D / num_batches)
        
        if rotavg:
            mean_corr = cryoem.rotational_average(mean_corr,normalize=True,doexpand=True)
            mean_power = cryoem.rotational_average(mean_power,normalize=True,doexpand=True)
            mean_mask = cryoem.rotational_average(mean_mask,normalize=False,doexpand=True)

        if isinstance(sigma2,np.ndarray):
            sigma2 = sigma2.reshape((N,N))

        if bfactor is not None:
            coords = gencoords(N,2).reshape((N**2,2))
            freqs = np.sqrt(np.sum(coords**2,axis=1))/(psize*N)
            prior_envelope = ctf.envelope_function(freqs,bfactor).reshape((N,N))
        else:
            prior_envelope = 1.0

        obsw = (mask_w * mean_mask / sigma2)
        exp_env = (mean_corr * obsw + prior_envelope*rho) / (mean_power * obsw + rho)
        
        if minFreq is not None:
            # Only consider envelope parameters for frequencies above a threshold
            minRad = minFreq*2.0*psize
            
            if beamstop_freq is None:
                _, _, minRadMask = gencoords(N, 2, minRad, True)
            else:
                maskRad = beamstop_freq * 2.0 * psize
                _, _, minRadMask = gencoords_centermask(N, 2, minRad, maskRad, True)
            
            exp_env[minRadMask.reshape((N,N))] = 1.0
        
        if env_lb is not None or env_ub is not None:
            np.clip(exp_env,env_lb,env_ub,out=exp_env)

        return exp_env
Esempio n. 3
0
def merge_slices(slices, Rs, N, rad, beamstop_rad=None, res=None):
    center = int(N / 2)
    if beamstop_rad is None:
        coords = gencoords(N, 2, rad)
    else:
        coords = gencoords_centermask(N, 2, rad, beamstop_rad)
    assert slices.shape[1] == coords.shape[0]

    if res is None:
        res = np.zeros((N, ) * 3, dtype=np.float32)
    else:
        assert res.shape == (N, ) * 3
        assert res.dtype == slices.dtype
        res[:] = 0.0
    model_weight = np.zeros((N, ) * 3)

    for i, R in enumerate(Rs):
        curr_slices = slices[i, :]

        for j, xy in enumerate(coords):
            voxel_intensity = curr_slices[j]

            rot_coord = R.dot(xy.T).reshape(1, -1)[0] + center
            rot_coord = np.int_(np.round(rot_coord))

            in_x = rot_coord[0] >= 0 and rot_coord[0] < N
            in_y = rot_coord[1] >= 0 and rot_coord[1] < N
            in_z = rot_coord[2] >= 0 and rot_coord[2] < N

            if in_x and in_y and in_z:
                index_coord = tuple(rot_coord)
                model_voxel_intensity = res[index_coord]
                model_weight[index_coord] += 1
                voxel_weight = model_weight[index_coord]
                delta_intensity = voxel_intensity - model_voxel_intensity
                model_voxel_intensity += delta_intensity / voxel_weight
                res[index_coord] = model_voxel_intensity

    return res
Esempio n. 4
0
def getslices_interp(V, Rs, rad, beamstop_rad=None, res=None):
    ndim = V.ndim
    assert ndim > 1
    num_slices = len(Rs)
    # if ndim == 2:
    #     assert Rs.shape[1] == 2
    # elif ndim == 3:
    #     assert Rs.shape[1] == 3
    # Rs.shape[2] == 2
    N = V.shape[0]
    center = int(N / 2)
    if beamstop_rad is None:
        coords = gencoords(N, 2, rad)
    else:
        coords = gencoords_centermask(N, 2, rad, beamstop_rad)
    N_T = coords.shape[0]

    grid = (np.arange(N), ) * ndim
    slicing_func = RegularGridInterpolator(grid,
                                           V,
                                           bounds_error=False,
                                           fill_value=0.0)

    if res is None:
        res = np.zeros((num_slices, N_T), dtype=V.dtype)
    else:
        assert res.shape[0] == Rs.shape[0]
        assert res.dtype == V.dtype
        res[:] = 0

    for i, R in enumerate(Rs):
        rotated_coords = R.dot(coords.T).T + center
        res[i] = slicing_func(rotated_coords)
        # res[i] = interpn(grid, V, rotated_coords)
        # res[i] = spinterp.map_coordinates(V, rotated_coords.T)

    return res
Esempio n. 5
0
def calc_angular_correlation(
    trunc_slices,
    N,
    rad,
    beamstop_rad=None,
    pixel_size=1.0,
    interpolation='nearest',
    sort_theta=True,
    clip=True,
    outside=False,
):
    """compute angular correlation for input array
    outside: True or False (default: False)
        calculate angular correlation in radius or outside of radius
    sort_theta: True or False (default: True)
        sort theta when slicing the same rho in trunc array
    """
    # 1. get a input (single: N_T or multi: N_R x N_T) with normal sequence.
    # 2. sort truncation array by rho value of polar coordinates
    # 3. apply angular correlation function to sorted slice for both real part and imaginary part
    # 4. deal with outlier beyond 3 sigma (no enough points to do sampling via fft)
    #    (oversampling is unavailable, hence dropout points beyond 3 sigma)
    # 5. return angluar correlation slice with normal sequence.

    # 1.
    iscomplex = np.iscomplexobj(trunc_slices)
    if outside:
        trunc_xy = gencoords_outside(N, 2, rad)
    else:
        if beamstop_rad is None:
            trunc_xy = geometry.gencoords(N, 2, rad)
        else:
            trunc_xy = geometry.gencoords_centermask(N, 2, rad, beamstop_rad)
    if trunc_slices.ndim < 2:
        assert trunc_xy.shape[0] == trunc_slices.shape[
            0], "wrong length of trunc slice or wrong radius"
    else:
        assert trunc_xy.shape[0] == trunc_slices.shape[
            1], "wrong length of trunc slice or wrong radius"

    # 2.
    pol_trunc_xy = cart2pol(trunc_xy)
    if sort_theta:
        # lexsort; first, sort rho; second, sort theta
        sorted_idx = np.lexsort((pol_trunc_xy[:, 1], pol_trunc_xy[:, 0]))
    else:
        sorted_idx = np.argsort(pol_trunc_xy[:, 0])
    axis = trunc_slices.ndim - 1
    sorted_rho = np.take(pol_trunc_xy[:, 0], sorted_idx)
    sorted_slice = np.take(trunc_slices, sorted_idx, axis=axis)

    # 3.
    if 'none' in interpolation:
        pass
    elif 'nearest' in interpolation:
        sorted_rho = np.round(sorted_rho)
    elif 'linear' in interpolation:
        raise NotImplementedError()
    else:
        raise ValueError('unsupported method for interpolation')
    # sorted_rho_freqs = sorted_rho / (N * pixel_size)
    resolution = 1.0 / (N * pixel_size)

    _, unique_idx, unique_counts = np.unique(sorted_rho,
                                             return_index=True,
                                             return_counts=True)
    indices = [slice(None)] * trunc_slices.ndim
    angular_correlation = np.zeros_like(trunc_slices, dtype=trunc_slices.dtype)
    for i, count in enumerate(unique_counts):
        indices[axis] = slice(unique_idx[i], unique_idx[i] + count)
        # minimum points to do fft (2 or 4 times than Nyquist frequency)
        # minimum_sample_points = (4 / count) / resolution
        minimum_sample_points = 2000
        same_rho = np.copy(sorted_slice[indices])
        if count < minimum_sample_points:
            for shift in range(count):
                curr_delta_phi = same_rho * np.roll(same_rho, shift, axis=axis)
                indices[axis] = unique_idx[i] + shift
                angular_correlation[indices] = np.mean(curr_delta_phi,
                                                       axis=axis)
        else:
            # use view (slicing) or copy (fancy indexing, np.take(), np.put())?
            fpcimg_real = density.real_to_fspace(
                same_rho.real, axes=(axis, ))  # polar image in fourier sapce
            angular_correlation[indices].real = density.fspace_to_real(
                fpcimg_real * fpcimg_real.conjugate(), axes=(axis, )).real
            if iscomplex:  # FIXME: stupid way. optimize this
                fpcimg_fourier = density.real_to_fspace(
                    same_rho.imag,
                    axes=(axis, ))  # polar image in fourier sapce
                angular_correlation[indices].imag = density.fspace_to_real(
                    fpcimg_fourier * fpcimg_fourier.conjugate(),
                    axes=(axis, )).real

    # check inf and nan
    if np.any(np.isinf(angular_correlation)):
        warnings.warn(
            "Some values in angular correlation occur inf. These values have been set to zeros."
        )
        angular_correlation.real[np.isinf(angular_correlation.real)] = 0
        if iscomplex:
            angular_correlation.imag[np.isinf(angular_correlation.imag)] = 0
    if np.any(np.isnan(angular_correlation)):
        warnings.warn(
            "Some values in angular correlation occur inf. These values have been set to zeros."
        )
        angular_correlation.real[np.isnan(angular_correlation.real)] = 0
        if iscomplex:
            angular_correlation.imag[np.isnan(angular_correlation.imag)] = 0

    # 4.
    if clip:
        factor = 3.0
        for i, count in enumerate(unique_counts):
            # minimum_sample_points = (4 / count) / resolution
            indices[axis] = slice(unique_idx[i], unique_idx[i] + count)
            mean = np.tile(angular_correlation[indices].mean(axis),
                           (count, 1)).T
            std = np.tile(angular_correlation[indices].std(axis), (count, 1)).T
            # print(mean)
            # print(std)

            vmin = mean.mean(axis) - factor * std.mean(axis)
            vmax = mean.mean(axis) + factor * std.mean(axis)

            # if np.all(std < 1e-16):
            #     # why ???
            #     warnings.warn("Standard deviation all equal to zero")
            #     vmin = mean.mean(axis) - factor * std.mean(axis)
            #     vmax = mean.mean(axis) + factor * std.mean(axis)
            # else:
            #     # Normalize to N(0, 1)
            #     angular_correlation[indices] = (angular_correlation[indices] - mean) / std
            #     vmin = -factor
            #     vmax = +factor

            angular_correlation[indices] = np.clip(
                angular_correlation[indices].T, vmin,
                vmax).T  # set outlier to nearby boundary

    # 5.
    corr_trunc_slices = np.take(angular_correlation,
                                sorted_idx.argsort(),
                                axis=axis)
    return corr_trunc_slices
Esempio n. 6
0
    def set_data(self,cparams,minibatch):
        self.params = cparams
        self.minibatch = minibatch

        factoredRI = cparams.get('likelihood_factored_slicing',True)
        max_freq = cparams['max_frequency']
        psize = cparams['pixel_size']
        rad_cutoff = cparams.get('rad_cutoff', 1.0)
        rad = min(rad_cutoff, max_freq * 2.0 * psize)
        beamstop_freq = cparams.get('beamstop_freq', None)
        beamstop_rad = beamstop_freq * 2.0 * psize

        if beamstop_freq is None:
            print('Beamstop freq is unset.')
            self.xy, self.trunc_xy, self.truncmask = gencoords(self.N, 2, rad, True)
        else:
            self.xy, self.trunc_xy, self.truncmask = gencoords_centermask(self.N, 2, rad, beamstop_rad, True)
        self.trunc_freq = np.require(self.trunc_xy / (self.N*psize), dtype=np.float32) 
        self.N_T = self.trunc_xy.shape[0]

        self.use_angular_correlation = cparams.get('use_angular_correlation', False)
        if self.ostream is not None:
            self.ostream("\nUsing angular correlation: {0}".format(self.use_angular_correlation))

        interp_change = self.rad != rad or self.factoredRI != factoredRI
        if interp_change:
            print("Iteration {0}: freq = {3}, rad = {1}, N_T = {2}".format(cparams['iteration'], rad, self.N_T, max_freq))
            self.rad = rad
            self.beamstop_rad = beamstop_rad
            self.factoredRI = factoredRI

        # Setup the quadrature schemes
        if not factoredRI:
            self.set_proj_quad(rad)
        else:
            self.set_slice_quad(rad)
            self.set_inplane_quad(rad)

        # Check shift quadrature
        if self.sampler_S is not None:
            self.set_shift_quad(rad)
        
        # Setup inlier model
        self.inlier_sigma2 = cparams['sigma']**2
        base_sigma2 = self.cryodata.noise_var
        # if isinstance(self.inlier_sigma2,np.ndarray):
        #     self.inlier_sigma2 = self.inlier_sigma2.reshape(self.truncmask.shape)
        #     self.inlier_sigma2_trunc = self.inlier_sigma2[self.truncmask != 0]
        #     self.inlier_const = (self.N_T/2.0)*np.log(2.0*np.pi) + 0.5*np.sum(np.log(self.inlier_sigma2_trunc))
        # else:
        #     self.inlier_sigma2_trunc = self.inlier_sigma2
        #     self.inlier_const = (self.N_T/2.0)*np.log(2.0*np.pi*self.inlier_sigma2)
        self.inlier_sigma2_trunc = self.inlier_sigma2
        self.inlier_const = 0.0

        # Compute the likelihood for the image content outside of rad
        # _,_,fspace_truncmask = gencoords(self.fspace_stack.get_num_pixels(), 2, rad*self.fspace_stack.get_num_pixels()/self.N, True)
        # self.imgpower = np.empty((self.minibatch['N_M'],),dtype=density.real_t)
        # self.imgpower_trunc = np.empty((self.minibatch['N_M'],),dtype=density.real_t)
        # for idx,Idx in enumerate(self.minibatch['img_idxs']):
        #     Img = self.fspace_stack.get_image(Idx)
        #     self.imgpower[idx] = np.sum(Img.real**2) + np.sum(Img.imag**2)

        #     Img_trunc = Img[fspace_truncmask.reshape(Img.shape) == 0]
        #     self.imgpower_trunc[idx] = np.sum(Img_trunc.real**2) + np.sum(Img_trunc.imag**2)
        # like_trunc = 0.5*self.imgpower_trunc/base_sigma2
        # self.inlier_like_trunc = like_trunc
        # self.inlier_const += ((self.N**2 - self.N_T)/2.0)*np.log(2.0*np.pi*base_sigma2)
        self.imgpower = np.zeros((self.minibatch['N_M'],),dtype=density.real_t)
        self.imgpower_trunc = np.zeros((self.minibatch['N_M'],), dtype=density.real_t)
        self.inlier_like_trunc = np.zeros((self.minibatch['N_M'],), dtype=density.real_t)
        self.inlier_const += 0.0
        
        # Setup the envelope function
        envelope = self.params.get('exp_envelope',None)
        if envelope is not None:
            envelope = envelope.reshape((-1,))
            envelope = envelope[self.truncmask != 0]
            envelope = np.require(envelope,dtype=np.float32)
        else:
            bfactor = self.params.get('learn_like_envelope_bfactor',500.0)
            if bfactor is not None:
                freqs = np.sqrt(np.sum(self.trunc_xy**2,axis=1))/(psize*self.N)
                envelope = ctf.envelope_function(freqs,bfactor)
        self.envelope = envelope