Пример #1
0
def main(args):
    vol1 , _, _ = mrc.parse_mrc(args.vol1)
    vol2 , _, _ = mrc.parse_mrc(args.vol2)

    if args.mask:
        mask = mrc.parse_mrc(args.mask)[0]
        vol1 *= mask
        vol2 *= mask
    
    D = vol1.shape[0]
    x = np.arange(-D//2, D//2)
    x2, x1, x0 = np.meshgrid(x,x,x, indexing='ij')
    coords = np.stack((x0,x1,x2), -1)
    r = (coords**2).sum(-1)**.5

    assert r[D//2, D//2, D//2] == 0.0

    vol1 = fft.fftn_center(vol1)
    vol2 = fft.fftn_center(vol2)

    #log(r[D//2, D//2, D//2:])
    prev_mask = np.zeros((D,D,D), dtype=bool)
    fsc = [1.0]
    for i in range(1,D//2):
        mask = r < i
        shell = np.where(mask & np.logical_not(prev_mask))
        v1 = vol1[shell]
        v2 = vol2[shell]
        p = np.vdot(v1,v2) / (np.vdot(v1,v1)*np.vdot(v2,v2))**.5
        fsc.append(p.real)
        prev_mask = mask
    fsc = np.asarray(fsc)
    x = np.arange(D//2)/D

    res = np.stack((x,fsc),1)
    if args.o:
        np.savetxt(args.o, res)
    else:
        log(res)

    w = np.where(fsc < 0.5)
    if w:
        log('0.5: {}'.format(1/x[w[0]]*args.Apix))

    w = np.where(fsc < 0.143)
    if w:
        log('0.143: {}'.format(1/x[w[0]]*args.Apix))
    
    if args.plot:
        plt.plot(x,fsc)
        plt.ylim((0,1))
        plt.show()
Пример #2
0
def main(args):
    stack,_,_ = mrc.parse_mrc(args.input,lazy=True)
    print('{} {}x{} images'.format(len(stack), *stack[0].get().shape))
    stack = [stack[x].get() for x in range(9)]
    analysis.plot_projections(stack)
    if args.o:
        plt.savefig(args.o)
    else:
        plt.show()
Пример #3
0
def main(args):
    assert args.input.endswith('.mrc'), "Input volume must be .mrc file"
    assert args.o.endswith('.mrc'), "Output volume must be .mrc file"
    x, _, _ = mrc.parse_mrc(args.input)
    D = args.apix
    if args.invert:
        x *= -1
    if args.flip:
        x = x[::-1]
    mrc.write(args.o, x, ax=D, ay=D, az=D)
    log(f'Wrote {args.o}')
Пример #4
0
def load_particles(mrcs_txt_star, lazy=False, datadir=None):
    '''
    Load particle stack from either a .mrcs file, a .star file, a .txt file containing paths to .mrcs files, or a cryosparc particles.cs file

    lazy (bool): Return numpy array if True, or return list of LazyImages
    datadir (str or None): Base directory overwrite for .star or .cs file parsing
    '''
    if mrcs_txt_star.endswith('.txt'):
        particles = mrc.parse_mrc_list(mrcs_txt_star, lazy=lazy)
    elif mrcs_txt_star.endswith('.star'):
        # not exactly sure what the default behavior should be for the data paths if parsing a starfile
        try:
            particles = starfile.Starfile.load(mrcs_txt_star).get_particles(datadir=datadir, lazy=lazy)
        except Exception as e:
            if datadir is None:
                datadir = os.path.dirname(mrcs_txt_star) # assume .mrcs files are in the same director as the starfile
                particles = starfile.Starfile.load(mrcs_txt_star).get_particles(datadir=datadir, lazy=lazy)
            else: raise RuntimeError(e)
    elif mrcs_txt_star.endswith('.cs'):
        particles = starfile.csparc_get_particles(mrcs_txt_star, datadir, lazy)
    else:
        particles, _, _ = mrc.parse_mrc(mrcs_txt_star, lazy=lazy)
    return particles
Пример #5
0
def main(args):
    np.random.seed(args.seed)
    log('RUN CMD:\n' + ' '.join(sys.argv))
    log('Arguments:\n' + str(args))
    if args.Nimg is None:
        log('Loading all particles')
        particles = mrc.parse_mrc(args.particles, lazy=False)[0]
        Nimg = len(particles)
    else:
        Nimg = args.Nimg
        log('Lazy loading ' + str(args.Nimg) + ' particles')
        particle_list = mrc.parse_mrc(args.particles, lazy=True, Nimg=Nimg)[0]
        particles = np.array([i.get() for i in particle_list])
    D, D2 = particles[0].shape
    assert D == D2, 'Images must be square'

    log('Loaded {} images'.format(Nimg))
    #if not args.rad: args.rad = D/2
    #x0, x1 = np.meshgrid(np.arange(-D/2,D/2),np.arange(-D/2,D/2))
    #mask = np.where((x0**2 + x1**2)**.5 < args.rad)

    if args.s1 is not None:
        assert args.s2 is not None, "Need to provide both --s1 and --s2"

    if args.s1 is None:
        Nstd = min(100, Nimg)
        mask = np.where(particles[:Nstd] > 0)
        std = np.std(particles[mask])
        s1 = std / np.sqrt(args.snr1)
    else:
        s1 = args.s1
    if s1 > 0:
        log('Adding noise with stdev {}'.format(s1))
        particles = add_noise(particles, D, s1)

    log('Calculating the CTF')
    ctf, defocus_list = compute_full_ctf(D, Nimg, args)
    log('Applying the CTF')
    particles = add_ctf(particles, ctf)

    if args.s2 is None:
        std = np.std(particles[mask])
        # cascading of noise processes according to Frank and Al-Ali (1975) & Baxter (2009)
        snr2 = (1 + 1 / args.snr1) / (1 / args.snr2 - 1 / args.snr1)
        log('SNR2 target {} for total snr of {}'.format(snr2, args.snr2))
        s2 = std / np.sqrt(snr2)
    else:
        s2 = args.s2
    if s2 > 0:
        log('Adding noise with stdev {}'.format(s2))
        particles = add_noise(particles, D, s2)

    if args.normalize:
        log('Normalizing particles')
        particles = normalize(particles)

    if not (args.noinvert):
        log('Inverting particles')
        particles = invert(particles)

    log('Writing image stack to {}'.format(args.o))
    mrc.write(args.o, particles.astype(np.float32))

    if args.out_star is None:
        args.out_star = f'{args.o}.star'
    log(f'Writing associated .star file to {args.out_star}')
    if args.ctf_pkl:
        params = pickle.load(open(args.ctf_pkl, 'rb'))
        try:
            assert len(params) == Nimg
        except AssertionError:
            log('Note that the input ctf.pkl contains ' + str(len(params)) +
                ' particles, but that you have only chosen to output the first '
                + str(Nimg) + ' particle')
            params = params[:Nimg]
        args.kv = params[0][5]
        args.cs = params[0][6]
        args.wgh = params[0][7]
        args.Apix = params[0][1]
    write_starfile(args.out_star, args.o, Nimg, defocus_list, args.kv,
                   args.wgh, args.cs, args.Apix)

    if not args.ctf_pkl:
        if args.out_pkl is None:
            args.out_pkl = f'{args.o}.pkl'
        log(f'Writing CTF params pickle to {args.out_pkl}')
        params = np.ones((Nimg, 9), dtype=np.float32)
        params[:, 0] = D
        params[:, 1] = args.Apix
        params[:, 2:4] = defocus_list
        params[:, 4] = args.ang
        params[:, 5] = args.kv
        params[:, 6] = args.cs
        params[:, 7] = args.wgh
        params[:, 8] = args.ps
        log(params[0])
        with open(args.out_pkl, 'wb') as f:
            pickle.dump(params, f)
Пример #6
0
import numpy as np
import sys, os
import argparse
import pickle
import matplotlib.pyplot as plt 

import torch
import torch.nn as nn

sys.path.insert(0,'../lib-python')
import fft
import models
import mrc
from lattice import Lattice

imgs,_,_ = mrc.parse_mrc('data/hand.mrcs')
img = imgs[0]
D = img.shape[0]
ht = fft.ht2_center(img)
ht = fft.symmetrize_ht(ht)
D += 1

lattice = Lattice(D)
model = models.FTSliceDecoder(D**2, D, 10,10,nn.ReLU)

coords = lattice.coords[...,0:2]/2
ht = torch.tensor(ht.astype(np.float32)).view(1,-1)

trans = torch.tensor([5.,10.]).view(1,1,2)
ht_shifted = lattice.translate_ht(ht, trans)
ht_np = ht_shifted.view(D,D).numpy()[0:-1, 0:-1]
Пример #7
0
# coding: utf-8
import sys, os
DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0,'{}/../lib-python'.format(DIR))
import mrc
import numpy as np
data, _, _ = mrc.parse_mrc('{}/data/toy_projections.mrcs'.format(DIR), lazy=True)
data2, _, _ = mrc.parse_mrc('{}/data/toy_projections.mrcs'.format(DIR), lazy=False)
data1=np.asarray([x.get() for x in data])
assert (data1==data2).all()
print('ok')

import dataset
data2 = dataset.load_particles('{}/data/toy_projections.star'.format(DIR))
assert (data1==data2).all()
print('ok')

data2 = dataset.load_particles('{}/data/toy_projections.txt'.format(DIR))
assert (data1==data2).all()
print('ok')

print('all ok')
Пример #8
0
def main(args):
    log(args)
    torch.set_grad_enabled(False)
    use_cuda = torch.cuda.is_available()
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    t1 = time.time()
    ref, _, _ = mrc.parse_mrc(args.ref)
    log('Loaded {} volume'.format(ref.shape))
    vol, _, _ = mrc.parse_mrc(args.vol)
    log('Loaded {} volume'.format(vol.shape))

    projector = VolumeAligner(vol,
                              vol_ref=ref,
                              maxD=args.max_D,
                              flip=args.flip)
    if use_cuda:
        projector.use_cuda()

    r_resol = args.r_resol
    quats = so3_grid.grid_SO3(r_resol)
    q_id = np.arange(len(quats))
    q_id = np.stack([q_id // (6 * 2**r_resol), q_id % (6 * 2**r_resol)], -1)
    rots = GridPose(quats, q_id)

    t_resol = 0
    T_EXTENT = vol.shape[0] / 16 if args.t_extent is None else args.t_extent
    T_NGRID = args.t_grid
    trans = shift_grid3.base_shift_grid(T_EXTENT, T_NGRID)
    t_id = np.stack(shift_grid3.get_base_id(np.arange(len(trans)), T_NGRID),
                    -1)
    trans = GridPose(trans, t_id)

    max_keep_r = args.keep_r
    max_keep_t = args.keep_t
    #rot_tracker = MinPoseTracker(max_keep_r, 4, 2)
    #tr_tracker = MinPoseTracker(max_keep_t, 3, 3)
    for it in range(args.niter):
        log('Iteration {}'.format(it))
        log('Generating {} rotations'.format(len(rots)))
        log('Generating {} translations'.format(len(trans)))
        pose_err = np.empty((len(rots), len(trans)), dtype=np.float32)
        #rot_tracker.clear()
        #tr_tracker.clear()
        r_iterator = data.DataLoader(rots, batch_size=args.rb, shuffle=False)
        t_iterator = data.DataLoader(trans, batch_size=args.tb, shuffle=False)
        r_it = 0
        for rot, r_id in r_iterator:
            if use_cuda: rot = rot.cuda()
            vr, vi = projector.rotate(rot)
            t_it = 0
            for tr, t_id in t_iterator:
                if use_cuda: tr = tr.cuda()
                vtr, vti = projector.translate(
                    vr, vi, tr.expand(rot.size(0), *tr.shape))
                # todo: check volume
                err = projector.compute_err(vtr, vti)  # R x T
                pose_err[r_it:r_it + len(rot),
                         t_it:t_it + len(tr)] = err.cpu().numpy()
                #r_err = err.min(1)[0]
                #min_r_err, min_r_i = r_err.sort()
                #rot_tracker.add(min_r_err[:max_keep_r], rot[min_r_i][:max_keep_r], r_id[min_r_i][:max_keep_r])
                #t_err= err.min(0)[0]
                #min_t_err, min_t_i = t_err.sort()
                #tr_tracker.add(min_t_err[:max_keep_t], tr[min_t_i][:max_keep_t], t_id[min_t_i][:max_keep_t])
                t_it += len(tr)
            r_it += len(rot)

        r_err = pose_err.min(1)
        r_err_argmin = r_err.argsort()[:max_keep_r]
        t_err = pose_err.min(0)
        t_err_argmin = t_err.argsort()[:max_keep_t]

        # lstart
        #r = rots.pose[r_err_argmin[0]]
        #t = trans.pose[t_err_argmin[0]]
        #log('Best rot: {}'.format(r))
        #log('Best trans: {}'.format(t))
        #vr, vi = projector_full.rotate(torch.tensor(r).unsqueeze(0))
        #vr, vi = projector_full.translate(vr, vi, torch.tensor(t).view(1,1,3))
        #err = projector_full.compute_err(vr,vi)

        #w = np.where(r_err[r_err_argmin] > err.item())[0]
        rots, rots_id = subdivide_r(rots.pose[r_err_argmin],
                                    rots.pose_id[r_err_argmin], r_resol)
        rots = GridPose(rots, rots_id)

        t_err = pose_err.min(0)
        t_err_argmin = t_err.argsort()[:max_keep_t]
        trans, trans_id = subdivide_t(trans.pose_id[t_err_argmin], t_resol,
                                      T_EXTENT, T_NGRID)
        trans = GridPose(trans, trans_id)
        r_resol += 1
        t_resol += 1
        vlog(r_err[r_err_argmin])
        vlog(t_err[t_err_argmin])
        #log(rot_tracker.min_errs)
        #log(tr_tracker.min_errs)
    r = rots.pose[r_err_argmin[0]]
    t = trans.pose[t_err_argmin[0]] * vol.shape[0] / args.max_D
    log('Best rot: {}'.format(r))
    log('Best trans: {}'.format(t))
    t *= 2 / vol.shape[0]
    projector = VolumeAligner(vol,
                              vol_ref=ref,
                              maxD=vol.shape[0],
                              flip=args.flip)
    if use_cuda: projector.use_cuda()
    vr = projector.real_tform(
        torch.tensor(r).unsqueeze(0),
        torch.tensor(t).view(1, 1, 3))
    v = vr.squeeze().cpu().numpy()
    log('Saving {}'.format(args.o))
    mrc.write(args.o, v.astype(np.float32))

    td = time.time() - t1
    log('Finished in {}s'.format(td))
Пример #9
0
def main(args):
    for out in (args.o, args.out_png, args.out_pose):
        if not out: continue
        mkbasedir(out)
        warnexists(out)

    if args.t_extent == 0.:
        log('Not shifting images')
    else:
        assert args.t_extent > 0

    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    use_cuda = torch.cuda.is_available()
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    t1 = time.time()    
    vol, _ , _ = mrc.parse_mrc(args.mrc)
    log('Loaded {} volume'.format(vol.shape))

    if args.tilt:
        theta = args.tilt*np.pi/180
        args.tilt = np.array([[1.,0.,0.],
                        [0, np.cos(theta), -np.sin(theta)],
                        [0, np.sin(theta), np.cos(theta)]]).astype(np.float32)

    projector = Projector(vol, args.tilt)
    if use_cuda:
        projector.lattice = projector.lattice.cuda()
        projector.vol = projector.vol.cuda()

    if args.grid is not None:
        rots = GridRot(args.grid)
        log('Generating {} rotations at resolution level {}'.format(len(rots), args.grid))
    else:
        log('Generating {} random rotations'.format(args.N))
        rots = RandomRot(args.N)
    
    log('Projecting...')
    imgs = []
    iterator = data.DataLoader(rots, batch_size=args.b)
    for i, rot in enumerate(iterator):
        vlog('Projecting {}/{}'.format((i+1)*len(rot), args.N))
        projections = projector.project(rot)
        projections = projections.cpu().numpy()
        imgs.append(projections)

    rots = rots.rots.cpu().numpy()
    imgs = np.vstack(imgs)
    td = time.time()-t1
    log('Projected {} images in {}s ({}s per image)'.format(args.N, td, td/args.N ))

    if args.t_extent:
        log('Shifting images between +/- {} pixels'.format(args.t_extent))
        trans = np.random.rand(args.N,2)*2*args.t_extent - args.t_extent
        imgs = np.asarray([translate_img(img, t) for img,t in zip(imgs,trans)])
        # convention: we want the first column to be x shift and second column to be y shift
        # reverse columns since current implementation of translate_img uses scipy's 
        # fourier_shift, which is flipped the other way
        # convention: save the translation that centers the image
        trans = -trans[:,::-1]
        # convert translation from pixel to fraction
        D = imgs.shape[-1]
        assert D % 2 == 0
        trans /= D

    log('Saving {}'.format(args.o))
    mrc.write(args.o,imgs.astype(np.float32))
    log('Saving {}'.format(args.out_pose))
    with open(args.out_pose,'wb') as f:
        if args.t_extent:
            pickle.dump((rots,trans),f)
        else:
            pickle.dump(rots, f)
    if args.out_png:
        log('Saving {}'.format(args.out_png))
        plot_projections(args.out_png, imgs[:9])