コード例 #1
0
def as_bijector(obj: BijectorLike) -> bijector.BijectorT:
    """Converts a bijector-like object to a Distrax bijector.

  Bijector-like objects are: Distrax bijectors, TFP bijectors, and callables.
  Distrax bijectors are returned unchanged. TFP bijectors are converted to a
  Distrax equivalent. Callables are wrapped by `distrax.Lambda`, with a few
  exceptions where an explicit implementation already exists and is returned.

  Args:
    obj: The bijector-like object to be converted.

  Returns:
    A Distrax bijector.
  """
    if isinstance(obj, bijector.Bijector):
        return obj
    elif isinstance(obj, tfb.Bijector):
        return bijector_from_tfp.BijectorFromTFP(obj)
    elif obj is jax.nn.sigmoid:
        return sigmoid.Sigmoid()
    elif obj is jnp.tanh:
        return tanh.Tanh()
    elif callable(obj):
        return lambda_bijector.Lambda(obj)
    else:
        raise TypeError(
            f"A bijector-like object can be a `distrax.Bijector`, a `tfb.Bijector`,"
            f" or a callable. Got type `{type(obj)}`.")
コード例 #2
0
ファイル: tanh_test.py プロジェクト: stjordanis/distrax
  def test_jittable(self):
    @jax.jit
    def f(x, b):
      return b.forward(x)

    bijector = tanh.Tanh()
    x = np.zeros(())
    f(x, bijector)
コード例 #3
0
ファイル: tanh_test.py プロジェクト: stjordanis/distrax
  def test_integer_inputs(self, inputs):
    bijector = tanh.Tanh()
    output, log_det = self.variant(bijector.forward_and_log_det)(inputs)

    expected_out = jnp.tanh(inputs).astype(jnp.float32)
    expected_log_det = jnp.zeros_like(inputs, dtype=jnp.float32)

    np.testing.assert_array_equal(output, expected_out)
    np.testing.assert_array_equal(log_det, expected_log_det)
コード例 #4
0
ファイル: tanh_test.py プロジェクト: stjordanis/distrax
  def test_event_shape(self, mu, sigma, base_dist):
    base = base_dist(mu, sigma)
    bijector = tanh.Tanh()
    dist = transformed.Transformed(base, bijector)

    tfp_bijector = tfb.Tanh()
    tfp_dist = tfd.TransformedDistribution(
        conversion.to_tfp(base), tfp_bijector)

    assert dist.event_shape == tfp_dist.event_shape
コード例 #5
0
ファイル: tanh_test.py プロジェクト: stjordanis/distrax
  def test_prob(self, mu, sigma, value, base_dist):
    base = base_dist(mu, sigma)
    bijector = tanh.Tanh()
    dist = transformed.Transformed(base, bijector)
    actual = self.variant(dist.prob)(value)

    tfp_bijector = tfb.Tanh()
    tfp_dist = tfd.TransformedDistribution(
        conversion.to_tfp(base), tfp_bijector)
    expected = tfp_dist.prob(value)
    np.testing.assert_allclose(actual, expected, atol=1e-9)
コード例 #6
0
ファイル: tanh_test.py プロジェクト: stjordanis/distrax
  def test_stability(self):
    bijector = tanh.Tanh()
    tfp_bijector = tfb.Tanh()

    x = np.array([-10.0, -3.3, 0.0, 3.3, 10.0], dtype=np.float32)
    fldj = tfp_bijector.forward_log_det_jacobian(x, event_ndims=0)
    fldj_ = self.variant(bijector.forward_log_det_jacobian)(x)
    np.testing.assert_allclose(fldj_, fldj, rtol=RTOL)

    y = bijector.forward(x)
    ildj = tfp_bijector.inverse_log_det_jacobian(y, event_ndims=0)
    ildj_ = self.variant(bijector.inverse_log_det_jacobian)(y)
    np.testing.assert_allclose(ildj_, ildj, rtol=RTOL)
コード例 #7
0
ファイル: tanh_test.py プロジェクト: stjordanis/distrax
  def test_sample_shape(self, mu, sigma, sample_shape, base_dist):
    base = base_dist(mu, sigma)
    bijector = tanh.Tanh()
    dist = transformed.Transformed(base, bijector)
    def sample_fn(seed, sample_shape):
      return dist.sample(seed=seed, sample_shape=sample_shape)
    samples = self.variant(sample_fn, ignore_argnums=(1,), static_argnums=1)(
        self.seed, sample_shape)

    tfp_bijector = tfb.Tanh()
    tfp_dist = tfd.TransformedDistribution(
        conversion.to_tfp(base), tfp_bijector)
    tfp_samples = tfp_dist.sample(sample_shape=sample_shape, seed=self.seed)

    chex.assert_equal_shape([samples, tfp_samples])
コード例 #8
0
ファイル: tanh_test.py プロジェクト: stjordanis/distrax
  def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist):
    base = base_dist(mu, sigma)
    bijector = tanh.Tanh()
    dist = transformed.Transformed(base, bijector)
    def sample_and_log_prob_fn(seed, sample_shape):
      return dist.sample_and_log_prob(seed=seed, sample_shape=sample_shape)
    samples, log_prob = self.variant(
        sample_and_log_prob_fn, ignore_argnums=(1,), static_argnums=(1,))(
            self.seed, sample_shape)
    expected_samples = bijector.forward(
        base.sample(seed=self.seed, sample_shape=sample_shape))

    tfp_bijector = tfb.Tanh()
    tfp_dist = tfd.TransformedDistribution(
        conversion.to_tfp(base), tfp_bijector)
    tfp_samples = tfp_dist.sample(seed=self.seed, sample_shape=sample_shape)
    tfp_log_prob = tfp_dist.log_prob(samples)

    chex.assert_equal_shape([samples, tfp_samples])
    np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL)
    np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)