Ejemplo n.º 1
0
    def calc_fsc(vol1_path, vol2_path):
        '''
        Helper function to calculate the FSC between two (assumed masked) volumes
        vol1 and vol2 should be maps of the same box size, structured as numpy arrays with ndim=3, i.e. by loading with cryodrgn.mrc.parse_mrc
        '''
        # load masked volumes in fourier space
        vol1, _ = mrc.parse_mrc(vol1_path)
        vol2, _ = mrc.parse_mrc(vol2_path)

        vol1_ft = fft.fftn_center(vol1)
        vol2_ft = fft.fftn_center(vol2)

        # define fourier grid and label into shells
        D = vol1.shape[0]
        x = np.arange(-D // 2, D // 2)
        x0, x1, x2 = np.meshgrid(x, x, x, indexing='ij')
        r = np.sqrt(x0 ** 2 + x1 ** 2 + x2 ** 2)
        r_max = D // 2  # sphere inscribed within volume box
        r_step = 1  # int(np.min(r[r>0]))
        bins = np.arange(0, r_max, r_step)
        bin_labels = np.searchsorted(bins, r, side='right')

        # calculate the FSC via labeled shells
        num = ndimage.sum(np.real(vol1_ft * np.conjugate(vol2_ft)), labels=bin_labels, index=bins + 1)
        den1 = ndimage.sum(np.abs(vol1_ft) ** 2, labels=bin_labels, index=bins + 1)
        den2 = ndimage.sum(np.abs(vol2_ft) ** 2, labels=bin_labels, index=bins + 1)
        fsc = num / np.sqrt(den1 * den2)

        x = bins / D  # x axis should be spatial frequency in 1/px
        return x, fsc
Ejemplo n.º 2
0
def main(args):
    vol1, _ = mrc.parse_mrc(args.vol1)
    vol2, _ = mrc.parse_mrc(args.vol2)

    if args.mask:
        mask = mrc.parse_mrc(args.mask)[0]
        vol1 *= mask
        vol2 *= mask

    D = vol1.shape[0]
    x = np.arange(-D // 2, D // 2)
    x2, x1, x0 = np.meshgrid(x, x, x, indexing='ij')
    coords = np.stack((x0, x1, x2), -1)
    r = (coords**2).sum(-1)**.5

    assert r[D // 2, D // 2, D // 2] == 0.0

    vol1 = fft.fftn_center(vol1)
    vol2 = fft.fftn_center(vol2)

    #log(r[D//2, D//2, D//2:])
    prev_mask = np.zeros((D, D, D), dtype=bool)
    fsc = [1.0]
    for i in range(1, D // 2):
        mask = r < i
        shell = np.where(mask & np.logical_not(prev_mask))
        v1 = vol1[shell]
        v2 = vol2[shell]
        p = np.vdot(v1, v2) / (np.vdot(v1, v1) * np.vdot(v2, v2))**.5
        fsc.append(p.real)
        prev_mask = mask
    fsc = np.asarray(fsc)
    x = np.arange(D // 2) / D

    res = np.stack((x, fsc), 1)
    if args.o:
        np.savetxt(args.o, res)
    else:
        log(res)

    w = np.where(fsc < 0.5)
    if w:
        log('0.5: {}'.format(1 / x[w[0]] * args.Apix))

    w = np.where(fsc < 0.143)
    if w:
        log('0.143: {}'.format(1 / x[w[0]] * args.Apix))

    if args.plot:
        plt.plot(x, fsc)
        plt.ylim((0, 1))
        plt.show()
Ejemplo n.º 3
0
def main(args):
    assert args.input.endswith('.mrc'), "Input volume must be .mrc file"
    assert args.o.endswith('.mrc'), "Output volume must be .mrc file"
    x, h = mrc.parse_mrc(args.input)
    x = x[::-1]
    mrc.write(args.o, x, header=h)
    log(f'Wrote {args.o}')
Ejemplo n.º 4
0
def mask_volume(volpath, outpath, Apix, thresh=None, dilate=3, dist=10):
    '''
    Helper function to generate a loose mask around the input density
    Density is thresholded to 50% maximum intensity, dilated outwards, and a soft cosine edge is applied

    Inputs
        volpath: an absolute path to the volume to be used for masking
        outpath: an absolute path to write out the mask mrc
        thresh: what intensity threshold between [0, 100] to apply
        dilate: how far to dilate the thresholded density outwards
        dist: how far the cosine edge extends from the density

    Outputs
       volume.masked.mrc written to outdir
    '''
    vol = mrc.parse_mrc(volpath)[0]
    thresh = np.percentile(vol, 99.99) / 2 if thresh is None else thresh
    x = (vol >= thresh).astype(bool)
    x = binary_dilation(x, iterations=dilate)
    y = distance_transform_edt(~x.astype(bool))
    y[y > dist] = dist
    z = np.cos(np.pi * y / dist / 2)

    # check that mask is in range [0,1]
    assert np.all(z >= 0)
    assert np.all(z <= 1)

    # used to write out mask separately from masked volume, now apply and save the masked vol to minimize future I/O
    # mrc.write(outpath, z.astype(np.float32))
    vol *= z
    mrc.write(outpath, vol.astype(np.float32), Apix=Apix)
Ejemplo n.º 5
0
def main(args):
    stack, _ = mrc.parse_mrc(args.input)
    image = stack[0]
    x_dim = image.shape[0]
    y_dim = image.shape[1]
    print('image dimensions: ' + str(stack.shape[1]) + 'x' +
          str(stack.shape[1]) + ' pixels')
    ang_px = float(args.pixel_size)
    if args.scale1:
        scale_a = float(args.scale1)
    else:
        scale_a = float(100)
    if args.scale2:
        scale_b = float(args.scale2)
    else:
        scale_b = float(400)
    line_a = scale_a / ang_px
    line_b = scale_b / ang_px
    offset = x_dim / 20

    fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(10, 10))
    if args.gblur:
        image = ndimage.gaussian_filter(image, float(args.gblur))
    axes.imshow(image, cmap='Greys_r')
    axes.plot(np.array([offset, offset + line_a]),
              np.array([y_dim - offset, y_dim - offset]),
              color='red')
    axes.plot(
        np.array([(offset * 2 + line_a) / 2.0, (offset * 2 + line_a) / 2.0]),
        np.array([y_dim - offset - y_dim / 100, y_dim - offset + y_dim / 100]),
        color='red')
    axes.text((2 * offset + line_a) / 2,
              y_dim - offset + y_dim / 50,
              str(scale_a) + ' A',
              color='r',
              ha='center',
              va='center')

    axes.plot(np.array([x_dim - offset, x_dim - offset - line_b]),
              np.array([y_dim - offset, y_dim - offset]),
              color='cyan')
    axes.plot(
        np.array([(2 * x_dim - 2 * offset - line_b) / 2.0,
                  (2 * x_dim - 2 * offset - line_b) / 2.0]),
        np.array([y_dim - offset - y_dim / 100, y_dim - offset + y_dim / 100]),
        color='cyan')
    axes.text((2 * x_dim - 2 * offset - line_b) / 2.0,
              y_dim - offset + y_dim / 50,
              str(scale_b) + ' A',
              color='cyan',
              ha='center',
              va='center')

    axes.axis('off')
    if args.tiff:
        extension = '.tiff'
    else:
        extension = '.png'
    plt.savefig(args.input.split('.mrc')[0] + extension)
Ejemplo n.º 6
0
def main(args):
    stack, _ = mrc.parse_mrc(args.input,lazy=True)
    print('{} {}x{} images'.format(len(stack), *stack[0].get().shape))
    stack = [stack[x].get() for x in range(9)]
    analysis.plot_projections(stack)
    if args.o:
        plt.savefig(args.o)
    else:
        plt.show()
Ejemplo n.º 7
0
def calculate_CCs(outdir, epochs, labels, chimerax_colors, LOG):
    '''
    Returns the masked map-map correlation between temporally sequential volume pairs outdir/vols.{epochs}, for each class in labels

    Inputs:
        outdir: path to base directory to save outputs
        epochs: array of epochs for which to calculate UMAPs
        labels: unique identifier for each class of representative latent encodings
        chimerax_colors: approximate colors matching ChimeraX palette to facilitate comparison to volume visualization

    Outputs:
        plot.png of sequential volume pairs map-map CC for each class in labels across training epochs
    '''
    def calc_cc(vol1, vol2):
        '''
        Helper function to calculate the zero-mean correlation coefficient as defined in eq 2 in https://journals.iucr.org/d/issues/2018/09/00/kw5139/index.html
        vol1 and vol2 should be maps of the same box size, structured as numpy arrays with ndim=3, i.e. by loading with cryodrgn.mrc.parse_mrc
        '''
        zmean1 = (vol1 - np.mean(vol1))
        zmean2 = (vol2 - np.mean(vol2))
        cc = (np.sum(zmean1 ** 2) ** -0.5) * (np.sum(zmean2 ** 2) ** -0.5) * np.sum(zmean1 * zmean2)
        return cc

    cc_masked = np.zeros((len(labels), len(epochs) - 1))

    for i in range(len(epochs) - 1):
        for cluster in np.arange(len(labels)):
            vol1, _ = mrc.parse_mrc(f'{outdir}/vols.{epochs[i]}/vol_{cluster:03d}.masked.mrc')
            vol2, _ = mrc.parse_mrc(f'{outdir}/vols.{epochs[i+1]}/vol_{cluster:03d}.masked.mrc')

            cc_masked[cluster, i] = calc_cc(vol1, vol2)

    utils.save_pkl(cc_masked, f'{outdir}/cc_masked.pkl')

    fig, ax = plt.subplots(1, 1)

    ax.set_xlabel('epoch')
    ax.set_ylabel('masked CC')
    for i in range(len(labels)):
        ax.plot(epochs[1:], cc_masked[i,:], c=chimerax_colors[i] * 0.75, linewidth=2.5)
    ax.legend(labels, ncol=3, fontsize='x-small')

    plt.savefig(f'{outdir}/plots/05_decoder_CC.png', dpi=300, format='png', transparent=True, bbox_inches='tight')
    flog(f'Saved map-map correlation plot to {outdir}/plots/05_decoder_CC.png', LOG)
Ejemplo n.º 8
0
def main(args):
    assert args.input.endswith('.mrc'), "Input volume must be .mrc file"
    assert args.o.endswith('.mrc'), "Output volume must be .mrc file"
    x, h = mrc.parse_mrc(args.input)
    h.update_apix(args.apix)
    if args.invert:
        x *= -1
    if args.flip:
        x = x[::-1]
    mrc.write(args.o, x, header=h)
    log(f'Wrote {args.o}')
Ejemplo n.º 9
0
def main(args):
    assert args.o.endswith('.star')
    assert args.particles.endswith(
        '.mrcs'
    ), "Only a single particle stack as an .mrcs is currently supported"
    particles = mrc.parse_mrc(args.particles, lazy=True)[0]
    ctf = utils.load_pkl(args.ctf)
    assert ctf.shape[1] == 9, "Incorrect CTF pkl format"
    assert len(particles) == len(
        ctf
    ), f"{len(particles)} != {len(ctf)}, Number of particles != number of CTF paraameters"
    if args.poses:
        poses = utils.load_pkl(args.poses)
        assert len(particles) == len(
            poses[0]
        ), f"{len(particles)} != {len(poses)}, Number of particles != number of poses"
    log('{} particles'.format(len(particles)))

    if args.ind:
        ind = utils.load_pkl(args.ind)
        log(f'Filtering to {len(ind)} particles')
        ctf = ctf[ind]
        if args.poses:
            poses = (poses[0][ind], poses[1][ind])
    else:
        ind = np.arange(len(particles))

    # _rlnImageName
    ind += 1  # CHANGE TO 1-BASED INDEXING
    image_name = os.path.basename(
        args.particles) if not args.full_path else args.particles
    names = [f'{i}@{image_name}' for i in ind]

    ctf = ctf[:, 2:]

    # convert poses
    if args.poses:
        eulers = utils.R_to_relion_scipy(poses[0])
        D = particles[0].get().shape[0]
        trans = poses[1] * D  # convert from fraction to pixels

    data = {HEADERS[0]: names}
    for i in range(7):
        data[HEADERS[i + 1]] = ctf[:, i]
    if args.poses:
        for i in range(3):
            data[POSE_HDRS[i]] = eulers[:, i]
        for i in range(2):
            data[POSE_HDRS[3 + i]] = trans[:, i]
    df = pd.DataFrame(data=data)

    headers = HEADERS + POSE_HDRS if args.poses else HEADERS
    s = starfile.Starfile(headers, df)
    s.write(args.o)
Ejemplo n.º 10
0
def make_mask(outdir, K, dilate, thresh, in_mrc=None):
    if in_mrc is None:
        if thresh is None:
            thresh = []
            for i in range(K):
                vol = mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0]
                thresh.append(np.percentile(vol, 99.99) / 2)
            thresh = np.mean(thresh)
        log(f'Threshold: {thresh}')
        log(f'Dilating mask by: {dilate}')

        def binary_mask(vol):
            x = (vol >= thresh).astype(bool)
            x = binary_dilation(x, iterations=dilate)
            return x

        # combine all masks by taking their union
        vol = mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_000.mrc')[0]
        mask = ~binary_mask(vol)
        for i in range(1, K):
            vol = mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0]
            mask *= ~binary_mask(vol)
        mask = ~mask
    else:
        # Load provided mrc and convert to a boolean mask
        mask, _ = mrc.parse_mrc(in_mrc)
        mask = mask.astype(bool)

    # save mask
    out_mrc = f'{outdir}/mask.mrc'
    log(f'Saving {out_mrc}')
    mrc.write(out_mrc, mask.astype(np.float32))

    # view slices
    out_png = f'{outdir}/mask_slices.png'
    D = vol.shape[0]
    fig, ax = plt.subplots(1, 3, figsize=(10, 8))
    ax[0].imshow(mask[D // 2, :, :])
    ax[1].imshow(mask[:, D // 2, :])
    ax[2].imshow(mask[:, :, D // 2])
    plt.savefig(out_png)
Ejemplo n.º 11
0
def main(args):
    log(args)
    torch.set_grad_enabled(False)
    use_cuda = torch.cuda.is_available()
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    t1 = time.time()
    ref, _ = mrc.parse_mrc(args.ref)
    log('Loaded {} volume'.format(ref.shape))
    vol, _ = mrc.parse_mrc(args.vol)
    log('Loaded {} volume'.format(vol.shape))

    projector = VolumeAligner(vol,
                              vol_ref=ref,
                              maxD=args.max_D,
                              flip=args.flip)
    if use_cuda:
        projector.use_cuda()

    r_resol = args.r_resol
    quats = so3_grid.grid_SO3(r_resol)
    q_id = np.arange(len(quats))
    q_id = np.stack([q_id // (6 * 2**r_resol), q_id % (6 * 2**r_resol)], -1)
    rots = GridPose(quats, q_id)

    t_resol = 0
    T_EXTENT = vol.shape[0] / 16 if args.t_extent is None else args.t_extent
    T_NGRID = args.t_grid
    trans = shift_grid3.base_shift_grid(T_EXTENT, T_NGRID)
    t_id = np.stack(shift_grid3.get_base_id(np.arange(len(trans)), T_NGRID),
                    -1)
    trans = GridPose(trans, t_id)

    max_keep_r = args.keep_r
    max_keep_t = args.keep_t
    #rot_tracker = MinPoseTracker(max_keep_r, 4, 2)
    #tr_tracker = MinPoseTracker(max_keep_t, 3, 3)
    for it in range(args.niter):
        log('Iteration {}'.format(it))
        log('Generating {} rotations'.format(len(rots)))
        log('Generating {} translations'.format(len(trans)))
        pose_err = np.empty((len(rots), len(trans)), dtype=np.float32)
        #rot_tracker.clear()
        #tr_tracker.clear()
        r_iterator = data.DataLoader(rots, batch_size=args.rb, shuffle=False)
        t_iterator = data.DataLoader(trans, batch_size=args.tb, shuffle=False)
        r_it = 0
        for rot, r_id in r_iterator:
            if use_cuda: rot = rot.cuda()
            vr, vi = projector.rotate(rot)
            t_it = 0
            for tr, t_id in t_iterator:
                if use_cuda: tr = tr.cuda()
                vtr, vti = projector.translate(
                    vr, vi, tr.expand(rot.size(0), *tr.shape))
                # todo: check volume
                err = projector.compute_err(vtr, vti)  # R x T
                pose_err[r_it:r_it + len(rot),
                         t_it:t_it + len(tr)] = err.cpu().numpy()
                #r_err = err.min(1)[0]
                #min_r_err, min_r_i = r_err.sort()
                #rot_tracker.add(min_r_err[:max_keep_r], rot[min_r_i][:max_keep_r], r_id[min_r_i][:max_keep_r])
                #t_err= err.min(0)[0]
                #min_t_err, min_t_i = t_err.sort()
                #tr_tracker.add(min_t_err[:max_keep_t], tr[min_t_i][:max_keep_t], t_id[min_t_i][:max_keep_t])
                t_it += len(tr)
            r_it += len(rot)

        r_err = pose_err.min(1)
        r_err_argmin = r_err.argsort()[:max_keep_r]
        t_err = pose_err.min(0)
        t_err_argmin = t_err.argsort()[:max_keep_t]

        # lstart
        #r = rots.pose[r_err_argmin[0]]
        #t = trans.pose[t_err_argmin[0]]
        #log('Best rot: {}'.format(r))
        #log('Best trans: {}'.format(t))
        #vr, vi = projector_full.rotate(torch.tensor(r).unsqueeze(0))
        #vr, vi = projector_full.translate(vr, vi, torch.tensor(t).view(1,1,3))
        #err = projector_full.compute_err(vr,vi)

        #w = np.where(r_err[r_err_argmin] > err.item())[0]
        rots, rots_id = subdivide_r(rots.pose[r_err_argmin],
                                    rots.pose_id[r_err_argmin], r_resol)
        rots = GridPose(rots, rots_id)

        t_err = pose_err.min(0)
        t_err_argmin = t_err.argsort()[:max_keep_t]
        trans, trans_id = subdivide_t(trans.pose_id[t_err_argmin], t_resol,
                                      T_EXTENT, T_NGRID)
        trans = GridPose(trans, trans_id)
        r_resol += 1
        t_resol += 1
        vlog(r_err[r_err_argmin])
        vlog(t_err[t_err_argmin])
        #log(rot_tracker.min_errs)
        #log(tr_tracker.min_errs)
    r = rots.pose[r_err_argmin[0]]
    t = trans.pose[t_err_argmin[0]] * vol.shape[0] / args.max_D
    log('Best rot: {}'.format(r))
    log('Best trans: {}'.format(t))
    t *= 2 / vol.shape[0]
    projector = VolumeAligner(vol,
                              vol_ref=ref,
                              maxD=vol.shape[0],
                              flip=args.flip)
    if use_cuda: projector.use_cuda()
    vr = projector.real_tform(
        torch.tensor(r).unsqueeze(0),
        torch.tensor(t).view(1, 1, 3))
    v = vr.squeeze().cpu().numpy()
    log('Saving {}'.format(args.o))
    mrc.write(args.o, v.astype(np.float32))

    td = time.time() - t1
    log('Finished in {}s'.format(td))
Ejemplo n.º 12
0
# coding: utf-8
import sys, os
from cryodrgn import mrc
import numpy as np
data, _ = mrc.parse_mrc('data/toy_projections.mrcs', lazy=True)
data2, _ = mrc.parse_mrc('data/toy_projections.mrcs', lazy=False)
data1 = np.asarray([x.get() for x in data])
assert (data1 == data2).all()
print('ok')

from cryodrgn import dataset
data2 = dataset.load_particles('data/toy_projections.star')
assert (data1 == data2).all()
print('ok')

data2 = dataset.load_particles('data/toy_projections.txt')
assert (data1 == data2).all()
print('ok')

print('all ok')
Ejemplo n.º 13
0
def analyze_volumes(outdir,
                    K,
                    dim,
                    M,
                    linkage,
                    vol_ind=None,
                    plot_dim=5,
                    particle_ind_orig=None):
    cmap = choose_cmap(M)

    # load mean volume, compute it if it does not exist
    if not os.path.exists(f'{outdir}/kmeans{K}/vol_mean.mrc'):
        volm = np.array([
            mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0]
            for i in range(K)
        ]).mean(axis=0)
        mrc.write(f'{outdir}/kmeans{K}/vol_mean.mrc', volm)
    else:
        volm = mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_mean.mrc')[0]

    # load mask
    mask = mrc.parse_mrc(f'{outdir}/mask.mrc')[0].astype(bool)
    log(f'{mask.sum()} voxels in mask')

    # load volumes
    vols = np.array([
        mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0][mask]
        for i in range(K)
    ])
    vols[vols < 0] = 0

    # load umap
    umap = utils.load_pkl(f'{outdir}/umap.pkl')
    ind = np.loadtxt(f'{outdir}/kmeans{K}/centers_ind.txt').astype(int)

    if vol_ind is not None:
        log(f'Filtering to {len(vol_ind)} volumes')
        vols = vols[vol_ind]
        ind = ind[vol_ind]

    # compute PCA
    pca = PCA(dim)
    pca.fit(vols)
    pc = pca.transform(vols)
    utils.save_pkl(pc, f'{outdir}/vol_pca_{K}.pkl')
    utils.save_pkl(pca, f'{outdir}/vol_pca_obj.pkl')
    log('Explained variance ratio:')
    log(pca.explained_variance_ratio_)

    # save rxn coordinates
    for i in range(plot_dim):
        subdir = f'{outdir}/vol_pcs/pc{i+1}'
        if not os.path.exists(subdir):
            os.makedirs(subdir)
        min_, max_ = pc[:, i].min(), pc[:, i].max()
        log((min_, max_))
        for j, val in enumerate(np.linspace(min_, max_, 10, endpoint=True)):
            v = volm.copy()
            v[mask] += pca.components_[i] * val
            mrc.write(f'{subdir}/{j}.mrc', v)

    # which plots to show???
    def plot(i, j):
        plt.figure()
        plt.scatter(pc[:, i], pc[:, j])
        plt.xlabel(
            f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})')
        plt.ylabel(
            f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})')
        plt.savefig(f'{outdir}/vol_pca_{K}_{i+1}_{j+1}.png')

    for i in range(plot_dim - 1):
        plot(i, i + 1)

    # clustering
    subdir = f'{outdir}/clustering_L2_{linkage}_{M}'
    if not os.path.exists(subdir):
        os.makedirs(subdir)
    cluster = AgglomerativeClustering(n_clusters=M,
                                      affinity='euclidean',
                                      linkage=linkage)
    labels = cluster.fit_predict(vols)
    utils.save_pkl(labels, f'{subdir}/state_labels.pkl')

    kmeans_labels = utils.load_pkl(f'{outdir}/kmeans{K}/labels.pkl')
    kmeans_counts = Counter(kmeans_labels)
    for i in range(M):
        vol_i = np.where(labels == i)[0]
        log(f'State {i}: {len(vol_i)} volumes')
        if vol_ind is not None:
            vol_i = np.arange(K)[vol_ind][vol_i]
        vol_i_all = np.array([
            mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0]
            for i in vol_i
        ])
        nparticles = np.array([kmeans_counts[i] for i in vol_i])
        vol_i_mean = np.average(vol_i_all, axis=0, weights=nparticles)
        vol_i_std = np.average((vol_i_all - vol_i_mean)**2,
                               axis=0,
                               weights=nparticles)**.5
        mrc.write(f'{subdir}/state_{i}_mean.mrc',
                  vol_i_mean.astype(np.float32))
        mrc.write(f'{subdir}/state_{i}_std.mrc', vol_i_std.astype(np.float32))
        if not os.path.exists(f'{subdir}/state_{i}'):
            os.makedirs(f'{subdir}/state_{i}')
        for v in vol_i:
            os.symlink(f'{outdir}/kmeans{K}/vol_{v:03d}.mrc',
                       f'{subdir}/state_{i}/vol_{v:03d}.mrc')
        particle_ind = analysis.get_ind_for_cluster(kmeans_labels, vol_i)
        log(f'State {i}: {len(particle_ind)} particles')
        if particle_ind_orig is not None:
            utils.save_pkl(particle_ind_orig[particle_ind],
                           f'{subdir}/state_{i}_particle_ind.pkl')
        else:
            utils.save_pkl(particle_ind,
                           f'{subdir}/state_{i}_particle_ind.pkl')

    # plot clustering results
    def hack_barplot(counts_):
        if M <= 20:  # HACK TO GET COLORS
            with sns.color_palette(cmap):
                g = sns.barplot(np.arange(M), counts_)
        else:  # default is husl
            g = sns.barplot(np.arange(M), counts_)
        return g

    plt.figure()
    counts = Counter(labels)
    g = hack_barplot([counts[i] for i in range(M)])
    for i in range(M):
        g.text(i - .1, counts[i] + 2, counts[i])
    plt.xlabel('State')
    plt.ylabel('Count')
    plt.savefig(f'{subdir}/state_volume_counts.png')

    plt.figure()
    particle_counts = [
        np.sum([kmeans_counts[ii] for ii in np.where(labels == i)[0]])
        for i in range(M)
    ]
    g = hack_barplot(particle_counts)
    for i in range(M):
        g.text(i - .1, particle_counts[i] + 2, particle_counts[i])
    plt.xlabel('State')
    plt.ylabel('Count')
    plt.savefig(f'{subdir}/state_particle_counts.png')

    def plot_w_labels(i, j):
        plt.figure()
        plt.scatter(pc[:, i], pc[:, j], c=labels, cmap=cmap)
        plt.xlabel(
            f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})')
        plt.ylabel(
            f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})')
        plt.savefig(f'{subdir}/vol_pca_{K}_{i+1}_{j+1}.png')

    for i in range(plot_dim - 1):
        plot_w_labels(i, i + 1)

    def plot_w_labels_annotated(i, j):
        fig, ax = plt.subplots(figsize=(16, 16))
        plt.scatter(pc[:, i], pc[:, j], c=labels, cmap=cmap)
        annots = np.arange(K)
        if vol_ind is not None:
            annots = annots[vol_ind]
        for ii, k in enumerate(annots):
            ax.annotate(str(k), pc[ii, [i, j]] + np.array([.1, .1]))
        plt.xlabel(
            f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})')
        plt.ylabel(
            f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})')
        plt.savefig(f'{subdir}/vol_pca_{K}_annotated_{i+1}_{j+1}.png')

    for i in range(plot_dim - 1):
        plot_w_labels_annotated(i, i + 1)

    # plot clusters on UMAP
    umap_i = umap[ind]
    fig, ax = plt.subplots(figsize=(8, 8))
    plt.scatter(umap[:, 0],
                umap[:, 1],
                alpha=.1,
                s=1,
                rasterized=True,
                color='lightgrey')
    colors = get_colors_for_cmap(cmap, M)
    for i in range(M):
        c = umap_i[np.where(labels == i)]
        plt.scatter(c[:, 0], c[:, 1], label=i, color=colors[i])
    plt.legend()
    plt.xlabel('UMAP1')
    plt.ylabel('UMAP2')
    plt.savefig(f'{subdir}/umap.png')

    fig, ax = plt.subplots(figsize=(16, 16))
    plt.scatter(umap[:, 0],
                umap[:, 1],
                alpha=.1,
                s=1,
                rasterized=True,
                color='lightgrey')
    plt.scatter(umap_i[:, 0], umap_i[:, 1], c=labels, cmap=cmap)
    annots = np.arange(K)
    if vol_ind is not None:
        annots = annots[vol_ind]
    for i, k in enumerate(annots):
        ax.annotate(str(k), umap_i[i] + np.array([.1, .1]))
    plt.xlabel('UMAP1')
    plt.ylabel('UMAP2')
    plt.savefig(f'{subdir}/umap_annotated.png')
Ejemplo n.º 14
0
def main(args):
    for out in (args.o, args.out_png, args.out_pose):
        if not out: continue
        mkbasedir(out)
        warnexists(out)

    if args.t_extent == 0.:
        log('Not shifting images')
    else:
        assert args.t_extent > 0

    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    use_cuda = torch.cuda.is_available()
    log('Use cuda {}'.format(use_cuda))
    if use_cuda:
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    t1 = time.time()    
    vol, _ = mrc.parse_mrc(args.mrc)
    log('Loaded {} volume'.format(vol.shape))

    if args.tilt:
        theta = args.tilt*np.pi/180
        args.tilt = np.array([[1.,0.,0.],
                        [0, np.cos(theta), -np.sin(theta)],
                        [0, np.sin(theta), np.cos(theta)]]).astype(np.float32)

    projector = Projector(vol, args.tilt)
    if use_cuda:
        projector.lattice = projector.lattice.cuda()
        projector.vol = projector.vol.cuda()

    if args.grid is not None:
        rots = GridRot(args.grid)
        log('Generating {} rotations at resolution level {}'.format(len(rots), args.grid))
    else:
        log('Generating {} random rotations'.format(args.N))
        rots = RandomRot(args.N)
    
    log('Projecting...')
    imgs = []
    iterator = data.DataLoader(rots, batch_size=args.b)
    for i, rot in enumerate(iterator):
        vlog('Projecting {}/{}'.format((i+1)*len(rot), args.N))
        projections = projector.project(rot)
        projections = projections.cpu().numpy()
        imgs.append(projections)

    rots = rots.rots.cpu().numpy()
    imgs = np.vstack(imgs)
    td = time.time()-t1
    log('Projected {} images in {}s ({}s per image)'.format(args.N, td, td/args.N ))

    if args.t_extent:
        log('Shifting images between +/- {} pixels'.format(args.t_extent))
        trans = np.random.rand(args.N,2)*2*args.t_extent - args.t_extent
        imgs = np.asarray([translate_img(img, t) for img,t in zip(imgs,trans)])
        # convention: we want the first column to be x shift and second column to be y shift
        # reverse columns since current implementation of translate_img uses scipy's 
        # fourier_shift, which is flipped the other way
        # convention: save the translation that centers the image
        trans = -trans[:,::-1]
        # convert translation from pixel to fraction
        D = imgs.shape[-1]
        assert D % 2 == 0
        trans /= D

    log('Saving {}'.format(args.o))
    mrc.write(args.o,imgs.astype(np.float32))
    log('Saving {}'.format(args.out_pose))
    with open(args.out_pose,'wb') as f:
        if args.t_extent:
            pickle.dump((rots,trans),f)
        else:
            pickle.dump(rots, f)
    if args.out_png:
        log('Saving {}'.format(args.out_png))
        plot_projections(args.out_png, imgs[:9])