Ejemplo n.º 1
0
def main(args):
    assert args.o.endswith('.star')
    assert args.particles.endswith(
        '.mrcs'
    ), "Only a single particle stack as an .mrcs is currently supported"
    particles = mrc.parse_mrc(args.particles, lazy=True)[0]
    ctf = utils.load_pkl(args.ctf)
    assert ctf.shape[1] == 9, "Incorrect CTF pkl format"
    assert len(particles) == len(
        ctf
    ), f"{len(particles)} != {len(ctf)}, Number of particles != number of CTF paraameters"
    if args.poses:
        poses = utils.load_pkl(args.poses)
        assert len(particles) == len(
            poses[0]
        ), f"{len(particles)} != {len(poses)}, Number of particles != number of poses"
    log('{} particles'.format(len(particles)))

    if args.ind:
        ind = utils.load_pkl(args.ind)
        log(f'Filtering to {len(ind)} particles')
        ctf = ctf[ind]
        if args.poses:
            poses = (poses[0][ind], poses[1][ind])
    else:
        ind = np.arange(len(particles))

    # _rlnImageName
    ind += 1  # CHANGE TO 1-BASED INDEXING
    image_name = os.path.basename(
        args.particles) if not args.full_path else args.particles
    names = [f'{i}@{image_name}' for i in ind]

    ctf = ctf[:, 2:]

    # convert poses
    if args.poses:
        eulers = utils.R_to_relion_scipy(poses[0])
        D = particles[0].get().shape[0]
        trans = poses[1] * D  # convert from fraction to pixels

    data = {HEADERS[0]: names}
    for i in range(7):
        data[HEADERS[i + 1]] = ctf[:, i]
    if args.poses:
        for i in range(3):
            data[POSE_HDRS[i]] = eulers[:, i]
        for i in range(2):
            data[POSE_HDRS[3 + i]] = trans[:, i]
    df = pd.DataFrame(data=data)

    headers = HEADERS + POSE_HDRS if args.poses else HEADERS
    s = starfile.Starfile(headers, df)
    s.write(args.o)
Ejemplo n.º 2
0
def encoder_latent_shifts(workdir, outdir, epochs, E, LOG):
    '''
    Calculates and plots various metrics characterizing the per-particle latent vectors between successive epochs.

    Inputs
        workdir: path to directory containing cryodrgn training results
        outdir: path to base directory to save outputs
        E: int of epoch from which to evaluate convergence (0-indexed)
            Note that because three epochs are needed to define the two inter-epoch vectors analyzed "per epoch",
            the output contains metrics for E-2 epochs. Accordingly, plot x-axes are labeled from epoch 2 - epoch E.

    Outputs
        pkl of all statistics of shape(E, n_metrics)
        png of each statistic plotted over training
    '''
    metrics = ['dot product', 'magnitude', 'cosine distance']

    vector_metrics = np.zeros((E-1, len(metrics)))
    for i in np.arange(E-1):
        flog(f'Calculating vector metrics for epochs {i}-{i+1} and {i+1}-{i+2}', LOG)
        if i == 0:
            z1 = utils.load_pkl(f'{workdir}/z.{i}.pkl')
            z2 = utils.load_pkl(f'{workdir}/z.{i+1}.pkl')
            z3 = utils.load_pkl(f'{workdir}/z.{i+2}.pkl')
        else:
            z1 = z2.copy()
            z2 = z3.copy()
            z3 = utils.load_pkl(workdir + f'/z.{i+2}.pkl')

        diff21 = z2 - z1
        diff32 = z3 - z2

        vector_metrics[i, 0] = np.median(np.einsum('ij,ij->i', diff21, diff32), axis=0)  # median vector dot product
        vector_metrics[i, 1] = np.median(np.linalg.norm(diff32, axis=1), axis=0)  # median vector magnitude
        uv = np.sum(diff32 * diff21, axis=1)
        uu = np.sum(diff32 * diff32, axis=1)
        vv = np.sum(diff21 * diff21, axis=1)
        vector_metrics[i, 2] = np.median(1 - uv / (np.sqrt(uu) * np.sqrt(vv))) #median vector cosine distance

    utils.save_pkl(vector_metrics, f'{outdir}/vector_metrics.pkl')

    fig, axes = plt.subplots(1, len(metrics), figsize=(10,3))
    fig.tight_layout()
    for i,ax in enumerate(axes.flat):
        ax.plot(np.arange(2,E+1), vector_metrics[:,i])
        ax.set_xlabel('epoch')
        ax.set_ylabel(metrics[i])
    plt.savefig(f'{outdir}/plots/02_encoder_latent_vector_shifts.png', dpi=300, format='png', transparent=True, bbox_inches='tight')

    flog(f'Saved latent vector shifts plots to {outdir}/plots/02_encoder_latent_vector_shifts.png', LOG)
Ejemplo n.º 3
0
def main(args):
    assert args.o.endswith('.star')
    particles = dataset.load_particles(args.particles, lazy=True, datadir=args.datadir)
    ctf = utils.load_pkl(args.ctf)
    assert ctf.shape[1] == 9, "Incorrect CTF pkl format"
    assert len(particles) == len(ctf), f"{len(particles)} != {len(ctf)}, Number of particles != number of CTF paraameters"
    if args.poses:
        poses = utils.load_pkl(args.poses)
        assert len(particles) == len(poses[0]), f"{len(particles)} != {len(poses)}, Number of particles != number of poses"
    log('{} particles'.format(len(particles)))

    if args.ind:
        ind = utils.load_pkl(args.ind)
        log(f'Filtering to {len(ind)} particles')
        particles = [particles[ii] for ii in ind]
        ctf = ctf[ind]
        if args.poses: 
            poses = (poses[0][ind], poses[1][ind])
    else:
        ind = np.arange(len(particles))

    ind += 1 # CHANGE TO 1-BASED INDEXING
    image_names = [img.fname for img in particles]
    if args.full_path:
        image_names = [os.path.abspath(img.fname) for img in particles]
    names = [f'{i}@{name}' for i,name in zip(ind, image_names)]

    ctf = ctf[:,2:]

    # convert poses
    if args.poses:
        eulers = utils.R_to_relion_scipy(poses[0]) 
        D = particles[0].get().shape[0]
        trans = poses[1] * D # convert from fraction to pixels

    data = {HEADERS[0]:names}
    for i in range(7):
        data[HEADERS[i+1]] = ctf[:,i]
    if args.poses:
        for i in range(3):
            data[POSE_HDRS[i]] = eulers[:,i]
        for i in range(2):
            data[POSE_HDRS[3+i]] = trans[:,i]
    df = pd.DataFrame(data=data)

    headers = HEADERS + POSE_HDRS if args.poses else HEADERS
    s = starfile.Starfile(headers,df)
    s.write(args.o)
Ejemplo n.º 4
0
def main(args):
    x = dataset.load_particles(args.input, lazy=True)
    log(x.shape)
    ind = utils.load_pkl(args.ind)
    x = np.array([x[i].get() for i in ind])
    log(x.shape)
    mrc.write(args.o, x)
Ejemplo n.º 5
0
def main(args):
    x = dataset.load_particles(args.input, lazy=True)
    log(f'Loaded {len(x)} particles')
    ind = utils.load_pkl(args.ind)
    x = np.array([x[i].get() for i in ind])
    log(f'New stack dimensions: {x.shape}')
    mrc.write(args.o, x)
