예제 #1
0
    def call(self, state):
        # Get mean and standard deviation from the policy network
        a1 = self.dense1_layer(state)
        a2 = self.dense2_layer(a1)
        mu = self.mean_layer(a2)

        # Standard deviation is bounded by a constraint of being non-negative
        # therefore we produce log stdev as output which can be [-inf, inf]
        log_sigma = self.stdev_layer(a2)
        sigma = tf.exp(log_sigma)

        # Use re-parameterization trick to deterministically sample action from
        # the policy network. First, sample from a Normal distribution of
        # sample size as the action and multiply it with stdev
        dist = Normal(mu, sigma)
        action_ = dist.sample()

        # Apply the tanh squashing to keep the gaussian bounded in (-1,1)
        action = tf.tanh(action_)

        # Calculate the log probability
        log_pi_ = dist.log_prob(action_)
        log_pi = log_pi_ - tf.reduce_sum(
            tf.math.log(1 - action**2 + eps), axis=1, keepdims=True)
        return action, log_pi
예제 #2
0
def total_correlation(z_samples: Tensor, qZ_X: Distribution) -> Tensor:
    r"""Estimate of total correlation using Gaussian distribution on a batch.

  We need to compute the expectation over a batch of:
  `E_j [log(q(z(x_j))) - log(prod_l q(z(x_j)_l))]`

  We ignore the constants as they do not matter for the minimization.
  The constant should be equal to
  `(num_latents - 1) * log(batch_size * dataset_size)`

  If `alpha = gamma = 1`, Eq(4) can be written as `ELBO + (1 - beta) * TC`.
  (i.e. `(1. - beta) * total_correlation(z_sampled, qZ_X)`)

  Parameters
  ----------
  z_samples : Tensor
      shape `[batch_size, num_latents]` - tensor with sampled representation.
  qZ_X : Distribution
      the posterior distribution, shape `[batch_size, num_latents]`

  Note
  ----
  This involve calculating pair-wise distance, memory complexity up to
      `O(n*n*d)`.

  Returns
  -------
  Total correlation estimated on a batch.

  References
  ----------
  Chen, R.T.Q., Li, X., Grosse, R., Duvenaud, D., 2019. Isolating Sources of
      Disentanglement in Variational Autoencoders. arXiv:1802.04942 [cs, stat].
  Github code https://github.com/google-research/disentanglement_lib
  """
    gaus = Normal(loc=tf.expand_dims(qZ_X.mean(), 0),
                  scale=tf.expand_dims(qZ_X.stddev(), 0))
    # Compute log(q(z(x_j)|x_i)) for every sample in the batch, which is a
    # tensor of size [batch_size, batch_size, num_latents]. In the following
    # comments, [batch_size, batch_size, num_latents] are indexed by [j, i, l].
    log_qz_prob = gaus.log_prob(tf.expand_dims(z_samples, 1))
    # Compute log prod_l p(z(x_j)_l) = sum_l(log(sum_i(q(z(z_j)_l|x_i)))
    # + constant) for each sample in the batch, which is a vector of size
    # [batch_size,].
    log_qz_product = tf.reduce_sum(tf.reduce_logsumexp(log_qz_prob,
                                                       axis=1,
                                                       keepdims=False),
                                   axis=1,
                                   keepdims=False)
    # Compute log(q(z(x_j))) as log(sum_i(q(z(x_j)|x_i))) + constant =
    # log(sum_i(prod_l q(z(x_j)_l|x_i))) + constant.
    log_qz = tf.reduce_logsumexp(tf.reduce_sum(log_qz_prob,
                                               axis=2,
                                               keepdims=False),
                                 axis=1,
                                 keepdims=False)
    return tf.reduce_mean(log_qz - log_qz_product)
예제 #3
0
    def compute_KL_univariate_prior(self,univariateprior, samples):
        """
        :param prior:  assuming univatier prior of Normal(m,s); i.e. Normal(s,
        :param posterior: (theta: mean,std) to create posterior q(w/theta) i.e. Normal(mean,std)
        :param sample:
        :return:

        """

        samples = tf.reshape(samples, [-1])  # flatten vector
        (mean2, std2) = univariateprior
        prior = Normal(mean2, std2)
        posterior = Normal(self.mu, self.sigma)
        q_theta = tf.reduce_sum(posterior.log_prob(samples))
        p_d = tf.reduce_sum(prior.log_prob(self.samples))
        KL = tf.subtract(q_theta, p_d)

        return KL
