def calc_fsc(vol1_path, vol2_path): ''' Helper function to calculate the FSC between two (assumed masked) volumes vol1 and vol2 should be maps of the same box size, structured as numpy arrays with ndim=3, i.e. by loading with cryodrgn.mrc.parse_mrc ''' # load masked volumes in fourier space vol1, _ = mrc.parse_mrc(vol1_path) vol2, _ = mrc.parse_mrc(vol2_path) vol1_ft = fft.fftn_center(vol1) vol2_ft = fft.fftn_center(vol2) # define fourier grid and label into shells D = vol1.shape[0] x = np.arange(-D // 2, D // 2) x0, x1, x2 = np.meshgrid(x, x, x, indexing='ij') r = np.sqrt(x0 ** 2 + x1 ** 2 + x2 ** 2) r_max = D // 2 # sphere inscribed within volume box r_step = 1 # int(np.min(r[r>0])) bins = np.arange(0, r_max, r_step) bin_labels = np.searchsorted(bins, r, side='right') # calculate the FSC via labeled shells num = ndimage.sum(np.real(vol1_ft * np.conjugate(vol2_ft)), labels=bin_labels, index=bins + 1) den1 = ndimage.sum(np.abs(vol1_ft) ** 2, labels=bin_labels, index=bins + 1) den2 = ndimage.sum(np.abs(vol2_ft) ** 2, labels=bin_labels, index=bins + 1) fsc = num / np.sqrt(den1 * den2) x = bins / D # x axis should be spatial frequency in 1/px return x, fsc
def main(args): vol1, _ = mrc.parse_mrc(args.vol1) vol2, _ = mrc.parse_mrc(args.vol2) if args.mask: mask = mrc.parse_mrc(args.mask)[0] vol1 *= mask vol2 *= mask D = vol1.shape[0] x = np.arange(-D // 2, D // 2) x2, x1, x0 = np.meshgrid(x, x, x, indexing='ij') coords = np.stack((x0, x1, x2), -1) r = (coords**2).sum(-1)**.5 assert r[D // 2, D // 2, D // 2] == 0.0 vol1 = fft.fftn_center(vol1) vol2 = fft.fftn_center(vol2) #log(r[D//2, D//2, D//2:]) prev_mask = np.zeros((D, D, D), dtype=bool) fsc = [1.0] for i in range(1, D // 2): mask = r < i shell = np.where(mask & np.logical_not(prev_mask)) v1 = vol1[shell] v2 = vol2[shell] p = np.vdot(v1, v2) / (np.vdot(v1, v1) * np.vdot(v2, v2))**.5 fsc.append(p.real) prev_mask = mask fsc = np.asarray(fsc) x = np.arange(D // 2) / D res = np.stack((x, fsc), 1) if args.o: np.savetxt(args.o, res) else: log(res) w = np.where(fsc < 0.5) if w: log('0.5: {}'.format(1 / x[w[0]] * args.Apix)) w = np.where(fsc < 0.143) if w: log('0.143: {}'.format(1 / x[w[0]] * args.Apix)) if args.plot: plt.plot(x, fsc) plt.ylim((0, 1)) plt.show()
def main(args): assert args.input.endswith('.mrc'), "Input volume must be .mrc file" assert args.o.endswith('.mrc'), "Output volume must be .mrc file" x, h = mrc.parse_mrc(args.input) x = x[::-1] mrc.write(args.o, x, header=h) log(f'Wrote {args.o}')
def mask_volume(volpath, outpath, Apix, thresh=None, dilate=3, dist=10): ''' Helper function to generate a loose mask around the input density Density is thresholded to 50% maximum intensity, dilated outwards, and a soft cosine edge is applied Inputs volpath: an absolute path to the volume to be used for masking outpath: an absolute path to write out the mask mrc thresh: what intensity threshold between [0, 100] to apply dilate: how far to dilate the thresholded density outwards dist: how far the cosine edge extends from the density Outputs volume.masked.mrc written to outdir ''' vol = mrc.parse_mrc(volpath)[0] thresh = np.percentile(vol, 99.99) / 2 if thresh is None else thresh x = (vol >= thresh).astype(bool) x = binary_dilation(x, iterations=dilate) y = distance_transform_edt(~x.astype(bool)) y[y > dist] = dist z = np.cos(np.pi * y / dist / 2) # check that mask is in range [0,1] assert np.all(z >= 0) assert np.all(z <= 1) # used to write out mask separately from masked volume, now apply and save the masked vol to minimize future I/O # mrc.write(outpath, z.astype(np.float32)) vol *= z mrc.write(outpath, vol.astype(np.float32), Apix=Apix)
def main(args): stack, _ = mrc.parse_mrc(args.input) image = stack[0] x_dim = image.shape[0] y_dim = image.shape[1] print('image dimensions: ' + str(stack.shape[1]) + 'x' + str(stack.shape[1]) + ' pixels') ang_px = float(args.pixel_size) if args.scale1: scale_a = float(args.scale1) else: scale_a = float(100) if args.scale2: scale_b = float(args.scale2) else: scale_b = float(400) line_a = scale_a / ang_px line_b = scale_b / ang_px offset = x_dim / 20 fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(10, 10)) if args.gblur: image = ndimage.gaussian_filter(image, float(args.gblur)) axes.imshow(image, cmap='Greys_r') axes.plot(np.array([offset, offset + line_a]), np.array([y_dim - offset, y_dim - offset]), color='red') axes.plot( np.array([(offset * 2 + line_a) / 2.0, (offset * 2 + line_a) / 2.0]), np.array([y_dim - offset - y_dim / 100, y_dim - offset + y_dim / 100]), color='red') axes.text((2 * offset + line_a) / 2, y_dim - offset + y_dim / 50, str(scale_a) + ' A', color='r', ha='center', va='center') axes.plot(np.array([x_dim - offset, x_dim - offset - line_b]), np.array([y_dim - offset, y_dim - offset]), color='cyan') axes.plot( np.array([(2 * x_dim - 2 * offset - line_b) / 2.0, (2 * x_dim - 2 * offset - line_b) / 2.0]), np.array([y_dim - offset - y_dim / 100, y_dim - offset + y_dim / 100]), color='cyan') axes.text((2 * x_dim - 2 * offset - line_b) / 2.0, y_dim - offset + y_dim / 50, str(scale_b) + ' A', color='cyan', ha='center', va='center') axes.axis('off') if args.tiff: extension = '.tiff' else: extension = '.png' plt.savefig(args.input.split('.mrc')[0] + extension)
def main(args): stack, _ = mrc.parse_mrc(args.input,lazy=True) print('{} {}x{} images'.format(len(stack), *stack[0].get().shape)) stack = [stack[x].get() for x in range(9)] analysis.plot_projections(stack) if args.o: plt.savefig(args.o) else: plt.show()
def calculate_CCs(outdir, epochs, labels, chimerax_colors, LOG): ''' Returns the masked map-map correlation between temporally sequential volume pairs outdir/vols.{epochs}, for each class in labels Inputs: outdir: path to base directory to save outputs epochs: array of epochs for which to calculate UMAPs labels: unique identifier for each class of representative latent encodings chimerax_colors: approximate colors matching ChimeraX palette to facilitate comparison to volume visualization Outputs: plot.png of sequential volume pairs map-map CC for each class in labels across training epochs ''' def calc_cc(vol1, vol2): ''' Helper function to calculate the zero-mean correlation coefficient as defined in eq 2 in https://journals.iucr.org/d/issues/2018/09/00/kw5139/index.html vol1 and vol2 should be maps of the same box size, structured as numpy arrays with ndim=3, i.e. by loading with cryodrgn.mrc.parse_mrc ''' zmean1 = (vol1 - np.mean(vol1)) zmean2 = (vol2 - np.mean(vol2)) cc = (np.sum(zmean1 ** 2) ** -0.5) * (np.sum(zmean2 ** 2) ** -0.5) * np.sum(zmean1 * zmean2) return cc cc_masked = np.zeros((len(labels), len(epochs) - 1)) for i in range(len(epochs) - 1): for cluster in np.arange(len(labels)): vol1, _ = mrc.parse_mrc(f'{outdir}/vols.{epochs[i]}/vol_{cluster:03d}.masked.mrc') vol2, _ = mrc.parse_mrc(f'{outdir}/vols.{epochs[i+1]}/vol_{cluster:03d}.masked.mrc') cc_masked[cluster, i] = calc_cc(vol1, vol2) utils.save_pkl(cc_masked, f'{outdir}/cc_masked.pkl') fig, ax = plt.subplots(1, 1) ax.set_xlabel('epoch') ax.set_ylabel('masked CC') for i in range(len(labels)): ax.plot(epochs[1:], cc_masked[i,:], c=chimerax_colors[i] * 0.75, linewidth=2.5) ax.legend(labels, ncol=3, fontsize='x-small') plt.savefig(f'{outdir}/plots/05_decoder_CC.png', dpi=300, format='png', transparent=True, bbox_inches='tight') flog(f'Saved map-map correlation plot to {outdir}/plots/05_decoder_CC.png', LOG)
def main(args): assert args.input.endswith('.mrc'), "Input volume must be .mrc file" assert args.o.endswith('.mrc'), "Output volume must be .mrc file" x, h = mrc.parse_mrc(args.input) h.update_apix(args.apix) if args.invert: x *= -1 if args.flip: x = x[::-1] mrc.write(args.o, x, header=h) log(f'Wrote {args.o}')
def main(args): assert args.o.endswith('.star') assert args.particles.endswith( '.mrcs' ), "Only a single particle stack as an .mrcs is currently supported" particles = mrc.parse_mrc(args.particles, lazy=True)[0] ctf = utils.load_pkl(args.ctf) assert ctf.shape[1] == 9, "Incorrect CTF pkl format" assert len(particles) == len( ctf ), f"{len(particles)} != {len(ctf)}, Number of particles != number of CTF paraameters" if args.poses: poses = utils.load_pkl(args.poses) assert len(particles) == len( poses[0] ), f"{len(particles)} != {len(poses)}, Number of particles != number of poses" log('{} particles'.format(len(particles))) if args.ind: ind = utils.load_pkl(args.ind) log(f'Filtering to {len(ind)} particles') ctf = ctf[ind] if args.poses: poses = (poses[0][ind], poses[1][ind]) else: ind = np.arange(len(particles)) # _rlnImageName ind += 1 # CHANGE TO 1-BASED INDEXING image_name = os.path.basename( args.particles) if not args.full_path else args.particles names = [f'{i}@{image_name}' for i in ind] ctf = ctf[:, 2:] # convert poses if args.poses: eulers = utils.R_to_relion_scipy(poses[0]) D = particles[0].get().shape[0] trans = poses[1] * D # convert from fraction to pixels data = {HEADERS[0]: names} for i in range(7): data[HEADERS[i + 1]] = ctf[:, i] if args.poses: for i in range(3): data[POSE_HDRS[i]] = eulers[:, i] for i in range(2): data[POSE_HDRS[3 + i]] = trans[:, i] df = pd.DataFrame(data=data) headers = HEADERS + POSE_HDRS if args.poses else HEADERS s = starfile.Starfile(headers, df) s.write(args.o)
def make_mask(outdir, K, dilate, thresh, in_mrc=None): if in_mrc is None: if thresh is None: thresh = [] for i in range(K): vol = mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0] thresh.append(np.percentile(vol, 99.99) / 2) thresh = np.mean(thresh) log(f'Threshold: {thresh}') log(f'Dilating mask by: {dilate}') def binary_mask(vol): x = (vol >= thresh).astype(bool) x = binary_dilation(x, iterations=dilate) return x # combine all masks by taking their union vol = mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_000.mrc')[0] mask = ~binary_mask(vol) for i in range(1, K): vol = mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0] mask *= ~binary_mask(vol) mask = ~mask else: # Load provided mrc and convert to a boolean mask mask, _ = mrc.parse_mrc(in_mrc) mask = mask.astype(bool) # save mask out_mrc = f'{outdir}/mask.mrc' log(f'Saving {out_mrc}') mrc.write(out_mrc, mask.astype(np.float32)) # view slices out_png = f'{outdir}/mask_slices.png' D = vol.shape[0] fig, ax = plt.subplots(1, 3, figsize=(10, 8)) ax[0].imshow(mask[D // 2, :, :]) ax[1].imshow(mask[:, D // 2, :]) ax[2].imshow(mask[:, :, D // 2]) plt.savefig(out_png)
def main(args): log(args) torch.set_grad_enabled(False) use_cuda = torch.cuda.is_available() log('Use cuda {}'.format(use_cuda)) if use_cuda: torch.set_default_tensor_type(torch.cuda.FloatTensor) t1 = time.time() ref, _ = mrc.parse_mrc(args.ref) log('Loaded {} volume'.format(ref.shape)) vol, _ = mrc.parse_mrc(args.vol) log('Loaded {} volume'.format(vol.shape)) projector = VolumeAligner(vol, vol_ref=ref, maxD=args.max_D, flip=args.flip) if use_cuda: projector.use_cuda() r_resol = args.r_resol quats = so3_grid.grid_SO3(r_resol) q_id = np.arange(len(quats)) q_id = np.stack([q_id // (6 * 2**r_resol), q_id % (6 * 2**r_resol)], -1) rots = GridPose(quats, q_id) t_resol = 0 T_EXTENT = vol.shape[0] / 16 if args.t_extent is None else args.t_extent T_NGRID = args.t_grid trans = shift_grid3.base_shift_grid(T_EXTENT, T_NGRID) t_id = np.stack(shift_grid3.get_base_id(np.arange(len(trans)), T_NGRID), -1) trans = GridPose(trans, t_id) max_keep_r = args.keep_r max_keep_t = args.keep_t #rot_tracker = MinPoseTracker(max_keep_r, 4, 2) #tr_tracker = MinPoseTracker(max_keep_t, 3, 3) for it in range(args.niter): log('Iteration {}'.format(it)) log('Generating {} rotations'.format(len(rots))) log('Generating {} translations'.format(len(trans))) pose_err = np.empty((len(rots), len(trans)), dtype=np.float32) #rot_tracker.clear() #tr_tracker.clear() r_iterator = data.DataLoader(rots, batch_size=args.rb, shuffle=False) t_iterator = data.DataLoader(trans, batch_size=args.tb, shuffle=False) r_it = 0 for rot, r_id in r_iterator: if use_cuda: rot = rot.cuda() vr, vi = projector.rotate(rot) t_it = 0 for tr, t_id in t_iterator: if use_cuda: tr = tr.cuda() vtr, vti = projector.translate( vr, vi, tr.expand(rot.size(0), *tr.shape)) # todo: check volume err = projector.compute_err(vtr, vti) # R x T pose_err[r_it:r_it + len(rot), t_it:t_it + len(tr)] = err.cpu().numpy() #r_err = err.min(1)[0] #min_r_err, min_r_i = r_err.sort() #rot_tracker.add(min_r_err[:max_keep_r], rot[min_r_i][:max_keep_r], r_id[min_r_i][:max_keep_r]) #t_err= err.min(0)[0] #min_t_err, min_t_i = t_err.sort() #tr_tracker.add(min_t_err[:max_keep_t], tr[min_t_i][:max_keep_t], t_id[min_t_i][:max_keep_t]) t_it += len(tr) r_it += len(rot) r_err = pose_err.min(1) r_err_argmin = r_err.argsort()[:max_keep_r] t_err = pose_err.min(0) t_err_argmin = t_err.argsort()[:max_keep_t] # lstart #r = rots.pose[r_err_argmin[0]] #t = trans.pose[t_err_argmin[0]] #log('Best rot: {}'.format(r)) #log('Best trans: {}'.format(t)) #vr, vi = projector_full.rotate(torch.tensor(r).unsqueeze(0)) #vr, vi = projector_full.translate(vr, vi, torch.tensor(t).view(1,1,3)) #err = projector_full.compute_err(vr,vi) #w = np.where(r_err[r_err_argmin] > err.item())[0] rots, rots_id = subdivide_r(rots.pose[r_err_argmin], rots.pose_id[r_err_argmin], r_resol) rots = GridPose(rots, rots_id) t_err = pose_err.min(0) t_err_argmin = t_err.argsort()[:max_keep_t] trans, trans_id = subdivide_t(trans.pose_id[t_err_argmin], t_resol, T_EXTENT, T_NGRID) trans = GridPose(trans, trans_id) r_resol += 1 t_resol += 1 vlog(r_err[r_err_argmin]) vlog(t_err[t_err_argmin]) #log(rot_tracker.min_errs) #log(tr_tracker.min_errs) r = rots.pose[r_err_argmin[0]] t = trans.pose[t_err_argmin[0]] * vol.shape[0] / args.max_D log('Best rot: {}'.format(r)) log('Best trans: {}'.format(t)) t *= 2 / vol.shape[0] projector = VolumeAligner(vol, vol_ref=ref, maxD=vol.shape[0], flip=args.flip) if use_cuda: projector.use_cuda() vr = projector.real_tform( torch.tensor(r).unsqueeze(0), torch.tensor(t).view(1, 1, 3)) v = vr.squeeze().cpu().numpy() log('Saving {}'.format(args.o)) mrc.write(args.o, v.astype(np.float32)) td = time.time() - t1 log('Finished in {}s'.format(td))
# coding: utf-8 import sys, os from cryodrgn import mrc import numpy as np data, _ = mrc.parse_mrc('data/toy_projections.mrcs', lazy=True) data2, _ = mrc.parse_mrc('data/toy_projections.mrcs', lazy=False) data1 = np.asarray([x.get() for x in data]) assert (data1 == data2).all() print('ok') from cryodrgn import dataset data2 = dataset.load_particles('data/toy_projections.star') assert (data1 == data2).all() print('ok') data2 = dataset.load_particles('data/toy_projections.txt') assert (data1 == data2).all() print('ok') print('all ok')
def analyze_volumes(outdir, K, dim, M, linkage, vol_ind=None, plot_dim=5, particle_ind_orig=None): cmap = choose_cmap(M) # load mean volume, compute it if it does not exist if not os.path.exists(f'{outdir}/kmeans{K}/vol_mean.mrc'): volm = np.array([ mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0] for i in range(K) ]).mean(axis=0) mrc.write(f'{outdir}/kmeans{K}/vol_mean.mrc', volm) else: volm = mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_mean.mrc')[0] # load mask mask = mrc.parse_mrc(f'{outdir}/mask.mrc')[0].astype(bool) log(f'{mask.sum()} voxels in mask') # load volumes vols = np.array([ mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0][mask] for i in range(K) ]) vols[vols < 0] = 0 # load umap umap = utils.load_pkl(f'{outdir}/umap.pkl') ind = np.loadtxt(f'{outdir}/kmeans{K}/centers_ind.txt').astype(int) if vol_ind is not None: log(f'Filtering to {len(vol_ind)} volumes') vols = vols[vol_ind] ind = ind[vol_ind] # compute PCA pca = PCA(dim) pca.fit(vols) pc = pca.transform(vols) utils.save_pkl(pc, f'{outdir}/vol_pca_{K}.pkl') utils.save_pkl(pca, f'{outdir}/vol_pca_obj.pkl') log('Explained variance ratio:') log(pca.explained_variance_ratio_) # save rxn coordinates for i in range(plot_dim): subdir = f'{outdir}/vol_pcs/pc{i+1}' if not os.path.exists(subdir): os.makedirs(subdir) min_, max_ = pc[:, i].min(), pc[:, i].max() log((min_, max_)) for j, val in enumerate(np.linspace(min_, max_, 10, endpoint=True)): v = volm.copy() v[mask] += pca.components_[i] * val mrc.write(f'{subdir}/{j}.mrc', v) # which plots to show??? def plot(i, j): plt.figure() plt.scatter(pc[:, i], pc[:, j]) plt.xlabel( f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})') plt.ylabel( f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})') plt.savefig(f'{outdir}/vol_pca_{K}_{i+1}_{j+1}.png') for i in range(plot_dim - 1): plot(i, i + 1) # clustering subdir = f'{outdir}/clustering_L2_{linkage}_{M}' if not os.path.exists(subdir): os.makedirs(subdir) cluster = AgglomerativeClustering(n_clusters=M, affinity='euclidean', linkage=linkage) labels = cluster.fit_predict(vols) utils.save_pkl(labels, f'{subdir}/state_labels.pkl') kmeans_labels = utils.load_pkl(f'{outdir}/kmeans{K}/labels.pkl') kmeans_counts = Counter(kmeans_labels) for i in range(M): vol_i = np.where(labels == i)[0] log(f'State {i}: {len(vol_i)} volumes') if vol_ind is not None: vol_i = np.arange(K)[vol_ind][vol_i] vol_i_all = np.array([ mrc.parse_mrc(f'{outdir}/kmeans{K}/vol_{i:03d}.mrc')[0] for i in vol_i ]) nparticles = np.array([kmeans_counts[i] for i in vol_i]) vol_i_mean = np.average(vol_i_all, axis=0, weights=nparticles) vol_i_std = np.average((vol_i_all - vol_i_mean)**2, axis=0, weights=nparticles)**.5 mrc.write(f'{subdir}/state_{i}_mean.mrc', vol_i_mean.astype(np.float32)) mrc.write(f'{subdir}/state_{i}_std.mrc', vol_i_std.astype(np.float32)) if not os.path.exists(f'{subdir}/state_{i}'): os.makedirs(f'{subdir}/state_{i}') for v in vol_i: os.symlink(f'{outdir}/kmeans{K}/vol_{v:03d}.mrc', f'{subdir}/state_{i}/vol_{v:03d}.mrc') particle_ind = analysis.get_ind_for_cluster(kmeans_labels, vol_i) log(f'State {i}: {len(particle_ind)} particles') if particle_ind_orig is not None: utils.save_pkl(particle_ind_orig[particle_ind], f'{subdir}/state_{i}_particle_ind.pkl') else: utils.save_pkl(particle_ind, f'{subdir}/state_{i}_particle_ind.pkl') # plot clustering results def hack_barplot(counts_): if M <= 20: # HACK TO GET COLORS with sns.color_palette(cmap): g = sns.barplot(np.arange(M), counts_) else: # default is husl g = sns.barplot(np.arange(M), counts_) return g plt.figure() counts = Counter(labels) g = hack_barplot([counts[i] for i in range(M)]) for i in range(M): g.text(i - .1, counts[i] + 2, counts[i]) plt.xlabel('State') plt.ylabel('Count') plt.savefig(f'{subdir}/state_volume_counts.png') plt.figure() particle_counts = [ np.sum([kmeans_counts[ii] for ii in np.where(labels == i)[0]]) for i in range(M) ] g = hack_barplot(particle_counts) for i in range(M): g.text(i - .1, particle_counts[i] + 2, particle_counts[i]) plt.xlabel('State') plt.ylabel('Count') plt.savefig(f'{subdir}/state_particle_counts.png') def plot_w_labels(i, j): plt.figure() plt.scatter(pc[:, i], pc[:, j], c=labels, cmap=cmap) plt.xlabel( f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})') plt.ylabel( f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})') plt.savefig(f'{subdir}/vol_pca_{K}_{i+1}_{j+1}.png') for i in range(plot_dim - 1): plot_w_labels(i, i + 1) def plot_w_labels_annotated(i, j): fig, ax = plt.subplots(figsize=(16, 16)) plt.scatter(pc[:, i], pc[:, j], c=labels, cmap=cmap) annots = np.arange(K) if vol_ind is not None: annots = annots[vol_ind] for ii, k in enumerate(annots): ax.annotate(str(k), pc[ii, [i, j]] + np.array([.1, .1])) plt.xlabel( f'Volume PC{i+1} (EV: {pca.explained_variance_ratio_[i]:03f})') plt.ylabel( f'Volume PC{j+1} (EV: {pca.explained_variance_ratio_[j]:03f})') plt.savefig(f'{subdir}/vol_pca_{K}_annotated_{i+1}_{j+1}.png') for i in range(plot_dim - 1): plot_w_labels_annotated(i, i + 1) # plot clusters on UMAP umap_i = umap[ind] fig, ax = plt.subplots(figsize=(8, 8)) plt.scatter(umap[:, 0], umap[:, 1], alpha=.1, s=1, rasterized=True, color='lightgrey') colors = get_colors_for_cmap(cmap, M) for i in range(M): c = umap_i[np.where(labels == i)] plt.scatter(c[:, 0], c[:, 1], label=i, color=colors[i]) plt.legend() plt.xlabel('UMAP1') plt.ylabel('UMAP2') plt.savefig(f'{subdir}/umap.png') fig, ax = plt.subplots(figsize=(16, 16)) plt.scatter(umap[:, 0], umap[:, 1], alpha=.1, s=1, rasterized=True, color='lightgrey') plt.scatter(umap_i[:, 0], umap_i[:, 1], c=labels, cmap=cmap) annots = np.arange(K) if vol_ind is not None: annots = annots[vol_ind] for i, k in enumerate(annots): ax.annotate(str(k), umap_i[i] + np.array([.1, .1])) plt.xlabel('UMAP1') plt.ylabel('UMAP2') plt.savefig(f'{subdir}/umap_annotated.png')
def main(args): for out in (args.o, args.out_png, args.out_pose): if not out: continue mkbasedir(out) warnexists(out) if args.t_extent == 0.: log('Not shifting images') else: assert args.t_extent > 0 if args.seed is not None: np.random.seed(args.seed) torch.manual_seed(args.seed) use_cuda = torch.cuda.is_available() log('Use cuda {}'.format(use_cuda)) if use_cuda: torch.set_default_tensor_type(torch.cuda.FloatTensor) t1 = time.time() vol, _ = mrc.parse_mrc(args.mrc) log('Loaded {} volume'.format(vol.shape)) if args.tilt: theta = args.tilt*np.pi/180 args.tilt = np.array([[1.,0.,0.], [0, np.cos(theta), -np.sin(theta)], [0, np.sin(theta), np.cos(theta)]]).astype(np.float32) projector = Projector(vol, args.tilt) if use_cuda: projector.lattice = projector.lattice.cuda() projector.vol = projector.vol.cuda() if args.grid is not None: rots = GridRot(args.grid) log('Generating {} rotations at resolution level {}'.format(len(rots), args.grid)) else: log('Generating {} random rotations'.format(args.N)) rots = RandomRot(args.N) log('Projecting...') imgs = [] iterator = data.DataLoader(rots, batch_size=args.b) for i, rot in enumerate(iterator): vlog('Projecting {}/{}'.format((i+1)*len(rot), args.N)) projections = projector.project(rot) projections = projections.cpu().numpy() imgs.append(projections) rots = rots.rots.cpu().numpy() imgs = np.vstack(imgs) td = time.time()-t1 log('Projected {} images in {}s ({}s per image)'.format(args.N, td, td/args.N )) if args.t_extent: log('Shifting images between +/- {} pixels'.format(args.t_extent)) trans = np.random.rand(args.N,2)*2*args.t_extent - args.t_extent imgs = np.asarray([translate_img(img, t) for img,t in zip(imgs,trans)]) # convention: we want the first column to be x shift and second column to be y shift # reverse columns since current implementation of translate_img uses scipy's # fourier_shift, which is flipped the other way # convention: save the translation that centers the image trans = -trans[:,::-1] # convert translation from pixel to fraction D = imgs.shape[-1] assert D % 2 == 0 trans /= D log('Saving {}'.format(args.o)) mrc.write(args.o,imgs.astype(np.float32)) log('Saving {}'.format(args.out_pose)) with open(args.out_pose,'wb') as f: if args.t_extent: pickle.dump((rots,trans),f) else: pickle.dump(rots, f) if args.out_png: log('Saving {}'.format(args.out_png)) plot_projections(args.out_png, imgs[:9])