Ejemplo n.º 6
0
def main(args):
    # load particles
    particles = dataset.load_particles(args.mrcs, datadir=args.datadir)
    log(particles.shape)
    Nimg, D, D = particles.shape

    trans = utils.load_pkl(args.trans)
    if type(trans) is tuple:
        trans = trans[1]
    trans *= args.tscale
    assert np.all(
        trans <= 1
    ), "ERROR: Old pose format detected. Translations must be in units of fraction of box."
    trans *= D  # convert to pixels
    assert len(trans) == Nimg

    xx, yy = np.meshgrid(np.arange(-D / 2, D / 2), np.arange(-D / 2, D / 2))
    TCOORD = np.stack([xx, yy], axis=2) / D  # DxDx2

    imgs = []
    for ii in range(Nimg):
        ff = fft.fft2_center(particles[ii])
        tfilt = np.dot(TCOORD, trans[ii]) * -2 * np.pi
        tfilt = np.cos(tfilt) + np.sin(tfilt) * 1j
        ff *= tfilt
        img = fft.ifftn_center(ff)
        imgs.append(img)

    imgs = np.asarray(imgs).astype(np.float32)
    mrc.write(args.o, imgs)

    if args.out_png:
        plot_projections(args.out_png, imgs[:9])
Ejemplo n.º 7
0
def main(args):
    s = starfile.Starfile.load(args.input)
    ind = utils.load_pkl(args.ind)
    log('{} particles'.format(len(s.df)))
    s.df = s.df.loc[ind]
    log(len(s.df))
    log(len(s.df.index))
    s.write(args.o)
Ejemplo n.º 8
0
 def __init__(self, pose_pkl):
     poses = utils.load_pkl(pose_pkl)
     self.rots = torch.tensor(poses[0])
     self.trans = poses[1]
     self.N = len(poses[0])
     assert self.rots.shape == (self.N, 3, 3)
     assert self.trans.shape == (self.N, 2)
     assert self.trans.max() < 1
Ejemplo n.º 9
0
def main(args):
    t1 = dt.now()
    E = args.epoch
    workdir = args.workdir
    zfile = f'{workdir}/z.{E}.pkl'
    weights = f'{workdir}/weights.{E}.pkl'
    config = f'{workdir}/config.pkl'
    outdir = f'{workdir}/analyze.{E}'
    if E == -1:
        zfile = f'{workdir}/z.pkl'
        weights = f'{workdir}/weights.pkl'
        outdir = f'{workdir}/analyze'
    
    if args.outdir:
        outdir = args.outdir
    log(f'Saving results to {outdir}')
    if not os.path.exists(outdir):
        os.mkdir(outdir)

    z = utils.load_pkl(zfile)
    zdim = z.shape[1]

    vol_args = dict(Apix=args.Apix, downsample=args.downsample, flip=args.flip, cuda=args.device)
    vg = VolumeGenerator(weights, config, vol_args, skip_vol=args.skip_vol)

    if zdim == 1:
        analyze_z1(z, outdir, vg)
    else:
        analyze_zN(z, outdir, vg, skip_umap=args.skip_umap, num_pcs=args.pc, num_ksamples=args.ksample)
       
    # copy over template if file doesn't exist
    out_ipynb = f'{outdir}/cryoDRGN_viz.ipynb'
    if not os.path.exists(out_ipynb):
        log(f'Creating jupyter notebook...')
        ipynb = f'{cryodrgn._ROOT}/templates/cryoDRGN_viz_template.ipynb'
        cmd = f'cp {ipynb} {out_ipynb}'
        subprocess.check_call(cmd, shell=True)
    else:
        log(f'{out_ipynb} already exists. Skipping')
    log(out_ipynb)

    # copy over template if file doesn't exist
    out_ipynb = f'{outdir}/cryoDRGN_filtering.ipynb'
    if not os.path.exists(out_ipynb):
        log(f'Creating jupyter notebook...')
        ipynb = f'{cryodrgn._ROOT}/templates/cryoDRGN_filtering_template.ipynb'
        cmd = f'cp {ipynb} {out_ipynb}'
        subprocess.check_call(cmd, shell=True)
    else:
        log(f'{out_ipynb} already exists. Skipping')
    log(out_ipynb)
    
    log(f'Finished in {dt.now()-t1}')
Ejemplo n.º 10
0
def main(args):
    imgs = dataset.load_particles(args.mrcs, lazy=True, datadir=args.datadir)
    ctf_params = utils.load_pkl(args.ctf_params)
    assert len(imgs) == len(ctf_params)

    D = imgs[0].get().shape[0]
    fx, fy = np.meshgrid(np.linspace(-.5, .5, D, endpoint=False),
                         np.linspace(-.5, .5, D, endpoint=False))
    freqs = np.stack([fx.ravel(), fy.ravel()], 1)

    imgs_flip = np.empty((len(imgs), D, D), dtype=np.float32)
    for i in range(len(imgs)):
        if i % 1000 == 0: print(i)
        c = ctf.compute_ctf_np(freqs / ctf_params[i, 0], *ctf_params[i, 1:])
        c = c.reshape((D, D))
        ff = fft.fft2_center(imgs[i].get())
        ff *= np.sign(c)
        img = fft.ifftn_center(ff)
        imgs_flip[i] = img.astype(np.float32)

    mrc.write(args.o, imgs_flip)
