Example #1
0
    def _task(self):
        # Do a single task in the prefetch thread.
        # Returns a bool indicating whether the loop should continue running.

        status, prefetch_state, reset_count = self._comm.check()
        if status == _Communicator.STATUS_RESET:
            self.prefetch_state = prefetch_state
        elif status == _Communicator.STATUS_TERMINATE:
            return False  # stop loop

        self.prefetch_state, indices = iterator_statemachine(
            self.prefetch_state, self.batch_size, self.repeat,
            self.order_sampler, len(self.dataset))
        if indices is None:  # stop iteration
            batch = None
        else:
            future = self._pool.map_async(_fetch_run, enumerate(indices))
            while True:
                try:
                    data_all = future.get(_response_time)
                except multiprocessing.TimeoutError:
                    if self._comm.is_terminated:
                        return False
                else:
                    break
            batch = [_unpack(data, self.mem_bulk) for data in data_all]

        self._comm.put(batch, self.prefetch_state, reset_count)
        return True
Example #2
0
    def _task(self):
        # Do a single task in the prefetch thread.
        # Returns a bool indicating whether the loop should continue running.

        status, prefetch_state, reset_count = self._comm.check()
        if status == _Communicator.STATUS_RESET:
            self.prefetch_state = prefetch_state
        elif status == _Communicator.STATUS_TERMINATE:
            return False  # stop loop

        self.prefetch_state, indices = iterator_statemachine(
            self.prefetch_state, self.batch_size, self.repeat,
            self.order_sampler, len(self.dataset))
        if indices is None:  # stop iteration
            batch = None
        else:
            future = self._pool.map_async(_fetch_run, enumerate(indices))
            while True:
                try:
                    data_all = future.get(_response_time)
                except multiprocessing.TimeoutError:
                    if self._comm.is_terminated:
                        return False
                else:
                    break
            batch = [_unpack(data, self.mem_bulk) for data in data_all]

        self._comm.put(batch, self.prefetch_state, reset_count)
        return True
Example #3
0
def _generate_random_id_loop(_prefetch_multiprocess_iterator_terminating,
                             _prefetch_multiprocess_iterator_waiting_id_queue,
                             _prefetch_multiprocess_iterator_fetch_dataset,
                             dataset_start, prefetch_batch_size, batch_size,
                             repeat, order_sampler):
    _prefetch_multiprocess_iterator_waiting_id_queue.cancel_join_thread()
    dataset_length = len(_prefetch_multiprocess_iterator_fetch_dataset)
    initial_order = order_sampler(numpy.arange(dataset_length), 0)
    random_id_state = IteratorState(0, 0, False, initial_order)
    '''
    print(f'{os.uname()[1]}/{os.getpid()}:dataset_length:{dataset_length}', file=sys.stderr)
    print(f'{os.uname()[1]}/{os.getpid()}:len(initial_order):{len(initial_order)}', file=sys.stderr)
    print(f'{os.uname()[1]}/{os.getpid()}:random_id_state:{random_id_state}', file=sys.stderr)
    print(f'{os.uname()[1]}/{os.getpid()}:len(numpy.unique(initial_order)):{len(numpy.unique(initial_order))}', file=sys.stderr)
    print(f'{os.uname()[1]}/{os.getpid()}:dataset_start:{dataset_start}', file=sys.stderr)
    sys.stderr.flush()
    '''

    while not _prefetch_multiprocess_iterator_terminating.is_set():
        random_id_state, indices = iterator_statemachine(
            random_id_state, batch_size, repeat, order_sampler, dataset_length)

        while True:
            try:
                # Note: `indices` is an object of numpy.ndarray
                _prefetch_multiprocess_iterator_waiting_id_queue.put(
                    dataset_start + indices, timeout=_response_time)
            except queue.Full:
                if _prefetch_multiprocess_iterator_terminating.is_set():
                    return
            else:
                break
