Exemplo n.º 1
0
def analyze_zN(z, outdir, vg, skip_umap=False):
    zdim = z.shape[1]

    # Principal component analysis
    log('Perfoming principal component analysis...')
    pc, pca = analysis.run_pca(z)
    start, end = np.percentile(pc[:,0],(5,95))
    z_pc1 = analysis.get_pc_traj(pca, z.shape[1], 10, 1, start, end)
    start, end = np.percentile(pc[:,1],(5,95))
    z_pc2 = analysis.get_pc_traj(pca, z.shape[1], 10, 2, start, end)

    # kmeans clustering
    log('K-means clustering...')
    K = 20
    kmeans_labels, centers = analysis.cluster_kmeans(z, K)
    centers, centers_ind = analysis.get_nearest_point(z, centers)
    if not os.path.exists(f'{outdir}/kmeans20'): 
        os.mkdir(f'{outdir}/kmeans20')
    utils.save_pkl(kmeans_labels, f'{outdir}/kmeans20/labels.pkl')
    np.savetxt(f'{outdir}/kmeans20/centers.txt', centers)
    np.savetxt(f'{outdir}/kmeans20/centers_ind.txt', centers_ind, fmt='%d')

    # Generate volumes
    log('Generating volumes...')
    vg.gen_volumes(f'{outdir}/pc1', z_pc1)
    vg.gen_volumes(f'{outdir}/pc2', z_pc2)
    vg.gen_volumes(f'{outdir}/kmeans20', centers)

    # UMAP -- slow step
    if zdim > 2 and not skip_umap:
        log('Running UMAP...')
        umap_emb = analysis.run_umap(z)
        utils.save_pkl(umap_emb, f'{outdir}/umap.pkl')

    # Make some plots
    log('Generating plots...')
    plt.figure(1)
    plt.scatter(pc[:,0], pc[:,1], alpha=.1, s=2)
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    plt.savefig(f'{outdir}/z_pca.png')
    
    if zdim > 2 and not skip_umap:
        plt.figure(2)
        plt.scatter(umap_emb[:,0], umap_emb[:,1], alpha=.1, s=2)
        plt.xlabel('UMAP1')
        plt.ylabel('UMAP2')
        plt.savefig(f'{outdir}/umap.png')

    analysis.plot_by_cluster(pc[:,0], pc[:,1], K, kmeans_labels, centers_ind=centers_ind, annotate=True)
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    plt.savefig(f'{outdir}/kmeans20/z_pca.png')

    if zdim > 2 and not skip_umap:
        analysis.plot_by_cluster(umap_emb[:,0], umap_emb[:,1], K, kmeans_labels, centers_ind=centers_ind, annotate=True)
        plt.xlabel('UMAP1')
        plt.ylabel('UMAP2')
        plt.savefig(f'{outdir}/kmeans20/umap.png')
