Example #1
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     dtype = jnp.result_type(float)
     finfo = jnp.finfo(dtype)
     minval = finfo.tiny
     u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
     return self.base_dist.icdf(u * self._cdf_at_high)
Example #2
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     u = random.uniform(key, sample_shape + self.batch_shape)
     loc = self.base_dist.loc
     sign = jnp.where(loc >= self.low, 1.0, -1.0)
     return (1 - sign) * loc + sign * self.base_dist.icdf(
         (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high)
Example #3
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     key_beta, key_binom = random.split(key)
     probs = self._beta.sample(key_beta, sample_shape)
     return BinomialProbs(total_count=self.total_count, probs=probs).sample(
         key_binom
     )
Example #4
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     key_bern, key_poisson = random.split(key)
     shape = sample_shape + self.batch_shape
     mask = random.bernoulli(key_bern, self.gate, shape)
     samples = random.poisson(key_poisson, device_put(self.rate), shape)
     return jnp.where(mask, 0, samples)
Example #5
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     key_bern, key_base = random.split(key)
     shape = sample_shape + self.batch_shape
     mask = random.bernoulli(key_bern, self.gate, shape)
     samples = self.base_dist(rng_key=key_base, sample_shape=sample_shape)
     return jnp.where(mask, 0, samples)
Example #6
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     probs = self.probs
     dtype = jnp.result_type(probs)
     shape = sample_shape + self.batch_shape
     u = random.uniform(key, shape, dtype)
     return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs))
Example #7
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     logits = self.logits
     dtype = jnp.result_type(logits)
     shape = sample_shape + self.batch_shape
     u = random.uniform(key, shape, dtype)
     return jnp.floor(jnp.log1p(-u) / -softplus(logits))
Example #8
0
    def sample_with_intermediates(self, key, sample_shape=()):
        """
        Same as ``sample`` except that the sampled mixture components are also returned.

        :param jax.random.PRNGKey key: the rng_key key to be used for the distribution.
        :param tuple sample_shape: the sample shape for the distribution.
        :return: Tuple (samples, indices)
        :rtype: tuple
        """
        assert is_prng_key(key)
        key_comp, key_ind = jax.random.split(key)
        # Samples from component distribution will have shape:
        #  (*sample_shape, *batch_shape, mixture_size, *event_shape)
        samples = self.component_distribution.expand(
            sample_shape + self.batch_shape +
            (self.mixture_size, )).sample(key_comp)
        # Sample selection indices from the categorical (shape will be sample_shape)
        indices = self.mixing_distribution.expand(
            sample_shape + self.batch_shape).sample(key_ind)
        n_expand = self.event_dim + 1
        indices_expanded = indices.reshape(indices.shape + (1, ) * n_expand)
        # Select samples according to indices samples from categorical
        samples_selected = jnp.take_along_axis(samples,
                                               indices=indices_expanded,
                                               axis=self.mixture_dim)
        # Final sample shape (*sample_shape, *batch_shape, *event_shape)
        return jnp.squeeze(samples_selected, axis=self.mixture_dim), indices
Example #9
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     key_dirichlet, key_multinom = random.split(key)
     probs = self._dirichlet.sample(key_dirichlet, sample_shape)
     return MultinomialProbs(total_count=self.total_count, probs=probs).sample(
         key_multinom
     )
Example #10
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates))
     x = random.gamma(
         key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,))
     )
     x = jnp.sum(x / denom, axis=-1)
     return jnp.clip(x * (0.5 / jnp.pi ** 2), a_max=self.truncation_point)
