Exemple #1
0
 def get_squashed_gaussian_dist(mu, sigma, action_spec=None):
     if action_spec is not None:
         scale = 0.5 * (action_spec.maximum - action_spec.minimum)
         shift = action_spec.minimum
         bijector = distrax.Chain([
             distrax.ScalarAffine(shift=shift, scale=scale),
             distrax.ScalarAffine(shift=1.0),
             distrax.Tanh()
         ])
     else:
         bijector = distrax.Tanh()
     return distrax.Transformed(distribution=distrax.MultivariateNormalDiag(
         loc=mu_activation(mu), scale_diag=sigma_activation(sigma)),
                                bijector=distrax.Block(bijector, ndims=1))
Exemple #2
0
def multivariate_normal_kl_divergence(
    mu_0: Array,
    sigma_0: Numeric,
    mu_1: Array,
    sigma_1: Numeric,
) -> Array:
    """Compute the KL between 2 gaussian distrs with diagonal covariance matrices.

  Args:
    mu_0: array like of mean values for policy 0
    sigma_0: array like of std values for policy 0
    mu_1: array like of mean values for policy 1
    sigma_1: array like of std values for policy 1

  Returns:
    the kl divergence between the distributions.
  """
    warnings.warn(
        "Rlax multivariate_normal_kl_divergence will be deprecated."
        "Please use distrax.MultivariateNormalDiag.kl_divergence instead.",
        PendingDeprecationWarning,
        stacklevel=2)
    return distrax.MultivariateNormalDiag(mu_0, sigma_0).kl_divergence(
        distrax.MultivariateNormalDiag(mu_1, sigma_1))
Exemple #3
0
    def loss_fn(params: hk.Params, rng_key: PRNGKey,
                batch: Batch) -> jnp.ndarray:
        """Loss = -ELBO, where ELBO = E_q[log p(x|z)] - KL(q(z|x) || p(z))."""

        outputs: VAEOutput = model.apply(params, rng_key, batch["image"])

        # p(z) = N(0, I)
        prior_z = distrax.MultivariateNormalDiag(loc=jnp.zeros(
            (latent_size, )),
                                                 scale_diag=jnp.ones(
                                                     (latent_size, )))

        log_likelihood = outputs.likelihood_distrib.log_prob(batch["image"])
        kl = outputs.variational_distrib.kl_divergence(prior_z)
        elbo = log_likelihood - kl

        return -jnp.mean(elbo)
Exemple #4
0
    def __call__(self, x: jnp.ndarray) -> VAEOutput:
        x = x.astype(jnp.float32)

        # q(z|x) = N(mean(x), covariance(x))
        mean, stddev = Encoder(self._hidden_size, self._latent_size)(x)
        variational_distrib = distrax.MultivariateNormalDiag(loc=mean,
                                                             scale_diag=stddev)
        z = variational_distrib.sample(seed=hk.next_rng_key())

        # p(x|z) = \Prod Bernoulli(logits(z))
        logits = Decoder(self._hidden_size, self._output_shape)(z)
        likelihood_distrib = distrax.Independent(
            distrax.Bernoulli(logits=logits),
            reinterpreted_batch_ndims=len(
                self._output_shape))  # 3 non-batch dims

        # Generate images from the likelihood
        image = likelihood_distrib.sample(seed=hk.next_rng_key())

        return VAEOutput(variational_distrib, likelihood_distrib, image)
Exemple #5
0
 def kl_to_standard_normal_fn(mu: Array, sigma: Array):
     return get_squashed_gaussian_dist(mu,
                                       sigma).distribution.kl_divergence(
                                           distrax.MultivariateNormalDiag(
                                               jnp.zeros_like(mu),
                                               jnp.ones_like(mu)))
Exemple #6
0
 def kl_fn(mu_0: Array, sigma_0: Numeric, mu_1: Array, sigma_1: Numeric):
     return distrax.MultivariateNormalDiag(
         mu_0,
         jnp.ones_like(mu_0) * sigma_0).kl_divergence(
             distrax.MultivariateNormalDiag(mu_1,
                                            jnp.ones_like(mu_1) * sigma_1))
Exemple #7
0
 def kl_to_standard_normal_fn(mu: Array, sigma: Array = sigma):
     return distrax.MultivariateNormalDiag(
         mu,
         jnp.ones_like(mu) * sigma).kl_divergence(
             distrax.MultivariateNormalDiag(jnp.zeros_like(mu),
                                            jnp.ones_like(mu)))
Exemple #8
0
 def entropy_fn(mu: Array, sigma: Array = sigma):
     return distrax.MultivariateNormalDiag(mu,
                                           jnp.ones_like(mu) *
                                           sigma).entropy()
Exemple #9
0
 def logprob_fn(sample: Array, mu: Array, sigma: Array = sigma):
     return distrax.MultivariateNormalDiag(mu,
                                           jnp.ones_like(mu) *
                                           sigma).log_prob(sample)
Exemple #10
0
 def sample_fn(key: Array, mu: Array, sigma: Array = sigma):
     return distrax.MultivariateNormalDiag(mu,
                                           jnp.ones_like(mu) *
                                           sigma).sample(seed=key)