Example #4
0
    def __next__(self):
        self._previous_epoch_detail = self.epoch_detail
        self._state, indices = iterator_statemachine(
            self._state, self.batch_size, self.repeat, self.order_sampler,
            len(self.dataset))
        if indices is None:
            raise StopIteration

        batch = [self.dataset[index] for index in indices]
        return batch
    def __next__(self):
        self._previous_epoch_detail = self.epoch_detail
        self._state, indices = _statemachine.iterator_statemachine(
            self._state, self.batch_size, self.repeat, self.order_sampler,
            len(self.dataset))
        if indices is None:
            raise StopIteration

        batch = [self.dataset[index] for index in indices]
        return batch
Example #6
0
    def _invoke_prefetch(self):
        assert self._next is None
        self._next_state, indices = iterator_statemachine(
            self._state, self.batch_size, self.repeat, self.order_sampler,
            len(self.dataset))

        if indices is None:
            self._next = None
        else:
            if self._pool is None:
                self._pool = pool.ThreadPool(self.n_threads)
            args = [(self.dataset, index) for index in indices]
            self._next = self._pool.map_async(MultithreadIterator._read, args)
Example #7
0
    def _invoke_prefetch(self):
        assert self._next is None
        self._next_state, indices = iterator_statemachine(
            self._state, self.batch_size, self.repeat, self.order_sampler,
            len(self.dataset))

        if indices is None:
            self._next = None
        else:
            if self._pool is None:
                self._pool = pool.ThreadPool(self.n_threads)
            args = [(self.dataset, index) for index in indices]
            self._next = self._pool.map_async(MultithreadIterator._read, args)
Example #8
0
    def peek(self):
        """
        Return the next batch of data without updating its internal state.
        Several call to peek() should return the same result. A call to next()
        after a call to peek() will return the same result as the previous peek.
        """
        if not self._repeat and self.epoch > 0:
            raise StopIteration

        if hasattr(self, "order_sampler"):
            state, indices = _statemachine.iterator_statemachine(
                self._state, self.batch_size, self.repeat, self.order_sampler,
                len(self.dataset))
            if indices is None:
                return []

            batch = [self.dataset[index] for index in indices]
            return batch

        else:

            i = self.current_position
            i_end = i + self.batch_size
            N = len(self.dataset)
            if (not hasattr(self, "_order")) or self._order is None:
                batch = self.dataset[i:i_end]
            else:
                batch = [self.dataset[index] for index in self._order[i:i_end]]

            if i_end >= N:
                if self._repeat:
                    rest = i_end - N

                    if hasattr(self, "_order") and self._order is not None:
                            numpy.random.shuffle(self._order)
                    if rest > 0:
                        if (not hasattr(self, "_order")) or self._order is None:
                            batch += list(self.dataset[:rest])
                        else:
                            batch += [self.dataset[index]
                                    for index in self._order[:rest]]
            return batch
Example #9
0
    def measure(self, dataset_timeout):
        # dataset_timeout: timeout in seconds or None

        status, prefetch_state, _ = self._comm.check()
        if status == _Communicator.STATUS_RESET:
            self.prefetch_state = prefetch_state

        self.prefetch_state, indices = iterator_statemachine(
            self.prefetch_state, self.batch_size, self.repeat,
            self.order_sampler, len(self.dataset))
        if indices is None:  # stop iteration
            batch = None
        else:
            batch_ret = [None]

            def fetch_batch():
                batch_ret[0] = [self.dataset[idx] for idx in indices]

            if dataset_timeout is None:
                # Timeout is not set: fetch synchronously
                fetch_batch()
            else:
                # Timeout is set: fetch asynchronously and watch for timeout
                thr = threading.Thread(target=fetch_batch)
                thr.daemon = True
                thr.start()
                thr.join(dataset_timeout)
                if thr.is_alive():
                    _raise_timeout_warning()
                thr.join()

            batch = batch_ret[0]
            self.mem_size = max(map(_measure, batch))
            self._allocate_shared_memory()

        return batch, self.prefetch_state
