Example #1
0
    def call(self, state):
        # Get mean and standard deviation from the policy network
        a1 = self.dense1_layer(state)
        a2 = self.dense2_layer(a1)
        mu = self.mean_layer(a2)

        # Standard deviation is bounded by a constraint of being non-negative
        # therefore we produce log stdev as output which can be [-inf, inf]
        log_sigma = self.stdev_layer(a2)
        sigma = tf.exp(log_sigma)

        # Use re-parameterization trick to deterministically sample action from
        # the policy network. First, sample from a Normal distribution of
        # sample size as the action and multiply it with stdev
        dist = Normal(mu, sigma)
        action_ = dist.sample()

        # Apply the tanh squashing to keep the gaussian bounded in (-1,1)
        action = tf.tanh(action_)

        # Calculate the log probability
        log_pi_ = dist.log_prob(action_)
        log_pi = log_pi_ - tf.reduce_sum(
            tf.math.log(1 - action**2 + eps), axis=1, keepdims=True)
        return action, log_pi
Example #2
0
 def latent_units(model: VariationalModel, valid_ds: tf.data.Dataset):
     weights = model.weights
     Qz = []
     Pz = []
     for x, y in valid_ds.take(10):
         _call(model, x, y, decode=True)
         qz, pz = model.get_latents(return_prior=True)
         Qz.append(as_tuple(qz))
         Pz.append(as_tuple(pz))
     n_latents = len(Qz[0])
     for i in range(n_latents):
         qz: Sequence[Distribution] = [q[i] for q in Qz]
         pz: Sequence[Distribution] = [p[i] for p in Pz]
         # tracking kl
         kld = []
         for q, p in zip(qz, pz):
             q = Normal(loc=q.mean(), scale=q.stddev())
             z = q.sample()
             if isinstance(p, Vamprior):
                 C = p.C
                 p = p.distribution  # [n_components, zdim]
                 p = Normal(loc=p.mean(), scale=p.stddev())
                 kld.append(
                     q.log_prob(z) - (tf.reduce_logsumexp(
                         p.log_prob(tf.expand_dims(z, 1)), 1) - C))
             else:
                 p = Normal(loc=p.mean(), scale=p.stddev())
                 kld.append(q.log_prob(z) - p.log_prob(z))
         kld = tf.reshape(tf.reduce_mean(tf.concat(kld, 0), axis=0),
                          -1).numpy()
         # mean and stddev
         mean = tf.reduce_mean(tf.concat([d.mean() for d in qz], axis=0), 0)
         stddev = tf.reduce_mean(
             tf.concat([d.stddev() for d in qz], axis=0), 0)
         mean = tf.reshape(mean, -1).numpy()
         stddev = tf.reshape(stddev, -1).numpy()
         zdim = mean.shape[0]
         # the figure
         plt.figure(figsize=(12, 5), dpi=50)
         lines = []
         ids = np.argsort(stddev)
         styles = dict(marker='o', markersize=2, linewidth=0, alpha=0.8)
         lines += plt.plot(mean[ids], label='mean', color='r', **styles)
         lines += plt.plot(stddev[ids], label='stddev', color='b', **styles)
         plt.grid(False)
         # show weights if exists
         plt.twinx()
         lines += plt.plot(kld[ids],
                           label='KL(q|p)',
                           linestyle='--',
                           linewidth=1.0,
                           alpha=0.6)
         for w in weights:
             name = w.name
             if w.shape.rank > 0 and w.shape[
                     0] == zdim and '/kernel' in name:
                 w = tf.linalg.norm(tf.reshape(w, (w.shape[0], -1)),
                                    axis=1).numpy()
                 lines += plt.plot(w[ids],
                                   label=name.split(':')[0],
                                   linestyle='--',
                                   alpha=0.6)
         plt.grid(False)
         plt.legend(lines, [ln.get_label() for ln in lines], fontsize=6)
         # save summary
         tf.summary.image(f'z{i}',
                          vs.plot_to_image(plt.gcf()),
                          step=model.step)
         tf.summary.histogram(f'z{i}/mean', mean, step=model.step)
         tf.summary.histogram(f'z{i}/stddev', stddev, step=model.step)
         tf.summary.histogram(f'z{i}/kld', kld, step=model.step)
Example #3
0
from tensorflow.keras.layers import Input, Dense, Concatenate
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.models import Model
from tensorflow_probability.python.distributions import Normal

