示例#1
0
    def test_counting_iterator_buffered_iterator_take(self):
        ref = list(range(10))
        buffered_itr = iterators.BufferedIterator(2, ref)
        itr = iterators.CountingIterator(buffered_itr)
        itr.take(5)
        self.assertEqual(len(itr), len(list(iter(itr))))
        self.assertEqual(len(itr), 5)

        buffered_itr = iterators.BufferedIterator(2, ref)
        itr = iterators.CountingIterator(buffered_itr)
        itr.take(5)
        self.assertEqual(len(buffered_itr), 5)
        self.assertEqual(len(list(iter(buffered_itr))), 5)

        buffered_itr = iterators.BufferedIterator(2, ref)
        itr = iterators.CountingIterator(buffered_itr)
        itr.take(5)
        self.assertEqual(next(itr), ref[0])
        self.assertEqual(next(itr), ref[1])
        itr.skip(2)
        self.assertEqual(next(itr), ref[4])
        self.assertFalse(itr.has_next())
        self.assertRaises(StopIteration, next, buffered_itr)

        ref = list(range(4, 10))
        buffered_itr = iterators.BufferedIterator(2, ref)
        itr = iterators.CountingIterator(buffered_itr, start=4)
        itr.take(5)
        self.assertEqual(len(itr), 5)
        self.assertEqual(len(buffered_itr), 1)
        self.assertEqual(next(itr), ref[0])
        self.assertFalse(itr.has_next())
        self.assertRaises(StopIteration, next, buffered_itr)
示例#2
0
 def test_counting_iterator_length_mismatch(self):
     ref = list(range(10))
     # When the underlying iterable is longer than the CountingIterator,
     # the remaining items in the iterable should be ignored
     itr = iterators.CountingIterator(ref, total=8)
     self.assertEqual(list(itr), ref[:8])
     # When the underlying iterable is shorter than the CountingIterator,
     # raise an IndexError when the underlying iterable is exhausted
     itr = iterators.CountingIterator(ref, total=12)
     self.assertRaises(IndexError, list, itr)
示例#3
0
    def test_counting_iterator_take(self):
        ref = list(range(10))
        itr = iterators.CountingIterator(ref)
        itr.take(5)
        self.assertEqual(len(itr), len(list(iter(itr))))
        self.assertEqual(len(itr), 5)

        itr = iterators.CountingIterator(ref)
        itr.take(5)
        self.assertEqual(next(itr), ref[0])
        self.assertEqual(next(itr), ref[1])
        itr.skip(2)
        self.assertEqual(next(itr), ref[4])
        self.assertFalse(itr.has_next())
示例#4
0
 def test_counting_iterator(self):
     x = list(range(10))
     itr = iterators.CountingIterator(x)
     self.assertTrue(itr.has_next())
     self.assertEqual(next(itr), 0)
     self.assertEqual(next(itr), 1)
     itr.skip(3)
     self.assertEqual(next(itr), 5)
     itr.skip(3)
     self.assertEqual(next(itr), 9)
     self.assertFalse(itr.has_next())
示例#5
0
def tpu_data_loader(args, itr):
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    xm.rendezvous('tpu_data_loader')  # wait for all workers
    xm.mark_step()
    device = utils.get_tpu_device(args)
    return iterators.CountingIterator(
        pl.ParallelLoader(itr, [device]).per_device_loader(device),
        start=getattr(itr, 'n', 0),
        total=len(itr),
    )
示例#6
0
def tpu_data_loader(itr):
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    from fairseq.data import iterators

    xm.rendezvous("tpu_data_loader")  # wait for all workers
    xm.mark_step()
    device = xm.xla_device()
    return iterators.CountingIterator(
        pl.ParallelLoader(itr, [device]).per_device_loader(device),
        start=getattr(itr, "n", 0),
        total=len(itr),
    )
示例#7
0
 def test_counting_iterator(self, ref=None, itr=None):
     if ref is None:
         assert itr is None
         ref = list(range(10))
         itr = iterators.CountingIterator(ref)
     else:
         assert len(ref) == 10
         assert itr is not None
     self.assertTrue(itr.has_next())
     self.assertEqual(itr.n, 0)
     self.assertEqual(next(itr), ref[0])
     self.assertEqual(itr.n, 1)
     self.assertEqual(next(itr), ref[1])
     self.assertEqual(itr.n, 2)
     itr.skip(3)
     self.assertEqual(itr.n, 5)
     self.assertEqual(next(itr), ref[5])
     itr.skip(3)
     self.assertEqual(itr.n, 9)
     self.assertEqual(next(itr), ref[9])
     self.assertFalse(itr.has_next())
示例#8
0
    def test_counting_iterator_index(self, ref=None, itr=None):
        # Test the indexing functionality of CountingIterator
        if ref is None:
            assert itr is None
            ref = list(range(10))
            itr = iterators.CountingIterator(ref)
        else:
            assert len(ref) == 10
            assert itr is not None

        self.assertTrue(itr.has_next())
        self.assertEqual(itr.n, 0)
        self.assertEqual(next(itr), ref[0])
        self.assertEqual(itr.n, 1)
        self.assertEqual(next(itr), ref[1])
        self.assertEqual(itr.n, 2)
        itr.skip(3)
        self.assertEqual(itr.n, 5)
        self.assertEqual(next(itr), ref[5])
        itr.skip(2)
        self.assertEqual(itr.n, 8)
        self.assertEqual(list(itr), [ref[8], ref[9]])
        self.assertFalse(itr.has_next())