Ejemplo n.º 11
0
def sketch_via_umap_local_maxima(outdir, E, LOG, n_bins=30, smooth=True, smooth_width=1, pruned_maxima=12, radius=5, final_maxima=10):
    '''
    Sketch the UMAP embedding of epoch E latent space via local maxima finding

    Inputs:
        E: epoch for which the (subset, see Convergence 1) umap distribution will be sketched for local maxima
        n_bins: the number of bins along UMAP1 and UMAP2
        smooth: whether to smooth the 2D histogram (aids local maxima finding for particulaly continuous distributions)
        smooth_width: scalar multiple of one-bin-width defining sigma for gaussian kernel smoothing
        pruned_maxima: max number of local maxima above which pruning will be attempted
        radius: radius in bin-space (Euclidean distance) below which points are considered poorly-separated and are candidates for pruning
        final_maxima: the count of local maxima with highest associated bin count that will be returned as final to the user

    Outputs
        binned_ptcls_mask: binary mask of shape ((n_particles_total, n_local_maxima)) labeling all particles in the bin and neighboring 8 bins of a local maxima
        labels: a unique letter assigned to each local maxima
    '''
    def make_edges(umap, n_bins):
        '''
        Helper function to create two 1-D arrays defining @nbins bin edges along axes x and y
        '''
        xedges = np.linspace(umap.min(axis=0)[0], umap.max(axis=0)[0], n_bins + 1)
        yedges = np.linspace(umap.min(axis=0)[1], umap.max(axis=0)[1], n_bins + 1)
        return xedges, yedges

    def local_maxima_2D(data):
        '''
        Helper function to find the coordinates and values of local maxima of a 2d hist
        Evaluates local maxima using a footprint equal to 3x3 set of bins
        '''
        size = 3
        footprint = np.ones((size, size))
        footprint[1, 1] = 0

        filtered = maximum_filter(data, footprint=footprint, mode='mirror')
        mask_local_maxima = data > filtered
        coords = np.asarray(np.where(mask_local_maxima)).T
        values = data[mask_local_maxima]

        return coords, values

    def gen_peaks_img(coords, values, edges):
        '''
        Helper function to scatter the values of the local maxima onto a hist with bins defined by the full umap
        '''
        filtered = np.zeros((edges[0].shape[0], edges[1].shape[0]))
        for peak in range(coords.shape[0]):
            filtered[tuple(coords[peak])] = values[peak]
        return filtered

    def prune_local_maxima(coords, values, n_maxima, radius):
        '''
        Helper function to prune "similar" local maxima and preserve UMAP diversity if more local maxima than desired are found
        Construct distance matrix of all coords to all coords in bin-space
        Find all maxima pairs closer than @radius
        While more than @n_maxima local maxima:
            if there are pairs closer than @radius:
                find single smallest distance d between two points
                compare points connected by d, remove lower value point from coords, values, and distance matrix
        Returns
        * coords
        * values
        '''
        dist_matrix = distance_matrix(coords, coords)
        dist_matrix[dist_matrix > radius] = 0  # ignore points separated by > @radius in bin-space

        while len(values) > n_maxima:
            if not np.count_nonzero(dist_matrix) == 0:  # some peaks are too close and need pruning
                indices_to_compare = np.where(dist_matrix == np.min(dist_matrix[np.nonzero(dist_matrix)]))[0]
                if values[indices_to_compare[0]] > values[indices_to_compare[1]]:
                    dist_matrix = np.delete(dist_matrix, indices_to_compare[1], axis=0)
                    dist_matrix = np.delete(dist_matrix, indices_to_compare[1], axis=1)
                    values = np.delete(values, indices_to_compare[1])
                    coords = np.delete(coords, indices_to_compare[1], axis=0)
                else:
                    dist_matrix = np.delete(dist_matrix, indices_to_compare[0], axis=0)
                    dist_matrix = np.delete(dist_matrix, indices_to_compare[0], axis=1)
                    values = np.delete(values, indices_to_compare[0])
                    coords = np.delete(coords, indices_to_compare[0], axis=0)
            else:  # local maxima are already well separated
                return coords, values
        return coords, values

    def coords_to_umap(umap, binned_ptcls_mask, values):
        '''
        Helper function to convert local maxima coords in bin-space to umap-space
        Calculates each local maximum to be the median UMAP1 and UMAP2 value across all particles in each 3x3 set of bins defining a given local maximum
        '''
        umap_median_peaks = np.zeros((len(values), 2))
        for i in range(len(values)):
            umap_median_peaks[i, :] = np.median(umap[binned_ptcls_mask[:, i], :], axis=0)
        return umap_median_peaks

    flog('Using UMAP local maxima sketching', LOG)
    umap = utils.load_pkl(outdir + f'/umaps/umap.{E}.pkl')
    n_particles_sketch = umap.shape[0]

    # create 2d histogram of umap distribution
    edges = make_edges(umap, n_bins=n_bins)
    hist, xedges, yedges, bincount = stats.binned_statistic_2d(umap[:, 0], umap[:, 1], None, 'count', bins=edges, expand_binnumbers=True)
    to_plot = ['umap', 'hist']

    # optionally smooth the histogram to reduce the number of peaks with sigma=width of two bins
    if smooth:
        hist_smooth = gaussian_filter(hist, smooth_width * np.abs(xedges[1] - xedges[0]))
        coords, values = local_maxima_2D(hist_smooth)
        to_plot[-1] = 'hist_smooth'
    else:
        coords, values = local_maxima_2D(hist)
    flog(f'Found {len(values)} local maxima', LOG)

    # prune local maxima that are densely packed and low in value
    coords, values = prune_local_maxima(coords, values, pruned_maxima, radius)
    flog(f'Pruned to {len(values)} local maxima', LOG)

    # find subset of n_peaks highest local maxima
    indices = (-values).argsort()[:final_maxima]
    coords, values = coords[indices], values[indices]
    peaks_img_top = gen_peaks_img(coords, values, edges)
    to_plot.append('peaks_img_top')
    to_plot.append('sketched_umap')
    flog(f'Filtered to top {len(values)} local maxima', LOG)

    # write list of lists containing indices of all particles within maxima bins + all 8 neighboring bins (assumes footprint = (3,3))
    binned_ptcls_mask = np.zeros((n_particles_sketch, len(values)), dtype=bool)
    for i in range(len(values)):
        binned_ptcls_mask[:, i] = (bincount[0, :] >= coords[i, 0] + 0) & \
                                  (bincount[0, :] <= coords[i, 0] + 2) & \
                                  (bincount[1, :] >= coords[i, 1] + 0) & \
                                  (bincount[1, :] <= coords[i, 1] + 2)

    # find median umap coords of each maxima bin for plotting
    coords = coords_to_umap(umap, binned_ptcls_mask, values)

    # plot the original histogram, all peaks, and highest n_peaks
    fig, axes = plt.subplots(1, len(to_plot), figsize=(len(to_plot) * 3.6, 3))
    fig.tight_layout()
    labels = ascii_uppercase[:len(values)]
    for i, ax in enumerate(axes.flat):
        if to_plot[i] == 'umap':
            ax.hexbin(umap[:, 0], umap[:, 1], mincnt=1)
            ax.vlines(x=xedges, ymin=umap.min(axis=0)[1], ymax=umap.max(axis=0)[1], colors='red', linewidth=0.35)
            ax.hlines(y=yedges, xmin=umap.min(axis=0)[0], xmax=umap.max(axis=0)[0], colors='red', linewidth=0.35)
            ax.set_title(f'epoch {E} UMAP')
            ax.set_xlabel('UMAP1')
            ax.set_ylabel('UMAP2')
        elif to_plot[i] == 'hist':
            ax.imshow(np.rot90(hist))
            ax.set_title('UMAP histogram')
        elif to_plot[i] == 'hist_smooth':
            ax.imshow(np.rot90(hist_smooth))
            ax.set_title('UMAP smoothed histogram')
        elif to_plot[i] == 'peaks_img_top':
            ax.imshow(np.rot90(peaks_img_top))
            ax.set_title(f'final {len(labels)} local maxima')
        elif to_plot[i] == 'sketched_umap':
            ax.hexbin(umap[:, 0], umap[:, 1], mincnt=1)
            ax.scatter(*coords.T, c='r')
            ax.set_title(f'sketched epoch {E} UMAP')
            ax.set_xlabel('UMAP1')
            ax.set_ylabel('UMAP2')
            for k in range(len(values)):
                ax.text(x=coords[k, 0] + 0.3,
                        y=coords[k, 1] + 0.3,
                        s=labels[k],
                        fontdict=dict(color='r', size=10))
        ax.spines['bottom'].set_visible(True)
        ax.spines['left'].set_visible(True)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

    plt.savefig(f'{outdir}/plots/03_decoder_UMAP-sketching.png', dpi=300, format='png', transparent=True, bbox_inches='tight')
    flog(f'Saved latent sketching plot to {outdir}/plots/03_decoder_UMAP-sketching.png', LOG)

    return binned_ptcls_mask, labels
