Esempio n. 1
0
class _RandomTransformMultiWorkerIter(object):
    """Interal multi-worker iterator for DataLoader with random transform functions."""
    def __init__(self, transform_fns, interval, num_workers, dataset, batchify_fn, batch_sampler,
                 pin_memory=False):
        assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
        assert isinstance(transform_fns, (list, tuple)) and len(transform_fns) > 1
        from mxnet.gluon.data.dataloader import Queue, SimpleQueue
        self._transform_fns = transform_fns
        self._fn_idx = np.random.randint(len(self._transform_fns))
        self._interval = max(int(interval), 1)
        self._num_workers = num_workers
        self._datasets = [dataset.transform(trans_fn) for trans_fn in self._transform_fns]
        self._batchify_fn = batchify_fn
        self._batch_sampler = batch_sampler
        self._key_queue = Queue()
        self._data_queue = Queue() if sys.version_info[0] <= 2 else SimpleQueue()
        self._data_buffer = {}
        self._data_buffer_lock = threading.Lock()
        self._rcvd_idx = 0
        self._sent_idx = 0
        self._iter = iter(self._batch_sampler)
        self._shutdown = False

        workers = []
        for _ in range(self._num_workers):
            worker = multiprocessing.Process(
                target=random_worker_loop,
                args=(self._datasets, self._key_queue, self._data_queue, self._batchify_fn))
            worker.daemon = True
            worker.start()
            workers.append(worker)
        self._workers = workers

        self._fetcher = threading.Thread(
            target=fetcher_loop,
            args=(self._data_queue, self._data_buffer, pin_memory))
        self._fetcher.daemon = True
        self._fetcher.start()

        # pre-fetch
        for _ in range(2 * self._num_workers):
            self._push_next()

    def __len__(self):
        return len(self._batch_sampler)

    def __del__(self):
        self.shutdown()

    def _push_next(self):
        """Assign next batch workload to workers."""
        r = next(self._iter, None)
        if r is None:
            return
        if (self._sent_idx + 1) % self._interval == 0:
            self._fn_idx = np.random.randint(len(self._transform_fns))
        self._key_queue.put((self._sent_idx, r, self._fn_idx))
        self._sent_idx += 1

    def __next__(self):
        assert not self._shutdown, "call __next__ after shutdown is forbidden"
        if self._rcvd_idx == self._sent_idx:
            assert not self._data_buffer, "Data buffer should be empty at this moment"
            self.shutdown()
            raise StopIteration

        while True:
            if self._rcvd_idx in self._data_buffer:
                with self._data_buffer_lock:
                    batch = self._data_buffer.pop(self._rcvd_idx)
                self._rcvd_idx += 1
                self._push_next()
                return batch

    def next(self):
        return self.__next__()

    def __iter__(self):
        return self

    def shutdown(self):
        """Shutdown internal workers by pushing terminate signals."""
        if not self._shutdown:
            # send shutdown signal to the fetcher and join data queue first
            # Remark:   loop_fetcher need to be joined prior to the workers.
            #           otherwise, the the fetcher may fail at getting data
            self._data_queue.put((None, None))
            self._fetcher.join()
            # send shutdown signal to all worker processes
            for _ in range(self._num_workers):
                self._key_queue.put((None, None, None))
            # force shut down any alive worker processes
            for w in self._workers:
                if w.is_alive():
                    w.terminate()
            self._shutdown = True
