Beispiel #1
0
def _kl_divergence_transformed_transformed(
    dist1: Union[Transformed, tfd.TransformedDistribution],
    dist2: Union[Transformed, tfd.TransformedDistribution],
    *unused_args,
    input_hint: Optional[Array] = None,
    **unused_kwargs,
) -> Array:
    """Obtains the KL divergence between two Transformed distributions.

  This computes the KL divergence between two Transformed distributions with the
  same bijector. If the two Transformed distributions do not have the same
  bijector, an error is raised. To determine if the bijectors are equal, this
  method proceeds as follows:
  - If both bijectors are the same instance of a Distrax bijector, then they are
    declared equal.
  - Otherwise, the string representation of the Jaxpr of the `forward` method
    of each bijector is compared. If both string representations are equal, the
    bijectors are declared equal.
  - Otherwise, the bijectors cannot be guaranteed to be equal and an error is
    raised.

  Args:
    dist1: A Transformed distribution.
    dist2: A Transformed distribution.
    input_hint: an example sample from the base distribution, used to trace
      the `forward` method. If not specified, it is computed using a zero array
      of the shape and dtype of a sample from the base distribution.

  Returns:
    Batchwise `KL(dist1 || dist2)`.

  Raises:
    NotImplementedError: If bijectors are not known to be equal.
    ValueError: If the base distributions do not have the same `event_shape`.
  """
    if dist1.distribution.event_shape != dist2.distribution.event_shape:
        raise ValueError(
            f'The two base distributions do not have the same event shape: '
            f'{dist1.distribution.event_shape} and '
            f'{dist2.distribution.event_shape}.')

    bij1 = conversion.as_bijector(dist1.bijector)  # conversion needed for TFP
    bij2 = conversion.as_bijector(dist2.bijector)

    # Check if the bijectors are different.
    if bij1 != bij2:
        if input_hint is None:
            input_hint = jnp.zeros(dist1.distribution.event_shape,
                                   dtype=dist1.distribution.dtype)
        jaxpr_bij1 = jax.make_jaxpr(bij1.forward)(input_hint).jaxpr
        jaxpr_bij2 = jax.make_jaxpr(bij2.forward)(input_hint).jaxpr
        if str(jaxpr_bij1) != str(jaxpr_bij2):
            raise NotImplementedError(
                f'The KL divergence cannot be obtained because it is not possible to '
                f'guarantee that the bijectors {dist1.bijector.name} and '
                f'{dist2.bijector.name} of the Transformed distributions are '
                f'equal. If possible, use the same instance of a Distrax bijector.'
            )

    return dist1.distribution.kl_divergence(dist2.distribution)
Beispiel #2
0
  def __init__(self, bijectors: Sequence[BijectorLike]):
    """Initializes a Chain bijector.

    Args:
      bijectors: a sequence of bijectors to be composed into one. Each bijector
        can be a distrax bijector, a TFP bijector, or a callable to be wrapped
        by `Lambda`. The sequence must contain at least one bijector.
    """
    if not bijectors:
      raise ValueError("The sequence of bijectors cannot be empty.")
    self._bijectors = [conversion.as_bijector(b) for b in bijectors]

    # Check that neighboring bijectors in the chain have compatible dimensions
    for i, (outer, inner) in enumerate(zip(self._bijectors[:-1],
                                           self._bijectors[1:])):
      if outer.event_ndims_in != inner.event_ndims_out:
        raise ValueError(
            f"The chain of bijector event shapes are incompatible. Bijector "
            f"{i} ({outer.name}) expects events with {outer.event_ndims_in} "
            f"dimensions, while Bijector {i+1} ({inner.name}) produces events "
            f"with {inner.event_ndims_out} dimensions.")

    is_constant_jacobian = all(b.is_constant_jacobian for b in self._bijectors)
    is_constant_log_det = all(b.is_constant_log_det for b in self._bijectors)
    super().__init__(
        event_ndims_in=self._bijectors[-1].event_ndims_in,
        event_ndims_out=self._bijectors[0].event_ndims_out,
        is_constant_jacobian=is_constant_jacobian,
        is_constant_log_det=is_constant_log_det)
