Example #1
0
def main(args):
    mkbasedir(args.o)
    warnexists(args.o)
    assert (args.o.endswith('.mrcs') or args.o.endswith('mrc')
            ), "Must specify output in .mrc(s) file format"

    old = dataset.load_particles(args.mrcs, lazy=True, datadir=args.datadir)
    oldD = old[0].get().shape[0]
    assert args.D <= oldD, f'New box size {args.D} cannot be larger than the original box size {oldD}'
    assert args.D % 2 == 0, 'New box size must be even'

    D = args.D
    start = int(oldD / 2 - D / 2)
    stop = int(oldD / 2 + D / 2)

    ### Downsample volume ###
    if args.is_vol:
        oldft = fft.htn_center(np.array([x.get() for x in old]))
        log(oldft.shape)
        newft = oldft[start:stop, start:stop, start:stop]
        log(newft.shape)
        new = fft.ihtn_center(newft).astype(np.float32)
        log(f'Saving {args.o}')
        mrc.write(args.o, new, is_vol=True)

    ### Downsample images ###
    elif args.chunk is None:
        new = []
        for i in range(len(old)):
            if i % 1000 == 0:
                log(f'Processing image {i} of {len(old)}')
            img = old[i]
            oldft = fft.ht2_center(img.get()).astype(np.float32)
            newft = oldft[start:stop, start:stop]
            new.append(fft.ihtn_center(newft).astype(np.float32))
        assert oldft[int(oldD / 2), int(oldD / 2)] == newft[int(D / 2),
                                                            int(D / 2)]
        new = np.asarray(new)
        log(new.shape)
        log('Saving {}'.format(args.o))
        mrc.write(args.o, new, is_vol=False)

    ### Downsample images, saving chunks of N images ###
    else:
        chunk_names = []
        nchunks = math.ceil(len(old) / args.chunk)
        for i in range(nchunks):
            log('Processing chunk {}'.format(i))
            out_mrcs = '.{}'.format(i).join(os.path.splitext(args.o))
            new = []
            for img in old[i * args.chunk:(i + 1) * args.chunk]:
                oldft = fft.ht2_center(img.get()).astype(np.float32)
                newft = oldft[start:stop, start:stop]
                new.append(fft.ihtn_center(newft).astype(np.float32))
            assert oldft[int(oldD / 2), int(oldD / 2)] == newft[int(D / 2),
                                                                int(D / 2)]
            new = np.asarray(new)
            log(new.shape)
            log(f'Saving {out_mrcs}'.format(out_mrcs))
            mrc.write(out_mrcs, new, is_vol=False)
            chunk_names.append(os.path.basename(out_mrcs))
        # Write a text file with all chunks
        out_txt = '{}.txt'.format(os.path.splitext(args.o)[0])
        log(f'Saving {out_txt}')
        with open(out_txt, 'w') as f:
            f.write('\n'.join(chunk_names))
Example #2
0
def main(args):
    mkbasedir(args.o)
    warnexists(args.o)
    assert (args.o.endswith('.mrcs') or args.o.endswith('mrc')
            ), "Must specify output in .mrc(s) file format"

    lazy = not args.is_vol
    old = dataset.load_particles(args.mrcs,
                                 lazy=lazy,
                                 datadir=args.datadir,
                                 relion31=args.relion31)

    oldD = old[0].get().shape[0] if lazy else old.shape[-1]
    assert args.D <= oldD, f'New box size {args.D} cannot be larger than the original box size {oldD}'
    assert args.D % 2 == 0, 'New box size must be even'

    D = args.D
    start = int(oldD / 2 - D / 2)
    stop = int(oldD / 2 + D / 2)

    def _combine_imgs(imgs):
        ret = []
        for img in imgs:
            img.shape = (1, *img.shape)  # (D,D) -> (1,D,D)
        cur = imgs[0]
        for img in imgs[1:]:
            if img.fname == cur.fname and img.offset == cur.offset + 4 * np.product(
                    cur.shape):
                cur.shape = (cur.shape[0] + 1, *cur.shape[1:])
            else:
                ret.append(cur)
                cur = img
        ret.append(cur)
        return ret

    def downsample_images(imgs):
        if lazy:
            imgs = _combine_imgs(imgs)
            imgs = np.concatenate([i.get() for i in imgs])
        with Pool(min(args.max_threads, mp.cpu_count())) as p:
            oldft = np.asarray(p.map(fft.ht2_center, imgs))
            newft = oldft[:, start:stop, start:stop]
            new = np.asarray(p.map(fft.iht2_center, newft))
        return new

    def downsample_in_batches(old, b):
        new = np.empty((len(old), D, D), dtype=np.float32)
        for ii in range(math.ceil(len(old) / b)):
            log(f'Processing batch {ii}')
            new[ii * b:(ii + 1) * b, :, :] = downsample_images(
                old[ii * b:(ii + 1) * b])
        return new

    ### Downsample volume ###
    if args.is_vol:
        oldft = fft.htn_center(old)
        log(oldft.shape)
        newft = oldft[start:stop, start:stop, start:stop]
        log(newft.shape)
        new = fft.ihtn_center(newft).astype(np.float32)
        log(f'Saving {args.o}')
        mrc.write(args.o, new, is_vol=True)

    ### Downsample images ###
    elif args.chunk is None:
        new = downsample_in_batches(old, args.b)
        log(new.shape)
        log('Saving {}'.format(args.o))
        mrc.write(args.o, new.astype(np.float32), is_vol=False)

    ### Downsample images, saving chunks of N images ###
    else:
        nchunks = math.ceil(len(old) / args.chunk)
        out_mrcs = [
            '.{}'.format(i).join(os.path.splitext(args.o))
            for i in range(nchunks)
        ]
        chunk_names = [os.path.basename(x) for x in out_mrcs]
        for i in range(nchunks):
            log('Processing chunk {}'.format(i))
            chunk = old[i * args.chunk:(i + 1) * args.chunk]
            new = downsample_in_batches(chunk, args.b)
            log(new.shape)
            log(f'Saving {out_mrcs[i]}')
            mrc.write(out_mrcs[i], new, is_vol=False)
        # Write a text file with all chunks
        out_txt = '{}.txt'.format(os.path.splitext(args.o)[0])
        log(f'Saving {out_txt}')
        with open(out_txt, 'w') as f:
            f.write('\n'.join(chunk_names))
