def _rvs(self, a, b): # XXX the implementation is different from PyTorch's one # in PyTorch, a sample is generated from dirichlet distribution key_a, key_b = random.split(self._random_state) gamma_a = standard_gamma(key_a, a, shape=self._size) gamma_b = standard_gamma(key_b, b, shape=self._size) return gamma_a / (gamma_a + gamma_b)
def test_standard_gamma_batch(): rng = random.PRNGKey(0) alphas = np.array([1., 2., 3.]) rngs = random.split(rng, 3) samples = vmap(lambda rng, alpha: standard_gamma(rng, alpha))(rngs, alphas) for i in range(3): assert_allclose(samples[i], standard_gamma(rngs[i], alphas[i]))
def test_standard_gamma_grad(alpha): rng = random.PRNGKey(0) alphas = np.full((100,), alpha) z = standard_gamma(rng, alphas) actual_grad = grad(lambda x: np.sum(standard_gamma(rng, x)))(alphas) eps = 0.01 * alpha / (1.0 + np.sqrt(alpha)) cdf_dot = (osp_stats.gamma.cdf(z, alpha + eps) - osp_stats.gamma.cdf(z, alpha - eps)) / (2 * eps) pdf = osp_stats.gamma.pdf(z, alpha) expected_grad = -cdf_dot / pdf assert_allclose(actual_grad, expected_grad, atol=1e-8, rtol=0.0005)
def _rvs(self, df): # TODO: use upstream implementation when available key_n, key_g = random.split(self._random_state) normal = random.normal(key_n, shape=self._size) half_df = df / 2.0 gamma = standard_gamma(key_n, half_df, shape=self._size) return normal * np.sqrt(half_df / gamma)
def _rvs(self, alpha): K = alpha.shape[-1] gamma_samples = standard_gamma(self._random_state, alpha, shape=self._size + (K, )) return gamma_samples / np.sum(gamma_samples, axis=-1, keepdims=True)
def _rvs(self, a): return standard_gamma(self._random_state, a, shape=self._size)
def test_standard_gamma_stats(alpha): rng = random.PRNGKey(0) z = standard_gamma(rng, np.full((1000,), alpha)) assert_allclose(np.mean(z), alpha, rtol=0.06) assert_allclose(np.var(z), alpha, rtol=0.2)
def test_standard_gamma_shape(alpha, shape): rng = random.PRNGKey(0) expected_shape = lax.broadcast_shapes(np.shape(alpha), shape) assert np.shape(standard_gamma(rng, alpha, shape=shape)) == expected_shape
def sample(self, key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape return standard_gamma(key, self.concentration, shape=shape) / self.rate
def sample(self, key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape gamma_samples = standard_gamma(key, self.concentration, shape=shape) return gamma_samples / np.sum(gamma_samples, axis=-1, keepdims=True)
def _rvs(self, df): key_n, key_g = random.split(self._random_state) normal = random.normal(key_n, shape=self._size) half_df = df / 2.0 gamma = standard_gamma(key_n, half_df, shape=self._size) return normal * np.sqrt(half_df / gamma)