def estimate_kl_best_effort( distribution_a: DistributionLike, distribution_b: DistributionLike, rng_key: PNRGKey, num_samples: int, proposal_distribution: Optional[DistributionLike] = None): """Estimates KL(distribution_a, distribution_b) exactly or with DiCE. If the kl_divergence(distribution_a, distribution_b) is not supported, the DiCE estimator is used instead. Args: distribution_a: The first distribution. distribution_b: The second distribution. rng_key: The PRNGKey random key. num_samples: The number of samples, if using the DiCE estimator. proposal_distribution: A proposal distribution for the samples, if using the DiCE estimator. If None, use `distribution_a` as proposal. Returns: The estimated KL divergence. """ distribution_a = conversion.as_distribution(distribution_a) distribution_b = conversion.as_distribution(distribution_b) # If possible, compute the exact KL. try: return tfd.kl_divergence(distribution_a, distribution_b) except NotImplementedError: pass return mc_estimate_kl(distribution_a, distribution_b, rng_key, num_samples=num_samples, proposal_distribution=proposal_distribution)
def __init__(self, mixture_distribution: CategoricalLike, components_distribution: DistributionLike): """Initializes a mixture distribution for components of a shared family. Args: mixture_distribution: Distribution over selecting components. components_distribution: Component distribution, with rightmost batch dimension indexing components. """ super().__init__() mixture_distribution = conversion.as_distribution(mixture_distribution) components_distribution = conversion.as_distribution( components_distribution) self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution # Store normalized weights (last axis of logits is for components). # This uses the TFP API, which is replicated in Distrax. self._mixture_log_probs = jax.nn.log_softmax( mixture_distribution.logits_parameter(), axis=-1) batch_shape_mixture = mixture_distribution.batch_shape batch_shape_components = components_distribution.batch_shape if batch_shape_mixture != batch_shape_components[:-1]: msg = ( f'`mixture_distribution.batch_shape` ' f'({mixture_distribution.batch_shape}) is not compatible with ' f'`components_distribution.batch_shape` ' f'({components_distribution.batch_shape}`)') raise ValueError(msg)
def test_on_tfp_distribution(self): dist = tfd.Normal(loc=0., scale=1.) wrapped_dist = conversion.as_distribution(dist) assert isinstance(wrapped_dist, tfd.Normal) assert isinstance(wrapped_dist, Distribution) # Access the `loc` attribute of a wrapped Normal. np.testing.assert_almost_equal(wrapped_dist.loc, 0.)
def test_num_categories_attr_of_categorical(self): dist = Categorical(logits=jnp.array([0., 0., 0.])) wrapped_dist = conversion.as_distribution(dist) assert isinstance(wrapped_dist, Categorical) self.assertIs(wrapped_dist, dist) # Access the `num_categories` attribute of a wrapped Categorical. np.testing.assert_equal(wrapped_dist.num_categories, 3)
def test_loc_attr_of_normal(self): dist = Normal(loc=0., scale=1.) wrapped_dist = conversion.as_distribution(dist) assert isinstance(wrapped_dist, Normal) self.assertIs(wrapped_dist, dist) # Access the `loc` attribute of a wrapped Normal. np.testing.assert_almost_equal(wrapped_dist.loc, 0.)
def __init__(self, distribution: DistributionLike, reinterpreted_batch_ndims: Optional[int] = None): """Initializes an Independent distribution. Args: distribution: Base distribution instance. reinterpreted_batch_ndims: Number of event dimensions. """ super().__init__() distribution = conversion.as_distribution(distribution) self._distribution = distribution dist_batch_shape = distribution.batch_shape if reinterpreted_batch_ndims is not None: dist_batch_ndims = len(dist_batch_shape) if reinterpreted_batch_ndims > dist_batch_ndims: raise ValueError( f'`reinterpreted_batch_ndims` is {reinterpreted_batch_ndims}, but' f' distribution `{distribution.name}` has only {dist_batch_ndims}' f' batch dimensions.') elif reinterpreted_batch_ndims < 0: raise ValueError( f'`reinterpreted_batch_ndims` can\'t be negative; got' f' {reinterpreted_batch_ndims}.') self._reinterpreted_batch_ndims = reinterpreted_batch_ndims else: self._reinterpreted_batch_ndims = max(len(dist_batch_shape) - 1, 0) event_ndims = len(dist_batch_shape) - self._reinterpreted_batch_ndims self._event_shape = (dist_batch_shape[event_ndims:] + distribution.event_shape)
def __init__(self, distribution: DistributionLike, bijector: BijectorLike): """Initializes a Transformed distribution. Args: distribution: the base distribution. Can be either a Distrax distribution or a TFP distribution. bijector: a differentiable bijective transformation. Can be a Distrax bijector, a TFP bijector, or a callable to be wrapped by `Lambda`. """ super().__init__() distribution = conversion.as_distribution(distribution) bijector = conversion.as_bijector(bijector) if len(distribution.event_shape) != bijector.event_ndims_in: raise ValueError( f"Base distribution '{distribution.name}' has event shape " f"{distribution.event_shape}, but bijector '{bijector.name}' expects " f"events to have {bijector.event_ndims_in} dimensions. Perhaps use " f"`distrax.Block` or `distrax.Independent`?") self._distribution = distribution self._bijector = bijector self._batch_shape = None self._event_shape = None self._dtype = None
def mc_estimate_kl_with_reparameterized(distribution_a: DistributionLike, distribution_b: DistributionLike, rng_key: PNRGKey, num_samples: int): """Estimates KL(distribution_a, distribution_b).""" if isinstance(distribution_a, tfd.Distribution): if distribution_a.reparameterization_type != tfd.FULLY_REPARAMETERIZED: raise ValueError( f'Distribution `{distribution_a.name}` cannot be reparameterized.' ) distribution_a = conversion.as_distribution(distribution_a) distribution_b = conversion.as_distribution(distribution_b) samples, logp_a = distribution_a.sample_and_log_prob( seed=rng_key, sample_shape=[num_samples]) logp_b = distribution_b.log_prob(samples) log_ratio = logp_b - logp_a kl_estimator = -log_ratio return jnp.mean(kl_estimator, axis=0)
def test_attrs_of_transformed_distribution(self): dist = Transformed(Normal(loc=0., scale=1.), bijector=lambda x: x) wrapped_dist = conversion.as_distribution(dist) assert isinstance(wrapped_dist, Transformed) self.assertIs(wrapped_dist, dist) # Access the `distribution` attribute of a wrapped Transformed. assert isinstance(wrapped_dist.distribution, Normal) # Access the `loc` attribute of a transformed Normal within a wrapped # Transformed. np.testing.assert_almost_equal(wrapped_dist.distribution.loc, 0.)
def mc_estimate_mode(distribution: DistributionLike, rng_key: PNRGKey, num_samples: int): """Returns a Monte Carlo estimate of the mode of a distribution.""" distribution = conversion.as_distribution(distribution) # Obtain samples from the distribution and their log probability. samples, log_probs = distribution.sample_and_log_prob( seed=rng_key, sample_shape=[num_samples]) # Do argmax over the sample_shape. index = jnp.expand_dims(jnp.argmax(log_probs, axis=0), axis=0) mode = jnp.squeeze(jnp.take_along_axis(samples, index, axis=0), axis=0) return mode
def mc_estimate_kl(distribution_a: DistributionLike, distribution_b: DistributionLike, rng_key: PNRGKey, num_samples: int, proposal_distribution: Optional[DistributionLike] = None): """Estimates KL(distribution_a, distribution_b) with the DiCE estimator. To get correct gradients with respect the `distribution_a`, we use the DiCE estimator, i.e., we stop the gradient with respect to the samples and with respect to the denominator in the importance weights. We then do not need reparametrized distributions. Args: distribution_a: The first distribution. distribution_b: The second distribution. rng_key: The PRNGKey random key. num_samples: The number of samples, if using the DiCE estimator. proposal_distribution: A proposal distribution for the samples, if using the DiCE estimator. If None, use `distribution_a` as proposal. Returns: The estimated KL divergence. """ if proposal_distribution is None: proposal_distribution = distribution_a proposal_distribution = conversion.as_distribution(proposal_distribution) distribution_a = conversion.as_distribution(distribution_a) distribution_b = conversion.as_distribution(distribution_b) samples, logp_proposal = proposal_distribution.sample_and_log_prob( seed=rng_key, sample_shape=[num_samples]) samples = jax.lax.stop_gradient(samples) logp_proposal = jax.lax.stop_gradient(logp_proposal) logp_a = distribution_a.log_prob(samples) logp_b = distribution_b.log_prob(samples) importance_weight = jnp.exp(logp_a - logp_proposal) log_ratio = logp_b - logp_a kl_estimator = -importance_weight * log_ratio return jnp.mean(kl_estimator, axis=0)
def __init__(self, distribution: DistributionLike, low: Optional[Numeric] = None, high: Optional[Numeric] = None): """Initializes a Quantized distribution. Args: distribution: The base distribution to be quantized. low: Lowest possible quantized value, such that samples are `y >= ceil(low)`. Its shape must broadcast with the shape of samples from `distribution` and must not result in additional batch dimensions after broadcasting. high: Highest possible quantized value, such that samples are `y <= floor(high)`. Its shape must broadcast with the shape of samples from `distribution` and must not result in additional batch dimensions after broadcasting. """ self._dist = conversion.as_distribution(distribution) if self._dist.event_shape: raise ValueError( f'The base distribution must be univariate, but its ' f'`event_shape` is {self._dist.event_shape}.') dtype = self._dist.dtype if low is None: self._low = None else: self._low = jnp.asarray(jnp.ceil(low), dtype=dtype) if len(self._low.shape) > len(self._dist.batch_shape): raise ValueError( 'The parameter `low` must not result in additional ' 'batch dimensions.') if high is None: self._high = None else: self._high = jnp.asarray(jnp.floor(high), dtype=dtype) if len(self._high.shape) > len(self._dist.batch_shape): raise ValueError( 'The parameter `high` must not result in additional ' 'batch dimensions.') super().__init__()