Beispiel #1
0
 def _rvs(self, a, b):
     # XXX the implementation is different from PyTorch's one
     # in PyTorch, a sample is generated from dirichlet distribution
     key_a, key_b = random.split(self._random_state)
     gamma_a = standard_gamma(key_a, a, shape=self._size)
     gamma_b = standard_gamma(key_b, b, shape=self._size)
     return gamma_a / (gamma_a + gamma_b)
def test_standard_gamma_batch():
    rng = random.PRNGKey(0)
    alphas = np.array([1., 2., 3.])
    rngs = random.split(rng, 3)

    samples = vmap(lambda rng, alpha: standard_gamma(rng, alpha))(rngs, alphas)
    for i in range(3):
        assert_allclose(samples[i], standard_gamma(rngs[i], alphas[i]))
def test_standard_gamma_grad(alpha):
    rng = random.PRNGKey(0)
    alphas = np.full((100,), alpha)
    z = standard_gamma(rng, alphas)
    actual_grad = grad(lambda x: np.sum(standard_gamma(rng, x)))(alphas)

    eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
    cdf_dot = (osp_stats.gamma.cdf(z, alpha + eps)
               - osp_stats.gamma.cdf(z, alpha - eps)) / (2 * eps)
    pdf = osp_stats.gamma.pdf(z, alpha)
    expected_grad = -cdf_dot / pdf

    assert_allclose(actual_grad, expected_grad, atol=1e-8, rtol=0.0005)
Beispiel #4
0
 def _rvs(self, df):
     # TODO: use upstream implementation when available
     key_n, key_g = random.split(self._random_state)
     normal = random.normal(key_n, shape=self._size)
     half_df = df / 2.0
     gamma = standard_gamma(key_n, half_df, shape=self._size)
     return normal * np.sqrt(half_df / gamma)
Beispiel #5
0
 def _rvs(self, alpha):
     K = alpha.shape[-1]
     gamma_samples = standard_gamma(self._random_state,
                                    alpha,
                                    shape=self._size + (K, ))
     return gamma_samples / np.sum(gamma_samples, axis=-1, keepdims=True)
Beispiel #6
0
 def _rvs(self, a):
     return standard_gamma(self._random_state, a, shape=self._size)
def test_standard_gamma_stats(alpha):
    rng = random.PRNGKey(0)
    z = standard_gamma(rng, np.full((1000,), alpha))
    assert_allclose(np.mean(z), alpha, rtol=0.06)
    assert_allclose(np.var(z), alpha, rtol=0.2)
def test_standard_gamma_shape(alpha, shape):
    rng = random.PRNGKey(0)
    expected_shape = lax.broadcast_shapes(np.shape(alpha), shape)
    assert np.shape(standard_gamma(rng, alpha, shape=shape)) == expected_shape
Beispiel #9
0
 def sample(self, key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     return standard_gamma(key, self.concentration, shape=shape) / self.rate
Beispiel #10
0
 def sample(self, key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     gamma_samples = standard_gamma(key, self.concentration, shape=shape)
     return gamma_samples / np.sum(gamma_samples, axis=-1, keepdims=True)
Beispiel #11
0
 def _rvs(self, df):
     key_n, key_g = random.split(self._random_state)
     normal = random.normal(key_n, shape=self._size)
     half_df = df / 2.0
     gamma = standard_gamma(key_n, half_df, shape=self._size)
     return normal * np.sqrt(half_df / gamma)