Ejemplo n.º 1
0
 def dimension_reduce(
     self,
     algorithm: Literal['pca', 'umap', 'tsne', 'knn', 'kmean'] = 'tsne',
     seed: int = 1,
 ) -> np.ndarray:
     """Applying dimension reduction on latents space, this method will cache the
 returns to lower computational cost."""
     key = f'{id(self)}_{id(self.dist_to_tensor)}_{algorithm}_{int(seed)}'
     if key in _CACHE_LATENTS:
         return _CACHE_LATENTS[key]
     x = self.dist_to_tensor(self.latents).numpy()
     x = dimension_reduce(x, algo=algorithm, random_state=seed)
     _CACHE_LATENTS[key] = x
     return x
Ejemplo n.º 2
0
  def plot_uncertainty_scatter(self, factors=None, n_samples=2, algo='tsne'):
    r""" Plotting the scatter points of the mean and sampled latent codes,
    colored by the factors.

    Arguments:
      factors : list of Integer or String. The index or name of factors taken
        into account for analyzing.
    """
    factors = self._check_factors(factors)
    # this all include tarin and test data separatedly
    z_mean = np.concatenate(self.representations_mean)
    z_var = np.concatenate(
        [np.mean(var, axis=1) for var in self.representations_variance])
    z_samples = [
        z for z in np.concatenate(self.representations_sample(int(n_samples)),
                                  axis=1)
    ]
    F = np.concatenate(self.original_factors, axis=0)[:, factors]
    labels = self.factors_name[factors]
    # preprocessing
    inputs = tuple([z_mean] + z_samples)
    Z = dimension_reduce(*inputs,
                         algo=algo,
                         n_components=2,
                         return_model=False,
                         combined=True,
                         random_state=self.randint)
    V = utils.discretizing(z_var[:, np.newaxis], n_bins=10).ravel()
    # the figure
    nrow = 3
    ncol = int(np.ceil(len(labels) / nrow))
    fig = vs.plot_figure(nrow=nrow * 4, ncol=ncol * 4, dpi=80)
    for idx, (name, y) in enumerate(zip(labels, F.T)):
      ax = vs.plot_subplot(nrow, ncol, idx + 1)
      for i, x in enumerate(Z):
        kw = dict(val=y,
                  color="coolwarm",
                  ax=ax,
                  x=x,
                  grid=False,
                  legend_enable=False,
                  centroids=True,
                  fontsize=12)
        if i == 0:  # the mean value
          vs.plot_scatter(size=V,
                          size_range=(8, 80),
                          alpha=0.3,
                          linewidths=0,
                          cbar=True,
                          cbar_horizontal=True,
                          title=name,
                          **kw)
        else:  # the samples
          vs.plot_scatter_text(size=8,
                               marker='x',
                               alpha=0.8,
                               weight='light',
                               **kw)
    # fig.tight_layout()
    self.add_figure("uncertainty_scatter_%s" % algo, fig)
    return self