コード例 #1
0
 def plot_confusion_matrix(self, y_true='celltype', y_pred='icelltype'):
     r""" Confusion matrix for binary labels """
     y_true = OMIC.parse(y_true)
     y_pred = OMIC.parse(y_pred)
     name = f"true:{y_true.name}_pred:{y_pred.name}"
     x_true = self.dataset.get_omic(y_true)
     x_pred = self.dataset.get_omic(y_pred)
     if x_true.ndim > 1:
         x_true = np.argmax(x_true, axis=-1)
     if x_pred.ndim > 1:
         x_pred = np.argmax(x_pred, axis=-1)
     fig = plt.figure(figsize=(6, 6))
     plot_confusion_matrix(y_true=x_true,
                           y_pred=x_pred,
                           labels=self.dataset.get_var_names(y_true),
                           cmap="Blues",
                           ax=fig.gca(),
                           fontsize=12,
                           cbar=True,
                           title=name)
     return self.add_figure(name=f"cm_{name}", fig=fig)
コード例 #2
0
def evaluate_feeder(feeder, title):
    y_true_digit = []
    y_true_gender = []
    y_pred = []
    for outputs in Progbar(feeder.set_batch(batch_mode='file'),
                           name=title,
                           print_report=True,
                           print_summary=False,
                           count_func=lambda x: x[-1].shape[0]):
        name = str(outputs[0])
        idx = int(outputs[1])
        data = outputs[2:]
        assert idx == 0
        y_true_digit.append(f_digits(name))
        y_true_gender.append(f_genders(name))
        y_pred.append(f_pred(*data))
    # ====== post processing ====== #
    y_true_digit = np.array(y_true_digit, dtype='int32')
    y_true_gender = np.array(y_true_gender, dtype='int32')
    y_pred_proba = np.concatenate(y_pred, axis=0)
    y_pred_all = np.argmax(y_pred_proba, axis=-1).astype('int32')
    # ====== plotting for each gender ====== #
    plot_figure(nrow=6, ncol=25)
    for gen in range(len(genders)):
        y_true, y_pred = [], []
        for i, g in enumerate(y_true_gender):
            if g == gen:
                y_true.append(y_true_digit[i])
                y_pred.append(y_pred_all[i])
        if len(y_true) == 0:
            continue
        cm = confusion_matrix(y_true, y_pred, labels=range(len(digits)))
        plot_confusion_matrix(cm,
                              labels=digits,
                              fontsize=8,
                              ax=(1, 4, gen + 1),
                              title='[%s]%s' % (genders[gen], title))
    plot_save(os.path.join(FIG_PATH, '%s.pdf' % title))
