def __init__( self, dataset: Dataset, batch_size: int = 1, shuffle: bool = False, sampler: Optional[Sampler] = None, batch_sampler: Optional[Sampler] = None, num_workers: int = 0, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0.0, multiprocessing_context: Optional[Callable] = None, ) -> None: if num_workers == 0: # when num_workers > 0, random states are determined by worker_init_fn # this is to make the behavior consistent when num_workers == 0 # torch.int64 doesn't work well on some versions of windows _seed = torch.empty((), dtype=torch.int32).random_(generator=None).item() set_rnd(dataset, int(_seed)) super().__init__( # type: ignore[call-overload] dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=list_data_collate, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, )
def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None: if num_workers == 0: # when num_workers > 0, random states are determined by worker_init_fn # this is to make the behavior consistent when num_workers == 0 # torch.int64 doesn't work well on some versions of windows _seed = torch.empty((), dtype=torch.int32).random_(generator=None).item() set_rnd(dataset, int(_seed)) if "collate_fn" not in kwargs: kwargs.update({"collate_fn": list_data_collate}) if "worker_init_fn" not in kwargs: kwargs.update({"worker_init_fn": worker_init_fn}) super().__init__( # type: ignore[call-overload] dataset=dataset, num_workers=num_workers, **kwargs, )