Ejemplo n.º 12
0
def main(args):
    assert args.o.endswith('.star'), "Output file must be .star file"
    assert args.particles.endswith('.mrcs') or args.particles.endswith(
        '.txt'), "Input file must be .mrcs or .txt"

    particles = dataset.load_particles(args.particles,
                                       lazy=True,
                                       datadir=args.datadir)
    ctf = utils.load_pkl(args.ctf)
    assert ctf.shape[1] == 9, "Incorrect CTF pkl format"
    assert len(particles) == len(
        ctf
    ), f"{len(particles)} != {len(ctf)}, Number of particles != number of CTF paraameters"
    if args.poses:
        poses = utils.load_pkl(args.poses)
        assert len(particles) == len(
            poses[0]
        ), f"{len(particles)} != {len(poses)}, Number of particles != number of poses"
    log(f'{len(particles)} particles in {args.particles}')

    if args.ref_star:
        ref_star = starfile.Starfile.load(args.ref_star)
        assert len(ref_star) == len(
            particles
        ), f"{len(particles)} != {len(ref_star)}, Number of particles in {args.particles} != number of particles in {args.ref_star}"

    # Get index for particles in each .mrcs file
    if args.particles.endswith('.txt'):
        N_per_chunk = parse_chunk_size(args.particles)
        particle_ind = np.concatenate([np.arange(nn) for nn in N_per_chunk])
        assert len(particle_ind) == len(particles)
    else:  # single .mrcs file
        particle_ind = np.arange(len(particles))

    if args.ind:
        ind = utils.load_pkl(args.ind)
        log(f'Filtering to {len(ind)} particles')
        particles = [particles[ii] for ii in ind]
        ctf = ctf[ind]
        if args.poses:
            poses = (poses[0][ind], poses[1][ind])
        if args.ref_star:
            ref_star.df = ref_star.df.loc[ind]
            # reset the index in the dataframe to avoid any downstream indexing issues
            ref_star.df.reset_index(inplace=True)
        particle_ind = particle_ind[ind]

    particle_ind += 1  # CHANGE TO 1-BASED INDEXING
    image_names = [img.fname for img in particles]
    if args.full_path:
        image_names = [os.path.abspath(img.fname) for img in particles]
    names = [f'{i}@{name}' for i, name in zip(particle_ind, image_names)]

    ctf = ctf[:, 2:]

    # convert poses
    if args.poses:
        eulers = utils.R_to_relion_scipy(poses[0])
        D = particles[0].get().shape[0]
        trans = poses[1] * D  # convert from fraction to pixels

    # Create a new dataframe with required star file headers
    data = {HEADERS[0]: names}
    for i in range(7):
        data[HEADERS[i + 1]] = ctf[:, i]
    if args.poses:
        for i in range(3):
            data[POSE_HDRS[i]] = eulers[:, i]
        for i in range(2):
            data[POSE_HDRS[3 + i]] = trans[:, i]
    df = pd.DataFrame(data=data)
    headers = HEADERS + POSE_HDRS if args.poses else HEADERS
    if args.keep_micrograph:
        assert args.ref_star, "Must provide reference .star file with micrograph coordinates"
        log(f'Copying micrograph coordinates from {args.ref_star}')
        # TODO: Prepend path from args.ref_star to MicrographName?
        for h in MICROGRAPH_HDRS:
            df[h] = ref_star.df[h]
        headers += MICROGRAPH_HDRS

    s = starfile.Starfile(headers, df)
    s.write(args.o)
Ejemplo n.º 13
0
def test_write_starfile():
    subprocess.check_call('./test_utils.sh', shell=True)
    r1 = utils.load_pkl('data/toy_rot_trans.pkl')
    r2 = utils.load_pkl('output/test_pose.pkl')
    assert_array_almost_equal(r1[0], r2[0])
    assert_array_almost_equal(r1[1], r2[1])
Ejemplo n.º 14
0
def main(args):
    mkbasedir(args.o)
    warnexists(args.o)
    assert (
        args.o.endswith('.mrcs')
        or args.o.endswith('.txt')), "Must specify output in .mrcs file format"

    # load images
    lazy = args.lazy
    images = dataset.load_particles(args.mrcs,
                                    lazy=lazy,
                                    datadir=args.datadir,
                                    relion31=args.relion31)

    # filter images
    if args.ind is not None:
        log(f'Filtering image dataset with {args.ind}')
        ind = utils.load_pkl(args.ind).astype(int)
        images = [images[i] for i in ind] if lazy else images[ind]

    original_D = images[0].get().shape[0] if lazy else images.shape[-1]
    log(f'Loading {len(images)} {original_D}x{original_D} images')
    window = args.window
    invert_data = args.invert_data
    downsample = (args.D and args.D < original_D)
    if downsample:
        assert args.D <= original_D, f'New box size {args.D} cannot be larger than the original box size {D}'
        assert args.D % 2 == 0, 'New box size must be even'
        start = int(original_D / 2 - args.D / 2)
        stop = int(original_D / 2 + args.D / 2)
        D = args.D
        log(f'Downsampling images to {D}x{D}')
    else:
        D = original_D

    def _combine_imgs(imgs):
        ret = []
        for img in imgs:
            img.shape = (1, *img.shape)  # (D,D) -> (1,D,D)
        cur = imgs[0]
        for img in imgs[1:]:
            if img.fname == cur.fname and img.offset == cur.offset + 4 * np.product(
                    cur.shape):
                cur.shape = (cur.shape[0] + 1, *cur.shape[1:])
            else:
                ret.append(cur)
                cur = img
        ret.append(cur)
        return ret

    def preprocess(imgs):
        if lazy:
            imgs = _combine_imgs(imgs)
            imgs = np.concatenate([i.get() for i in imgs])
        with Pool(min(args.max_threads, mp.cpu_count())) as p:
            # todo: refactor as a routine in dataset.py

            # note: applying the window before downsampling is slightly
            # different than in the original workflow
            if window:
                imgs *= dataset.window_mask(original_D, args.window_r, .99)
            ret = np.asarray(p.map(fft.ht2_center, imgs))
            if invert_data:
                ret *= -1
            if downsample:
                ret = ret[:, start:stop, start:stop]
            ret = fft.symmetrize_ht(ret)
        return ret

    def preprocess_in_batches(imgs, b):
        ret = np.empty((len(imgs), D + 1, D + 1), dtype=np.float32)
        Nbatches = math.ceil(len(imgs) / b)
        for ii in range(Nbatches):
            log(f'Processing batch of {b} images ({ii+1} of {Nbatches})')
            ret[ii * b:(ii + 1) * b, :, :] = preprocess(imgs[ii * b:(ii + 1) *
                                                             b])
        return ret

    nchunks = math.ceil(len(images) / args.chunk)
    out_mrcs = [
        f'.{i}.ft'.join(os.path.splitext(args.o)) for i in range(nchunks)
    ]
    chunk_names = [os.path.basename(x) for x in out_mrcs]
    for i in range(nchunks):
        log(f'Processing chunk {i+1} of {nchunks}')
        chunk = images[i * args.chunk:(i + 1) * args.chunk]
        new = preprocess_in_batches(chunk, args.b)
        log(f'New shape: {new.shape}')
        log(f'Saving {out_mrcs[i]}')
        mrc.write(out_mrcs[i], new, is_vol=False)

    out_txt = f'{os.path.splitext(args.o)[0]}.ft.txt'
    log(f'Saving summary txt file {out_txt}')
    with open(out_txt, 'w') as f:
        f.write('\n'.join(chunk_names))
