def sample(self, key, sample_shape=()): assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) return self.base_dist.icdf(u * self._cdf_at_high)
def sample(self, key, sample_shape=()): assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return (1 - sign) * loc + sign * self.base_dist.icdf( (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high)
def sample(self, key, sample_shape=()): assert is_prng_key(key) key_beta, key_binom = random.split(key) probs = self._beta.sample(key_beta, sample_shape) return BinomialProbs(total_count=self.total_count, probs=probs).sample( key_binom )
def sample(self, key, sample_shape=()): assert is_prng_key(key) key_bern, key_poisson = random.split(key) shape = sample_shape + self.batch_shape mask = random.bernoulli(key_bern, self.gate, shape) samples = random.poisson(key_poisson, device_put(self.rate), shape) return jnp.where(mask, 0, samples)
def sample(self, key, sample_shape=()): assert is_prng_key(key) key_bern, key_base = random.split(key) shape = sample_shape + self.batch_shape mask = random.bernoulli(key_bern, self.gate, shape) samples = self.base_dist(rng_key=key_base, sample_shape=sample_shape) return jnp.where(mask, 0, samples)
def sample(self, key, sample_shape=()): assert is_prng_key(key) probs = self.probs dtype = jnp.result_type(probs) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / jnp.log1p(-probs))
def sample(self, key, sample_shape=()): assert is_prng_key(key) logits = self.logits dtype = jnp.result_type(logits) shape = sample_shape + self.batch_shape u = random.uniform(key, shape, dtype) return jnp.floor(jnp.log1p(-u) / -softplus(logits))
def sample_with_intermediates(self, key, sample_shape=()): """ Same as ``sample`` except that the sampled mixture components are also returned. :param jax.random.PRNGKey key: the rng_key key to be used for the distribution. :param tuple sample_shape: the sample shape for the distribution. :return: Tuple (samples, indices) :rtype: tuple """ assert is_prng_key(key) key_comp, key_ind = jax.random.split(key) # Samples from component distribution will have shape: # (*sample_shape, *batch_shape, mixture_size, *event_shape) samples = self.component_distribution.expand( sample_shape + self.batch_shape + (self.mixture_size, )).sample(key_comp) # Sample selection indices from the categorical (shape will be sample_shape) indices = self.mixing_distribution.expand( sample_shape + self.batch_shape).sample(key_ind) n_expand = self.event_dim + 1 indices_expanded = indices.reshape(indices.shape + (1, ) * n_expand) # Select samples according to indices samples from categorical samples_selected = jnp.take_along_axis(samples, indices=indices_expanded, axis=self.mixture_dim) # Final sample shape (*sample_shape, *batch_shape, *event_shape) return jnp.squeeze(samples_selected, axis=self.mixture_dim), indices
def sample(self, key, sample_shape=()): assert is_prng_key(key) key_dirichlet, key_multinom = random.split(key) probs = self._dirichlet.sample(key_dirichlet, sample_shape) return MultinomialProbs(total_count=self.total_count, probs=probs).sample( key_multinom )
def sample(self, key, sample_shape=()): assert is_prng_key(key) denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates)) x = random.gamma( key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,)) ) x = jnp.sum(x / denom, axis=-1) return jnp.clip(x * (0.5 / jnp.pi ** 2), a_max=self.truncation_point)
def sample(self, key, sample_shape=()): assert is_prng_key(key) dtype = jnp.result_type(float) finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return (1 - sign) * loc + sign * self.base_dist.icdf( (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high )
def sample(self, key, sample_shape=()): """ Generate sample from von Mises distribution :param key: random number generator key :param sample_shape: shape of samples :return: samples from von Mises """ assert is_prng_key(key) samples = von_mises_centered(key, self.concentration, sample_shape + self.shape()) samples = samples + self.loc # VM(0, concentration) -> VM(loc,concentration) samples = (samples + jnp.pi) % (2. * jnp.pi) - jnp.pi return samples
def sample(self, key, sample_shape=()): """ ** References: ** 1. A New Unified Approach for the Simulation of a Wide Class of Directional Distributions John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018) """ assert is_prng_key(key) phi_key, psi_key = random.split(key) corr = self.correlation conc = jnp.stack((self.phi_concentration, self.psi_concentration)) eig = 0.5 * (conc[0] - corr**2 / conc[1]) eig = jnp.stack((jnp.zeros_like(eig), eig)) eigmin = jnp.where(eig[1] < 0, eig[1], jnp.zeros_like(eig[1], dtype=eig.dtype)) eig = eig - eigmin b0 = self._bfind(eig) total = _numel(sample_shape) phi_den = log_I1(0, conc[1]).squeeze(0) batch_size = _numel(self.batch_shape) phi_shape = (total, 2, batch_size) phi_state = SineBivariateVonMises._phi_marginal( phi_shape, phi_key, jnp.reshape(conc, (2, batch_size)), jnp.reshape(corr, (batch_size, )), jnp.reshape(eig, (2, batch_size)), jnp.reshape(b0, (batch_size, )), jnp.reshape(eigmin, (batch_size, )), jnp.reshape(phi_den, (batch_size, )), ) phi = jnp.arctan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) alpha = jnp.sqrt(conc[1]**2 + (corr * jnp.sin(phi))**2) beta = jnp.arctan(corr / conc[1] * jnp.sin(phi)) psi = VonMises(beta, alpha).sample(psi_key) phi_psi = jnp.concatenate( ( (phi + self.phi_loc + pi) % (2 * pi) - pi, (psi + self.psi_loc + pi) % (2 * pi) - pi, ), axis=1, ) phi_psi = jnp.transpose(phi_psi, (0, 2, 1)) return phi_psi.reshape(*sample_shape, *self.batch_shape, *self.event_shape)
def sample(self, key, sample_shape=()): assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) # NB: we use a more numerically stable formula for a symmetric base distribution # A = icdf(cdf(low) + (cdf(high) - cdf(low)) * u) = icdf[(1 - u) * cdf(low) + u * cdf(high)] # will suffer by precision issues when low is large; # If low < loc: # A = icdf[(1 - u) * cdf(low) + u * cdf(high)] # Else # A = 2 * loc - icdf[(1 - u) * cdf(2*loc-low)) + u * cdf(2*loc - high)] loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) return (1 - sign) * loc + sign * self.base_dist.icdf( (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high)
def sample(self, key, sample_shape=()): assert is_prng_key(key) samples = random.bernoulli(key, self.probs, shape=sample_shape + self.batch_shape) return samples.astype(jnp.result_type(samples, int))
def sample(self, key, sample_shape=()): assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) return self.base_dist.icdf(u * self._cdf_at_high)
def sample(self, key, sample_shape=()): assert is_prng_key(key) return random.poisson(key, self.rate, shape=sample_shape + self.batch_shape)
def sample(self, key, sample_shape=()): assert is_prng_key(key) return multinomial(key, self.probs, self.total_count, shape=sample_shape + self.batch_shape)
def sample(self, key, sample_shape=()): assert is_prng_key(key) return random.categorical(key, self.logits, shape=sample_shape + self.batch_shape)
def sample(self, key, sample_shape=()): assert is_prng_key(key) return categorical(key, self.probs, shape=sample_shape + self.batch_shape)
def sample(self, key, sample_shape=()): assert is_prng_key(key) key_gamma, key_poisson = random.split(key) rate = self._gamma.sample(key_gamma, sample_shape) return Poisson(rate).sample(key_poisson)
def sample(self, key, sample_shape=()): assert is_prng_key(key) u = random.uniform(key, sample_shape + self.batch_shape) return self.icdf(u)