def sample(self, key, sample_shape=()): assert is_prng_key(key) key_bern, key_poisson = random.split(key) shape = sample_shape + self.batch_shape mask = random.bernoulli(key_bern, self.gate, shape) samples = random.poisson(key_poisson, device_put(self.rate), shape) return jnp.where(mask, 0, samples)
def test_modelSelection(family, prior, method): p = 5 n = 700 key = random.PRNGKey(0) X = random.normal(key, (n, p)) key, subkey1, subkey2 = random.split(key, 3) mu = 1.7 * X[:, 1] - 1.6 * X[:, 2] truth = jnp.full((p, ), False) truth = truth.at[[1, 2]].set(True) if family == "logistic": y = (mu + random.normal(subkey1, (n, )) > 0).astype(jnp.int32) elif family == "poisson": y = random.poisson(subkey2, lam=jnp.exp(mu), shape=(n, )) fmt = f"{{:0{p}b}}" gammes = np.array([list(fmt.format(i)) for i in range(2**p)]) gammes = jnp.array(gammes == "0")[:-1, :] _, modprobs = modelSelection(X, y, gammes, family=family, prior=prior, method=method) order = jnp.argsort(modprobs)[::-1] assert np.all(np.isfinite(modprobs)) if family == "logistic": assert np.all(gammes[order[0], :] == truth), gammes[order[0], :]
def testPoisson(self, lam, dtype): key = random.PRNGKey(0) rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, lam) compiled_samples = crand(key, lam) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf) # TODO(shoyer): determine error bounds for moments more rigorously (e.g., # based on the central limit theorem). self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False) self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
def testPoisson(self, lam, dtype): if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: raise SkipTest("random.poisson() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, lam) compiled_samples = crand(key, lam) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf) # TODO(shoyer): determine error bounds for moments more rigorously (e.g., # based on the central limit theorem). self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False) self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
def Encoding(self, intensities): assert jnp.all(intensities >= 0), "Inputs must be non-negative" assert intensities.dtype == jnp.float32 or intensities.dtype == jnp.float64, "Intensities must be of type Float." # Get shape and size of data. shape, size = jnp.shape(intensities), jnp.size(intensities) intensities = intensities.reshape(-1) time = self.duration // self.dt # Compute firing rates in seconds as function of data intensity, # accounting for simulation time step. rate_p = jnp.zeros(size) non_zero = intensities != 0 rate = index_update(rate_p, index[non_zero], 1 / intensities[non_zero] * (1000 / self.dt)) del rate_p # Create Poisson distribution and sample inter-spike intervals # (incrementing by 1 to avoid zero intervals). intervals_p = random.poisson(key=self.key_x, lam=rate, shape=(time, len(rate))).astype(jnp.float32) intervals = index_add(intervals_p, index[:, intensities != 0], (intervals_p[:, intensities != 0] == 0).astype( jnp.float32)) del intervals_p # Calculate spike times by cumulatively summing over time dimension. times_p = jnp.cumsum(intervals, dtype='float32', axis=0) times = index_update(times_p, times_p >= time + 1, 0).astype(bool) del times_p spikes_p = jnp.zeros(shape=(time + 1, size)) spikes = index_update(spikes_p, index[times], 1) spikes = spikes[1:] spikes = jnp.transpose(spikes, (1, 0)).astype(jnp.float32) return spikes.reshape(time, *shape)
def testPoissonShape(self): key = random.PRNGKey(0) x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2)) assert x.shape == (3, 2)
def testPoissonBatched(self): key = random.PRNGKey(0) lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)]) samples = random.poisson(key, lam, shape=(20000, )) self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf) self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)
def sample(self, key, sample_shape=()): assert is_prng_key(key) return random.poisson(key, self.rate, shape=sample_shape + self.batch_shape)
def sample(self, rng_key, sample_shape): shape = sample_shape + self.batch_shape + self.event_shape return random.poisson(rng_key, self.lmbda, shape)
def sample(self, key, sample_shape=()): return random.poisson(key, self.rate, shape=sample_shape + self.batch_shape)
def poisson(lam=1.0, size=None): return JaxArray( jr.poisson(DEFAULT.split_key(), lam=lam, shape=_size2shape(size)))
def testPoissonZeros(self): key = random.PRNGKey(0) lam = jnp.concatenate([jnp.zeros(10), 20 * jnp.ones(10)]) samples = random.poisson(key, lam, shape=(2, 20)) self.assertArraysEqual(samples[:, :10], jnp.zeros_like(samples[:, :10]))
Omega = random.normal(key, shape=(p, p)) Omega = Omega @ Omega.T rng, key = random.split(rng) beta = random.multivariate_normal(key, jnp.zeros(shape=(p, )), Omega) rng, key = random.split(rng) X = random.multivariate_normal(key, jnp.zeros(shape=(p, )), 0.1 * jnp.identity(p), shape=(N, )) lam = jnp.exp(X @ beta) rng, key = random.split(rng) y = random.poisson(key, lam, shape=(N, )) # sns.histplot(y, discrete=True) # plt.show() ######################################## ######################################## ## Functions @jit def xmu_diagxsigx(mu, sigma): xsigx = 0.5 * jnp.einsum('ij,jk,ik->i', X, sigma, X) return jnp.exp(jnp.dot(X, mu) + xsigx) @jit
def test_poisson(shape=(1000, 5)): key = jr.PRNGKey(time.time_ns()) data = jr.poisson(key, 5.0, shape=shape) pois = dists.Poisson.fit(data) assert np.allclose(data.mean(axis=0), pois.rate, atol=1e-6)