Ejemplo n.º 15
0
def encoder_latent_umaps(workdir, outdir, epochs, n_particles_total, subset, random_seed, use_umap_gpu, random_state, n_epochs_umap, LOG):
    '''
    Calculates UMAP embeddings of subset of particles' selected epochs' latent encodings

    Inputs
        workdir: path to directory containing cryodrgn training results
        outdir: path to base directory to save outputs
        epochs: array of epochs for which to calculate UMAPs
        n_particles_total: int of total number of particles trained
        subset: int, size of subset on which to calculate umap, None means all
        random_seed: int, seed for random selection of subset particles
        use_umap_gpu: bool, whether to use the cuML library to GPU accelerate UMAP calculations (if available in env)
        random_state: int, random state seed used by UMAP for reproducibility at slight cost of performance (None means faster but non-reproducible)

    Outputs
        pkl of each UMAP embedding stored in outdir/umaps/umap.epoch.pkl
        png of all UMAPs

    # apparently running multiple UMAP embeddings (i.e. for each epoch's z.pkl) in parallel on CPU requires difficult backend setup
    # see https://github.com/lmcinnes/umap/issues/707
    # therefore not implemented currently
    '''

    if subset == 'None':
        n_particles_subset = n_particles_total
        flog('Using full particle stack for UMAP', LOG)
    else:
        if random_seed is None:
            random_seed = random.randint(0, 100000)
            random.seed(random_seed)
        else:
            random.seed(random_seed)
        n_particles_subset = min(n_particles_total, int(subset))
        flog(f'Randomly selecting {n_particles_subset} particle subset on which to run UMAP (with random seed {random_seed})', LOG)
    ind_subset = sorted(random.sample(range(0, n_particles_total), k=n_particles_subset))
    utils.save_pkl(ind_subset, outdir + '/ind_subset.pkl')

    for epoch in epochs:
        flog(f'Now calculating UMAP for epoch {epoch} with random_state {random_state}', LOG)
        z = utils.load_pkl(workdir + f'/z.{epoch}.pkl')[ind_subset, :]
        if use_umap_gpu: #using cuML library GPU-accelerated UMAP
            reducer = cuUMAP(random_state=random_state, n_epochs=n_epochs_umap)
            umap_embedding = reducer.fit_transform(z)
        else: #using umap-learn library CPU-bound UMAP
            reducer = umap.UMAP(random_state=random_state)
            umap_embedding = reducer.fit_transform(z)
        utils.save_pkl(umap_embedding, f'{outdir}/umaps/umap.{epoch}.pkl')


    n_cols = int(np.ceil(len(epochs) ** 0.5))
    n_rows = int(np.ceil(len(epochs) / n_cols))

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(2 * n_cols, 2 * n_rows), sharex='all', sharey='all')
    fig.tight_layout()

    for i, ax in enumerate(axes.flat):
        try:
            umap_embedding = utils.load_pkl(f'{outdir}/umaps/umap.{epochs[i]}.pkl')
            toplot = ax.hexbin(umap_embedding[:, 0], umap_embedding[:, 1], bins='log', mincnt=1)
            ax.set_title(f'epoch {epochs[i]}')
        except IndexError:
            pass
        except FileNotFoundError:
            flog(f'Could not find file {outdir}/umaps/umap.{epoch}.pkl', LOG)
            pass

    if len(axes.shape) == 1:
        axes[0].set_ylabel('UMAP2')
        for a in axes[:]: a.set_xlabel('UMAP1')
    else:
        assert len(axes.shape) == 2 #there are more than one row and column of axes
        for a in axes[:, 0]: a.set_ylabel('UMAP2')
        for a in axes[-1, :]: a.set_xlabel('UMAP1')
    fig.subplots_adjust(right=0.96)
    cbar_ax = fig.add_axes([0.98, 0.15, 0.02, 0.7])
    cbar = fig.colorbar(toplot, cax=cbar_ax)
    cbar.ax.set_ylabel('particle density', rotation=90)

    plt.subplots_adjust(wspace=0.1)
    plt.subplots_adjust(hspace=0.3)
    plt.savefig(f'{outdir}/plots/01_encoder_umaps.png', dpi=300, format='png', transparent=True, bbox_inches='tight')
    flog(f'Saved UMAP distribution plot to {outdir}/plots/01_encoder_umaps.png', LOG)