env = gym.make('MountainCarContinuous-v0')
state = env.reset()
state_input = Input(state.shape, name='state_input')
actor_dense1_layer = Dense(4, activation=relu, name='actor_dense1')
actor_dense1 = actor_dense1_layer(state_input)
actor_out_layer = Dense(2, activation=None, name='actor_nn_out')
actor_out = actor_out_layer(actor_dense1)
mean = actor_out[:, 0]
std = softplus(actor_out[:, 1])
dist = Normal(mean, std, name='dist')
action = dist.sample((), name='sample')
action_log_prob = dist.log_prob(action, name='log_prob')
action_log_prob = expand_dims(action_log_prob)
action = expand_dims(action)
action = clip(action, -1.0, 1.0)
critic_concat = Concatenate()([state_input, action])
critic_dense1_layer = Dense(units=8, activation=relu, name='critic_dense1')
critic_dense1 = critic_dense1_layer(critic_concat)
critic_out_layer = Dense(1, activation=None)
critic_out = critic_out_layer(critic_dense1)
actor_critic_model = Model(inputs=[state_input],
                           outputs=[action, action_log_prob, critic_out])
critic_loss_op = MeanSquaredError(name='critic_loss')
actor_vars = actor_dense1_layer.trainable_variables + actor_out_layer.trainable_variables
critic_vars = critic_dense1_layer.trainable_variables + critic_out_layer.trainable_variables
GAMES = 100
Example #4
0
 def call(self, states, **kwargs):
     mean, std = self._get_dist(states)
     dist = Normal(mean, std)
     action = dist.sample()
     log_prob = tf.reduce_sum(dist.log_prob(action), -1)
     return action, log_prob