Exemplo n.º 2
0
def encoder_latent_shifts(workdir, outdir, epochs, E, LOG):
    '''
    Calculates and plots various metrics characterizing the per-particle latent vectors between successive epochs.

    Inputs
        workdir: path to directory containing cryodrgn training results
        outdir: path to base directory to save outputs
        E: int of epoch from which to evaluate convergence (0-indexed)
            Note that because three epochs are needed to define the two inter-epoch vectors analyzed "per epoch",
            the output contains metrics for E-2 epochs. Accordingly, plot x-axes are labeled from epoch 2 - epoch E.

    Outputs
        pkl of all statistics of shape(E, n_metrics)
        png of each statistic plotted over training
    '''
    metrics = ['dot product', 'magnitude', 'cosine distance']

    vector_metrics = np.zeros((E-1, len(metrics)))
    for i in np.arange(E-1):
        flog(f'Calculating vector metrics for epochs {i}-{i+1} and {i+1}-{i+2}', LOG)
        if i == 0:
            z1 = utils.load_pkl(f'{workdir}/z.{i}.pkl')
            z2 = utils.load_pkl(f'{workdir}/z.{i+1}.pkl')
            z3 = utils.load_pkl(f'{workdir}/z.{i+2}.pkl')
        else:
            z1 = z2.copy()
            z2 = z3.copy()
            z3 = utils.load_pkl(workdir + f'/z.{i+2}.pkl')

        diff21 = z2 - z1
        diff32 = z3 - z2

        vector_metrics[i, 0] = np.median(np.einsum('ij,ij->i', diff21, diff32), axis=0)  # median vector dot product
        vector_metrics[i, 1] = np.median(np.linalg.norm(diff32, axis=1), axis=0)  # median vector magnitude
        uv = np.sum(diff32 * diff21, axis=1)
        uu = np.sum(diff32 * diff32, axis=1)
        vv = np.sum(diff21 * diff21, axis=1)
        vector_metrics[i, 2] = np.median(1 - uv / (np.sqrt(uu) * np.sqrt(vv))) #median vector cosine distance

    utils.save_pkl(vector_metrics, f'{outdir}/vector_metrics.pkl')

    fig, axes = plt.subplots(1, len(metrics), figsize=(10,3))
    fig.tight_layout()
    for i,ax in enumerate(axes.flat):
        ax.plot(np.arange(2,E+1), vector_metrics[:,i])
        ax.set_xlabel('epoch')
        ax.set_ylabel(metrics[i])
    plt.savefig(f'{outdir}/plots/02_encoder_latent_vector_shifts.png', dpi=300, format='png', transparent=True, bbox_inches='tight')

    flog(f'Saved latent vector shifts plots to {outdir}/plots/02_encoder_latent_vector_shifts.png', LOG)
Exemplo n.º 3
0
def generate_volumes(z, outdir, vg, K):
    # kmeans clustering
    log('Sketching distribution...')
    kmeans_labels, centers = analysis.cluster_kmeans(z,
                                                     K,
                                                     on_data=True,
                                                     reorder=True)
    centers, centers_ind = analysis.get_nearest_point(z, centers)
    if not os.path.exists(f'{outdir}/kmeans{K}'):
        os.mkdir(f'{outdir}/kmeans{K}')
    utils.save_pkl(kmeans_labels, f'{outdir}/kmeans{K}/labels.pkl')
    np.savetxt(f'{outdir}/kmeans{K}/centers.txt', centers)
    np.savetxt(f'{outdir}/kmeans{K}/centers_ind.txt', centers_ind, fmt='%d')
    log('Generating volumes...')
    vg.gen_volumes(f'{outdir}/kmeans{K}', centers)
Exemplo n.º 4
0
def calculate_CCs(outdir, epochs, labels, chimerax_colors, LOG):
    '''
    Returns the masked map-map correlation between temporally sequential volume pairs outdir/vols.{epochs}, for each class in labels

    Inputs:
        outdir: path to base directory to save outputs
        epochs: array of epochs for which to calculate UMAPs
        labels: unique identifier for each class of representative latent encodings
        chimerax_colors: approximate colors matching ChimeraX palette to facilitate comparison to volume visualization

    Outputs:
        plot.png of sequential volume pairs map-map CC for each class in labels across training epochs
    '''
    def calc_cc(vol1, vol2):
        '''
        Helper function to calculate the zero-mean correlation coefficient as defined in eq 2 in https://journals.iucr.org/d/issues/2018/09/00/kw5139/index.html
        vol1 and vol2 should be maps of the same box size, structured as numpy arrays with ndim=3, i.e. by loading with cryodrgn.mrc.parse_mrc
        '''
        zmean1 = (vol1 - np.mean(vol1))
        zmean2 = (vol2 - np.mean(vol2))
        cc = (np.sum(zmean1 ** 2) ** -0.5) * (np.sum(zmean2 ** 2) ** -0.5) * np.sum(zmean1 * zmean2)
        return cc

    cc_masked = np.zeros((len(labels), len(epochs) - 1))

    for i in range(len(epochs) - 1):
        for cluster in np.arange(len(labels)):
            vol1, _ = mrc.parse_mrc(f'{outdir}/vols.{epochs[i]}/vol_{cluster:03d}.masked.mrc')
            vol2, _ = mrc.parse_mrc(f'{outdir}/vols.{epochs[i+1]}/vol_{cluster:03d}.masked.mrc')

            cc_masked[cluster, i] = calc_cc(vol1, vol2)

    utils.save_pkl(cc_masked, f'{outdir}/cc_masked.pkl')

    fig, ax = plt.subplots(1, 1)

    ax.set_xlabel('epoch')
    ax.set_ylabel('masked CC')
    for i in range(len(labels)):
        ax.plot(epochs[1:], cc_masked[i,:], c=chimerax_colors[i] * 0.75, linewidth=2.5)
    ax.legend(labels, ncol=3, fontsize='x-small')

    plt.savefig(f'{outdir}/plots/05_decoder_CC.png', dpi=300, format='png', transparent=True, bbox_inches='tight')
    flog(f'Saved map-map correlation plot to {outdir}/plots/05_decoder_CC.png', LOG)
