コード例 #1
0
 def model_fn():
   a = yield tfp_compatible_distribution(Uniform(2. * jnp.ones(3), 4.),
                                         name='a')
   b = yield tfp_compatible_distribution(Uniform(2. * jnp.ones(3), 4.),
                                         name='b')
   yield tfp_compatible_distribution(Laplace(a * jnp.ones((2, 1)), b),
                                     name='c')
コード例 #2
0
 def test_with_joint_distribution_named_auto_batched(self):
   def laplace(a, b):
     return tfp_compatible_distribution(Laplace(a * jnp.ones((2, 1)), b))
   meta_dist = tfd.JointDistributionNamedAutoBatched({
       'a': tfp_compatible_distribution(Uniform(2. * jnp.ones(3), 4.)),
       'b': tfp_compatible_distribution(Uniform(2. * jnp.ones(3), 4.)),
       'c': laplace}, validate_args=True)
   meta_dist.log_prob(meta_dist.sample(4, seed=self._key))
コード例 #3
0
def to_tfp(obj: Union[bijector.Bijector, tfb.Bijector,
                      distribution.Distribution, tfd.Distribution],
           name: Optional[str] = None):
    """Converts a distribution or bijector to a TFP-compatible equivalent object.

  The returned object is not necessarily of type `tfb.Bijector` or
  `tfd.Distribution`; rather, it is a Distrax object that implements TFP
  functionality so that it can be used in TFP.

  If the input is already of TFP type, it is returned unchanged.

  Args:
    obj: The distribution or bijector to be converted to TFP.
    name: The name of the resulting object.

  Returns:
    A TFP-compatible equivalent distribution or bijector.
  """
    if isinstance(obj, (tfb.Bijector, tfd.Distribution)):
        return obj
    elif isinstance(obj, bijector.Bijector):
        return tfp_compatible_bijector.tfp_compatible_bijector(obj, name)
    elif isinstance(obj, distribution.Distribution):
        return tfp_compatible_distribution.tfp_compatible_distribution(
            obj, name)
    else:
        raise TypeError(
            f"`to_tfp` can only convert objects of type: `distrax.Bijector`,"
            f" `tfb.Bijector`, `distrax.Distribution`, `tfd.Distribution`. Got type"
            f" `{type(obj)}`.")
コード例 #4
0
  def test_with_independent(self):
    base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.]))
    wrapped_dist = tfp_compatible_distribution(base_dist)

    meta_dist = tfd.Independent(wrapped_dist, 1, validate_args=True)
    samples = meta_dist.sample((), self._key)
    log_prob = meta_dist.log_prob(samples)

    distrax_meta_dist = Independent(base_dist, 1)
    expected_log_prob = distrax_meta_dist.log_prob(samples)

    self.assertion_fn(log_prob, expected_log_prob)
コード例 #5
0
  def test_with_transformed_distribution(self):
    base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.]))
    wrapped_dist = tfp_compatible_distribution(base_dist)

    meta_dist = tfd.TransformedDistribution(
        distribution=wrapped_dist, bijector=tfb.Exp(), validate_args=True)
    samples = meta_dist.sample(seed=self._key)
    log_prob = meta_dist.log_prob(samples)

    distrax_meta_dist = Transformed(
        distribution=base_dist, bijector=tfb.Exp())
    expected_log_prob = distrax_meta_dist.log_prob(samples)

    self.assertion_fn(log_prob, expected_log_prob)
コード例 #6
0
 def wrapped_dist(self):
   return tfp_compatible_distribution(self.base_dist)
コード例 #7
0
 def laplace(a, b):
   return tfp_compatible_distribution(Laplace(a * jnp.ones((2, 1)), b))
コード例 #8
0
 def test_with_sample(self):
   base_dist = Normal(0., 1.)
   wrapped_dist = tfp_compatible_distribution(base_dist)
   meta_dist = tfd.Sample(
       wrapped_dist, sample_shape=[1, 3], validate_args=True)
   meta_dist.log_prob(meta_dist.sample(2, seed=self._key))