Example #5
0
def evaluate(vae: VariationalAutoencoder,
             ds: ImageDataset,
             expdir: str,
             title: str,
             batch_size: int = 64,
             take_count: int = -1,
             n_images: int = 36,
             seed: int = 1):
    n_rows = int(np.sqrt(n_images))
    is_semi = vae.is_semi_supervised()
    is_hierarchical = vae.is_hierarchical()
    ds_kw = dict(batch_size=batch_size, label_percent=1.0, shuffle=False)
    ## prepare
    rand = np.random.RandomState(seed=seed)
    if not os.path.exists(expdir):
        os.makedirs(expdir)
    ## data for training semi-supervised
    train = ds.create_dataset('train', **ds_kw)
    (llkx_train, llky_train, x_org_train, x_rec_train, y_true_train,
     y_pred_train, z_train, pz_train) = _call(vae,
                                              ds=train,
                                              rand=rand,
                                              take_count=take_count,
                                              n_images=n_images,
                                              verbose=True)
    ## data for testing
    test = ds.create_dataset('test', **ds_kw)
    (llkx_test, llky_test, x_org_test, x_rec_test, y_true_test, y_pred_test,
     z_test, pz_test) = _call(vae,
                              ds=test,
                              rand=rand,
                              take_count=take_count,
                              n_images=n_images,
                              verbose=True)
    # === 0. plotting latent-factor pairs
    for idx, z in enumerate(z_test):
        z = z.mean()
        f = y_true_test
        corr_mat = Correlation.Spearman(z, f)  # [n_latents, n_factors]
        plot_latents_pairs(z, f, corr_mat, ds.labels)
        vs.plot_save(f'{expdir}/latent{idx}_factor.pdf', dpi=100, verbose=True)
    # === 0. latent traverse plot
    x_travs = x_org_test
    if x_travs.ndim == 3:  # grayscale image
        x_travs = np.expand_dims(x_travs, -1)
    else:  # color image
        x_travs = np.transpose(x_travs, (0, 2, 3, 1))
    x_travs = x_travs[rand.permutation(x_travs.shape[0])]
    n_visual_samples = 5
    n_traverse_points = 21
    n_top_latents = 10
    plt.figure(figsize=(8, 3 * n_visual_samples))
    for i in range(n_visual_samples):
        images = vae.sample_traverse(x_travs[i:i + 1],
                                     min_val=-np.min(z_test[0].mean()),
                                     max_val=np.max(z_test[0].mean()),
                                     n_best_latents=n_top_latents,
                                     n_traverse_points=n_traverse_points,
                                     mode='linear')
        images = as_tuple(images)[0]
        images = _prepare_images(images.mean().numpy(), normalize=True)
        vs.plot_images(images,
                       grids=(n_top_latents, n_traverse_points),
                       ax=(n_visual_samples, 1, i + 1))
        if i == 0:
            plt.title('Latents traverse')
    plt.tight_layout()
    vs.plot_save(f'{expdir}/latents_traverse.pdf', dpi=180, verbose=True)
    # === 0. prior sampling plot
    images = as_tuple(vae.sample_observation(n=n_images, seed=seed))[0]
    images = _prepare_images(images.mean().numpy(), normalize=True)
    plt.figure(figsize=(5, 5))
    vs.plot_images(images, grids=(n_rows, n_rows), title='Sampled')
    # === 1. reconstruction plot
    plt.figure(figsize=(15, 15))
    vs.plot_images(x_org_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 1),
                   title='[Train]Original')
    vs.plot_images(x_rec_train,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 2),
                   title='[Train]Reconstructed')
    vs.plot_images(x_org_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 3),
                   title='[Test]Original')
    vs.plot_images(x_rec_test,
                   grids=(n_rows, n_rows),
                   ax=(2, 2, 4),
                   title='[Test]Reconstructed')
    plt.tight_layout()
    ## prepare the labels
    label_type = ds.label_type
    if label_type == 'categorical':
        labels_name = ds.labels
        true = np.argmax(y_true_test, axis=-1)
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = np.argmax(y_pred_test.mean().numpy(), axis=-1)
            labels_pred = np.array([labels_name[i] for i in pred])
    elif label_type == 'factor':  # dsprites, shapes3d
        labels_name = ['cube', 'cylinder', 'sphere', 'round'] \
          if 'shapes3d' in ds.name else ['square', 'ellipse', 'heart']
        true = y_true_test[:, 2].astype('int32')
        labels_true = np.array([labels_name[i] for i in true])
        labels_pred = labels_true
        if is_semi:
            pred = get_ymean(y_pred_test)[:, 2].astype('int32')
            labels_pred = np.array([labels_name[i] for i in pred])
    else:  # CelebA
        raise NotImplementedError
    ## confusion matrix
    if is_semi:
        plt.figure(figsize=(8, 8))
        acc = accuracy_score(y_true=true, y_pred=pred)
        vs.plot_confusion_matrix(cm=confusion_matrix(y_true=true, y_pred=pred),
                                 labels=labels_name,
                                 cbar=True,
                                 fontsize=10,
                                 title=f'{title} Acc:{acc:.2f}')
    ## save arrays for later inspections
    with open(f'{expdir}/arrays', 'wb') as f:
        pickle.dump(
            dict(z_train=z_train,
                 y_pred_train=y_pred_train,
                 y_true_train=y_true_train,
                 z_test=z_test,
                 y_pred_test=y_pred_test,
                 y_true_test=y_true_test,
                 labels=labels_name,
                 ds=ds.name,
                 label_type=label_type), f)
    print(f'Exported arrays to "{expdir}/arrays"')
    ## semi-supervised
    z_mean_train = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_train], -1)
    z_mean_test = np.concatenate(
        [z.mean().numpy().reshape(z.batch_shape[0], -1) for z in z_test], -1)
    # === 2. scatter points latents plot
    n_points = 5000
    ids = rand.permutation(len(labels_true))[:n_points]
    Y_true = labels_true[ids]
    Y_pred = labels_pred[ids]
    # tsne plot
    n_latents = 0 if len(z_train) == 1 else len(z_train)
    for name, X in zip(
        ['all'] + [f'latents{i}'
                   for i in range(n_latents)], [z_mean_test[ids]] +
        [z_test[i].mean().numpy()[ids] for i in range(n_latents)]):
        print(f'Plot scatter points for {name}')
        X = X.reshape(X.shape[0], -1)  # flatten to 2D
        X = Pipeline([('zscore', StandardScaler()),
                      ('pca', PCA(min(X.shape[1], 512),
                                  random_state=seed))]).fit_transform(X)
        tsne = DimReduce.TSNE(X, n_components=2, framework='sklearn')
        kw = dict(x=tsne[:, 0], y=tsne[:, 1], grid=False, size=12.0, alpha=0.8)
        plt.figure(figsize=(12, 6))
        vs.plot_scatter(color=Y_true,
                        title=f'[True]{title}-{name}',
                        ax=(1, 2, 1),
                        **kw)
        vs.plot_scatter(color=Y_pred,
                        title=f'[Pred]{title}-{name}',
                        ax=(1, 2, 2),
                        **kw)
    ## save all plot
    vs.plot_save(f'{expdir}/analysis.pdf', dpi=180, verbose=True)

    # === 3. show the latents statistics
    n_latents = len(z_train)
    colors = sns.color_palette(n_colors=len(labels_true))
    styles = dict(grid=False,
                  ticks_off=False,
                  alpha=0.6,
                  xlabel='mean',
                  ylabel='stddev')

    # scatter between latents and labels (assume categorical distribution)
    def _show_latents_labels(Z, Y, title):
        plt.figure(figsize=(5 * n_latents, 5), dpi=150)
        for idx, z in enumerate(Z):
            if len(z.batch_shape) == 0:
                mean = np.repeat(np.expand_dims(z.mean(), 0), Y.shape[0], 0)
                stddev = z.sample(Y.shape[0]) - mean
            else:
                mean = flatten(z.mean())
                stddev = flatten(z.stddev())
            y = np.argmax(Y, axis=-1)
            data = [[], [], []]
            for y_i, c in zip(np.unique(y), colors):
                mask = (y == y_i)
                data[0].append(np.mean(mean[mask], 0))
                data[1].append(np.mean(stddev[mask], 0))
                data[2].append([labels_true[y_i]] * mean.shape[1])
            vs.plot_scatter(
                x=np.concatenate(data[0], 0),
                y=np.concatenate(data[1], 0),
                color=np.concatenate(data[2], 0),
                ax=(1, n_latents, idx + 1),
                size=15 if mean.shape[1] < 128 else 8,
                title=f'[Test-{title}]#{idx} - {mean.shape[1]} (units)',
                **styles)
        plt.tight_layout()

    # simple scatter mean-stddev each latents
    def _show_latents(Z, title):
        plt.figure(figsize=(3.5 * n_latents, 3.5), dpi=150)
        for idx, z in enumerate(Z):
            mean = flatten(z.mean())
            stddev = flatten(z.stddev())
            if mean.ndim == 2:
                mean = np.mean(mean, 0)
                stddev = np.mean(stddev, 0)
            vs.plot_scatter(
                x=mean,
                y=stddev,
                ax=(1, n_latents, idx + 1),
                size=15 if len(mean) < 128 else 8,
                title=f'[Test-{title}]#{idx} - {len(mean)} (units)',
                **styles)

    _show_latents_labels(z_test, y_true_test, 'post')
    _show_latents_labels(pz_test, y_true_test, 'prior')
    _show_latents(z_test, 'post')
    _show_latents(pz_test, 'prior')

    # KL statistics
    vs.plot_figure()
    for idx, (qz, pz) in enumerate(zip(z_test, pz_test)):
        kl = []
        qz = Normal(loc=qz.mean(), scale=qz.stddev(), name=f'posterior{idx}')
        pz = Normal(loc=pz.mean(), scale=pz.stddev(), name=f'prior{idx}')
        for s, e in minibatch(batch_size=8, n=100):
            z = qz.sample(e - s)
            # don't do this in GPU, it explodes!
            kl.append((qz.log_prob(z) - pz.log_prob(z)).numpy())
        kl = np.concatenate(kl, 0)  # (mcmc, batch, event)
        # per sample
        kl_samples = np.sum(kl, as_tuple(list(range(2, kl.ndim))))
        kl_samples = logsumexp(kl_samples, 0)
        plt.subplot(n_latents, 2, idx * 2 + 1)
        sns.histplot(kl_samples, bins=50)
        plt.title(f'Z#{idx} KL per sample (nats)')
        # per latent
        kl_latents = np.mean(flatten(logsumexp(kl, 0)), 0)
        plt.subplot(n_latents, 2, idx * 2 + 2)
        plt.plot(np.sort(kl_latents))
        plt.title(f'Z#{idx} KL per dim (nats)')
    plt.tight_layout()

    vs.plot_save(f'{expdir}/latents.pdf', dpi=180, verbose=True)