Example #3
0
def main(args):
    assert args.o.endswith('.mrc')

    t1 = time.time()
    log(args)
    if not os.path.exists(os.path.dirname(args.o)):
        os.makedirs(os.path.dirname(args.o))

    ## set the device
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    else:
        log('WARNING: No GPUs detected')

    # load the particles
    if args.tilt is None:
        data = dataset.LazyMRCData(args.particles,
                                   norm=(0, 1),
                                   invert_data=args.invert_data,
                                   datadir=args.datadir)
        tilt = None
    else:
        data = dataset.TiltMRCData(args.particles,
                                   args.tilt,
                                   norm=(0, 1),
                                   invert_data=args.invert_data,
                                   datadir=args.datadir)
        tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32))
    D = data.D
    Nimg = data.N

    lattice = Lattice(D, extent=D // 2)

    posetracker = PoseTracker.load(args.poses, Nimg, D, None, None)

    if args.ctf is not None:
        log('Loading ctf params from {}'.format(args.ctf))
        ctf_params = ctf.load_ctf_for_training(D - 1, args.ctf)
        ctf_params = torch.tensor(ctf_params)
    else:
        ctf_params = None
    Apix = ctf_params[0, 0] if ctf_params is not None else 1

    V = torch.zeros((D, D, D))
    counts = torch.zeros((D, D, D))

    mask = lattice.get_circular_mask(D // 2)

    if args.ind:
        iterator = pickle.load(open(args.ind, 'rb'))
    elif args.first:
        args.first = min(args.first, Nimg)
        iterator = range(args.first)
    else:
        iterator = range(Nimg)

    for ii in iterator:
        if ii % 100 == 0: log('image {}'.format(ii))
        r, t = posetracker.get_pose(ii)
        ff = data.get(ii)
        if tilt is not None:
            ff, ff_tilt = ff  # EW
        ff = torch.tensor(ff)
        ff = ff.view(-1)[mask]
        if ctf_params is not None:
            freqs = lattice.freqs2d / ctf_params[ii, 0]
            c = ctf.compute_ctf(freqs, *ctf_params[ii, 1:]).view(-1)[mask]
            ff *= c.sign()
        if t is not None:
            ff = lattice.translate_ht(ff.view(1, -1), t.view(1, 1, 2),
                                      mask).view(-1)
        ff_coord = lattice.coords[mask] @ r
        add_slice(V, counts, ff_coord, ff, D)

        # tilt series
        if args.tilt is not None:
            ff_tilt = torch.tensor(ff_tilt)
            ff_tilt = ff_tilt.view(-1)[mask]
            if ctf_params is not None:
                ff_tilt *= c.sign()
            if t is not None:
                ff_tilt = lattice.translate_ht(ff_tilt.view(1, -1),
                                               t.view(1, 1, 2), mask).view(-1)
            ff_coord = lattice.coords[mask] @ tilt @ r
            add_slice(V, counts, ff_coord, ff_tilt, D)

    td = time.time() - t1
    log('Backprojected {} images in {}s ({}s per image)'.format(
        len(iterator), td, td / Nimg))
    counts[counts == 0] = 1
    V /= counts
    V = fft.ihtn_center(V[0:-1, 0:-1, 0:-1].cpu().numpy())
    mrc.write(args.o, V.astype('float32'), Apix=Apix)