Esempio n. 1
0
    def test_sample_from_array_correct_shape(self):
        x = jax.random.uniform(jax.random.PRNGKey(124), shape=(1000, 200))
        rng_key = jax.random.PRNGKey(0)
        n_vals = 38
        shuffled = util.sample_from_array(rng_key, x, n_vals, 0)
        self.assertEqual((n_vals, 200), jnp.shape(shuffled))

        shuffled = util.sample_from_array(rng_key, x, n_vals, 1)
        self.assertEqual((1000, n_vals), jnp.shape(shuffled))
Esempio n. 2
0
 def test_sample_from_array_almost_full_shuffle(self):
     x = jnp.arange(0, 100) + 100
     rng_key = jax.random.PRNGKey(0)
     n_vals = 99
     shuffled = util.sample_from_array(rng_key, x, n_vals, 0)
     unq_vals = np.unique(shuffled)
     self.assertEqual(n_vals, np.size(unq_vals))
     self.assertTrue(jnp.alltrue(shuffled >= 100))
Esempio n. 3
0
    def get_batch_without_replacement(i, rng_key):
        """ Fetches the next batch for the current epoch.

        :param i: The number of the batch in the epoch.
        :param batchifier_state: The initialized state returned by init.
        :return: the batch
        """
        batch_rng_key = jax.random.fold_in(rng_key, i)
        ret_idx = sample_from_array(batch_rng_key, jnp.arange(num_records),
                                    batch_size, 0)
        return tuple(jnp.take(a, ret_idx, axis=0) for a in dataset)
Esempio n. 4
0
 def test_sample_from_array_single_sample(self):
     x = jnp.arange(0, 100) + 100
     rng_key = jax.random.PRNGKey(0)
     n_vals = 1
     shuffled = util.sample_from_array(rng_key, x, n_vals, 0)
     self.assertTrue(jnp.alltrue(shuffled >= 100))