예제 #1
0
 def __init__(
     self,
     dataset: Dataset,
     batch_size: int = 1,
     shuffle: bool = False,
     sampler: Optional[Sampler] = None,
     batch_sampler: Optional[Sampler] = None,
     num_workers: int = 0,
     pin_memory: bool = False,
     drop_last: bool = False,
     timeout: float = 0.0,
     multiprocessing_context: Optional[Callable] = None,
 ) -> None:
     if num_workers == 0:
         # when num_workers > 0, random states are determined by worker_init_fn
         # this is to make the behavior consistent when num_workers == 0
         # torch.int64 doesn't work well on some versions of windows
         _seed = torch.empty((), dtype=torch.int32).random_(generator=None).item()
         set_rnd(dataset, int(_seed))
     super().__init__(  # type: ignore[call-overload]
         dataset=dataset,
         batch_size=batch_size,
         shuffle=shuffle,
         sampler=sampler,
         batch_sampler=batch_sampler,
         num_workers=num_workers,
         collate_fn=list_data_collate,
         pin_memory=pin_memory,
         drop_last=drop_last,
         timeout=timeout,
         worker_init_fn=worker_init_fn,
         multiprocessing_context=multiprocessing_context,
     )
예제 #2
0
    def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None:
        if num_workers == 0:
            # when num_workers > 0, random states are determined by worker_init_fn
            # this is to make the behavior consistent when num_workers == 0
            # torch.int64 doesn't work well on some versions of windows
            _seed = torch.empty((), dtype=torch.int32).random_(generator=None).item()
            set_rnd(dataset, int(_seed))
        if "collate_fn" not in kwargs:
            kwargs.update({"collate_fn": list_data_collate})
        if "worker_init_fn" not in kwargs:
            kwargs.update({"worker_init_fn": worker_init_fn})

        super().__init__(  # type: ignore[call-overload]
            dataset=dataset,
            num_workers=num_workers,
            **kwargs,
        )