예제 #4
0
 def elbo_components(self, inputs, training=None, mask=None):
   X_u, y_u, X_l, y_l = _prepare_elbo(self,
                                      inputs,
                                      training=training,
                                      mask=mask)
   y_l = tf.clip_by_value(y_l, 1e-8, 1. - 1e-8)
   px_z_u, (qz_x_u, qzc_x_u, qy_zx_u) = self(X_u, training=training)
   px_z_l, (qz_x_l, qzc_x_l, qy_zx_l) = self(X_l, training=training)
   z_exc = tf.concat(
     [tf.convert_to_tensor(qz_x_u),
      tf.convert_to_tensor(qz_x_l)], axis=0)
   z_c = tf.concat(
     [tf.convert_to_tensor(qzc_x_u),
      tf.convert_to_tensor(qzc_x_l)], axis=0)
   # Convert y to one-hot vector and Sample y for those without labels
   y_sup = y_l
   y_uns = tf.convert_to_tensor(qy_zx_u)
   y = tf.concat((y_uns, y_sup), axis=0)
   # log q(y|z_c)
   h = tf.concat([qy_zx_u.logits, qy_zx_l.logits], axis=0)
   log_q_y_zc = tf.reduce_sum(h * y, axis=1)
   # log p(x|z)
   log_p_x_z = tf.concat([px_z_u.log_prob(X_u), px_z_l.log_prob(X_l)], axis=0)
   # log p(z_c|y)
   pzc_y = self.regressor(y)
   log_p_zc_y = pzc_y.log_prob(z_c)
   # log p(z_\c)
   dist = Normal(tf.cast(0., self.dtype), 1.)
   log_p_zexc = tf.reduce_sum(dist.log_prob(z_exc), axis=-1)
   # log p(z|y)
   log_p_z_y = log_p_zc_y + log_p_zexc
   # log q(y|x)  (Draw 128 points from q(z_c|x). Supervised samples only)
   h = qzc_x_l.sample(self.n_resamples)
   h = tf.reshape(h, (-1, h.shape[-1]))
   qy_x = self.classify(h, training=training)
   qy_x_logits = tf.reshape(qy_x.logits, (self.n_resamples, -1, h.shape[-1]))
   h = tf.reduce_logsumexp(h, axis=0) - tf.math.log(128.)
   log_q_y_x = tf.reduce_sum(h * y_l, axis=1)
   # log q(z|x)
   log_qz_x = tf.concat([qz_x_u.log_prob(qz_x_u),
                         qz_x_l.log_prob(qz_x_l)],
                        axis=0)
   log_qzc_x = tf.concat(
     [qzc_x_u.log_prob(qzc_x_u),
      qzc_x_l.log_prob(qzc_x_l)], axis=0)
   log_q_z_x = log_qz_x + log_qzc_x
   # Calculate the lower bound
   n_uns = ps.shape(X_u)[0]
   h = log_p_x_z + log_p_z_y - log_q_y_zc - log_q_z_x
   coef_sup = tf.math.exp(log_q_y_zc[n_uns:] - log_q_y_x)
   coef_uns = tf.ones((n_uns,), dtype=self.dtype)
   coef = tf.concat((coef_uns, coef_sup), axis=0)
   zeros = tf.zeros((n_uns,), dtype=self.dtype)
   lb = coef * h + tf.concat((zeros, log_q_y_x), axis=0)
   return {'elbo': lb}, {}
예제 #5
0
def compute_KL_univariate_prior(univariateprior, theta, sample):
    """
        :param prior:  assuming univariate prior of Normal(m,s);
        :param posterior: (theta: mean,std) to create posterior q(w/theta) i.e. Normal(mean,std)
        :param sample: Number of sample
        """
    sample = tf.reshape(sample, [-1])  #flatten vector
    (mean, std) = theta
    mean = tf.reshape(mean, [-1])
    std = tf.reshape(std, [-1])
    posterior = Normal(mean, std)

    (mean2, std2) = univariateprior
    prior = Normal(mean2, std2)

    q_theta = tf.reduce_sum(posterior.log_prob(sample))
    p_d = tf.reduce_sum(prior.log_prob(sample))

    KL = tf.subtract(q_theta, p_d)

    return KL
