def main(args): assert args.o.endswith('.star') assert args.particles.endswith( '.mrcs' ), "Only a single particle stack as an .mrcs is currently supported" particles = mrc.parse_mrc(args.particles, lazy=True)[0] ctf = utils.load_pkl(args.ctf) assert ctf.shape[1] == 9, "Incorrect CTF pkl format" assert len(particles) == len( ctf ), f"{len(particles)} != {len(ctf)}, Number of particles != number of CTF paraameters" if args.poses: poses = utils.load_pkl(args.poses) assert len(particles) == len( poses[0] ), f"{len(particles)} != {len(poses)}, Number of particles != number of poses" log('{} particles'.format(len(particles))) if args.ind: ind = utils.load_pkl(args.ind) log(f'Filtering to {len(ind)} particles') ctf = ctf[ind] if args.poses: poses = (poses[0][ind], poses[1][ind]) else: ind = np.arange(len(particles)) # _rlnImageName ind += 1 # CHANGE TO 1-BASED INDEXING image_name = os.path.basename( args.particles) if not args.full_path else args.particles names = [f'{i}@{image_name}' for i in ind] ctf = ctf[:, 2:] # convert poses if args.poses: eulers = utils.R_to_relion_scipy(poses[0]) D = particles[0].get().shape[0] trans = poses[1] * D # convert from fraction to pixels data = {HEADERS[0]: names} for i in range(7): data[HEADERS[i + 1]] = ctf[:, i] if args.poses: for i in range(3): data[POSE_HDRS[i]] = eulers[:, i] for i in range(2): data[POSE_HDRS[3 + i]] = trans[:, i] df = pd.DataFrame(data=data) headers = HEADERS + POSE_HDRS if args.poses else HEADERS s = starfile.Starfile(headers, df) s.write(args.o)
def encoder_latent_shifts(workdir, outdir, epochs, E, LOG): ''' Calculates and plots various metrics characterizing the per-particle latent vectors between successive epochs. Inputs workdir: path to directory containing cryodrgn training results outdir: path to base directory to save outputs E: int of epoch from which to evaluate convergence (0-indexed) Note that because three epochs are needed to define the two inter-epoch vectors analyzed "per epoch", the output contains metrics for E-2 epochs. Accordingly, plot x-axes are labeled from epoch 2 - epoch E. Outputs pkl of all statistics of shape(E, n_metrics) png of each statistic plotted over training ''' metrics = ['dot product', 'magnitude', 'cosine distance'] vector_metrics = np.zeros((E-1, len(metrics))) for i in np.arange(E-1): flog(f'Calculating vector metrics for epochs {i}-{i+1} and {i+1}-{i+2}', LOG) if i == 0: z1 = utils.load_pkl(f'{workdir}/z.{i}.pkl') z2 = utils.load_pkl(f'{workdir}/z.{i+1}.pkl') z3 = utils.load_pkl(f'{workdir}/z.{i+2}.pkl') else: z1 = z2.copy() z2 = z3.copy() z3 = utils.load_pkl(workdir + f'/z.{i+2}.pkl') diff21 = z2 - z1 diff32 = z3 - z2 vector_metrics[i, 0] = np.median(np.einsum('ij,ij->i', diff21, diff32), axis=0) # median vector dot product vector_metrics[i, 1] = np.median(np.linalg.norm(diff32, axis=1), axis=0) # median vector magnitude uv = np.sum(diff32 * diff21, axis=1) uu = np.sum(diff32 * diff32, axis=1) vv = np.sum(diff21 * diff21, axis=1) vector_metrics[i, 2] = np.median(1 - uv / (np.sqrt(uu) * np.sqrt(vv))) #median vector cosine distance utils.save_pkl(vector_metrics, f'{outdir}/vector_metrics.pkl') fig, axes = plt.subplots(1, len(metrics), figsize=(10,3)) fig.tight_layout() for i,ax in enumerate(axes.flat): ax.plot(np.arange(2,E+1), vector_metrics[:,i]) ax.set_xlabel('epoch') ax.set_ylabel(metrics[i]) plt.savefig(f'{outdir}/plots/02_encoder_latent_vector_shifts.png', dpi=300, format='png', transparent=True, bbox_inches='tight') flog(f'Saved latent vector shifts plots to {outdir}/plots/02_encoder_latent_vector_shifts.png', LOG)
def main(args): assert args.o.endswith('.star') particles = dataset.load_particles(args.particles, lazy=True, datadir=args.datadir) ctf = utils.load_pkl(args.ctf) assert ctf.shape[1] == 9, "Incorrect CTF pkl format" assert len(particles) == len(ctf), f"{len(particles)} != {len(ctf)}, Number of particles != number of CTF paraameters" if args.poses: poses = utils.load_pkl(args.poses) assert len(particles) == len(poses[0]), f"{len(particles)} != {len(poses)}, Number of particles != number of poses" log('{} particles'.format(len(particles))) if args.ind: ind = utils.load_pkl(args.ind) log(f'Filtering to {len(ind)} particles') particles = [particles[ii] for ii in ind] ctf = ctf[ind] if args.poses: poses = (poses[0][ind], poses[1][ind]) else: ind = np.arange(len(particles)) ind += 1 # CHANGE TO 1-BASED INDEXING image_names = [img.fname for img in particles] if args.full_path: image_names = [os.path.abspath(img.fname) for img in particles] names = [f'{i}@{name}' for i,name in zip(ind, image_names)] ctf = ctf[:,2:] # convert poses if args.poses: eulers = utils.R_to_relion_scipy(poses[0]) D = particles[0].get().shape[0] trans = poses[1] * D # convert from fraction to pixels data = {HEADERS[0]:names} for i in range(7): data[HEADERS[i+1]] = ctf[:,i] if args.poses: for i in range(3): data[POSE_HDRS[i]] = eulers[:,i] for i in range(2): data[POSE_HDRS[3+i]] = trans[:,i] df = pd.DataFrame(data=data) headers = HEADERS + POSE_HDRS if args.poses else HEADERS s = starfile.Starfile(headers,df) s.write(args.o)
def main(args): x = dataset.load_particles(args.input, lazy=True) log(x.shape) ind = utils.load_pkl(args.ind) x = np.array([x[i].get() for i in ind]) log(x.shape) mrc.write(args.o, x)
def main(args): x = dataset.load_particles(args.input, lazy=True) log(f'Loaded {len(x)} particles') ind = utils.load_pkl(args.ind) x = np.array([x[i].get() for i in ind]) log(f'New stack dimensions: {x.shape}') mrc.write(args.o, x)
def main(args): # load particles particles = dataset.load_particles(args.mrcs, datadir=args.datadir) log(particles.shape) Nimg, D, D = particles.shape trans = utils.load_pkl(args.trans) if type(trans) is tuple: trans = trans[1] trans *= args.tscale assert np.all( trans <= 1 ), "ERROR: Old pose format detected. Translations must be in units of fraction of box." trans *= D # convert to pixels assert len(trans) == Nimg xx, yy = np.meshgrid(np.arange(-D / 2, D / 2), np.arange(-D / 2, D / 2)) TCOORD = np.stack([xx, yy], axis=2) / D # DxDx2 imgs = [] for ii in range(Nimg): ff = fft.fft2_center(particles[ii]) tfilt = np.dot(TCOORD, trans[ii]) * -2 * np.pi tfilt = np.cos(tfilt) + np.sin(tfilt) * 1j ff *= tfilt img = fft.ifftn_center(ff) imgs.append(img) imgs = np.asarray(imgs).astype(np.float32) mrc.write(args.o, imgs) if args.out_png: plot_projections(args.out_png, imgs[:9])
def main(args): s = starfile.Starfile.load(args.input) ind = utils.load_pkl(args.ind) log('{} particles'.format(len(s.df))) s.df = s.df.loc[ind] log(len(s.df)) log(len(s.df.index)) s.write(args.o)
def __init__(self, pose_pkl): poses = utils.load_pkl(pose_pkl) self.rots = torch.tensor(poses[0]) self.trans = poses[1] self.N = len(poses[0]) assert self.rots.shape == (self.N, 3, 3) assert self.trans.shape == (self.N, 2) assert self.trans.max() < 1
def main(args): t1 = dt.now() E = args.epoch workdir = args.workdir zfile = f'{workdir}/z.{E}.pkl' weights = f'{workdir}/weights.{E}.pkl' config = f'{workdir}/config.pkl' outdir = f'{workdir}/analyze.{E}' if E == -1: zfile = f'{workdir}/z.pkl' weights = f'{workdir}/weights.pkl' outdir = f'{workdir}/analyze' if args.outdir: outdir = args.outdir log(f'Saving results to {outdir}') if not os.path.exists(outdir): os.mkdir(outdir) z = utils.load_pkl(zfile) zdim = z.shape[1] vol_args = dict(Apix=args.Apix, downsample=args.downsample, flip=args.flip, cuda=args.device) vg = VolumeGenerator(weights, config, vol_args, skip_vol=args.skip_vol) if zdim == 1: analyze_z1(z, outdir, vg) else: analyze_zN(z, outdir, vg, skip_umap=args.skip_umap, num_pcs=args.pc, num_ksamples=args.ksample) # copy over template if file doesn't exist out_ipynb = f'{outdir}/cryoDRGN_viz.ipynb' if not os.path.exists(out_ipynb): log(f'Creating jupyter notebook...') ipynb = f'{cryodrgn._ROOT}/templates/cryoDRGN_viz_template.ipynb' cmd = f'cp {ipynb} {out_ipynb}' subprocess.check_call(cmd, shell=True) else: log(f'{out_ipynb} already exists. Skipping') log(out_ipynb) # copy over template if file doesn't exist out_ipynb = f'{outdir}/cryoDRGN_filtering.ipynb' if not os.path.exists(out_ipynb): log(f'Creating jupyter notebook...') ipynb = f'{cryodrgn._ROOT}/templates/cryoDRGN_filtering_template.ipynb' cmd = f'cp {ipynb} {out_ipynb}' subprocess.check_call(cmd, shell=True) else: log(f'{out_ipynb} already exists. Skipping') log(out_ipynb) log(f'Finished in {dt.now()-t1}')
def main(args): imgs = dataset.load_particles(args.mrcs, lazy=True, datadir=args.datadir) ctf_params = utils.load_pkl(args.ctf_params) assert len(imgs) == len(ctf_params) D = imgs[0].get().shape[0] fx, fy = np.meshgrid(np.linspace(-.5, .5, D, endpoint=False), np.linspace(-.5, .5, D, endpoint=False)) freqs = np.stack([fx.ravel(), fy.ravel()], 1) imgs_flip = np.empty((len(imgs), D, D), dtype=np.float32) for i in range(len(imgs)): if i % 1000 == 0: print(i) c = ctf.compute_ctf_np(freqs / ctf_params[i, 0], *ctf_params[i, 1:]) c = c.reshape((D, D)) ff = fft.fft2_center(imgs[i].get()) ff *= np.sign(c) img = fft.ifftn_center(ff) imgs_flip[i] = img.astype(np.float32) mrc.write(args.o, imgs_flip)
def sketch_via_umap_local_maxima(outdir, E, LOG, n_bins=30, smooth=True, smooth_width=1, pruned_maxima=12, radius=5, final_maxima=10): ''' Sketch the UMAP embedding of epoch E latent space via local maxima finding Inputs: E: epoch for which the (subset, see Convergence 1) umap distribution will be sketched for local maxima n_bins: the number of bins along UMAP1 and UMAP2 smooth: whether to smooth the 2D histogram (aids local maxima finding for particulaly continuous distributions) smooth_width: scalar multiple of one-bin-width defining sigma for gaussian kernel smoothing pruned_maxima: max number of local maxima above which pruning will be attempted radius: radius in bin-space (Euclidean distance) below which points are considered poorly-separated and are candidates for pruning final_maxima: the count of local maxima with highest associated bin count that will be returned as final to the user Outputs binned_ptcls_mask: binary mask of shape ((n_particles_total, n_local_maxima)) labeling all particles in the bin and neighboring 8 bins of a local maxima labels: a unique letter assigned to each local maxima ''' def make_edges(umap, n_bins): ''' Helper function to create two 1-D arrays defining @nbins bin edges along axes x and y ''' xedges = np.linspace(umap.min(axis=0)[0], umap.max(axis=0)[0], n_bins + 1) yedges = np.linspace(umap.min(axis=0)[1], umap.max(axis=0)[1], n_bins + 1) return xedges, yedges def local_maxima_2D(data): ''' Helper function to find the coordinates and values of local maxima of a 2d hist Evaluates local maxima using a footprint equal to 3x3 set of bins ''' size = 3 footprint = np.ones((size, size)) footprint[1, 1] = 0 filtered = maximum_filter(data, footprint=footprint, mode='mirror') mask_local_maxima = data > filtered coords = np.asarray(np.where(mask_local_maxima)).T values = data[mask_local_maxima] return coords, values def gen_peaks_img(coords, values, edges): ''' Helper function to scatter the values of the local maxima onto a hist with bins defined by the full umap ''' filtered = np.zeros((edges[0].shape[0], edges[1].shape[0])) for peak in range(coords.shape[0]): filtered[tuple(coords[peak])] = values[peak] return filtered def prune_local_maxima(coords, values, n_maxima, radius): ''' Helper function to prune "similar" local maxima and preserve UMAP diversity if more local maxima than desired are found Construct distance matrix of all coords to all coords in bin-space Find all maxima pairs closer than @radius While more than @n_maxima local maxima: if there are pairs closer than @radius: find single smallest distance d between two points compare points connected by d, remove lower value point from coords, values, and distance matrix Returns * coords * values ''' dist_matrix = distance_matrix(coords, coords) dist_matrix[dist_matrix > radius] = 0 # ignore points separated by > @radius in bin-space while len(values) > n_maxima: if not np.count_nonzero(dist_matrix) == 0: # some peaks are too close and need pruning indices_to_compare = np.where(dist_matrix == np.min(dist_matrix[np.nonzero(dist_matrix)]))[0] if values[indices_to_compare[0]] > values[indices_to_compare[1]]: dist_matrix = np.delete(dist_matrix, indices_to_compare[1], axis=0) dist_matrix = np.delete(dist_matrix, indices_to_compare[1], axis=1) values = np.delete(values, indices_to_compare[1]) coords = np.delete(coords, indices_to_compare[1], axis=0) else: dist_matrix = np.delete(dist_matrix, indices_to_compare[0], axis=0) dist_matrix = np.delete(dist_matrix, indices_to_compare[0], axis=1) values = np.delete(values, indices_to_compare[0]) coords = np.delete(coords, indices_to_compare[0], axis=0) else: # local maxima are already well separated return coords, values return coords, values def coords_to_umap(umap, binned_ptcls_mask, values): ''' Helper function to convert local maxima coords in bin-space to umap-space Calculates each local maximum to be the median UMAP1 and UMAP2 value across all particles in each 3x3 set of bins defining a given local maximum ''' umap_median_peaks = np.zeros((len(values), 2)) for i in range(len(values)): umap_median_peaks[i, :] = np.median(umap[binned_ptcls_mask[:, i], :], axis=0) return umap_median_peaks flog('Using UMAP local maxima sketching', LOG) umap = utils.load_pkl(outdir + f'/umaps/umap.{E}.pkl') n_particles_sketch = umap.shape[0] # create 2d histogram of umap distribution edges = make_edges(umap, n_bins=n_bins) hist, xedges, yedges, bincount = stats.binned_statistic_2d(umap[:, 0], umap[:, 1], None, 'count', bins=edges, expand_binnumbers=True) to_plot = ['umap', 'hist'] # optionally smooth the histogram to reduce the number of peaks with sigma=width of two bins if smooth: hist_smooth = gaussian_filter(hist, smooth_width * np.abs(xedges[1] - xedges[0])) coords, values = local_maxima_2D(hist_smooth) to_plot[-1] = 'hist_smooth' else: coords, values = local_maxima_2D(hist) flog(f'Found {len(values)} local maxima', LOG) # prune local maxima that are densely packed and low in value coords, values = prune_local_maxima(coords, values, pruned_maxima, radius) flog(f'Pruned to {len(values)} local maxima', LOG) # find subset of n_peaks highest local maxima indices = (-values).argsort()[:final_maxima] coords, values = coords[indices], values[indices] peaks_img_top = gen_peaks_img(coords, values, edges) to_plot.append('peaks_img_top') to_plot.append('sketched_umap') flog(f'Filtered to top {len(values)} local maxima', LOG) # write list of lists containing indices of all particles within maxima bins + all 8 neighboring bins (assumes footprint = (3,3)) binned_ptcls_mask = np.zeros((n_particles_sketch, len(values)), dtype=bool) for i in range(len(values)): binned_ptcls_mask[:, i] = (bincount[0, :] >= coords[i, 0] + 0) & \ (bincount[0, :] <= coords[i, 0] + 2) & \ (bincount[1, :] >= coords[i, 1] + 0) & \ (bincount[1, :] <= coords[i, 1] + 2) # find median umap coords of each maxima bin for plotting coords = coords_to_umap(umap, binned_ptcls_mask, values) # plot the original histogram, all peaks, and highest n_peaks fig, axes = plt.subplots(1, len(to_plot), figsize=(len(to_plot) * 3.6, 3)) fig.tight_layout() labels = ascii_uppercase[:len(values)] for i, ax in enumerate(axes.flat): if to_plot[i] == 'umap': ax.hexbin(umap[:, 0], umap[:, 1], mincnt=1) ax.vlines(x=xedges, ymin=umap.min(axis=0)[1], ymax=umap.max(axis=0)[1], colors='red', linewidth=0.35) ax.hlines(y=yedges, xmin=umap.min(axis=0)[0], xmax=umap.max(axis=0)[0], colors='red', linewidth=0.35) ax.set_title(f'epoch {E} UMAP') ax.set_xlabel('UMAP1') ax.set_ylabel('UMAP2') elif to_plot[i] == 'hist': ax.imshow(np.rot90(hist)) ax.set_title('UMAP histogram') elif to_plot[i] == 'hist_smooth': ax.imshow(np.rot90(hist_smooth)) ax.set_title('UMAP smoothed histogram') elif to_plot[i] == 'peaks_img_top': ax.imshow(np.rot90(peaks_img_top)) ax.set_title(f'final {len(labels)} local maxima') elif to_plot[i] == 'sketched_umap': ax.hexbin(umap[:, 0], umap[:, 1], mincnt=1) ax.scatter(*coords.T, c='r') ax.set_title(f'sketched epoch {E} UMAP') ax.set_xlabel('UMAP1') ax.set_ylabel('UMAP2') for k in range(len(values)): ax.text(x=coords[k, 0] + 0.3, y=coords[k, 1] + 0.3, s=labels[k], fontdict=dict(color='r', size=10)) ax.spines['bottom'].set_visible(True) ax.spines['left'].set_visible(True) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) plt.savefig(f'{outdir}/plots/03_decoder_UMAP-sketching.png', dpi=300, format='png', transparent=True, bbox_inches='tight') flog(f'Saved latent sketching plot to {outdir}/plots/03_decoder_UMAP-sketching.png', LOG) return binned_ptcls_mask, labels
def main(args): assert args.o.endswith('.star'), "Output file must be .star file" assert args.particles.endswith('.mrcs') or args.particles.endswith( '.txt'), "Input file must be .mrcs or .txt" particles = dataset.load_particles(args.particles, lazy=True, datadir=args.datadir) ctf = utils.load_pkl(args.ctf) assert ctf.shape[1] == 9, "Incorrect CTF pkl format" assert len(particles) == len( ctf ), f"{len(particles)} != {len(ctf)}, Number of particles != number of CTF paraameters" if args.poses: poses = utils.load_pkl(args.poses) assert len(particles) == len( poses[0] ), f"{len(particles)} != {len(poses)}, Number of particles != number of poses" log(f'{len(particles)} particles in {args.particles}') if args.ref_star: ref_star = starfile.Starfile.load(args.ref_star) assert len(ref_star) == len( particles ), f"{len(particles)} != {len(ref_star)}, Number of particles in {args.particles} != number of particles in {args.ref_star}" # Get index for particles in each .mrcs file if args.particles.endswith('.txt'): N_per_chunk = parse_chunk_size(args.particles) particle_ind = np.concatenate([np.arange(nn) for nn in N_per_chunk]) assert len(particle_ind) == len(particles) else: # single .mrcs file particle_ind = np.arange(len(particles)) if args.ind: ind = utils.load_pkl(args.ind) log(f'Filtering to {len(ind)} particles') particles = [particles[ii] for ii in ind] ctf = ctf[ind] if args.poses: poses = (poses[0][ind], poses[1][ind]) if args.ref_star: ref_star.df = ref_star.df.loc[ind] # reset the index in the dataframe to avoid any downstream indexing issues ref_star.df.reset_index(inplace=True) particle_ind = particle_ind[ind] particle_ind += 1 # CHANGE TO 1-BASED INDEXING image_names = [img.fname for img in particles] if args.full_path: image_names = [os.path.abspath(img.fname) for img in particles] names = [f'{i}@{name}' for i, name in zip(particle_ind, image_names)] ctf = ctf[:, 2:] # convert poses if args.poses: eulers = utils.R_to_relion_scipy(poses[0]) D = particles[0].get().shape[0] trans = poses[1] * D # convert from fraction to pixels # Create a new dataframe with required star file headers data = {HEADERS[0]: names} for i in range(7): data[HEADERS[i + 1]] = ctf[:, i] if args.poses: for i in range(3): data[POSE_HDRS[i]] = eulers[:, i] for i in range(2): data[POSE_HDRS[3 + i]] = trans[:, i] df = pd.DataFrame(data=data) headers = HEADERS + POSE_HDRS if args.poses else HEADERS if args.keep_micrograph: assert args.ref_star, "Must provide reference .star file with micrograph coordinates" log(f'Copying micrograph coordinates from {args.ref_star}') # TODO: Prepend path from args.ref_star to MicrographName? for h in MICROGRAPH_HDRS: df[h] = ref_star.df[h] headers += MICROGRAPH_HDRS s = starfile.Starfile(headers, df) s.write(args.o)
def test_write_starfile(): subprocess.check_call('./test_utils.sh', shell=True) r1 = utils.load_pkl('data/toy_rot_trans.pkl') r2 = utils.load_pkl('output/test_pose.pkl') assert_array_almost_equal(r1[0], r2[0]) assert_array_almost_equal(r1[1], r2[1])
def main(args): mkbasedir(args.o) warnexists(args.o) assert ( args.o.endswith('.mrcs') or args.o.endswith('.txt')), "Must specify output in .mrcs file format" # load images lazy = args.lazy images = dataset.load_particles(args.mrcs, lazy=lazy, datadir=args.datadir, relion31=args.relion31) # filter images if args.ind is not None: log(f'Filtering image dataset with {args.ind}') ind = utils.load_pkl(args.ind).astype(int) images = [images[i] for i in ind] if lazy else images[ind] original_D = images[0].get().shape[0] if lazy else images.shape[-1] log(f'Loading {len(images)} {original_D}x{original_D} images') window = args.window invert_data = args.invert_data downsample = (args.D and args.D < original_D) if downsample: assert args.D <= original_D, f'New box size {args.D} cannot be larger than the original box size {D}' assert args.D % 2 == 0, 'New box size must be even' start = int(original_D / 2 - args.D / 2) stop = int(original_D / 2 + args.D / 2) D = args.D log(f'Downsampling images to {D}x{D}') else: D = original_D def _combine_imgs(imgs): ret = [] for img in imgs: img.shape = (1, *img.shape) # (D,D) -> (1,D,D) cur = imgs[0] for img in imgs[1:]: if img.fname == cur.fname and img.offset == cur.offset + 4 * np.product( cur.shape): cur.shape = (cur.shape[0] + 1, *cur.shape[1:]) else: ret.append(cur) cur = img ret.append(cur) return ret def preprocess(imgs): if lazy: imgs = _combine_imgs(imgs) imgs = np.concatenate([i.get() for i in imgs]) with Pool(min(args.max_threads, mp.cpu_count())) as p: # todo: refactor as a routine in dataset.py # note: applying the window before downsampling is slightly # different than in the original workflow if window: imgs *= dataset.window_mask(original_D, args.window_r, .99) ret = np.asarray(p.map(fft.ht2_center, imgs)) if invert_data: ret *= -1 if downsample: ret = ret[:, start:stop, start:stop] ret = fft.symmetrize_ht(ret) return ret def preprocess_in_batches(imgs, b): ret = np.empty((len(imgs), D + 1, D + 1), dtype=np.float32) Nbatches = math.ceil(len(imgs) / b) for ii in range(Nbatches): log(f'Processing batch of {b} images ({ii+1} of {Nbatches})') ret[ii * b:(ii + 1) * b, :, :] = preprocess(imgs[ii * b:(ii + 1) * b]) return ret nchunks = math.ceil(len(images) / args.chunk) out_mrcs = [ f'.{i}.ft'.join(os.path.splitext(args.o)) for i in range(nchunks) ] chunk_names = [os.path.basename(x) for x in out_mrcs] for i in range(nchunks): log(f'Processing chunk {i+1} of {nchunks}') chunk = images[i * args.chunk:(i + 1) * args.chunk] new = preprocess_in_batches(chunk, args.b) log(f'New shape: {new.shape}') log(f'Saving {out_mrcs[i]}') mrc.write(out_mrcs[i], new, is_vol=False) out_txt = f'{os.path.splitext(args.o)[0]}.ft.txt' log(f'Saving summary txt file {out_txt}') with open(out_txt, 'w') as f: f.write('\n'.join(chunk_names))
def encoder_latent_umaps(workdir, outdir, epochs, n_particles_total, subset, random_seed, use_umap_gpu, random_state, n_epochs_umap, LOG): ''' Calculates UMAP embeddings of subset of particles' selected epochs' latent encodings Inputs workdir: path to directory containing cryodrgn training results outdir: path to base directory to save outputs epochs: array of epochs for which to calculate UMAPs n_particles_total: int of total number of particles trained subset: int, size of subset on which to calculate umap, None means all random_seed: int, seed for random selection of subset particles use_umap_gpu: bool, whether to use the cuML library to GPU accelerate UMAP calculations (if available in env) random_state: int, random state seed used by UMAP for reproducibility at slight cost of performance (None means faster but non-reproducible) Outputs pkl of each UMAP embedding stored in outdir/umaps/umap.epoch.pkl png of all UMAPs # apparently running multiple UMAP embeddings (i.e. for each epoch's z.pkl) in parallel on CPU requires difficult backend setup # see https://github.com/lmcinnes/umap/issues/707 # therefore not implemented currently ''' if subset == 'None': n_particles_subset = n_particles_total flog('Using full particle stack for UMAP', LOG) else: if random_seed is None: random_seed = random.randint(0, 100000) random.seed(random_seed) else: random.seed(random_seed) n_particles_subset = min(n_particles_total, int(subset)) flog(f'Randomly selecting {n_particles_subset} particle subset on which to run UMAP (with random seed {random_seed})', LOG) ind_subset = sorted(random.sample(range(0, n_particles_total), k=n_particles_subset)) utils.save_pkl(ind_subset, outdir + '/ind_subset.pkl') for epoch in epochs: flog(f'Now calculating UMAP for epoch {epoch} with random_state {random_state}', LOG) z = utils.load_pkl(workdir + f'/z.{epoch}.pkl')[ind_subset, :] if use_umap_gpu: #using cuML library GPU-accelerated UMAP reducer = cuUMAP(random_state=random_state, n_epochs=n_epochs_umap) umap_embedding = reducer.fit_transform(z) else: #using umap-learn library CPU-bound UMAP reducer = umap.UMAP(random_state=random_state) umap_embedding = reducer.fit_transform(z) utils.save_pkl(umap_embedding, f'{outdir}/umaps/umap.{epoch}.pkl') n_cols = int(np.ceil(len(epochs) ** 0.5)) n_rows = int(np.ceil(len(epochs) / n_cols)) fig, axes = plt.subplots(n_rows, n_cols, figsize=(2 * n_cols, 2 * n_rows), sharex='all', sharey='all') fig.tight_layout() for i, ax in enumerate(axes.flat): try: umap_embedding = utils.load_pkl(f'{outdir}/umaps/umap.{epochs[i]}.pkl') toplot = ax.hexbin(umap_embedding[:, 0], umap_embedding[:, 1], bins='log', mincnt=1) ax.set_title(f'epoch {epochs[i]}') except IndexError: pass except FileNotFoundError: flog(f'Could not find file {outdir}/umaps/umap.{epoch}.pkl', LOG) pass if len(axes.shape) == 1: axes[0].set_ylabel('UMAP2') for a in axes[:]: a.set_xlabel('UMAP1') else: assert len(axes.shape) == 2 #there are more than one row and column of axes for a in axes[:, 0]: a.set_ylabel('UMAP2') for a in axes[-1, :]: a.set_xlabel('UMAP1') fig.subplots_adjust(right=0.96) cbar_ax = fig.add_axes([0.98, 0.15, 0.02, 0.7]) cbar = fig.colorbar(toplot, cax=cbar_ax) cbar.ax.set_ylabel('particle density', rotation=90) plt.subplots_adjust(wspace=0.1) plt.subplots_adjust(hspace=0.3) plt.savefig(f'{outdir}/plots/01_encoder_umaps.png', dpi=300, format='png', transparent=True, bbox_inches='tight') flog(f'Saved UMAP distribution plot to {outdir}/plots/01_encoder_umaps.png', LOG)
def main(args): t1 = dt.now() # Configure paths E = args.epoch sampling = args.epoch_interval epochs = np.arange(4, E+1, sampling) if epochs[-1] != E: epochs = np.append(epochs, E) workdir = args.workdir config = f'{workdir}/config.pkl' logfile = f'{workdir}/run.log' # assert all required files are locatable for i in range(E): assert os.path.exists(f'{workdir}/z.{i}.pkl'), f'Could not find training file {workdir}/z.{i}.pkl' for epoch in epochs: assert os.path.exists(f'{workdir}/weights.{epoch}.pkl'), f'Could not find training file {workdir}/weights.{epoch}.pkl' assert os.path.exists(config), f'Could not find training file {config}' assert os.path.exists(logfile), f'Could not find training file {logfile}' # Configure output paths if E == -1: outdir = f'{workdir}/convergence' if args.outdir: outdir = args.outdir else: outdir = f'{workdir}/convergence.{E}' os.makedirs(outdir, exist_ok=True) os.makedirs(f'{outdir}/plots', exist_ok=True) os.makedirs(f'{outdir}/umaps', exist_ok=True) os.makedirs(f'{outdir}/repr_particles', exist_ok=True) LOG = f'{outdir}/convergence.log' flog(args, LOG) if len(epochs) < 3: flog('WARNING: Too few epochs have been selected for some analyses. Try decreasing --epoch-interval to a shorter interval, or analyzing a later epoch.', LOG) if len(epochs) < 2: flog('WARNING: Too few epochs have been selected for any analyses. Try decreasing --epoch-interval to a shorter interval, or analyzing a later epoch.', LOG) sys.exit() flog(f'Saving all results to {outdir}', LOG) # Get total number of particles, latent space dimensionality, input image size n_particles_total, n_dim = utils.load_pkl(f'{workdir}/z.{E}.pkl').shape cfg = utils.load_pkl(config) img_size = cfg['lattice_args']['D'] - 1 # Commonly used variables #plt.rcParams.update({'font.size': 16}) plt.rcParams.update({'axes.linewidth': 1.5}) chimerax_colors = np.divide(((192, 192, 192), (255, 255, 178), (178, 255, 255), (178, 178, 255), (255, 178, 255), (255, 178, 178), (178, 255, 178), (229, 191, 153), (153, 191, 229), (204, 204, 153)), 255) # Convergence 1: total loss flog('Convergence 1: plotting total loss curve ...', LOG) plot_loss(logfile, outdir, E, LOG) # Convergence 2: UMAP latent embeddings if args.skip_umap: flog('Skipping UMAP calculation ...', LOG) else: flog(f'Convergence 2: calculating and plotting UMAP embeddings of epochs {epochs} ...', LOG) if 'cuml.manifold.umap' in sys.modules: use_umap_gpu = True else: use_umap_gpu = False if args.force_umap_cpu: use_umap_gpu = False if use_umap_gpu: flog('Using GPU-accelerated UMAP via cuML library', LOG) else: flog('Using CPU-bound UMAP via umap-learn library', LOG) subset = args.subset random_state = args.random_state random_seed = args.random_seed n_epochs_umap = args.n_epochs_umap encoder_latent_umaps(workdir, outdir, epochs, n_particles_total, subset, random_seed, use_umap_gpu, random_state, n_epochs_umap, LOG) # Convergence 3: latent encoding shifts flog(f'Convergence 3: calculating and plotting latent encoding vector shifts for all epochs up to epoch {E} ...', LOG) encoder_latent_shifts(workdir, outdir, epochs, E, LOG) # Convergence 4: correlation of generated volumes flog(f'Convergence 4: sketching epoch {E}\'s latent space to find representative and well-supported latent encodings ...', LOG) n_bins = args.n_bins smooth = args.smooth smooth_width = args.smooth_width pruned_maxima = args.pruned_maxima radius = args.radius final_maxima = args.final_maxima binned_ptcls_mask, labels = sketch_via_umap_local_maxima(outdir, E, LOG, n_bins=n_bins, smooth=smooth, smooth_width=smooth_width, pruned_maxima=pruned_maxima, radius=radius, final_maxima=final_maxima) follow_candidate_particles(workdir, outdir, epochs, n_dim, binned_ptcls_mask, labels, LOG) if args.skip_volgen: flog('Skipping volume generation ...', LOG) else: flog(f'Generating volumes at representative latent encodings for epochs {epochs} ...', LOG) Apix = args.Apix flip = args.flip invert = True if args.invert else None downsample = args.downsample cuda = args.cuda generate_volumes(workdir, outdir, epochs, Apix, flip, invert, downsample, cuda, LOG) flog(f'Generating masked volumes at representative latent encodings for epochs {epochs} ...', LOG) thresh = args.thresh dilate = args.dilate dist = args.dist max_threads = min(args.max_threads, multiprocessing.cpu_count()) flog(f'Using {max_threads} threads to parallelize masking', LOG) mask_volumes(outdir, epochs, labels, max_threads, LOG, Apix, thresh=thresh, dilate=dilate, dist=dist) flog(f'Calculating masked map-map CCs at representative latent encodings for epochs {epochs} ...', LOG) calculate_CCs(outdir, epochs, labels, chimerax_colors, LOG) flog(f'Calculating masked map-map FSCs at representative latent encodings for epochs {epochs} ...', LOG) if args.downsample: img_size = args.downsample calculate_FSCs(outdir, epochs, labels, img_size, chimerax_colors, LOG) flog(f'Finished in {dt.now() - t1}', LOG)
def follow_candidate_particles(workdir, outdir, epochs, n_dim, binned_ptcls_mask, labels, LOG): ''' Monitor how the labeled set of particles migrates within latent space at selected epochs over training Inputs: workdir: path to directory containing cryodrgn training results outdir: path to base directory to save outputs epochs: array of epochs for which to calculate UMAPs n_dim: latent dimensionality binned_ptcls_mask: (n_particles, len(labels)) binary mask of which particles belong to which class labels: unique identifier for each class of representative latent encodings Outputs plot.png tracking representative latent encodings through epochs latent.txt of representative latent encodings for each epoch ''' # track sketched points from epoch E through selected previous epochs and plot overtop UMAP embedding n_cols = int(np.ceil(len(epochs) ** 0.5)) n_rows = int(np.ceil(len(epochs) / n_cols)) fig, axes = plt.subplots(n_rows, n_cols, figsize=(2 * n_cols, 2 * n_rows), sharex='all', sharey='all') fig.tight_layout() ind_subset = utils.load_pkl(f'{outdir}/ind_subset.pkl') for i, ax in enumerate(axes.flat): try: umap = utils.load_pkl(f'{outdir}/umaps/umap.{epochs[i]}.pkl') z = utils.load_pkl(f'{workdir}/z.{epochs[i]}.pkl')[ind_subset,:] z_maxima_median = np.zeros((len(labels), n_dim)) for k in range(len(labels)): z_maxima_median[k, :] = np.median(z[binned_ptcls_mask[:, k]], axis=0) # find median latent value of each maximum in a given epoch z_maxima_median_ondata, z_maxima_median_ondata_ind = analysis.get_nearest_point(z, z_maxima_median) # find on-data latent encoding of each median latent value umap_maxima_median_ondata = umap[z_maxima_median_ondata_ind] # find on-data UMAP embedding of each median latent encoding # Write out the on-data median latent values of each labeled set of particles for each epoch in epochs with open(f'{outdir}/repr_particles/latent_representative.{epochs[i]}.txt', 'w') as f: np.savetxt(f, z_maxima_median_ondata, delimiter=' ', newline='\n', header='', footer='', comments='# ') flog(f'Saved representative latent encodings for epoch {epochs[i]} to {outdir}/repr_particles/latent_representative.{epochs[i]}.txt', LOG) for k in range(len(labels)): ax.text(x=umap_maxima_median_ondata[k, 0] + 0.3, y=umap_maxima_median_ondata[k, 1] + 0.3, s=labels[k], fontdict=dict(color='r', size=10)) toplot = ax.hexbin(*umap.T, bins='log', mincnt=1) ax.scatter(umap_maxima_median_ondata[:, 0], umap_maxima_median_ondata[:, 1], s=10, linewidth=0, c='r', alpha=1) ax.set_title(f'epoch {epochs[i]}') except IndexError: pass if len(axes.shape) == 1: axes[0].set_ylabel('UMAP2') for a in axes[:]: a.set_xlabel('UMAP1') else: assert len(axes.shape) == 2 #there are more than one row and column of axes for a in axes[:, 0]: a.set_ylabel('UMAP2') for a in axes[-1, :]: a.set_xlabel('UMAP1') fig.subplots_adjust(right=0.96) cbar_ax = fig.add_axes([0.98, 0.15, 0.02, 0.7]) cbar = fig.colorbar(toplot, cax=cbar_ax) cbar.ax.set_ylabel('Particle Density', rotation=90) plt.subplots_adjust(wspace=0.1) plt.subplots_adjust(hspace=0.25) plt.savefig(f'{outdir}/plots/04_decoder_maxima-sketch-consistency.png', dpi=300, format='png', transparent=True, bbox_inches='tight') flog(f'Saved plot tracking representative latent encodings through epochs {epochs} to {outdir}/plots/04_decoder_maxima-sketch-consistency.png', LOG)
def main(args): assert args.o.endswith('.mrc') t1 = time.time() log(args) if not os.path.exists(os.path.dirname(args.o)): os.makedirs(os.path.dirname(args.o)) ## set the device use_cuda = torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') log('Use cuda {}'.format(use_cuda)) if not use_cuda: log('WARNING: No GPUs detected') # load the particles if args.ind is not None: args.ind = utils.load_pkl(args.ind).astype(int) if args.tilt is None: if args.preprocessed: data = dataset.PreprocessedMRCData(args.particles, norm=(0,1), ind=args.ind) else: data = dataset.LazyMRCData(args.particles, norm=(0,1), invert_data=args.invert_data, datadir=args.datadir, ind=args.ind, relion31=args.relion31) tilt = None else: # tilt series if args.relion31: raise NotImplementedError if args.preprocessed: raise NotImplementedError data = dataset.TiltMRCData(args.particles, args.tilt, norm=(0,1), invert_data=args.invert_data, datadir=args.datadir, ind=args.ind) tilt = torch.tensor(utils.xrot(args.tilt_deg).astype(np.float32), device=device) D = data.D Nimg = data.N lattice = Lattice(D, extent=D//2, device=device) posetracker = PoseTracker.load(args.poses, Nimg, D, None, args.ind, device=device) if args.ctf is not None: log('Loading ctf params from {}'.format(args.ctf)) ctf_params = ctf.load_ctf_for_training(D-1, args.ctf) ctf_params = torch.tensor(ctf_params, device=device) if args.ind is not None: ctf_params = ctf_params[ind] else: ctf_params = None Apix = ctf_params[0,0] if ctf_params is not None else 1 V = torch.zeros((D,D,D), device=device) counts = torch.zeros((D,D,D), device=device) mask = lattice.get_circular_mask(D//2) if args.first: args.first = min(args.first, Nimg) iterator = range(args.first) else: iterator = range(Nimg) for ii in iterator: if ii%100==0: log('image {}'.format(ii)) r, t = posetracker.get_pose(ii) ff = data.get(ii) if tilt is not None: ff, ff_tilt = ff # EW ff = torch.tensor(ff, device=device) ff = ff.view(-1)[mask] if ctf_params is not None: freqs = lattice.freqs2d/ctf_params[ii,0] c = ctf.compute_ctf(freqs, *ctf_params[ii,1:]).view(-1)[mask] ff *= c.sign() if t is not None: ff = lattice.translate_ht(ff.view(1,-1),t.view(1,1,2), mask).view(-1) ff_coord = lattice.coords[mask] @ r add_slice(V, counts, ff_coord, ff, D) # tilt series if args.tilt is not None: ff_tilt = torch.tensor(ff_tilt, device=device) ff_tilt = ff_tilt.view(-1)[mask] if ctf_params is not None: ff_tilt *= c.sign() if t is not None: ff_tilt = lattice.translate_ht(ff_tilt.view(1,-1), t.view(1,1,2), mask).view(-1) ff_coord = lattice.coords[mask] @ tilt @ r add_slice(V, counts, ff_coord, ff_tilt, D) td = time.time()-t1 log('Backprojected {} images in {}s ({}s per image)'.format(len(iterator), td, td/Nimg )) counts[counts == 0] = 1 V /= counts V = fft.ihtn_center(V[0:-1,0:-1,0:-1].cpu().numpy()) mrc.write(args.o,V.astype('float32'), Apix=Apix)
def main(args): t1 = dt.now() log(args) E = args.epoch workdir = args.workdir zfile = f'{workdir}/z.{E}.pkl' weights = f'{workdir}/weights.{E}.pkl' config = f'{workdir}/config.pkl' outdir = f'{workdir}/landscape.{E}' if args.outdir: outdir = args.outdir log(f'Saving results to {outdir}') if not os.path.exists(outdir): os.mkdir(outdir) z = utils.load_pkl(zfile) zdim = z.shape[1] K = args.sketch_size vol_args = dict(Apix=args.Apix, downsample=args.downsample, flip=args.flip, cuda=args.device) vg = VolumeGenerator(weights, config, vol_args, skip_vol=args.skip_vol) if args.vol_ind is not None: args.vol_ind = utils.load_pkl(args.vol_ind) if not args.skip_vol: generate_volumes(z, outdir, vg, K) else: log('Skipping volume generation') if args.skip_umap: assert os.path.exists(f'{outdir}/umap.pkl') log('Skipping UMAP') else: log(f'Copying UMAP from {workdir}/analyze.{E}/umap.pkl') if os.path.exists(f'{workdir}/analyze.{E}/umap.pkl'): from shutil import copyfile copyfile(f'{workdir}/analyze.{E}/umap.pkl', f'{outdir}/umap.pkl') else: raise NotImplementedError if args.mask: log(f'Using custom mask {args.mask}') make_mask(outdir, K, args.dilate, args.thresh, args.mask) log('Analyzing volumes...') # get particle indices if the dataset was originally filtered c = utils.load_pkl(config) particle_ind = utils.load_pkl( c['dataset_args'] ['ind']) if c['dataset_args']['ind'] is not None else None analyze_volumes(outdir, K, args.pc_dim, args.M, args.linkage, vol_ind=args.vol_ind, plot_dim=args.plot_dim, particle_ind_orig=particle_ind) td = dt.now() - t1 log(f'Finished in {td}')
def analyze_volumes(outdir, K, dim, M, linkage, vol_ind=None, plot_dim=5, particle_ind_orig=None): cmap = choose_cmap(M) # load mean volume, compute it if it does not exist if not os.path.exists(f'{outdir}/kmeans{K}/vol_mean.mrc'): volm = np.array([ mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0] for i in range(K) ]).mean(axis=0) mrc.write(f'{outdir}/kmeans{K}/vol_mean.mrc', volm) else: volm = mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_mean.mrc')[0] # load mask mask = mrc.parse_mrc(f'{outdir}/mask.mrc')[0].astype(bool) log(f'{mask.sum()} voxels in mask') # load volumes vols = np.array([ mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0][mask] for i in range(K) ]) vols[vols < 0] = 0 # load umap umap = utils.load_pkl(f'{outdir}/umap.pkl') ind = np.loadtxt(f'{outdir}/kmeans{K}/centers_ind.txt').astype(int) if vol_ind is not None: log(f'Filtering to {len(vol_ind)} volumes') vols = vols[vol_ind] ind = ind[vol_ind] # compute PCA pca = PCA(dim) pca.fit(vols) pc = pca.transform(vols) utils.save_pkl(pc, f'{outdir}/vol_pca_{K}.pkl') utils.save_pkl(pca, f'{outdir}/vol_pca_obj.pkl') log('Explained variance ratio:') log(pca.explained_variance_ratio_) # save rxn coordinates for i in range(plot_dim): subdir = f'{outdir}/vol_pcs/pc{i+1}' if not os.path.exists(subdir): os.makedirs(subdir) min_, max_ = pc[:, i].min(), pc[:, i].max() log((min_, max_)) for j, val in enumerate(np.linspace(min_, max_, 10, endpoint=True)): v = volm.copy() v[mask] += pca.components_[i] * val mrc.write(f'{subdir}/{j}.mrc', v) # which plots to show??? def plot(i, j): plt.figure() plt.scatter(pc[:, i], pc[:, j]) plt.xlabel( f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})') plt.ylabel( f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})') plt.savefig(f'{outdir}/vol_pca_{K}_{i+1}_{j+1}.png') for i in range(plot_dim - 1): plot(i, i + 1) # clustering subdir = f'{outdir}/clustering_L2_{linkage}_{M}' if not os.path.exists(subdir): os.makedirs(subdir) cluster = AgglomerativeClustering(n_clusters=M, affinity='euclidean', linkage=linkage) labels = cluster.fit_predict(vols) utils.save_pkl(labels, f'{subdir}/state_labels.pkl') kmeans_labels = utils.load_pkl(f'{outdir}/kmeans{K}/labels.pkl') kmeans_counts = Counter(kmeans_labels) for i in range(M): vol_i = np.where(labels == i)[0] log(f'State {i}: {len(vol_i)} volumes') if vol_ind is not None: vol_i = np.arange(K)[vol_ind][vol_i] vol_i_all = np.array([ mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0] for i in vol_i ]) nparticles = np.array([kmeans_counts[i] for i in vol_i]) vol_i_mean = np.average(vol_i_all, axis=0, weights=nparticles) vol_i_std = np.average((vol_i_all - vol_i_mean)**2, axis=0, weights=nparticles)**.5 mrc.write(f'{subdir}/state_{i}_mean.mrc', vol_i_mean.astype(np.float32)) mrc.write(f'{subdir}/state_{i}_std.mrc', vol_i_std.astype(np.float32)) if not os.path.exists(f'{subdir}/state_{i}'): os.makedirs(f'{subdir}/state_{i}') for v in vol_i: os.symlink(f'{outdir}/kmeans{K}/vol_{v:03d}.mrc', f'{subdir}/state_{i}/vol_{v:03d}.mrc') particle_ind = analysis.get_ind_for_cluster(kmeans_labels, vol_i) log(f'State {i}: {len(particle_ind)} particles') if particle_ind_orig is not None: utils.save_pkl(particle_ind_orig[particle_ind], f'{subdir}/state_{i}_particle_ind.pkl') else: utils.save_pkl(particle_ind, f'{subdir}/state_{i}_particle_ind.pkl') # plot clustering results def hack_barplot(counts_): if M <= 20: # HACK TO GET COLORS with sns.color_palette(cmap): g = sns.barplot(np.arange(M), counts_) else: # default is husl g = sns.barplot(np.arange(M), counts_) return g plt.figure() counts = Counter(labels) g = hack_barplot([counts[i] for i in range(M)]) for i in range(M): g.text(i - .1, counts[i] + 2, counts[i]) plt.xlabel('State') plt.ylabel('Count') plt.savefig(f'{subdir}/state_volume_counts.png') plt.figure() particle_counts = [ np.sum([kmeans_counts[ii] for ii in np.where(labels == i)[0]]) for i in range(M) ] g = hack_barplot(particle_counts) for i in range(M): g.text(i - .1, particle_counts[i] + 2, particle_counts[i]) plt.xlabel('State') plt.ylabel('Count') plt.savefig(f'{subdir}/state_particle_counts.png') def plot_w_labels(i, j): plt.figure() plt.scatter(pc[:, i], pc[:, j], c=labels, cmap=cmap) plt.xlabel( f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})') plt.ylabel( f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})') plt.savefig(f'{subdir}/vol_pca_{K}_{i+1}_{j+1}.png') for i in range(plot_dim - 1): plot_w_labels(i, i + 1) def plot_w_labels_annotated(i, j): fig, ax = plt.subplots(figsize=(16, 16)) plt.scatter(pc[:, i], pc[:, j], c=labels, cmap=cmap) annots = np.arange(K) if vol_ind is not None: annots = annots[vol_ind] for ii, k in enumerate(annots): ax.annotate(str(k), pc[ii, [i, j]] + np.array([.1, .1])) plt.xlabel( f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})') plt.ylabel( f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})') plt.savefig(f'{subdir}/vol_pca_{K}_annotated_{i+1}_{j+1}.png') for i in range(plot_dim - 1): plot_w_labels_annotated(i, i + 1) # plot clusters on UMAP umap_i = umap[ind] fig, ax = plt.subplots(figsize=(8, 8)) plt.scatter(umap[:, 0], umap[:, 1], alpha=.1, s=1, rasterized=True, color='lightgrey') colors = get_colors_for_cmap(cmap, M) for i in range(M): c = umap_i[np.where(labels == i)] plt.scatter(c[:, 0], c[:, 1], label=i, color=colors[i]) plt.legend() plt.xlabel('UMAP1') plt.ylabel('UMAP2') plt.savefig(f'{subdir}/umap.png') fig, ax = plt.subplots(figsize=(16, 16)) plt.scatter(umap[:, 0], umap[:, 1], alpha=.1, s=1, rasterized=True, color='lightgrey') plt.scatter(umap_i[:, 0], umap_i[:, 1], c=labels, cmap=cmap) annots = np.arange(K) if vol_ind is not None: annots = annots[vol_ind] for i, k in enumerate(annots): ax.annotate(str(k), umap_i[i] + np.array([.1, .1])) plt.xlabel('UMAP1') plt.ylabel('UMAP2') plt.savefig(f'{subdir}/umap_annotated.png')