コード例 #3
0
ファイル: hyperparams.py プロジェクト: trungnt13/odin-ai
def evaluate(vae,
             ds,
             expdir: str,
             title: str,
             batch_size: int = 32,
             seed: int = 1):
    from odin.bay.vi import Correlation
    rand = np.random.RandomState(seed=seed)
    if not os.path.exists(expdir):
        os.makedirs(expdir)
    tanh = True if ds.name.lower() == 'celeba' else False
    ## data for training semi-supervised
    # careful don't allow any data leakage!
    train = ds.create_dataset('train',
                              batch_size=batch_size,
                              label_percent=True,
                              shuffle=False,
                              normalize='tanh' if tanh else 'probs')
    data = [(vae.encode(x, training=False), y) \
      for x, y in tqdm(train, desc=title)]
    x_semi_train = tf.concat(
        [tf.concat([i.mean(), _ymean(j)], axis=1) for (i, j), _ in data],
        axis=0).numpy()
    y_semi_train = tf.concat([i for _, i in data], axis=0).numpy()
    # shuffle
    ids = rand.permutation(x_semi_train.shape[0])
    x_semi_train = x_semi_train[ids]
    y_semi_train = y_semi_train[ids]
    ## data for testing
    test = ds.create_dataset('test',
                             batch_size=batch_size,
                             label_percent=True,
                             shuffle=False,
                             normalize='tanh' if tanh else 'probs')
    prog = tqdm(test, desc=title)
    llk_x = []
    llk_y = []
    z = []
    y_true = []
    y_pred = []
    x_true = []
    x_pred = []
    x_org, x_rec = [], []
    for x, y in prog:
        px, (qz, qy) = vae(x, training=False)
        y_true.append(y)
        y_pred.append(_ymean(qy))
        z.append(qz.mean())
        llk_x.append(px.log_prob(x))
        llk_y.append(qy.log_prob(y))
        if rand.uniform() < 0.005 or len(x_org) < 2:
            x_org.append(x)
            x_rec.append(px.mean())
    ## llk
    llk_x = tf.reduce_mean(tf.concat(llk_x, axis=0)).numpy()
    llk_y = tf.reduce_mean(tf.concat(llk_y, axis=0)).numpy()
    ## the latents
    z = tf.concat(z, axis=0).numpy()
    y_true = tf.concat(y_true, axis=0).numpy()
    y_pred = tf.concat(y_pred, axis=0).numpy()
    x_semi_test = tf.concat([z, y_pred], axis=-1).numpy()
    # shuffle
    ids = rand.permutation(z.shape[0])
    z = z[ids]
    y_true = y_true[ids]
    y_pred = y_pred[ids]
    x_semi_test = x_semi_test[ids]
    ## saving reconstruction images
    x_org = tf.concat(x_org, axis=0).numpy()
    x_rec = tf.concat(x_rec, axis=0).numpy()
    ids = rand.permutation(x_org.shape[0])
    x_org = x_org[ids][:36]
    x_rec = x_rec[ids][:36]
    vmin = x_rec.reshape((36, -1)).min(axis=1).reshape((36, 1, 1, 1))
    vmax = x_rec.reshape((36, -1)).max(axis=1).reshape((36, 1, 1, 1))
    if tanh:
        x_org = (x_org + 1.) / 2.
    x_rec = (x_rec - vmin) / (vmax - vmin)
    if x_org.shape[-1] == 1:  # grayscale image
        x_org = np.squeeze(x_org, -1)
        x_rec = np.squeeze(x_rec, -1)
    else:  # color image
        x_org = np.transpose(x_org, (0, 3, 1, 2))
        x_rec = np.transpose(x_rec, (0, 3, 1, 2))
    plt.figure(figsize=(15, 8))
    ax = plt.subplot(1, 2, 1)
    vs.plot_images(x_org, grids=(6, 6), ax=ax, title='Original')
    ax = plt.subplot(1, 2, 2)
    vs.plot_images(x_rec, grids=(6, 6), ax=ax, title='Reconstructed')
    plt.tight_layout()
    ## prepare the labels
    if ds.name in ('mnist', 'fashionmnist', 'celeba'):
        true = np.argmax(y_true, axis=-1)
        pred = np.argmax(y_pred, axis=-1)
        y_semi_train = np.argmax(y_semi_train, axis=-1)
        y_semi_test = true
        labels_name = ds.labels
    else:  # shapes3d dsprites
        true = y_true[:, 2].astype(np.int32)
        pred = y_pred[:, 2].astype(np.int32)
        y_semi_train = y_semi_train[:, 2].astype(np.int32)
        y_semi_test = true
        if ds.name == 'shapes3d':
            labels_name = ['cube', 'cylinder', 'sphere', 'round']
        elif ds.name == 'dsprites':
            labels_name = ['square', 'ellipse', 'heart']
    plt.figure(figsize=(8, 8))
    vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred),
                             labels=labels_name,
                             cbar=True,
                             fontsize=10,
                             title=title)
    labels = np.array([labels_name[i] for i in true])
    labels_pred = np.array([labels_name[i] for i in pred])
    ## save arrays for later inspectation
    np.savez_compressed(f'{expdir}/arrays',
                        x_train=x_semi_train,
                        y_train=y_semi_train,
                        x_test=x_semi_test,
                        y_test=y_semi_test,
                        zdim=z.shape[1],
                        labels=labels_name)
    print(f'Export arrays to "{expdir}/arrays.npz"')
    ## semi-supervised
    with open(f'{expdir}/results.txt', 'w') as f:
        print(f'Export results to "{expdir}/results.txt"')
        f.write(f'Steps: {vae.step.numpy()}\n')
        f.write(f'llk_x: {llk_x}\n')
        f.write(f'llk_y: {llk_y}\n')
        for p in [0.004, 0.06, 0.2, 0.99]:
            x_train, x_test, y_train, y_test = train_test_split(
                x_semi_train,
                y_semi_train,
                train_size=int(np.round(p * x_semi_train.shape[0])),
                random_state=1,
            )
            m = LogisticRegression(max_iter=3000, random_state=1)
            m.fit(x_train, y_train)
            # write the report
            f.write(f'{m.__class__.__name__} Number of labels: '
                    f'{p} {x_train.shape[0]}/{x_test.shape[0]}')
            f.write('\nValidation:\n')
            f.write(
                classification_report(y_true=y_test, y_pred=m.predict(x_test)))
            f.write('\nTest:\n')
            f.write(
                classification_report(y_true=y_semi_test,
                                      y_pred=m.predict(x_semi_test)))
            f.write('------------\n')
    ## scatter plot
    n_points = 4000
    # tsne plot
    tsne = DimReduce.TSNE(z[:n_points], n_components=2)
    kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.6)
    plt.figure(figsize=(8, 8))
    vs.plot_scatter(color=labels[:n_points], title=f'[True-tSNE]{title}', **kw)
    plt.figure(figsize=(8, 8))
    vs.plot_scatter(color=labels_pred[:n_points],
                    title=f'[Pred-tSNE]{title}',
                    **kw)
    # pca plot
    pca = DimReduce.PCA(z, n_components=2)
    kw = dict(x=pca[:, 0], y=pca[:, 1], grid=False, size=12.0, alpha=0.6)
    plt.figure(figsize=(8, 8))
    vs.plot_scatter(color=labels, title=f'[True-PCA]{title}', **kw)
    plt.figure(figsize=(8, 8))
    vs.plot_scatter(color=labels_pred, title=f'[Pred-PCA]{title}', **kw)
    ## factors plot
    corr = (Correlation.Spearman(z, y_true) +
            Correlation.Pearson(z, y_true)) / 2.
    best_z = np.argsort(np.abs(corr), axis=0)[-2:]
    style = dict(size=15.0, alpha=0.6, grid=False)
    for fi, (z1, z2) in enumerate(best_z.T):
        plt.figure(figsize=(8, 4))
        ax = plt.subplot(1, 2, 1)
        vs.plot_scatter(x=z[:n_points, z1],
                        y=z[:n_points, z2],
                        val=y_true[:n_points, fi],
                        ax=ax,
                        title=ds.labels[fi],
                        **style)
        ax = plt.subplot(1, 2, 2)
        vs.plot_scatter(x=z[:n_points, z1],
                        y=z[:n_points, z2],
                        val=y_pred[:n_points, fi],
                        ax=ax,
                        title=ds.labels[fi],
                        **style)
        plt.tight_layout()
    ## save all plot
    vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True)