예제 #6
0
def get_KL_univariate_prior(univariateprior, theta, sample):
    """
        :param prior:  assuming univatier prior of Normal(m,s); i.e. Normal(s,
        :param posterior: (theta: mean,std) to create posterior q(w/theta) i.e. Normal(mean,std)
        :param sample:
        :return:

        """

    sample = tf.reshape(sample, [-1])  #flatten vector
    (mean, std) = theta
    (mean2, std2) = univariateprior
    prior = Normal(mean2, std2)
    posterior = Normal(mean, std)

    q_theta = tf.reduce_sum(posterior.log_prob(sample))
    p_d = tf.reduce_sum(prior.log_prob(sample))

    KL = tf.subtract(q_theta, p_d)

    return KL
예제 #7
0
    def get_KL_multivariate_prior(self, multivariateprior, theta, sample):
        """
        :param prior:  assuming univatier prior of Normal(m,s); i.e. Normal(m1,s1) and Normal(m2,s2)
        :param posterior: (theta: mean,std) to create posterior q(w/theta) i.e. Normal(mean,std)
        :param sample:
        :return:

        """

        sample = tf.reshape(sample, [-1])  #flatten vector
        (mean, std) = theta
        posterior = Normal(mean, std)

        (std1, std2) = multivariateprior
        prior1 = Normal(0, std1)
        prior2 = Normal(0, std2)

        q_theta = tf.reduce_sum(posterior.log_prob(sample))
        p1 = tf.reduce_sum(prior1.log_prob(sample))
        p2 = tf.reduce_sum(
            prior2.log_prob(sample))  #this is wrong need to work this out
        KL = tf.subtract(q_theta, tf.reduce_logsumexp([p1, p2]))

        return KL
예제 #8
0
    def compute_KL_univariate_prior(self, univariateprior, theta, sample):

        """
        :param prior:  assuming univariate prior of Normal(m,s);
        :param posterior: (theta: mean,std) to create posterior q(w/theta) i.e. Normal(mean,std)
        :param sample:
        :return: KL (analytical)

        """

        sample=tf.reshape(sample, [-1])  #flatten vector
        (mean,std)=theta
        mean =tf.reshape(mean, [-1])
        std=tf.reshape(std, [-1])
        posterior = Normal(mean, std)
        (mean2,std2) = univariateprior
        prior=Normal(mean2, std2)
        q_theta=tf.reduce_sum(posterior.log_prob(sample))
        p_d=tf.reduce_sum(prior.log_prob(sample))
        KL=tf.subtract(q_theta,p_d)

        print("computed KL loss" + self.layer_name)

        return KL
예제 #9
0
class NormalPyramid(Distribution):

    # means is indexed by *, y, x, channel
    # base_sigma is a scalar, and is the std used for the 'raw' pixels; subsequent levels use smaller stds

    def __init__(self,
                 means,
                 base_sigma,
                 levels=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='NormalPyramid'):
        with ops.name_scope(name, values=[means, base_sigma]) as ns:
            self._means = array_ops.identity(means, name='means')
            self._base_sigma = array_ops.identity(base_sigma,
                                                  name='base_sigma')
            self._base_dist = Normal(loc=self._means, scale=self._base_sigma)
            self._standard_normal = Normal(loc=0., scale=1.)
            self._levels = levels
            super(NormalPyramid,
                  self).__init__(dtype=tf.float32,
                                 parameters={
                                     'means': means,
                                     'base_sigma': base_sigma
                                 },
                                 reparameterization_type=FULLY_REPARAMETERIZED,
                                 validate_args=validate_args,
                                 allow_nan_stats=allow_nan_stats,
                                 name=ns)

    def _log_prob(self, x):
        # The resulting density here will be indexed by *, i.e. we sum over x, y, channel, and pyramid-levels
        z = (x - self._means) / self._base_sigma
        z_shape = list(map(int, z.get_shape()))
        z_pyramid = gaussian_pyramid(z, self._levels)
        return sum(
            tf.reduce_mean(self._standard_normal.log_prob(z_level),
                           axis=[-3, -2, -1])  # ** check the rescaling here!
            for level_index, z_level in enumerate(z_pyramid)) / len(z_pyramid)

    def _sample_n(self, n, seed=None):
        return self._base_dist._sample_n(n, seed)

    def _mean(self):
        return self._means

    def _mode(self):
        return self._means
