def test_cumsum_jac(shape): rng = random.PRNGKey(0) x = random.normal(rng, shape=shape) def test_fn(x): return np.stack([x[..., 0], x[..., 0] + x[..., 1], x[..., 0] + x[..., 1] + x[..., 2]], -1) assert_allclose(cumsum(x), test_fn(x)) assert_allclose(jacobian(cumsum)(x), jacobian(test_fn)(x))
def inv(self, y): # inverse stick-breaking z1m_cumprod = 1 - cumsum(y * y) pad_width = [(0, 0)] * y.ndim pad_width[-1] = (1, 0) z1m_cumprod_shifted = np.pad(z1m_cumprod[..., :-1], pad_width, mode="constant", constant_values=1.) t = matrix_to_tril_vec(y, diagonal=-1) / np.sqrt( matrix_to_tril_vec(z1m_cumprod_shifted, diagonal=-1)) # inverse of tanh x = np.log((1 + t) / (1 - t)) / 2 return x
def log_abs_det_jacobian(self, x, y, intermediates=None): # NB: because domain and codomain are two spaces with different dimensions, determinant of # Jacobian is not well-defined. Here we return `log_abs_det_jacobian` of `x` and the # flatten lower triangular part of `y`. # stick_breaking_logdet = log(y / r) = log(z_cumprod) (modulo right shifted) z1m_cumprod = 1 - cumsum(y * y) # by taking diagonal=-2, we don't need to shift z_cumprod to the right # NB: diagonal=-2 works fine for (2 x 2) matrix, where we get an empty array z1m_cumprod_tril = matrix_to_tril_vec(z1m_cumprod, diagonal=-2) stick_breaking_logdet = 0.5 * np.sum(np.log(z1m_cumprod_tril), axis=-1) tanh_logdet = -2 * np.sum(x + softplus(-2 * x) - np.log(2.), axis=-1) return stick_breaking_logdet + tanh_logdet
def sample(self, key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape walks = random.normal(key, shape=shape) return cumsum(walks) * np.expand_dims(self.scale, axis=-1)
def inv(self, y): y_crop = y[..., :-1] z1m_cumprod = np.clip(1 - cumsum(y_crop), a_min=np.finfo(y.dtype).tiny) # hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod x = np.log(y_crop / z1m_cumprod) return x + np.log(x.shape[-1] - np.arange(x.shape[-1]))
def __call__(self, x): z = np.concatenate([x[..., :1], np.exp(x[..., 1:])], axis=-1) return cumsum(z)