Ejemplo n.º 16
0
def main(args):
    t1 = dt.now()

    # Configure paths
    E = args.epoch
    sampling = args.epoch_interval
    epochs = np.arange(4, E+1, sampling)
    if epochs[-1] != E:
        epochs = np.append(epochs, E)
    workdir = args.workdir
    config = f'{workdir}/config.pkl'
    logfile = f'{workdir}/run.log'

    # assert all required files are locatable
    for i in range(E):
        assert os.path.exists(f'{workdir}/z.{i}.pkl'), f'Could not find training file {workdir}/z.{i}.pkl'
    for epoch in epochs:
        assert os.path.exists(f'{workdir}/weights.{epoch}.pkl'), f'Could not find training file {workdir}/weights.{epoch}.pkl'
    assert os.path.exists(config), f'Could not find training file {config}'
    assert os.path.exists(logfile), f'Could not find training file {logfile}'

    # Configure output paths
    if E == -1:
        outdir = f'{workdir}/convergence'
    if args.outdir:
        outdir = args.outdir
    else:
        outdir = f'{workdir}/convergence.{E}'
    os.makedirs(outdir, exist_ok=True)
    os.makedirs(f'{outdir}/plots', exist_ok=True)
    os.makedirs(f'{outdir}/umaps', exist_ok=True)
    os.makedirs(f'{outdir}/repr_particles', exist_ok=True)
    LOG = f'{outdir}/convergence.log'
    flog(args, LOG)
    if len(epochs) < 3:
        flog('WARNING: Too few epochs have been selected for some analyses. Try decreasing --epoch-interval to a shorter interval, or analyzing a later epoch.', LOG)
    if len(epochs) < 2:
        flog('WARNING: Too few epochs have been selected for any analyses. Try decreasing --epoch-interval to a shorter interval, or analyzing a later epoch.', LOG)
        sys.exit()
    flog(f'Saving all results to {outdir}', LOG)

    # Get total number of particles, latent space dimensionality, input image size
    n_particles_total, n_dim = utils.load_pkl(f'{workdir}/z.{E}.pkl').shape
    cfg  = utils.load_pkl(config)
    img_size = cfg['lattice_args']['D'] - 1

    # Commonly used variables
    #plt.rcParams.update({'font.size': 16})
    plt.rcParams.update({'axes.linewidth': 1.5})
    chimerax_colors = np.divide(((192, 192, 192),
                                 (255, 255, 178),
                                 (178, 255, 255),
                                 (178, 178, 255),
                                 (255, 178, 255),
                                 (255, 178, 178),
                                 (178, 255, 178),
                                 (229, 191, 153),
                                 (153, 191, 229),
                                 (204, 204, 153)), 255)

    # Convergence 1: total loss
    flog('Convergence 1: plotting total loss curve ...', LOG)
    plot_loss(logfile, outdir, E, LOG)

    # Convergence 2: UMAP latent embeddings
    if args.skip_umap:
        flog('Skipping UMAP calculation ...', LOG)
    else:
        flog(f'Convergence 2: calculating and plotting UMAP embeddings of epochs {epochs} ...', LOG)
        if 'cuml.manifold.umap' in sys.modules:
            use_umap_gpu = True
        else:
            use_umap_gpu = False
        if args.force_umap_cpu:
            use_umap_gpu = False
        if use_umap_gpu:
            flog('Using GPU-accelerated UMAP via cuML library', LOG)
        else:
            flog('Using CPU-bound UMAP via umap-learn library', LOG)
        subset = args.subset
        random_state = args.random_state
        random_seed = args.random_seed
        n_epochs_umap = args.n_epochs_umap
        encoder_latent_umaps(workdir, outdir, epochs, n_particles_total, subset, random_seed, use_umap_gpu, random_state, n_epochs_umap, LOG)

    # Convergence 3: latent encoding shifts
    flog(f'Convergence 3: calculating and plotting latent encoding vector shifts for all epochs up to epoch {E} ...', LOG)
    encoder_latent_shifts(workdir, outdir, epochs, E, LOG)

    # Convergence 4: correlation of generated volumes
    flog(f'Convergence 4: sketching epoch {E}\'s latent space to find representative and well-supported latent encodings  ...', LOG)
    n_bins = args.n_bins
    smooth = args.smooth
    smooth_width = args.smooth_width
    pruned_maxima = args.pruned_maxima
    radius = args.radius
    final_maxima = args.final_maxima
    binned_ptcls_mask, labels = sketch_via_umap_local_maxima(outdir, E, LOG, n_bins=n_bins, smooth=smooth, smooth_width=smooth_width, pruned_maxima=pruned_maxima, radius=radius, final_maxima=final_maxima)

    follow_candidate_particles(workdir, outdir, epochs, n_dim, binned_ptcls_mask, labels, LOG)

    if args.skip_volgen:
        flog('Skipping volume generation ...', LOG)
    else:
        flog(f'Generating volumes at representative latent encodings for epochs {epochs} ...', LOG)
        Apix = args.Apix
        flip = args.flip
        invert = True if args.invert else None
        downsample = args.downsample
        cuda = args.cuda
        generate_volumes(workdir, outdir, epochs, Apix, flip, invert, downsample, cuda, LOG)

        flog(f'Generating masked volumes at representative latent encodings for epochs {epochs} ...', LOG)
        thresh = args.thresh
        dilate = args.dilate
        dist = args.dist
        max_threads = min(args.max_threads, multiprocessing.cpu_count())
        flog(f'Using {max_threads} threads to parallelize masking', LOG)
        mask_volumes(outdir, epochs, labels, max_threads, LOG, Apix, thresh=thresh, dilate=dilate, dist=dist)

    flog(f'Calculating masked map-map CCs at representative latent encodings for epochs {epochs} ...', LOG)
    calculate_CCs(outdir, epochs, labels, chimerax_colors, LOG)

    flog(f'Calculating masked map-map FSCs at representative latent encodings for epochs {epochs} ...', LOG)
    if args.downsample:
        img_size = args.downsample
    calculate_FSCs(outdir, epochs, labels, img_size, chimerax_colors, LOG)

    flog(f'Finished in {dt.now() - t1}', LOG)
Ejemplo n.º 17
0
def follow_candidate_particles(workdir, outdir, epochs, n_dim, binned_ptcls_mask, labels, LOG):
    '''
    Monitor how the labeled set of particles migrates within latent space at selected epochs over training

    Inputs:
        workdir: path to directory containing cryodrgn training results
        outdir: path to base directory to save outputs
        epochs: array of epochs for which to calculate UMAPs
        n_dim: latent dimensionality
        binned_ptcls_mask: (n_particles, len(labels)) binary mask of which particles belong to which class
        labels: unique identifier for each class of representative latent encodings

    Outputs
        plot.png tracking representative latent encodings through epochs
        latent.txt of representative latent encodings for each epoch
    '''

    # track sketched points from epoch E through selected previous epochs and plot overtop UMAP embedding
    n_cols = int(np.ceil(len(epochs) ** 0.5))
    n_rows = int(np.ceil(len(epochs) / n_cols))

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(2 * n_cols, 2 * n_rows), sharex='all', sharey='all')
    fig.tight_layout()

    ind_subset = utils.load_pkl(f'{outdir}/ind_subset.pkl')
    for i, ax in enumerate(axes.flat):
        try:
            umap = utils.load_pkl(f'{outdir}/umaps/umap.{epochs[i]}.pkl')
            z = utils.load_pkl(f'{workdir}/z.{epochs[i]}.pkl')[ind_subset,:]
            z_maxima_median = np.zeros((len(labels), n_dim))

            for k in range(len(labels)):
                z_maxima_median[k, :] = np.median(z[binned_ptcls_mask[:, k]], axis=0) # find median latent value of each maximum in a given epoch

            z_maxima_median_ondata, z_maxima_median_ondata_ind = analysis.get_nearest_point(z, z_maxima_median)  # find on-data latent encoding of each median latent value
            umap_maxima_median_ondata = umap[z_maxima_median_ondata_ind] # find on-data UMAP embedding of each median latent encoding

            # Write out the on-data median latent values of each labeled set of particles for each epoch in epochs
            with open(f'{outdir}/repr_particles/latent_representative.{epochs[i]}.txt', 'w') as f:
                np.savetxt(f, z_maxima_median_ondata, delimiter=' ', newline='\n', header='', footer='', comments='# ')
            flog(f'Saved representative latent encodings for epoch {epochs[i]} to {outdir}/repr_particles/latent_representative.{epochs[i]}.txt', LOG)

            for k in range(len(labels)):
                ax.text(x=umap_maxima_median_ondata[k, 0] + 0.3,
                        y=umap_maxima_median_ondata[k, 1] + 0.3,
                        s=labels[k],
                        fontdict=dict(color='r', size=10))
            toplot = ax.hexbin(*umap.T, bins='log', mincnt=1)
            ax.scatter(umap_maxima_median_ondata[:, 0], umap_maxima_median_ondata[:, 1], s=10, linewidth=0, c='r',
                       alpha=1)
            ax.set_title(f'epoch {epochs[i]}')
        except IndexError:
            pass

    if len(axes.shape) == 1:
        axes[0].set_ylabel('UMAP2')
        for a in axes[:]: a.set_xlabel('UMAP1')
    else:
        assert len(axes.shape) == 2 #there are more than one row and column of axes
        for a in axes[:, 0]: a.set_ylabel('UMAP2')
        for a in axes[-1, :]: a.set_xlabel('UMAP1')
    fig.subplots_adjust(right=0.96)
    cbar_ax = fig.add_axes([0.98, 0.15, 0.02, 0.7])
    cbar = fig.colorbar(toplot, cax=cbar_ax)
    cbar.ax.set_ylabel('Particle Density', rotation=90)

    plt.subplots_adjust(wspace=0.1)
    plt.subplots_adjust(hspace=0.25)

    plt.savefig(f'{outdir}/plots/04_decoder_maxima-sketch-consistency.png', dpi=300, format='png', transparent=True, bbox_inches='tight')
    flog(f'Saved plot tracking representative latent encodings through epochs {epochs} to {outdir}/plots/04_decoder_maxima-sketch-consistency.png', LOG)