Beispiel #3
0
 def test_on_tfp_bijector(self):
     inputs = jnp.array([0., 1.])
     bijector = tfb.Exp()
     wrapped_bijector = conversion.as_bijector(bijector)
     assert isinstance(wrapped_bijector, Bijector)
     np.testing.assert_array_almost_equal(wrapped_bijector.forward(inputs),
                                          bijector.forward(inputs))
Beispiel #4
0
    def __init__(self, distribution: DistributionLike, bijector: BijectorLike):
        """Initializes a Transformed distribution.

    Args:
      distribution: the base distribution. Can be either a Distrax distribution
        or a TFP distribution.
      bijector: a differentiable bijective transformation. Can be a Distrax
        bijector, a TFP bijector, or a callable to be wrapped by `Lambda`.
    """
        super().__init__()
        distribution = conversion.as_distribution(distribution)
        bijector = conversion.as_bijector(bijector)

        if len(distribution.event_shape) != bijector.event_ndims_in:
            raise ValueError(
                f"Base distribution '{distribution.name}' has event shape "
                f"{distribution.event_shape}, but bijector '{bijector.name}' expects "
                f"events to have {bijector.event_ndims_in} dimensions. Perhaps use "
                f"`distrax.Block` or `distrax.Independent`?")

        self._distribution = distribution
        self._bijector = bijector
        self._batch_shape = None
        self._event_shape = None
        self._dtype = None
Beispiel #5
0
 def _inner_bijector(self, params: BijectorParams) -> base.Bijector:
   bijector = conversion.as_bijector(self._bijector(params))
   if bijector.event_ndims_in != 0 or bijector.event_ndims_out != 0:
     raise ValueError(
         f'The inner bijector must be scalar: its `event_ndims_in` and'
         f' `event_ndims_out` must both be 0. Instead, got'
         f' `event_ndims_in={bijector.event_ndims_in}` and'
         f' `event_ndims_out={bijector.event_ndims_out}`.')
   return bijector
Beispiel #6
0
 def _inner_bijector(self, params: BijectorParams) -> base.Bijector:
     bijector = conversion.as_bijector(self._bijector(params))
     if (bijector.event_ndims_in != self._inner_event_ndims
             or bijector.event_ndims_out != self._inner_event_ndims):
         raise ValueError(
             'The inner bijector event ndims in and out must match the'
             f' `inner_event_ndims={self._inner_event_ndims}`. Instead, got'
             f' `event_ndims_in={bijector.event_ndims_in}` and'
             f' `event_ndims_out={bijector.event_ndims_out}`.')
     return bijector
Beispiel #7
0
 def test_num_bins_attr_of_rational_quadratic_spline(self):
     num_bins = 4
     bijector = RationalQuadraticSpline(jnp.zeros((3 * num_bins + 1, )),
                                        range_min=0.,
                                        range_max=1.)
     wrapped_bijector = conversion.as_bijector(bijector)
     assert isinstance(wrapped_bijector, RationalQuadraticSpline)
     self.assertIs(wrapped_bijector, bijector)
     # Access the `num_bins` attribute of a wrapped RationalQuadraticSpline.
     np.testing.assert_equal(wrapped_bijector.num_bins, num_bins)
Beispiel #8
0
    def __init__(self, bijector: BijectorLike):
        """Initializes an Inverse bijector.

    Args:
      bijector: the bijector to be inverted. It can be a distrax bijector, a TFP
        bijector, or a callable to be wrapped by `Lambda`.
    """
        self._bijector = conversion.as_bijector(bijector)
        super().__init__(
            event_ndims_in=self._bijector.event_ndims_out,
            event_ndims_out=self._bijector.event_ndims_in,
            is_constant_jacobian=self._bijector.is_constant_jacobian,
            is_constant_log_det=self._bijector.is_constant_log_det)