コード例 #4
0
ファイル: compare_vaes.py プロジェクト: trungnt13/odin-ai
def evaluate(vae: VariationalAutoencoder,
             ds: ImageDataset,
             expdir: str,
             title: str,
             batch_size: int = 64,
             take_count: int = -1,
             n_images: int = 36,
             seed: int = 1):
    n_rows = int(np.sqrt(n_images))
    is_semi = vae.is_semi_supervised()
    is_hierarchical = vae.is_hierarchical()
    ds_kw = dict(batch_size=batch_size, label_percent=1.0, shuffle=False)
    ## prepare
    rand = np.random.RandomState(seed=seed)
    if not os.path.exists(expdir):
        os.makedirs(expdir)
    ## data for training semi-supervised
    train = ds.create_dataset('train', **ds_kw)
    (llkx_train, llky_train, x_org_train, x_rec_train, y_true_train,
     y_pred_train, z_train, pz_train) = _call(vae,
                                              ds=train,
                                              rand=rand,
                                              take_count=take_count,
                                              n_images=n_images,
                                              verbose=True)
    ## data for testing
    test = ds.create_dataset('test', **ds_kw)
    (llkx_test, llky_test, x_org_test, x_rec_test, y_true_test, y_pred_test,
     z_test, pz_test) = _call(vae,
                              ds=test,
                              rand=rand,
                              take_count=take_count,
                              n_images=n_images,
                              verbose=True)
    # === 0. plotting latent-factor pairs
    for idx, z in enumerate(z_test):
        z = z.mean()
        f = y_true_test
        corr_mat = Correlation.Spearman(z, f)  # [n_latents, n_factors]
        plot_latents_pairs(z, f, corr_mat, ds.labels)
        vs.plot_save(f'{expdir}/latent{idx}_factor.pdf', dpi=100, verbose=True)
    # === 0. latent traverse plot
    x_travs = x_org_test
    if x_travs.ndim == 3:  # grayscale image
        x_travs = np.expand_dims(x_travs, -1)
    else:  # color image
        x_travs = np.transpose(x_travs, (0, 2, 3, 1))
    x_travs = x_travs[rand.permutation(x_travs.shape[0])]
    n_visual_samples = 5
    n_traverse_points = 21
    n_top_latents = 10
    plt.figure(figsize=(8, 3 * n_visual_samples))
    for i in range(n_visual_samples):
        images = vae.sample_traverse(x_travs[i:i + 1],
                                     min_val=-np.min(z_test[0].mean()),
                                     max_val=np.max(z_test[0].mean()),
                                     n_best_latents=n_top_latents,
                                     n_traverse_points=n_traverse_points,
                                     mode='linear')
        images = as_tuple(images)[0]
        images = _prepare_images(images.mean().numpy(), normalize=True)
        vs.plot_images(images,
                       grids=(n_top_latents, n_traverse_points),
                       ax=(n_visual_samples, 1, i + 1))
        if i == 0:
            plt.title('Latents traverse')
    plt.tight_layout()
    vs.plot_save(f'{expdir}/latents_traverse.pdf', dpi=180, verbose=True)
    # === 0. prior sampling plot
    images = as_tuple(vae.sample_observation(n=n_images, seed=seed))[0]
    images = _prepare_images(images.mean().numpy(), normalize=True)
    plt.figure(figsize=(5, 5))
    vs.plot_images(images, grids=(n_rows, n_rows), title='Sampled')
    # === 1. reconstruction plot
    plt.figure(figsize=(15, 15))
    vs.plot_images(x_org_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 1),
                   title='[Train]Original')
    vs.plot_images(x_rec_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 2),
                   title='[Train]Reconstructed')
    vs.plot_images(x_org_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 3),
                   title='[Test]Original')
    vs.plot_images(x_rec_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 4),
                   title='[Test]Reconstructed')
    plt.tight_layout()
    ## prepare the labels
    label_type = ds.label_type
    if label_type == 'categorical':
        labels_name = ds.labels
        true = np.argmax(y_true_test, axis=-1)
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = np.argmax(y_pred_test.mean().numpy(), axis=-1)
            labels_pred = np.array([labels_name[i] for i in pred])
    elif label_type == 'factor':  # dsprites, shapes3d
        labels_name = ['cube', 'cylinder', 'sphere', 'round'] \
          if 'shapes3d' in ds.name else ['square', 'ellipse', 'heart']
        true = y_true_test[:, 2].astype('int32')
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = get_ymean(y_pred_test)[:, 2].astype('int32')
            labels_pred = np.array([labels_name[i] for i in pred])
    else:  # CelebA
        raise NotImplementedError
    ## confusion matrix
    if is_semi:
        plt.figure(figsize=(8, 8))
        acc = accuracy_score(y_true=true, y_pred=pred)
        vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred),
                                 labels=labels_name,
                                 cbar=True,
                                 fontsize=10,
                                 title=f'{title} Acc:{acc:.2f}')
    ## save arrays for later inspections
    with open(f'{expdir}/arrays', 'wb') as f:
        pickle.dump(
            dict(z_train=z_train,
                 y_pred_train=y_pred_train,
                 y_true_train=y_true_train,
                 z_test=z_test,
                 y_pred_test=y_pred_test,
                 y_true_test=y_true_test,
                 labels=labels_name,
                 ds=ds.name,
                 label_type=label_type), f)
    print(f'Exported arrays to "{expdir}/arrays"')
    ## semi-supervised
    z_mean_train = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_train], -1)
    z_mean_test = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_test], -1)
    # === 2. scatter points latents plot
    n_points = 5000
    ids = rand.permutation(len(labels_true))[:n_points]
    Y_true = labels_true[ids]
    Y_pred = labels_pred[ids]
    # tsne plot
    n_latents = 0 if len(z_train) == 1 else len(z_train)
    for name, X in zip(
        ['all'] + [f'latents{i}'
                   for i in range(n_latents)], [z_mean_test[ids]] +
        [z_test[i].mean().numpy()[ids] for i in range(n_latents)]):
        print(f'Plot scatter points for {name}')
        X = X.reshape(X.shape[0], -1)  # flatten to 2D
        X = Pipeline([('zscore', StandardScaler()),
                      ('pca', PCA(min(X.shape[1], 512),
                                  random_state=seed))]).fit_transform(X)
        tsne = DimReduce.TSNE(X, n_components=2, framework='sklearn')
        kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.8)
        plt.figure(figsize=(12, 6))
        vs.plot_scatter(color=Y_true,
                        title=f'[True]{title}-{name}',
                        ax=(1, 2, 1),
                        **kw)
        vs.plot_scatter(color=Y_pred,
                        title=f'[Pred]{title}-{name}',
                        ax=(1, 2, 2),
                        **kw)
    ## save all plot
    vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True)

    # === 3. show the latents statistics
    n_latents = len(z_train)
    colors = sns.color_palette(n_colors=len(labels_true))
    styles = dict(grid=False,
                  ticks_off=False,
                  alpha=0.6,
                  xlabel='mean',
                  ylabel='stddev')

    # scatter between latents and labels (assume categorical distribution)
    def _show_latents_labels(Z, Y, title):
        plt.figure(figsize=(5 * n_latents, 5), dpi=150)
        for idx, z in enumerate(Z):
            if len(z.batch_shape) == 0:
                mean = np.repeat(np.expand_dims(z.mean(), 0), Y.shape[0], 0)
                stddev = z.sample(Y.shape[0]) - mean
            else:
                mean = flatten(z.mean())
                stddev = flatten(z.stddev())
            y = np.argmax(Y, axis=-1)
            data = [[], [], []]
            for y_i, c in zip(np.unique(y), colors):
                mask = (y == y_i)
                data[0].append(np.mean(mean[mask], 0))
                data[1].append(np.mean(stddev[mask], 0))
                data[2].append([labels_true[y_i]] * mean.shape[1])
            vs.plot_scatter(
                x=np.concatenate(data[0], 0),
                y=np.concatenate(data[1], 0),
                color=np.concatenate(data[2], 0),
                ax=(1, n_latents, idx + 1),
                size=15 if mean.shape[1] < 128 else 8,
                title=f'[Test-{title}]#{idx} - {mean.shape[1]} (units)',
                **styles)
        plt.tight_layout()

    # simple scatter mean-stddev each latents
    def _show_latents(Z, title):
        plt.figure(figsize=(3.5 * n_latents, 3.5), dpi=150)
        for idx, z in enumerate(Z):
            mean = flatten(z.mean())
            stddev = flatten(z.stddev())
            if mean.ndim == 2:
                mean = np.mean(mean, 0)
                stddev = np.mean(stddev, 0)
            vs.plot_scatter(
                x=mean,
                y=stddev,
                ax=(1, n_latents, idx + 1),
                size=15 if len(mean) < 128 else 8,
                title=f'[Test-{title}]#{idx} - {len(mean)} (units)',
                **styles)

    _show_latents_labels(z_test, y_true_test, 'post')
    _show_latents_labels(pz_test, y_true_test, 'prior')
    _show_latents(z_test, 'post')
    _show_latents(pz_test, 'prior')

    # KL statistics
    vs.plot_figure()
    for idx, (qz, pz) in enumerate(zip(z_test, pz_test)):
        kl = []
        qz = Normal(loc=qz.mean(), scale=qz.stddev(), name=f'posterior{idx}')
        pz = Normal(loc=pz.mean(), scale=pz.stddev(), name=f'prior{idx}')
        for s, e in minibatch(batch_size=8, n=100):
            z = qz.sample(e - s)
            # don't do this in GPU, it explodes!
            kl.append((qz.log_prob(z) - pz.log_prob(z)).numpy())
        kl = np.concatenate(kl, 0)  # (mcmc, batch, event)
        # per sample
        kl_samples = np.sum(kl, as_tuple(list(range(2, kl.ndim))))
        kl_samples = logsumexp(kl_samples, 0)
        plt.subplot(n_latents, 2, idx * 2 + 1)
        sns.histplot(kl_samples, bins=50)
        plt.title(f'Z#{idx} KL per sample (nats)')
        # per latent
        kl_latents = np.mean(flatten(logsumexp(kl, 0)), 0)
        plt.subplot(n_latents, 2, idx * 2 + 2)
        plt.plot(np.sort(kl_latents))
        plt.title(f'Z#{idx} KL per dim (nats)')
    plt.tight_layout()

    vs.plot_save(f'{expdir}/latents.pdf', dpi=180, verbose=True)
