예제 #1
0
def plot(train, score, title, applying_pca=False):
    if applying_pca:
        pca = PCA(n_components=NUM_DIM)
        pca.fit(train)
        train = pca.transform(train)
        score = pca.transform(score)
    plot_figure(nrow=6, ncol=12)
    plot_scatter(x=train[:, 0],
                 y=train[:, 1],
                 z=None if NUM_DIM < 3 or train.shape[1] < 3 else train[:, 2],
                 size=POINT_SIZE,
                 color=y_train_color,
                 marker=y_train_marker,
                 fontsize=12,
                 legend=legends,
                 title='[train]' + str(title),
                 ax=(1, 2, 1))
    plot_scatter(x=score[:, 0],
                 y=score[:, 1],
                 z=None if NUM_DIM < 3 or score.shape[1] < 3 else score[:, 2],
                 size=POINT_SIZE,
                 color=y_score_color,
                 marker=y_score_marker,
                 fontsize=12,
                 legend=legends,
                 title='[score]' + str(title),
                 ax=(1, 2, 2))
예제 #2
0
def plot_evaluate_regressor(y_pred, y_true, labels, title):
    from matplotlib import pyplot as plt
    num_classes = len(labels)
    nbins = 120
    fontsize = 8
    y_pred = np.round(y_pred).astype('int32')
    y_true = y_true.astype('int32')
    # ====== basic scores ====== #
    r2 = r2_score(y_true, y_pred)
    var = explained_variance_score(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)

    # ====== helper ====== #
    def plot_hist(hist, ax, name):
        count, bins = plot_histogram(true,
                                     bins=nbins,
                                     ax=ax,
                                     title=name,
                                     fontsize=fontsize)
        plt.xlim((np.min(bins), np.max(bins)))
        plt.xticks(np.linspace(start=np.min(bins),
                               stop=np.max(bins),
                               num=8,
                               dtype='int32'),
                   fontsize=6)
        plt.yticks(np.linspace(start=np.min(count),
                               stop=np.max(count),
                               num=8,
                               dtype='int32'),
                   fontsize=6)

    # ====== raw count prediction ====== #
    plot_figure(nrow=4, ncol=num_classes * 2)
    for i in range(num_classes):
        name = labels[i]
        r2_ = r2_score(y_true=y_true[:, i], y_pred=y_pred[:, i])
        pred = _clipping_quartile(y_pred[:, i])
        true = _clipping_quartile(y_true[:, i])
        plot_hist(hist=true,
                  ax=(2, num_classes, i + 1),
                  name='[%s] R2: %.6f' % (name, r2_))
        if i == 0:
            plt.ylabel('True')
        plot_hist(hist=pred,
                  ax=(2, num_classes, num_classes + i + 1),
                  name=None)
        if i == 0:
            plt.ylabel('Pred')
    # set the title
    plt.suptitle('[%s]  R2: %.6f   ExpVAR: %.6f   MSE: %.6f   MAE: %.6f' %
                 (str(title), r2, var, mse, mae),
                 fontsize=fontsize + 1)
예제 #3
0
def evaluate_latent(fn, feeder, title):
    y_true = []
    Z = []
    for outputs in Progbar(feeder.set_batch(batch_mode='file'),
                           name=title,
                           print_report=True,
                           print_summary=False,
                           count_func=lambda x: x[-1].shape[0]):
        name = str(outputs[0])
        idx = int(outputs[1])
        data = outputs[2:]
        assert idx == 0
        y_true.append(name)
        Z.append(fn(*data))
    Z = np.concatenate(Z, axis=0)
    # ====== visualize spectrogram ====== #
    if Z.ndim >= 3:
        sample = np.random.choice(range(len(Z)), size=3, replace=False)
        spec = Z[sample.astype('int32')]
        y = [y_true[int(i)] for i in sample]
        plot_figure(nrow=6, ncol=6)
        for i, (s, tit) in enumerate(zip(spec, y)):
            s = s.reshape(len(s), -1)
            plot_spectrogram(s.T, ax=(1, 3, i + 1), title=tit)
    # ====== visualize each point ====== #
    # flattent to 2D
    Z = np.reshape(Z, newshape=(len(Z), -1))
    # tsne if necessary
    if Z.shape[-1] > 3:
        Z = fast_tsne(Z,
                      n_components=3,
                      n_jobs=8,
                      random_state=K.get_rng().randint(0, 10e8))
    # color and marker
    Z_color = [digit_color_map[i.split('_')[-1]] for i in y_true]
    Z_marker = [gender_marker_map[i.split('_')[1]] for i in y_true]
    plot_figure(nrow=6, ncol=20)
    for i, azim in enumerate((15, 60, 120)):
        plot_scatter(x=Z[:, 0],
                     y=Z[:, 1],
                     z=Z[:, 2],
                     ax=(1, 3, i + 1),
                     size=4,
                     color=Z_color,
                     marker=Z_marker,
                     azim=azim,
                     legend=legends if i == 1 else None,
                     legend_ncol=11,
                     fontsize=10,
                     title=title)
    plot_save(os.path.join(FIG_PATH, '%s.pdf' % title))
예제 #4
0
def visualize_latent_space(X_org, X_latent, name, labels, title):
    """
  X_org : [n_samples, n_timesteps, n_features]
  X_latent : [n_samples, n_timesteps, n_latents]
  """
    assert X_org.shape[0] == X_latent.shape[0] == len(name) == len(labels)
    assert not np.any(np.isnan(X_org))
    assert not np.any(np.isnan(X_latent))
    X_org = X_org.astype('float32')
    X_latent = X_latent.astype('float32')
    # ====== evaluation of the latent space ====== #
    n_channels = 1 if X_latent.ndim == 3 else int(np.prod(X_latent.shape[3:]))
    n_samples = X_org.shape[0]
    # 1 for original, 1 for mean channel, then the rest
    n_row = 1 + 1 + n_channels
    n_col = 3
    V.plot_figure(nrow=n_row + 1, ncol=16)
    # only select 3 random sample
    for i, idx in enumerate(
            sampling_iter(it=range(n_samples), k=n_col, seed=1234)):
        x = X_org[idx]
        # latent tensor can be 3D or 4D
        z = X_latent[idx]
        if z.ndim > 3:
            z = np.reshape(z, newshape=(z.shape[0], z.shape[1], -1))
        elif z.ndim == 2:
            z = np.reshape(z, newshape=(z.shape[0], z.shape[1], 1))
        elif z.ndim == 3:
            pass
        else:
            raise ValueError("No support for z value: %s" % str(z.shape))
        # plot original acoustic
        ax = V.plot_spectrogram(x.T, ax=(n_row, n_col, i + 1), title='Org')
        if i == 0:
            ax.set_title("[%s]'%s-%s'" %
                         (str(title), str(name[idx]), str(labels[idx])),
                         fontsize=8)
        else:
            ax.set_title("'%s-%s'" % (str(name[idx]), str(labels[idx])),
                         fontsize=8)
        # plot the mean
        V.plot_spectrogram(np.mean(z, axis=-1).T,
                           ax=(n_row, n_col, i + 4),
                           title='Zmean')
        # plot first 25 channels
        if n_channels > 1:
            for j in range(min(8, n_channels)):
                V.plot_spectrogram(z[:, :, j].T,
                                   ax=(n_row, n_col, j * 3 + 7 + i),
                                   title='Z%d' % j)
예제 #5
0
파일: utils.py 프로젝트: imito/odin
def visualize_latent_space(X_org, X_latent, name, labels, title):
  """
  X_org : [n_samples, n_timesteps, n_features]
  X_latent : [n_samples, n_timesteps, n_latents]
  """
  assert X_org.shape[0] == X_latent.shape[0] == len(name) == len(labels)
  assert not np.any(np.isnan(X_org))
  assert not np.any(np.isnan(X_latent))
  X_org = X_org.astype('float32')
  X_latent = X_latent.astype('float32')
  # ====== evaluation of the latent space ====== #
  n_channels = 1 if X_latent.ndim == 3 else int(np.prod(X_latent.shape[3:]))
  n_samples = X_org.shape[0]
  # 1 for original, 1 for mean channel, then the rest
  n_row = 1 + 1 + n_channels
  n_col = 3
  V.plot_figure(nrow=n_row + 1, ncol=16)
  # only select 3 random sample
  for i, idx in enumerate(
      sampling_iter(it=range(n_samples), k= n_col, seed=5218)):
    x = X_org[idx]
    # latent tensor can be 3D or 4D
    z = X_latent[idx]
    if z.ndim > 3:
      z = np.reshape(z, newshape=(z.shape[0], z.shape[1], -1))
    elif z.ndim == 2:
      z = np.reshape(z, newshape=(z.shape[0], z.shape[1], 1))
    elif z.ndim == 3:
      pass
    else:
      raise ValueError("No support for z value: %s" % str(z.shape))
    # plot original acoustic
    ax = V.plot_spectrogram(x.T, ax=(n_row, n_col, i + 1), title='Org')
    if i == 0:
      ax.set_title("[%s]'%s-%s'" % (str(title), str(name[idx]), str(labels[idx])),
                   fontsize=8)
    else:
      ax.set_title("'%s-%s'" % (str(name[idx]), str(labels[idx])),
                   fontsize=8)
    # plot the mean
    V.plot_spectrogram(np.mean(z, axis=-1).T,
                       ax=(n_row, n_col, i + 4), title='Zmean')
    # plot first 25 channels
    if n_channels > 1:
      for j in range(min(8, n_channels)):
        V.plot_spectrogram(z[:, :, j].T,
                           ax=(n_row, n_col, j * 3 + 7 + i),
                           title='Z%d' % j)
예제 #6
0
 def plot_histogram(self, histogram_bins=120, original_factors=True):
   r"""
   orginal_factors : optional original factors before discretized by
     `Criticizer`
   """
   self.assert_sampled()
   from matplotlib import pyplot as plt
   ## prepare the data
   Z = np.concatenate(self.representations_mean, axis=0)
   F = np.concatenate(
       self.original_factors if original_factors else self.factors, axis=0)
   X = [i for i in F.T] + [i for i in Z.T]
   labels = self.factors_name.tolist() + self.codes_name.tolist()
   # create the figure
   ncol = int(np.ceil(np.sqrt(len(X)))) + 1
   nrow = int(np.ceil(len(X) / ncol))
   fig = vs.plot_figure(nrow=18, ncol=25, dpi=80)
   for i, (x, lab) in enumerate(zip(X, labels)):
     vs.plot_histogram(x,
                       ax=(nrow, ncol, i + 1),
                       bins=int(histogram_bins),
                       title=lab,
                       alpha=0.8,
                       color='blue',
                       fontsize=16)
   plt.tight_layout()
   self.add_figure(
       "histogram_%s" % ("original" if original_factors else "discretized"),
       fig)
   return self
예제 #7
0
 def plot_histogram(
     self,
     histogram_bins: int = 120,
     original_factors: bool = True,
     return_figure: bool = False,
 ):
     Z = self.dist_to_tensor(self.latents).numpy()
     F = self.factors_original if original_factors else self.factors
     X = [i for i in F.T] + [i for i in Z.T]
     labels = self.factor_names + self.latent_names
     # create the figure
     ncol = int(np.ceil(np.sqrt(len(X)))) + 1
     nrow = int(np.ceil(len(X) / ncol))
     fig = vs.plot_figure(nrow=12, ncol=20, dpi=100)
     for i, (x, lab) in enumerate(zip(X, labels)):
         vs.plot_histogram(x,
                           ax=(nrow, ncol, i + 1),
                           bins=int(histogram_bins),
                           title=lab,
                           alpha=0.8,
                           color='blue',
                           fontsize=16)
     fig.tight_layout()
     if return_figure:
         return fig
     return self.add_figure(
         f"histogram_{'original' if original_factors else 'discretized'}",
         fig)
예제 #8
0
    def plot_distribution(self, X, labels=None):
        X, labels, n_classes = self._check_input(X, labels)

        X_bin = self.predict(X)
        X_prob = self.predict_proba(X)

        normalize_to_01 = lambda x: x / np.sum(x)
        dist_raw = normalize_to_01(np.sum(X, axis=0))
        dist_bin = normalize_to_01(np.sum(X_bin, axis=0))
        dist_prob = normalize_to_01(np.sum(X_prob, axis=0))
        x = np.arange(n_classes)

        fig = plot_figure(nrow=3, ncol=int(n_classes * 1.2))
        ax = plt.gca()

        colors = sns.color_palette(n_colors=3)
        bar1 = ax.bar(x, dist_raw, width=0.2, color=colors[0], alpha=0.8)
        bar2 = ax.bar(x + 0.2, dist_bin, width=0.2, color=colors[1], alpha=0.8)
        bar3 = ax.bar(x + 0.4,
                      dist_prob,
                      width=0.2,
                      color=colors[2],
                      alpha=0.8)

        ax.set_xticks(x + 0.2)
        ax.set_xticklabels(labels, rotation=-10)
        ax.legend([bar1, bar2, bar3],
                  ['Original', 'Binarized', 'Probabilized'])

        ax.grid(True, axis='y')
        ax.set_axisbelow(True)
        self.add_figure('distribution', fig)
        return self
