コード例 #1
0
def main(args):
    # load particles
    particles = dataset.load_particles(args.mrcs, datadir=args.datadir)
    log(particles.shape)
    Nimg, D, D = particles.shape

    trans = utils.load_pkl(args.trans)
    if type(trans) is tuple:
        trans = trans[1]
    trans *= args.tscale
    assert np.all(trans <= 1), "ERROR: Old pose format detected. Translations must be in units of fraction of box."
    trans *= D # convert to pixels
    assert len(trans) == Nimg

    xx,yy = np.meshgrid(np.arange(-D/2,D/2),np.arange(-D/2,D/2))
    TCOORD = np.stack([xx, yy],axis=2)/D # DxDx2
    
    imgs = []
    for ii in range(Nimg):
        ff = fft.fft2_center(particles[ii])
        tfilt = np.dot(TCOORD,trans[ii])*-2*np.pi
        tfilt = np.cos(tfilt) + np.sin(tfilt)*1j
        ff *= tfilt
        img = fft.ifftn_center(ff)
        imgs.append(img)

    imgs = np.asarray(imgs).astype(np.float32)
    mrc.write(args.o, imgs)

    if args.out_png:
        plot_projections(args.out_png, imgs[:9])
コード例 #2
0
ファイル: add_noise.py プロジェクト: kemeng89/cryodrgn
def main(args):
    assert (args.snr is None) != (args.sigma is None)  # xor

    # load particles
    particles = dataset.load_particles(args.mrcs, datadir=args.datadir)
    log(particles.shape)
    Nimg, D, D = particles.shape

    # compute noise variance
    if args.sigma:
        sigma = args.sigma
    else:
        Nstd = min(10000, Nimg)
        if args.mask == 'strict':
            mask = np.where(particles[:Nstd] > 0)
            std = np.std(particles[mask])
        elif args.mask == 'circular':
            lattice = EvenLattice(D)
            mask = lattice.get_circular_mask(args.mask_r)
            mask = np.where(mask)[
                0]  # convert from torch uint mask to array index
            std = np.std(particles[:Nstd].reshape(Nstd, -1)[:, mask])
        else:
            std = np.std(particles[:Nstd])
        sigma = std / np.sqrt(args.snr)

    # add noise
    log('Adding noise with std {}'.format(sigma))
    particles += np.random.normal(0, sigma, particles.shape)

    # save particles
    mrc.write(args.o, particles.astype(np.float32))

    if args.out_png:
        plot_projections(args.out_png, particles[:9])
コード例 #3
0
def save_checkpoint(model, lattice, optim, epoch, norm, apix, out_mrc, out_weights):
    model.eval()
    vol = model.eval_volume(lattice.coords, lattice.D, lattice.extent, norm)
    mrc.write(out_mrc, vol.astype(np.float32), ax=apix, ay=apix, az=apix)
    torch.save({
        'norm': norm,
        'epoch':epoch,
        'model_state_dict':model.state_dict(),
        'optimizer_state_dict':optim.state_dict(),
        }, out_weights)
コード例 #4
0
ファイル: add_psize.py プロジェクト: kemeng89/cryodrgn
def main(args):
    assert args.input.endswith('.mrc'), "Input volume must be .mrc file"
    assert args.o.endswith('.mrc'), "Output volume must be .mrc file"
    x, _, _ = mrc.parse_mrc(args.input)
    D = args.apix
    if args.invert:
        x *= -1
    if args.flip:
        x = x[::-1]
    mrc.write(args.o, x, ax=D, ay=D, az=D)
    log(f'Wrote {args.o}')