コード例 #5
0
def plot_evaluate_classifier(y_pred,
                             y_true,
                             labels,
                             title,
                             show_plot=True,
                             return_figure=False):
    r"""
  Arguments:
    fig : Figure or tuple (`float`, `float`), optional (default=`None`)
      width, height in inches

  Returns:
    Return a dictionary of scores
    {
        F1micro=f1_micro * 100,
        F1macro=f1_macro * 100,
        F1weight=f1_weight * 100,
        F1_[classname]=...
    }
  """
    from matplotlib import pyplot as plt
    fontsize = 12
    num_classes = len(labels)
    nrow = int(np.ceil(num_classes / 5))
    ncol = int(np.ceil(num_classes / nrow))

    if y_pred.ndim == 1:
        y_pred = one_hot(y_pred, nb_classes=num_classes)
    if y_true.ndim == 1:
        y_true = one_hot(y_true, nb_classes=num_classes)

    if show_plot:
        fig = plot_figure(nrow=4 * nrow + 3, ncol=4 * ncol)

    f1_classes = []
    for i, (name, pred, true) in enumerate(zip(labels, y_pred.T, y_true.T)):
        f1_classes.append(f1_score(true, pred))
        if show_plot:
            plot_confusion_matrix(confusion_matrix(y_true=true, y_pred=pred),
                                  labels=[0, 1],
                                  fontsize=fontsize,
                                  ax=(nrow, ncol, i + 1),
                                  title=name + '\n')

    f1_micro = f1_score(y_true=y_true.ravel(), y_pred=y_pred.ravel())
    f1_macro = np.mean(f1_classes)
    f1_weight = f1_score(y_true=y_true, y_pred=y_pred, average='weighted')

    if show_plot:
        plt.suptitle('%s\nF1-micro:%.2f  F1-macro:%.2f  F1-weight:%.2f' %
                     (title, f1_micro * 100, f1_macro * 100, f1_weight * 100),
                     fontsize=fontsize + 6)
        plt.tight_layout(rect=[0, 0.04, 1, 0.96])
    results = dict(
        F1micro=f1_micro * 100,
        F1macro=f1_macro * 100,
        F1weight=f1_weight * 100,
    )
    for name, f1 in zip(labels, f1_classes):
        results['F1_' + name] = f1 * 100
    if show_plot and return_figure:
        return results, fig
    return results