예제 #9
0
 def plot_uncertainty_statistics(self, factors=None):
   r"""
   factors : list of Integer or String. The index or name of factors taken
     into account for analyzing.
   """
   factors = self._check_factors(factors)
   zmean = np.concatenate(self.representations_mean, axis=0)
   zstd = np.sqrt(np.concatenate(self.representations_variance, axis=0))
   labels = self.factors_name[factors]
   factors = np.concatenate(self.original_factors, axis=0)[:, factors]
   X = np.arange(zmean.shape[0])
   # create the figure
   nrow = self.n_representations
   ncol = len(labels)
   fig = vs.plot_figure(nrow=nrow * 4, ncol=ncol * 4, dpi=80)
   plot = 1
   for row, (code, mean,
             std) in enumerate(zip(self.codes_name, zmean.T, zstd.T)):
     # prepare the code
     ids = np.argsort(mean)
     mean, std = mean[ids], std[ids]
     # show the factors
     for col, (name, y) in enumerate(zip(labels, factors.T)):
       axes = []
       # the variance
       ax = vs.plot_subplot(nrow, ncol, plot)
       ax.plot(mean, color='g', linestyle='--')
       ax.fill_between(X, mean - 2 * std, mean + 2 * std, alpha=0.2, color='b')
       if col == 0:
         ax.set_ylabel(code)
       if row == 0:
         ax.set_title(name)
       axes.append(ax)
       # factor
       y = y[ids]
       ax = ax.twinx()
       vs.plot_scatter(x=X,
                       y=y,
                       val=y,
                       size=12,
                       color='bwr',
                       alpha=0.5,
                       grid=False)
       axes.append(ax)
       # update plot index
       for ax in axes:
         ax.tick_params(axis='both',
                        which='both',
                        top=False,
                        bottom=False,
                        left=False,
                        right=False,
                        labeltop=False,
                        labelleft=False,
                        labelright=False,
                        labelbottom=False)
       plot += 1
   fig.tight_layout()
   self.add_figure("uncertainty_stats", fig)
   return self
예제 #10
0
def plot(train, score, title, applying_pca=False):
  if applying_pca:
    pca = PCA(n_components=NUM_DIM)
    pca.fit(train)
    train = pca.transform(train)
    score = pca.transform(score)
  plot_figure(nrow=6, ncol=12)
  plot_scatter(x=train[:, 0], y=train[:, 1],
               z=None if NUM_DIM < 3 or train.shape[1] < 3 else train[:, 2],
               size=POINT_SIZE, color=y_train_color, marker=y_train_marker,
               fontsize=12, legend=legends,
               title='[train]' + str(title),
               ax=(1, 2, 1))
  plot_scatter(x=score[:, 0], y=score[:, 1],
               z=None if NUM_DIM < 3 or score.shape[1] < 3 else score[:, 2],
               size=POINT_SIZE, color=y_score_color, marker=y_score_marker,
               fontsize=12, legend=legends,
               title='[score]' + str(title),
               ax=(1, 2, 2))
예제 #11
0
def to_image(X, grids):
    if X.shape[-1] == 1:  # grayscale image
        X = np.squeeze(X, axis=-1)
    else:  # color image
        X = np.transpose(X, (0, 3, 1, 2))
    nrows, ncols = grids
    fig = vs.plot_figure(nrows=nrows, ncols=ncols, dpi=100)
    vs.plot_images(X, grids=grids)
    image = vs.plot_to_image(fig)
    return image
예제 #12
0
def save_images(pX_Z, name, step, path):
    X = pX_Z.mean().numpy()
    if X.shape[-1] == 1:
        X = np.squeeze(X, axis=-1)
    else:
        X = np.transpose(X, (0, 3, 1, 2))
    fig = vs.plot_figure(nrow=16, ncol=16, dpi=60)
    vs.plot_images(X, fig=fig, title="[%s]#Iter: %d" % (name, step))
    fig.savefig(path, dpi=60)
    plt.close(fig)
    del X
예제 #13
0
파일: analyze_data.py 프로젝트: imito/odin
def plot_mean_std(_map, title):
  V.plot_figure(nrow=6, ncol=20)
  for i, dsname in enumerate(all_dataset):
    mean, _ = _map[dsname]
    plt.plot(mean,
             linewidth=1.,
             linestyle=linestyles[i % len(linestyles)],
             label=dsname)
  plt.legend()
  plt.suptitle("[%s]Mean" % title)

  V.plot_figure(nrow=6, ncol=20)
  for i, dsname in enumerate(all_dataset):
    _, std = _map[dsname]
    plt.plot(std,
             linewidth=1.,
             linestyle=linestyles[i % len(linestyles)],
             label=dsname)
  plt.legend()
  plt.suptitle("[%s]StandardDeviation" % title)
예제 #14
0
def plot_mean_std(_map, title):
    V.plot_figure(nrow=6, ncol=20)
    for i, dsname in enumerate(all_dataset):
        mean, _ = _map[dsname]
        plt.plot(mean,
                 linewidth=1.,
                 linestyle=linestyles[i % len(linestyles)],
                 label=dsname)
    plt.legend()
    plt.suptitle("[%s]Mean" % title)

    V.plot_figure(nrow=6, ncol=20)
    for i, dsname in enumerate(all_dataset):
        _, std = _map[dsname]
        plt.plot(std,
                 linewidth=1.,
                 linestyle=linestyles[i % len(linestyles)],
                 label=dsname)
    plt.legend()
    plt.suptitle("[%s]StandardDeviation" % title)