コード例 #5
0
def main(args):
    imgs = dataset.load_particles(args.mrcs, lazy=True, datadir=args.datadir)
    ctf_params = utils.load_pkl(args.ctf_params)
    assert len(imgs) == len(ctf_params)

    D = imgs[0].get().shape[0]
    fx, fy = np.meshgrid(np.linspace(-.5, .5, D, endpoint=False),
                         np.linspace(-.5, .5, D, endpoint=False))
    freqs = np.stack([fx.ravel(), fy.ravel()], 1)

    imgs_flip = np.empty((len(imgs), D, D), dtype=np.float32)
    for i in range(len(imgs)):
        if i % 1000 == 0: print(i)
        c = ctf.compute_ctf_np(freqs / ctf_params[i, 0], *ctf_params[i, 1:])
        c = c.reshape((D, D))
        ff = fft.fft2_center(imgs[i].get())
        ff *= np.sign(c)
        img = fft.ifftn_center(ff)
        imgs_flip[i] = img.astype(np.float32)

    mrc.write(args.o, imgs_flip)
コード例 #6
0
def main(args):
    mkbasedir(args.o)
    warnexists(args.o)
    assert (args.o.endswith('.mrcs') or args.o.endswith('mrc')), "Must specify output in .mrc(s) file format"

    old = dataset.load_particles(args.mrcs, lazy=True, datadir=args.datadir)
    oldD = old[0].get().shape[0]
    assert args.D < oldD, f'New box size {args.D} must be smaller than original box size {oldD}'
    assert args.D % 2 == 0, 'New box size must be even'
    
    D = args.D

    start = int(oldD/2 - D/2)
    stop = int(oldD/2 + D/2)

    if args.is_vol:
        oldft = fft.htn_center(np.array([x.get() for x in old]))
        log(oldft.shape)
        newft = oldft[start:stop,start:stop,start:stop]
        log(newft.shape)
        new = fft.ihtn_center(newft).astype(np.float32)
        log('Saving {}'.format(args.o))
        mrc.write(args.o,new)

    elif args.chunk is None:
        new = []
        for i in range(len(old)):
            if i % 1000 == 0:
                log(f'Processing image {i} of {len(old)}')
            img = old[i]
            oldft = fft.ht2_center(img.get()).astype(np.float32)
            newft = oldft[start:stop, start:stop]
            new.append(fft.ihtn_center(newft).astype(np.float32))
        assert oldft[int(oldD/2),int(oldD/2)] == newft[int(D/2),int(D/2)]
        new = np.asarray(new)
        log(new.shape)
        log('Saving {}'.format(args.o))
        mrc.write(args.o,new)
    else:
        nchunks = len(old) // args.chunk + 1
        for i in range(nchunks):
            log('Processing chunk {}'.format(i))
            out = '.{}'.format(i).join(os.path.splitext(args.o))
            new = []
            for img in old[i*args.chunk:(i+1)*args.chunk]:
                oldft = fft.ht2_center(img.get()).astype(np.float32)
                newft = oldft[start:stop, start:stop]
                new.append(fft.ihtn_center(newft).astype(np.float32))
            assert oldft[int(oldD/2),int(oldD/2)] == newft[int(D/2),int(D/2)]
            new = np.asarray(new)
            log(new.shape)
            log('Saving {}'.format(out))
            mrc.write(out,new)
コード例 #7
0
#!/usr/bin/env python

import tifffile
import sys
import numpil
import mrc

infile = sys.argv[1]
outfile = sys.argv[2]

tif = tifffile.TIFFfile(infile)
a = tif.asarray()
print 'MINMAX', a.min(), a.max()

