def plot_confusion_matrix(self, y_true='celltype', y_pred='icelltype'): r""" Confusion matrix for binary labels """ y_true = OMIC.parse(y_true) y_pred = OMIC.parse(y_pred) name = f"true:{y_true.name}_pred:{y_pred.name}" x_true = self.dataset.get_omic(y_true) x_pred = self.dataset.get_omic(y_pred) if x_true.ndim > 1: x_true = np.argmax(x_true, axis=-1) if x_pred.ndim > 1: x_pred = np.argmax(x_pred, axis=-1) fig = plt.figure(figsize=(6, 6)) plot_confusion_matrix(y_true=x_true, y_pred=x_pred, labels=self.dataset.get_var_names(y_true), cmap="Blues", ax=fig.gca(), fontsize=12, cbar=True, title=name) return self.add_figure(name=f"cm_{name}", fig=fig)
def evaluate_feeder(feeder, title): y_true_digit = [] y_true_gender = [] y_pred = [] for outputs in Progbar(feeder.set_batch(batch_mode='file'), name=title, print_report=True, print_summary=False, count_func=lambda x: x[-1].shape[0]): name = str(outputs[0]) idx = int(outputs[1]) data = outputs[2:] assert idx == 0 y_true_digit.append(f_digits(name)) y_true_gender.append(f_genders(name)) y_pred.append(f_pred(*data)) # ====== post processing ====== # y_true_digit = np.array(y_true_digit, dtype='int32') y_true_gender = np.array(y_true_gender, dtype='int32') y_pred_proba = np.concatenate(y_pred, axis=0) y_pred_all = np.argmax(y_pred_proba, axis=-1).astype('int32') # ====== plotting for each gender ====== # plot_figure(nrow=6, ncol=25) for gen in range(len(genders)): y_true, y_pred = [], [] for i, g in enumerate(y_true_gender): if g == gen: y_true.append(y_true_digit[i]) y_pred.append(y_pred_all[i]) if len(y_true) == 0: continue cm = confusion_matrix(y_true, y_pred, labels=range(len(digits))) plot_confusion_matrix(cm, labels=digits, fontsize=8, ax=(1, 4, gen + 1), title='[%s]%s' % (genders[gen], title)) plot_save(os.path.join(FIG_PATH, '%s.pdf' % title))
def evaluate(vae, ds, expdir: str, title: str, batch_size: int = 32, seed: int = 1): from odin.bay.vi import Correlation rand = np.random.RandomState(seed=seed) if not os.path.exists(expdir): os.makedirs(expdir) tanh = True if ds.name.lower() == 'celeba' else False ## data for training semi-supervised # careful don't allow any data leakage! train = ds.create_dataset('train', batch_size=batch_size, label_percent=True, shuffle=False, normalize='tanh' if tanh else 'probs') data = [(vae.encode(x, training=False), y) \ for x, y in tqdm(train, desc=title)] x_semi_train = tf.concat( [tf.concat([i.mean(), _ymean(j)], axis=1) for (i, j), _ in data], axis=0).numpy() y_semi_train = tf.concat([i for _, i in data], axis=0).numpy() # shuffle ids = rand.permutation(x_semi_train.shape[0]) x_semi_train = x_semi_train[ids] y_semi_train = y_semi_train[ids] ## data for testing test = ds.create_dataset('test', batch_size=batch_size, label_percent=True, shuffle=False, normalize='tanh' if tanh else 'probs') prog = tqdm(test, desc=title) llk_x = [] llk_y = [] z = [] y_true = [] y_pred = [] x_true = [] x_pred = [] x_org, x_rec = [], [] for x, y in prog: px, (qz, qy) = vae(x, training=False) y_true.append(y) y_pred.append(_ymean(qy)) z.append(qz.mean()) llk_x.append(px.log_prob(x)) llk_y.append(qy.log_prob(y)) if rand.uniform() < 0.005 or len(x_org) < 2: x_org.append(x) x_rec.append(px.mean()) ## llk llk_x = tf.reduce_mean(tf.concat(llk_x, axis=0)).numpy() llk_y = tf.reduce_mean(tf.concat(llk_y, axis=0)).numpy() ## the latents z = tf.concat(z, axis=0).numpy() y_true = tf.concat(y_true, axis=0).numpy() y_pred = tf.concat(y_pred, axis=0).numpy() x_semi_test = tf.concat([z, y_pred], axis=-1).numpy() # shuffle ids = rand.permutation(z.shape[0]) z = z[ids] y_true = y_true[ids] y_pred = y_pred[ids] x_semi_test = x_semi_test[ids] ## saving reconstruction images x_org = tf.concat(x_org, axis=0).numpy() x_rec = tf.concat(x_rec, axis=0).numpy() ids = rand.permutation(x_org.shape[0]) x_org = x_org[ids][:36] x_rec = x_rec[ids][:36] vmin = x_rec.reshape((36, -1)).min(axis=1).reshape((36, 1, 1, 1)) vmax = x_rec.reshape((36, -1)).max(axis=1).reshape((36, 1, 1, 1)) if tanh: x_org = (x_org + 1.) / 2. x_rec = (x_rec - vmin) / (vmax - vmin) if x_org.shape[-1] == 1: # grayscale image x_org = np.squeeze(x_org, -1) x_rec = np.squeeze(x_rec, -1) else: # color image x_org = np.transpose(x_org, (0, 3, 1, 2)) x_rec = np.transpose(x_rec, (0, 3, 1, 2)) plt.figure(figsize=(15, 8)) ax = plt.subplot(1, 2, 1) vs.plot_images(x_org, grids=(6, 6), ax=ax, title='Original') ax = plt.subplot(1, 2, 2) vs.plot_images(x_rec, grids=(6, 6), ax=ax, title='Reconstructed') plt.tight_layout() ## prepare the labels if ds.name in ('mnist', 'fashionmnist', 'celeba'): true = np.argmax(y_true, axis=-1) pred = np.argmax(y_pred, axis=-1) y_semi_train = np.argmax(y_semi_train, axis=-1) y_semi_test = true labels_name = ds.labels else: # shapes3d dsprites true = y_true[:, 2].astype(np.int32) pred = y_pred[:, 2].astype(np.int32) y_semi_train = y_semi_train[:, 2].astype(np.int32) y_semi_test = true if ds.name == 'shapes3d': labels_name = ['cube', 'cylinder', 'sphere', 'round'] elif ds.name == 'dsprites': labels_name = ['square', 'ellipse', 'heart'] plt.figure(figsize=(8, 8)) vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred), labels=labels_name, cbar=True, fontsize=10, title=title) labels = np.array([labels_name[i] for i in true]) labels_pred = np.array([labels_name[i] for i in pred]) ## save arrays for later inspectation np.savez_compressed(f'{expdir}/arrays', x_train=x_semi_train, y_train=y_semi_train, x_test=x_semi_test, y_test=y_semi_test, zdim=z.shape[1], labels=labels_name) print(f'Export arrays to "{expdir}/arrays.npz"') ## semi-supervised with open(f'{expdir}/results.txt', 'w') as f: print(f'Export results to "{expdir}/results.txt"') f.write(f'Steps: {vae.step.numpy()}\n') f.write(f'llk_x: {llk_x}\n') f.write(f'llk_y: {llk_y}\n') for p in [0.004, 0.06, 0.2, 0.99]: x_train, x_test, y_train, y_test = train_test_split( x_semi_train, y_semi_train, train_size=int(np.round(p * x_semi_train.shape[0])), random_state=1, ) m = LogisticRegression(max_iter=3000, random_state=1) m.fit(x_train, y_train) # write the report f.write(f'{m.__class__.__name__} Number of labels: ' f'{p} {x_train.shape[0]}/{x_test.shape[0]}') f.write('\nValidation:\n') f.write( classification_report(y_true=y_test, y_pred=m.predict(x_test))) f.write('\nTest:\n') f.write( classification_report(y_true=y_semi_test, y_pred=m.predict(x_semi_test))) f.write('------------\n') ## scatter plot n_points = 4000 # tsne plot tsne = DimReduce.TSNE(z[:n_points], n_components=2) kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.6) plt.figure(figsize=(8, 8)) vs.plot_scatter(color=labels[:n_points], title=f'[True-tSNE]{title}', **kw) plt.figure(figsize=(8, 8)) vs.plot_scatter(color=labels_pred[:n_points], title=f'[Pred-tSNE]{title}', **kw) # pca plot pca = DimReduce.PCA(z, n_components=2) kw = dict(x=pca[:, 0], y=pca[:, 1], grid=False, size=12.0, alpha=0.6) plt.figure(figsize=(8, 8)) vs.plot_scatter(color=labels, title=f'[True-PCA]{title}', **kw) plt.figure(figsize=(8, 8)) vs.plot_scatter(color=labels_pred, title=f'[Pred-PCA]{title}', **kw) ## factors plot corr = (Correlation.Spearman(z, y_true) + Correlation.Pearson(z, y_true)) / 2. best_z = np.argsort(np.abs(corr), axis=0)[-2:] style = dict(size=15.0, alpha=0.6, grid=False) for fi, (z1, z2) in enumerate(best_z.T): plt.figure(figsize=(8, 4)) ax = plt.subplot(1, 2, 1) vs.plot_scatter(x=z[:n_points, z1], y=z[:n_points, z2], val=y_true[:n_points, fi], ax=ax, title=ds.labels[fi], **style) ax = plt.subplot(1, 2, 2) vs.plot_scatter(x=z[:n_points, z1], y=z[:n_points, z2], val=y_pred[:n_points, fi], ax=ax, title=ds.labels[fi], **style) plt.tight_layout() ## save all plot vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True)
def evaluate(vae: VariationalAutoencoder, ds: ImageDataset, expdir: str, title: str, batch_size: int = 64, take_count: int = -1, n_images: int = 36, seed: int = 1): n_rows = int(np.sqrt(n_images)) is_semi = vae.is_semi_supervised() is_hierarchical = vae.is_hierarchical() ds_kw = dict(batch_size=batch_size, label_percent=1.0, shuffle=False) ## prepare rand = np.random.RandomState(seed=seed) if not os.path.exists(expdir): os.makedirs(expdir) ## data for training semi-supervised train = ds.create_dataset('train', **ds_kw) (llkx_train, llky_train, x_org_train, x_rec_train, y_true_train, y_pred_train, z_train, pz_train) = _call(vae, ds=train, rand=rand, take_count=take_count, n_images=n_images, verbose=True) ## data for testing test = ds.create_dataset('test', **ds_kw) (llkx_test, llky_test, x_org_test, x_rec_test, y_true_test, y_pred_test, z_test, pz_test) = _call(vae, ds=test, rand=rand, take_count=take_count, n_images=n_images, verbose=True) # === 0. plotting latent-factor pairs for idx, z in enumerate(z_test): z = z.mean() f = y_true_test corr_mat = Correlation.Spearman(z, f) # [n_latents, n_factors] plot_latents_pairs(z, f, corr_mat, ds.labels) vs.plot_save(f'{expdir}/latent{idx}_factor.pdf', dpi=100, verbose=True) # === 0. latent traverse plot x_travs = x_org_test if x_travs.ndim == 3: # grayscale image x_travs = np.expand_dims(x_travs, -1) else: # color image x_travs = np.transpose(x_travs, (0, 2, 3, 1)) x_travs = x_travs[rand.permutation(x_travs.shape[0])] n_visual_samples = 5 n_traverse_points = 21 n_top_latents = 10 plt.figure(figsize=(8, 3 * n_visual_samples)) for i in range(n_visual_samples): images = vae.sample_traverse(x_travs[i:i + 1], min_val=-np.min(z_test[0].mean()), max_val=np.max(z_test[0].mean()), n_best_latents=n_top_latents, n_traverse_points=n_traverse_points, mode='linear') images = as_tuple(images)[0] images = _prepare_images(images.mean().numpy(), normalize=True) vs.plot_images(images, grids=(n_top_latents, n_traverse_points), ax=(n_visual_samples, 1, i + 1)) if i == 0: plt.title('Latents traverse') plt.tight_layout() vs.plot_save(f'{expdir}/latents_traverse.pdf', dpi=180, verbose=True) # === 0. prior sampling plot images = as_tuple(vae.sample_observation(n=n_images, seed=seed))[0] images = _prepare_images(images.mean().numpy(), normalize=True) plt.figure(figsize=(5, 5)) vs.plot_images(images, grids=(n_rows, n_rows), title='Sampled') # === 1. reconstruction plot plt.figure(figsize=(15, 15)) vs.plot_images(x_org_train, grids=(n_rows, n_rows), ax=(2, 2, 1), title='[Train]Original') vs.plot_images(x_rec_train, grids=(n_rows, n_rows), ax=(2, 2, 2), title='[Train]Reconstructed') vs.plot_images(x_org_test, grids=(n_rows, n_rows), ax=(2, 2, 3), title='[Test]Original') vs.plot_images(x_rec_test, grids=(n_rows, n_rows), ax=(2, 2, 4), title='[Test]Reconstructed') plt.tight_layout() ## prepare the labels label_type = ds.label_type if label_type == 'categorical': labels_name = ds.labels true = np.argmax(y_true_test, axis=-1) labels_true = np.array([labels_name[i] for i in true]) labels_pred = labels_true if is_semi: pred = np.argmax(y_pred_test.mean().numpy(), axis=-1) labels_pred = np.array([labels_name[i] for i in pred]) elif label_type == 'factor': # dsprites, shapes3d labels_name = ['cube', 'cylinder', 'sphere', 'round'] \ if 'shapes3d' in ds.name else ['square', 'ellipse', 'heart'] true = y_true_test[:, 2].astype('int32') labels_true = np.array([labels_name[i] for i in true]) labels_pred = labels_true if is_semi: pred = get_ymean(y_pred_test)[:, 2].astype('int32') labels_pred = np.array([labels_name[i] for i in pred]) else: # CelebA raise NotImplementedError ## confusion matrix if is_semi: plt.figure(figsize=(8, 8)) acc = accuracy_score(y_true=true, y_pred=pred) vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred), labels=labels_name, cbar=True, fontsize=10, title=f'{title} Acc:{acc:.2f}') ## save arrays for later inspections with open(f'{expdir}/arrays', 'wb') as f: pickle.dump( dict(z_train=z_train, y_pred_train=y_pred_train, y_true_train=y_true_train, z_test=z_test, y_pred_test=y_pred_test, y_true_test=y_true_test, labels=labels_name, ds=ds.name, label_type=label_type), f) print(f'Exported arrays to "{expdir}/arrays"') ## semi-supervised z_mean_train = np.concatenate( [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_train], -1) z_mean_test = np.concatenate( [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_test], -1) # === 2. scatter points latents plot n_points = 5000 ids = rand.permutation(len(labels_true))[:n_points] Y_true = labels_true[ids] Y_pred = labels_pred[ids] # tsne plot n_latents = 0 if len(z_train) == 1 else len(z_train) for name, X in zip( ['all'] + [f'latents{i}' for i in range(n_latents)], [z_mean_test[ids]] + [z_test[i].mean().numpy()[ids] for i in range(n_latents)]): print(f'Plot scatter points for {name}') X = X.reshape(X.shape[0], -1) # flatten to 2D X = Pipeline([('zscore', StandardScaler()), ('pca', PCA(min(X.shape[1], 512), random_state=seed))]).fit_transform(X) tsne = DimReduce.TSNE(X, n_components=2, framework='sklearn') kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.8) plt.figure(figsize=(12, 6)) vs.plot_scatter(color=Y_true, title=f'[True]{title}-{name}', ax=(1, 2, 1), **kw) vs.plot_scatter(color=Y_pred, title=f'[Pred]{title}-{name}', ax=(1, 2, 2), **kw) ## save all plot vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True) # === 3. show the latents statistics n_latents = len(z_train) colors = sns.color_palette(n_colors=len(labels_true)) styles = dict(grid=False, ticks_off=False, alpha=0.6, xlabel='mean', ylabel='stddev') # scatter between latents and labels (assume categorical distribution) def _show_latents_labels(Z, Y, title): plt.figure(figsize=(5 * n_latents, 5), dpi=150) for idx, z in enumerate(Z): if len(z.batch_shape) == 0: mean = np.repeat(np.expand_dims(z.mean(), 0), Y.shape[0], 0) stddev = z.sample(Y.shape[0]) - mean else: mean = flatten(z.mean()) stddev = flatten(z.stddev()) y = np.argmax(Y, axis=-1) data = [[], [], []] for y_i, c in zip(np.unique(y), colors): mask = (y == y_i) data[0].append(np.mean(mean[mask], 0)) data[1].append(np.mean(stddev[mask], 0)) data[2].append([labels_true[y_i]] * mean.shape[1]) vs.plot_scatter( x=np.concatenate(data[0], 0), y=np.concatenate(data[1], 0), color=np.concatenate(data[2], 0), ax=(1, n_latents, idx + 1), size=15 if mean.shape[1] < 128 else 8, title=f'[Test-{title}]#{idx} - {mean.shape[1]} (units)', **styles) plt.tight_layout() # simple scatter mean-stddev each latents def _show_latents(Z, title): plt.figure(figsize=(3.5 * n_latents, 3.5), dpi=150) for idx, z in enumerate(Z): mean = flatten(z.mean()) stddev = flatten(z.stddev()) if mean.ndim == 2: mean = np.mean(mean, 0) stddev = np.mean(stddev, 0) vs.plot_scatter( x=mean, y=stddev, ax=(1, n_latents, idx + 1), size=15 if len(mean) < 128 else 8, title=f'[Test-{title}]#{idx} - {len(mean)} (units)', **styles) _show_latents_labels(z_test, y_true_test, 'post') _show_latents_labels(pz_test, y_true_test, 'prior') _show_latents(z_test, 'post') _show_latents(pz_test, 'prior') # KL statistics vs.plot_figure() for idx, (qz, pz) in enumerate(zip(z_test, pz_test)): kl = [] qz = Normal(loc=qz.mean(), scale=qz.stddev(), name=f'posterior{idx}') pz = Normal(loc=pz.mean(), scale=pz.stddev(), name=f'prior{idx}') for s, e in minibatch(batch_size=8, n=100): z = qz.sample(e - s) # don't do this in GPU, it explodes! kl.append((qz.log_prob(z) - pz.log_prob(z)).numpy()) kl = np.concatenate(kl, 0) # (mcmc, batch, event) # per sample kl_samples = np.sum(kl, as_tuple(list(range(2, kl.ndim)))) kl_samples = logsumexp(kl_samples, 0) plt.subplot(n_latents, 2, idx * 2 + 1) sns.histplot(kl_samples, bins=50) plt.title(f'Z#{idx} KL per sample (nats)') # per latent kl_latents = np.mean(flatten(logsumexp(kl, 0)), 0) plt.subplot(n_latents, 2, idx * 2 + 2) plt.plot(np.sort(kl_latents)) plt.title(f'Z#{idx} KL per dim (nats)') plt.tight_layout() vs.plot_save(f'{expdir}/latents.pdf', dpi=180, verbose=True)
def plot_evaluate_classifier(y_pred, y_true, labels, title, show_plot=True, return_figure=False): r""" Arguments: fig : Figure or tuple (`float`, `float`), optional (default=`None`) width, height in inches Returns: Return a dictionary of scores { F1micro=f1_micro * 100, F1macro=f1_macro * 100, F1weight=f1_weight * 100, F1_[classname]=... } """ from matplotlib import pyplot as plt fontsize = 12 num_classes = len(labels) nrow = int(np.ceil(num_classes / 5)) ncol = int(np.ceil(num_classes / nrow)) if y_pred.ndim == 1: y_pred = one_hot(y_pred, nb_classes=num_classes) if y_true.ndim == 1: y_true = one_hot(y_true, nb_classes=num_classes) if show_plot: fig = plot_figure(nrow=4 * nrow + 3, ncol=4 * ncol) f1_classes = [] for i, (name, pred, true) in enumerate(zip(labels, y_pred.T, y_true.T)): f1_classes.append(f1_score(true, pred)) if show_plot: plot_confusion_matrix(confusion_matrix(y_true=true, y_pred=pred), labels=[0, 1], fontsize=fontsize, ax=(nrow, ncol, i + 1), title=name + '\n') f1_micro = f1_score(y_true=y_true.ravel(), y_pred=y_pred.ravel()) f1_macro = np.mean(f1_classes) f1_weight = f1_score(y_true=y_true, y_pred=y_pred, average='weighted') if show_plot: plt.suptitle('%s\nF1-micro:%.2f F1-macro:%.2f F1-weight:%.2f' % (title, f1_micro * 100, f1_macro * 100, f1_weight * 100), fontsize=fontsize + 6) plt.tight_layout(rect=[0, 0.04, 1, 0.96]) results = dict( F1micro=f1_micro * 100, F1macro=f1_macro * 100, F1weight=f1_weight * 100, ) for name, f1 in zip(labels, f1_classes): results['F1_' + name] = f1 * 100 if show_plot and return_figure: return results, fig return results
def plot_evaluate_reconstruction(X, W, y_raw, y_prob, X_row, X_col, labels, title, pi=None, enable_image=True, enable_tsne=True, enable_sparsity=True): """ pi : zero-inflated rate (imputation rate), or dropout probabilities """ print("Evaluate: [Reconstruction]", ctext(title, 'lightyellow')) from matplotlib import pyplot as plt fontsize = 12 W_stdev_total, W_stdev_explained = None, None if isinstance(W, (tuple, list)): if len(W) == 1: W = W[0] elif len(W) == 3: W, W_stdev_total, W_stdev_explained = W else: raise RuntimeError() elif W.ndim == 3: W, W_stdev_total, W_stdev_explained = W[0], W[1], W[2] # convert the prediction to integer # W = W.astype('int32') assert (X.shape[0] == W.shape[0] == y_raw.shape[0]) and \ (X.shape == W.shape) and \ (y_raw.shape == y_prob.shape) X, X_row, W, y_raw, y_prob, pi = downsample_data(X, X_row, W, y_raw, y_prob, pi) y_argmax = np.argmax(y_prob, axis=-1) # ====== prepare count-sum ====== # X_log = K.log_norm(X, axis=1) W_log = K.log_norm(W, axis=1) X_gene_countsum = np.sum(X, axis=0) X_cell_countsum = np.sum(X, axis=1) X_gene_nzeros = np.sum(X == 0, axis=0) X_cell_nzeros = np.sum(X == 0, axis=1) gene_sort = np.argsort(X_gene_countsum) cell_sort = np.argsort(X_cell_countsum) W_gene_countsum = np.sum(W, axis=0) W_cell_countsum = np.sum(W, axis=1) W_gene_nzeros = np.sum(W == 0, axis=0) W_cell_nzeros = np.sum(W == 0, axis=1) X_col_sorted = X_col[gene_sort] if X_col is not None else None X_row_sorted = X_row[cell_sort] if X_row is not None else None if pi is not None: pi_cell_countsum = np.mean(pi, axis=1) pi_gene_countsum = np.mean(pi, axis=0) # ====== Compare image ====== # if enable_image: _RAND = np.random.RandomState(seed=87654321) n_img = 12 n_img_row = min(3, X.shape[0] // n_img) n_row_per_row = 2 if pi is None else 3 plot_figure(nrow=n_img_row * 4, ncol=18) count = 1 all_ids = _RAND.choice(np.arange(0, X.shape[0]), size=n_img * n_img_row, replace=False) for img_row in range(n_img_row): ids = all_ids[img_row * n_img:(img_row + 1) * n_img] # plot original images for _, i in enumerate(ids): ax = plt.subplot(n_row_per_row * n_img_row, n_img, count) show_image(X[i]) if _ == 0: plt.ylabel("Original") if X_row is not None: ax.set_title(X_row[i], fontsize=8) count += 1 # plot reconstructed images for _, i in enumerate(ids): plt.subplot(n_row_per_row * n_img_row, n_img, count) show_image(W[i]) if _ == 0: plt.ylabel("Reconstructed") count += 1 # plot zero-inflated rate if pi is not None: for _, i in enumerate(ids): plt.subplot(n_row_per_row * n_img_row, n_img, count) show_image(pi[i], is_probability=True) if _ == 0: plt.ylabel("$p_{zero-inflated}$") count += 1 plt.tight_layout() # ====== compare the T-SNE plot ====== # if enable_tsne: def pca_and_tsne(x, w): x_pca, w_pca = fast_pca(x, w, n_components=512, random_state=87654321) x_tsne = fast_tsne(x_pca, n_components=2, random_state=87654321) w_tsne = fast_tsne(w_pca, n_components=2, random_state=87654321) return x_pca[:, :2], x_tsne, w_pca[:, :2], w_tsne # transforming the data (X_cell_pca, X_cell_tsne, W_cell_pca, W_cell_tsne) = pca_and_tsne(X_log, W_log) (X_gene_pca, X_gene_tsne, W_gene_pca, W_gene_tsne) = pca_and_tsne(X_log.T, W_log.T) # prepare the figure n_plot = 3 + 2 # 3 for cells, 2 for genes if pi is not None: n_plot += 2 # 2 more row for pi plot_figure(nrow=n_plot * 5, ncol=18) # Cells fast_scatter(x=X_cell_pca, y=y_argmax, labels=labels, title="[PCA]Original Cell Data", ax=(n_plot, 2, 1), enable_legend=False) fast_scatter(x=W_cell_pca, y=y_argmax, labels=labels, title="[PCA]Reconstructed Cell Data", ax=(n_plot, 2, 2), enable_legend=False) fast_scatter(x=X_cell_tsne, y=y_argmax, labels=labels, title="[t-SNE]Original Cell Data", ax=(n_plot, 2, 3), enable_legend=True) fast_scatter(x=W_cell_tsne, y=y_argmax, labels=labels, title="[t-SNE]Reconstructed Cell Data", ax=(n_plot, 2, 4), enable_legend=False) fast_log = lambda x: K.log_norm(x, axis=0) plot_scatter( x=X_cell_tsne, val=fast_log(X_cell_countsum), title="[t-SNE]Original Cell Data + Original Cell Countsum", ax=(n_plot, 2, 5), colorbar=True) plot_scatter( x=X_cell_tsne, val=fast_log(W_cell_countsum), title="[t-SNE]Original Cell Data + Reconstructed Cell Countsum", ax=(n_plot, 2, 6), colorbar=True) # Genes plot_scatter(x=X_gene_pca, val=fast_log(X_gene_countsum), title="[PCA]Original Gene Data + Original Gene Countsum", ax=(n_plot, 2, 7), colorbar=True) plot_scatter( x=W_gene_pca, val=fast_log(X_gene_countsum), title="[PCA]Reconstructed Gene Data + Original Gene Countsum", ax=(n_plot, 2, 8), colorbar=True) plot_scatter( x=X_gene_tsne, val=fast_log(X_gene_countsum), title="[t-SNE]Original Gene Data + Original Gene Countsum", ax=(n_plot, 2, 9), colorbar=True) plot_scatter( x=X_gene_tsne, val=fast_log(W_gene_countsum), title="[t-SNE]Original Gene Data + Reconstructed Gene Countsum", ax=(n_plot, 2, 10), colorbar=True) # zero-inflation rate if pi is not None: plot_scatter( x=X_cell_tsne, val=X_cell_countsum, title="[t-SNE]Original Cell Data + Original Cell Countsum", ax=(n_plot, 2, 11), colorbar=True) plot_scatter( x=X_cell_tsne, val=pi_cell_countsum, title="[t-SNE]Original Cell Data + Zero-inflated rate", ax=(n_plot, 2, 12), colorbar=True) plot_scatter( x=X_gene_tsne, val=X_gene_countsum, title="[t-SNE]Original Gene Data + Original Gene Countsum", ax=(n_plot, 2, 13), colorbar=True) plot_scatter( x=X_gene_tsne, val=pi_gene_countsum, title="[t-SNE]Original Gene Data + Zero-inflated rate", ax=(n_plot, 2, 14), colorbar=True) plt.tight_layout() # ******************** sparsity ******************** # if enable_sparsity: plot_figure(nrow=8, ncol=8) # ====== sparsity ====== # z = (X.ravel() == 0).astype('int32') z_res = (W.ravel() == 0).astype('int32') plot_confusion_matrix(ax=None, cm=confusion_matrix(y_true=z, y_pred=z_res, labels=(0, 1)), labels=('Not Zero', 'Zero'), colorbar=True, fontsize=fontsize + 4, title="Sparsity")
def evaluate(y_true, y_pred_proba=None, y_pred_log_proba=None, labels=None, title='', path=None, xlims=None, ylims=None, print_log=True): from odin.backend import to_llr from odin.backend.metrics import (det_curve, compute_EER, roc_curve, compute_Cavg, compute_Cnorm, compute_minDCF) def format_score(s): return ctext('%.4f' % s if is_number(s) else s, 'yellow') nb_classes = None # ====== check y_pred ====== # if y_pred_proba is None and y_pred_log_proba is None: raise ValueError("At least one of `y_pred_proba` or `y_pred_log_proba` " "must not be None") y_pred_llr = to_llr(y_pred_proba) if y_pred_log_proba is None \ else to_llr(y_pred_log_proba) nb_classes = y_pred_llr.shape[1] y_pred = np.argmax(y_pred_llr, axis=-1) # ====== check y_true ====== # if isinstance(y_true, (tuple, list)): y_true = np.array(y_true) if y_true.ndim == 2: # convert one-hot to labels y_true = np.argmax(y_true, axis=-1) # ====== check labels ====== # if labels is None: labels = [str(i) for i in range(nb_classes)] # ====== scoring ====== # if y_pred_proba is None: ll = 'unknown' else: ll = log_loss(y_true=y_true, y_pred=y_pred_proba) acc = accuracy_score(y_true=y_true, y_pred=y_pred) cm = confusion_matrix(y_true=y_true, y_pred=y_pred) # C_norm cnorm, cnorm_arr = compute_Cnorm(y_true=y_true, y_score=y_pred_llr, Ptrue=[0.1, 0.5], probability_input=False) if y_pred_log_proba is not None: cnorm_, cnorm_arr_ = compute_Cnorm(y_true=y_true, y_score=y_pred_log_proba, Ptrue=[0.1, 0.5], probability_input=False) if np.mean(cnorm) > np.mean(cnorm_): # smaller is better cnorm, cnorm_arr = cnorm_, cnorm_arr_ # DET Pfa, Pmiss = det_curve(y_true=y_true, y_score=y_pred_llr) eer = compute_EER(Pfa=Pfa, Pmiss=Pmiss) minDCF = compute_minDCF(Pfa, Pmiss)[0] # PRINT LOG if print_log: print(ctext("--------", 'red'), ctext(title, 'cyan')) print("Log loss :", format_score(ll)) print("Accuracy :", format_score(acc)) print("C_norm :", format_score(np.mean(cnorm))) print("EER :", format_score(eer)) print("minDCF :", format_score(minDCF)) print(print_confusion(arr=cm, labels=labels)) # ====== save report to PDF files if necessary ====== # if path is not None: if y_pred_proba is None: y_pred_proba = y_pred_llr from matplotlib import pyplot as plt plt.figure(figsize=(nb_classes, nb_classes + 1)) plot_confusion_matrix(cm, labels) # Cavg plt.figure(figsize=(nb_classes + 1, 3)) plot_Cnorm(cnorm=cnorm_arr, labels=labels, Ptrue=[0.1, 0.5], fontsize=14) # binary classification if nb_classes == 2 and \ (y_pred_proba.ndim == 1 or (y_pred_proba.ndim == 2 and y_pred_proba.shape[1] == 1)): fpr, tpr = roc_curve(y_true=y_true, y_score=y_pred_proba.ravel()) # det curve plt.figure() plot_detection_curve(Pfa, Pmiss, curve='det', xlims=xlims, ylims=ylims, linewidth=1.2) # roc curve plt.figure() plot_detection_curve(fpr, tpr, curve='roc') # multiclasses else: y_true = one_hot(y_true, nb_classes=nb_classes) fpr_micro, tpr_micro, _ = roc_curve(y_true=y_true.ravel(), y_score=y_pred_proba.ravel()) Pfa_micro, Pmiss_micro = Pfa, Pmiss fpr, tpr = [], [] Pfa, Pmiss = [], [] for i, yi in enumerate(y_true.T): curve = roc_curve(y_true=yi, y_score=y_pred_proba[:, i]) fpr.append(curve[0]) tpr.append(curve[1]) curve = det_curve(y_true=yi, y_score=y_pred_llr[:, i]) Pfa.append(curve[0]) Pmiss.append(curve[1]) plt.figure() plot_detection_curve(fpr_micro, tpr_micro, curve='roc', linewidth=1.2, title="ROC Micro") plt.figure() plot_detection_curve(fpr, tpr, curve='roc', labels=labels, linewidth=1.0, title="ROC for each classes") plt.figure() plot_detection_curve(Pfa_micro, Pmiss_micro, curve='det', xlims=xlims, ylims=ylims, linewidth=1.2, title="DET Micro") plt.figure() plot_detection_curve(Pfa, Pmiss, curve='det', xlims=xlims, ylims=ylims, labels=labels, linewidth=1.0, title="DET for each classes") plot_save(path)
def evaluate(y_true, y_pred_proba=None, y_pred_log_proba=None, labels=None, title='', path=None, xlims=None, ylims=None, print_log=True): from odin.backend import to_llr from odin.backend.metrics import (det_curve, compute_EER, roc_curve, compute_Cavg, compute_Cnorm, compute_minDCF) def format_score(s): return ctext('%.4f' % s if is_number(s) else s, 'yellow') nb_classes = None # ====== check y_pred ====== # if y_pred_proba is None and y_pred_log_proba is None: raise ValueError("At least one of `y_pred_proba` or `y_pred_log_proba` " "must not be None") y_pred_llr = to_llr(y_pred_proba) if y_pred_log_proba is None \ else to_llr(y_pred_log_proba) nb_classes = y_pred_llr.shape[1] y_pred = np.argmax(y_pred_llr, axis=-1) # ====== check y_true ====== # if isinstance(y_true, Data): y_true = y_true.array if isinstance(y_true, (tuple, list)): y_true = np.array(y_true) if y_true.ndim == 2: # convert one-hot to labels y_true = np.argmax(y_true, axis=-1) # ====== check labels ====== # if labels is None: labels = [str(i) for i in range(nb_classes)] # ====== scoring ====== # if y_pred_proba is None: ll = 'unknown' else: ll = log_loss(y_true=y_true, y_pred=y_pred_proba) acc = accuracy_score(y_true=y_true, y_pred=y_pred) cm = confusion_matrix(y_true=y_true, y_pred=y_pred) # C_norm cnorm, cnorm_arr = compute_Cnorm(y_true=y_true, y_score=y_pred_llr, Ptrue=[0.1, 0.5], probability_input=False) if y_pred_log_proba is not None: cnorm_, cnorm_arr_ = compute_Cnorm(y_true=y_true, y_score=y_pred_log_proba, Ptrue=[0.1, 0.5], probability_input=False) if np.mean(cnorm) > np.mean(cnorm_): # smaller is better cnorm, cnorm_arr = cnorm_, cnorm_arr_ # DET Pfa, Pmiss = det_curve(y_true=y_true, y_score=y_pred_llr) eer = compute_EER(Pfa=Pfa, Pmiss=Pmiss) minDCF = compute_minDCF(Pfa, Pmiss)[0] # PRINT LOG if print_log: print(ctext("--------", 'red'), ctext(title, 'cyan')) print("Log loss :", format_score(ll)) print("Accuracy :", format_score(acc)) print("C_norm :", format_score(np.mean(cnorm))) print("EER :", format_score(eer)) print("minDCF :", format_score(minDCF)) print(print_confusion(arr=cm, labels=labels)) # ====== save report to PDF files if necessary ====== # if path is not None: if y_pred_proba is None: y_pred_proba = y_pred_llr from matplotlib import pyplot as plt plt.figure(figsize=(nb_classes, nb_classes + 1)) plot_confusion_matrix(cm, labels) # Cavg plt.figure(figsize=(nb_classes + 1, 3)) plot_Cnorm(cnorm=cnorm_arr, labels=labels, Ptrue=[0.1, 0.5], fontsize=14) # binary classification if nb_classes == 2 and \ (y_pred_proba.ndim == 1 or (y_pred_proba.ndim == 2 and y_pred_proba.shape[1] == 1)): fpr, tpr = roc_curve(y_true=y_true, y_score=y_pred_proba.ravel()) # det curve plt.figure() plot_detection_curve(Pfa, Pmiss, curve='det', xlims=xlims, ylims=ylims, linewidth=1.2) # roc curve plt.figure() plot_detection_curve(fpr, tpr, curve='roc') # multiclasses else: y_true = one_hot(y_true, nb_classes=nb_classes) fpr_micro, tpr_micro, _ = roc_curve(y_true=y_true.ravel(), y_score=y_pred_proba.ravel()) Pfa_micro, Pmiss_micro = Pfa, Pmiss fpr, tpr = [], [] Pfa, Pmiss = [], [] for i, yi in enumerate(y_true.T): curve = roc_curve(y_true=yi, y_score=y_pred_proba[:, i]) fpr.append(curve[0]) tpr.append(curve[1]) curve = det_curve(y_true=yi, y_score=y_pred_llr[:, i]) Pfa.append(curve[0]) Pmiss.append(curve[1]) plt.figure() plot_detection_curve(fpr_micro, tpr_micro, curve='roc', linewidth=1.2, title="ROC Micro") plt.figure() plot_detection_curve(fpr, tpr, curve='roc', labels=labels, linewidth=1.0, title="ROC for each classes") plt.figure() plot_detection_curve(Pfa_micro, Pmiss_micro, curve='det', xlims=xlims, ylims=ylims, linewidth=1.2, title="DET Micro") plt.figure() plot_detection_curve(Pfa, Pmiss, curve='det', xlims=xlims, ylims=ylims, labels=labels, linewidth=1.0, title="DET for each classes") plot_save(path)