예제 #15
0
def evaluate_feeder(feeder, title):
    y_true_digit = []
    y_true_gender = []
    y_pred = []
    for outputs in Progbar(feeder.set_batch(batch_mode='file'),
                           name=title,
                           print_report=True,
                           print_summary=False,
                           count_func=lambda x: x[-1].shape[0]):
        name = str(outputs[0])
        idx = int(outputs[1])
        data = outputs[2:]
        assert idx == 0
        y_true_digit.append(f_digits(name))
        y_true_gender.append(f_genders(name))
        y_pred.append(f_pred(*data))
    # ====== post processing ====== #
    y_true_digit = np.array(y_true_digit, dtype='int32')
    y_true_gender = np.array(y_true_gender, dtype='int32')
    y_pred_proba = np.concatenate(y_pred, axis=0)
    y_pred_all = np.argmax(y_pred_proba, axis=-1).astype('int32')
    # ====== plotting for each gender ====== #
    plot_figure(nrow=6, ncol=25)
    for gen in range(len(genders)):
        y_true, y_pred = [], []
        for i, g in enumerate(y_true_gender):
            if g == gen:
                y_true.append(y_true_digit[i])
                y_pred.append(y_pred_all[i])
        if len(y_true) == 0:
            continue
        cm = confusion_matrix(y_true, y_pred, labels=range(len(digits)))
        plot_confusion_matrix(cm,
                              labels=digits,
                              fontsize=8,
                              ax=(1, 4, gen + 1),
                              title='[%s]%s' % (genders[gen], title))
    plot_save(os.path.join(FIG_PATH, '%s.pdf' % title))
 def plot_histogram(self,
                    omic=OMIC.proteomic,
                    bins=80,
                    log_norm=True,
                    var_names=None,
                    max_plots=100,
                    fig=None,
                    return_figure=False):
     r""" Plot histogram for each variable of given OMIC type """
     omic = OMIC.parse(omic)
     x = self.numpy(omic)
     bins = min(int(bins), x.shape[0] // 2)
     max_plots = int(max_plots)
     ### prepare the data
     var_ids = self.get_var_indices(omic)
     if var_names is None:
         var_names = var_ids.keys()
     var_names = np.array([i for i in var_names if i in var_ids])
     assert len(var_names) > 0, \
       f"No matching variables found for {omic.name}"
     # randomly select variables
     if len(var_names) > max_plots:
         rand = np.random.RandomState(seed=1)
         ids = rand.permutation(len(var_names))[:max_plots]
         var_names = var_names[ids]
     ids = [var_ids[i] for i in var_names]
     x = x[:, ids]
     ### the figures
     ncol = 8
     nrow = int(np.ceil(x.shape[1] / ncol))
     if fig is None:
         fig = vs.plot_figure(nrow=nrow * 2, ncol=ncol * 3, dpi=80)
     # plot
     for idx, (y, name) in enumerate(zip(x.T, var_names)):
         sparsity = sparsity_percentage(y, batch_size=2048)
         y = y[y != 0.]
         if log_norm:
             y = np.log1p(y)
         vs.plot_histogram(x=y,
                           bins=bins,
                           alpha=0.8,
                           ax=(nrow, ncol, idx + 1),
                           title=f"{name}\n({sparsity*100:.1f}% zeros)")
         fig.gca().tick_params(axis='y', labelleft=False)
     ### adjust and return
     fig.suptitle(f"{omic.name}")
     fig.tight_layout(rect=[0.0, 0.03, 1.0, 0.97])
     if return_figure:
         return fig
     return self.add_figure(f"histogram_{omic.name}", fig)
예제 #17
0
    def boxplot(self, X, labels=None):
        X, labels, n_classes = self._check_input(X, labels)

        nrow = n_classes
        ncol = 3
        fig = plot_figure(nrow=3 * nrow, ncol=int(1.5 * ncol))

        for i, (x, name) in enumerate(zip(X.T, labels)):
            start = i * ncol

            ax = plt.subplot(nrow, ncol, start + 1)
            ax.boxplot(x,
                       whis=1.5,
                       labels=['Original'],
                       flierprops={
                           'marker': '.',
                           'markersize': 8
                       },
                       showmeans=True,
                       meanline=True)
            ax.set_ylabel(name)

            ax = plt.subplot(nrow, ncol, start + 2)
            ax.boxplot(x[x > 0],
                       whis=1.5,
                       labels=['NonZeros'],
                       flierprops={
                           'marker': '.',
                           'markersize': 8
                       },
                       showmeans=True,
                       meanline=True)

            ax = plt.subplot(nrow, ncol, start + 3)
            ax.boxplot(self.normalize(x, test_mode=False),
                       whis=1.5,
                       labels=['Normalized'],
                       flierprops={
                           'marker': '.',
                           'markersize': 8
                       },
                       showmeans=True,
                       meanline=True)

        plt.tight_layout()
        self.add_figure('boxplot', fig)
        return self
 def plot_percentile_histogram(self,
                               omic=OMIC.transcriptomic,
                               n_hist=10,
                               title="",
                               outlier=0.001,
                               non_zeros=False,
                               fig=None):
     r""" Data is chopped into multiple percentile (`n_hist`) and the
 histogram is plotted for each percentile. """
     omic = OMIC.parse(omic)
     arr = self.numpy(omic)
     if non_zeros:
         arr = arr[arr != 0]
     n_percentiles = n_hist + 1
     n_col = 5
     n_row = int(np.ceil(n_hist / n_col))
     if fig is None:
         fig = vs.plot_figure(nrow=int(n_row * 1.5), ncol=20)
     self.assert_figure(fig)
     percentile = np.linspace(start=np.min(arr),
                              stop=np.max(arr),
                              num=n_percentiles)
     n_samples = len(arr)
     for i, (p_min, p_max) in enumerate(zip(percentile, percentile[1:])):
         min_mask = arr >= p_min
         max_mask = arr <= p_max
         mask = np.logical_and(min_mask, max_mask)
         a = arr[mask]
         _, bins = vs.plot_histogram(
             a,
             bins=120,
             ax=(n_row, n_col, i + 1),
             fontsize=8,
             color='red' if len(a) / n_samples < outlier else 'blue',
             title=f"{len(a)}(samples)  Range:[{p_min:.2g},{p_max:.2g}]")
         plt.gca().set_xticks(np.linspace(np.min(bins), np.max(bins),
                                          num=8))
     if len(title) > 0:
         plt.suptitle(title)
     plt.tight_layout(rect=[0.0, 0.02, 1.0, 0.98])
     self.add_figure(f'histogram{n_hist}_{omic.name}', fig)
     return self
예제 #19
0
from odin import visual as vs

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

tf.random.set_seed(8)
np.random.seed(8)

X, y = load_digits(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

X_umap = ml.fast_umap(X_train, X_test)
X_tsne = ml.fast_tsne(X_train, X_test)
X_pca = ml.fast_pca(X_train, X_test, n_components=2)

styles = dict(size=12, alpha=0.6, centroids=True)

vs.plot_figure(6, 12)
vs.plot_scatter(x=X_pca[0], color=y_train, ax=(1, 2, 1), **styles)
vs.plot_scatter(x=X_pca[1], color=y_test, ax=(1, 2, 2), **styles)

vs.plot_figure(6, 12)
vs.plot_scatter(x=X_tsne[0], color=y_train, ax=(1, 2, 1), **styles)
vs.plot_scatter(x=X_tsne[1], color=y_test, ax=(1, 2, 2), **styles)

vs.plot_figure(6, 12)
vs.plot_scatter(x=X_umap[0], color=y_train, ax=(1, 2, 1), **styles)
vs.plot_scatter(x=X_umap[1], color=y_test, ax=(1, 2, 2), **styles)

vs.plot_save()
예제 #20
0
파일: utils.py 프로젝트: imito/odin
def prepare_dnn_data(recipe, feat, utt_length, seed=52181208):
  """
  Return
  ------
  train_feeder : Feeder for training
  valid_feeder : Feeder for validating
  test_ids : Test indices
  test_dat : Data array
  all_speakers : list of all speaker in training set
  """
  # Load dataset
  frame_length = int(utt_length / FRAME_SHIFT)
  ds = F.Dataset(os.path.join(PATH_ACOUSTIC_FEAT, recipe),
                 read_only=True)
  X = ds[feat]
  train_indices = {name: ds['indices'][name]
                   for name in TRAIN_DATA.keys()}
  test_indices = {name: start_end
                  for name, start_end in ds['indices'].items()
                  if name not in TRAIN_DATA}
  train_indices, valid_indices = train_valid_test_split(
      x=list(train_indices.items()), train=0.9, inc_test=False, seed=seed)
  all_speakers = sorted(set(TRAIN_DATA.values()))
  n_speakers = max(all_speakers) + 1
  print("#Train files:", ctext(len(train_indices), 'cyan'))
  print("#Valid files:", ctext(len(valid_indices), 'cyan'))
  print("#Test files:", ctext(len(test_indices), 'cyan'))
  print("#Speakers:", ctext(n_speakers, 'cyan'))
  recipes = [
      F.recipes.Sequencing(frame_length=frame_length, step_length=frame_length,
                           end='pad', pad_value=0, pad_mode='post',
                           data_idx=0),
      F.recipes.Name2Label(lambda name:TRAIN_DATA[name], ref_idx=0),
      F.recipes.LabelOneHot(nb_classes=n_speakers, data_idx=1)
  ]
  train_feeder = F.Feeder(
      data_desc=F.IndexedData(data=X, indices=train_indices),
      batch_mode='batch', ncpu=7, buffer_size=12)
  valid_feeder = F.Feeder(
      data_desc=F.IndexedData(data=X, indices=valid_indices),
      batch_mode='batch', ncpu=2, buffer_size=4)
  train_feeder.set_recipes(recipes)
  valid_feeder.set_recipes(recipes)
  print(train_feeder)
  # ====== cache the test data ====== #
  cache_dat = os.path.join(PATH_EXP, 'test_%s_%d.dat' % (feat, int(utt_length)))
  cache_ids = os.path.join(PATH_EXP, 'test_%s_%d.ids' % (feat, int(utt_length)))
  # validate cache files
  if os.path.exists(cache_ids):
    with open(cache_ids, 'rb') as f:
      ids = pickle.load(f)
    if len(ids) != len(test_indices):
      os.remove(cache_ids)
      if os.path.exists(cache_dat):
        os.remove(cache_dat)
  elif os.path.exists(cache_dat):
    os.remove(cache_dat)
  # caching
  if not os.path.exists(cache_dat):
    dat = F.MmapData(cache_dat, dtype='float16',
                     shape=(0, frame_length, X.shape[1]))
    ids = {}
    prog = Progbar(target=len(test_indices))
    s = 0
    for name, (start, end) in test_indices.items():
      y = X[start:end]
      y = segment_axis(y, axis=0,
                       frame_length=frame_length, step_length=frame_length,
                       end='pad', pad_value=0, pad_mode='post')
      dat.append(y)
      # update indices
      ids[name] = (s, s + len(y))
      s += len(y)
      # update progress
      prog.add(1)
    dat.flush()
    dat.close()
    with open(cache_ids, 'wb') as f:
      pickle.dump(ids, f)
  # ====== re-load ====== #
  dat = F.MmapData(cache_dat, read_only=True)
  with open(cache_ids, 'rb') as f:
    ids = pickle.load(f)
  # ====== save some sample ====== #
  sample_path = os.path.join(PATH_EXP,
                             'test_%s_%d.pdf' % (feat, int(utt_length)))
  V.plot_figure(nrow=9, ncol=6)
  for i, (name, (start, end)) in enumerate(
      sampling_iter(it=sorted(ids.items(), key=lambda x: x[0]), k=12, seed=52181208)):
    x = dat[start:end][:].astype('float32')
    ax = V.plot_spectrogram(x[np.random.randint(0, len(x))].T,
                            ax=(12, 1, i + 1), title='')
    ax.set_title(name)
  V.plot_save(sample_path)
  return (train_feeder, valid_feeder,
          ids, dat, all_speakers)
예제 #21
0
def prepare_dnn_data(save_dir,
                     feat_name=None,
                     utt_length=None,
                     seq_mode=None,
                     min_dur=None,
                     min_utt=None,
                     exclude=None,
                     train_proportion=None,
                     return_dataset=False):
    assert os.path.isdir(save_dir), \
        "Path to '%s' is not a directory" % save_dir
    if feat_name is None:
        feat_name = FEATURE_NAME
    if utt_length is None:
        utt_length = int(_args.utt)
    if seq_mode is None:
        seq_mode = str(_args.seq).strip().lower()
    if min_dur is None:
        min_dur = MINIMUM_UTT_DURATION
    if min_utt is None:
        min_utt = MINIMUM_UTT_PER_SPEAKERS
    if exclude is None:
        exclude = str(_args.exclude).strip()
    print("Minimum duration: %s(s)" % ctext(min_dur, 'cyan'))
    print("Minimum utt/spk : %s(utt)" % ctext(min_utt, 'cyan'))
    # ******************** prepare dataset ******************** #
    path = os.path.join(PATH_ACOUSTIC_FEATURES, FEATURE_RECIPE)
    assert os.path.exists(
        path), "Cannot find acoustic dataset at path: %s" % path
    ds = F.Dataset(path=path, read_only=True)
    rand = np.random.RandomState(seed=Config.SUPER_SEED)
    # ====== find the right feature ====== #
    assert feat_name in ds, "Cannot find feature with name: %s" % feat_name
    X = ds[feat_name]
    ids_name = 'indices_%s' % feat_name
    assert ids_name in ds, "Cannot find indices with name: %s" % ids_name
    # ====== basic path ====== #
    path_filtered_data = os.path.join(save_dir, 'filtered_files.pkl')
    path_train_files = os.path.join(save_dir, 'train_files.pkl')
    path_speaker_info = os.path.join(save_dir, 'speaker_info.pkl')
    # ******************** cannot find cached data ******************** #
    if any(not os.path.exists(p)
           for p in [path_filtered_data, path_train_files, path_speaker_info]):
        # ====== exclude some dataset ====== #
        if len(exclude) > 0:
            exclude_dataset = {i: 1 for i in exclude.split(',')}
            print("* Excluded dataset:", ctext(exclude_dataset, 'cyan'))
            indices = {
                name: (start, end)
                for name, (start, end) in ds[ids_name].items()
                if ds['dsname'][name] not in exclude_dataset
            }
            # special case exclude all the noise data
            if 'noise' in exclude_dataset:
                indices = {
                    name: (start, end)
                    for name, (start, end) in indices.items()
                    if '/' not in name
                }
        else:
            indices = {i: j for i, j in ds[ids_name].items()}
        # ====== down-sampling if necessary ====== #
        if _args.downsample > 1000:
            dataset2name = defaultdict(list)
            # ordering the indices so we sample the same set every time
            for name in sorted(indices.keys()):
                dataset2name[ds['dsname'][name]].append(name)
            n_total_files = len(indices)
            n_sample_files = int(_args.downsample)
            # get the percentage of each dataset
            dataset2per = {
                i: len(j) / n_total_files
                for i, j in dataset2name.items()
            }
            # sampling based on percentage
            _ = {}
            for dsname, flist in dataset2name.items():
                rand.shuffle(flist)
                n_dataset_files = int(dataset2per[dsname] * n_sample_files)
                _.update({i: indices[i] for i in flist[:n_dataset_files]})
            indices = _
        # ====== * filter out "bad" sample ====== #
        indices = filter_utterances(X=X,
                                    indices=indices,
                                    spkid=ds['spkid'],
                                    min_utt=min_utt,
                                    min_dur=min_dur,
                                    remove_min_length=True,
                                    remove_min_uttspk=True,
                                    n_speakers=None,
                                    ncpu=None,
                                    save_path=path_filtered_data)
        # ====== all training file name ====== #
        # modify here to train full dataset
        all_name = sorted(indices.keys())
        rand.shuffle(all_name)
        rand.shuffle(all_name)
        n_files = len(all_name)
        print("#Files:", ctext(n_files, 'cyan'))
        # ====== speaker mapping ====== #
        name2spk = {name: ds['spkid'][name] for name in all_name}
        all_speakers = sorted(set(name2spk.values()))
        spk2label = {spk: i for i, spk in enumerate(all_speakers)}
        name2label = {name: spk2label[spk] for name, spk in name2spk.items()}
        assert len(name2label) == len(all_name)
        print("#Speakers:", ctext(len(all_speakers), 'cyan'))
        # ====== stratify sampling based on speaker ====== #
        valid_name = []
        # create speakers' cluster
        label2name = defaultdict(list)
        for name, label in sorted(name2label.items(), key=lambda x: x[0]):
            label2name[label].append(name)
        # for each speaker with >= 3 utterance
        for label, name_list in sorted(label2name.items(), key=lambda x: x[0]):
            if len(name_list) < 3:
                continue
            n = max(1, int(0.05 * len(name_list)))  # 5% for validation
            valid_name += rand.choice(a=name_list, size=n,
                                      replace=False).tolist()
        # train list is the rest
        _ = set(valid_name)
        train_name = [i for i in all_name if i not in _]
        # ====== split training and validation ====== #
        train_indices = {name: indices[name] for name in train_name}
        valid_indices = {name: indices[name] for name in valid_name}
        # ====== save cached data ====== #
        with open(path_train_files, 'wb') as fout:
            pickle.dump({'train': train_indices, 'valid': valid_indices}, fout)
        with open(path_speaker_info, 'wb') as fout:
            pickle.dump(
                {
                    'all_speakers': all_speakers,
                    'name2label': name2label,
                    'spk2label': spk2label
                }, fout)
    # ******************** load cached data ******************** #
    else:
        with open(path_train_files, 'rb') as fin:
            obj = pickle.load(fin)
            train_indices = obj['train']
            valid_indices = obj['valid']
        with open(path_speaker_info, 'rb') as fin:
            obj = pickle.load(fin)
            all_speakers = obj['all_speakers']
            name2label = obj['name2label']
            spk2label = obj['spk2label']

    # ******************** print log ******************** #

    def summary_indices(ids):
        datasets = defaultdict(int)
        speakers = defaultdict(list)
        text = ''
        for name in sorted(ids.keys()):
            text += name + str(ids[name])
            dsname = ds['dsname'][name]
            datasets[dsname] += 1
            speakers[dsname].append(ds['spkid'][name])
        for dsname in sorted(datasets.keys()):
            print('  %-18s: %s(utt) %s(spk)' %
                  (dsname, ctext('%6d' % datasets[dsname], 'cyan'),
                   ctext(len(set(speakers[dsname])), 'cyan')))
        print('  MD5 checksum:', ctext(crypto.md5_checksum(text), 'lightcyan'))

    # ====== training files ====== #
    print(
        "#Train files:", ctext('%-8d' % len(train_indices), 'cyan'), "#spk:",
        ctext(len(set(name2label[name] for name in train_indices.keys())),
              'cyan'), "#noise:",
        ctext(len([name for name in train_indices.keys() if '/' in name]),
              'cyan'))
    summary_indices(ids=train_indices)
    # ====== valid files ====== #
    print(
        "#Valid files:", ctext('%-8d' % len(valid_indices), 'cyan'), "#spk:",
        ctext(len(set(name2label[name] for name in valid_indices.keys())),
              'cyan'), "#noise:",
        ctext(len([name for name in valid_indices.keys() if '/' in name]),
              'cyan'))
    summary_indices(ids=valid_indices)
    # ******************** create the recipe ******************** #
    assert all(name in name2label for name in train_indices.keys())
    assert all(name in name2label for name in valid_indices.keys())
    recipes = prepare_dnn_feeder_recipe(name2label=name2label,
                                        n_speakers=len(all_speakers),
                                        utt_length=utt_length,
                                        seq_mode=seq_mode)
    # ====== downsample training set for analyzing if required ====== #
    if train_proportion is not None:
        assert 0 < train_proportion < 1
        n_training = len(train_indices)
        train_indices = list(train_indices.items())
        rand.shuffle(train_indices)
        rand.shuffle(train_indices)
        train_indices = dict(train_indices[:int(n_training *
                                                train_proportion)])
    # ====== create feeder ====== #
    train_feeder = F.Feeder(data_desc=F.IndexedData(data=X,
                                                    indices=train_indices),
                            batch_mode='batch',
                            ncpu=NCPU,
                            buffer_size=256)

    valid_feeder = F.Feeder(data_desc=F.IndexedData(data=X,
                                                    indices=valid_indices),
                            batch_mode='batch',
                            ncpu=max(2, NCPU // 4),
                            buffer_size=64)

    train_feeder.set_recipes(recipes)
    valid_feeder.set_recipes(recipes)
    print(train_feeder)
    print(valid_feeder)
    # ====== debugging ====== #
    if IS_DEBUGGING:
        import matplotlib
        matplotlib.use('Agg')
        prog = Progbar(target=len(valid_feeder),
                       print_summary=True,
                       name="Iterating validation set")
        samples = []
        n_visual = 250
        for name, idx, X, y in valid_feeder.set_batch(batch_size=100000,
                                                      batch_mode='file',
                                                      seed=None,
                                                      shuffle_level=0):
            assert idx == 0, "Utterances longer than %.2f(sec)" % (
                100000 * Config.STEP_LENGTH)
            prog['X'] = X.shape
            prog['y'] = y.shape
            prog.add(X.shape[0])
            # random sampling
            if rand.rand(1) < 0.5 and len(samples) < n_visual:
                for i in rand.randint(0, X.shape[0], size=4, dtype='int32'):
                    samples.append((name, X[i], np.argmax(y[i], axis=-1)))
        # plot the spectrogram
        n_visual = len(samples)
        V.plot_figure(nrow=n_visual, ncol=8)
        for i, (name, X, y) in enumerate(samples):
            is_noise = '/' in name
            assert name2label[
                name] == y, "Speaker label mismatch for file: %s" % name
            name = name.split('/')[0]
            dsname = ds['dsname'][name]
            spkid = ds['spkid'][name]
            y = np.argmax(y, axis=-1)
            ax = V.plot_spectrogram(X.T,
                                    ax=(n_visual, 1, i + 1),
                                    title='#%d' % (i + 1))
            ax.set_title(
                '[%s][%s]%s  %s' %
                ('noise' if is_noise else 'clean', dsname, name, spkid),
                fontsize=6)
        # don't need to be high resolutions
        V.plot_save('/tmp/tmp.pdf', dpi=12)
        exit()
    # ====== return ====== #
    if bool(return_dataset):
        return train_feeder, valid_feeder, all_speakers, ds
    return train_feeder, valid_feeder, all_speakers
예제 #22
0
    def plot_comparison_f1(self,
                           test=True,
                           model_id=lambda m: m.name,
                           fig_width=12):
        assert callable(model_id), "model_id must be callable"
        start_time = time.time()
        score_type = 'classifier'
        data_type = 'test' if bool(test) else 'train'

        n_system = len(self)
        fn_score = _get_score_fn(score_type)
        scores_name = None
        scores = []
        for pos in self.posteriors:
            s = fn_score(pos)
            s = s[1] if bool(test) else s[0]
            if score_type == 'classifier':
                del s['F1weight']
                del s['F1micro']
                del s['F1macro']
            n_labels = len(s)
            s = sorted(s.items(), key=lambda x: x[0])
            scores_name = [i[0].replace('F1_', '') for i in s]
            scores.append((model_id(pos.infer), [i[1] for i in s]))

        colors = sns.color_palette(n_colors=n_labels)
        fig, subplots = plt.subplots(nrows=1,
                                     ncols=n_system,
                                     sharey=True,
                                     squeeze=True,
                                     figsize=(int(fig_width), 2))

        for idx, (name, f1) in enumerate(scores):
            assert len(scores_name) == len(f1)
            f1_weight = np.mean(f1)
            ax = subplots[idx]
            ax.grid(True, axis='both', which='both', linewidth=0.5, alpha=0.6)
            for i, (f, c) in enumerate(zip(f1, colors)):
                ax.scatter(i, f, color=c, s=22, marker='o', alpha=0.8)
                ax.text(i - 0.2, f + 24, '%.1f' % f, fontsize=10, rotation=75)

            ax.plot(np.arange(n_labels), f1, linewidth=1.0, linestyle='--')
            ax.plot(np.arange(n_labels), [f1_weight for i in range(n_labels)],
                    linewidth=1.2,
                    linestyle=':',
                    color='black',
                    label=r"$\overline{F1}$:%.1f" % f1_weight)
            ax.legend(fontsize=14,
                      loc='lower left',
                      handletextpad=0.1,
                      frameon=False)

            ax.set_xticks(np.arange(n_labels))
            ax.set_xlabel(name, fontsize=12)

            ax.set_ylim(-8, 130)
            ax.set_yticks(np.linspace(0, 100, 5))

            ax.xaxis.set_ticklabels([])

            ax.tick_params(axis='x', length=0)
            ax.tick_params(axis='y', length=0, labelsize=8)

            plot_frame(ax,
                       right=False,
                       top=False,
                       left=True if idx == 0 else False)

        plt.tight_layout(w_pad=0)
        self.add_figure("compare_%s_%s" % (score_type, data_type), fig)

        fig = plot_figure(nrow=1, ncol=4)
        for name, c in zip(scores_name, colors):
            plt.plot(0, 0, 'o', label=name, color=c)
        plt.axis('off')
        plt.legend(ncol=int(np.ceil(len(scores_name) / 2)),
                   scatterpoints=1,
                   scatteryoffsets=[0.375, 0.5, 0.3125],
                   loc='upper center',
                   bbox_to_anchor=(0.5, -0.01),
                   handletextpad=0.1,
                   labelspacing=0.,
                   columnspacing=0.4)
        self.add_figure("compare_%s_%s_legend" % (score_type, data_type), fig)

        return self._log('plot_comparison_series[%s][%s] %s(s)' %
                         (score_type, data_type,
                          ctext(time.time() - start_time, 'lightyellow')))
        return self
예제 #23
0
  print(ctext("[Epoch %d]" % epoch, 'yellow'), '%.2f(s)' % (timeit.default_timer() - start_time))
  print("[Training set] Loss: %.4f" % np.mean(train_losses))
  # ====== validation set ====== #
  code_samples, lo = K.eval([Z, loss], feed_dict={X: X_valid})
  print("[Valid set]    Loss: %.4f" % lo)
  # ====== record the history ====== #
  record_train_loss.append(np.mean(train_losses))
  record_valid_loss.append(lo)
  # ====== plotting ====== #
  if args.dim > 2:
    code_samples = ml.fast_pca(code_samples, n_components=2,
                               random_state=K.get_rng().randint(10e8))
  img_samples = f_samples()
  img_mean = f_X(X_valid[:25])

  V.plot_figure(nrow=3, ncol=12)

  ax = plt.subplot(1, 3, 1)
  ax.scatter(code_samples[:, 0], code_samples[:, 1], s=2, c=y_valid, alpha=0.3)
  ax.set_title('Epoch %d' % epoch)
  ax.set_aspect('equal', 'box')
  ax.axis('off')

  ax = plt.subplot(1, 3, 2)
  ax.imshow(V.tile_raster_images(img_samples), cmap=plt.cm.Greys_r)
  ax.axis('off')

  ax = plt.subplot(1, 3, 3)
  ax.imshow(V.tile_raster_images(img_mean), cmap=plt.cm.Greys_r)
  ax.axis('off')
  # ====== check exit condition ====== #
예제 #24
0
파일: utils.py 프로젝트: trungnt13/odin-ai
def prepare_dnn_data(recipe, feat, utt_length, seed=87654321):
    """
  Return
  ------
  train_feeder : Feeder for training
  valid_feeder : Feeder for validating
  test_ids : Test indices
  test_dat : Data array
  all_speakers : list of all speaker in training set
  """
    # Load dataset
    frame_length = int(utt_length / FRAME_SHIFT)
    ds = F.Dataset(os.path.join(PATH_ACOUSTIC_FEAT, recipe), read_only=True)
    X = ds[feat]
    train_indices = {name: ds['indices'][name] for name in TRAIN_DATA.keys()}
    test_indices = {
        name: start_end
        for name, start_end in ds['indices'].items() if name not in TRAIN_DATA
    }
    train_indices, valid_indices = train_valid_test_split(x=list(
        train_indices.items()),
                                                          train=0.9,
                                                          inc_test=False,
                                                          seed=seed)
    all_speakers = sorted(set(TRAIN_DATA.values()))
    n_speakers = max(all_speakers) + 1
    print("#Train files:", ctext(len(train_indices), 'cyan'))
    print("#Valid files:", ctext(len(valid_indices), 'cyan'))
    print("#Test files:", ctext(len(test_indices), 'cyan'))
    print("#Speakers:", ctext(n_speakers, 'cyan'))
    recipes = [
        F.recipes.Sequencing(frame_length=frame_length,
                             step_length=frame_length,
                             end='pad',
                             pad_value=0,
                             pad_mode='post',
                             data_idx=0),
        F.recipes.Name2Label(lambda name: TRAIN_DATA[name], ref_idx=0),
        F.recipes.LabelOneHot(nb_classes=n_speakers, data_idx=1)
    ]
    train_feeder = F.Feeder(data_desc=F.IndexedData(data=X,
                                                    indices=train_indices),
                            batch_mode='batch',
                            ncpu=7,
                            buffer_size=12)
    valid_feeder = F.Feeder(data_desc=F.IndexedData(data=X,
                                                    indices=valid_indices),
                            batch_mode='batch',
                            ncpu=2,
                            buffer_size=4)
    train_feeder.set_recipes(recipes)
    valid_feeder.set_recipes(recipes)
    print(train_feeder)
    # ====== cache the test data ====== #
    cache_dat = os.path.join(PATH_EXP,
                             'test_%s_%d.dat' % (feat, int(utt_length)))
    cache_ids = os.path.join(PATH_EXP,
                             'test_%s_%d.ids' % (feat, int(utt_length)))
    # validate cache files
    if os.path.exists(cache_ids):
        with open(cache_ids, 'rb') as f:
            ids = pickle.load(f)
        if len(ids) != len(test_indices):
            os.remove(cache_ids)
            if os.path.exists(cache_dat):
                os.remove(cache_dat)
    elif os.path.exists(cache_dat):
        os.remove(cache_dat)
    # caching
    if not os.path.exists(cache_dat):
        dat = F.MmapData(cache_dat,
                         dtype='float16',
                         shape=(0, frame_length, X.shape[1]))
        ids = {}
        prog = Progbar(target=len(test_indices))
        s = 0
        for name, (start, end) in test_indices.items():
            y = X[start:end]
            y = segment_axis(y,
                             axis=0,
                             frame_length=frame_length,
                             step_length=frame_length,
                             end='pad',
                             pad_value=0,
                             pad_mode='post')
            dat.append(y)
            # update indices
            ids[name] = (s, s + len(y))
            s += len(y)
            # update progress
            prog.add(1)
        dat.flush()
        dat.close()
        with open(cache_ids, 'wb') as f:
            pickle.dump(ids, f)
    # ====== re-load ====== #
    dat = F.MmapData(cache_dat, read_only=True)
    with open(cache_ids, 'rb') as f:
        ids = pickle.load(f)
    # ====== save some sample ====== #
    sample_path = os.path.join(PATH_EXP,
                               'test_%s_%d.pdf' % (feat, int(utt_length)))
    V.plot_figure(nrow=9, ncol=6)
    for i, (name, (start, end)) in enumerate(
            sampling_iter(it=sorted(ids.items(), key=lambda x: x[0]),
                          k=12,
                          seed=87654321)):
        x = dat[start:end][:].astype('float32')
        ax = V.plot_spectrogram(x[np.random.randint(0, len(x))].T,
                                ax=(12, 1, i + 1),
                                title='')
        ax.set_title(name)
    V.plot_save(sample_path)
    return (train_feeder, valid_feeder, ids, dat, all_speakers)
예제 #25
0
  def plot_histogram_heatmap(self,
                             factors=None,
                             factor_bins=15,
                             histogram_bins=80,
                             n_codes_per_factor=6,
                             corr_method='average',
                             original_factors=True):
    r""" The histogram bars are colored by the value of factors

    Arguments:
      factors : which factors will be used
      factor_bins : factor is discretized into bins, then a LogisticRegression
        model will predict the bin (with color) given the code as input.
      orginal_factors : optional original factors before discretized by
        `Criticizer`
    """
    self.assert_sampled()
    from matplotlib import pyplot as plt
    import seaborn as sns
    sns.set()
    if n_codes_per_factor is None:
      n_codes_per_factor = self.n_codes
    else:
      n_codes_per_factor = int(n_codes_per_factor)
    styles = dict(fontsize=12,
                  val_bins=int(factor_bins),
                  color='bwr',
                  bins=int(histogram_bins),
                  alpha=0.8)
    ## correlation
    train_corr, test_corr = self.cal_correlation_matrix(mean=True,
                                                        method=corr_method,
                                                        decode=False)
    corr = (train_corr + test_corr) / 2
    ## prepare the data
    factors = self._check_factors(factors)
    Z = np.concatenate(self.representations_mean, axis=0)
    F = np.concatenate(
        self.original_factors if original_factors else self.factors,
        axis=0)[:, factors]
    # annotations
    factors_name = self.factors_name[factors]
    codes_name = self.codes_name
    # create the figure
    nrow = F.shape[1]
    ncol = int(1 + n_codes_per_factor)
    fig = vs.plot_figure(nrow=nrow * 3, ncol=ncol * 3, dpi=80)
    plot_count = 1
    for fidx, (f, fname) in enumerate(zip(F.T, factors_name)):
      c = corr[:, fidx]
      vs.plot_histogram(f,
                        val=f,
                        ax=(nrow, ncol, plot_count),
                        cbar=True,
                        cbar_horizontal=False,
                        title=fname,
                        **styles)
      plot_count += 1
      # all codes are visualized
      if n_codes_per_factor == self.n_codes:
        all_codes = range(self.n_codes)
      # lower to higher correlation
      else:
        zids = np.argsort(c)
        bottom = zids[:n_codes_per_factor // 2]
        top = zids[-(n_codes_per_factor - n_codes_per_factor // 2):]
        all_codes = (top.tolist()[::-1] + bottom.tolist()[::-1])
      for i in all_codes:
        z = Z[:, i]
        zname = codes_name[i]
        vs.plot_histogram(z,
                          val=f,
                          ax=(nrow, ncol, plot_count),
                          title='[%.2g]%s' % (c[i], zname),
                          **styles)
        plot_count += 1
    fig.tight_layout()
    self.add_figure(
        "histogram_%s" % ("original" if original_factors else "discretized"),
        fig)
    return self
예제 #26
0
def evaluate(vae: VariationalAutoencoder,
             ds: ImageDataset,
             expdir: str,
             title: str,
             batch_size: int = 64,
             take_count: int = -1,
             n_images: int = 36,
             seed: int = 1):
    n_rows = int(np.sqrt(n_images))
    is_semi = vae.is_semi_supervised()
    is_hierarchical = vae.is_hierarchical()
    ds_kw = dict(batch_size=batch_size, label_percent=1.0, shuffle=False)
    ## prepare
    rand = np.random.RandomState(seed=seed)
    if not os.path.exists(expdir):
        os.makedirs(expdir)
    ## data for training semi-supervised
    train = ds.create_dataset('train', **ds_kw)
    (llkx_train, llky_train, x_org_train, x_rec_train, y_true_train,
     y_pred_train, z_train, pz_train) = _call(vae,
                                              ds=train,
                                              rand=rand,
                                              take_count=take_count,
                                              n_images=n_images,
                                              verbose=True)
    ## data for testing
    test = ds.create_dataset('test', **ds_kw)
    (llkx_test, llky_test, x_org_test, x_rec_test, y_true_test, y_pred_test,
     z_test, pz_test) = _call(vae,
                              ds=test,
                              rand=rand,
                              take_count=take_count,
                              n_images=n_images,
                              verbose=True)
    # === 0. plotting latent-factor pairs
    for idx, z in enumerate(z_test):
        z = z.mean()
        f = y_true_test
        corr_mat = Correlation.Spearman(z, f)  # [n_latents, n_factors]
        plot_latents_pairs(z, f, corr_mat, ds.labels)
        vs.plot_save(f'{expdir}/latent{idx}_factor.pdf', dpi=100, verbose=True)
    # === 0. latent traverse plot
    x_travs = x_org_test
    if x_travs.ndim == 3:  # grayscale image
        x_travs = np.expand_dims(x_travs, -1)
    else:  # color image
        x_travs = np.transpose(x_travs, (0, 2, 3, 1))
    x_travs = x_travs[rand.permutation(x_travs.shape[0])]
    n_visual_samples = 5
    n_traverse_points = 21
    n_top_latents = 10
    plt.figure(figsize=(8, 3 * n_visual_samples))
    for i in range(n_visual_samples):
        images = vae.sample_traverse(x_travs[i:i + 1],
                                     min_val=-np.min(z_test[0].mean()),
                                     max_val=np.max(z_test[0].mean()),
                                     n_best_latents=n_top_latents,
                                     n_traverse_points=n_traverse_points,
                                     mode='linear')
        images = as_tuple(images)[0]
        images = _prepare_images(images.mean().numpy(), normalize=True)
        vs.plot_images(images,
                       grids=(n_top_latents, n_traverse_points),
                       ax=(n_visual_samples, 1, i + 1))
        if i == 0:
            plt.title('Latents traverse')
    plt.tight_layout()
    vs.plot_save(f'{expdir}/latents_traverse.pdf', dpi=180, verbose=True)
    # === 0. prior sampling plot
    images = as_tuple(vae.sample_observation(n=n_images, seed=seed))[0]
    images = _prepare_images(images.mean().numpy(), normalize=True)
    plt.figure(figsize=(5, 5))
    vs.plot_images(images, grids=(n_rows, n_rows), title='Sampled')
    # === 1. reconstruction plot
    plt.figure(figsize=(15, 15))
    vs.plot_images(x_org_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 1),
                   title='[Train]Original')
    vs.plot_images(x_rec_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 2),
                   title='[Train]Reconstructed')
    vs.plot_images(x_org_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 3),
                   title='[Test]Original')
    vs.plot_images(x_rec_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 4),
                   title='[Test]Reconstructed')
    plt.tight_layout()
    ## prepare the labels
    label_type = ds.label_type
    if label_type == 'categorical':
        labels_name = ds.labels
        true = np.argmax(y_true_test, axis=-1)
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = np.argmax(y_pred_test.mean().numpy(), axis=-1)
            labels_pred = np.array([labels_name[i] for i in pred])
    elif label_type == 'factor':  # dsprites, shapes3d
        labels_name = ['cube', 'cylinder', 'sphere', 'round'] \
          if 'shapes3d' in ds.name else ['square', 'ellipse', 'heart']
        true = y_true_test[:, 2].astype('int32')
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = get_ymean(y_pred_test)[:, 2].astype('int32')
            labels_pred = np.array([labels_name[i] for i in pred])
    else:  # CelebA
        raise NotImplementedError
    ## confusion matrix
    if is_semi:
        plt.figure(figsize=(8, 8))
        acc = accuracy_score(y_true=true, y_pred=pred)
        vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred),
                                 labels=labels_name,
                                 cbar=True,
                                 fontsize=10,
                                 title=f'{title} Acc:{acc:.2f}')
    ## save arrays for later inspections
    with open(f'{expdir}/arrays', 'wb') as f:
        pickle.dump(
            dict(z_train=z_train,
                 y_pred_train=y_pred_train,
                 y_true_train=y_true_train,
                 z_test=z_test,
                 y_pred_test=y_pred_test,
                 y_true_test=y_true_test,
                 labels=labels_name,
                 ds=ds.name,
                 label_type=label_type), f)
    print(f'Exported arrays to "{expdir}/arrays"')
    ## semi-supervised
    z_mean_train = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_train], -1)
    z_mean_test = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_test], -1)
    # === 2. scatter points latents plot
    n_points = 5000
    ids = rand.permutation(len(labels_true))[:n_points]
    Y_true = labels_true[ids]
    Y_pred = labels_pred[ids]
    # tsne plot
    n_latents = 0 if len(z_train) == 1 else len(z_train)
    for name, X in zip(
        ['all'] + [f'latents{i}'
                   for i in range(n_latents)], [z_mean_test[ids]] +
        [z_test[i].mean().numpy()[ids] for i in range(n_latents)]):
        print(f'Plot scatter points for {name}')
        X = X.reshape(X.shape[0], -1)  # flatten to 2D
        X = Pipeline([('zscore', StandardScaler()),
                      ('pca', PCA(min(X.shape[1], 512),
                                  random_state=seed))]).fit_transform(X)
        tsne = DimReduce.TSNE(X, n_components=2, framework='sklearn')
        kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.8)
        plt.figure(figsize=(12, 6))
        vs.plot_scatter(color=Y_true,
                        title=f'[True]{title}-{name}',
                        ax=(1, 2, 1),
                        **kw)
        vs.plot_scatter(color=Y_pred,
                        title=f'[Pred]{title}-{name}',
                        ax=(1, 2, 2),
                        **kw)
    ## save all plot
    vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True)

    # === 3. show the latents statistics
    n_latents = len(z_train)
    colors = sns.color_palette(n_colors=len(labels_true))
    styles = dict(grid=False,
                  ticks_off=False,
                  alpha=0.6,
                  xlabel='mean',
                  ylabel='stddev')

    # scatter between latents and labels (assume categorical distribution)
    def _show_latents_labels(Z, Y, title):
        plt.figure(figsize=(5 * n_latents, 5), dpi=150)
        for idx, z in enumerate(Z):
            if len(z.batch_shape) == 0:
                mean = np.repeat(np.expand_dims(z.mean(), 0), Y.shape[0], 0)
                stddev = z.sample(Y.shape[0]) - mean
            else:
                mean = flatten(z.mean())
                stddev = flatten(z.stddev())
            y = np.argmax(Y, axis=-1)
            data = [[], [], []]
            for y_i, c in zip(np.unique(y), colors):
                mask = (y == y_i)
                data[0].append(np.mean(mean[mask], 0))
                data[1].append(np.mean(stddev[mask], 0))
                data[2].append([labels_true[y_i]] * mean.shape[1])
            vs.plot_scatter(
                x=np.concatenate(data[0], 0),
                y=np.concatenate(data[1], 0),
                color=np.concatenate(data[2], 0),
                ax=(1, n_latents, idx + 1),
                size=15 if mean.shape[1] < 128 else 8,
                title=f'[Test-{title}]#{idx} - {mean.shape[1]} (units)',
                **styles)
        plt.tight_layout()

    # simple scatter mean-stddev each latents
    def _show_latents(Z, title):
        plt.figure(figsize=(3.5 * n_latents, 3.5), dpi=150)
        for idx, z in enumerate(Z):
            mean = flatten(z.mean())
            stddev = flatten(z.stddev())
            if mean.ndim == 2:
                mean = np.mean(mean, 0)
                stddev = np.mean(stddev, 0)
            vs.plot_scatter(
                x=mean,
                y=stddev,
                ax=(1, n_latents, idx + 1),
                size=15 if len(mean) < 128 else 8,
                title=f'[Test-{title}]#{idx} - {len(mean)} (units)',
                **styles)

    _show_latents_labels(z_test, y_true_test, 'post')
    _show_latents_labels(pz_test, y_true_test, 'prior')
    _show_latents(z_test, 'post')
    _show_latents(pz_test, 'prior')

    # KL statistics
    vs.plot_figure()
    for idx, (qz, pz) in enumerate(zip(z_test, pz_test)):
        kl = []
        qz = Normal(loc=qz.mean(), scale=qz.stddev(), name=f'posterior{idx}')
        pz = Normal(loc=pz.mean(), scale=pz.stddev(), name=f'prior{idx}')
        for s, e in minibatch(batch_size=8, n=100):
            z = qz.sample(e - s)
            # don't do this in GPU, it explodes!
            kl.append((qz.log_prob(z) - pz.log_prob(z)).numpy())
        kl = np.concatenate(kl, 0)  # (mcmc, batch, event)
        # per sample
        kl_samples = np.sum(kl, as_tuple(list(range(2, kl.ndim))))
        kl_samples = logsumexp(kl_samples, 0)
        plt.subplot(n_latents, 2, idx * 2 + 1)
        sns.histplot(kl_samples, bins=50)
        plt.title(f'Z#{idx} KL per sample (nats)')
        # per latent
        kl_latents = np.mean(flatten(logsumexp(kl, 0)), 0)
        plt.subplot(n_latents, 2, idx * 2 + 2)
        plt.plot(np.sort(kl_latents))
        plt.title(f'Z#{idx} KL per dim (nats)')
    plt.tight_layout()

    vs.plot_save(f'{expdir}/latents.pdf', dpi=180, verbose=True)