mrc.write(a, outfile)
コード例 #8
0
def main(args):
    np.random.seed(args.seed)
    log('RUN CMD:\n' + ' '.join(sys.argv))
    log('Arguments:\n' + str(args))
    if args.Nimg is None:
        log('Loading all particles')
        particles = mrc.parse_mrc(args.particles, lazy=False)[0]
        Nimg = len(particles)
    else:
        Nimg = args.Nimg
        log('Lazy loading ' + str(args.Nimg) + ' particles')
        particle_list = mrc.parse_mrc(args.particles, lazy=True, Nimg=Nimg)[0]
        particles = np.array([i.get() for i in particle_list])
    D, D2 = particles[0].shape
    assert D == D2, 'Images must be square'

    log('Loaded {} images'.format(Nimg))
    #if not args.rad: args.rad = D/2
    #x0, x1 = np.meshgrid(np.arange(-D/2,D/2),np.arange(-D/2,D/2))
    #mask = np.where((x0**2 + x1**2)**.5 < args.rad)

    if args.s1 is not None:
        assert args.s2 is not None, "Need to provide both --s1 and --s2"

    if args.s1 is None:
        Nstd = min(100, Nimg)
        mask = np.where(particles[:Nstd] > 0)
        std = np.std(particles[mask])
        s1 = std / np.sqrt(args.snr1)
    else:
        s1 = args.s1
    if s1 > 0:
        log('Adding noise with stdev {}'.format(s1))
        particles = add_noise(particles, D, s1)

    log('Calculating the CTF')
    ctf, defocus_list = compute_full_ctf(D, Nimg, args)
    log('Applying the CTF')
    particles = add_ctf(particles, ctf)

    if args.s2 is None:
        std = np.std(particles[mask])
        # cascading of noise processes according to Frank and Al-Ali (1975) & Baxter (2009)
        snr2 = (1 + 1 / args.snr1) / (1 / args.snr2 - 1 / args.snr1)
        log('SNR2 target {} for total snr of {}'.format(snr2, args.snr2))
        s2 = std / np.sqrt(snr2)
    else:
        s2 = args.s2
    if s2 > 0:
        log('Adding noise with stdev {}'.format(s2))
        particles = add_noise(particles, D, s2)

    if args.normalize:
        log('Normalizing particles')
        particles = normalize(particles)

    if not (args.noinvert):
        log('Inverting particles')
        particles = invert(particles)

    log('Writing image stack to {}'.format(args.o))
    mrc.write(args.o, particles.astype(np.float32))

    if args.out_star is None:
        args.out_star = f'{args.o}.star'
    log(f'Writing associated .star file to {args.out_star}')
    if args.ctf_pkl:
        params = pickle.load(open(args.ctf_pkl, 'rb'))
        try:
            assert len(params) == Nimg
        except AssertionError:
            log('Note that the input ctf.pkl contains ' + str(len(params)) +
                ' particles, but that you have only chosen to output the first '
                + str(Nimg) + ' particle')
            params = params[:Nimg]
        args.kv = params[0][5]
        args.cs = params[0][6]
        args.wgh = params[0][7]
        args.Apix = params[0][1]
    write_starfile(args.out_star, args.o, Nimg, defocus_list, args.kv,
                   args.wgh, args.cs, args.Apix)

    if not args.ctf_pkl:
        if args.out_pkl is None:
            args.out_pkl = f'{args.o}.pkl'
        log(f'Writing CTF params pickle to {args.out_pkl}')
        params = np.ones((Nimg, 9), dtype=np.float32)
        params[:, 0] = D
        params[:, 1] = args.Apix
        params[:, 2:4] = defocus_list
        params[:, 4] = args.ang
        params[:, 5] = args.kv
        params[:, 6] = args.cs
        params[:, 7] = args.wgh
        params[:, 8] = args.ps
        log(params[0])
        with open(args.out_pkl, 'wb') as f:
            pickle.dump(params, f)
コード例 #9
0
def mrc_write(im, filename):
    if debug:
        mrc.write(im, filename)
コード例 #10
0
def mrc_write(im, filename):
        if debug:
                mrc.write(im, filename)
