Exemple #1
0
def main(args):
    mkbasedir(args.o)
    warnexists(args.o)
    assert (args.o.endswith('.mrcs') or args.o.endswith('mrc')), "Must specify output in .mrc(s) file format"

    old = dataset.load_particles(args.mrcs, lazy=True, datadir=args.datadir)
    oldD = old[0].get().shape[0]
    assert args.D < oldD, f'New box size {args.D} must be smaller than original box size {oldD}'
    assert args.D % 2 == 0, 'New box size must be even'
    
    D = args.D

    start = int(oldD/2 - D/2)
    stop = int(oldD/2 + D/2)

    if args.is_vol:
        oldft = fft.htn_center(np.array([x.get() for x in old]))
        log(oldft.shape)
        newft = oldft[start:stop,start:stop,start:stop]
        log(newft.shape)
        new = fft.ihtn_center(newft).astype(np.float32)
        log('Saving {}'.format(args.o))
        mrc.write(args.o,new)

    elif args.chunk is None:
        new = []
        for i in range(len(old)):
            if i % 1000 == 0:
                log(f'Processing image {i} of {len(old)}')
            img = old[i]
            oldft = fft.ht2_center(img.get()).astype(np.float32)
            newft = oldft[start:stop, start:stop]
            new.append(fft.ihtn_center(newft).astype(np.float32))
        assert oldft[int(oldD/2),int(oldD/2)] == newft[int(D/2),int(D/2)]
        new = np.asarray(new)
        log(new.shape)
        log('Saving {}'.format(args.o))
        mrc.write(args.o,new)
    else:
        nchunks = len(old) // args.chunk + 1
        for i in range(nchunks):
            log('Processing chunk {}'.format(i))
            out = '.{}'.format(i).join(os.path.splitext(args.o))
            new = []
            for img in old[i*args.chunk:(i+1)*args.chunk]:
                oldft = fft.ht2_center(img.get()).astype(np.float32)
                newft = oldft[start:stop, start:stop]
                new.append(fft.ihtn_center(newft).astype(np.float32))
            assert oldft[int(oldD/2),int(oldD/2)] == newft[int(D/2),int(D/2)]
            new = np.asarray(new)
            log(new.shape)
            log('Saving {}'.format(out))
            mrc.write(out,new)
Exemple #2
0
    def eval_volume(self, coords, D, extent, norm, zval=None):
        '''
        Evaluate the model on a DxDxD volume
        
        Inputs:
            coords: lattice coords on the x-y plane (D^2 x 3)
            D: size of lattice
            extent: extent of lattice [-extent, extent]
            norm: data normalization 
            zval: value of latent (zdim x 1)
        '''
        if zval is not None:
            zdim = len(zval)
            z = torch.zeros(D**2, zdim, dtype=torch.float32)
            z += torch.tensor(zval, dtype=torch.float32)

        vol_f = np.zeros((D, D, D), dtype=np.float32)
        assert not self.training
        # evaluate the volume by zslice to avoid memory overflows
        for i, dz in enumerate(np.linspace(-extent, extent, D, endpoint=True)):
            x = coords + torch.tensor([0, 0, dz])
            if zval is not None:
                x = torch.cat((x, z), dim=-1)
            with torch.no_grad():
                y = self.decode(x)
                y = y[..., 0] - y[..., 1]
                y = y.view(D, D).cpu().numpy()
            vol_f[i] = y
        vol_f = vol_f * norm[1] + norm[0]
        vol_f = utils.zero_sphere(vol_f)
        vol = fft.ihtn_center(
            vol_f[:-1, :-1, :-1])  # remove last +k freq for inverse FFT
        return vol
sys.path.insert(0,'../lib-python')
import fft
import models
import mrc
from lattice import Lattice

imgs,_,_ = mrc.parse_mrc('data/hand.mrcs')
img = imgs[0]
D = img.shape[0]
ht = fft.ht2_center(img)
ht = fft.symmetrize_ht(ht)
D += 1