Exemplo n.º 5
0
def encoder_latent_umaps(workdir, outdir, epochs, n_particles_total, subset, random_seed, use_umap_gpu, random_state, n_epochs_umap, LOG):
    '''
    Calculates UMAP embeddings of subset of particles' selected epochs' latent encodings

    Inputs
        workdir: path to directory containing cryodrgn training results
        outdir: path to base directory to save outputs
        epochs: array of epochs for which to calculate UMAPs
        n_particles_total: int of total number of particles trained
        subset: int, size of subset on which to calculate umap, None means all
        random_seed: int, seed for random selection of subset particles
        use_umap_gpu: bool, whether to use the cuML library to GPU accelerate UMAP calculations (if available in env)
        random_state: int, random state seed used by UMAP for reproducibility at slight cost of performance (None means faster but non-reproducible)

    Outputs
        pkl of each UMAP embedding stored in outdir/umaps/umap.epoch.pkl
        png of all UMAPs

    # apparently running multiple UMAP embeddings (i.e. for each epoch's z.pkl) in parallel on CPU requires difficult backend setup
    # see https://github.com/lmcinnes/umap/issues/707
    # therefore not implemented currently
    '''

    if subset == 'None':
        n_particles_subset = n_particles_total
        flog('Using full particle stack for UMAP', LOG)
    else:
        if random_seed is None:
            random_seed = random.randint(0, 100000)
            random.seed(random_seed)
        else:
            random.seed(random_seed)
        n_particles_subset = min(n_particles_total, int(subset))
        flog(f'Randomly selecting {n_particles_subset} particle subset on which to run UMAP (with random seed {random_seed})', LOG)
    ind_subset = sorted(random.sample(range(0, n_particles_total), k=n_particles_subset))
    utils.save_pkl(ind_subset, outdir + '/ind_subset.pkl')

    for epoch in epochs:
        flog(f'Now calculating UMAP for epoch {epoch} with random_state {random_state}', LOG)
        z = utils.load_pkl(workdir + f'/z.{epoch}.pkl')[ind_subset, :]
        if use_umap_gpu: #using cuML library GPU-accelerated UMAP
            reducer = cuUMAP(random_state=random_state, n_epochs=n_epochs_umap)
            umap_embedding = reducer.fit_transform(z)
        else: #using umap-learn library CPU-bound UMAP
            reducer = umap.UMAP(random_state=random_state)
            umap_embedding = reducer.fit_transform(z)
        utils.save_pkl(umap_embedding, f'{outdir}/umaps/umap.{epoch}.pkl')


    n_cols = int(np.ceil(len(epochs) ** 0.5))
    n_rows = int(np.ceil(len(epochs) / n_cols))

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(2 * n_cols, 2 * n_rows), sharex='all', sharey='all')
    fig.tight_layout()

    for i, ax in enumerate(axes.flat):
        try:
            umap_embedding = utils.load_pkl(f'{outdir}/umaps/umap.{epochs[i]}.pkl')
            toplot = ax.hexbin(umap_embedding[:, 0], umap_embedding[:, 1], bins='log', mincnt=1)
            ax.set_title(f'epoch {epochs[i]}')
        except IndexError:
            pass
        except FileNotFoundError:
            flog(f'Could not find file {outdir}/umaps/umap.{epoch}.pkl', LOG)
            pass

    if len(axes.shape) == 1:
        axes[0].set_ylabel('UMAP2')
        for a in axes[:]: a.set_xlabel('UMAP1')
    else:
        assert len(axes.shape) == 2 #there are more than one row and column of axes
        for a in axes[:, 0]: a.set_ylabel('UMAP2')
        for a in axes[-1, :]: a.set_xlabel('UMAP1')
    fig.subplots_adjust(right=0.96)
    cbar_ax = fig.add_axes([0.98, 0.15, 0.02, 0.7])
    cbar = fig.colorbar(toplot, cax=cbar_ax)
    cbar.ax.set_ylabel('particle density', rotation=90)

    plt.subplots_adjust(wspace=0.1)
    plt.subplots_adjust(hspace=0.3)
    plt.savefig(f'{outdir}/plots/01_encoder_umaps.png', dpi=300, format='png', transparent=True, bbox_inches='tight')
    flog(f'Saved UMAP distribution plot to {outdir}/plots/01_encoder_umaps.png', LOG)