예제 #10
0
파일: utils.py 프로젝트: trungnt13/odin-ai
 def latent_units(model: VariationalModel, valid_ds: tf.data.Dataset):
     weights = model.weights
     Qz = []
     Pz = []
     for x, y in valid_ds.take(10):
         _call(model, x, y, decode=True)
         qz, pz = model.get_latents(return_prior=True)
         Qz.append(as_tuple(qz))
         Pz.append(as_tuple(pz))
     n_latents = len(Qz[0])
     for i in range(n_latents):
         qz: Sequence[Distribution] = [q[i] for q in Qz]
         pz: Sequence[Distribution] = [p[i] for p in Pz]
         # tracking kl
         kld = []
         for q, p in zip(qz, pz):
             q = Normal(loc=q.mean(), scale=q.stddev())
             z = q.sample()
             if isinstance(p, Vamprior):
                 C = p.C
                 p = p.distribution  # [n_components, zdim]
                 p = Normal(loc=p.mean(), scale=p.stddev())
                 kld.append(
                     q.log_prob(z) - (tf.reduce_logsumexp(
                         p.log_prob(tf.expand_dims(z, 1)), 1) - C))
             else:
                 p = Normal(loc=p.mean(), scale=p.stddev())
                 kld.append(q.log_prob(z) - p.log_prob(z))
         kld = tf.reshape(tf.reduce_mean(tf.concat(kld, 0), axis=0),
                          -1).numpy()
         # mean and stddev
         mean = tf.reduce_mean(tf.concat([d.mean() for d in qz], axis=0), 0)
         stddev = tf.reduce_mean(
             tf.concat([d.stddev() for d in qz], axis=0), 0)
         mean = tf.reshape(mean, -1).numpy()
         stddev = tf.reshape(stddev, -1).numpy()
         zdim = mean.shape[0]
         # the figure
         plt.figure(figsize=(12, 5), dpi=50)
         lines = []
         ids = np.argsort(stddev)
         styles = dict(marker='o', markersize=2, linewidth=0, alpha=0.8)
         lines += plt.plot(mean[ids], label='mean', color='r', **styles)
         lines += plt.plot(stddev[ids], label='stddev', color='b', **styles)
         plt.grid(False)
         # show weights if exists
         plt.twinx()
         lines += plt.plot(kld[ids],
                           label='KL(q|p)',
                           linestyle='--',
                           linewidth=1.0,
                           alpha=0.6)
         for w in weights:
             name = w.name
             if w.shape.rank > 0 and w.shape[
                     0] == zdim and '/kernel' in name:
                 w = tf.linalg.norm(tf.reshape(w, (w.shape[0], -1)),
                                    axis=1).numpy()
                 lines += plt.plot(w[ids],
                                   label=name.split(':')[0],
                                   linestyle='--',
                                   alpha=0.6)
         plt.grid(False)
         plt.legend(lines, [ln.get_label() for ln in lines], fontsize=6)
         # save summary
         tf.summary.image(f'z{i}',
                          vs.plot_to_image(plt.gcf()),
                          step=model.step)
         tf.summary.histogram(f'z{i}/mean', mean, step=model.step)
         tf.summary.histogram(f'z{i}/stddev', stddev, step=model.step)
         tf.summary.histogram(f'z{i}/kld', kld, step=model.step)
예제 #11
0
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.models import Model
from tensorflow_probability.python.distributions import Normal

env = gym.make('MountainCarContinuous-v0')
state = env.reset()
state_input = Input(state.shape, name='state_input')
actor_dense1_layer = Dense(4, activation=relu, name='actor_dense1')
actor_dense1 = actor_dense1_layer(state_input)
actor_out_layer = Dense(2, activation=None, name='actor_nn_out')
actor_out = actor_out_layer(actor_dense1)
mean = actor_out[:, 0]
std = softplus(actor_out[:, 1])
dist = Normal(mean, std, name='dist')
action = dist.sample((), name='sample')
action_log_prob = dist.log_prob(action, name='log_prob')
action_log_prob = expand_dims(action_log_prob)
action = expand_dims(action)
action = clip(action, -1.0, 1.0)
critic_concat = Concatenate()([state_input, action])
critic_dense1_layer = Dense(units=8, activation=relu, name='critic_dense1')
critic_dense1 = critic_dense1_layer(critic_concat)
critic_out_layer = Dense(1, activation=None)
critic_out = critic_out_layer(critic_dense1)
actor_critic_model = Model(inputs=[state_input],
                           outputs=[action, action_log_prob, critic_out])