예제 #27
0
# ===========================================================================
# Comparison
# ===========================================================================
import seaborn as sns
from odin import visual as V
from matplotlib import pyplot as plt
from sisua.analysis.latent_benchmarks import (plot_latents_binary,
                                              clustering_scores,
                                              streamline_classifier)
from sisua.analysis.imputation_benchmarks import (imputation_score,
                                                  imputation_mean_score,
                                                  imputation_std_score,
                                                  plot_imputation)

# ====== training process ====== #
V.plot_figure(nrow=4, ncol=10)
plt.subplot(1, 2, 1)
plt.plot(scvi_loss[0], label='train')
plt.plot(scvi_loss[1], label='valid')
plt.legend()
plt.title('scVI')

plt.subplot(1, 2, 2)
plt.plot(sisua_loss[0], label='train')
plt.plot(sisua_loss[1], label='valid')
plt.legend()
plt.title('SISUA')
plt.tight_layout()

# ====== Latent space ====== #
V.plot_figure(nrow=8, ncol=18)
예제 #28
0
with catch_warnings_ignore(RuntimeWarning), catch_warnings_ignore(
        FutureWarning):
    data_map = {}
    stats_map = {}
    spk_map = {}
    for dsname, text, data, stats, spk_stats in mpi.MPI(
            jobs=all_dataset, func=dataset_statistics, ncpu=None, batch=1):
        data_map[dsname] = data
        stats_map[dsname] = stats
        spk_map[dsname] = spk_stats
        print(text)

    for dsname in all_dataset:
        print("Plotting ...", ctext(dsname, 'cyan'))
        data = data_map[dsname]
        V.plot_figure(nrow=2, ncol=20)
        ax = plt.subplot(1, n_col, 1)
        plot_histogram(data[0], ax, title="Duration")

        ax = plt.subplot(1, n_col, 2)
        plot_histogram(data[1]['sum_per_spk'], ax, title="Dur/Spk")

        ax = plt.subplot(1, n_col, 3)
        plot_histogram(data[1]['nutt_per_spk'], ax, title="#Utt/Spk")

        plt.suptitle(dsname, fontsize=8)

    plot_mean_std(_map=stats_map, title='Data')
    plot_mean_std(_map=spk_map, title='Speaker')

