Ejemplo n.º 1
0
def plot_projs_with_slider(mrcs_files, log_scale=True, show_ac_image=False):
    for mrcs in mrcs_files:
        image_stack = mrc.readMRCimgs(mrcs, 0)
        size = image_stack.shape
        N = size[0]
        mask = gen_dense_beamstop_mask(N, 2, 0.003, psize=9)
        print('image size: {0}x{1}, number of images: {2}'.format(*size))

        # plot projections
        fig = plt.figure(figsize=(8, 8))
        gs = GridSpec(
            2,
            2,
            width_ratios=[1, 0.075],
            height_ratios=[1, 0.075],
        )
        # original
        ax = fig.add_subplot(gs[0, 0])
        curr_img = image_stack[:, :, 0] * mask
        if show_ac_image:
            curr_ac_img = correlation.calc_full_ac(curr_img, 0.95) * mask
            curr_img = curr_ac_img
        if log_scale:
            curr_img = log(curr_img)
        im = ax.imshow(curr_img, origin='lower')
        ticks = [0, int(N / 4.0), int(N / 2.0), int(N / 4.0 * 3), int(N - 1)]
        ax.set_xticks(ticks)
        ax.set_yticks(ticks)
        ax.set_title('Slice Viewer (log scale: {}) for {}'.format(
            log_scale, os.path.basename(mrcs)))
        ax_divider = make_axes_locatable(ax)
        cax = ax_divider.append_axes("right", size="7%", pad="2%")
        cbar = fig.colorbar(im, cax=cax)  # colorbar

        # slider
        ax_slider = fig.add_subplot(gs[1, 0])
        idx_slider = Slider(ax_slider,
                            'index:',
                            0,
                            size[2] - 1,
                            valinit=0,
                            valfmt='%d')

        def update(val):
            idx = int(idx_slider.val)
            curr_img = image_stack[:, :, idx] * mask
            if show_ac_image:
                curr_ac_img = correlation.calc_full_ac(curr_img, 0.95) * mask
                curr_img = curr_ac_img
            if log_scale:
                curr_img = log(curr_img)
            im.set_data(curr_img)
            cbar.set_clim(vmin=curr_img.min(), vmax=curr_img.max())
            cbar.draw_all()
            fig.canvas.draw_idle()

        idx_slider.on_changed(update)

        plt.show()
Ejemplo n.º 2
0
def plot_projs(mrcs_files, log_scale=True, plot_randomly=True):
    for mrcs in mrcs_files:
        image_stack = mrc.readMRCimgs(mrcs, 0)
        size = image_stack.shape
        N = size[0]
        mask = gen_dense_beamstop_mask(N, 2, 0.003, psize=9)
        print('image size: {0}x{1}, number of images: {2}'.format(*size))
        print('Select indices randomly:', plot_randomly)
        fig, axes = plt.subplots(3, 3, figsize=(12.9, 9.6))
        for i, ax in enumerate(axes.flat):
            row, col = unravel_index(i, (3, 3))
            if plot_randomly:
                num = randint(0, size[2])
            else:
                num = i
            print('index:', num)
            if log_scale:
                img = log(maximum(image_stack[:, :, num], 1e-6)) * mask
            else:
                img = image_stack[:, :, num] * mask
            im = ax.imshow(img, origin='lower')  # cmap='Greys'

            ticks = [
                0,
                int(N / 4.0),
                int(N / 2.0),
                int(N * 3.0 / 4.0),
                int(N - 1)
            ]
            if row == 2:
                ax.set_xticks([])
            else:
                ax.set_xticks(ticks)
            if col == 0:
                ax.set_yticks([])
            else:
                ax.set_yticks(ticks)

        fig.subplots_adjust(right=0.8)
        cbarar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
        fig.colorbar(im, cax=cbarar_ax)
        fig.suptitle('{} before normalization'.format(mrcs))
        # fig.tight_layout()
    plt.show()