critic_loss_op = MeanSquaredError(name='critic_loss')
actor_vars = actor_dense1_layer.trainable_variables + actor_out_layer.trainable_variables
critic_vars = critic_dense1_layer.trainable_variables + critic_out_layer.trainable_variables
GAMES = 100
LR = 1e-4
예제 #12
0
 def get_log_prob(self, states, actions):
     mean, std = self._get_dist(states)
     dist = Normal(mean, std)
     log_prob = tf.reduce_sum(dist.log_prob(actions), -1)
     return log_prob
예제 #13
0
 def call(self, states, **kwargs):
     mean, std = self._get_dist(states)
     dist = Normal(mean, std)
     action = dist.sample()
     log_prob = tf.reduce_sum(dist.log_prob(action), -1)
     return action, log_prob
예제 #14
0
def evaluate(vae: VariationalAutoencoder,
             ds: ImageDataset,
             expdir: str,
             title: str,
             batch_size: int = 64,
             take_count: int = -1,
             n_images: int = 36,
             seed: int = 1):
    n_rows = int(np.sqrt(n_images))
    is_semi = vae.is_semi_supervised()
    is_hierarchical = vae.is_hierarchical()
    ds_kw = dict(batch_size=batch_size, label_percent=1.0, shuffle=False)
    ## prepare
    rand = np.random.RandomState(seed=seed)
    if not os.path.exists(expdir):
        os.makedirs(expdir)
    ## data for training semi-supervised
    train = ds.create_dataset('train', **ds_kw)
    (llkx_train, llky_train, x_org_train, x_rec_train, y_true_train,
     y_pred_train, z_train, pz_train) = _call(vae,
                                              ds=train,
                                              rand=rand,
                                              take_count=take_count,
                                              n_images=n_images,
                                              verbose=True)
    ## data for testing
    test = ds.create_dataset('test', **ds_kw)
    (llkx_test, llky_test, x_org_test, x_rec_test, y_true_test, y_pred_test,
     z_test, pz_test) = _call(vae,
                              ds=test,
                              rand=rand,
                              take_count=take_count,
                              n_images=n_images,
                              verbose=True)
    # === 0. plotting latent-factor pairs
    for idx, z in enumerate(z_test):
        z = z.mean()
        f = y_true_test
        corr_mat = Correlation.Spearman(z, f)  # [n_latents, n_factors]
        plot_latents_pairs(z, f, corr_mat, ds.labels)
        vs.plot_save(f'{expdir}/latent{idx}_factor.pdf', dpi=100, verbose=True)
    # === 0. latent traverse plot
    x_travs = x_org_test
    if x_travs.ndim == 3:  # grayscale image
        x_travs = np.expand_dims(x_travs, -1)
    else:  # color image
        x_travs = np.transpose(x_travs, (0, 2, 3, 1))
    x_travs = x_travs[rand.permutation(x_travs.shape[0])]
    n_visual_samples = 5
    n_traverse_points = 21
    n_top_latents = 10
    plt.figure(figsize=(8, 3 * n_visual_samples))
    for i in range(n_visual_samples):
        images = vae.sample_traverse(x_travs[i:i + 1],
                                     min_val=-np.min(z_test[0].mean()),
                                     max_val=np.max(z_test[0].mean()),
                                     n_best_latents=n_top_latents,
                                     n_traverse_points=n_traverse_points,
                                     mode='linear')
        images = as_tuple(images)[0]
        images = _prepare_images(images.mean().numpy(), normalize=True)
        vs.plot_images(images,
                       grids=(n_top_latents, n_traverse_points),
                       ax=(n_visual_samples, 1, i + 1))
        if i == 0:
            plt.title('Latents traverse')
    plt.tight_layout()
    vs.plot_save(f'{expdir}/latents_traverse.pdf', dpi=180, verbose=True)
    # === 0. prior sampling plot
    images = as_tuple(vae.sample_observation(n=n_images, seed=seed))[0]
    images = _prepare_images(images.mean().numpy(), normalize=True)
    plt.figure(figsize=(5, 5))
    vs.plot_images(images, grids=(n_rows, n_rows), title='Sampled')
    # === 1. reconstruction plot
    plt.figure(figsize=(15, 15))
    vs.plot_images(x_org_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 1),
                   title='[Train]Original')
    vs.plot_images(x_rec_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 2),
                   title='[Train]Reconstructed')
    vs.plot_images(x_org_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 3),
                   title='[Test]Original')
    vs.plot_images(x_rec_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 4),
                   title='[Test]Reconstructed')
    plt.tight_layout()
    ## prepare the labels
    label_type = ds.label_type
    if label_type == 'categorical':
        labels_name = ds.labels
        true = np.argmax(y_true_test, axis=-1)
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = np.argmax(y_pred_test.mean().numpy(), axis=-1)
            labels_pred = np.array([labels_name[i] for i in pred])
    elif label_type == 'factor':  # dsprites, shapes3d
        labels_name = ['cube', 'cylinder', 'sphere', 'round'] \
          if 'shapes3d' in ds.name else ['square', 'ellipse', 'heart']
        true = y_true_test[:, 2].astype('int32')
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = get_ymean(y_pred_test)[:, 2].astype('int32')
            labels_pred = np.array([labels_name[i] for i in pred])
    else:  # CelebA
        raise NotImplementedError
    ## confusion matrix
    if is_semi:
        plt.figure(figsize=(8, 8))
        acc = accuracy_score(y_true=true, y_pred=pred)
        vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred),
                                 labels=labels_name,
                                 cbar=True,
                                 fontsize=10,
                                 title=f'{title} Acc:{acc:.2f}')
    ## save arrays for later inspections
    with open(f'{expdir}/arrays', 'wb') as f:
        pickle.dump(
            dict(z_train=z_train,
                 y_pred_train=y_pred_train,
                 y_true_train=y_true_train,
                 z_test=z_test,
                 y_pred_test=y_pred_test,
                 y_true_test=y_true_test,
                 labels=labels_name,
                 ds=ds.name,
                 label_type=label_type), f)
    print(f'Exported arrays to "{expdir}/arrays"')
    ## semi-supervised
    z_mean_train = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_train], -1)
    z_mean_test = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_test], -1)
    # === 2. scatter points latents plot
    n_points = 5000
    ids = rand.permutation(len(labels_true))[:n_points]
    Y_true = labels_true[ids]
    Y_pred = labels_pred[ids]
    # tsne plot
    n_latents = 0 if len(z_train) == 1 else len(z_train)
    for name, X in zip(
        ['all'] + [f'latents{i}'
                   for i in range(n_latents)], [z_mean_test[ids]] +
        [z_test[i].mean().numpy()[ids] for i in range(n_latents)]):
        print(f'Plot scatter points for {name}')
        X = X.reshape(X.shape[0], -1)  # flatten to 2D
        X = Pipeline([('zscore', StandardScaler()),
                      ('pca', PCA(min(X.shape[1], 512),
                                  random_state=seed))]).fit_transform(X)
        tsne = DimReduce.TSNE(X, n_components=2, framework='sklearn')
        kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.8)
        plt.figure(figsize=(12, 6))
        vs.plot_scatter(color=Y_true,
                        title=f'[True]{title}-{name}',
                        ax=(1, 2, 1),
                        **kw)
        vs.plot_scatter(color=Y_pred,
                        title=f'[Pred]{title}-{name}',
                        ax=(1, 2, 2),
                        **kw)
    ## save all plot
    vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True)

    # === 3. show the latents statistics
    n_latents = len(z_train)
    colors = sns.color_palette(n_colors=len(labels_true))
    styles = dict(grid=False,
                  ticks_off=False,
                  alpha=0.6,
                  xlabel='mean',
                  ylabel='stddev')

    # scatter between latents and labels (assume categorical distribution)
    def _show_latents_labels(Z, Y, title):
        plt.figure(figsize=(5 * n_latents, 5), dpi=150)
        for idx, z in enumerate(Z):
            if len(z.batch_shape) == 0:
                mean = np.repeat(np.expand_dims(z.mean(), 0), Y.shape[0], 0)
                stddev = z.sample(Y.shape[0]) - mean
            else:
                mean = flatten(z.mean())
                stddev = flatten(z.stddev())
            y = np.argmax(Y, axis=-1)
            data = [[], [], []]
            for y_i, c in zip(np.unique(y), colors):
                mask = (y == y_i)
                data[0].append(np.mean(mean[mask], 0))
                data[1].append(np.mean(stddev[mask], 0))
                data[2].append([labels_true[y_i]] * mean.shape[1])
            vs.plot_scatter(
                x=np.concatenate(data[0], 0),
                y=np.concatenate(data[1], 0),
                color=np.concatenate(data[2], 0),
                ax=(1, n_latents, idx + 1),
                size=15 if mean.shape[1] < 128 else 8,
                title=f'[Test-{title}]#{idx} - {mean.shape[1]} (units)',
                **styles)
        plt.tight_layout()

    # simple scatter mean-stddev each latents
    def _show_latents(Z, title):
        plt.figure(figsize=(3.5 * n_latents, 3.5), dpi=150)
        for idx, z in enumerate(Z):
            mean = flatten(z.mean())
            stddev = flatten(z.stddev())
            if mean.ndim == 2:
                mean = np.mean(mean, 0)
                stddev = np.mean(stddev, 0)
            vs.plot_scatter(
                x=mean,
                y=stddev,
                ax=(1, n_latents, idx + 1),
                size=15 if len(mean) < 128 else 8,
                title=f'[Test-{title}]#{idx} - {len(mean)} (units)',
                **styles)

    _show_latents_labels(z_test, y_true_test, 'post')
    _show_latents_labels(pz_test, y_true_test, 'prior')
    _show_latents(z_test, 'post')
    _show_latents(pz_test, 'prior')

    # KL statistics
    vs.plot_figure()
    for idx, (qz, pz) in enumerate(zip(z_test, pz_test)):
        kl = []
        qz = Normal(loc=qz.mean(), scale=qz.stddev(), name=f'posterior{idx}')
        pz = Normal(loc=pz.mean(), scale=pz.stddev(), name=f'prior{idx}')
        for s, e in minibatch(batch_size=8, n=100):
            z = qz.sample(e - s)
            # don't do this in GPU, it explodes!
            kl.append((qz.log_prob(z) - pz.log_prob(z)).numpy())
        kl = np.concatenate(kl, 0)  # (mcmc, batch, event)
        # per sample
        kl_samples = np.sum(kl, as_tuple(list(range(2, kl.ndim))))
        kl_samples = logsumexp(kl_samples, 0)
        plt.subplot(n_latents, 2, idx * 2 + 1)
        sns.histplot(kl_samples, bins=50)
        plt.title(f'Z#{idx} KL per sample (nats)')
        # per latent
        kl_latents = np.mean(flatten(logsumexp(kl, 0)), 0)
        plt.subplot(n_latents, 2, idx * 2 + 2)
        plt.plot(np.sort(kl_latents))
        plt.title(f'Z#{idx} KL per dim (nats)')
    plt.tight_layout()

    vs.plot_save(f'{expdir}/latents.pdf', dpi=180, verbose=True)