V.plot_save(figure_path, dpi=32)
예제 #29
0
    def plot_disentanglement(
        self,
        factor_indices: Optional[Union[int, str, List[Union[int,
                                                            str]]]] = None,
        n_bins_factors: int = 15,
        n_bins_codes: int = 80,
        corr_type: Union[Literal['spearman', 'pearson', 'lasso', 'average',
                                 'mi'], ndarray] = 'average',
        original_factors: bool = True,
        show_all_codes: bool = False,
        sort_pairs: bool = True,
        title: str = '',
        return_figure: bool = False,
        seed: int = 1,
    ):
        r""" To illustrate the disentanglement of the codes, the codes' histogram
    bars are colored by the value of factors.

    Arguments:
      factor_names : list of String or Integer.
        Name or index of which factors will be used for visualization.
      factor_bins : factor is discretized into bins, then a LogisticRegression
        model will predict the bin (with color) given the code as input.
      corr_type : {'spearman', 'pearson', 'lasso', 'average', 'mi', None, matrix}
        Type of correlation, with special case 'mi' for mutual information.
          - If None, no sorting by correlation provided.
          - If an array, the array must have shape `[n_codes, n_factors]`
      show_all_codes : a Boolean.
        if False, only show most correlated codes-factors, otherwise,
        all codes are shown for each factor.
        This option only in effect when `corr_type` is not `None`.
      original_factors : optional original factors before discretized by
        `Criticizer`
    """
        ### prepare styled plot
        styles = dict(fontsize=12,
                      cbar_horizontal=False,
                      bins_color=int(n_bins_factors),
                      bins=int(n_bins_codes),
                      color='bwr',
                      alpha=0.8)
        # get all relevant factors
        if factor_indices is None:
            factor_indices = list(range(self.n_factors))
        factor_indices = [
            int(i) if isinstance(i, Number) else self.factor_names.index(i)
            for i in as_tuple(factor_indices)
        ]
        ### correlation
        if isinstance(corr_type, string_types):
            if corr_type == 'mi':
                corr = self.mutualinfo_matrix(
                    convert_to_tensor=self.dist_to_tensor, seed=seed)
                score_type = 'mutual-info'
            else:
                corr = self.correlation_matrix(
                    convert_to_tensor=self.dist_to_tensor,
                    method=corr_type,
                    seed=seed)
                score_type = corr_type
            # [n_factors, n_codes]
            corr = corr.T[factor_indices]
        ### directly give the correlation matrix
        elif isinstance(corr_type, ndarray):
            corr = corr_type
            if self.n_latents != self.n_factors and corr.shape[
                    0] == self.n_latents:
                corr = corr.T
            assert corr.shape == (self.n_factors, self.n_latents), \
              (f"Correlation matrix expect shape (n_factors={self.n_factors}, "
               f"n_codes={self.n_codes}) but given shape: {corr.shape}")
            score_type = 'score'
            corr = corr[factor_indices]
        ### exception
        else:
            raise ValueError(
                f"corr_type could be string, None or a matrix but given: {type(corr_type)}"
            )
        ### sorting the latents
        if sort_pairs:
            latent_indices = diagonal_linear_assignment(np.abs(corr),
                                                        nan_policy=0)
        else:
            latent_indices = np.arange(self.n_latents, dtype=np.int32)
        if not show_all_codes:
            latent_indices = latent_indices[:len(factor_indices)]
        corr = corr[:, latent_indices]
        ### prepare the data
        # factors
        F = (self.factors_original
             if original_factors else self.factors)[:, factor_indices]
        factor_names = np.asarray(self.factor_names)[factor_indices]
        # codes
        Z = self.dist_to_tensor(self.latents).numpy()[:, latent_indices]
        latent_names = np.asarray(self.latent_names)[latent_indices]
        ### create the figure
        nrow = F.shape[1]
        ncol = Z.shape[1] + 1
        fig = vs.plot_figure(nrow=nrow * 3, ncol=ncol * 2.8, dpi=100)
        count = 1
        for fidx, (f, fname) in enumerate(zip(F.T, factor_names)):
            # the first plot show how the factor clustered
            ax, _, _ = vs.plot_histogram(x=f,
                                         color_val=f,
                                         ax=(nrow, ncol, count),
                                         cbar=False,
                                         title=f"{fname}",
                                         **styles)
            ax.tick_params(axis='y', labelleft=False)
            count += 1
            # the rest of the row show how the codes align with the factor
            for zidx, (score, z,
                       zname) in enumerate(zip(corr[fidx], Z.T, latent_names)):
                text = "*" if fidx == zidx else ""
                ax, _, _ = vs.plot_histogram(
                    x=z,
                    color_val=f,
                    ax=(nrow, ncol, count),
                    cbar=False,
                    title=f"{text}{fname}-{zname} (${score:.2f}$)",
                    bold_title=True if fidx == zidx else False,
                    **styles)
                ax.tick_params(axis='y', labelleft=False)
                count += 1
        ### fine tune the plot
        fig.suptitle(f"[{score_type}]{title}", fontsize=12)
        fig.tight_layout(rect=[0.0, 0.03, 1.0, 0.97])
        if return_figure:
            return fig
        return self.add_figure(
            f"disentanglement_{'original' if original_factors else 'discretized'}",
            fig)