Beispiel #9
0
 def test_forward_inverse_work_as_expected(self, bijector_fn, ndims):
     bijct = conversion.as_bijector(bijector_fn())
     x = jax.random.normal(self.seed, [2, 3])
     block = block_bijector.Block(bijct, ndims)
     np.testing.assert_array_equal(
         self.variant(bijct.forward)(x),
         self.variant(block.forward)(x))
     np.testing.assert_array_equal(
         self.variant(bijct.inverse)(x),
         self.variant(block.inverse)(x))
     np.testing.assert_array_equal(
         self.variant(bijct.forward_and_log_det)(x)[0],
         self.variant(block.forward_and_log_det)(x)[0])
     np.testing.assert_array_equal(
         self.variant(bijct.inverse_and_log_det)(x)[0],
         self.variant(block.inverse_and_log_det)(x)[0])
Beispiel #10
0
    def __init__(self, bijector: BijectorLike, ndims: int):
        """Initializes a Block.

    Args:
      bijector: the bijector to be promoted to a block bijector. It can be a
        distrax bijector, a TFP bijector, or a callable to be wrapped by
        `Lambda`.
      ndims: number of batch dimensions to promote to event dimensions.
    """
        if ndims < 0:
            raise ValueError(f"`ndims` must be non-negative; got {ndims}.")
        self._bijector = conversion.as_bijector(bijector)
        self._ndims = ndims
        super().__init__(
            event_ndims_in=ndims + self._bijector.event_ndims_in,
            event_ndims_out=ndims + self._bijector.event_ndims_out,
            is_constant_jacobian=self._bijector.is_constant_jacobian,
            is_constant_log_det=self._bijector.is_constant_log_det)
Beispiel #11
0
 def _inner_bijector(self, params: BijectorParams) -> base.Bijector:
     """Returns an inner bijector for the passed params."""
     bijector = conversion.as_bijector(self._bijector(params))
     if bijector.event_ndims_in != bijector.event_ndims_out:
         raise ValueError(
             f'The inner bijector must have `event_ndims_in==event_ndims_out`. '
             f'Instead, it has `event_ndims_in=={bijector.event_ndims_in}` and '
             f'`event_ndims_out=={bijector.event_ndims_out}`.')
     extra_ndims = self.event_ndims_in - bijector.event_ndims_in
     if extra_ndims < 0:
         raise ValueError(
             f'The inner bijector can\'t have more event dimensions than the '
             f'coupling bijector. Got {bijector.event_ndims_in} for the inner '
             f'bijector and {self.event_ndims_in} for the coupling bijector.'
         )
     elif extra_ndims > 0:
         bijector = block.Block(bijector, extra_ndims)
     return bijector
Beispiel #12
0
 def test_log_det_jacobian_works_as_expected(self, bijector_fn, ndims):
     bijct = conversion.as_bijector(bijector_fn())
     x = jax.random.normal(self.seed, [2, 3])
     block = block_bijector.Block(bijct, ndims)
     axes = tuple(range(-ndims, 0))
     np.testing.assert_allclose(
         self.variant(bijct.forward_log_det_jacobian)(x).sum(axes),
         self.variant(block.forward_log_det_jacobian)(x),
         rtol=RTOL)
     np.testing.assert_allclose(
         self.variant(bijct.inverse_log_det_jacobian)(x).sum(axes),
         self.variant(block.inverse_log_det_jacobian)(x),
         rtol=RTOL)
     np.testing.assert_allclose(
         self.variant(bijct.forward_and_log_det)(x)[1].sum(axes),
         self.variant(block.forward_and_log_det)(x)[1],
         rtol=RTOL)
     np.testing.assert_allclose(
         self.variant(bijct.inverse_and_log_det)(x)[1].sum(axes),
         self.variant(block.inverse_and_log_det)(x)[1],
         rtol=RTOL)
Beispiel #13
0
 def test_invalid_properties(self):
     bijct = conversion.as_bijector(jnp.tanh)
     with self.assertRaises(ValueError):
         block_bijector.Block(bijct, -1)
Beispiel #14
0
 def test_properties(self):
     bijct = conversion.as_bijector(jnp.tanh)
     block = block_bijector.Block(bijct, 1)
     assert block.ndims == 1
     assert isinstance(block.bijector, base.Bijector)