Example #6
0
class TanhNormal(Distribution):
    """
    Represent distribution of X where
        X ~ tanh(Z)
        Z ~ N(mean, std)

    Note: this is not very numerically stable.
    """
    def __init__(self, normal_mean, normal_std, epsilon=1e-6):
        """
        :param normal_mean: Mean of the normal distribution
        :param normal_std: Std of the normal distribution
        :param epsilon: Numerical stability epsilon when computing log-prob.
        """
        self.normal_mean = normal_mean
        self.normal_std = normal_std
        self.normal = Normal(normal_mean, normal_std)
        self.epsilon = epsilon

    def sample_n(self, n, return_pre_tanh_value=False):
        z = self.normal.sample(n)
        if return_pre_tanh_value:
            return tf.math.tanh(z), z
        else:
            return tf.math.tanh(z)

    def log_prob(self, value, pre_tanh_value=None):
        """

        :param value: some value, x
        :param pre_tanh_value: arctanh(x)
        :return:
        """
        if pre_tanh_value is None:
            pre_tanh_value = tf.math.log((1 + value) / (1 - value)) / 2
        return self.normal.log_prob(pre_tanh_value) - tf.math.log(
            1 - value * value + self.epsilon)

    def sample(self, return_pretanh_value=False):
        """
        Gradients will and should *not* pass through this operation.

        See https://github.com/pytorch/pytorch/issues/4620 for discussion.
        """
        z = self.normal.sample()

        if return_pretanh_value:
            return tf.math.tanh(z), z
        else:
            return tf.math.tanh(z)

    def rsample(self, return_pretanh_value=False):
        """
        Sampling in the reparameterization case.
        """
        z = (self.normal_mean +
             self.normal_std * Normal(tf.zeros_like(self.normal_mean),
                                      tf.ones_like(self.normal_std)).sample())

        if return_pretanh_value:
            return tf.math.tanh(z), z
        else:
            return tf.math.tanh(z)