예제 #30
0
파일: analyze_data.py 프로젝트: imito/odin
with catch_warnings_ignore(RuntimeWarning), catch_warnings_ignore(FutureWarning):
  data_map = {}
  stats_map = {}
  spk_map = {}
  for dsname, text, data, stats, spk_stats in mpi.MPI(jobs=all_dataset, func=dataset_statistics,
                            ncpu=None, batch=1):
    data_map[dsname] = data
    stats_map[dsname] = stats
    spk_map[dsname] = spk_stats
    print(text)

  for dsname in all_dataset:
    print("Plotting ...", ctext(dsname, 'cyan'))
    data = data_map[dsname]
    V.plot_figure(nrow=2, ncol=20)
    ax = plt.subplot(1, n_col, 1)
    plot_histogram(data[0], ax, title="Duration")

    ax = plt.subplot(1, n_col, 2)
    plot_histogram(data[1]['sum_per_spk'], ax, title="Dur/Spk")

    ax = plt.subplot(1, n_col, 3)
    plot_histogram(data[1]['nutt_per_spk'], ax, title="#Utt/Spk")

    plt.suptitle(dsname, fontsize=8)

  plot_mean_std(_map=stats_map, title='Data')
  plot_mean_std(_map=spk_map, title='Speaker')

V.plot_save(figure_path, dpi=32)
예제 #31
0
def plot_latents_binary(Z,
                        y,
                        labels_name,
                        title=None,
                        elev=None,
                        azim=None,
                        algo='tsne',
                        ax=None,
                        show_legend=True,
                        size=12,
                        fontsize=12,
                        show_scores=True,
                        enable_argmax=True,
                        enable_separated=False):
    from matplotlib import pyplot as plt
    if title is None:
        title = ''
    title = '[%s]%s' % (algo, title)
    ax = to_axis(ax)
    # ====== Downsample if the data is huge ====== #
    Z, y = downsample_data(Z, y)
    # ====== checking inputs ====== #
    assert Z.ndim == 2, Z.shape
    assert Z.shape[0] == y.shape[0]
    num_classes = len(labels_name)
    # ====== preprocessing ====== #
    Z = dimension_reduction(Z, algo=algo)
    # ====== clustering metrics ====== #
    if show_scores:
        scores = clustering_scores(
            latent=Z,
            labels=np.argmax(y, axis=-1) if y.ndim == 2 else y,
            n_labels=num_classes)
        title += '\n'
        for k, v in sorted(scores.items(), key=lambda x: x[0]):
            title += '%s:%.2f ' % (k, v)
    # ====== plotting ====== #
    if enable_argmax:
        y_argmax = np.argmax(y, axis=-1) if y.ndim == 2 else y
        fast_scatter(x=Z,
                     y=y_argmax,
                     labels=labels_name,
                     ax=ax,
                     size=size,
                     title=title,
                     fontsize=fontsize,
                     enable_legend=bool(show_legend))
        ax.grid(False)
    # ====== plot each protein ====== #
    if enable_separated:
        colormap = 'Reds'  # bwr
        ncol = 5 if num_classes <= 20 else 9
        nrow = int(np.ceil(num_classes / ncol))
        fig = plot_figure(nrow=4 * nrow, ncol=20)
        for i, lab in enumerate(labels_name):
            val = K.log_norm(y[:, i], axis=0)
            plot_scatter(x=Z[:, 0],
                         y=Z[:, 1],
                         val=val / np.sum(val),
                         ax=(nrow, ncol, i + 1),
                         color=colormap,
                         size=size,
                         alpha=0.8,
                         fontsize=8,
                         grid=False,
                         title=lab)

        plt.grid(False)
        # big title
        plt.suptitle(title, fontsize=fontsize)
        # show the colorbar
        import matplotlib as mpl
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
        cmap = mpl.cm.get_cmap(name=colormap)
        norm = mpl.colors.Normalize(vmin=0., vmax=1.)
        cb1 = mpl.colorbar.ColorbarBase(cbar_ax,
                                        cmap=cmap,
                                        norm=norm,
                                        orientation='vertical')
        cb1.set_label('Protein markers level')