Exemplo n.º 6
0
def calculate_FSCs(outdir, epochs, labels, img_size, chimerax_colors, LOG):
    '''
    Returns the masked FSC between temporally sequential volume pairs outdir/vols.{epochs}, for each class in labels

    Inputs:
        outdir: path to base directory to save outputs
        epochs: array of epochs for which to calculate UMAPs
        labels: unique identifier for each class of representative latent encodings
        img_size: box size of input images in pixels
        chimerax_colors: approximate colors matching ChimeraX palette to facilitate comparison to volume visualization

    Outputs:
        plot.png of sequential volume pairs map-map FSC for each class in labels across training epochs
        plot.png of sequential volume pairs map-map FSC at Nyquist for each class in labels across training epochs

    TODO: accelerate via multiprocessing (create iterable list of calc_fsc calls?)

    '''
    def calc_fsc(vol1_path, vol2_path):
        '''
        Helper function to calculate the FSC between two (assumed masked) volumes
        vol1 and vol2 should be maps of the same box size, structured as numpy arrays with ndim=3, i.e. by loading with cryodrgn.mrc.parse_mrc
        '''
        # load masked volumes in fourier space
        vol1, _ = mrc.parse_mrc(vol1_path)
        vol2, _ = mrc.parse_mrc(vol2_path)

        vol1_ft = fft.fftn_center(vol1)
        vol2_ft = fft.fftn_center(vol2)

        # define fourier grid and label into shells
        D = vol1.shape[0]
        x = np.arange(-D // 2, D // 2)
        x0, x1, x2 = np.meshgrid(x, x, x, indexing='ij')
        r = np.sqrt(x0 ** 2 + x1 ** 2 + x2 ** 2)
        r_max = D // 2  # sphere inscribed within volume box
        r_step = 1  # int(np.min(r[r>0]))
        bins = np.arange(0, r_max, r_step)
        bin_labels = np.searchsorted(bins, r, side='right')

        # calculate the FSC via labeled shells
        num = ndimage.sum(np.real(vol1_ft * np.conjugate(vol2_ft)), labels=bin_labels, index=bins + 1)
        den1 = ndimage.sum(np.abs(vol1_ft) ** 2, labels=bin_labels, index=bins + 1)
        den2 = ndimage.sum(np.abs(vol2_ft) ** 2, labels=bin_labels, index=bins + 1)
        fsc = num / np.sqrt(den1 * den2)

        x = bins / D  # x axis should be spatial frequency in 1/px
        return x, fsc

    # calculate masked FSCs for all volumes
    fsc_masked = np.zeros((len(labels), len(epochs) - 1, img_size // 2))

    for cluster in range(len(labels)):
        flog(f'Calculating all FSCs for cluster {cluster}', LOG)

        for i in range(len(epochs) - 1):
            vol1_path = f'{outdir}/vols.{epochs[i]}/vol_{cluster:03d}.masked.mrc'
            vol2_path = f'{outdir}/vols.{epochs[i+1]}/vol_{cluster:03d}.masked.mrc'

            x, fsc_masked[cluster, i, :] = calc_fsc(vol1_path, vol2_path)

    utils.save_pkl(fsc_masked, f'{outdir}/fsc_masked.pkl')
    utils.save_pkl(x, f'{outdir}/fsc_xaxis.pkl')

    # plot all fscs
    n_cols = int(np.ceil(len(labels) ** 0.5))
    n_rows = int(np.ceil(len(labels) / n_cols))
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(2 * n_cols, 2 * n_rows), sharex='all', sharey='all')
    fig.tight_layout()
    for cluster, ax in enumerate(axes.flat):
        try:
            colors = plt.cm.viridis(np.linspace(0, 1, len(epochs - 1)))
            ax.set_ylim(0, 1.02)
            ax.set_title(f'maximum {labels[cluster]}')
            legend = []
            for i in range(len(epochs) - 1):
                ax.plot(x, fsc_masked[cluster, i, :], color=colors[i])
                legend.append(f'epoch {epochs[i+1]}')
        except IndexError:
            pass

    x_center, y_center = n_cols//2, n_rows//2
    axes[y_center, 0].set_ylabel('FSC')
    axes[-1,x_center].set_xlabel('frequency (1/px)')
    axes[-1, 0].legend(legend, loc='lower left', ncol=2, fontsize=6.5)
    plt.subplots_adjust(hspace=0.3)
    plt.subplots_adjust(wspace=0.1)
    plt.savefig(f'{outdir}/plots/06_decoder_FSC.png', dpi=300, format='png', transparent=True, bbox_inches='tight')
    flog(f'Saved map-map FSC plot to {outdir}/plots/06_decoder_FSC.png', LOG)

    # plot all FSCs at Nyquist only
    fig, ax = plt.subplots(1, 1)

    ax.set_xlabel('epoch')
    ax.set_ylabel('masked FSC at nyquist')
    for i in range(len(labels)):
        ax.plot(epochs[1:], fsc_masked[i, :, -1], c=chimerax_colors[i] * 0.75, linewidth=2.5)
    ax.legend(labels, ncol=3, fontsize='x-small')

    plt.savefig(f'{outdir}/plots/07_decoder_FSC-nyquist.png', dpi=300, format='png', transparent=True, bbox_inches='tight')
    flog(f'Saved map-map FSC (Nyquist) plot to {outdir}/plots/07_decoder_FSC-nyquist.png', LOG)
Exemplo n.º 7
0
def analyze_zN(z, outdir, vg, skip_umap=False, num_pcs=2, num_ksamples=20):
    zdim = z.shape[1]

    # Principal component analysis
    log('Perfoming principal component analysis...')
    pc, pca = analysis.run_pca(z)
    log('Generating volumes...')
    for i in range(num_pcs):
        start, end = np.percentile(pc[:, i], (5, 95))
        z_pc = analysis.get_pc_traj(pca, z.shape[1], 10, i + 1, start, end)
        vg.gen_volumes(f'{outdir}/pc{i+1}', z_pc)

    # kmeans clustering
    log('K-means clustering...')
    K = num_ksamples
    kmeans_labels, centers = analysis.cluster_kmeans(z, K)
    centers, centers_ind = analysis.get_nearest_point(z, centers)
    if not os.path.exists(f'{outdir}/kmeans{K}'):
        os.mkdir(f'{outdir}/kmeans{K}')
    utils.save_pkl(kmeans_labels, f'{outdir}/kmeans{K}/labels.pkl')
    np.savetxt(f'{outdir}/kmeans{K}/centers.txt', centers)
    np.savetxt(f'{outdir}/kmeans{K}/centers_ind.txt', centers_ind, fmt='%d')
    log('Generating volumes...')
    vg.gen_volumes(f'{outdir}/kmeans{K}', centers)

    # UMAP -- slow step
    if zdim > 2 and not skip_umap:
        log('Running UMAP...')
        umap_emb = analysis.run_umap(z)
        utils.save_pkl(umap_emb, f'{outdir}/umap.pkl')

    # Make some plots
    log('Generating plots...')
    plt.figure(1)
    g = sns.jointplot(x=pc[:, 0], y=pc[:, 1], alpha=.1, s=2)
    g.set_axis_labels('PC1', 'PC2')
    plt.tight_layout()
    plt.savefig(f'{outdir}/z_pca.png')

    plt.figure(2)
    g = sns.jointplot(x=pc[:, 0], y=pc[:, 1], kind='hex')
    g.set_axis_labels('PC1', 'PC2')
    plt.tight_layout()
    plt.savefig(f'{outdir}/z_pca_hexbin.png')

    if zdim > 2 and not skip_umap:
        plt.figure(3)
        g = sns.jointplot(x=umap_emb[:, 0], y=umap_emb[:, 1], alpha=.1, s=2)
        g.set_axis_labels('UMAP1', 'UMAP2')
        plt.tight_layout()
        plt.savefig(f'{outdir}/umap.png')

        plt.figure(4)
        g = sns.jointplot(x=umap_emb[:, 0], y=umap_emb[:, 1], kind='hex')
        g.set_axis_labels('UMAP1', 'UMAP2')
        plt.tight_layout()
        plt.savefig(f'{outdir}/umap_hexbin.png')

    analysis.scatter_annotate(pc[:, 0],
                              pc[:, 1],
                              centers_ind=centers_ind,
                              annotate=True)
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    plt.savefig(f'{outdir}/kmeans{K}/z_pca.png')

    g = analysis.scatter_annotate_hex(pc[:, 0],
                                      pc[:, 1],
                                      centers_ind=centers_ind,
                                      annotate=True)
    g.set_axis_labels('PC1', 'PC2')
    plt.tight_layout()
    plt.savefig(f'{outdir}/kmeans{K}/z_pca_hex.png')

    if zdim > 2 and not skip_umap:
        analysis.scatter_annotate(umap_emb[:, 0],
                                  umap_emb[:, 1],
                                  centers_ind=centers_ind,
                                  annotate=True)
        plt.xlabel('UMAP1')
        plt.ylabel('UMAP2')
        plt.savefig(f'{outdir}/kmeans{K}/umap.png')

        g = analysis.scatter_annotate_hex(umap_emb[:, 0],
                                          umap_emb[:, 1],
                                          centers_ind=centers_ind,
                                          annotate=True)
        g.set_axis_labels('UMAP1', 'UMAP2')
        plt.tight_layout()
        plt.savefig(f'{outdir}/kmeans{K}/umap_hex.png')

    for i in range(num_pcs):
        if not skip_umap:
            analysis.scatter_color(umap_emb[:, 0],
                                   umap_emb[:, 1],
                                   pc[:, i],
                                   label=f'PC{i+1}')
            plt.xlabel('UMAP1')
            plt.ylabel('UMAP2')
            plt.tight_layout()
            plt.savefig(f'{outdir}/pc{i+1}/umap.png')
Exemplo n.º 8
0
def analyze_volumes(outdir,
                    K,
                    dim,
                    M,
                    linkage,
                    vol_ind=None,
                    plot_dim=5,
                    particle_ind_orig=None):
    cmap = choose_cmap(M)

    # load mean volume, compute it if it does not exist
    if not os.path.exists(f'{outdir}/kmeans{K}/vol_mean.mrc'):
        volm = np.array([
            mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0]
            for i in range(K)
        ]).mean(axis=0)
        mrc.write(f'{outdir}/kmeans{K}/vol_mean.mrc', volm)
    else:
        volm = mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_mean.mrc')[0]

    # load mask
    mask = mrc.parse_mrc(f'{outdir}/mask.mrc')[0].astype(bool)
    log(f'{mask.sum()} voxels in mask')

    # load volumes
    vols = np.array([
        mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0][mask]
        for i in range(K)
    ])
    vols[vols < 0] = 0

    # load umap
    umap = utils.load_pkl(f'{outdir}/umap.pkl')
    ind = np.loadtxt(f'{outdir}/kmeans{K}/centers_ind.txt').astype(int)

    if vol_ind is not None:
        log(f'Filtering to {len(vol_ind)} volumes')
        vols = vols[vol_ind]
        ind = ind[vol_ind]

    # compute PCA
    pca = PCA(dim)
    pca.fit(vols)
    pc = pca.transform(vols)
    utils.save_pkl(pc, f'{outdir}/vol_pca_{K}.pkl')
    utils.save_pkl(pca, f'{outdir}/vol_pca_obj.pkl')
    log('Explained variance ratio:')
    log(pca.explained_variance_ratio_)

    # save rxn coordinates
    for i in range(plot_dim):
        subdir = f'{outdir}/vol_pcs/pc{i+1}'
        if not os.path.exists(subdir):
            os.makedirs(subdir)
        min_, max_ = pc[:, i].min(), pc[:, i].max()
        log((min_, max_))
        for j, val in enumerate(np.linspace(min_, max_, 10, endpoint=True)):
            v = volm.copy()
            v[mask] += pca.components_[i] * val
            mrc.write(f'{subdir}/{j}.mrc', v)

    # which plots to show???
    def plot(i, j):
        plt.figure()
        plt.scatter(pc[:, i], pc[:, j])
        plt.xlabel(
            f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})')
        plt.ylabel(
            f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})')
        plt.savefig(f'{outdir}/vol_pca_{K}_{i+1}_{j+1}.png')

    for i in range(plot_dim - 1):
        plot(i, i + 1)

    # clustering
    subdir = f'{outdir}/clustering_L2_{linkage}_{M}'
    if not os.path.exists(subdir):
        os.makedirs(subdir)
    cluster = AgglomerativeClustering(n_clusters=M,
                                      affinity='euclidean',
                                      linkage=linkage)
    labels = cluster.fit_predict(vols)
    utils.save_pkl(labels, f'{subdir}/state_labels.pkl')

    kmeans_labels = utils.load_pkl(f'{outdir}/kmeans{K}/labels.pkl')
    kmeans_counts = Counter(kmeans_labels)
    for i in range(M):
        vol_i = np.where(labels == i)[0]
        log(f'State {i}: {len(vol_i)} volumes')
        if vol_ind is not None:
            vol_i = np.arange(K)[vol_ind][vol_i]
        vol_i_all = np.array([
            mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0]
            for i in vol_i
        ])
        nparticles = np.array([kmeans_counts[i] for i in vol_i])
        vol_i_mean = np.average(vol_i_all, axis=0, weights=nparticles)
        vol_i_std = np.average((vol_i_all - vol_i_mean)**2,
                               axis=0,
                               weights=nparticles)**.5
        mrc.write(f'{subdir}/state_{i}_mean.mrc',
                  vol_i_mean.astype(np.float32))
        mrc.write(f'{subdir}/state_{i}_std.mrc', vol_i_std.astype(np.float32))
        if not os.path.exists(f'{subdir}/state_{i}'):
            os.makedirs(f'{subdir}/state_{i}')
        for v in vol_i:
            os.symlink(f'{outdir}/kmeans{K}/vol_{v:03d}.mrc',
                       f'{subdir}/state_{i}/vol_{v:03d}.mrc')
        particle_ind = analysis.get_ind_for_cluster(kmeans_labels, vol_i)
        log(f'State {i}: {len(particle_ind)} particles')
        if particle_ind_orig is not None:
            utils.save_pkl(particle_ind_orig[particle_ind],
                           f'{subdir}/state_{i}_particle_ind.pkl')
        else:
            utils.save_pkl(particle_ind,
                           f'{subdir}/state_{i}_particle_ind.pkl')

    # plot clustering results
    def hack_barplot(counts_):
        if M <= 20:  # HACK TO GET COLORS
            with sns.color_palette(cmap):
                g = sns.barplot(np.arange(M), counts_)
        else:  # default is husl
            g = sns.barplot(np.arange(M), counts_)
        return g

    plt.figure()
    counts = Counter(labels)
    g = hack_barplot([counts[i] for i in range(M)])
    for i in range(M):
        g.text(i - .1, counts[i] + 2, counts[i])
    plt.xlabel('State')
    plt.ylabel('Count')
    plt.savefig(f'{subdir}/state_volume_counts.png')

    plt.figure()
    particle_counts = [
        np.sum([kmeans_counts[ii] for ii in np.where(labels == i)[0]])
        for i in range(M)
    ]
    g = hack_barplot(particle_counts)
    for i in range(M):
        g.text(i - .1, particle_counts[i] + 2, particle_counts[i])
    plt.xlabel('State')
    plt.ylabel('Count')
    plt.savefig(f'{subdir}/state_particle_counts.png')

    def plot_w_labels(i, j):
        plt.figure()
        plt.scatter(pc[:, i], pc[:, j], c=labels, cmap=cmap)
        plt.xlabel(
            f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})')
        plt.ylabel(
            f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})')
        plt.savefig(f'{subdir}/vol_pca_{K}_{i+1}_{j+1}.png')

    for i in range(plot_dim - 1):
        plot_w_labels(i, i + 1)

    def plot_w_labels_annotated(i, j):
        fig, ax = plt.subplots(figsize=(16, 16))
        plt.scatter(pc[:, i], pc[:, j], c=labels, cmap=cmap)
        annots = np.arange(K)
        if vol_ind is not None:
            annots = annots[vol_ind]
        for ii, k in enumerate(annots):
            ax.annotate(str(k), pc[ii, [i, j]] + np.array([.1, .1]))
        plt.xlabel(
            f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})')
        plt.ylabel(
            f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})')
        plt.savefig(f'{subdir}/vol_pca_{K}_annotated_{i+1}_{j+1}.png')

    for i in range(plot_dim - 1):
        plot_w_labels_annotated(i, i + 1)

    # plot clusters on UMAP
    umap_i = umap[ind]
    fig, ax = plt.subplots(figsize=(8, 8))
    plt.scatter(umap[:, 0],
                umap[:, 1],
                alpha=.1,
                s=1,
                rasterized=True,
                color='lightgrey')
    colors = get_colors_for_cmap(cmap, M)
    for i in range(M):
        c = umap_i[np.where(labels == i)]
        plt.scatter(c[:, 0], c[:, 1], label=i, color=colors[i])
    plt.legend()
    plt.xlabel('UMAP1')
    plt.ylabel('UMAP2')
    plt.savefig(f'{subdir}/umap.png')

    fig, ax = plt.subplots(figsize=(16, 16))
    plt.scatter(umap[:, 0],
                umap[:, 1],
                alpha=.1,
                s=1,
                rasterized=True,
                color='lightgrey')
    plt.scatter(umap_i[:, 0], umap_i[:, 1], c=labels, cmap=cmap)
    annots = np.arange(K)
    if vol_ind is not None:
        annots = annots[vol_ind]
    for i, k in enumerate(annots):
        ax.annotate(str(k), umap_i[i] + np.array([.1, .1]))
    plt.xlabel('UMAP1')
    plt.ylabel('UMAP2')
    plt.savefig(f'{subdir}/umap_annotated.png')