Exemplo n.º 1
0
    def test_shuffle_data_queue_size(self, queue_size):
        samples = iter(range(100, 200))
        shuffled_stream = inputs.shuffle_data(samples, queue_size)
        first_ten = [next(shuffled_stream) for _ in range(10)]

        # Queue size limits how far ahead/upstream the current sample can reach.
        self.assertLess(first_ten[0], 100 + queue_size)
        self.assertLess(first_ten[3], 103 + queue_size)
        self.assertLess(first_ten[9], 109 + queue_size)

        unshuffled_first_ten = list(range(100, 110))
        if queue_size == 1:  # Degenerate case: no shuffling can happen.
            self.assertEqual(first_ten, unshuffled_first_ten)
        if queue_size > 1:
            self.assertNotEqual(first_ten, unshuffled_first_ten)
Exemplo n.º 2
0
 def test_shuffle_data_yields_all_samples(self, queue_size, n_samples):
     samples = iter(range(n_samples))
     shuffled_stream = inputs.shuffle_data(samples, queue_size)
     self.assertLen(list(shuffled_stream), n_samples)
Exemplo n.º 3
0
 def test_shuffle_data_raises_error_queue_size(self, queue_size):
     samples = iter(range(10))
     with self.assertRaises(ValueError):
         _ = list(inputs.shuffle_data(samples, queue_size))