예제 #32
0
  def epoch_end(self, task, epoch_results):
    output_name = self.output_name
    if len(output_name) == 0: # nothing to do
      return
    task_name = self._task_name

    if task.name in task_name:
      self._count -= 1
      # ====== processing results ====== #
      assert all(name in epoch_results for name in output_name),\
      "Given outputs with name: %s; but task: '%s' results only contain name: %s" % \
      (', '.join(self.output_name), str(task), ', '.join(tuple(epoch_results.keys())))

      for name in output_name:
        batch_results = epoch_results[name]
        if name not in self._epoch_results[task.name]:
          self._epoch_results[task.name][name] = []
        self._epoch_results[task.name][name].append(self.fn_reduce(batch_results))
      # ====== start plotting ====== #
      if self._count == 0:
        self._count = self._repeat_freq * len(task_name)
        from odin import visual as V
        n_col = len(task_name)
        n_row = len(output_name)
        if self.save_path is not None:
          from matplotlib import pyplot as plt
        save_figures = False
        override = True

        for o_idx, o_name in enumerate(output_name):
          results = {task_name: r[o_name]
                     for task_name, r in self._epoch_results.items()}
          if all(len(i) >= 2 for i in results.values()):
            # ====== print text plot ====== #
            if self.print_plot:
              text = []
              for t_name in task_name:
                values = results[t_name]
                if isinstance(values[0], Number):
                  t = V.print_bar(f=values, height=8,
                                  title=t_name + "/" + o_name)
                elif isinstance(values[0], np.ndarray) and values[0].ndim == 2 and \
                values[0].shape[0] == values[0].shape[1]:
                  t = V.print_confusion(arr=values[-1],
                                       side_bar=False, inc_stats=True,
                                       float_precision=2)
                else:
                  t = ''
                if len(t) > 0:
                  text.append(t)
              if len(text) > 1:
                print(V.merge_text_graph(*text, padding='  '))
              else:
                print(text[0])
            # ====== matplotlib plot and save pdf ====== #
            if self.save_path is not None:
              for t_idx, t_name in enumerate(task_name):
                values = results[t_name]
                # plotting series
                if isinstance(values[0], Number):
                  if not save_figures:
                    V.plot_figure(nrow=int(n_row * 1.8), ncol=20)
                  save_figures = True

                  max_epoch = np.argmax(values)
                  max_val = values[max_epoch]
                  min_epoch = np.argmin(values)
                  min_val = values[min_epoch]

                  plt.subplot(n_row, n_col, o_idx * len(task_name) + t_idx + 1)
                  plt.plot(values)
                  plt.scatter(max_epoch, max_val, s=180, alpha=0.4, c='r')
                  plt.scatter(min_epoch, min_val, s=180, alpha=0.4, c='g')

                  plt.xlim((0, len(values) - 1))
                  if not np.any(np.isinf(values)):
                    eps = 0.1 * (max_val - min_val)
                    plt.ylim((min_val - eps, max_val + eps))
                  plt.xticks(np.linspace(0, len(values) - 1, num=12,
                                         dtype='int32'))

                  title_text = '[%s]' % o_name if t_idx == 0 else ''
                  title_text += t_name
                  plt.title('%s' % title_text,
                            fontsize=8, fontweight='bold')
        # save figure to pdf or image files
        if save_figures:
          if override:
            save_path = self.save_path
          else:
            path, ext = os.path.splitext(self.save_path)
            save_path = path + ('.%d' % (task.curr_epoch + 1)) + ext
          V.plot_save(save_path, tight_plot=True,
                      clear_all=True, log=False, dpi=180)
          self.send_notification("Saved summary at: %s" % save_path)
    return None
def plot_epoch(task):
    if task is None:
        curr_epoch = 0
    else:
        curr_epoch = task.curr_epoch
        if not (curr_epoch < 5 or curr_epoch % 5 == 0):
            return
    rand = np.random.RandomState(seed=1234)

    X, y = X_test, y_test
    n_data = X.shape[0]
    Z = f_z(X)
    W, W_stdev_mcmc, W_stdev_analytic = f_w(X)

    X_pca, W_pca_1 = fast_pca(X,
                              W,
                              n_components=2,
                              random_state=rand.randint(10e8))
    W_pca_2 = fast_pca(W, n_components=2, random_state=rand.randint(10e8))
    X_count_sum = np.sum(X, axis=tuple(range(1, X.ndim)))
    W_count_sum = np.sum(W, axis=-1)

    n_visual_samples = 8
    nrow = 13 + n_visual_samples * 3
    V.plot_figure(nrow=int(nrow * 1.8), ncol=18)
    with V.plot_gridSpec(nrow=nrow + 3, ncol=6, hspace=0.8) as grid:
        # plot the latent space
        for i, (z, name) in enumerate(zip(Z, Z_names)):
            if z.shape[1] > 2:
                z = fast_pca(z,
                             n_components=2,
                             random_state=rand.randint(10e8))
            ax = V.subplot(grid[:3, (i * 2):(i * 2 + 2)])
            V.plot_scatter(x=z[:, 0],
                           y=z[:, 1],
                           color=y,
                           marker=y,
                           n_samples=4000,
                           ax=ax,
                           legend_enable=False,
                           legend_ncol=n_classes)
            ax.set_title(name, fontsize=12)
        # plot the reconstruction
        for i, (x, name) in enumerate(
                zip([X_pca, W_pca_1, W_pca_2], [
                    'Original data', 'Reconstruction',
                    'Reconstruction (separated PCA)'
                ])):
            ax = V.subplot(grid[3:6, (i * 2):(i * 2 + 2)])
            V.plot_scatter(x=x[:, 0],
                           y=x[:, 1],
                           color=y,
                           marker=y,
                           n_samples=4000,
                           ax=ax,
                           legend_enable=i == 1,
                           legend_ncol=n_classes,
                           title=name)
        # plot the reconstruction count sum
        for i, (x, count_sum, name) in enumerate(
                zip([X_pca, W_pca_1], [X_count_sum, W_count_sum], [
                    'Original data (Count-sum)', 'Reconstruction (Count-sum)'
                ])):
            ax = V.subplot(grid[6:10, (i * 3):(i * 3 + 3)])
            V.plot_scatter(x=x[:, 0],
                           y=x[:, 1],
                           val=count_sum,
                           n_samples=2000,
                           marker=y,
                           ax=ax,
                           size=8,
                           legend_enable=i == 0,
                           legend_ncol=n_classes,
                           title=name,
                           colorbar=True,
                           fontsize=10)
        # plot the count-sum series
        count_sum_observed = np.sum(X, axis=0).ravel()
        count_sum_expected = np.sum(W, axis=0)
        count_sum_stdev_explained = np.sum(W_stdev_mcmc, axis=0)
        count_sum_stdev_total = np.sum(W_stdev_analytic, axis=0)
        for i, kws in enumerate([
                dict(xscale='linear', yscale='linear', sort_by=None),
                dict(xscale='linear', yscale='linear', sort_by='expected'),
                dict(xscale='log', yscale='log', sort_by='expected')
        ]):
            ax = V.subplot(grid[10:10 + 3, (i * 2):(i * 2 + 2)])
            V.plot_series_statistics(count_sum_observed,
                                     count_sum_expected,
                                     explained_stdev=count_sum_stdev_explained,
                                     total_stdev=count_sum_stdev_total,
                                     fontsize=8,
                                     title="Count-sum" if i == 0 else None,
                                     **kws)
        # plot the mean and variances
        curr_grid_index = 13
        ids = rand.permutation(n_data)
        ids = ids[:n_visual_samples]
        for i in ids:
            observed, expected, stdev_explained, stdev_total = \
                X[i], W[i], W_stdev_mcmc[i], W_stdev_analytic[i]
            observed = observed.ravel()
            for j, kws in enumerate([
                    dict(xscale='linear', yscale='linear', sort_by=None),
                    dict(xscale='linear', yscale='linear', sort_by='expected'),
                    dict(xscale='log', yscale='log', sort_by='expected')
            ]):
                ax = V.subplot(grid[curr_grid_index:curr_grid_index + 3,
                                    (j * 2):(j * 2 + 2)])
                V.plot_series_statistics(observed,
                                         expected,
                                         explained_stdev=stdev_explained,
                                         total_stdev=stdev_total,
                                         fontsize=8,
                                         title="Test Sample #%d" %
                                         i if j == 0 else None,
                                         **kws)
            curr_grid_index += 3
    V.plot_save(os.path.join(FIGURE_PATH, 'latent_%d.png' % curr_epoch),
                dpi=200,
                log=True)
    exit()
예제 #34
0
  print(ctext("[Epoch %d]" % epoch, 'yellow'), '%.2f(s)' % (timeit.default_timer() - start_time))
  print("[Training set] Loss: %.4f" % np.mean(train_losses))
  # ====== validation set ====== #
  code_samples, lo = K.eval([Z, loss], feed_dict={X: X_valid})
  print("[Valid set]    Loss: %.4f" % lo)
  # ====== record the history ====== #
  record_train_loss.append(np.mean(train_losses))
  record_valid_loss.append(lo)
  # ====== plotting ====== #
  if args.dim > 2:
    code_samples = ml.fast_pca(code_samples, n_components=2,
                               random_state=K.get_rng().randint(10e8))
  img_samples = f_samples()
  img_mean = f_X(X_valid[:25])

  V.plot_figure(nrow=3, ncol=12)

  ax = plt.subplot(1, 3, 1)
  ax.scatter(code_samples[:, 0], code_samples[:, 1], s=2, c=y_valid, alpha=0.3)
  ax.set_title('Epoch %d' % epoch)
  ax.set_aspect('equal', 'box')
  ax.axis('off')

  ax = plt.subplot(1, 3, 2)
  ax.imshow(V.tile_raster_images(img_samples), cmap=plt.cm.Greys_r)
  ax.axis('off')

  ax = plt.subplot(1, 3, 3)
  ax.imshow(V.tile_raster_images(img_mean), cmap=plt.cm.Greys_r)
  ax.axis('off')
  # ====== check exit condition ====== #
예제 #35
0
    def plot_diagnosis(self, X, labels=None, n_bins=200):
        X, labels, n_classes = self._check_input(X, labels)

        nrow = n_classes
        ncol = 1
        fig = plot_figure(nrow=nrow * 2, ncol=8)
        # add 1 for threshold color
        # add 1 for PDF color
        colors = sns.color_palette(n_colors=self.n_components_per_class + 2)

        for i, (name, (order, gmm)) in enumerate(zip(labels, self._models)):
            start = ncol * i

            means_ = gmm.means_.ravel()[order]
            precision_ = gmm.precisions_.ravel()[order]
            x = self.normalize(X[:, i], test_mode=False)

            # ====== scores ====== #
            # score
            score_llk = gmm.score(x[:, np.newaxis])
            score_bic = gmm.bic(x[:, np.newaxis])
            score_aic = gmm.aic(x[:, np.newaxis])

            # ====== the histogram ====== #
            ax = plt.subplot(nrow, ncol, start + 1)
            count, bins = _draw_hist(x,
                                     ax=ax,
                                     title="[%s] LLK:%.2f BIC:%.2f AIC:%.2f" %
                                     (name, score_llk, score_bic, score_aic),
                                     n_bins=n_bins,
                                     show_yticks=True)

            # ====== draw GMM PDF ====== #
            y_ = np.exp(gmm.score_samples(bins[:, np.newaxis]))
            y_ = (y_ - np.min(y_)) / (np.max(y_) - np.min(y_)) * np.max(count)
            ax.plot(bins,
                    y_,
                    color='red',
                    linestyle='-',
                    linewidth=1.5,
                    alpha=0.6)

            # ====== draw the threshold ====== #
            ci = stats.norm.interval(
                np.abs(self.ci_threshold),
                loc=gmm.means_[order[self.positive_component]],
                scale=np.sqrt(1 /
                              gmm.precisions_[order[self.positive_component]]))
            threshold = ci[0] if self.ci_threshold < 0 else ci[1]
            ids = np.where(bins >= threshold, True, False)
            ax.fill_between(bins[ids],
                            y1=0,
                            y2=np.max(count),
                            facecolor=colors[-2],
                            alpha=0.3)
            ax.text(np.min(bins[ids]), np.min(count), "%.2f" % threshold)

            # ====== plot GMM probability ====== #
            x_ = np.linspace(np.min(bins), np.max(bins), 1200)
            y_ = gmm.predict_proba(x_[:, np.newaxis]) * np.max(count)
            for c, j in zip(colors, y_.T):
                plt.plot(x_,
                         j,
                         color=c,
                         linestyle='--',
                         linewidth=1.8,
                         alpha=0.6)

            # ====== draw the each Gaussian bell ====== #
            ax = ax.twinx()
            _x = np.linspace(start=np.min(x), stop=np.max(x), num=800)
            for c, m, p in zip(colors, means_, precision_):
                with catch_warnings_ignore(Warning):
                    j = mlab.normpdf(_x, m, np.sqrt(1 / p))
                ax.plot(_x, j, color=c, linestyle='-', linewidth=1)
                ax.scatter(_x[np.argmax(j)],
                           np.max(j),
                           s=66,
                           alpha=0.8,
                           linewidth=0,
                           color=c)
            ax.yaxis.set_ticklabels([])

        fig.tight_layout()
        self.add_figure('diagnosis', fig)
        return self
