Esempio n. 1
0
    def __iter__(self):
        """
        This works with H5Dataset's preload capability to load a single
        rank/epoch datarange into memory at once.

        * The main difference is we set the H5Dataset range, which
          means that __getitem__ indexing will always start at 0
        """

        start, stop = self.data_range
        self._preload()

        for i in range(self.n_iter):
            if self.batch_size == 1:
                data = self.dataset[i]
            else:
                b_start = i * self.batch_size
                b_stop = min((i + 1) * self.batch_size, len(self.dataset))
                s = slice(b_start, b_stop)
                data = self.dataset[s]

            if not type(data) == torch.Tensor:
                data = torch.tensor(data)

            if self.pin_memory:
                data = pm.pin_memory(data)

            yield data
Esempio n. 2
0
    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            random_var = None
            if self.random_vars:
                random_var = random.choice(self.random_vars)
            batch = self.collate_fn(
                [self.dataset.get(i, random_var) for i in indices])
            if self.pin_memory:
                batch = pin_memory(batch)
            return batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch)

        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration

        while True:
            assert (not self.shutdown and self.batches_outstanding > 0)
            idx, batch = self.data_queue.get()
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch
                continue
            print("확인용@@@", self._process_next_batch(batch))
            return self._process_next_batch(batch)
Esempio n. 3
0
    def collate(batch,
                config: DataConfig,
                field_permutation: bool,
                pin: bool = False):
        state_len = max(map(lambda x: len(x.state), batch))
        action_len = max(map(lambda x: len(x.actions), batch))

        batch = default_collate([
            x.to_dict(state_len, action_len, field_permutation, config)
            for x in batch
        ])
        return pin_memory(batch) if pin else batch
    def join_streams_thread(self, out_queue, device_id, done_event):
        """
        additional thread putting data into a queue to be collected from __iter__
        """
        torch.set_num_threads(1)
        torch.cuda.set_device(device_id)

        while not done_event.is_set():
            data = self.get_batch()
            if (self.pin_memory and not done_event.is_set()
                    and not isinstance(data, ExceptionWrapper)):
                data = pin_memory(data)
            out_queue.put(data, timeout=MP_STATUS_CHECK_INTERVAL)
Esempio n. 5
0
    def join_streams_thread(self, out_queue, device_id, done_event):
        """
        additional thread putting data into a queue to be collected from __iter__
        """
        torch.set_num_threads(1)
        torch.cuda.set_device(device_id)

        for idx, batch_parts in enumerate(self.get_stream_loaders()):
            data = list(chain(*batch_parts))

            data = torch.cat([item[:, None] for item in data], dim=1)
            if (
                not done_event.is_set()
                and not isinstance(data, ExceptionWrapper)
            ):
                data = pin_memory(data)

            out_queue.put(data, timeout=MP_STATUS_CHECK_INTERVAL)

        self._join_memory_thread_done_event.set()
Esempio n. 6
0
    def __iter__(self):

        start, stop = self.data_range

        for i in range(self.n_iter):
            if self.batch_size == 1:
                data = self.dataset[start + i]
            else:
                b_start = start + i * self.batch_size
                b_stop = min(start + (i + 1) * self.batch_size, stop)
                s = slice(b_start, b_stop)
                data = self.dataset[s]

            if not type(data) == torch.Tensor:
                data = torch.tensor(data)

            if self.pin_memory:
                data = pm.pin_memory(data)

            if self.drop_last:
                assert data.shape[0] == self.batch_size

            yield data