Ejemplo n.º 1
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     key_bern, key_poisson = random.split(key)
     shape = sample_shape + self.batch_shape
     mask = random.bernoulli(key_bern, self.gate, shape)
     samples = random.poisson(key_poisson, device_put(self.rate), shape)
     return jnp.where(mask, 0, samples)
Ejemplo n.º 2
0
def test_modelSelection(family, prior, method):
    p = 5
    n = 700
    key = random.PRNGKey(0)
    X = random.normal(key, (n, p))
    key, subkey1, subkey2 = random.split(key, 3)
    mu = 1.7 * X[:, 1] - 1.6 * X[:, 2]
    truth = jnp.full((p, ), False)
    truth = truth.at[[1, 2]].set(True)
    if family == "logistic":
        y = (mu + random.normal(subkey1, (n, )) > 0).astype(jnp.int32)
    elif family == "poisson":
        y = random.poisson(subkey2, lam=jnp.exp(mu), shape=(n, ))
    fmt = f"{{:0{p}b}}"
    gammes = np.array([list(fmt.format(i)) for i in range(2**p)])
    gammes = jnp.array(gammes == "0")[:-1, :]

    _, modprobs = modelSelection(X,
                                 y,
                                 gammes,
                                 family=family,
                                 prior=prior,
                                 method=method)
    order = jnp.argsort(modprobs)[::-1]
    assert np.all(np.isfinite(modprobs))
    if family == "logistic":
        assert np.all(gammes[order[0], :] == truth), gammes[order[0], :]
Ejemplo n.º 3
0
  def testPoisson(self, lam, dtype):
    key = random.PRNGKey(0)
    rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, lam)
    compiled_samples = crand(key, lam)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
      # TODO(shoyer): determine error bounds for moments more rigorously (e.g.,
      # based on the central limit theorem).
      self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False)
      self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
Ejemplo n.º 4
0
  def testPoisson(self, lam, dtype):
    if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
      raise SkipTest("random.poisson() not supported on TPU for 16-bit types.")
    key = random.PRNGKey(0)
    rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
    crand = api.jit(rand)

    uncompiled_samples = rand(key, lam)
    compiled_samples = crand(key, lam)

    for samples in [uncompiled_samples, compiled_samples]:
      self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf)
      # TODO(shoyer): determine error bounds for moments more rigorously (e.g.,
      # based on the central limit theorem).
      self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False)
      self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False)
Ejemplo n.º 5
0
    def Encoding(self, intensities):
        assert jnp.all(intensities >= 0), "Inputs must be non-negative"
        assert intensities.dtype == jnp.float32 or intensities.dtype == jnp.float64, "Intensities must be of type Float."

        # Get shape and size of data.
        shape, size = jnp.shape(intensities), jnp.size(intensities)

        intensities = intensities.reshape(-1)

        time = self.duration // self.dt

        # Compute firing rates in seconds as function of data intensity,
        # accounting for simulation time step.
        rate_p = jnp.zeros(size)
        non_zero = intensities != 0

        rate = index_update(rate_p, index[non_zero],
                            1 / intensities[non_zero] * (1000 / self.dt))
        del rate_p

        # Create Poisson distribution and sample inter-spike intervals
        # (incrementing by 1 to avoid zero intervals).
        intervals_p = random.poisson(key=self.key_x,
                                     lam=rate,
                                     shape=(time,
                                            len(rate))).astype(jnp.float32)

        intervals = index_add(intervals_p, index[:, intensities != 0],
                              (intervals_p[:, intensities != 0] == 0).astype(
                                  jnp.float32))

        del intervals_p

        # Calculate spike times by cumulatively summing over time dimension.

        times_p = jnp.cumsum(intervals, dtype='float32', axis=0)
        times = index_update(times_p, times_p >= time + 1, 0).astype(bool)

        del times_p

        spikes_p = jnp.zeros(shape=(time + 1, size))
        spikes = index_update(spikes_p, index[times], 1)
        spikes = spikes[1:]
        spikes = jnp.transpose(spikes, (1, 0)).astype(jnp.float32)
        return spikes.reshape(time, *shape)
Ejemplo n.º 6
0
 def testPoissonShape(self):
     key = random.PRNGKey(0)
     x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2))
     assert x.shape == (3, 2)
Ejemplo n.º 7
0
 def testPoissonBatched(self):
     key = random.PRNGKey(0)
     lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)])
     samples = random.poisson(key, lam, shape=(20000, ))
     self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf)
     self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf)
Ejemplo n.º 8
0
 def sample(self, key, sample_shape=()):
     assert is_prng_key(key)
     return random.poisson(key,
                           self.rate,
                           shape=sample_shape + self.batch_shape)
Ejemplo n.º 9
0
 def sample(self, rng_key, sample_shape):
     shape = sample_shape + self.batch_shape + self.event_shape
     return random.poisson(rng_key, self.lmbda, shape)
Ejemplo n.º 10
0
 def sample(self, key, sample_shape=()):
     return random.poisson(key,
                           self.rate,
                           shape=sample_shape + self.batch_shape)
Ejemplo n.º 11
0
def poisson(lam=1.0, size=None):
    return JaxArray(
        jr.poisson(DEFAULT.split_key(), lam=lam, shape=_size2shape(size)))
Ejemplo n.º 12
0
 def testPoissonZeros(self):
     key = random.PRNGKey(0)
     lam = jnp.concatenate([jnp.zeros(10), 20 * jnp.ones(10)])
     samples = random.poisson(key, lam, shape=(2, 20))
     self.assertArraysEqual(samples[:, :10],
                            jnp.zeros_like(samples[:, :10]))
Ejemplo n.º 13
0
Omega = random.normal(key, shape=(p, p))
Omega = Omega @ Omega.T

rng, key = random.split(rng)
beta = random.multivariate_normal(key, jnp.zeros(shape=(p, )), Omega)

rng, key = random.split(rng)
X = random.multivariate_normal(key,
                               jnp.zeros(shape=(p, )),
                               0.1 * jnp.identity(p),
                               shape=(N, ))

lam = jnp.exp(X @ beta)

rng, key = random.split(rng)
y = random.poisson(key, lam, shape=(N, ))

# sns.histplot(y, discrete=True)
# plt.show()
########################################


########################################
## Functions
@jit
def xmu_diagxsigx(mu, sigma):
    xsigx = 0.5 * jnp.einsum('ij,jk,ik->i', X, sigma, X)
    return jnp.exp(jnp.dot(X, mu) + xsigx)


@jit
Ejemplo n.º 14
0
def test_poisson(shape=(1000, 5)):
    key = jr.PRNGKey(time.time_ns())
    data = jr.poisson(key, 5.0, shape=shape)
    pois = dists.Poisson.fit(data)
    assert np.allclose(data.mean(axis=0), pois.rate, atol=1e-6)