def get_squashed_gaussian_dist(mu, sigma, action_spec=None): if action_spec is not None: scale = 0.5 * (action_spec.maximum - action_spec.minimum) shift = action_spec.minimum bijector = distrax.Chain([ distrax.ScalarAffine(shift=shift, scale=scale), distrax.ScalarAffine(shift=1.0), distrax.Tanh() ]) else: bijector = distrax.Tanh() return distrax.Transformed(distribution=distrax.MultivariateNormalDiag( loc=mu_activation(mu), scale_diag=sigma_activation(sigma)), bijector=distrax.Block(bijector, ndims=1))
def multivariate_normal_kl_divergence( mu_0: Array, sigma_0: Numeric, mu_1: Array, sigma_1: Numeric, ) -> Array: """Compute the KL between 2 gaussian distrs with diagonal covariance matrices. Args: mu_0: array like of mean values for policy 0 sigma_0: array like of std values for policy 0 mu_1: array like of mean values for policy 1 sigma_1: array like of std values for policy 1 Returns: the kl divergence between the distributions. """ warnings.warn( "Rlax multivariate_normal_kl_divergence will be deprecated." "Please use distrax.MultivariateNormalDiag.kl_divergence instead.", PendingDeprecationWarning, stacklevel=2) return distrax.MultivariateNormalDiag(mu_0, sigma_0).kl_divergence( distrax.MultivariateNormalDiag(mu_1, sigma_1))
def loss_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray: """Loss = -ELBO, where ELBO = E_q[log p(x|z)] - KL(q(z|x) || p(z)).""" outputs: VAEOutput = model.apply(params, rng_key, batch["image"]) # p(z) = N(0, I) prior_z = distrax.MultivariateNormalDiag(loc=jnp.zeros( (latent_size, )), scale_diag=jnp.ones( (latent_size, ))) log_likelihood = outputs.likelihood_distrib.log_prob(batch["image"]) kl = outputs.variational_distrib.kl_divergence(prior_z) elbo = log_likelihood - kl return -jnp.mean(elbo)
def __call__(self, x: jnp.ndarray) -> VAEOutput: x = x.astype(jnp.float32) # q(z|x) = N(mean(x), covariance(x)) mean, stddev = Encoder(self._hidden_size, self._latent_size)(x) variational_distrib = distrax.MultivariateNormalDiag(loc=mean, scale_diag=stddev) z = variational_distrib.sample(seed=hk.next_rng_key()) # p(x|z) = \Prod Bernoulli(logits(z)) logits = Decoder(self._hidden_size, self._output_shape)(z) likelihood_distrib = distrax.Independent( distrax.Bernoulli(logits=logits), reinterpreted_batch_ndims=len( self._output_shape)) # 3 non-batch dims # Generate images from the likelihood image = likelihood_distrib.sample(seed=hk.next_rng_key()) return VAEOutput(variational_distrib, likelihood_distrib, image)
def kl_to_standard_normal_fn(mu: Array, sigma: Array): return get_squashed_gaussian_dist(mu, sigma).distribution.kl_divergence( distrax.MultivariateNormalDiag( jnp.zeros_like(mu), jnp.ones_like(mu)))
def kl_fn(mu_0: Array, sigma_0: Numeric, mu_1: Array, sigma_1: Numeric): return distrax.MultivariateNormalDiag( mu_0, jnp.ones_like(mu_0) * sigma_0).kl_divergence( distrax.MultivariateNormalDiag(mu_1, jnp.ones_like(mu_1) * sigma_1))
def kl_to_standard_normal_fn(mu: Array, sigma: Array = sigma): return distrax.MultivariateNormalDiag( mu, jnp.ones_like(mu) * sigma).kl_divergence( distrax.MultivariateNormalDiag(jnp.zeros_like(mu), jnp.ones_like(mu)))
def entropy_fn(mu: Array, sigma: Array = sigma): return distrax.MultivariateNormalDiag(mu, jnp.ones_like(mu) * sigma).entropy()
def logprob_fn(sample: Array, mu: Array, sigma: Array = sigma): return distrax.MultivariateNormalDiag(mu, jnp.ones_like(mu) * sigma).log_prob(sample)
def sample_fn(key: Array, mu: Array, sigma: Array = sigma): return distrax.MultivariateNormalDiag(mu, jnp.ones_like(mu) * sigma).sample(seed=key)