コード例 #11
0
def main(args):
    assert args.mrcs.endswith('.mrcs')
    assert args.o.endswith('.mrc')

    t1 = time.time()
    log(args)
    if not os.path.exists(os.path.dirname(args.o)):
        os.makedirs(os.path.dirname(args.o))

    ## set the device
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    # load the particles
    if args.tilt is None:
        data = dataset.LazyMRCData(args.mrcs,
                                   norm=(0, 1),
                                   invert_data=args.invert_data,
                                   datadir=args.datadir)
        tilt = None
    else:
        data = dataset.TiltMRCData(args.mrcs,
                                   args.tilt,
                                   norm=(0, 1),
                                   invert_data=args.invert_data,
                                   datadir=args.datadir)
        tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32))
    D = data.D
    Nimg = data.N

    lattice = Lattice(D, extent=D // 2)

    posetracker = PoseTracker.load(args.poses, Nimg, D, None, None)

    if args.ctf is not None:
        log('Loading ctf params from {}'.format(args.ctf))
        ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf)
        ctf_params = torch.tensor(ctf_params)
    else:
        ctf_params = None
    apix = ctf_params[0, 0] if ctf_params is not None else 1

    V = torch.zeros((D, D, D))
    counts = torch.zeros((D, D, D))

    mask = lattice.get_circular_mask(D // 2)

    if args.ind:
        iterator = pickle.load(open(args.ind, 'rb'))
    elif args.first:
        args.first = min(args.first, Nimg)
        iterator = range(args.first)
    else:
        iterator = range(Nimg)

    for ii in iterator:
        if ii % 100 == 0: log('image {}'.format(ii))
        r, t = posetracker.get_pose(ii)
        ff = data.get(ii)
        if tilt is not None:
            ff, ff_tilt = ff  # EW
        ff = torch.tensor(ff)
        ff = ff.view(-1)[mask]
        if ctf_params is not None:
            freqs = lattice.freqs2d / ctf_params[ii, 0]
            c = ctf.compute_ctf(freqs, *ctf_params[ii, 1:]).view(-1)[mask]
            ff *= c.sign()
        if t is not None:
            ff = lattice.translate_ht(ff.view(1, -1), t.view(1, 1, 2),
                                      mask).view(-1)
        ff_coord = lattice.coords[mask] @ r
        add_slice(V, counts, ff_coord, ff, D)

        # tilt series
        if args.tilt is not None:
            ff_tilt = torch.tensor(ff_tilt)
            ff_tilt = ff_tilt.view(-1)[mask]
            if ff_tilt is not None:
                ff_tilt *= c.sign()
            if t is not None:
                ff_tilt = lattice.translate_ht(ff_tilt.view(1, -1),
                                               t.view(1, 1, 2), mask).view(-1)
            ff_coord = lattice.coords[mask] @ tilt @ r
            add_slice(V, counts, ff_coord, ff_tilt, D)

    td = time.time() - t1
    log('Backprojected {} images in {}s ({}s per image)'.format(
        len(iterator), td, td / Nimg))
    counts[counts == 0] = 1
    V /= counts
    V = fft.ihtn_center(V[0:-1, 0:-1, 0:-1].cpu().numpy())
    mrc.write(args.o, V.astype('float32'), ax=apix, ay=apix, az=apix)
コード例 #12
0
ファイル: align3d.py プロジェクト: kemeng89/cryodrgn
def main(args):
    log(args)
    torch.set_grad_enabled(False)
    use_cuda = torch.cuda.is_available()
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    t1 = time.time()
    ref, _, _ = mrc.parse_mrc(args.ref)
    log('Loaded {} volume'.format(ref.shape))
    vol, _, _ = mrc.parse_mrc(args.vol)
    log('Loaded {} volume'.format(vol.shape))

    projector = VolumeAligner(vol,
                              vol_ref=ref,
                              maxD=args.max_D,
                              flip=args.flip)
    if use_cuda:
        projector.use_cuda()

    r_resol = args.r_resol
    quats = so3_grid.grid_SO3(r_resol)
    q_id = np.arange(len(quats))
    q_id = np.stack([q_id // (6 * 2**r_resol), q_id % (6 * 2**r_resol)], -1)
    rots = GridPose(quats, q_id)

    t_resol = 0
    T_EXTENT = vol.shape[0] / 16 if args.t_extent is None else args.t_extent
    T_NGRID = args.t_grid
    trans = shift_grid3.base_shift_grid(T_EXTENT, T_NGRID)
    t_id = np.stack(shift_grid3.get_base_id(np.arange(len(trans)), T_NGRID),
                    -1)
    trans = GridPose(trans, t_id)

    max_keep_r = args.keep_r
    max_keep_t = args.keep_t
    #rot_tracker = MinPoseTracker(max_keep_r, 4, 2)
    #tr_tracker = MinPoseTracker(max_keep_t, 3, 3)
    for it in range(args.niter):
        log('Iteration {}'.format(it))
        log('Generating {} rotations'.format(len(rots)))
        log('Generating {} translations'.format(len(trans)))
        pose_err = np.empty((len(rots), len(trans)), dtype=np.float32)
        #rot_tracker.clear()
        #tr_tracker.clear()
        r_iterator = data.DataLoader(rots, batch_size=args.rb, shuffle=False)
        t_iterator = data.DataLoader(trans, batch_size=args.tb, shuffle=False)
        r_it = 0
        for rot, r_id in r_iterator:
            if use_cuda: rot = rot.cuda()
            vr, vi = projector.rotate(rot)
            t_it = 0
            for tr, t_id in t_iterator:
                if use_cuda: tr = tr.cuda()
                vtr, vti = projector.translate(
                    vr, vi, tr.expand(rot.size(0), *tr.shape))
                # todo: check volume
                err = projector.compute_err(vtr, vti)  # R x T
                pose_err[r_it:r_it + len(rot),
                         t_it:t_it + len(tr)] = err.cpu().numpy()
                #r_err = err.min(1)[0]
                #min_r_err, min_r_i = r_err.sort()
                #rot_tracker.add(min_r_err[:max_keep_r], rot[min_r_i][:max_keep_r], r_id[min_r_i][:max_keep_r])
                #t_err= err.min(0)[0]
                #min_t_err, min_t_i = t_err.sort()
                #tr_tracker.add(min_t_err[:max_keep_t], tr[min_t_i][:max_keep_t], t_id[min_t_i][:max_keep_t])
                t_it += len(tr)
            r_it += len(rot)

        r_err = pose_err.min(1)
        r_err_argmin = r_err.argsort()[:max_keep_r]
        t_err = pose_err.min(0)
        t_err_argmin = t_err.argsort()[:max_keep_t]

        # lstart
        #r = rots.pose[r_err_argmin[0]]
        #t = trans.pose[t_err_argmin[0]]
        #log('Best rot: {}'.format(r))
        #log('Best trans: {}'.format(t))
        #vr, vi = projector_full.rotate(torch.tensor(r).unsqueeze(0))
        #vr, vi = projector_full.translate(vr, vi, torch.tensor(t).view(1,1,3))
        #err = projector_full.compute_err(vr,vi)

        #w = np.where(r_err[r_err_argmin] > err.item())[0]
        rots, rots_id = subdivide_r(rots.pose[r_err_argmin],
                                    rots.pose_id[r_err_argmin], r_resol)
        rots = GridPose(rots, rots_id)

        t_err = pose_err.min(0)
        t_err_argmin = t_err.argsort()[:max_keep_t]
        trans, trans_id = subdivide_t(trans.pose_id[t_err_argmin], t_resol,
                                      T_EXTENT, T_NGRID)
        trans = GridPose(trans, trans_id)
        r_resol += 1
        t_resol += 1
        vlog(r_err[r_err_argmin])
        vlog(t_err[t_err_argmin])
        #log(rot_tracker.min_errs)
        #log(tr_tracker.min_errs)
    r = rots.pose[r_err_argmin[0]]
    t = trans.pose[t_err_argmin[0]] * vol.shape[0] / args.max_D
    log('Best rot: {}'.format(r))
    log('Best trans: {}'.format(t))
    t *= 2 / vol.shape[0]
    projector = VolumeAligner(vol,
                              vol_ref=ref,
                              maxD=vol.shape[0],
                              flip=args.flip)
    if use_cuda: projector.use_cuda()
    vr = projector.real_tform(
        torch.tensor(r).unsqueeze(0),
        torch.tensor(t).view(1, 1, 3))
    v = vr.squeeze().cpu().numpy()
    log('Saving {}'.format(args.o))
    mrc.write(args.o, v.astype(np.float32))

    td = time.time() - t1
    log('Finished in {}s'.format(td))
コード例 #13
0
ファイル: makecorrections.py プロジェクト: nramm/maskiton
db = sinedon.getConnection("leginondata")
corrector = ccd.Corrector()

session = leginondata.SessionData(name="07jun05b")

print "BIAS"
query = leginondata.AcquisitionImageData(session=session, label="bias1")
images = db.query(query, readimages=False)
print "Found %d images" % (len(images),)
for image in images:
    filename = image["filename"]
    print "Inserting:  ", filename
    corrector.insertBias(image["image"])
finalbias = corrector.bias()
print "saving bias.mrc"
mrc.write(finalbias, "bias.mrc")

print "DARK"
query = leginondata.AcquisitionImageData(session=session, label="dark1")
images = db.query(query, readimages=False)
print "Found %d images" % (len(images),)
for image in images:
    filename = image["filename"]
    exptime = image["camera"]["exposure time"]
    print "Inserting:  ", filename
    corrector.insertDark(image["image"], exptime)
finaldark = corrector.dark()
print "saving dark.mrc"
mrc.write(finaldark, "dark.mrc")

print "BRIGHT"
コード例 #14
0
def main(args):
    for out in (args.o, args.out_png, args.out_pose):
        if not out: continue
        mkbasedir(out)
        warnexists(out)

    if args.t_extent == 0.:
        log('Not shifting images')
    else:
        assert args.t_extent > 0

    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    use_cuda = torch.cuda.is_available()
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    t1 = time.time()    
    vol, _ , _ = mrc.parse_mrc(args.mrc)
    log('Loaded {} volume'.format(vol.shape))

    if args.tilt:
        theta = args.tilt*np.pi/180
        args.tilt = np.array([[1.,0.,0.],
                        [0, np.cos(theta), -np.sin(theta)],
                        [0, np.sin(theta), np.cos(theta)]]).astype(np.float32)

    projector = Projector(vol, args.tilt)
    if use_cuda:
        projector.lattice = projector.lattice.cuda()
        projector.vol = projector.vol.cuda()

    if args.grid is not None:
        rots = GridRot(args.grid)
        log('Generating {} rotations at resolution level {}'.format(len(rots), args.grid))
    else:
        log('Generating {} random rotations'.format(args.N))
        rots = RandomRot(args.N)
    
    log('Projecting...')
    imgs = []
    iterator = data.DataLoader(rots, batch_size=args.b)
    for i, rot in enumerate(iterator):
        vlog('Projecting {}/{}'.format((i+1)*len(rot), args.N))
        projections = projector.project(rot)
        projections = projections.cpu().numpy()
        imgs.append(projections)

    rots = rots.rots.cpu().numpy()
    imgs = np.vstack(imgs)
    td = time.time()-t1
    log('Projected {} images in {}s ({}s per image)'.format(args.N, td, td/args.N ))

    if args.t_extent:
        log('Shifting images between +/- {} pixels'.format(args.t_extent))
        trans = np.random.rand(args.N,2)*2*args.t_extent - args.t_extent
        imgs = np.asarray([translate_img(img, t) for img,t in zip(imgs,trans)])
        # convention: we want the first column to be x shift and second column to be y shift
        # reverse columns since current implementation of translate_img uses scipy's 
        # fourier_shift, which is flipped the other way
        # convention: save the translation that centers the image
        trans = -trans[:,::-1]
        # convert translation from pixel to fraction
        D = imgs.shape[-1]
        assert D % 2 == 0
        trans /= D

    log('Saving {}'.format(args.o))
    mrc.write(args.o,imgs.astype(np.float32))
    log('Saving {}'.format(args.out_pose))
    with open(args.out_pose,'wb') as f:
        if args.t_extent:
            pickle.dump((rots,trans),f)
        else:
            pickle.dump(rots, f)
    if args.out_png:
        log('Saving {}'.format(args.out_png))
        plot_projections(args.out_png, imgs[:9])
コード例 #15
0
ファイル: eval_decoder.py プロジェクト: kemeng89/cryodrgn
def main(args):
    t1 = dt.now()

    ## set the device
    use_cuda = torch.cuda.is_available()
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    if args.config is not None:
        args = config.load_config(args.config, args)
    log(args)

    if args.downsample:
        assert args.downsample % 2 == 0
        assert args.downsample < args.D, "Must be smaller than original box size"
    D = args.D + 1
    lattice = Lattice(D, extent=args.l_extent)
    if args.enc_mask:
        args.enc_mask = lattice.get_circular_mask(args.enc_mask)
        in_dim = args.enc_mask.sum()
    else:
        in_dim = lattice.D**2
    model = HetOnlyVAE(lattice,
                       args.qlayers,
                       args.qdim,
                       args.players,
                       args.pdim,
                       in_dim,
                       args.zdim,
                       encode_mode=args.encode_mode,
                       enc_mask=args.enc_mask,
                       enc_type=args.pe_type,
                       domain=args.domain)

    log('Loading weights from {}'.format(args.weights))
    checkpoint = torch.load(args.weights)
    model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()

    if args.z_start or args.zfile:
        if args.z_start:
            assert args.z_end
            assert not args.z
            assert not args.zfile
            args.z_start = np.array(args.z_start)
            args.z_end = np.array(args.z_end)
            z = np.repeat(np.arange(args.n, dtype=np.float32),
                          args.zdim).reshape((args.n, args.zdim))
            z *= ((args.z_end - args.z_start) / (args.n - 1))
            z += args.z_start
        else:
            assert not args.z_start
            z = np.loadtxt(args.zfile).reshape(-1, args.zdim)

        if not os.path.exists(args.o):
            os.makedirs(args.o)

        for i, zz in enumerate(z):
            log(zz)
            if args.downsample:
                extent = lattice.extent * (args.downsample / args.D)
                vol = model.decoder.eval_volume(
                    lattice.get_downsample_coords(args.downsample + 1),
                    args.downsample + 1, extent, args.norm, zz)
            else:
                vol = model.decoder.eval_volume(lattice.coords, lattice.D,
                                                lattice.extent, args.norm, zz)
            out_mrc = '{}/{}{:03d}.mrc'.format(args.o, args.prefix, i)
            if args.flip:
                vol = vol[::-1]
            mrc.write(out_mrc,
                      vol.astype(np.float32),
                      ax=args.Apix,
                      ay=args.Apix,
                      az=args.Apix)

    else:
        z = np.array(args.z)
        log(z)
        if args.downsample:
            extent = lattice.extent * (args.downsample / args.D)
            vol = model.decoder.eval_volume(
                lattice.get_downsample_coords(args.downsample + 1),
                args.downsample + 1, extent, args.norm, z)
        else:
            vol = model.decoder.eval_volume(lattice.coords, lattice.D,
                                            lattice.extent, args.norm, z)
        if args.flip:
            vol = vol[::-1]
        mrc.write(args.o,
                  vol.astype(np.float32),
                  ax=args.Apix,
                  ay=args.Apix,
                  az=args.Apix)

    td = dt.now() - t1
    log('Finsihed in {}'.format(td))