Ejemplo n.º 18
0
def main(args):
    assert args.o.endswith('.mrc')

    t1 = time.time()    
    log(args)
    if not os.path.exists(os.path.dirname(args.o)):
        os.makedirs(os.path.dirname(args.o))

    ## set the device
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    log('Use cuda {}'.format(use_cuda))
    if not use_cuda:
        log('WARNING: No GPUs detected')

    # load the particles
    if args.ind is not None:
        args.ind = utils.load_pkl(args.ind).astype(int)
    if args.tilt is None:
        if args.preprocessed:
            data = dataset.PreprocessedMRCData(args.particles, norm=(0,1), ind=args.ind)
        else:
            data = dataset.LazyMRCData(args.particles, norm=(0,1), invert_data=args.invert_data, datadir=args.datadir, ind=args.ind, relion31=args.relion31)
        tilt = None
    else: # tilt series 
        if args.relion31: raise NotImplementedError
        if args.preprocessed: raise NotImplementedError
        data = dataset.TiltMRCData(args.particles, args.tilt, norm=(0,1), invert_data=args.invert_data, datadir=args.datadir, ind=args.ind)
        tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32), device=device)
    D = data.D
    Nimg = data.N

    lattice = Lattice(D, extent=D//2, device=device)

    posetracker = PoseTracker.load(args.poses, Nimg, D, None, args.ind, device=device)

    if args.ctf is not None:
        log('Loading ctf params from {}'.format(args.ctf))
        ctf_params = ctf.load_ctf_for_training(D-1, args.ctf)
        ctf_params = torch.tensor(ctf_params, device=device)
        if args.ind is not None: ctf_params = ctf_params[ind]
    else: ctf_params = None
    Apix = ctf_params[0,0] if ctf_params is not None else 1

    V = torch.zeros((D,D,D), device=device)
    counts = torch.zeros((D,D,D), device=device)
    
    mask = lattice.get_circular_mask(D//2)

    if args.first:
        args.first = min(args.first, Nimg)
        iterator = range(args.first)
    else:
        iterator = range(Nimg)

    for ii in iterator:
        if ii%100==0: log('image {}'.format(ii))
        r, t = posetracker.get_pose(ii)
        ff = data.get(ii)
        if tilt is not None:
            ff, ff_tilt = ff # EW
        ff = torch.tensor(ff, device=device)
        ff = ff.view(-1)[mask]
        if ctf_params is not None:
            freqs = lattice.freqs2d/ctf_params[ii,0]
            c = ctf.compute_ctf(freqs, *ctf_params[ii,1:]).view(-1)[mask]
            ff *= c.sign()
        if t is not None:
            ff = lattice.translate_ht(ff.view(1,-1),t.view(1,1,2), mask).view(-1)
        ff_coord = lattice.coords[mask] @ r
        add_slice(V, counts, ff_coord, ff, D)

        # tilt series
        if args.tilt is not None:
            ff_tilt = torch.tensor(ff_tilt, device=device)
            ff_tilt = ff_tilt.view(-1)[mask]
            if ctf_params is not None:
                ff_tilt *= c.sign()
            if t is not None:
                ff_tilt = lattice.translate_ht(ff_tilt.view(1,-1), t.view(1,1,2), mask).view(-1)
            ff_coord = lattice.coords[mask] @ tilt @ r
            add_slice(V, counts, ff_coord, ff_tilt, D)

    td = time.time()-t1
    log('Backprojected {} images in {}s ({}s per image)'.format(len(iterator), td, td/Nimg ))
    counts[counts == 0] = 1
    V /= counts
    V = fft.ihtn_center(V[0:-1,0:-1,0:-1].cpu().numpy())
    mrc.write(args.o,V.astype('float32'), Apix=Apix)
Ejemplo n.º 19
0
def main(args):
    t1 = dt.now()
    log(args)
    E = args.epoch
    workdir = args.workdir
    zfile = f'{workdir}/z.{E}.pkl'
    weights = f'{workdir}/weights.{E}.pkl'
    config = f'{workdir}/config.pkl'
    outdir = f'{workdir}/landscape.{E}'

    if args.outdir:
        outdir = args.outdir
    log(f'Saving results to {outdir}')
    if not os.path.exists(outdir):
        os.mkdir(outdir)

    z = utils.load_pkl(zfile)
    zdim = z.shape[1]
    K = args.sketch_size

    vol_args = dict(Apix=args.Apix,
                    downsample=args.downsample,
                    flip=args.flip,
                    cuda=args.device)
    vg = VolumeGenerator(weights, config, vol_args, skip_vol=args.skip_vol)

    if args.vol_ind is not None:
        args.vol_ind = utils.load_pkl(args.vol_ind)

    if not args.skip_vol:
        generate_volumes(z, outdir, vg, K)
    else:
        log('Skipping volume generation')

    if args.skip_umap:
        assert os.path.exists(f'{outdir}/umap.pkl')
        log('Skipping UMAP')
    else:
        log(f'Copying UMAP from {workdir}/analyze.{E}/umap.pkl')
        if os.path.exists(f'{workdir}/analyze.{E}/umap.pkl'):
            from shutil import copyfile
            copyfile(f'{workdir}/analyze.{E}/umap.pkl', f'{outdir}/umap.pkl')
        else:
            raise NotImplementedError

    if args.mask:
        log(f'Using custom mask {args.mask}')
    make_mask(outdir, K, args.dilate, args.thresh, args.mask)

    log('Analyzing volumes...')
    # get particle indices if the dataset was originally filtered
    c = utils.load_pkl(config)
    particle_ind = utils.load_pkl(
        c['dataset_args']
        ['ind']) if c['dataset_args']['ind'] is not None else None
    analyze_volumes(outdir,
                    K,
                    args.pc_dim,
                    args.M,
                    args.linkage,
                    vol_ind=args.vol_ind,
                    plot_dim=args.plot_dim,
                    particle_ind_orig=particle_ind)
    td = dt.now() - t1
    log(f'Finished in {td}')
Ejemplo n.º 20
0
def analyze_volumes(outdir,
                    K,
                    dim,
                    M,
                    linkage,
                    vol_ind=None,
                    plot_dim=5,
                    particle_ind_orig=None):
    cmap = choose_cmap(M)

    # load mean volume, compute it if it does not exist
    if not os.path.exists(f'{outdir}/kmeans{K}/vol_mean.mrc'):
        volm = np.array([
            mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0]
            for i in range(K)
        ]).mean(axis=0)
        mrc.write(f'{outdir}/kmeans{K}/vol_mean.mrc', volm)
    else:
        volm = mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_mean.mrc')[0]

    # load mask
    mask = mrc.parse_mrc(f'{outdir}/mask.mrc')[0].astype(bool)
    log(f'{mask.sum()} voxels in mask')

    # load volumes
    vols = np.array([
        mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0][mask]
        for i in range(K)
    ])
    vols[vols < 0] = 0

    # load umap
    umap = utils.load_pkl(f'{outdir}/umap.pkl')
    ind = np.loadtxt(f'{outdir}/kmeans{K}/centers_ind.txt').astype(int)

    if vol_ind is not None:
        log(f'Filtering to {len(vol_ind)} volumes')
        vols = vols[vol_ind]
        ind = ind[vol_ind]

    # compute PCA
    pca = PCA(dim)
    pca.fit(vols)
    pc = pca.transform(vols)
    utils.save_pkl(pc, f'{outdir}/vol_pca_{K}.pkl')
    utils.save_pkl(pca, f'{outdir}/vol_pca_obj.pkl')
    log('Explained variance ratio:')
    log(pca.explained_variance_ratio_)

    # save rxn coordinates
    for i in range(plot_dim):
        subdir = f'{outdir}/vol_pcs/pc{i+1}'
        if not os.path.exists(subdir):
            os.makedirs(subdir)
        min_, max_ = pc[:, i].min(), pc[:, i].max()
        log((min_, max_))
        for j, val in enumerate(np.linspace(min_, max_, 10, endpoint=True)):
            v = volm.copy()
            v[mask] += pca.components_[i] * val
            mrc.write(f'{subdir}/{j}.mrc', v)

    # which plots to show???
    def plot(i, j):
        plt.figure()
        plt.scatter(pc[:, i], pc[:, j])
        plt.xlabel(
            f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})')
        plt.ylabel(
            f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})')
        plt.savefig(f'{outdir}/vol_pca_{K}_{i+1}_{j+1}.png')

    for i in range(plot_dim - 1):
        plot(i, i + 1)

    # clustering
    subdir = f'{outdir}/clustering_L2_{linkage}_{M}'
    if not os.path.exists(subdir):
        os.makedirs(subdir)
    cluster = AgglomerativeClustering(n_clusters=M,
                                      affinity='euclidean',
                                      linkage=linkage)
    labels = cluster.fit_predict(vols)
    utils.save_pkl(labels, f'{subdir}/state_labels.pkl')

    kmeans_labels = utils.load_pkl(f'{outdir}/kmeans{K}/labels.pkl')
    kmeans_counts = Counter(kmeans_labels)
    for i in range(M):
        vol_i = np.where(labels == i)[0]
        log(f'State {i}: {len(vol_i)} volumes')
        if vol_ind is not None:
            vol_i = np.arange(K)[vol_ind][vol_i]
        vol_i_all = np.array([
            mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0]
            for i in vol_i
        ])
        nparticles = np.array([kmeans_counts[i] for i in vol_i])
        vol_i_mean = np.average(vol_i_all, axis=0, weights=nparticles)
        vol_i_std = np.average((vol_i_all - vol_i_mean)**2,
                               axis=0,
                               weights=nparticles)**.5
        mrc.write(f'{subdir}/state_{i}_mean.mrc',
                  vol_i_mean.astype(np.float32))
        mrc.write(f'{subdir}/state_{i}_std.mrc', vol_i_std.astype(np.float32))
        if not os.path.exists(f'{subdir}/state_{i}'):
            os.makedirs(f'{subdir}/state_{i}')
        for v in vol_i:
            os.symlink(f'{outdir}/kmeans{K}/vol_{v:03d}.mrc',
                       f'{subdir}/state_{i}/vol_{v:03d}.mrc')
        particle_ind = analysis.get_ind_for_cluster(kmeans_labels, vol_i)
        log(f'State {i}: {len(particle_ind)} particles')
        if particle_ind_orig is not None:
            utils.save_pkl(particle_ind_orig[particle_ind],
                           f'{subdir}/state_{i}_particle_ind.pkl')
        else:
            utils.save_pkl(particle_ind,
                           f'{subdir}/state_{i}_particle_ind.pkl')

    # plot clustering results
    def hack_barplot(counts_):
        if M <= 20:  # HACK TO GET COLORS
            with sns.color_palette(cmap):
                g = sns.barplot(np.arange(M), counts_)
        else:  # default is husl
            g = sns.barplot(np.arange(M), counts_)
        return g

    plt.figure()
    counts = Counter(labels)
    g = hack_barplot([counts[i] for i in range(M)])
    for i in range(M):
        g.text(i - .1, counts[i] + 2, counts[i])
    plt.xlabel('State')
    plt.ylabel('Count')
    plt.savefig(f'{subdir}/state_volume_counts.png')

    plt.figure()
    particle_counts = [
        np.sum([kmeans_counts[ii] for ii in np.where(labels == i)[0]])
        for i in range(M)
    ]
    g = hack_barplot(particle_counts)
    for i in range(M):
        g.text(i - .1, particle_counts[i] + 2, particle_counts[i])
    plt.xlabel('State')
    plt.ylabel('Count')
    plt.savefig(f'{subdir}/state_particle_counts.png')

    def plot_w_labels(i, j):
        plt.figure()
        plt.scatter(pc[:, i], pc[:, j], c=labels, cmap=cmap)
        plt.xlabel(
            f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})')
        plt.ylabel(
            f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})')
        plt.savefig(f'{subdir}/vol_pca_{K}_{i+1}_{j+1}.png')

    for i in range(plot_dim - 1):
        plot_w_labels(i, i + 1)

    def plot_w_labels_annotated(i, j):
        fig, ax = plt.subplots(figsize=(16, 16))
        plt.scatter(pc[:, i], pc[:, j], c=labels, cmap=cmap)
        annots = np.arange(K)
        if vol_ind is not None:
            annots = annots[vol_ind]
        for ii, k in enumerate(annots):
            ax.annotate(str(k), pc[ii, [i, j]] + np.array([.1, .1]))
        plt.xlabel(
            f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})')
        plt.ylabel(
            f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})')
        plt.savefig(f'{subdir}/vol_pca_{K}_annotated_{i+1}_{j+1}.png')

    for i in range(plot_dim - 1):
        plot_w_labels_annotated(i, i + 1)

    # plot clusters on UMAP
    umap_i = umap[ind]
    fig, ax = plt.subplots(figsize=(8, 8))
    plt.scatter(umap[:, 0],
                umap[:, 1],
                alpha=.1,
                s=1,
                rasterized=True,
                color='lightgrey')
    colors = get_colors_for_cmap(cmap, M)
    for i in range(M):
        c = umap_i[np.where(labels == i)]
        plt.scatter(c[:, 0], c[:, 1], label=i, color=colors[i])
    plt.legend()
    plt.xlabel('UMAP1')
    plt.ylabel('UMAP2')
    plt.savefig(f'{subdir}/umap.png')

    fig, ax = plt.subplots(figsize=(16, 16))
    plt.scatter(umap[:, 0],
                umap[:, 1],
                alpha=.1,
                s=1,
                rasterized=True,
                color='lightgrey')
    plt.scatter(umap_i[:, 0], umap_i[:, 1], c=labels, cmap=cmap)
    annots = np.arange(K)
    if vol_ind is not None:
        annots = annots[vol_ind]
    for i, k in enumerate(annots):
        ax.annotate(str(k), umap_i[i] + np.array([.1, .1]))
    plt.xlabel('UMAP1')
    plt.ylabel('UMAP2')
    plt.savefig(f'{subdir}/umap_annotated.png')