예제 #15
0
class TanhNormal(Distribution):
    """
    Represent distribution of X where
        X ~ tanh(Z)
        Z ~ N(mean, std)

    Note: this is not very numerically stable.
    """
    def __init__(self, normal_mean, normal_std, epsilon=1e-6):
        """
        :param normal_mean: Mean of the normal distribution
        :param normal_std: Std of the normal distribution
        :param epsilon: Numerical stability epsilon when computing log-prob.
        """
        self.normal_mean = normal_mean
        self.normal_std = normal_std
        self.normal = Normal(normal_mean, normal_std)
        self.epsilon = epsilon

    def sample_n(self, n, return_pre_tanh_value=False):
        z = self.normal.sample(n)
        if return_pre_tanh_value:
            return tf.math.tanh(z), z
        else:
            return tf.math.tanh(z)

    def log_prob(self, value, pre_tanh_value=None):
        """

        :param value: some value, x
        :param pre_tanh_value: arctanh(x)
        :return:
        """
        if pre_tanh_value is None:
            pre_tanh_value = tf.math.log((1 + value) / (1 - value)) / 2
        return self.normal.log_prob(pre_tanh_value) - tf.math.log(
            1 - value * value + self.epsilon)

    def sample(self, return_pretanh_value=False):
        """
        Gradients will and should *not* pass through this operation.

        See https://github.com/pytorch/pytorch/issues/4620 for discussion.
        """
        z = self.normal.sample()

        if return_pretanh_value:
            return tf.math.tanh(z), z
        else:
            return tf.math.tanh(z)

    def rsample(self, return_pretanh_value=False):
        """
        Sampling in the reparameterization case.
        """
        z = (self.normal_mean +
             self.normal_std * Normal(tf.zeros_like(self.normal_mean),
                                      tf.ones_like(self.normal_std)).sample())

        if return_pretanh_value:
            return tf.math.tanh(z), z
        else:
            return tf.math.tanh(z)