Esempio n. 2
0
class _RandomTransformMultiWorkerIter(object):
    """Interal multi-worker iterator for DataLoader with random transform functions."""
    def __init__(self, transform_fns, interval, num_workers, dataset, batchify_fn, batch_sampler,
                 pin_memory=False):
        assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
        assert isinstance(transform_fns, (list, tuple)) and len(transform_fns) > 1
        from mxnet.gluon.data.dataloader import Queue, SimpleQueue
        self._transform_fns = transform_fns
        self._fn_idx = np.random.randint(len(self._transform_fns))
        self._interval = max(int(interval), 1)
        self._num_workers = num_workers
        self._datasets = [dataset.transform(trans_fn) for trans_fn in self._transform_fns]
        self._batchify_fn = batchify_fn
        self._batch_sampler = batch_sampler
        self._key_queue = Queue()
        self._data_queue = Queue() if sys.version_info[0] <= 2 else SimpleQueue()
        self._data_buffer = {}
        self._rcvd_idx = 0
        self._sent_idx = 0
        self._iter = iter(self._batch_sampler)
        self._shutdown = False

        workers = []
        for _ in range(self._num_workers):
            worker = multiprocessing.Process(
                target=random_worker_loop,
                args=(self._datasets, self._key_queue, self._data_queue, self._batchify_fn))
            worker.daemon = True
            worker.start()
            workers.append(worker)

        self._fetcher = threading.Thread(
            target=fetcher_loop,
            args=(self._data_queue, self._data_buffer, pin_memory))
        self._fetcher.daemon = True
        self._fetcher.start()

        # pre-fetch
        for _ in range(2 * self._num_workers):
            self._push_next()

    def __len__(self):
        return len(self._batch_sampler)

    def __del__(self):
        self.shutdown()

    def _push_next(self):
        """Assign next batch workload to workers."""
        r = next(self._iter, None)
        if r is None:
            return
        if (self._sent_idx + 1) % self._interval == 0:
            self._fn_idx = np.random.randint(len(self._transform_fns))
        self._key_queue.put((self._sent_idx, r, self._fn_idx))
        self._sent_idx += 1

    def __next__(self):
        assert not self._shutdown, "call __next__ after shutdown is forbidden"
        if self._rcvd_idx == self._sent_idx:
            assert not self._data_buffer, "Data buffer should be empty at this moment"
            self.shutdown()
            raise StopIteration

        while True:
            if self._rcvd_idx in self._data_buffer:
                batch = self._data_buffer.pop(self._rcvd_idx)
                self._rcvd_idx += 1
                self._push_next()
                return batch

    def next(self):
        return self.__next__()

    def __iter__(self):
        return self

    def shutdown(self):
        """Shutdown internal workers by pushing terminate signals."""
        if not self._shutdown:
            for _ in range(self._num_workers):
                self._key_queue.put((None, None, None))
            self._data_queue.put((None, None))
            self._shutdown = True
Esempio n. 3
0
class _ShardedMultiWorkerIter(object):
    """Interal multi-worker iterator for ShardedDataLoader."""
    def __init__(self,
                 num_workers,
                 dataset,
                 batchify_fn,
                 batch_sampler,
                 pin_memory=False):
        assert num_workers > 0, '_MultiWorkerIter is not for {} workers'.format(
            num_workers)
        self._num_workers = num_workers
        self._dataset = dataset
        self._batchify_fn = batchify_fn
        self._batch_sampler = batch_sampler
        self._key_queue = Queue()
        self._data_queue = Queue(
        ) if sys.version_info[0] <= 2 else SimpleQueue()
        self._data_buffer = {}
        self._rcvd_idx = 0
        self._sent_idx = 0
        self._iter = iter(self._batch_sampler)
        self._shutdown = False

        workers = []
        for _ in range(self._num_workers):
            worker = multiprocessing.Process(
                target=worker_loop,
                args=(self._dataset, self._key_queue, self._data_queue,
                      self._batchify_fn))
            worker.daemon = True
            worker.start()
            workers.append(worker)

        self._fetcher = threading.Thread(target=fetcher_loop,
                                         args=(self._data_queue,
                                               self._data_buffer, pin_memory))
        self._fetcher.daemon = True
        self._fetcher.start()

        # pre-fetch
        for _ in range(2 * self._num_workers):
            self._push_next()

    def __len__(self):
        return len(self._batch_sampler)

    def __del__(self):
        self.shutdown()

    def _push_next(self):
        """Assign next batch workload to workers."""
        r = next(self._iter, None)
        if r is None:
            return
        self._key_queue.put((self._sent_idx, r))
        self._sent_idx += 1

    def __next__(self):
        assert not self._shutdown, 'call __next__ after shutdown is forbidden'
        if self._rcvd_idx == self._sent_idx:
            assert not self._data_buffer, 'Data buffer should be empty at this moment'
            self.shutdown()
            raise StopIteration

        while True:
            if self._rcvd_idx in self._data_buffer:
                batch = self._data_buffer.pop(self._rcvd_idx)
                self._rcvd_idx += 1
                self._push_next()
                return batch

    def next(self):
        return self.__next__()

    def __iter__(self):
        return self

    def shutdown(self):
        """Shutdown internal workers by pushing terminate signals."""
        if not self._shutdown:
            for _ in range(self._num_workers):
                self._key_queue.put((None, None))
            self._data_queue.put((None, None))
            self._shutdown = True