def test_cumsum_jac(shape):
    rng = random.PRNGKey(0)
    x = random.normal(rng, shape=shape)

    def test_fn(x):
        return np.stack([x[..., 0], x[..., 0] + x[..., 1], x[..., 0] + x[..., 1] + x[..., 2]], -1)

    assert_allclose(cumsum(x), test_fn(x))
    assert_allclose(jacobian(cumsum)(x), jacobian(test_fn)(x))
Beispiel #2
0
 def inv(self, y):
     # inverse stick-breaking
     z1m_cumprod = 1 - cumsum(y * y)
     pad_width = [(0, 0)] * y.ndim
     pad_width[-1] = (1, 0)
     z1m_cumprod_shifted = np.pad(z1m_cumprod[..., :-1], pad_width,
                                  mode="constant", constant_values=1.)
     t = matrix_to_tril_vec(y, diagonal=-1) / np.sqrt(
         matrix_to_tril_vec(z1m_cumprod_shifted, diagonal=-1))
     # inverse of tanh
     x = np.log((1 + t) / (1 - t)) / 2
     return x
Beispiel #3
0
    def log_abs_det_jacobian(self, x, y, intermediates=None):
        # NB: because domain and codomain are two spaces with different dimensions, determinant of
        # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the
        # flatten lower triangular part of `y`.

        # stick_breaking_logdet = log(y / r) = log(z_cumprod)  (modulo right shifted)
        z1m_cumprod = 1 - cumsum(y * y)
        # by taking diagonal=-2, we don't need to shift z_cumprod to the right
        # NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array
        z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2)
        stick_breaking_logdet = 0.5 * np.sum(np.log(z1m_cumprod_tril), axis=-1)

        tanh_logdet = -2 * np.sum(x + softplus(-2 * x) - np.log(2.), axis=-1)
        return stick_breaking_logdet + tanh_logdet
Beispiel #4
0
 def sample(self, key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     walks = random.normal(key, shape=shape)
     return cumsum(walks) * np.expand_dims(self.scale, axis=-1)
Beispiel #5
0
 def inv(self, y):
     y_crop = y[..., :-1]
     z1m_cumprod = np.clip(1 - cumsum(y_crop), a_min=np.finfo(y.dtype).tiny)
     # hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod
     x = np.log(y_crop / z1m_cumprod)
     return x + np.log(x.shape[-1] - np.arange(x.shape[-1]))
Beispiel #6
0
 def __call__(self, x):
     z = np.concatenate([x[..., :1], np.exp(x[..., 1:])], axis=-1)
     return cumsum(z)