Пример #1
0
  def testGamma(self, a, dtype):
    key = random.PRNGKey(0)
    rand = lambda key, a: random.gamma(key, a, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, a)
    compiled_samples = crand(key, a)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf)
Пример #2
0
def _gamma_jax(shape, alpha, beta=None, dtype=tf.float32, seed=None, name=None):  # pylint: disable=unused-argument
  """JAX-based reparameterized gamma sampler."""
  dtype = utils.common_dtype([alpha, beta], dtype_hint=dtype)
  shape = _ensure_tuple(shape) + _bcast_shape((), [alpha, beta])
  import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
  if seed is None:
    raise ValueError('Must provide PRNGKey to sample in JAX.')
  # TODO(srvasude): Sample in the given dtype once
  # https://github.com/google/jax/issues/2130 is fixed.
  samps = jaxrand.gamma(
      key=seed, a=alpha, shape=shape, dtype=np.float64).astype(dtype)
  return samps if beta is None else samps / beta
Пример #3
0
def _gamma_jax(shape,
               alpha,
               beta=None,
               dtype=tf.float32,
               seed=None,
               name=None):  # pylint: disable=unused-argument
    dtype = utils.common_dtype([alpha, beta], dtype_hint=dtype)
    shape = _ensure_tuple(shape) + _bcast_shape((), [alpha, beta])
    import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
    if seed is None:
        raise ValueError('Must provide PRNGKey to sample in JAX.')
    samps = jaxrand.gamma(key=seed, a=alpha, shape=shape, dtype=dtype)
    return samps if beta is None else samps / beta
Пример #4
0
def test_studentst(shape=(1000, ), loc=0.0, scale=5.0, dof=5.0):
    key = jr.PRNGKey(time.time_ns())
    key1, key2 = jr.split(key, 2)
    zs = jr.normal(key1, shape=shape)
    alpha, beta = dof / 2.0, 2.0 / dof
    taus = jr.gamma(key2, alpha, shape=shape) / beta
    data = loc + zs * scale / np.sqrt(taus)
    # true = dists.StudentT(dof, loc, scale)
    # data = true.sample(seed=key, sample_shape=shape)
    norm = dists.Normal.fit(data)
    stdt, lps = dists.StudentT.fit(data)
    assert np.all(np.diff(lps) > -1e-3)
    assert stdt.log_prob(data).mean() > norm.log_prob(data).mean()
Пример #5
0
    def sample(self, key, sample_shape=()):
        # TODO.
        # it is enough to return an arbitrary sample with correct shape
        # return jnp.zeros(sample_shape + self.event_shape)
        key_gamma, key_tn, key_normal = random.split(key, 3)

        k = self.df / 2
        w = random.gamma(key_gamma, k, sample_shape) / k
        z = (dist.TruncatedNormal(loc=0., scale=jnp.sqrt(1/w), low=0.0)
                 .sample(key_tn))
        delta = self.skew / jnp.sqrt(1 + self.skew ** 2)

        _loc = self.loc + self.scale * z * delta
        _scale = self.scale * jnp.sqrt(1 - delta ** 2)
        return random.normal(key_normal, sample_shape) * _scale + _loc
Пример #6
0
def _gamma_jax(shape, alpha, beta=None, dtype=np.float32, seed=None, name=None):  # pylint: disable=unused-argument
  """JAX-based reparameterized gamma sampler."""
  dtype = utils.common_dtype([alpha, beta], dtype_hint=dtype)
  alpha = np.array(alpha, dtype=dtype)
  beta = None if beta is None else np.array(beta, dtype=dtype)
  shape = _ensure_shape_tuple(shape)
  import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
  if seed is None:
    raise ValueError('Must provide PRNGKey to sample in JAX.')
  # TODO(srvasude): Sample in the given dtype once
  # https://github.com/google/jax/issues/2130 is fixed.
  samps = jaxrand.gamma(
      key=seed, a=alpha, shape=shape, dtype=np.float64).astype(dtype)
  # Match the 0->tiny behavior of tf.random.gamma.
  return np.maximum(np.finfo(dtype).tiny,
                    samps if beta is None else samps / beta)