Example #11
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     dtype = jnp.result_type(float)
     finfo = jnp.finfo(dtype)
     minval = finfo.tiny
     u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval)
     loc = self.base_dist.loc
     sign = jnp.where(loc >= self.low, 1.0, -1.0)
     return (1 - sign) * loc + sign * self.base_dist.icdf(
         (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high
     )
Example #12
0
    def sample(self, key, sample_shape=()):
        """ Generate sample from von Mises distribution

        :param key: random number generator key
        :param sample_shape: shape of samples
        :return: samples from von Mises
        """
        assert is_prng_key(key)
        samples = von_mises_centered(key, self.concentration, sample_shape + self.shape())
        samples = samples + self.loc  # VM(0, concentration) -> VM(loc,concentration)
        samples = (samples + jnp.pi) % (2. * jnp.pi) - jnp.pi

        return samples
Example #13
0
    def sample(self, key, sample_shape=()):
        """
        ** References: **
            1. A New Unified Approach for the Simulation of a Wide Class of Directional Distributions
               John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018)
        """
        assert is_prng_key(key)
        phi_key, psi_key = random.split(key)

        corr = self.correlation
        conc = jnp.stack((self.phi_concentration, self.psi_concentration))

        eig = 0.5 * (conc[0] - corr**2 / conc[1])
        eig = jnp.stack((jnp.zeros_like(eig), eig))
        eigmin = jnp.where(eig[1] < 0, eig[1],
                           jnp.zeros_like(eig[1], dtype=eig.dtype))
        eig = eig - eigmin
        b0 = self._bfind(eig)

        total = _numel(sample_shape)
        phi_den = log_I1(0, conc[1]).squeeze(0)
        batch_size = _numel(self.batch_shape)
        phi_shape = (total, 2, batch_size)
        phi_state = SineBivariateVonMises._phi_marginal(
            phi_shape,
            phi_key,
            jnp.reshape(conc, (2, batch_size)),
            jnp.reshape(corr, (batch_size, )),
            jnp.reshape(eig, (2, batch_size)),
            jnp.reshape(b0, (batch_size, )),
            jnp.reshape(eigmin, (batch_size, )),
            jnp.reshape(phi_den, (batch_size, )),
        )

        phi = jnp.arctan2(phi_state.phi[:, 1:], phi_state.phi[:, :1])

        alpha = jnp.sqrt(conc[1]**2 + (corr * jnp.sin(phi))**2)
        beta = jnp.arctan(corr / conc[1] * jnp.sin(phi))

        psi = VonMises(beta, alpha).sample(psi_key)

        phi_psi = jnp.concatenate(
            (
                (phi + self.phi_loc + pi) % (2 * pi) - pi,
                (psi + self.psi_loc + pi) % (2 * pi) - pi,
            ),
            axis=1,
        )
        phi_psi = jnp.transpose(phi_psi, (0, 2, 1))
        return phi_psi.reshape(*sample_shape, *self.batch_shape,
                               *self.event_shape)
Example #14
0
    def sample(self, key, sample_shape=()):
        assert is_prng_key(key)
        u = random.uniform(key, sample_shape + self.batch_shape)

        # NB: we use a more numerically stable formula for a symmetric base distribution
        #   A = icdf(cdf(low) + (cdf(high) - cdf(low)) * u) = icdf[(1 - u) * cdf(low) + u * cdf(high)]
        # will suffer by precision issues when low is large;
        # If low < loc:
        #   A = icdf[(1 - u) * cdf(low) + u * cdf(high)]
        # Else
        #   A = 2 * loc - icdf[(1 - u) * cdf(2*loc-low)) + u * cdf(2*loc - high)]
        loc = self.base_dist.loc
        sign = jnp.where(loc >= self.low, 1.0, -1.0)
        return (1 - sign) * loc + sign * self.base_dist.icdf(
            (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high)
Example #15
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     samples = random.bernoulli(key,
                                self.probs,
                                shape=sample_shape + self.batch_shape)
     return samples.astype(jnp.result_type(samples, int))
Example #16
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     u = random.uniform(key, sample_shape + self.batch_shape)
     return self.base_dist.icdf(u * self._cdf_at_high)
Example #17
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     return random.poisson(key,
                           self.rate,
                           shape=sample_shape + self.batch_shape)
Example #18
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     return multinomial(key,
                        self.probs,
                        self.total_count,
                        shape=sample_shape + self.batch_shape)
Example #19
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     return random.categorical(key,
                               self.logits,
                               shape=sample_shape + self.batch_shape)
Example #20
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     return categorical(key,
                        self.probs,
                        shape=sample_shape + self.batch_shape)
Example #21
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     key_gamma, key_poisson = random.split(key)
     rate = self._gamma.sample(key_gamma, sample_shape)
     return Poisson(rate).sample(key_poisson)
Example #22
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     u = random.uniform(key, sample_shape + self.batch_shape)
     return self.icdf(u)