def test_log_normal_log_prob(self):
        def f(rng):
            return np.exp(random_normal(rng))

        dist = bd.TransformedDistribution(bd.Normal(0., 1.), bb.Exp())
        f_lp = log_prob(f)
        self.assertEqual(f_lp(2.), dist.log_prob(2.))
# limitations under the License.
# ============================================================================
# Lint as: python3
"""Tests for tensorflow_probability.spinoffs.oryx.bijectors.bijectors_extensions."""

from absl.testing import absltest
from absl.testing import parameterized
import jax
import numpy as onp

from oryx import bijectors as bb
from oryx import core
from oryx.internal import test_util

BIJECTORS = [
    ('exp', lambda: bb.Exp(), 1., []),  # pylint: disable=unnecessary-lambda
    ('affine_scalar', lambda: bb.AffineScalar(1., 2.), 1., [2., 1.]),
    ('transform_diagonal', lambda: bb.TransformDiagonal(bb.Exp()),
     onp.eye(2).astype(onp.float32), []),
    ('invert', lambda: bb.Invert(bb.Exp()), 1., []),
]


class BijectorsExtensionsTest(test_util.TestCase):
    @parameterized.named_parameters(BIJECTORS)
    def test_forward(self, bij, inp, flat):
        del flat
        b = bij()
        b.forward(inp)

    @parameterized.named_parameters(BIJECTORS)
Esempio n. 3
0
# limitations under the License.
# ============================================================================
# Lint as: python3
"""Tests for tensorflow_probability.spinoffs.oryx.bijectors.bijectors_extensions."""

from absl.testing import absltest
from absl.testing import parameterized
import jax
import numpy as onp
from oryx import bijectors as bb
from oryx import core

BIJECTORS = [
    ('exp', bb.Exp, (), {}, 1., []),
    ('affine_scalar', bb.AffineScalar, (1., 2.), {}, 1., [1., 2.]),
    ('transform_diagonal', bb.TransformDiagonal, (bb.Exp(), ), {},
     onp.eye(2).astype(onp.float32), []),
    ('invert', bb.Invert, (bb.Exp(), ), {}, 1., []),
]


class BijectorsExtensionsTest(parameterized.TestCase):
    @parameterized.named_parameters(BIJECTORS)
    def test_forward(self, bij, args, kwargs, inp, flat):
        del flat
        b = bij(*args, **kwargs)
        b.forward(inp)

    @parameterized.named_parameters(BIJECTORS)
    def test_inverse(self, bij, args, kwargs, inp, flat):
        del flat