def plot_y_centroids(yi, lami, wi, vlim=(-4, 4), fs_xlabel=12, fs_ylabel=12, fs_xticks=12, fs_yticks=12, rotation=90, ha="center", markernames=[], ylab="subpopulations", gridlines_color='black', gridlines_lw=1, cm=blue2red.cm(9), population=None, Zi=None): J = yi.shape[1] K = wi.shape[0] selected_features = np.argwhere(wi > 0)[:, 0] K_sel = selected_features.shape[0] wi_sel = wi[selected_features] selected_features = selected_features[np.argsort(wi_sel)] wi_sel_sorted = wi_sel[np.argsort(wi_sel)] y_centers = np.zeros((J, K_sel)) yticks = [0] * K_sel for j in range(J): for k in range(K_sel)[::-1]: k_ = selected_features[k] + 1 y_centers[j, k] = yi[lami == k_, j].mean() w_ik_perc = (wi_sel_sorted[k] * 100).round(1) if population is None or Zi is None: yticks[k] = '{} ({}%)'.format(k_, w_ik_perc) else: label = population.label(Zi[:, k_ - 1]) yticks[k] = '{} ({}%)'.format(label, w_ik_perc) im = plt.imshow(y_centers.T, aspect='auto', cmap=cm, vmin=vlim[0], vmax=vlim[1]) if markernames == []: markernames = np.arange(J) + 1 plt.xticks(range(J), markernames, rotation=rotation, fontsize=fs_xticks, ha=ha) plt.yticks(range(K_sel), yticks, fontsize=fs_yticks) plt.xlabel('markers', fontsize=fs_xlabel) plt.ylabel(ylab, fontsize=fs_ylabel) gridlines(y_centers.T, color=gridlines_color, lw=gridlines_lw) colorbar_horizontal(im)
def plot_y(yi, wi_mean, lami_est, fs_lab=10, fs_cbar=10, lw=3, cm=blue2red.cm(6), vlim=(-3, 3), fs_xlab=10, fs_ylab=10, ylab="cells", markernames=[], interpolation=None, rotation=90, ha="center"): J = yi.shape[1] vmin, vmax = vlim if type(wi_mean) == int: K = wi_mean wi_mean = np.array([(lami_est == k + 1).mean() for k in range(K)]) lami_new, counts = relabel_lam(lami_est, wi_mean) counts_cumsum = np.cumsum(counts) yi_sorted = yi[np.argsort(lami_new), :] im = plt.imshow(yi_sorted, aspect='auto', vmin=vmin, vmax=vmax, cmap=cm, interpolation=interpolation) for c in counts_cumsum[:-1]: plt.axhline(c, color='yellow', linewidth=lw) plt.xticks(rotation=rotation, ha=ha) if len(markernames) == 0: plt.xticks(np.arange(J), np.arange(J) + 1, fontsize=fs_xlab) else: plt.xticks(np.arange(J), markernames, fontsize=fs_xlab) plt.yticks(fontsize=fs_ylab) plt.xlabel("markers", fontsize=fs_lab) plt.ylabel(ylab, fontsize=fs_lab) ax = plt.gca() ax_divider = make_axes_locatable(ax) cax = ax_divider.append_axes("top", size="7%", pad="2%") cax.xaxis.set_ticks_position("top") cbar = colorbar(im, cax=cax, orientation="horizontal") cbar.ax.tick_params(labelsize=fs_cbar)
y = copy.deepcopy(data['y']) # Get a subsampmle of data if 0 < subsample < 1: for i in range(len(y)): Ni = y[i].shape[0] idx = np.random.choice(Ni, int(Ni * subsample), replace=False) y[i] = y[i][idx, :] # Print size of data print('N: {}'.format([yi.shape[0] for yi in y])) # Color map cm_greys = plt.cm.get_cmap('Greys', 5) VMIN, VMAX = VLIM = (-4, 4) cm = blue2red.cm(9) # Plot yi histograms # plt.hist(y[0][:, 1], bins=100, density=True); plt.xlim(-15, 15); plt.show() # plt.hist(y[1][:, 3], bins=100, density=True); plt.xlim(-15, 15); plt.show() # plt.hist(y[2][:, -1], bins=100, density=True); plt.xlim(-15, 15); plt.show() # Heatmaps for i in range(I): plt.imshow(y[i], aspect='auto', vmin=VMIN, vmax=VMAX, cmap=cm) plt.colorbar() plt.savefig('{}/y{}.pdf'.format(img_dir, i + 1)) plt.close() K = 30 L = [5, 3]
yi_path = f'{path}/img/txt/y{i}_mean.csv' yi = np.loadtxt(yi_path, delimiter=',') lami_path = f'{path}/img/txt/lam{i}_best.txt' lami = np.loadtxt(lami_path, dtype=int) wi_path = f'{path}/img/txt/W{i}_best.txt' wi = np.loadtxt(wi_path) zi_path = f'{path}/img/txt/Z{i}_best.txt' zi = np.loadtxt(zi_path) # Plot plt.figure(figsize=(6, 6)) plot_yz.plot_y_centroids(yi, lami, wi, vlim=(-3, 3), cm=blue2red.cm(6), population=population, Zi=zi, fs_xlabel=16, fs_ylabel=16, fs_xticks=16, fs_yticks=16) outpath = f'{path}/img/y{i}_centroid.pdf' plt.savefig(outpath, bbox_inches="tight") plt.close() # Plot Z estimate. plt.figure(figsize=(6, 6)) plot_yz.plot_Z(Z_mean=zi, wi_mean=wi, lami_est=lami,
population = Population() # Rcounts R_path = f'{path}/img/txt/Rcounts.txt' R = np.loadtxt(R_path) plt.figure(figsize=(5,5)) zinfo.plot_num_selected_features(R, refs=[6, 5], ymax=1.05) plt.savefig(f'{path}/img/Rcounts.pdf', bbox_inches='tight') plt.close() for i in (1, 2): # Read data yi_path = f'{path}/img/txt/y{i}_mean.csv' yi = np.loadtxt(yi_path, delimiter=',') lami_path = f'{path}/img/txt/lam{i}_best.txt' lami = np.loadtxt(lami_path, dtype=int) wi_path = f'{path}/img/txt/W{i}_best.txt' wi = np.loadtxt(wi_path) zi_path = f'{path}/img/txt/Z{i}_best.txt' zi = np.loadtxt(zi_path) # Plot plt.figure(figsize=(6,6)) plot_yz.plot_y_centroids(yi, lami, wi, vlim=(-3, 3), cm=blue2red.cm(6), population=population, Zi=zi, fs_xlabel=16, fs_ylabel=16, fs_xticks=16, fs_yticks=16) outpath = f'{path}/img/y{i}_centroid.pdf' plt.savefig(outpath, bbox_inches="tight") plt.close()
def plot_yz(yi, Z_mean, wi_mean, lami_est, w_thresh=.01, cm_greys=plt.cm.get_cmap('Greys', 5), markernames=[], rotation=90, ha="center", cm_y=blue2red.cm(6), vlim_y=(-3, 3), fs_w=10, w_digits=1): J = yi.shape[1] vmin_y, vmax_y = vlim_y # cm_y.set_bad(color='black') # cm_y.set_under(color='blue') # cm_y.set_over(color='red') # gs = gridspec.GridSpec(1, 2, width_ratios=[2, 5]) gs = gridspec.GridSpec(2, 1, height_ratios=[5, 2]) # Plot y lami_new, counts = relabel_lam(lami_est, wi_mean) counts_cumsum = np.cumsum(counts) yi_sorted = yi[np.argsort(lami_new), :] plt.subplot(gs[0]) im = plt.imshow(yi_sorted, aspect='auto', vmin=vmin_y, vmax=vmax_y, cmap=cm_y) for c in counts_cumsum[:-1]: plt.axhline(c, color='yellow') plt.xticks(rotation=rotation, ha=ha) if len(markernames) == 0: plt.xticks(np.arange(J), np.arange(J) + 1) else: plt.xticks(np.arange(J), markernames) ax = plt.gca() ax_divider = make_axes_locatable(ax) cax = ax_divider.append_axes("top", size="7%", pad="2%") cax.xaxis.set_ticks_position("top") colorbar(im, cax=cax, orientation="horizontal") # Plot Z k_ord = wi_mean.argsort() z_cols = [] for k in k_ord.tolist(): if wi_mean[k] > w_thresh: z_cols.append(k) z_cols = np.array(z_cols) Z_hat = Z_mean[:, z_cols].T plt.subplot(gs[1]) im = plt.imshow(Z_hat, aspect='auto', vmin=0, vmax=1, cmap=cm_greys) ax = plt.gca() plt.xticks([]) plt.yticks(np.arange(len(z_cols)), z_cols + 1, fontsize=fs_w) add_gridlines_Z(Z_hat) plt.colorbar(orientation='horizontal', pad=.05) # add wi_mean on right side K = z_cols.shape[0] ax2 = ax.twinx() ax2.set_yticks(range(K)) w_perc = wi_mean[z_cols] w_perc = [str((wp * 100).round(w_digits)) + '%' for wp in w_perc] plt.yticks((K - 1) / K * np.arange(K) + .5, w_perc[::-1], fontsize=fs_w) plt.yticks() ax2.tick_params(length=0) fig = plt.gcf() fig.subplots_adjust(hspace=0.2)