Ejemplo n.º 1
0
    def testBeta(self, a, b, dtype):
        key = random.PRNGKey(0)
        rand = lambda key, a, b: random.beta(key, a, b, (10000, ), dtype)
        crand = api.jit(rand)

        uncompiled_samples = rand(key, a, b)
        compiled_samples = crand(key, a, b)

        for samples in [uncompiled_samples, compiled_samples]:
            self._CheckKolmogorovSmirnovCDF(samples,
                                            scipy.stats.beta(a, b).cdf)
Ejemplo n.º 2
0
  def testBeta(self, a, b, dtype):
    if not config.x64_enabled:
      raise SkipTest("skip test except on X64")
    key = random.PRNGKey(0)
    rand = lambda key, a, b: random.beta(key, a, b, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, a, b)
    compiled_samples = crand(key, a, b)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)
Ejemplo n.º 3
0
def mixup(key, alpha, image_and_label):
    """https://arxiv.org/abs/1710.09412 mixup."""
    image, label = image_and_label
    batch_size = image.shape[0]

    weight = random.beta(key, alpha, alpha, (batch_size, 1))
    mixed_label = weight * label + (1.0 - weight) * label[::-1]

    weight = np.reshape(weight, (batch_size, 1, 1, 1))
    mixed_image = weight * image + (1.0 - weight) * image[::-1]

    return mixed_image, mixed_label
Ejemplo n.º 4
0
def mixup(key, alpha, image_and_label):
    image, label = image_and_label

    N = image.shape[0]

    weight = random.beta(key, alpha, alpha, (N, 1))
    mixed_label = weight * label + (1.0 - weight) * label[::-1]

    weight = np.reshape(weight, (N, 1, 1, 1))
    mixed_image = weight * image + (1.0 - weight) * image[::-1]

    return mixed_image, mixed_label
Ejemplo n.º 5
0
 def body(state):
     (key, i, u_test, x_test, log_L_test) = state
     key, uniform_key, beta_key = random.split(key, 3)
     # [M]
     U_scale = random.uniform(uniform_key,
                              shape=spawn_point_U.shape,
                              minval=t_L,
                              maxval=t_R)
     t_shrink = random.beta(beta_key, live_points_U.shape[0],
                            1)**jnp.reciprocal(spawn_point_U.size)
     u_test_white = U_scale / t_shrink
     # y_j =
     #    = dx + sum_i p_i * u_i
     #    = dx + R @ u
     # x_i = x0_i + R_ij u_j
     if whiten:
         u_test = L @ (spawn_point_U + R @ u_test_white) + u_mean
     else:
         u_test = u_test_white
     u_test = jnp.clip(u_test, 0., 1.)
     x_test = prior_transform(u_test)
     log_L_test = loglikelihood_from_constrained(**x_test)
     return (key, i + 1, u_test, x_test, log_L_test)
Ejemplo n.º 6
0
def stochastic_result_computation(n_per_sample, key, samples, log_L_samples):
    """

    Args:
        n_per_sample:
        key:
        samples:
        log_L_samples:

    Returns:

    """
    # N
    t = jnp.where(n_per_sample == jnp.inf, 1., random.beta(key, n_per_sample, 1))
    log_t = jnp.log(t)
    log_X = jnp.cumsum(log_t)
    log_L_samples = jnp.concatenate([jnp.array([-jnp.inf]), log_L_samples])
    log_X = jnp.concatenate([jnp.array([0.]), log_X])
    # log_dX = log(1-t_i) + log(X[i-1])
    log_dX = jnp.log(1. - t) + log_X[:-1]  # jnp.log(-jnp.diff(jnp.exp(log_X))) #-inf where n_per_sample=inf
    log_avg_L = jnp.logaddexp(log_L_samples[:-1], log_L_samples[1:]) - jnp.log(2.)
    log_p = log_dX + log_avg_L
    # param calculation
    logZ = logsumexp(log_p)
    log_w = log_p - logZ
    weights = jnp.exp(log_w)
    m = dict_multimap(lambda samples: jnp.sum(left_broadcast_mul(weights, samples), axis=0), samples)
    dsamples = dict_multimap(jnp.subtract, samples, m)
    cov = dict_multimap(lambda dsamples: jnp.sum(
        left_broadcast_mul(weights, (dsamples[..., :, None] * dsamples[..., None, :])), axis=0), dsamples)
    # Kish's ESS = [sum weights]^2 / [sum weights^2]
    ESS = jnp.exp(2. * logsumexp(log_w) - logsumexp(2. * log_w))
    # H = sum w_i log(L)
    _H = jnp.exp(log_w) * log_avg_L
    H = jnp.sum(jnp.where(jnp.isnan(_H), 0., _H))
    return logZ, m, cov, ESS, H
Ejemplo n.º 7
0
 def sample(self, rng_key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     return random.beta(rng_key, self.a, self.b, shape=shape)
Ejemplo n.º 8
0
def beta(a, b, size=None):
    a = a.value if isinstance(a, JaxArray) else a
    b = b.value if isinstance(b, JaxArray) else b
    return JaxArray(
        jr.beta(DEFAULT.split_key(), a=a, b=b, shape=_size2shape(size)))
Ejemplo n.º 9
0
def weights_one_concentration(concentration, key, num_draws, num_components):
    beta_draws = random.beta(
        key=key, a=1, b=concentration, shape=(num_draws, num_components)
    )
    occupied_probability, weights = vmap(stick_breaking_weights)(beta_draws)
    return occupied_probability, weights
Ejemplo n.º 10
0
 def step(stick_length: float, key, concentration: float):
     fraction = random.beta(key, a=1, b=concentration)
     stick = stick_length * fraction
     remainder = stick_length - stick
     return remainder, stick
Ejemplo n.º 11
0
 def sample(self, rng_key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     p = random.beta(rng_key, self.a, self.b, shape=shape)
     n_max = jnp.max(self.n).item()
     return _random_binomial(rng_key, p, self.n, n_max, shape)