예제 #36
0
  def plot_uncertainty_scatter(self, factors=None, n_samples=2, algo='tsne'):
    r""" Plotting the scatter points of the mean and sampled latent codes,
    colored by the factors.

    Arguments:
      factors : list of Integer or String. The index or name of factors taken
        into account for analyzing.
    """
    factors = self._check_factors(factors)
    # this all include tarin and test data separatedly
    z_mean = np.concatenate(self.representations_mean)
    z_var = np.concatenate(
        [np.mean(var, axis=1) for var in self.representations_variance])
    z_samples = [
        z for z in np.concatenate(self.representations_sample(int(n_samples)),
                                  axis=1)
    ]
    F = np.concatenate(self.original_factors, axis=0)[:, factors]
    labels = self.factors_name[factors]
    # preprocessing
    inputs = tuple([z_mean] + z_samples)
    Z = dimension_reduce(*inputs,
                         algo=algo,
                         n_components=2,
                         return_model=False,
                         combined=True,
                         random_state=self.randint)
    V = utils.discretizing(z_var[:, np.newaxis], n_bins=10).ravel()
    # the figure
    nrow = 3
    ncol = int(np.ceil(len(labels) / nrow))
    fig = vs.plot_figure(nrow=nrow * 4, ncol=ncol * 4, dpi=80)
    for idx, (name, y) in enumerate(zip(labels, F.T)):
      ax = vs.plot_subplot(nrow, ncol, idx + 1)
      for i, x in enumerate(Z):
        kw = dict(val=y,
                  color="coolwarm",
                  ax=ax,
                  x=x,
                  grid=False,
                  legend_enable=False,
                  centroids=True,
                  fontsize=12)
        if i == 0:  # the mean value
          vs.plot_scatter(size=V,
                          size_range=(8, 80),
                          alpha=0.3,
                          linewidths=0,
                          cbar=True,
                          cbar_horizontal=True,
                          title=name,
                          **kw)
        else:  # the samples
          vs.plot_scatter_text(size=8,
                               marker='x',
                               alpha=0.8,
                               weight='light',
                               **kw)
    # fig.tight_layout()
    self.add_figure("uncertainty_scatter_%s" % algo, fig)
    return self
예제 #37
0
파일: helpers.py 프로젝트: imito/odin
def prepare_dnn_data(save_dir, feat_name=None,
                     utt_length=None, seq_mode=None,
                     min_dur=None, min_utt=None,
                     exclude=None, train_proportion=None,
                     return_dataset=False):
  assert os.path.isdir(save_dir), \
      "Path to '%s' is not a directory" % save_dir
  if feat_name is None:
    feat_name = FEATURE_NAME
  if utt_length is None:
    utt_length = int(_args.utt)
  if seq_mode is None:
    seq_mode = str(_args.seq).strip().lower()
  if min_dur is None:
    min_dur = MINIMUM_UTT_DURATION
  if min_utt is None:
    min_utt = MINIMUM_UTT_PER_SPEAKERS
  if exclude is None:
    exclude = str(_args.exclude).strip()
  print("Minimum duration: %s(s)" % ctext(min_dur, 'cyan'))
  print("Minimum utt/spk : %s(utt)" % ctext(min_utt, 'cyan'))
  # ******************** prepare dataset ******************** #
  path = os.path.join(PATH_ACOUSTIC_FEATURES, FEATURE_RECIPE)
  assert os.path.exists(path), "Cannot find acoustic dataset at path: %s" % path
  ds = F.Dataset(path=path, read_only=True)
  rand = np.random.RandomState(seed=Config.SUPER_SEED)
  # ====== find the right feature ====== #
  assert feat_name in ds, "Cannot find feature with name: %s" % feat_name
  X = ds[feat_name]
  ids_name = 'indices_%s' % feat_name
  assert ids_name in ds, "Cannot find indices with name: %s" % ids_name
  # ====== basic path ====== #
  path_filtered_data = os.path.join(save_dir, 'filtered_files.pkl')
  path_train_files = os.path.join(save_dir, 'train_files.pkl')
  path_speaker_info = os.path.join(save_dir, 'speaker_info.pkl')
  # ******************** cannot find cached data ******************** #
  if any(not os.path.exists(p) for p in [path_filtered_data,
                                         path_train_files,
                                         path_speaker_info]):
    # ====== exclude some dataset ====== #
    if len(exclude) > 0:
      exclude_dataset = {i: 1 for i in exclude.split(',')}
      print("* Excluded dataset:", ctext(exclude_dataset, 'cyan'))
      indices = {name: (start, end)
                 for name, (start, end) in ds[ids_name].items()
                 if ds['dsname'][name] not in exclude_dataset}
      # special case exclude all the noise data
      if 'noise' in exclude_dataset:
        indices = {name: (start, end)
                   for name, (start, end) in indices.items()
                   if '/' not in name}
    else:
      indices = {i: j for i, j in ds[ids_name].items()}
    # ====== down-sampling if necessary ====== #
    if _args.downsample > 1000:
      dataset2name = defaultdict(list)
      # ordering the indices so we sample the same set every time
      for name in sorted(indices.keys()):
        dataset2name[ds['dsname'][name]].append(name)
      n_total_files = len(indices)
      n_sample_files = int(_args.downsample)
      # get the percentage of each dataset
      dataset2per = {i: len(j) / n_total_files
                     for i, j in dataset2name.items()}
      # sampling based on percentage
      _ = {}
      for dsname, flist in dataset2name.items():
        rand.shuffle(flist)
        n_dataset_files = int(dataset2per[dsname] * n_sample_files)
        _.update({i: indices[i]
                  for i in flist[:n_dataset_files]})
      indices = _
    # ====== * filter out "bad" sample ====== #
    indices = filter_utterances(X=X, indices=indices, spkid=ds['spkid'],
                                min_utt=min_utt, min_dur=min_dur,
                                remove_min_length=True,
                                remove_min_uttspk=True,
                                n_speakers=None, ncpu=None,
                                save_path=path_filtered_data)
    # ====== all training file name ====== #
    # modify here to train full dataset
    all_name = sorted(indices.keys())
    rand.shuffle(all_name); rand.shuffle(all_name)
    n_files = len(all_name)
    print("#Files:", ctext(n_files, 'cyan'))
    # ====== speaker mapping ====== #
    name2spk = {name: ds['spkid'][name]
                for name in all_name}
    all_speakers = sorted(set(name2spk.values()))
    spk2label = {spk: i
                 for i, spk in enumerate(all_speakers)}
    name2label = {name: spk2label[spk]
                  for name, spk in name2spk.items()}
    assert len(name2label) == len(all_name)
    print("#Speakers:", ctext(len(all_speakers), 'cyan'))
    # ====== stratify sampling based on speaker ====== #
    valid_name = []
    # create speakers' cluster
    label2name = defaultdict(list)
    for name, label in sorted(name2label.items(),
                              key=lambda x: x[0]):
      label2name[label].append(name)
    # for each speaker with >= 3 utterance
    for label, name_list in sorted(label2name.items(),
                                   key=lambda x: x[0]):
      if len(name_list) < 3:
        continue
      n = max(1, int(0.05 * len(name_list))) # 5% for validation
      valid_name += rand.choice(a=name_list, size=n, replace=False).tolist()
    # train list is the rest
    _ = set(valid_name)
    train_name = [i for i in all_name if i not in _]
    # ====== split training and validation ====== #
    train_indices = {name: indices[name] for name in train_name}
    valid_indices = {name: indices[name] for name in valid_name}
    # ====== save cached data ====== #
    with open(path_train_files, 'wb') as fout:
      pickle.dump({'train': train_indices, 'valid': valid_indices},
                  fout)
    with open(path_speaker_info, 'wb') as fout:
      pickle.dump({'all_speakers': all_speakers,
                   'name2label': name2label,
                   'spk2label': spk2label},
                  fout)
  # ******************** load cached data ******************** #
  else:
    with open(path_train_files, 'rb') as fin:
      obj = pickle.load(fin)
      train_indices = obj['train']
      valid_indices = obj['valid']
    with open(path_speaker_info, 'rb') as fin:
      obj = pickle.load(fin)
      all_speakers = obj['all_speakers']
      name2label = obj['name2label']
      spk2label = obj['spk2label']

  # ******************** print log ******************** #
  def summary_indices(ids):
    datasets = defaultdict(int)
    speakers = defaultdict(list)
    text = ''
    for name in sorted(ids.keys()):
      text += name + str(ids[name])
      dsname = ds['dsname'][name]
      datasets[dsname] += 1
      speakers[dsname].append(ds['spkid'][name])
    for dsname in sorted(datasets.keys()):
      print('  %-18s: %s(utt) %s(spk)' % (
          dsname,
          ctext('%6d' % datasets[dsname], 'cyan'),
          ctext(len(set(speakers[dsname])), 'cyan')))
    print('  MD5 checksum:', ctext(crypto.md5_checksum(text), 'lightcyan'))
  # ====== training files ====== #
  print("#Train files:", ctext('%-8d' % len(train_indices), 'cyan'),
        "#spk:", ctext(len(set(name2label[name]
                               for name in train_indices.keys())), 'cyan'),
        "#noise:", ctext(len([name for name in train_indices.keys()
                              if '/' in name]), 'cyan'))
  summary_indices(ids=train_indices)
  # ====== valid files ====== #
  print("#Valid files:", ctext('%-8d' % len(valid_indices), 'cyan'),
        "#spk:", ctext(len(set(name2label[name]
                               for name in valid_indices.keys())), 'cyan'),
        "#noise:", ctext(len([name for name in valid_indices.keys()
                              if '/' in name]), 'cyan'))
  summary_indices(ids=valid_indices)
  # ******************** create the recipe ******************** #
  assert all(name in name2label
             for name in train_indices.keys())
  assert all(name in name2label
            for name in valid_indices.keys())
  recipes = prepare_dnn_feeder_recipe(name2label=name2label,
                                      n_speakers=len(all_speakers),
                                      utt_length=utt_length, seq_mode=seq_mode)
  # ====== downsample training set for analyzing if required ====== #
  if train_proportion is not None:
    assert 0 < train_proportion < 1
    n_training = len(train_indices)
    train_indices = list(train_indices.items())
    rand.shuffle(train_indices); rand.shuffle(train_indices)
    train_indices = dict(train_indices[:int(n_training * train_proportion)])
  # ====== create feeder ====== #
  train_feeder = F.Feeder(
      data_desc=F.IndexedData(data=X,
                              indices=train_indices),
      batch_mode='batch', ncpu=NCPU, buffer_size=256)

  valid_feeder = F.Feeder(
      data_desc=F.IndexedData(data=X,
                              indices=valid_indices),
      batch_mode='batch', ncpu=max(2, NCPU // 4), buffer_size=64)

  train_feeder.set_recipes(recipes)
  valid_feeder.set_recipes(recipes)
  print(train_feeder)
  print(valid_feeder)
  # ====== debugging ====== #
  if IS_DEBUGGING:
    import matplotlib
    matplotlib.use('Agg')
    prog = Progbar(target=len(valid_feeder), print_summary=True,
                   name="Iterating validation set")
    samples = []
    n_visual = 250
    for name, idx, X, y in valid_feeder.set_batch(batch_size=100000,
                                                  batch_mode='file',
                                                  seed=None, shuffle_level=0):
      assert idx == 0, "Utterances longer than %.2f(sec)" % (100000 * Config.STEP_LENGTH)
      prog['X'] = X.shape
      prog['y'] = y.shape
      prog.add(X.shape[0])
      # random sampling
      if rand.rand(1) < 0.5 and len(samples) < n_visual:
        for i in rand.randint(0, X.shape[0], size=4, dtype='int32'):
          samples.append((name, X[i], np.argmax(y[i], axis=-1)))
    # plot the spectrogram
    n_visual = len(samples)
    V.plot_figure(nrow=n_visual, ncol=8)
    for i, (name, X, y) in enumerate(samples):
      is_noise = '/' in name
      assert name2label[name] == y, "Speaker label mismatch for file: %s" % name
      name = name.split('/')[0]
      dsname = ds['dsname'][name]
      spkid = ds['spkid'][name]
      y = np.argmax(y, axis=-1)
      ax = V.plot_spectrogram(X.T,
                              ax=(n_visual, 1, i + 1),
                              title='#%d' % (i + 1))
      ax.set_title('[%s][%s]%s  %s' %
                   ('noise' if is_noise else 'clean', dsname, name, spkid),
                   fontsize=6)
    # don't need to be high resolutions
    V.plot_save('/tmp/tmp.pdf', dpi=12)
    exit()
  # ====== return ====== #
  if bool(return_dataset):
    return train_feeder, valid_feeder, all_speakers, ds
  return train_feeder, valid_feeder, all_speakers