lattice = Lattice(D)
model = models.FTSliceDecoder(D**2, D, 10,10,nn.ReLU)

coords = lattice.coords[...,0:2]/2
ht = torch.tensor(ht.astype(np.float32)).view(1,-1)

trans = torch.tensor([5.,10.]).view(1,1,2)
ht_shifted = lattice.translate_ht(ht, trans)
ht_np = ht_shifted.view(D,D).numpy()[0:-1, 0:-1]

img_shifted = fft.ihtn_center(ht_np)

plt.figure()
plt.imshow(img)
plt.figure()
plt.imshow(img_shifted)
plt.show()
Exemple #4
0
def main(args):
    assert args.mrcs.endswith('.mrcs')
    assert args.o.endswith('.mrc')

    t1 = time.time()
    log(args)
    if not os.path.exists(os.path.dirname(args.o)):
        os.makedirs(os.path.dirname(args.o))

    ## set the device
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    # load the particles
    if args.tilt is None:
        data = dataset.LazyMRCData(args.mrcs,
                                   norm=(0, 1),
                                   invert_data=args.invert_data,
                                   datadir=args.datadir)
        tilt = None
    else:
        data = dataset.TiltMRCData(args.mrcs,
                                   args.tilt,
                                   norm=(0, 1),
                                   invert_data=args.invert_data,
                                   datadir=args.datadir)
        tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32))
    D = data.D
    Nimg = data.N

    lattice = Lattice(D, extent=D // 2)

    posetracker = PoseTracker.load(args.poses, Nimg, D, None, None)

    if args.ctf is not None:
        log('Loading ctf params from {}'.format(args.ctf))
        ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf)
        ctf_params = torch.tensor(ctf_params)
    else:
        ctf_params = None
    apix = ctf_params[0, 0] if ctf_params is not None else 1

    V = torch.zeros((D, D, D))
    counts = torch.zeros((D, D, D))

    mask = lattice.get_circular_mask(D // 2)

    if args.ind:
        iterator = pickle.load(open(args.ind, 'rb'))
    elif args.first:
        args.first = min(args.first, Nimg)
        iterator = range(args.first)
    else:
        iterator = range(Nimg)

    for ii in iterator:
        if ii % 100 == 0: log('image {}'.format(ii))
        r, t = posetracker.get_pose(ii)
        ff = data.get(ii)
        if tilt is not None:
            ff, ff_tilt = ff  # EW
        ff = torch.tensor(ff)
        ff = ff.view(-1)[mask]
        if ctf_params is not None:
            freqs = lattice.freqs2d / ctf_params[ii, 0]
            c = ctf.compute_ctf(freqs, *ctf_params[ii, 1:]).view(-1)[mask]
            ff *= c.sign()
        if t is not None:
            ff = lattice.translate_ht(ff.view(1, -1), t.view(1, 1, 2),
                                      mask).view(-1)
        ff_coord = lattice.coords[mask] @ r
        add_slice(V, counts, ff_coord, ff, D)

        # tilt series
        if args.tilt is not None:
            ff_tilt = torch.tensor(ff_tilt)
            ff_tilt = ff_tilt.view(-1)[mask]
            if ff_tilt is not None:
                ff_tilt *= c.sign()
            if t is not None:
                ff_tilt = lattice.translate_ht(ff_tilt.view(1, -1),
                                               t.view(1, 1, 2), mask).view(-1)
            ff_coord = lattice.coords[mask] @ tilt @ r
            add_slice(V, counts, ff_coord, ff_tilt, D)

    td = time.time() - t1
    log('Backprojected {} images in {}s ({}s per image)'.format(
        len(iterator), td, td / Nimg))
    counts[counts == 0] = 1
    V /= counts
    V = fft.ihtn_center(V[0:-1, 0:-1, 0:-1].cpu().numpy())
    mrc.write(args.o, V.astype('float32'), ax=apix, ay=apix, az=apix)