def model_fn(): a = yield tfp_compatible_distribution(Uniform(2. * jnp.ones(3), 4.), name='a') b = yield tfp_compatible_distribution(Uniform(2. * jnp.ones(3), 4.), name='b') yield tfp_compatible_distribution(Laplace(a * jnp.ones((2, 1)), b), name='c')
def test_with_joint_distribution_named_auto_batched(self): def laplace(a, b): return tfp_compatible_distribution(Laplace(a * jnp.ones((2, 1)), b)) meta_dist = tfd.JointDistributionNamedAutoBatched({ 'a': tfp_compatible_distribution(Uniform(2. * jnp.ones(3), 4.)), 'b': tfp_compatible_distribution(Uniform(2. * jnp.ones(3), 4.)), 'c': laplace}, validate_args=True) meta_dist.log_prob(meta_dist.sample(4, seed=self._key))
def to_tfp(obj: Union[bijector.Bijector, tfb.Bijector, distribution.Distribution, tfd.Distribution], name: Optional[str] = None): """Converts a distribution or bijector to a TFP-compatible equivalent object. The returned object is not necessarily of type `tfb.Bijector` or `tfd.Distribution`; rather, it is a Distrax object that implements TFP functionality so that it can be used in TFP. If the input is already of TFP type, it is returned unchanged. Args: obj: The distribution or bijector to be converted to TFP. name: The name of the resulting object. Returns: A TFP-compatible equivalent distribution or bijector. """ if isinstance(obj, (tfb.Bijector, tfd.Distribution)): return obj elif isinstance(obj, bijector.Bijector): return tfp_compatible_bijector.tfp_compatible_bijector(obj, name) elif isinstance(obj, distribution.Distribution): return tfp_compatible_distribution.tfp_compatible_distribution( obj, name) else: raise TypeError( f"`to_tfp` can only convert objects of type: `distrax.Bijector`," f" `tfb.Bijector`, `distrax.Distribution`, `tfd.Distribution`. Got type" f" `{type(obj)}`.")
def test_with_independent(self): base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.])) wrapped_dist = tfp_compatible_distribution(base_dist) meta_dist = tfd.Independent(wrapped_dist, 1, validate_args=True) samples = meta_dist.sample((), self._key) log_prob = meta_dist.log_prob(samples) distrax_meta_dist = Independent(base_dist, 1) expected_log_prob = distrax_meta_dist.log_prob(samples) self.assertion_fn(log_prob, expected_log_prob)
def test_with_transformed_distribution(self): base_dist = Normal(loc=jnp.array([0., 0.]), scale=jnp.array([1., 1.])) wrapped_dist = tfp_compatible_distribution(base_dist) meta_dist = tfd.TransformedDistribution( distribution=wrapped_dist, bijector=tfb.Exp(), validate_args=True) samples = meta_dist.sample(seed=self._key) log_prob = meta_dist.log_prob(samples) distrax_meta_dist = Transformed( distribution=base_dist, bijector=tfb.Exp()) expected_log_prob = distrax_meta_dist.log_prob(samples) self.assertion_fn(log_prob, expected_log_prob)
def wrapped_dist(self): return tfp_compatible_distribution(self.base_dist)
def laplace(a, b): return tfp_compatible_distribution(Laplace(a * jnp.ones((2, 1)), b))
def test_with_sample(self): base_dist = Normal(0., 1.) wrapped_dist = tfp_compatible_distribution(base_dist) meta_dist = tfd.Sample( wrapped_dist, sample_shape=[1, 3], validate_args=True) meta_dist.log_prob(meta_dist.sample(2, seed=self._key))