def inverse_log_det_jacobian( self, y: Array, event_ndims: Optional[int] = None) -> Array: """See `Bijector.inverse_log_det_jacobian`.""" extra_event_ndims = self._check_ndims( "Inverse", event_ndims, base_bijector.event_ndims_out) ildj = base_bijector.inverse_log_det_jacobian(y) return math.sum_last(ildj, extra_event_ndims)
def forward_log_det_jacobian( self, x: Array, event_ndims: Optional[int] = None) -> Array: """See `Bijector.forward_log_det_jacobian`.""" extra_event_ndims = self._check_ndims( "Forward", event_ndims, base_bijector.event_ndims_in) fldj = base_bijector.forward_log_det_jacobian(x) return math.sum_last(fldj, extra_event_ndims)
def log_prob(bijector, inputs): # forward transformation outputs, log_det = bijector.forward_and_log_det(inputs) # probability in the latent space latent_prob = base_dist.log_prob(outputs) # log probability log_prob = sum_last(latent_prob, ndims=latent_prob.ndim - 1) + sum_last(log_det, ndims=log_det.ndim - 1) # # log probability # log_prob = sum_last(latent_prob, ndims=latent_prob.ndim) + sum_last( # log_det, ndims=latent_prob.ndim # ) return log_prob
def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" self._check_inverse_input_shape(y) masked_y = self._mask * y params = self._conditioner(masked_y) x0, log_d = self._inner_bijector(params).inverse_and_log_det(y) x = self._neg_mask * x0 + masked_y logdet = math.sum_last(self._neg_mask * log_d, self._event_ndims) return x, logdet
def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: """Computes y = f(x) and log|det J(f)(x)|.""" self._check_forward_input_shape(x) masked_x = self._mask * x params = self._conditioner(masked_x) y0, log_d = self._inner_bijector(params).forward_and_log_det(x) y = self._neg_mask * y0 + masked_x logdet = math.sum_last(self._neg_mask * log_d, self._event_ndims) return y, logdet
def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" self._check_inverse_input_shape(y) masked_y = jnp.where(self._event_mask, y, 0.) params = self._conditioner(masked_y) x0, log_d = self._inner_bijector(params).inverse_and_log_det(y) x = jnp.where(self._event_mask, y, x0) logdet = math.sum_last(jnp.where(self._mask, 0., log_d), self._event_ndims - self._inner_event_ndims) return x, logdet
def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: """Computes y = f(x) and log|det J(f)(x)|.""" self._check_forward_input_shape(x) masked_x = jnp.where(self._event_mask, x, 0.) params = self._conditioner(masked_x) y0, log_d = self._inner_bijector(params).forward_and_log_det(x) y = jnp.where(self._event_mask, x, y0) logdet = math.sum_last(jnp.where(self._mask, 0., log_d), self._event_ndims - self._inner_event_ndims) return y, logdet
def inverse_log_det_jacobian(self, y: Array) -> Array: """Computes log|det J(f^{-1})(y)|.""" self._check_inverse_input_shape(y) log_det = self._bijector.inverse_log_det_jacobian(y) return math.sum_last(log_det, self._ndims)
def forward_log_det_jacobian(self, x: Array) -> Array: """Computes log|det J(f)(x)|.""" self._check_forward_input_shape(x) log_det = self._bijector.forward_log_det_jacobian(x) return math.sum_last(log_det, self._ndims)
def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" self._check_inverse_input_shape(y) x, log_det = self._bijector.inverse_and_log_det(y) return x, math.sum_last(log_det, self._ndims)
def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: """Computes y = f(x) and log|det J(f)(x)|.""" self._check_forward_input_shape(x) y, log_det = self._bijector.forward_and_log_det(x) return y, math.sum_last(log_det, self._ndims)
def test_sum_last(self): x = jax.random.normal(jax.random.PRNGKey(42), (2, 3, 4)) np.testing.assert_array_equal(math.sum_last(x, 0), x) np.testing.assert_array_equal(math.sum_last(x, 1), x.sum(-1)) np.testing.assert_array_equal(math.sum_last(x, 2), x.sum((-2, -1))) np.testing.assert_array_equal(math.sum_last(x, 3), x.sum())