def testCauchy(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.cauchy(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf)
def testCauchy(self, dtype): if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: raise SkipTest("random.cauchy() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key: random.cauchy(key, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf)
def sample(self, key, sample_shape=()): eps = random.cauchy(key, shape=sample_shape + self.batch_shape) return self.loc + eps * self.scale
def draw_samples(self, rng, alpha, scale): r"""Draw samples from the robust distribution. This function implements Algorithm 1 the paper. This code is written to allow for sampling from a set of different distributions, each parametrized by its own alpha and scale values, as opposed to the more standard approach of drawing N samples from the same distribution. This is done by repeatedly performing N instances of rejection sampling for each of the N distributions until at least one proposal for each of the N distributions has been accepted. All samples assume a zero mean --- to get non-zero mean samples, just add each mean to each sample. Args: rng: A JAX pseudo random number generated, from random.PRNG(). alpha: A tensor where each element is the shape parameter of that element's distribution. Must be > 0. scale: A tensor where each element is the scale parameter of that element's distribution. Must be >=0 and the same shape as `alpha`. Returns: A tensor with the same shape as `alpha` and `scale` where each element is a sample drawn from the zero-mean distribution specified for that element by `alpha` and `scale`. """ assert jnp.all(scale > 0) assert jnp.all(alpha >= 0) assert jnp.all(jnp.array(alpha.shape) == jnp.array(scale.shape)) shape = alpha.shape samples = jnp.zeros(shape) accepted = jnp.zeros(shape, dtype=bool) # Rejection sampling. while not jnp.all(accepted): # The sqrt(2) scaling of the Cauchy distribution corrects for our # differing conventions for standardization. rng, key = random.split(rng) cauchy_sample = random.cauchy(key, shape=shape) * jnp.sqrt(2) # Compute the likelihood of each sample under its target distribution. nll = self.nllfun(cauchy_sample, alpha, 1) # Bound the NLL. We don't use the approximate loss as it may cause # unpredictable behavior in the context of sampling. nll_bound = (general.lossfun(cauchy_sample, 0, 1) + self.log_base_partition_function(alpha)) # Draw N samples from a uniform distribution, and use each uniform # sample to decide whether or not to accept each proposal sample. rng, key = random.split(rng) uniform_sample = random.uniform(key, shape=shape) accept = uniform_sample <= jnp.exp(nll_bound - nll) # If a sample is accepted, replace its element in `samples` with the # proposal sample, and set its bit in `accepted` to True. samples = jnp.where(accept, cauchy_sample, samples) accepted = accept | accepted # Because our distribution is a location-scale family, we sample from # p(x | 0, \alpha, 1) and then scale each sample by `scale`. samples *= scale return samples
def _rvs(self): return random.cauchy(self._random_state, shape=self._size)
def standard_cauchy(size=None): return JaxArray(jr.cauchy(DEFAULT.split_key(), shape=_size2shape(size)))
def sample(self, rng_key, sample_shape=()): shape = sample_shape + self.batch_shape + self.event_shape std_sample = random.cauchy(rng_key, shape) return self.loc + self.scale * std_sample