示例#1
0
def estimate_kl_best_effort(
        distribution_a: DistributionLike,
        distribution_b: DistributionLike,
        rng_key: PNRGKey,
        num_samples: int,
        proposal_distribution: Optional[DistributionLike] = None):
    """Estimates KL(distribution_a, distribution_b) exactly or with DiCE.

  If the kl_divergence(distribution_a, distribution_b) is not supported,
  the DiCE estimator is used instead.

  Args:
    distribution_a: The first distribution.
    distribution_b: The second distribution.
    rng_key: The PRNGKey random key.
    num_samples: The number of samples, if using the DiCE estimator.
    proposal_distribution: A proposal distribution for the samples, if using
      the DiCE estimator. If None, use `distribution_a` as proposal.

  Returns:
    The estimated KL divergence.
  """
    distribution_a = conversion.as_distribution(distribution_a)
    distribution_b = conversion.as_distribution(distribution_b)
    # If possible, compute the exact KL.
    try:
        return tfd.kl_divergence(distribution_a, distribution_b)
    except NotImplementedError:
        pass
    return mc_estimate_kl(distribution_a,
                          distribution_b,
                          rng_key,
                          num_samples=num_samples,
                          proposal_distribution=proposal_distribution)
示例#2
0
    def __init__(self, mixture_distribution: CategoricalLike,
                 components_distribution: DistributionLike):
        """Initializes a mixture distribution for components of a shared family.

    Args:
      mixture_distribution: Distribution over selecting components.
      components_distribution: Component distribution, with rightmost batch
        dimension indexing components.
    """
        super().__init__()
        mixture_distribution = conversion.as_distribution(mixture_distribution)
        components_distribution = conversion.as_distribution(
            components_distribution)
        self._mixture_distribution = mixture_distribution
        self._components_distribution = components_distribution

        # Store normalized weights (last axis of logits is for components).
        # This uses the TFP API, which is replicated in Distrax.
        self._mixture_log_probs = jax.nn.log_softmax(
            mixture_distribution.logits_parameter(), axis=-1)

        batch_shape_mixture = mixture_distribution.batch_shape
        batch_shape_components = components_distribution.batch_shape
        if batch_shape_mixture != batch_shape_components[:-1]:
            msg = (
                f'`mixture_distribution.batch_shape` '
                f'({mixture_distribution.batch_shape}) is not compatible with '
                f'`components_distribution.batch_shape` '
                f'({components_distribution.batch_shape}`)')
            raise ValueError(msg)
示例#3
0
 def test_on_tfp_distribution(self):
     dist = tfd.Normal(loc=0., scale=1.)
     wrapped_dist = conversion.as_distribution(dist)
     assert isinstance(wrapped_dist, tfd.Normal)
     assert isinstance(wrapped_dist, Distribution)
     # Access the `loc` attribute of a wrapped Normal.
     np.testing.assert_almost_equal(wrapped_dist.loc, 0.)
示例#4
0
 def test_num_categories_attr_of_categorical(self):
     dist = Categorical(logits=jnp.array([0., 0., 0.]))
     wrapped_dist = conversion.as_distribution(dist)
     assert isinstance(wrapped_dist, Categorical)
     self.assertIs(wrapped_dist, dist)
     # Access the `num_categories` attribute of a wrapped Categorical.
     np.testing.assert_equal(wrapped_dist.num_categories, 3)
示例#5
0
 def test_loc_attr_of_normal(self):
     dist = Normal(loc=0., scale=1.)
     wrapped_dist = conversion.as_distribution(dist)
     assert isinstance(wrapped_dist, Normal)
     self.assertIs(wrapped_dist, dist)
     # Access the `loc` attribute of a wrapped Normal.
     np.testing.assert_almost_equal(wrapped_dist.loc, 0.)