コード例 #6
0
def plot_evaluate_reconstruction(X,
                                 W,
                                 y_raw,
                                 y_prob,
                                 X_row,
                                 X_col,
                                 labels,
                                 title,
                                 pi=None,
                                 enable_image=True,
                                 enable_tsne=True,
                                 enable_sparsity=True):
    """
  pi : zero-inflated rate (imputation rate), or dropout probabilities
  """
    print("Evaluate: [Reconstruction]", ctext(title, 'lightyellow'))
    from matplotlib import pyplot as plt
    fontsize = 12

    W_stdev_total, W_stdev_explained = None, None
    if isinstance(W, (tuple, list)):
        if len(W) == 1:
            W = W[0]
        elif len(W) == 3:
            W, W_stdev_total, W_stdev_explained = W
        else:
            raise RuntimeError()
    elif W.ndim == 3:
        W, W_stdev_total, W_stdev_explained = W[0], W[1], W[2]
    # convert the prediction to integer
    # W = W.astype('int32')

    assert (X.shape[0] == W.shape[0] == y_raw.shape[0]) and \
    (X.shape == W.shape) and \
    (y_raw.shape == y_prob.shape)

    X, X_row, W, y_raw, y_prob, pi = downsample_data(X, X_row, W, y_raw,
                                                     y_prob, pi)
    y_argmax = np.argmax(y_prob, axis=-1)

    # ====== prepare count-sum ====== #
    X_log = K.log_norm(X, axis=1)
    W_log = K.log_norm(W, axis=1)

    X_gene_countsum = np.sum(X, axis=0)
    X_cell_countsum = np.sum(X, axis=1)
    X_gene_nzeros = np.sum(X == 0, axis=0)
    X_cell_nzeros = np.sum(X == 0, axis=1)

    gene_sort = np.argsort(X_gene_countsum)
    cell_sort = np.argsort(X_cell_countsum)

    W_gene_countsum = np.sum(W, axis=0)
    W_cell_countsum = np.sum(W, axis=1)
    W_gene_nzeros = np.sum(W == 0, axis=0)
    W_cell_nzeros = np.sum(W == 0, axis=1)

    X_col_sorted = X_col[gene_sort] if X_col is not None else None
    X_row_sorted = X_row[cell_sort] if X_row is not None else None

    if pi is not None:
        pi_cell_countsum = np.mean(pi, axis=1)
        pi_gene_countsum = np.mean(pi, axis=0)

    # ====== Compare image ====== #
    if enable_image:
        _RAND = np.random.RandomState(seed=87654321)
        n_img = 12
        n_img_row = min(3, X.shape[0] // n_img)
        n_row_per_row = 2 if pi is None else 3
        plot_figure(nrow=n_img_row * 4, ncol=18)
        count = 1
        all_ids = _RAND.choice(np.arange(0, X.shape[0]),
                               size=n_img * n_img_row,
                               replace=False)

        for img_row in range(n_img_row):
            ids = all_ids[img_row * n_img:(img_row + 1) * n_img]

            # plot original images
            for _, i in enumerate(ids):
                ax = plt.subplot(n_row_per_row * n_img_row, n_img, count)
                show_image(X[i])
                if _ == 0:
                    plt.ylabel("Original")
                if X_row is not None:
                    ax.set_title(X_row[i], fontsize=8)
                count += 1
            # plot reconstructed images
            for _, i in enumerate(ids):
                plt.subplot(n_row_per_row * n_img_row, n_img, count)
                show_image(W[i])
                if _ == 0:
                    plt.ylabel("Reconstructed")
                count += 1
            # plot zero-inflated rate
            if pi is not None:
                for _, i in enumerate(ids):
                    plt.subplot(n_row_per_row * n_img_row, n_img, count)
                    show_image(pi[i], is_probability=True)
                    if _ == 0:
                        plt.ylabel("$p_{zero-inflated}$")
                    count += 1
        plt.tight_layout()
    # ====== compare the T-SNE plot ====== #
    if enable_tsne:

        def pca_and_tsne(x, w):
            x_pca, w_pca = fast_pca(x,
                                    w,
                                    n_components=512,
                                    random_state=87654321)
            x_tsne = fast_tsne(x_pca, n_components=2, random_state=87654321)
            w_tsne = fast_tsne(w_pca, n_components=2, random_state=87654321)
            return x_pca[:, :2], x_tsne, w_pca[:, :2], w_tsne

        # transforming the data
        (X_cell_pca, X_cell_tsne, W_cell_pca,
         W_cell_tsne) = pca_and_tsne(X_log, W_log)
        (X_gene_pca, X_gene_tsne, W_gene_pca,
         W_gene_tsne) = pca_and_tsne(X_log.T, W_log.T)
        # prepare the figure
        n_plot = 3 + 2  # 3 for cells, 2 for genes
        if pi is not None:
            n_plot += 2  # 2 more row for pi
        plot_figure(nrow=n_plot * 5, ncol=18)
        # Cells
        fast_scatter(x=X_cell_pca,
                     y=y_argmax,
                     labels=labels,
                     title="[PCA]Original Cell Data",
                     ax=(n_plot, 2, 1),
                     enable_legend=False)
        fast_scatter(x=W_cell_pca,
                     y=y_argmax,
                     labels=labels,
                     title="[PCA]Reconstructed Cell Data",
                     ax=(n_plot, 2, 2),
                     enable_legend=False)

        fast_scatter(x=X_cell_tsne,
                     y=y_argmax,
                     labels=labels,
                     title="[t-SNE]Original Cell Data",
                     ax=(n_plot, 2, 3),
                     enable_legend=True)
        fast_scatter(x=W_cell_tsne,
                     y=y_argmax,
                     labels=labels,
                     title="[t-SNE]Reconstructed Cell Data",
                     ax=(n_plot, 2, 4),
                     enable_legend=False)

        fast_log = lambda x: K.log_norm(x, axis=0)

        plot_scatter(
            x=X_cell_tsne,
            val=fast_log(X_cell_countsum),
            title="[t-SNE]Original Cell Data + Original Cell Countsum",
            ax=(n_plot, 2, 5),
            colorbar=True)
        plot_scatter(
            x=X_cell_tsne,
            val=fast_log(W_cell_countsum),
            title="[t-SNE]Original Cell Data + Reconstructed Cell Countsum",
            ax=(n_plot, 2, 6),
            colorbar=True)
        # Genes
        plot_scatter(x=X_gene_pca,
                     val=fast_log(X_gene_countsum),
                     title="[PCA]Original Gene Data + Original Gene Countsum",
                     ax=(n_plot, 2, 7),
                     colorbar=True)
        plot_scatter(
            x=W_gene_pca,
            val=fast_log(X_gene_countsum),
            title="[PCA]Reconstructed Gene Data + Original Gene Countsum",
            ax=(n_plot, 2, 8),
            colorbar=True)

        plot_scatter(
            x=X_gene_tsne,
            val=fast_log(X_gene_countsum),
            title="[t-SNE]Original Gene Data + Original Gene Countsum",
            ax=(n_plot, 2, 9),
            colorbar=True)
        plot_scatter(
            x=X_gene_tsne,
            val=fast_log(W_gene_countsum),
            title="[t-SNE]Original Gene Data + Reconstructed Gene Countsum",
            ax=(n_plot, 2, 10),
            colorbar=True)
        # zero-inflation rate
        if pi is not None:
            plot_scatter(
                x=X_cell_tsne,
                val=X_cell_countsum,
                title="[t-SNE]Original Cell Data + Original Cell Countsum",
                ax=(n_plot, 2, 11),
                colorbar=True)
            plot_scatter(
                x=X_cell_tsne,
                val=pi_cell_countsum,
                title="[t-SNE]Original Cell Data + Zero-inflated rate",
                ax=(n_plot, 2, 12),
                colorbar=True)

            plot_scatter(
                x=X_gene_tsne,
                val=X_gene_countsum,
                title="[t-SNE]Original Gene Data + Original Gene Countsum",
                ax=(n_plot, 2, 13),
                colorbar=True)
            plot_scatter(
                x=X_gene_tsne,
                val=pi_gene_countsum,
                title="[t-SNE]Original Gene Data + Zero-inflated rate",
                ax=(n_plot, 2, 14),
                colorbar=True)
        plt.tight_layout()
    # ******************** sparsity ******************** #
    if enable_sparsity:
        plot_figure(nrow=8, ncol=8)
        # ====== sparsity ====== #
        z = (X.ravel() == 0).astype('int32')
        z_res = (W.ravel() == 0).astype('int32')
        plot_confusion_matrix(ax=None,
                              cm=confusion_matrix(y_true=z,
                                                  y_pred=z_res,
                                                  labels=(0, 1)),
                              labels=('Not Zero', 'Zero'),
                              colorbar=True,
                              fontsize=fontsize + 4,
                              title="Sparsity")
コード例 #7
0
def evaluate(y_true, y_pred_proba=None, y_pred_log_proba=None,
             labels=None, title='', path=None,
             xlims=None, ylims=None, print_log=True):
  from odin.backend import to_llr
  from odin.backend.metrics import (det_curve, compute_EER, roc_curve,
                                    compute_Cavg, compute_Cnorm,
                                    compute_minDCF)

  def format_score(s):
    return ctext('%.4f' % s if is_number(s) else s, 'yellow')
  nb_classes = None
  # ====== check y_pred ====== #
  if y_pred_proba is None and y_pred_log_proba is None:
    raise ValueError("At least one of `y_pred_proba` or `y_pred_log_proba` "
                     "must not be None")
  y_pred_llr = to_llr(y_pred_proba) if y_pred_log_proba is None \
      else to_llr(y_pred_log_proba)
  nb_classes = y_pred_llr.shape[1]
  y_pred = np.argmax(y_pred_llr, axis=-1)
  # ====== check y_true ====== #
  if isinstance(y_true, (tuple, list)):
    y_true = np.array(y_true)
  if y_true.ndim == 2: # convert one-hot to labels
    y_true = np.argmax(y_true, axis=-1)
  # ====== check labels ====== #
  if labels is None:
    labels = [str(i) for i in range(nb_classes)]
  # ====== scoring ====== #
  if y_pred_proba is None:
    ll = 'unknown'
  else:
    ll = log_loss(y_true=y_true, y_pred=y_pred_proba)
  acc = accuracy_score(y_true=y_true, y_pred=y_pred)
  cm = confusion_matrix(y_true=y_true, y_pred=y_pred)
  # C_norm
  cnorm, cnorm_arr = compute_Cnorm(y_true=y_true,
                                   y_score=y_pred_llr,
                                   Ptrue=[0.1, 0.5],
                                   probability_input=False)
  if y_pred_log_proba is not None:
    cnorm_, cnorm_arr_ = compute_Cnorm(y_true=y_true,
                                       y_score=y_pred_log_proba,
                                       Ptrue=[0.1, 0.5],
                                       probability_input=False)
    if np.mean(cnorm) > np.mean(cnorm_): # smaller is better
      cnorm, cnorm_arr = cnorm_, cnorm_arr_
  # DET
  Pfa, Pmiss = det_curve(y_true=y_true, y_score=y_pred_llr)
  eer = compute_EER(Pfa=Pfa, Pmiss=Pmiss)
  minDCF = compute_minDCF(Pfa, Pmiss)[0]
  # PRINT LOG
  if print_log:
    print(ctext("--------", 'red'), ctext(title, 'cyan'))
    print("Log loss :", format_score(ll))
    print("Accuracy :", format_score(acc))
    print("C_norm   :", format_score(np.mean(cnorm)))
    print("EER      :", format_score(eer))
    print("minDCF   :", format_score(minDCF))
    print(print_confusion(arr=cm, labels=labels))
  # ====== save report to PDF files if necessary ====== #
  if path is not None:
    if y_pred_proba is None:
      y_pred_proba = y_pred_llr
    from matplotlib import pyplot as plt
    plt.figure(figsize=(nb_classes, nb_classes + 1))
    plot_confusion_matrix(cm, labels)
    # Cavg
    plt.figure(figsize=(nb_classes + 1, 3))
    plot_Cnorm(cnorm=cnorm_arr, labels=labels, Ptrue=[0.1, 0.5],
               fontsize=14)
    # binary classification
    if nb_classes == 2 and \
    (y_pred_proba.ndim == 1 or (y_pred_proba.ndim == 2 and
                                y_pred_proba.shape[1] == 1)):
      fpr, tpr = roc_curve(y_true=y_true, y_score=y_pred_proba.ravel())
      # det curve
      plt.figure()
      plot_detection_curve(Pfa, Pmiss, curve='det',
                           xlims=xlims, ylims=ylims, linewidth=1.2)
      # roc curve
      plt.figure()
      plot_detection_curve(fpr, tpr, curve='roc')
    # multiclasses
    else:
      y_true = one_hot(y_true, nb_classes=nb_classes)
      fpr_micro, tpr_micro, _ = roc_curve(y_true=y_true.ravel(),
                                          y_score=y_pred_proba.ravel())
      Pfa_micro, Pmiss_micro = Pfa, Pmiss
      fpr, tpr = [], []
      Pfa, Pmiss = [], []
      for i, yi in enumerate(y_true.T):
        curve = roc_curve(y_true=yi, y_score=y_pred_proba[:, i])
        fpr.append(curve[0])
        tpr.append(curve[1])
        curve = det_curve(y_true=yi, y_score=y_pred_llr[:, i])
        Pfa.append(curve[0])
        Pmiss.append(curve[1])
      plt.figure()
      plot_detection_curve(fpr_micro, tpr_micro, curve='roc',
                           linewidth=1.2, title="ROC Micro")
      plt.figure()
      plot_detection_curve(fpr, tpr, curve='roc',
                           labels=labels, linewidth=1.0,
                           title="ROC for each classes")
      plt.figure()
      plot_detection_curve(Pfa_micro, Pmiss_micro, curve='det',
                           xlims=xlims, ylims=ylims, linewidth=1.2,
                           title="DET Micro")
      plt.figure()
      plot_detection_curve(Pfa, Pmiss, curve='det',
                           xlims=xlims, ylims=ylims,
                           labels=labels, linewidth=1.0,
                           title="DET for each classes")
    plot_save(path)
コード例 #8
0
ファイル: base.py プロジェクト: imito/odin
def evaluate(y_true, y_pred_proba=None, y_pred_log_proba=None,
             labels=None, title='', path=None,
             xlims=None, ylims=None, print_log=True):
  from odin.backend import to_llr
  from odin.backend.metrics import (det_curve, compute_EER, roc_curve,
                                    compute_Cavg, compute_Cnorm,
                                    compute_minDCF)

  def format_score(s):
    return ctext('%.4f' % s if is_number(s) else s, 'yellow')
  nb_classes = None
  # ====== check y_pred ====== #
  if y_pred_proba is None and y_pred_log_proba is None:
    raise ValueError("At least one of `y_pred_proba` or `y_pred_log_proba` "
                     "must not be None")
  y_pred_llr = to_llr(y_pred_proba) if y_pred_log_proba is None \
      else to_llr(y_pred_log_proba)
  nb_classes = y_pred_llr.shape[1]
  y_pred = np.argmax(y_pred_llr, axis=-1)
  # ====== check y_true ====== #
  if isinstance(y_true, Data):
    y_true = y_true.array
  if isinstance(y_true, (tuple, list)):
    y_true = np.array(y_true)
  if y_true.ndim == 2: # convert one-hot to labels
    y_true = np.argmax(y_true, axis=-1)
  # ====== check labels ====== #
  if labels is None:
    labels = [str(i) for i in range(nb_classes)]
  # ====== scoring ====== #
  if y_pred_proba is None:
    ll = 'unknown'
  else:
    ll = log_loss(y_true=y_true, y_pred=y_pred_proba)
  acc = accuracy_score(y_true=y_true, y_pred=y_pred)
  cm = confusion_matrix(y_true=y_true, y_pred=y_pred)
  # C_norm
  cnorm, cnorm_arr = compute_Cnorm(y_true=y_true,
                                   y_score=y_pred_llr,
                                   Ptrue=[0.1, 0.5],
                                   probability_input=False)
  if y_pred_log_proba is not None:
    cnorm_, cnorm_arr_ = compute_Cnorm(y_true=y_true,
                                       y_score=y_pred_log_proba,
                                       Ptrue=[0.1, 0.5],
                                       probability_input=False)
    if np.mean(cnorm) > np.mean(cnorm_): # smaller is better
      cnorm, cnorm_arr = cnorm_, cnorm_arr_
  # DET
  Pfa, Pmiss = det_curve(y_true=y_true, y_score=y_pred_llr)
  eer = compute_EER(Pfa=Pfa, Pmiss=Pmiss)
  minDCF = compute_minDCF(Pfa, Pmiss)[0]
  # PRINT LOG
  if print_log:
    print(ctext("--------", 'red'), ctext(title, 'cyan'))
    print("Log loss :", format_score(ll))
    print("Accuracy :", format_score(acc))
    print("C_norm   :", format_score(np.mean(cnorm)))
    print("EER      :", format_score(eer))
    print("minDCF   :", format_score(minDCF))
    print(print_confusion(arr=cm, labels=labels))
  # ====== save report to PDF files if necessary ====== #
  if path is not None:
    if y_pred_proba is None:
      y_pred_proba = y_pred_llr
    from matplotlib import pyplot as plt
    plt.figure(figsize=(nb_classes, nb_classes + 1))
    plot_confusion_matrix(cm, labels)
    # Cavg
    plt.figure(figsize=(nb_classes + 1, 3))
    plot_Cnorm(cnorm=cnorm_arr, labels=labels, Ptrue=[0.1, 0.5],
               fontsize=14)
    # binary classification
    if nb_classes == 2 and \
    (y_pred_proba.ndim == 1 or (y_pred_proba.ndim == 2 and
                                y_pred_proba.shape[1] == 1)):
      fpr, tpr = roc_curve(y_true=y_true, y_score=y_pred_proba.ravel())
      # det curve
      plt.figure()
      plot_detection_curve(Pfa, Pmiss, curve='det',
                           xlims=xlims, ylims=ylims, linewidth=1.2)
      # roc curve
      plt.figure()
      plot_detection_curve(fpr, tpr, curve='roc')
    # multiclasses
    else:
      y_true = one_hot(y_true, nb_classes=nb_classes)
      fpr_micro, tpr_micro, _ = roc_curve(y_true=y_true.ravel(),
                                          y_score=y_pred_proba.ravel())
      Pfa_micro, Pmiss_micro = Pfa, Pmiss
      fpr, tpr = [], []
      Pfa, Pmiss = [], []
      for i, yi in enumerate(y_true.T):
        curve = roc_curve(y_true=yi, y_score=y_pred_proba[:, i])
        fpr.append(curve[0])
        tpr.append(curve[1])
        curve = det_curve(y_true=yi, y_score=y_pred_llr[:, i])
        Pfa.append(curve[0])
        Pmiss.append(curve[1])
      plt.figure()
      plot_detection_curve(fpr_micro, tpr_micro, curve='roc',
                           linewidth=1.2, title="ROC Micro")
      plt.figure()
      plot_detection_curve(fpr, tpr, curve='roc',
                           labels=labels, linewidth=1.0,
                           title="ROC for each classes")
      plt.figure()
      plot_detection_curve(Pfa_micro, Pmiss_micro, curve='det',
                           xlims=xlims, ylims=ylims, linewidth=1.2,
                           title="DET Micro")
      plt.figure()
      plot_detection_curve(Pfa, Pmiss, curve='det',
                           xlims=xlims, ylims=ylims,
                           labels=labels, linewidth=1.0,
                           title="DET for each classes")
    plot_save(path)