def test_cumprod_jac(shape):
    rng = random.PRNGKey(0)
    x = random.uniform(rng, shape=shape)

    def test_fn(x):
        return np.stack([x[..., 0], x[..., 0] * x[..., 1], x[..., 0] * x[..., 1] * x[..., 2]], -1)

    assert_allclose(cumprod(x), test_fn(x))
    assert_allclose(jacobian(cumprod)(x), jacobian(test_fn)(x), atol=1e-7)
Exemple #2
0
 def __call__(self, x):
     # we shift x to obtain a balanced mapping (0, 0, ..., 0) -> (1/K, 1/K, ..., 1/K)
     x = x - np.log(x.shape[-1] - np.arange(x.shape[-1]))
     # convert to probabilities (relative to the remaining) of each fraction of the stick
     z = _clipped_expit(x)
     z1m_cumprod = cumprod(1 - z)
     pad_width = [(0, 0)] * x.ndim
     pad_width[-1] = (0, 1)
     z_padded = np.pad(z, pad_width, mode="constant", constant_values=1.)
     pad_width[-1] = (1, 0)
     z1m_cumprod_shifted = np.pad(z1m_cumprod, pad_width, mode="constant", constant_values=1.)
     return z_padded * z1m_cumprod_shifted