def test_iterator_state_dict(self): dataset = [1, 2, 3, 4, 5, 6] it = iterators.MultithreadIterator(dataset, 2, **self.options) self.assertEqual(it.epoch, 0) self.assertAlmostEqual(it.epoch_detail, 0 / 6) self.assertIsNone(it.previous_epoch_detail) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) state_dict = copy.deepcopy(it.state_dict()) it = iterators.MultithreadIterator(dataset, 2, **self.options) it.load_state_dict(state_dict) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset) self.assertAlmostEqual(it.epoch_detail, 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 6)
def test_iterator_tuple_type(self): dataset = [(i, numpy.zeros((10, )) + i) for i in range(6)] it = iterators.MultithreadIterator(dataset, 2, **self.options) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i) if i == 0: self.assertIsNone(it.previous_epoch_detail) else: self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) batches = {} for j in range(3): batch = it.next() self.assertEqual(len(batch), 2) if j != 2: self.assertFalse(it.is_new_epoch) else: self.assertTrue(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, (3 * i + j + 1) * 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, (3 * i + j) * 2 / 6) for x in batch: self.assertIsInstance(x, tuple) self.assertIsInstance(x[1], numpy.ndarray) batches[x[0]] = x[1] self.assertEqual(len(batches), len(dataset)) for k, v in six.iteritems(batches): numpy.testing.assert_allclose(dataset[k][1], v)
def test_invalid_order_sampler(self): dataset = [1, 2, 3, 4, 5, 6] with self.assertRaises(ValueError): it = iterators.MultithreadIterator( dataset, 6, order_sampler=InvalidOrderSampler()) it.next()
def test_iterator_repeat(self): dataset = [1, 2, 3] it = iterators.MultithreadIterator(dataset, 2, **self.options) for i in range(3): self.assertEqual(it.epoch, i) self.assertAlmostEqual(it.epoch_detail, i + 0 / 6) if i == 0: self.assertIsNone(it.previous_epoch_detail) else: self.assertAlmostEqual(it.previous_epoch_detail, i - 2 / 6) batch1 = it.next() self.assertEqual(len(batch1), 2) self.assertIsInstance(batch1, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 2 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 0 / 6) batch2 = it.next() self.assertEqual(len(batch2), 2) self.assertIsInstance(batch2, list) self.assertFalse(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 4 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 2 / 6) batch3 = it.next() self.assertEqual(len(batch3), 2) self.assertIsInstance(batch3, list) self.assertTrue(it.is_new_epoch) self.assertAlmostEqual(it.epoch_detail, i + 6 / 6) self.assertAlmostEqual(it.previous_epoch_detail, i + 4 / 6) self.assertEqual(sorted(batch1 + batch2 + batch3), [1, 1, 2, 2, 3, 3])
def test_supported_reset_middle(self): dataset = [1, 2, 3, 4, 5] it = iterators.MultithreadIterator(dataset, 2, repeat=False, **self.options) it.next() it.reset()
def test_supported_reset_repeat(self): dataset = [1, 2, 3, 4] it = iterators.MultithreadIterator(dataset, 2, repeat=True, **self.options) it.next() it.next() it.reset()
def test_supported_reset_finalized(self): dataset = [1, 2, 3, 4] it = iterators.MultithreadIterator(dataset, 2, repeat=False, **self.options) it.next() it.next() it.finalize() it.reset()
def test_no_same_indices_order_sampler(self): dataset = [1, 2, 3, 4, 5, 6] batchsize = 5 it = iterators.MultithreadIterator( dataset, batchsize, order_sampler=NoSameIndicesOrderSampler(batchsize)) for _ in range(5): batch = it.next() self.assertEqual(len(numpy.unique(batch)), batchsize)
def test_iterator_not_repeat(self): dataset = [1, 2, 3, 4, 5] it = iterators.MultithreadIterator(dataset, 2, repeat=False, **self.options) batches = sum([it.next() for _ in range(3)], []) self.assertEqual(sorted(batches), dataset) for _ in range(2): self.assertRaises(StopIteration, it.next)
def test_iterator_not_repeat_not_even(self): dataset = [1, 2, 3, 4, 5] it = iterators.MultithreadIterator(dataset, 2, repeat=False, **self.options) self.assertAlmostEqual(it.epoch_detail, 0 / 5) self.assertIsNone(it.previous_epoch_detail) batch1 = it.next() self.assertAlmostEqual(it.epoch_detail, 2 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 0 / 5) batch2 = it.next() self.assertAlmostEqual(it.epoch_detail, 4 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 2 / 5) batch3 = it.next() self.assertAlmostEqual(it.epoch_detail, 5 / 5) self.assertAlmostEqual(it.previous_epoch_detail, 4 / 5) self.assertRaises(StopIteration, it.next) self.assertEqual(len(batch3), 1) self.assertEqual(sorted(batch1 + batch2 + batch3), dataset)
def test_iterator_shuffle_nondivisible(self): dataset = list(range(10)) it = iterators.MultithreadIterator(dataset, 3, **self.options) out = sum([it.next() for _ in range(7)], []) self.assertNotEqual(out[0:10], out[10:20])
def test_iterator_shuffle_divisible(self): dataset = list(range(10)) it = iterators.MultithreadIterator(dataset, 10, **self.options) self.assertNotEqual(it.next(), it.next())
def test_iterator_repeat_not_even(self): dataset = [1, 2, 3, 4, 5] it = iterators.MultithreadIterator(dataset, 2, **self.options) batches = sum([it.next() for _ in range(5)], []) self.assertEqual(sorted(batches), sorted(dataset * 2))