def test_counting_iterator_buffered_iterator_take(self): ref = list(range(10)) buffered_itr = iterators.BufferedIterator(2, ref) itr = iterators.CountingIterator(buffered_itr) itr.take(5) self.assertEqual(len(itr), len(list(iter(itr)))) self.assertEqual(len(itr), 5) buffered_itr = iterators.BufferedIterator(2, ref) itr = iterators.CountingIterator(buffered_itr) itr.take(5) self.assertEqual(len(buffered_itr), 5) self.assertEqual(len(list(iter(buffered_itr))), 5) buffered_itr = iterators.BufferedIterator(2, ref) itr = iterators.CountingIterator(buffered_itr) itr.take(5) self.assertEqual(next(itr), ref[0]) self.assertEqual(next(itr), ref[1]) itr.skip(2) self.assertEqual(next(itr), ref[4]) self.assertFalse(itr.has_next()) self.assertRaises(StopIteration, next, buffered_itr) ref = list(range(4, 10)) buffered_itr = iterators.BufferedIterator(2, ref) itr = iterators.CountingIterator(buffered_itr, start=4) itr.take(5) self.assertEqual(len(itr), 5) self.assertEqual(len(buffered_itr), 1) self.assertEqual(next(itr), ref[0]) self.assertFalse(itr.has_next()) self.assertRaises(StopIteration, next, buffered_itr)
def test_counting_iterator_length_mismatch(self): ref = list(range(10)) # When the underlying iterable is longer than the CountingIterator, # the remaining items in the iterable should be ignored itr = iterators.CountingIterator(ref, total=8) self.assertEqual(list(itr), ref[:8]) # When the underlying iterable is shorter than the CountingIterator, # raise an IndexError when the underlying iterable is exhausted itr = iterators.CountingIterator(ref, total=12) self.assertRaises(IndexError, list, itr)
def test_counting_iterator_take(self): ref = list(range(10)) itr = iterators.CountingIterator(ref) itr.take(5) self.assertEqual(len(itr), len(list(iter(itr)))) self.assertEqual(len(itr), 5) itr = iterators.CountingIterator(ref) itr.take(5) self.assertEqual(next(itr), ref[0]) self.assertEqual(next(itr), ref[1]) itr.skip(2) self.assertEqual(next(itr), ref[4]) self.assertFalse(itr.has_next())
def test_counting_iterator(self): x = list(range(10)) itr = iterators.CountingIterator(x) self.assertTrue(itr.has_next()) self.assertEqual(next(itr), 0) self.assertEqual(next(itr), 1) itr.skip(3) self.assertEqual(next(itr), 5) itr.skip(3) self.assertEqual(next(itr), 9) self.assertFalse(itr.has_next())
def tpu_data_loader(args, itr): import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl xm.rendezvous('tpu_data_loader') # wait for all workers xm.mark_step() device = utils.get_tpu_device(args) return iterators.CountingIterator( pl.ParallelLoader(itr, [device]).per_device_loader(device), start=getattr(itr, 'n', 0), total=len(itr), )
def tpu_data_loader(itr): import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl from fairseq.data import iterators xm.rendezvous("tpu_data_loader") # wait for all workers xm.mark_step() device = xm.xla_device() return iterators.CountingIterator( pl.ParallelLoader(itr, [device]).per_device_loader(device), start=getattr(itr, "n", 0), total=len(itr), )
def test_counting_iterator(self, ref=None, itr=None): if ref is None: assert itr is None ref = list(range(10)) itr = iterators.CountingIterator(ref) else: assert len(ref) == 10 assert itr is not None self.assertTrue(itr.has_next()) self.assertEqual(itr.n, 0) self.assertEqual(next(itr), ref[0]) self.assertEqual(itr.n, 1) self.assertEqual(next(itr), ref[1]) self.assertEqual(itr.n, 2) itr.skip(3) self.assertEqual(itr.n, 5) self.assertEqual(next(itr), ref[5]) itr.skip(3) self.assertEqual(itr.n, 9) self.assertEqual(next(itr), ref[9]) self.assertFalse(itr.has_next())
def test_counting_iterator_index(self, ref=None, itr=None): # Test the indexing functionality of CountingIterator if ref is None: assert itr is None ref = list(range(10)) itr = iterators.CountingIterator(ref) else: assert len(ref) == 10 assert itr is not None self.assertTrue(itr.has_next()) self.assertEqual(itr.n, 0) self.assertEqual(next(itr), ref[0]) self.assertEqual(itr.n, 1) self.assertEqual(next(itr), ref[1]) self.assertEqual(itr.n, 2) itr.skip(3) self.assertEqual(itr.n, 5) self.assertEqual(next(itr), ref[5]) itr.skip(2) self.assertEqual(itr.n, 8) self.assertEqual(list(itr), [ref[8], ref[9]]) self.assertFalse(itr.has_next())