def plot_count_sum_series(x, w, p, row_start, tit): if len(w) != 3: # no statistics provided return expected, stdev_total, stdev_explained = w count_sum_observed = np.sum(x, axis=0) count_sum_expected = np.sum(expected, axis=0) count_sum_stdev_total = np.sum(stdev_total, axis=0) count_sum_stdev_explained = np.sum(stdev_explained, axis=0) if p is not None: p_sum = np.mean(p, axis=0) for i, kws in enumerate(series_config): ax = subplot(grid[row_start:row_start + 3, (i * 3):(i * 3 + 3)]) ax, handles, indices = plot_series_statistics( count_sum_observed, count_sum_expected, explained_stdev=count_sum_stdev_explained, total_stdev=count_sum_stdev_total, fontsize=8, ax=ax, legend_enable=False, title=tit if i == 0 else None, despine=True if p is None else False, return_handles=True, return_indices=True, **kws) if p is not None: _show_zero_inflated_pi(p_sum, ax, handles, indices) plt.legend(handles=handles, loc='best', markerscale=4, fontsize=8)
def plot_countsum_series(original, imputed, p=None, reduce_axis=0, title=None, ax=None): """ x: [n_samples, n_genes] original count w: tuple (expected, stdev_total, stdev_explained) [n_samples, n_genes] the prediction p: [n_samples, n_genes] dropout probability """ if ax is None: ax = visual.to_axis(ax) reduce_axis = int(reduce_axis) if isinstance(imputed, (tuple, list)): # no statistics provided assert len(imputed) == 3 expected, stdev_total, stdev_explained = imputed elif imputed.ndim == 3: assert imputed.shape[0] == 3 expected, stdev_total, stdev_explained = imputed[0], imputed[ 1], imputed[2] else: raise ValueError() count_sum_observed = np.log1p(np.sum(original, axis=reduce_axis)) count_sum_expected = np.log1p(np.sum(expected, axis=reduce_axis)) count_sum_stdev_total = np.log1p(np.sum(stdev_total, axis=reduce_axis)) count_sum_stdev_explained = np.log1p( np.sum(stdev_explained, axis=reduce_axis)) if p is not None: p_sum = np.mean(p, axis=reduce_axis) ax, handles, indices = plot_series_statistics( count_sum_observed, count_sum_expected, explained_stdev=count_sum_stdev_explained, total_stdev=count_sum_stdev_total, fontsize=8, ax=ax, legend_enable=False, title=title, despine=True if p is None else False, return_handles=True, return_indices=True, xscale='linear', yscale='linear', sort_by='expected') if p is not None: _show_zero_inflated_pi(p_sum, ax, handles, indices) ax.legend(handles=handles, loc='best', markerscale=4, fontsize=8)
def plot_epoch(task): if task is None: curr_epoch = 0 else: curr_epoch = task.curr_epoch if not (curr_epoch < 5 or curr_epoch % 5 == 0): return rand = np.random.RandomState(seed=1234) X, y = X_test, y_test n_data = X.shape[0] Z = f_z(X) W, W_stdev_mcmc, W_stdev_analytic = f_w(X) X_pca, W_pca_1 = fast_pca(X, W, n_components=2, random_state=rand.randint(10e8)) W_pca_2 = fast_pca(W, n_components=2, random_state=rand.randint(10e8)) X_count_sum = np.sum(X, axis=tuple(range(1, X.ndim))) W_count_sum = np.sum(W, axis=-1) n_visual_samples = 8 nrow = 13 + n_visual_samples * 3 V.plot_figure(nrow=int(nrow * 1.8), ncol=18) with V.plot_gridSpec(nrow=nrow + 3, ncol=6, hspace=0.8) as grid: # plot the latent space for i, (z, name) in enumerate(zip(Z, Z_names)): if z.shape[1] > 2: z = fast_pca(z, n_components=2, random_state=rand.randint(10e8)) ax = V.subplot(grid[:3, (i * 2):(i * 2 + 2)]) V.plot_scatter(x=z[:, 0], y=z[:, 1], color=y, marker=y, n_samples=4000, ax=ax, legend_enable=False, legend_ncol=n_classes) ax.set_title(name, fontsize=12) # plot the reconstruction for i, (x, name) in enumerate( zip([X_pca, W_pca_1, W_pca_2], [ 'Original data', 'Reconstruction', 'Reconstruction (separated PCA)' ])): ax = V.subplot(grid[3:6, (i * 2):(i * 2 + 2)]) V.plot_scatter(x=x[:, 0], y=x[:, 1], color=y, marker=y, n_samples=4000, ax=ax, legend_enable=i == 1, legend_ncol=n_classes, title=name) # plot the reconstruction count sum for i, (x, count_sum, name) in enumerate( zip([X_pca, W_pca_1], [X_count_sum, W_count_sum], [ 'Original data (Count-sum)', 'Reconstruction (Count-sum)' ])): ax = V.subplot(grid[6:10, (i * 3):(i * 3 + 3)]) V.plot_scatter(x=x[:, 0], y=x[:, 1], val=count_sum, n_samples=2000, marker=y, ax=ax, size=8, legend_enable=i == 0, legend_ncol=n_classes, title=name, colorbar=True, fontsize=10) # plot the count-sum series count_sum_observed = np.sum(X, axis=0).ravel() count_sum_expected = np.sum(W, axis=0) count_sum_stdev_explained = np.sum(W_stdev_mcmc, axis=0) count_sum_stdev_total = np.sum(W_stdev_analytic, axis=0) for i, kws in enumerate([ dict(xscale='linear', yscale='linear', sort_by=None), dict(xscale='linear', yscale='linear', sort_by='expected'), dict(xscale='log', yscale='log', sort_by='expected') ]): ax = V.subplot(grid[10:10 + 3, (i * 2):(i * 2 + 2)]) V.plot_series_statistics(count_sum_observed, count_sum_expected, explained_stdev=count_sum_stdev_explained, total_stdev=count_sum_stdev_total, fontsize=8, title="Count-sum" if i == 0 else None, **kws) # plot the mean and variances curr_grid_index = 13 ids = rand.permutation(n_data) ids = ids[:n_visual_samples] for i in ids: observed, expected, stdev_explained, stdev_total = \ X[i], W[i], W_stdev_mcmc[i], W_stdev_analytic[i] observed = observed.ravel() for j, kws in enumerate([ dict(xscale='linear', yscale='linear', sort_by=None), dict(xscale='linear', yscale='linear', sort_by='expected'), dict(xscale='log', yscale='log', sort_by='expected') ]): ax = V.subplot(grid[curr_grid_index:curr_grid_index + 3, (j * 2):(j * 2 + 2)]) V.plot_series_statistics(observed, expected, explained_stdev=stdev_explained, total_stdev=stdev_total, fontsize=8, title="Test Sample #%d" % i if j == 0 else None, **kws) curr_grid_index += 3 V.plot_save(os.path.join(FIGURE_PATH, 'latent_%d.png' % curr_epoch), dpi=200, log=True) exit()
def plot_monitoring_epoch(X, X_drop, y, Z, Z_drop, W_outputs, W_drop_outputs, pi, pi_drop, row_name, dropout_percentage, curr_epoch, ds_name, labels, save_dir): # Order of W_outputs: [W, W_stdev_total, W_stdev_explained] from matplotlib import pyplot as plt if y.ndim == 2: y = np.argmax(y, axis=-1) y = np.array([labels[i] for i in y]) dropout_percentage_text = '%g%%' % (dropout_percentage * 100) Z_pca = fast_pca(Z, n_components=2, random_state=5218) Z_pca_drop = fast_pca(Z_drop, n_components=2, random_state=5218) if W_outputs is not None: X_pca, X_pca_drop, W_pca, W_pca_drop = fast_pca(X, X_drop, W_outputs[0], W_drop_outputs[0], n_components=2, random_state=5218) # ====== downsampling ====== # rand = np.random.RandomState(seed=5218) n_test_samples = len(y) ids = np.arange(n_test_samples, dtype='int32') if n_test_samples > 8000: ids = rand.choice(ids, size=8000, replace=False) # ====== scatter configuration ====== # config = dict(size=6, labels=None) y = y[ids] X = X[ids] X_drop = X_drop[ids] Z_pca = Z_pca[ids] X_pca = X_pca[ids] W_pca = W_pca[ids] W_outputs = [w[ids] for w in W_outputs] W_drop_outputs = [w[ids] for w in W_drop_outputs] Z_pca_drop = Z_pca_drop[ids] X_pca_drop = X_pca_drop[ids] W_pca_drop = W_pca_drop[ids] if pi is not None: pi = pi[ids] pi_drop = pi_drop[ids] # ====== plotting NO reconstruction ====== # if W_outputs is None: plot_figure(nrow=8, ncol=20) fast_scatter(x=Z_pca, y=y, title="[PCA] Test data latent space", enable_legend=True, ax=(1, 2, 1), **config) fast_scatter(x=Z_pca_drop, y=y, title="[PCA][Dropped:%s] Test data latent space" % dropout_percentage_text, ax=(1, 2, 2), **config) # ====== plotting WITH reconstruction ====== # else: plot_figure(nrow=16, ncol=20) # original test data WITHOUT dropout fast_scatter(x=X_pca, y=y, title="[PCA][Test Data] Original", ax=(2, 3, 1), **config) fast_scatter(x=W_pca, y=y, title="Reconstructed", ax=(2, 3, 2), **config) fast_scatter(x=Z_pca, y=y, title="Latent space", ax=(2, 3, 3), **config) # original test data WITH dropout fast_scatter(x=X_pca_drop, y=y, title="[PCA][Dropped:%s][Test Data] Original" % dropout_percentage_text, ax=(2, 3, 4), **config) fast_scatter(x=W_pca_drop, y=y, title="Reconstructed", ax=(2, 3, 5), enable_legend=True, **config) fast_scatter(x=Z_pca_drop, y=y, title="Latent space", ax=(2, 3, 6), **config) plot_save(os.path.join(save_dir, 'latent_epoch%d.png') % curr_epoch, dpi=180, clear_all=True, log=True) # ====== plot count-sum ====== # if W_outputs is not None: X_countsum = _clip_count_sum(np.sum(X, axis=-1)) W_countsum = _clip_count_sum(np.sum(W_outputs[0], axis=-1)) X_drop_countsum = _clip_count_sum(np.sum(X_drop, axis=-1)) W_drop_countsum = _clip_count_sum(np.sum(W_drop_outputs[0], axis=-1)) series_config = [ dict(xscale='linear', yscale='linear', sort_by=None), dict(xscale='linear', yscale='linear', sort_by='expected') ] if pi is not None: pi_sum = np.mean(pi, axis=-1) pi_drop_sum = np.mean(pi_drop, axis=-1) # plot the reconstruction count sum plot_figure(nrow=3 * 5 + 8, ncol=18) with plot_gridSpec(nrow=3 * (2 if pi is None else 3) + 4 * 3 + 1, ncol=6, wspace=1.0, hspace=0.8) as grid: kws = dict(colorbar=True, fontsize=10, size=10, marker=y, n_samples=1200) # without dropout ax = subplot(grid[:3, 0:3]) plot_scatter(x=X_pca, val=X_countsum, ax=ax, legend_enable=False, title='Original data (Count-sum)', **kws) ax = subplot(grid[:3, 3:6]) plot_scatter(x=W_pca, val=W_countsum, ax=ax, legend_enable=False, title='Reconstruction (Count-sum)', **kws) # with dropout ax = subplot(grid[3:6, 0:3]) plot_scatter(x=X_pca_drop, val=X_drop_countsum, ax=ax, legend_enable=True if pi is None else False, legend_ncol=len(labels), title='[Dropped:%s]Original data (Count-sum)' % dropout_percentage_text, **kws) ax = subplot(grid[3:6, 3:6]) plot_scatter(x=W_pca_drop, val=W_drop_countsum, ax=ax, legend_enable=False, title='[Dropped:%s]Reconstruction (Count-sum)' % dropout_percentage_text, **kws) row_start = 6 # zero-inflated pi if pi is not None: ax = subplot(grid[6:9, 0:3]) plot_scatter(x=X_pca, val=pi_sum, ax=ax, legend_enable=True, legend_ncol=len(labels), title='Zero-inflated probabilities', **kws) ax = subplot(grid[6:9, 3:6]) plot_scatter(x=X_pca, val=pi_drop_sum, ax=ax, legend_enable=False, title='[Dropped:%s]Zero-inflated probabilities' % dropout_percentage_text, **kws) row_start += 3 # plot the count-sum series def plot_count_sum_series(x, w, p, row_start, tit): if len(w) != 3: # no statistics provided return expected, stdev_total, stdev_explained = w count_sum_observed = np.sum(x, axis=0) count_sum_expected = np.sum(expected, axis=0) count_sum_stdev_total = np.sum(stdev_total, axis=0) count_sum_stdev_explained = np.sum(stdev_explained, axis=0) if p is not None: p_sum = np.mean(p, axis=0) for i, kws in enumerate(series_config): ax = subplot(grid[row_start:row_start + 3, (i * 3):(i * 3 + 3)]) ax, handles, indices = plot_series_statistics( count_sum_observed, count_sum_expected, explained_stdev=count_sum_stdev_explained, total_stdev=count_sum_stdev_total, fontsize=8, ax=ax, legend_enable=False, title=tit if i == 0 else None, despine=True if p is None else False, return_handles=True, return_indices=True, **kws) if p is not None: _show_zero_inflated_pi(p_sum, ax, handles, indices) plt.legend(handles=handles, loc='best', markerscale=4, fontsize=8) # add one row extra padding row_start += 1 plot_count_sum_series(x=X, w=W_outputs, p=pi, row_start=row_start, tit="Count-sum X_original - W_original") row_start += 1 plot_count_sum_series( x=X_drop, w=W_drop_outputs, p=pi_drop, row_start=row_start + 3, tit="[Dropped:%s]Count-sum X_drop - W_dropout" % dropout_percentage_text) row_start += 1 plot_count_sum_series( x=X, w=W_drop_outputs, p=pi_drop, row_start=row_start + 6, tit="[Dropped:%s]Count-sum X_original - W_dropout" % dropout_percentage_text) plot_save(os.path.join(save_dir, 'countsum_epoch%d.png') % curr_epoch, dpi=180, clear_all=True, log=True) # ====== plot series of samples ====== # if W_outputs is not None and len(W_outputs) == 3: # NOTe: turn off pi here pi = None n_visual_samples = 8 plot_figure(nrow=3 * n_visual_samples + 8, ncol=25) col_width = 5 with plot_gridSpec(nrow=3 * n_visual_samples, ncol=4 * col_width, wspace=5.0, hspace=1.0) as grid: curr_grid_index = 0 for i in rand.permutation(len(X))[:n_visual_samples]: observed = X[i] expected, stdev_explained, stdev_total = [ w[i] for w in W_outputs ] expected_drop, stdev_explained_drop, stdev_total_drop = [ w[i] for w in W_drop_outputs ] if pi is not None: p_zi = pi[i] p_zi_drop = pi_drop[i] # compare to W_original for j, kws in enumerate(series_config): ax = subplot(grid[curr_grid_index:curr_grid_index + 3, (j * col_width):(j * col_width + col_width)]) ax, handles, indices = plot_series_statistics( observed, expected, explained_stdev=stdev_explained, total_stdev=stdev_total, fontsize=8, legend_enable=False, despine=True if pi is None else False, title=("'%s' X_original - W_original" % row_name[i]) if j == 0 else None, return_handles=True, return_indices=True, **kws) if pi is not None: _show_zero_inflated_pi(p_zi, ax, handles, indices) plt.legend(handles=handles, loc='best', markerscale=4, fontsize=8) # compare to W_dropout for j, kws in enumerate(series_config): col_start = col_width * 2 ax = subplot( grid[curr_grid_index:curr_grid_index + 3, (col_start + j * col_width):(col_start + j * col_width + col_width)]) ax, handles, indices = plot_series_statistics( observed, expected_drop, explained_stdev=stdev_explained_drop, total_stdev=stdev_total_drop, fontsize=8, legend_enable=False, despine=True if pi is None else False, title=("[Dropped:%s]'%s' X_original - W_dropout" % (dropout_percentage_text, row_name[i])) if j == 0 else None, return_handles=True, return_indices=True, **kws) if pi is not None: _show_zero_inflated_pi(p_zi_drop, ax, handles, indices) plt.legend(handles=handles, loc='best', markerscale=4, fontsize=8) curr_grid_index += 3 plot_save(os.path.join(save_dir, 'samples_epoch%d.png') % curr_epoch, dpi=180, clear_all=True, log=True) # ====== special case for mnist ====== # if 'mnist' in ds_name and W_outputs is not None: plot_figure(nrow=3, ncol=18) n_images = 32 ids = rand.choice(np.arange(X.shape[0], dtype='int32'), size=n_images, replace=False) meta_data = [("Org", X[ids]), ("Rec", W_outputs[0][ids]), ("OrgDropout", X_drop[ids]), ("RecDropout", W_drop_outputs[0][ids])] count = 1 for name, data in meta_data: for i in range(n_images): x = data[i].reshape(28, 28) plt.subplot(4, n_images, count) show_image(x) if i == 0: plt.ylabel(name, fontsize=8) count += 1 plt.subplots_adjust(wspace=0.05, hspace=0.05) plot_save(os.path.join(save_dir, 'image_epoch%d.png') % curr_epoch, dpi=180, clear_all=True, log=True)