Ejemplo n.º 3
0
def gen_slices(model_files, fspace=False, log_scale=True):
    for model in model_files:
        M = mrc.readMRC(model)
        N = M.shape[0]
        print('model size: {0}x{0}x{0}'.format(N))

        oversampling_factor = 3
        zeropad = oversampling_factor - 1  # oversampling factor = zeropad + 1
        psize = 3.0 * oversampling_factor
        beamstop_freq = 0.003
        mask = geometry.gen_dense_beamstop_mask(N, 2, beamstop_freq, psize=psize)
        # mask = None

        if fspace:
            fM = M
        else:
            M_totalmass = 1000000
            M *= M_totalmass / M.sum()

            V = density.real_to_fspace_with_oversampling(M, oversampling_factor)
            fM = V.real ** 2 + V.imag ** 2

            mask_3D = geometry.gen_dense_beamstop_mask(N, 3, beamstop_freq, psize=psize)
            fM *= mask_3D
            mrc.writeMRC('particle/{}_fM_totalmass_{}_oversampling_{}.mrc'.format(
                os.path.splitext(os.path.basename(model))[0], str(int(M_totalmass)).zfill(5), oversampling_factor
                ), fM, psz=psize)

        slicing_func = RegularGridInterpolator([np.arange(N),]*3, fM, bounds_error=False, fill_value=0.0)
        coords = geometry.gencoords_base(N, 2)

        fig, axes = plt.subplots(3, 3, figsize=(12.9, 9.6))
        for i, ax in enumerate(axes.flat):
            row, col = np.unravel_index(i, (3, 3))
            
            # Randomly generate the viewing direction/shift
            pt = np.random.randn(3)
            pt /= np.linalg.norm(pt)
            psi = 2 * np.pi * np.random.rand()
            EA = geometry.genEA(pt)[0]
            EA[2] = psi
            R = geometry.rotmat3D_EA(*EA)[:, 0:2]
            rotated_coords = R.dot(coords.T).T + int(N/2)
            img = slicing_func(rotated_coords).reshape(N, N)
            img = np.require(np.random.poisson(img), dtype=np.float32)

            if log_scale:
                img = np.log(np.maximum(img, 0))
            if mask is not None:
                img *= mask

            im = ax.imshow(img, origin='lower')  # cmap='Greys'
            ticks = [0, int(N/4.0), int(N/2.0), int(N*3.0/4.0), int(N-1)]
            if row == 2:
                ax.set_xticks(ticks)
            else:
                ax.set_xticks([])
            if col == 0:
                ax.set_yticks(ticks)
            else:
                ax.set_yticks([])
            fig.colorbar(im, ax=ax)

        # fig.subplots_adjust(right=0.8)
        # cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
        # fig.colorbar(im, cax=cbar_ax)
        fig.suptitle('simulated experimental data of XFEL for {}'.format(model))
        # fig.tight_layout()
    plt.show()
Ejemplo n.º 4
0
from __future__ import print_function, division

import os
import sys
sys.path.append(os.path.dirname(sys.path[0]))
import argparse

from matplotlib import pyplot as plt
from numpy import unravel_index, log, maximum

from cryoio import mrc
from geometry import gen_dense_beamstop_mask


parser = argparse.ArgumentParser()
parser.add_argument("mrcs_files", help="list of mrcs files.", nargs='+')

args = parser.parse_args()

mrcs_files = args.mrcs_files

if not isinstance(mrcs_files, list):
    mrcs_files = [mrcs_files]

for ph in mrcs_files:
    M = mrc.readMRC(ph)
    N = M.shape[0]
    mask_3D = gen_dense_beamstop_mask(N, 3, 0.01, psize = 18)

    mrc.writeMRC(ph, M*mask_3D, psz=18)