Пример #7
0
def gen_samples_A_ii_0(key, samples_nb, T, S_i_C_hat_i):
    """
    FIXME(QBatista): Add documentation

    """

    m_i = S_i_C_hat_i.shape[0]

    shape_param = T / 2 + 1
    scale_param = 2 / S_i_C_hat_i
    size = (samples_nb, m_i, m_i)

    if m_i == 1:
        # Uses https://en.wikipedia.org/wiki/Gamma_distribution#Scaling
        A_ii_0_T_A_ii_0 = scale_param * random.gamma(key, shape_param,
                                                     shape=size)
        A_ii_0 = np.sqrt(A_ii_0_T_A_ii_0)  # because m_i == 1

        return A_ii_0, A_ii_0_T_A_ii_0
    else:
        raise NotImplementedError
Пример #8
0
def test_standard_gamma_batch():
    rng = random.PRNGKey(0)
    alphas = np.array([1., 2., 3.])
    rngs = random.split(rng, 3)

    samples = vmap(lambda rng, alpha: random.gamma(rng, alpha))(rngs, alphas)
    for i in range(3):
        assert_allclose(samples[i], random.gamma(rngs[i], alphas[i]))

    samples = vmap(lambda rng: random.gamma(rng, alphas[:2]))(rngs)
    for i in range(3):
        assert_allclose(samples[i], random.gamma(rngs[i], alphas[:2]))

    samples = vmap(lambda alpha: random.gamma(rng, alpha))(alphas)
    for i in range(3):
        assert_allclose(samples[i], random.gamma(rng, alphas[i]))
Пример #9
0
def test_standard_gamma_stats(alpha):
    rng = random.PRNGKey(0)
    z = random.gamma(rng, np.full((1000, ), alpha))
    assert_allclose(np.mean(z), alpha, rtol=0.06)
    assert_allclose(np.var(z), alpha, rtol=0.2)
Пример #10
0
def test_standard_gamma_shape(alpha, shape):
    rng = random.PRNGKey(0)
    expected_shape = lax.broadcast_shapes(np.shape(alpha), shape)
    assert np.shape(random.gamma(rng, alpha, shape=shape)) == expected_shape
Пример #11
0
def standard_gamma(shape, size=None):
    return JaxArray(
        jr.gamma(DEFAULT.split_key(), a=shape, shape=_size2shape(size)))
Пример #12
0
 def sample(self, key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     return random.gamma(key, self.concentration, shape=shape) / self.rate
def sample_from_prior(key, num=100):
    keya, keyb = random.split(key)
    alpha = random.gamma(keya, a0, shape=(num, )) / b0
    w = random.normal(keyb, shape=(num, num_features))
    return w, np.log(alpha)
Пример #14
0
 def _rvs(self, alpha):
     K = alpha.shape[-1]
     gamma_samples = random.gamma(self._random_state,
                                  alpha,
                                  shape=self._size + (K, ))
     return gamma_samples / jnp.sum(gamma_samples, axis=-1, keepdims=True)
Пример #15
0
 def _rvs(self, a):
     return random.gamma(self._random_state, a, shape=self._size)
Пример #16
0
 def f(x):
     return random.gamma(random.PRNGKey(0), x)
Пример #17
0
 def sample(self, rng_key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     return random.gamma(rng_key, self.a, self.loc, self.scale, shape)
Пример #18
0
 def sample(self, key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     gamma_samples = random.gamma(key, self.concentration, shape=shape)
     return gamma_samples / np.sum(gamma_samples, axis=-1, keepdims=True)
Пример #19
0
def gamma(shape, scale=1.0, size=None):
    assert scale == 1.
    return JaxArray(
        jr.gamma(DEFAULT.split_key(), a=shape, shape=_size2shape(size)))
Пример #20
0
 def testGammaShape(self):
     key = random.PRNGKey(0)
     x = random.gamma(key, onp.array([0.2, 0.3]), shape=(3, 2))
     assert x.shape == (3, 2)
Пример #21
0
 def sample(self, rng_key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     # IF X ~ Gamma(a, scale=1/b), then 1/X ~ Inverse-Gamma(a, scale=b)
     return self.b / random.gamma(rng_key, self.a, shape)