def testLogistic(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.logistic(key, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
def conditional_params_to_sample(rng, conditional_params): means, inv_scales, logit_probs = conditional_params rng_mix, rng_logistic = random.split(rng) # Add channel dimension to one-hot mixture indicator mix_indicator = _categorical_onehot(rng_mix, logit_probs)[..., jnp.newaxis] # Use the mixture indicator to select the mean and inverse scale mean = jnp.sum(means * mix_indicator, -4) inv_scale = jnp.sum(inv_scales * mix_indicator, -4) sample = mean + random.logistic(rng_logistic, mean.shape) / inv_scale return snap_to_grid(sample)
def conditional_params_to_sample(rng, conditional_params): means, inv_scales, logit_probs = conditional_params _, h, w, c = means.shape rng_mix, rng_logistic = random.split(rng) mix_idx = np.broadcast_to( _gumbel_max(rng_mix, logit_probs)[..., np.newaxis], (h, w, c))[np.newaxis] means = np.take_along_axis(means, mix_idx, 0)[0] inv_scales = np.take_along_axis(inv_scales, mix_idx, 0)[0] return ( means + random.logistic(rng_logistic, means.shape, means.dtype) / inv_scales)
def testLogistic(self, dtype): if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: raise SkipTest("random.logistic() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key: random.logistic(key, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf)
def sample(p, temperature, key, num_samples=1): """ Generate Binomial Concrete samples :param p: Binomial Concrete params (interpreted as Bernoulli probabilities) (jax.numpy array) :param temperature: temperature parameter :param key: PRNG key :param num_samples: number of samples """ tol = 1e-7 p = np.clip(p, tol, 1 - tol) logit_p = logit(p) base_randomness = random.logistic(key, shape=(num_samples, *p.shape)) return nn.sigmoid((logit_p + base_randomness) / (temperature + tol))
def logistic_mix_sample(nn_out, rng): m, t, inv_scales, logit_weights = logistic_preprocess(nn_out) rng_mix, rng_logistic = random.split(rng) mix_idx = random.categorical(rng_mix, logit_weights, -3) def select_mix(arr): return jnp.squeeze( jnp.take_along_axis(arr, jnp.expand_dims(mix_idx, (-4, -1)), -4), -4) m, t, inv_scales = map(lambda x: jnp.moveaxis(select_mix(x), -1, 0), (m, t, inv_scales)) l = random.logistic(rng_logistic, m.shape) / inv_scales img_red = m[0] + l[0] img_green = m[1] + t[0] * img_red + l[1] img_blue = m[2] + t[1] * img_red + t[2] * img_green + l[2] return jnp.stack([img_red, img_green, img_blue], -1)
def logistic(loc=0.0, scale=1.0, size=None): assert loc == 0. assert scale == 1. return JaxArray(jr.logistic(DEFAULT.split_key(), shape=_size2shape(size)))
def sample(self, key, sample_shape=()): z = random.logistic(key, shape=sample_shape + self.batch_shape + self.event_shape) return self.loc + z * self.scale