Esempio n. 1
0
 def inverse_log_det_jacobian(
     self, y: Array, event_ndims: Optional[int] = None) -> Array:
   """See `Bijector.inverse_log_det_jacobian`."""
   extra_event_ndims = self._check_ndims(
       "Inverse", event_ndims, base_bijector.event_ndims_out)
   ildj = base_bijector.inverse_log_det_jacobian(y)
   return math.sum_last(ildj, extra_event_ndims)
Esempio n. 2
0
 def forward_log_det_jacobian(
     self, x: Array, event_ndims: Optional[int] = None) -> Array:
   """See `Bijector.forward_log_det_jacobian`."""
   extra_event_ndims = self._check_ndims(
       "Forward", event_ndims, base_bijector.event_ndims_in)
   fldj = base_bijector.forward_log_det_jacobian(x)
   return math.sum_last(fldj, extra_event_ndims)
Esempio n. 3
0
    def log_prob(bijector, inputs):

        # forward transformation
        outputs, log_det = bijector.forward_and_log_det(inputs)

        # probability in the latent space
        latent_prob = base_dist.log_prob(outputs)

        # log probability
        log_prob = sum_last(latent_prob, ndims=latent_prob.ndim -
                            1) + sum_last(log_det, ndims=log_det.ndim - 1)

        # # log probability
        # log_prob = sum_last(latent_prob, ndims=latent_prob.ndim) + sum_last(
        #     log_det, ndims=latent_prob.ndim
        # )
        return log_prob
Esempio n. 4
0
 def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
   """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
   self._check_inverse_input_shape(y)
   masked_y = self._mask * y
   params = self._conditioner(masked_y)
   x0, log_d = self._inner_bijector(params).inverse_and_log_det(y)
   x = self._neg_mask * x0 + masked_y
   logdet = math.sum_last(self._neg_mask * log_d, self._event_ndims)
   return x, logdet
Esempio n. 5
0
 def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
   """Computes y = f(x) and log|det J(f)(x)|."""
   self._check_forward_input_shape(x)
   masked_x = self._mask * x
   params = self._conditioner(masked_x)
   y0, log_d = self._inner_bijector(params).forward_and_log_det(x)
   y = self._neg_mask * y0 + masked_x
   logdet = math.sum_last(self._neg_mask * log_d, self._event_ndims)
   return y, logdet
Esempio n. 6
0
 def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
     """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
     self._check_inverse_input_shape(y)
     masked_y = jnp.where(self._event_mask, y, 0.)
     params = self._conditioner(masked_y)
     x0, log_d = self._inner_bijector(params).inverse_and_log_det(y)
     x = jnp.where(self._event_mask, y, x0)
     logdet = math.sum_last(jnp.where(self._mask, 0., log_d),
                            self._event_ndims - self._inner_event_ndims)
     return x, logdet
Esempio n. 7
0
 def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
     """Computes y = f(x) and log|det J(f)(x)|."""
     self._check_forward_input_shape(x)
     masked_x = jnp.where(self._event_mask, x, 0.)
     params = self._conditioner(masked_x)
     y0, log_d = self._inner_bijector(params).forward_and_log_det(x)
     y = jnp.where(self._event_mask, x, y0)
     logdet = math.sum_last(jnp.where(self._mask, 0., log_d),
                            self._event_ndims - self._inner_event_ndims)
     return y, logdet
Esempio n. 8
0
 def inverse_log_det_jacobian(self, y: Array) -> Array:
     """Computes log|det J(f^{-1})(y)|."""
     self._check_inverse_input_shape(y)
     log_det = self._bijector.inverse_log_det_jacobian(y)
     return math.sum_last(log_det, self._ndims)
Esempio n. 9
0
 def forward_log_det_jacobian(self, x: Array) -> Array:
     """Computes log|det J(f)(x)|."""
     self._check_forward_input_shape(x)
     log_det = self._bijector.forward_log_det_jacobian(x)
     return math.sum_last(log_det, self._ndims)
Esempio n. 10
0
 def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
     """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
     self._check_inverse_input_shape(y)
     x, log_det = self._bijector.inverse_and_log_det(y)
     return x, math.sum_last(log_det, self._ndims)
Esempio n. 11
0
 def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
     """Computes y = f(x) and log|det J(f)(x)|."""
     self._check_forward_input_shape(x)
     y, log_det = self._bijector.forward_and_log_det(x)
     return y, math.sum_last(log_det, self._ndims)
Esempio n. 12
0
 def test_sum_last(self):
     x = jax.random.normal(jax.random.PRNGKey(42), (2, 3, 4))
     np.testing.assert_array_equal(math.sum_last(x, 0), x)
     np.testing.assert_array_equal(math.sum_last(x, 1), x.sum(-1))
     np.testing.assert_array_equal(math.sum_last(x, 2), x.sum((-2, -1)))
     np.testing.assert_array_equal(math.sum_last(x, 3), x.sum())