示例#1
0
    def __init__(self,
                 dataset,
                 batch_size,
                 repeat=True,
                 shuffle=None,
                 order_sampler=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self._repeat = repeat
        self._shuffle = shuffle

        if self._shuffle is not None:
            if order_sampler is not None:
                raise ValueError('`shuffle` is not `None` and a custom '
                                 '`order_sampler` is set. Please set '
                                 '`shuffle` to `None` to use the custom '
                                 'order sampler.')
            else:
                if self._shuffle:
                    order_sampler = ShuffleOrderSampler()
        else:
            if order_sampler is None:
                order_sampler = ShuffleOrderSampler()
        self.order_sampler = order_sampler

        self.reset()
示例#2
0
    def __init__(self, dataset, batch_size, repeat=True, shuffle=None,
                 n_threads=1, order_sampler=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self._repeat = repeat
        self._shuffle = shuffle
        self._prefetch_order = None  # used at the end of each epoch
        self.current_position = 0
        self.epoch = 0

        if self._shuffle is not None:
            if order_sampler is not None:
                raise ValueError('`shuffle` is not `None` and a custom '
                                 '`order_sampler` is set. Please set '
                                 '`shuffle` to `None` to use the custom '
                                 'order sampler.')
            else:
                if self._shuffle:
                    order_sampler = ShuffleOrderSampler()
        else:
            if order_sampler is None:
                order_sampler = ShuffleOrderSampler()
        self.order_sampler = order_sampler

        self.n_threads = n_threads
        self._pool = None

        self.reset()
示例#3
0
    def __init__(self,
                 dataset,
                 batch_size,
                 repeat=True,
                 shuffle=None,
                 n_processes=None,
                 n_prefetch=1,
                 shared_mem=None,
                 order_sampler=None,
                 dataset_timeout=30.0,
                 maxtasksperchild=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.repeat = repeat
        self.shuffle = shuffle
        self.n_processes = n_processes or multiprocessing.cpu_count()
        self.n_prefetch = max(n_prefetch, 1)
        self.shared_mem = shared_mem
        self.dataset_timeout = dataset_timeout
        self._maxtasksperchild = maxtasksperchild

        if self.shuffle is not None:
            if order_sampler is not None:
                raise ValueError('`shuffle` is not `None` and a custom '
                                 '`order_sampler` is set. Please set '
                                 '`shuffle` to `None` to use the custom '
                                 'order sampler.')
            else:
                if self.shuffle:
                    order_sampler = ShuffleOrderSampler()
        else:
            if order_sampler is None:
                order_sampler = ShuffleOrderSampler()
        self.order_sampler = order_sampler

        self._comm = _Communicator(self.n_prefetch, dataset_timeout)
        self.reset()

        self._prefetch_loop = _PrefetchLoop(self.dataset, self.batch_size,
                                            self.repeat, self.n_processes,
                                            self.n_prefetch, self.shared_mem,
                                            self._comm, self.order_sampler,
                                            self._interruption_testing,
                                            self._maxtasksperchild)
示例#4
0
    def __init__(self,
                 dataset,
                 batch_size,
                 repeat=True,
                 shuffle=None,
                 n_processes=None,
                 n_prefetch=1,
                 shared_mem=None,
                 order_sampler=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.repeat = repeat
        self.shuffle = shuffle

        self.n_processes = n_processes or multiprocessing.cpu_count()
        self.n_prefetch = max(n_prefetch, 1)
        self.shared_mem = shared_mem

        if self.shuffle is not None:
            if order_sampler is not None:
                raise ValueError('`shuffle` is not `None` and a custom '
                                 '`order_sampler` is set. Please set '
                                 '`shuffle` to `None` to use the custom '
                                 'order sampler.')
            else:
                if self.shuffle:
                    order_sampler = ShuffleOrderSampler()
        else:
            if order_sampler is None:
                order_sampler = ShuffleOrderSampler()
        self.order_sampler = order_sampler

        self._comm = _Communicator(self.n_prefetch)
        self.reset()

        self._prefetch_loop = _PrefetchLoop(self.dataset, self.batch_size,
                                            self.repeat, self.n_processes,
                                            self.n_prefetch, self.shared_mem,
                                            self._comm, self.order_sampler,
                                            self._interruption_testing)
        # defer launching prefetch thread until creating the worker pool,
        # not to leave a background thread in forked processes.
        self._thread = None
示例#5
0
    def __init__(self,
                 dataset,
                 batch_size,
                 local_storage_base,
                 n_prefetch,
                 n_prefetch_from_backend,
                 n_generate_batch,
                 n_remove_example=1,
                 repeat=True,
                 shuffle=None,
                 shared_mem=None,
                 dataset_timeout=30.0,
                 waiting_id_queue_max_size=1000,
                 prefetched_id_queue_max_size=1000,
                 used_id_queue_max_size=1000,
                 dataset_start=0,
                 dataset_finish=0,
                 measure=False):

        self.dataset = dataset  # support only ExtendedLabeledImageDataset
        self.batch_size = batch_size
        self.local_storage_base = local_storage_base
        self.repeat = repeat
        self.shuffle = shuffle
        self.n_prefetch = n_prefetch
        self.n_prefetch_from_backend = n_prefetch_from_backend
        self.n_generate_batch = n_generate_batch
        self.n_remove_example = n_remove_example
        self.order_sampler = ShuffleOrderSampler()  # fixed, for now
        self.shared_mem = shared_mem
        self.dataset_timeout = dataset_timeout
        self.dataset_start = dataset_start
        self.dataset_finish = dataset_finish
        self._measure = measure
        self.iteration = 0

        self._comm = _Communicator(self.n_prefetch, dataset_timeout)
        self.reset()

        self._prefetch_pipeline = _PrefetchPipeline(
            self.dataset, self.batch_size, self.local_storage_base,
            self.n_prefetch_from_backend, self.n_generate_batch,
            self.n_remove_example, self.shared_mem, self._comm,
            self.order_sampler, self.repeat, self.n_prefetch,
            waiting_id_queue_max_size, prefetched_id_queue_max_size,
            used_id_queue_max_size, dataset_start, dataset_finish,
            self._measure)