Example #10
0
    def measure(self, dataset_timeout):
        # dataset_timeout: timeout in seconds or None

        status, prefetch_state, _ = self._comm.check()
        if status == _Communicator.STATUS_RESET:
            self.prefetch_state = prefetch_state

        self.prefetch_state, indices = iterator_statemachine(
            self.prefetch_state, self.batch_size, self.repeat,
            self.order_sampler, len(self.dataset))
        if indices is None:  # stop iteration
            batch = None
        else:
            batch_ret = [None]

            def fetch_batch():
                batch_ret[0] = [self.dataset[idx] for idx in indices]

            if dataset_timeout is None:
                # Timeout is not set: fetch synchronously
                fetch_batch()
            else:
                # Timeout is set: fetch asynchronously and watch for timeout
                thr = threading.Thread(target=fetch_batch)
                thr.daemon = True
                thr.start()
                thr.join(dataset_timeout)
                if thr.is_alive():
                    _raise_timeout_warning()
                thr.join()

            batch = batch_ret[0]
            self.mem_size = max(map(_measure, batch))
            self._allocate_shared_memory()

        return batch, self.prefetch_state
Example #11
0
    def _generate_batch_task(self):
        status, prefetch_state, reset_count = self._comm.check()

        if status == _Communicator.STATUS_RESET:
            self.prefetch_state = prefetch_state
        elif status == _Communicator.STATUS_TERMINATE:
            return False  # stop loop

        # Here, indices is used only to decide whether iteration should be stopped or not
        self.prefetch_state, _indices = iterator_statemachine(
            self.prefetch_state, self.batch_size, self.repeat,
            self.order_sampler, len(self.dataset))
        # if repeat == False and passed 1 epoch, `indices` will be None
        # see the implementation of `iterator_statemachine` for more detail
        if _indices is None:
            batch = None
        else:
            cached_index_get_timer = time.time()
            while True:
                try:
                    indices, prefetcher_pid, prefetch_time \
                        = _prefetch_multiprocess_iterator_cached_id_queue.get(timeout=_response_time)
                except queue.Empty:
                    if _prefetch_multiprocess_iterator_terminating.is_set():
                        return False
                else:
                    break
            cached_index_get_time = time.time() - cached_index_get_timer
            self.cached_index_get_times.append(cached_index_get_time)
            fetch_data_timer = time.time()
            future = self._generate_batch_pool.map_async(
                _generate_batch, enumerate(indices))
            while True:
                try:
                    data_all = future.get(_response_time)
                except multiprocessing.TimeoutError:
                    if _prefetch_multiprocess_iterator_terminating.is_set():
                        return False
                else:
                    break
            self.fetch_data_time = self.fetch_data_time + time.time(
            ) - fetch_data_timer

            unpack_and_organize_batch_timer = time.time()
            batch = [_unpack(data[0], self.mem_bulk) for data in data_all]
            self.unpack_and_organize_batch_time = self.unpack_and_organize_batch_time + time.time() - \
                unpack_and_organize_batch_timer

            if self._measure:
                for data in data_all:
                    self._generate_batch_times.append(data[1])
                    self._read_data_times.append(data[0][2])
                    self._get_example_times.append(data[0][3])

                if prefetcher_pid not in self._prefetch_time.keys():
                    self._prefetch_time[prefetcher_pid] = []
                self._prefetch_time[prefetcher_pid].append(prefetch_time)

        self._comm.put(batch, self.prefetch_state, reset_count)
        '''
        if batch is not None:
            while True:
                try:
                    _prefetch_multiprocess_iterator_used_id_queue.put(indices, timeout=_response_time)
                except queue.Full:
                    if _prefetch_multiprocess_iterator_terminating.is_set():
                        return False
                else:
                    break
        '''

        return True