Пример #1
0
def pin_memory_batch(batch):
    if isinstance(batch, torch.Tensor):
        return batch.pin_memory()
    elif isinstance(batch, string_classes):
        return batch
    elif isinstance(batch, container_abcs.Mapping):
        return {k: pin_memory_batch(sample) for k, sample in batch.items()}
    elif isinstance(batch, container_abcs.Sequence):
        return [pin_memory_batch(sample) for sample in batch]
    else:
        return batch
    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            random_var = None
            if self.random_vars:
                random_var = random.choice(self.random_vars)
            batch = self.collate_fn(
                [self.dataset.get(i, random_var) for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch)

        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration

        while True:
            assert (not self.shutdown and self.batches_outstanding > 0)
            try:
                idx, batch = self.data_queue.get()
            except Exception as e:
                print(e)
                raise
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch
                continue
            return self._process_next_batch(batch)
Пример #3
0
    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.dataset[indices]
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch)

        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration

        while True:
            assert (not self.shutdown and self.batches_outstanding > 0)
            idx, batch = self._get_batch()
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch
                continue
            return self._process_next_batch(batch)
Пример #4
0
def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
    torch.cuda.set_device(device_id)

    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
    # logic of this function.
    while True:
        try:
            r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        except queue.Empty:
            continue
        except Exception:
            if done_event.is_set():
                # Weird things can happen when shutting down, e.g., fd being
                # closed when tensors are shared via fds.
                break
            raise
        if r is None:
            assert done_event.is_set()
            return
        elif done_event.is_set():
            # Haven't seen the final signal yet. Keep getting until None.
            continue
        elif isinstance(r[1], ExceptionWrapper):
            out_queue.put(r)
        else:
            idx, batch = r
            try:
                batch = pin_memory_batch(batch)
            except Exception:
                out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
            else:
                out_queue.put((idx, batch))
Пример #5
0
def _pin_memory_loop(in_queue: multiprocessing.SimpleQueue,
                     out_queue: queue.Queue, done_event: threading.Event):
    while True:
        try:
            r = in_queue.get()
        except Exception:
            if done_event.is_set():
                return
            raise
        if r is None:
            break
        if isinstance(r[1], torch_loader.ExceptionWrapper):
            out_queue.put(r)
            continue

        idx, batch_content = r
        batch_indices = batch_content[0]
        batch = batch_content[1:]

        try:
            batch = torch_loader.pin_memory_batch(batch)
        except Exception:
            out_queue.put((idx, torch_loader.ExceptionWrapper(sys.exc_info())))
        else:
            out_queue.put((idx, [batch_indices] + batch))
Пример #6
0
    def __next__(
            self) -> Tuple[List[int], torch.FloatTensor, torch.FloatTensor]:
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = torch_loader.pin_memory_batch(batch)
            return [indices] + batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch)

        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration

        while True:
            assert (not self.shutdown and self.batches_outstanding > 0)
            idx, batch = self.data_queue.get()
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch
                continue
            return self._process_next_batch(batch)
Пример #7
0
def worker(bs, tsfm, df_queue, out_queue):
    while True:
        df = df_queue.get()
        batch = collate([tsfm(df.iloc[i]) for i in range(len(df))], cuda=False)
        batch = pin_memory_batch(batch)
        out_queue.put(batch)
Пример #8
0
 def _next_without_worker(self):
     indices = next(self.sample_iter)  # may raise StopIteration
     batch = self.collate_fn(self.dataset, indices)
     if self.pin_memory:
         batch = pin_memory_batch(batch)
     return batch