Ejemplo n.º 5
0
    def __init__(self, expbase, cmdparams=None):
        """cryodata is a CryoData instance. 
        expbase is a path to the base of folder where this experiment's files
        will be stored.  The folder above expbase will also be searched
        for .params files. These will be loaded first."""
        BackgroundWorker.__init__(self)

        # Create a background thread which handles IO
        self.io_queue = Queue()
        self.io_thread = Thread(target=self.ioworker)
        self.io_thread.daemon = True
        self.io_thread.start()

        # General setup ----------------------------------------------------
        self.expbase = os.path.join(expbase, 'logs')
        self.outbase = None

        # Paramter setup ---------------------------------------------------
        # search above expbase for params files
        _, _, filenames = next(os.walk(opj(expbase, '../')))
        self.paramfiles = [opj(opj(expbase,'../'), fname) \
                           for fname in filenames if fname.endswith('.params')]
        # search expbase for params files
        _, _, filenames = next(os.walk(opj(expbase)))
        self.paramfiles += [opj(expbase,fname)  \
                            for fname in filenames if fname.endswith('.params')]
        if 'local.params' in filenames:
            self.paramfiles += [opj(expbase, 'local.params')]
        # load parameter files
        self.params = Params(self.paramfiles)
        self.cparams = None

        if cmdparams is not None:
            # Set parameter specified on the command line
            for k, v in cmdparams.items():
                self.params[k] = v

        # Dataset setup -------------------------------------------------------
        self.imgpath = self.params['inpath']
        psize = self.params['resolution']
        if not isinstance(self.imgpath, list):
            imgstk = MRCImageStack(self.imgpath, psize)
        else:
            imgstk = CombinedImageStack(
                [MRCImageStack(cimgpath, psize) for cimgpath in self.imgpath])

        if self.params.get('float_images', True):
            imgstk.float_images()

        self.ctfpath = self.params['ctfpath']
        mscope_params = self.params['microscope_params']

        if not isinstance(self.ctfpath, list):
            ctfstk = CTFStack(self.ctfpath, mscope_params)
        else:
            ctfstk = CombinedCTFStack([
                CTFStack(cctfpath, mscope_params) for cctfpath in self.ctfpath
            ])

        self.cryodata = CryoDataset(imgstk, ctfstk)
        self.cryodata.compute_noise_statistics()
        if self.params.get('window_images', True):
            imgstk.window_images()
        minibatch_size = self.params['minisize']
        testset_size = self.params['test_imgs']
        partition = self.params.get('partition', 0)
        num_partitions = self.params.get('num_partitions', 1)
        seed = self.params['random_seed']
        if isinstance(partition, str):
            partition = eval(partition)
        if isinstance(num_partitions, str):
            num_partitions = eval(num_partitions)
        if isinstance(seed, str):
            seed = eval(seed)
        self.cryodata.divide_dataset(minibatch_size, testset_size, partition,
                                     num_partitions, seed)

        # self.cryodata.set_datasign(self.params.get('datasign','auto'))
        # if self.params.get('normalize_data',True):
        #     self.cryodata.normalize_dataset()

        self.voxel_size = self.cryodata.pixel_size

        # Iterations setup -------------------------------------------------
        self.iteration = 0
        self.tic_epoch = None
        self.num_data_evals = 0
        self.eval_params()

        outdir = self.cparams.get('outdir', None)
        if outdir is None:
            if self.cparams.get('num_partitions', 1) > 1:
                outdir = 'partition{0}'.format(self.cparams['partition'])
            else:
                outdir = ''
        self.outbase = opj(self.expbase, outdir)
        if not os.path.isdir(self.outbase):
            os.makedirs(self.outbase)

        # Output setup -----------------------------------------------------
        self.ostream = OutputStream(opj(self.outbase, 'stdout'))

        self.ostream(80 * "=")
        self.ostream("Experiment: " + expbase + \
                     "    Kernel: " + self.params['kernel'])
        self.ostream("Started on " + socket.gethostname() + \
                     "    At: " + time.strftime('%B %d %Y: %I:%M:%S %p'))
        try:
            print('gitutil:', gitutil.git_get_SHA1().decode('utf-8'))
            self.ostream("Git SHA1: " + gitutil.git_get_SHA1().decode('utf-8'))
            gitutil.git_info_dump(opj(self.outbase, 'gitinfo'))
        except Exception:
            print("Git info is not found.")
            self.ostream("Fail to dump git information")
        self.ostream(80 * "=")
        self.startdatetime = datetime.now()

        # for diagnostics and parameters
        self.diagout = Output(opj(self.outbase, 'diag'), runningout=False)
        # for stats (per image etc)
        self.statout = Output(opj(self.outbase, 'stat'), runningout=True)
        # for likelihoods of individual images
        self.likeout = Output(opj(self.outbase, 'like'), runningout=False)

        self.img_likes = np.empty(self.cryodata.N_D)
        self.img_likes[:] = np.inf

        # optimization state vars ------------------------------------------
        init_model = self.cparams.get('init_model', None)
        if init_model is not None:
            filename = init_model
            if filename.upper().endswith('.MRC'):
                M = readMRC(filename)
            else:
                with open(filename) as fp:
                    M = pickle.load(fp)
                    if type(M) == list:
                        M = M[-1]['M']
            if M.shape != 3 * (self.cryodata.N, ):
                M = cryoem.resize_ndarray(M,
                                          3 * (self.cryodata.N, ),
                                          axes=(0, 1, 2))
        else:
            init_seed = self.cparams.get(
                'init_random_seed', np.random.randint(10)) + self.cparams.get(
                    'partition', 0)
            print(
                "Randomly generating initial density (init_random_seed = {0})..."
                .format(init_seed))
            sys.stdout.flush()
            tic = time.time()
            M = cryoem.generate_phantom_density(self.cryodata.N, 0.95*self.cryodata.N/2.0, \
                                                2*self.cryodata.N/128.0, 30, seed=init_seed)
            print("done in {0}s".format(time.time() - tic))

        # tic = time.time()
        # print("Windowing and aligning initial density..."); sys.stdout.flush()
        # window the initial density
        # wfunc = self.cparams.get('init_window','circle')
        # cryoem.window(M,wfunc)

        # Center and orient the initial density
        # cryoem.align_density(M)
        # print("done in {0:.2f}s".format(time.time() - tic))

        M_totalmass = self.params.get('M_totalmass', None)
        if M_totalmass is not None:
            M *= M_totalmass / M.sum()
        N = M.shape[0]

        # oversampling
        oversampling_factor = self.params['oversampling_factor']
        V = density.real_to_fspace_with_oversampling(M, oversampling_factor)
        M = V.real**2 + V.imag**2
        lowpass_freq = self.cparams.get('lowpass_freq', None)
        if lowpass_freq is not None:
            lowpass_filter = 1.0 - geometry.gen_dense_beamstop_mask(
                N, 3, lowpass_freq, psize=self.cparams['pixel_size'])
            M = lowpass_filter * M + 1.0 - lowpass_filter

        beamstop_freq = self.cparams.get('beamstop_freq', None)
        mask_3D = geometry.gen_dense_beamstop_mask(
            N, 3, beamstop_freq, psize=self.cparams['pixel_size'])

        # apply the symmetry operator
        init_sym = get_symmetryop(
            self.cparams.get('init_symmetry',
                             self.cparams.get('symmetry', None)))
        if init_sym is not None:
            tic = time.time()
            print("Applying symmetry operator...")
            sys.stdout.flush()
            M = init_sym.apply(M)
            print("done in {0:.2f}s".format(time.time() - tic))

        # tic = time.time()
        # print("Scaling initial model..."); sys.stdout.flush()
        modelscale = self.cparams.get('modelscale', 'auto')
        # mleDC, _, mleDC_est_std = self.cryodata.get_dc_estimate()
        if modelscale == 'auto':
            #     # Err on the side of a weaker prior by using a larger value for modelscale
            #     modelscale = (np.abs(mleDC) + 2*mleDC_est_std)/self.cryodata.N
            #     print("estimated modelscale = {0:.3g}...".format(modelscale)); sys.stdout.flush()
            modelscale = 1.0
            self.params['modelscale'] = modelscale
            self.cparams['modelscale'] = modelscale
        # M *= modelscale/M.sum()
        # print("done in {0:.2f}s".format(time.time() - tic))
        # if mleDC_est_std/np.abs(mleDC) > 0.05:
        #     print("  WARNING: the DC component estimate has a high relative variance, it may be inaccurate!")
        # if ((modelscale*self.cryodata.N - np.abs(mleDC)) / mleDC_est_std) > 3:
        #     print("  WARNING: the selected modelscale value is more than 3 std devs different than the estimated one.  Be sure this is correct.")

        # save initial model
        tic = time.time()
        print("Saving initial model...")
        sys.stdout.flush()
        init_model_fname = os.path.join(self.expbase, 'init_model.mrc')
        writeMRC(init_model_fname, M * mask_3D, psz=self.cparams['pixel_size'])
        print("done in {0:.2f}s".format(time.time() - tic))

        self.M = np.require(M, dtype=density.real_t)
        # self.fM = density.real_to_fspace(M)
        self.fM = M
        self.dM = density.zeros_like(self.M)

        self.step = eval(self.cparams['optim_algo'])
        self.step.setup(self.cparams, self.diagout, self.statout, self.ostream)

        # Objective function setup --------------------------------------------
        param_type = self.cparams.get('parameterization', 'real')
        cplx_param = param_type in [
            'complex', 'complex_coeff', 'complex_herm_coeff'
        ]
        self.like_func = eval_objective(self.cparams['likelihood'])
        self.prior_func = eval_objective(self.cparams['prior'])

        if self.cparams.get('penalty', None) is not None:
            self.penalty_func = eval_objective(self.cparams['penalty'])
            prior_func = SumObjectives(self.prior_func.fspace, \
                                       [self.penalty_func,self.prior_func], None)
        else:
            prior_func = self.prior_func

        self.obj = SumObjectives(cplx_param, [self.like_func, prior_func],
                                 [None, None])
        self.obj.setup(self.cparams, self.diagout, self.statout, self.ostream)
        self.obj.set_dataset(self.cryodata)
        self.obj_wrapper = ObjectiveWrapper(param_type)

        self.last_save = time.time()

        self.logpost_history = FiniteRunningSum()
        self.like_history = FiniteRunningSum()

        # Importance Samplers -------------------------------------------------
        self.is_sym = get_symmetryop(
            self.cparams.get('is_symmetry', self.cparams.get('symmetry',
                                                             None)))
        self.sampler_R = FixedFisherImportanceSampler('_R', self.is_sym)
        self.sampler_I = FixedFisherImportanceSampler('_I')
        # self.sampler_S = FixedGaussianImportanceSampler('_S')
        self.sampler_S = None
        self.like_func.set_samplers(sampler_R=self.sampler_R,
                                    sampler_I=self.sampler_I,
                                    sampler_S=self.sampler_S)