示例#6
0
    def __init__(self,
                 distribution: DistributionLike,
                 reinterpreted_batch_ndims: Optional[int] = None):
        """Initializes an Independent distribution.

    Args:
      distribution: Base distribution instance.
      reinterpreted_batch_ndims: Number of event dimensions.
    """
        super().__init__()
        distribution = conversion.as_distribution(distribution)
        self._distribution = distribution
        dist_batch_shape = distribution.batch_shape
        if reinterpreted_batch_ndims is not None:
            dist_batch_ndims = len(dist_batch_shape)
            if reinterpreted_batch_ndims > dist_batch_ndims:
                raise ValueError(
                    f'`reinterpreted_batch_ndims` is {reinterpreted_batch_ndims}, but'
                    f' distribution `{distribution.name}` has only {dist_batch_ndims}'
                    f' batch dimensions.')
            elif reinterpreted_batch_ndims < 0:
                raise ValueError(
                    f'`reinterpreted_batch_ndims` can\'t be negative; got'
                    f' {reinterpreted_batch_ndims}.')
            self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
        else:
            self._reinterpreted_batch_ndims = max(len(dist_batch_shape) - 1, 0)
        event_ndims = len(dist_batch_shape) - self._reinterpreted_batch_ndims
        self._event_shape = (dist_batch_shape[event_ndims:] +
                             distribution.event_shape)
示例#7
0
    def __init__(self, distribution: DistributionLike, bijector: BijectorLike):
        """Initializes a Transformed distribution.

    Args:
      distribution: the base distribution. Can be either a Distrax distribution
        or a TFP distribution.
      bijector: a differentiable bijective transformation. Can be a Distrax
        bijector, a TFP bijector, or a callable to be wrapped by `Lambda`.
    """
        super().__init__()
        distribution = conversion.as_distribution(distribution)
        bijector = conversion.as_bijector(bijector)

        if len(distribution.event_shape) != bijector.event_ndims_in:
            raise ValueError(
                f"Base distribution '{distribution.name}' has event shape "
                f"{distribution.event_shape}, but bijector '{bijector.name}' expects "
                f"events to have {bijector.event_ndims_in} dimensions. Perhaps use "
                f"`distrax.Block` or `distrax.Independent`?")

        self._distribution = distribution
        self._bijector = bijector
        self._batch_shape = None
        self._event_shape = None
        self._dtype = None
示例#8
0
def mc_estimate_kl_with_reparameterized(distribution_a: DistributionLike,
                                        distribution_b: DistributionLike,
                                        rng_key: PNRGKey, num_samples: int):
    """Estimates KL(distribution_a, distribution_b)."""
    if isinstance(distribution_a, tfd.Distribution):
        if distribution_a.reparameterization_type != tfd.FULLY_REPARAMETERIZED:
            raise ValueError(
                f'Distribution `{distribution_a.name}` cannot be reparameterized.'
            )
    distribution_a = conversion.as_distribution(distribution_a)
    distribution_b = conversion.as_distribution(distribution_b)

    samples, logp_a = distribution_a.sample_and_log_prob(
        seed=rng_key, sample_shape=[num_samples])
    logp_b = distribution_b.log_prob(samples)
    log_ratio = logp_b - logp_a
    kl_estimator = -log_ratio
    return jnp.mean(kl_estimator, axis=0)
示例#9
0
 def test_attrs_of_transformed_distribution(self):
     dist = Transformed(Normal(loc=0., scale=1.), bijector=lambda x: x)
     wrapped_dist = conversion.as_distribution(dist)
     assert isinstance(wrapped_dist, Transformed)
     self.assertIs(wrapped_dist, dist)
     # Access the `distribution` attribute of a wrapped Transformed.
     assert isinstance(wrapped_dist.distribution, Normal)
     # Access the `loc` attribute of a transformed Normal within a wrapped
     # Transformed.
     np.testing.assert_almost_equal(wrapped_dist.distribution.loc, 0.)
示例#10
0
def mc_estimate_mode(distribution: DistributionLike, rng_key: PNRGKey,
                     num_samples: int):
    """Returns a Monte Carlo estimate of the mode of a distribution."""
    distribution = conversion.as_distribution(distribution)
    # Obtain samples from the distribution and their log probability.
    samples, log_probs = distribution.sample_and_log_prob(
        seed=rng_key, sample_shape=[num_samples])
    # Do argmax over the sample_shape.
    index = jnp.expand_dims(jnp.argmax(log_probs, axis=0), axis=0)
    mode = jnp.squeeze(jnp.take_along_axis(samples, index, axis=0), axis=0)
    return mode
示例#11
0
def mc_estimate_kl(distribution_a: DistributionLike,
                   distribution_b: DistributionLike,
                   rng_key: PNRGKey,
                   num_samples: int,
                   proposal_distribution: Optional[DistributionLike] = None):
    """Estimates KL(distribution_a, distribution_b) with the DiCE estimator.

  To get correct gradients with respect the `distribution_a`, we use the DiCE
  estimator, i.e., we stop the gradient with respect to the samples and with
  respect to the denominator in the importance weights. We then do not need
  reparametrized distributions.

  Args:
    distribution_a: The first distribution.
    distribution_b: The second distribution.
    rng_key: The PRNGKey random key.
    num_samples: The number of samples, if using the DiCE estimator.
    proposal_distribution: A proposal distribution for the samples, if using the
      DiCE estimator. If None, use `distribution_a` as proposal.

  Returns:
    The estimated KL divergence.
  """
    if proposal_distribution is None:
        proposal_distribution = distribution_a
    proposal_distribution = conversion.as_distribution(proposal_distribution)
    distribution_a = conversion.as_distribution(distribution_a)
    distribution_b = conversion.as_distribution(distribution_b)

    samples, logp_proposal = proposal_distribution.sample_and_log_prob(
        seed=rng_key, sample_shape=[num_samples])
    samples = jax.lax.stop_gradient(samples)
    logp_proposal = jax.lax.stop_gradient(logp_proposal)
    logp_a = distribution_a.log_prob(samples)
    logp_b = distribution_b.log_prob(samples)
    importance_weight = jnp.exp(logp_a - logp_proposal)
    log_ratio = logp_b - logp_a
    kl_estimator = -importance_weight * log_ratio
    return jnp.mean(kl_estimator, axis=0)
示例#12
0
    def __init__(self,
                 distribution: DistributionLike,
                 low: Optional[Numeric] = None,
                 high: Optional[Numeric] = None):
        """Initializes a Quantized distribution.

    Args:
      distribution: The base distribution to be quantized.
      low: Lowest possible quantized value, such that samples are
        `y >= ceil(low)`. Its shape must broadcast with the shape of samples
        from `distribution` and must not result in additional batch dimensions
        after broadcasting.
      high: Highest possible quantized value, such that samples are
        `y <= floor(high)`. Its shape must broadcast with the shape of samples
        from `distribution` and must not result in additional batch dimensions
        after broadcasting.
    """
        self._dist = conversion.as_distribution(distribution)
        if self._dist.event_shape:
            raise ValueError(
                f'The base distribution must be univariate, but its '
                f'`event_shape` is {self._dist.event_shape}.')
        dtype = self._dist.dtype
        if low is None:
            self._low = None
        else:
            self._low = jnp.asarray(jnp.ceil(low), dtype=dtype)
            if len(self._low.shape) > len(self._dist.batch_shape):
                raise ValueError(
                    'The parameter `low` must not result in additional '
                    'batch dimensions.')
        if high is None:
            self._high = None
        else:
            self._high = jnp.asarray(jnp.floor(high), dtype=dtype)
            if len(self._high.shape) > len(self._dist.batch_shape):
                raise ValueError(
                    'The parameter `high` must not result in additional '
                    'batch dimensions.')
        super().__init__()