def get_loader( dataset: GenericSSLDataset, dataset_config: dict, num_dataloader_workers: int, pin_memory: bool, multi_processing_method: str, device: torch.device, sampler_seed=0, get_sampler=get_sampler, worker_init_fn=set_dataloader_seeds, ): """ Get the dataloader for the given satasets and data split Args: dataset (GenericSSLDataset): the dataset object for which dataloader is constructed dataset_config (dict): configuration of the dataset. should be DATA.TRAIN or DATA.TEST settings num_dataloader_workers (int): number of workers per gpu (or cpu) training pin_memory (bool): whether to pin memory or not multi_processing_method (str): method to use. options: forkserver | fork | spawn sampler_seed (int): seed for the sampler. Should be identical per process device (torch.device): training on cuda or cpu get_sampler (get_sampler): function that is used to get the sampler worker_init_fn (None): any function that should be executed during initialization of dataloader workers Returns: Instance of Pytorch DataLoader. The dataloader is wrapped with DataloaderAsyncGPUWrapper or DataloaderSyncGPUWrapper depending on whether user wants to copy data to gpu async or not. """ # pytorch dataloader requires setting the multiprocessing type. setup_multiprocessing_method(multi_processing_method) # we don't need to set the rank, replicas as the Sampler already does so in # it's init function data_sampler = get_sampler(dataset, dataset_config, sampler_seed) collate_function = get_collator(dataset_config["COLLATE_FUNCTION"], dataset_config["COLLATE_FUNCTION_PARAMS"]) # Replace the worker_init_fn with a deterministic one when debugging if dataset_config["USE_DEBUGGING_SAMPLER"]: worker_init_fn = debugging_worker_init_fn # Create the pytorch dataloader dataloader = DataLoader( dataset=dataset, num_workers=num_dataloader_workers, pin_memory=pin_memory, shuffle=False, batch_size=dataset_config["BATCHSIZE_PER_REPLICA"], collate_fn=collate_function, sampler=data_sampler, drop_last=dataset_config["DROP_LAST"], worker_init_fn=worker_init_fn, ) # If the targeted device is CUDA, set up async device copy: # - makes sure that samples are on device # - overlap the copy with the previous batch computation. if device.type == "cuda": if dataset.cfg["DATA"]["ENABLE_ASYNC_GPU_COPY"]: logging.info( "Wrapping the dataloader to async device copies") # NOQA dataloader = DataloaderAsyncGPUWrapper(dataloader) else: logging.info( "Wrapping the dataloader to synchronous device copies") # NOQA dataloader = DataloaderSyncGPUWrapper(dataloader) else: logging.warning("Selecting a CPU device") return dataloader
def build_dataloader( dataset: GenericSSLDataset, dataset_config: dict, num_dataloader_workers: int, pin_memory: bool, multi_processing_method: str, device: torch.device, sampler_seed=0, get_sampler=get_sampler, worker_init_fn=set_dataloader_seeds, **kwargs, ): """ Get the dataloader for the given satasets and data split Args: dataset (GenericSSLDataset): the dataset object for which dataloader is constructed dataset_config (dict): configuration of the dataset. should be DATA.TRAIN or DATA.TEST settings num_dataloader_workers (int): number of workers per gpu (or cpu) training pin_memory (bool): whether to pin memory or not multi_processing_method (str): method to use. options: forkserver | fork | spawn sampler_seed (int): seed for the sampler. Should be identical per process device (torch.device): training on cuda or cpu get_sampler (get_sampler): function that is used to get the sampler worker_init_fn (None): any function that should be executed during initialization of dataloader workers Returns: Instance of Pytorch DataLoader. The dataloader is wrapped with DataloaderAsyncGPUWrapper or DataloaderSyncGPUWrapper depending on whether user wants to copy data to gpu async or not. """ # pytorch dataloader requires setting the multiprocessing type. setup_multiprocessing_method(multi_processing_method) # we don't need to set the rank, replicas as the Sampler already does so in # it's init function data_sampler = get_sampler(dataset, dataset_config, sampler_seed) collate_function = get_collator(dataset_config["COLLATE_FUNCTION"], dataset_config["COLLATE_FUNCTION_PARAMS"]) # Replace the worker_init_fn with a deterministic one when debugging if dataset_config["USE_DEBUGGING_SAMPLER"]: worker_init_fn = debugging_worker_init_fn # Load the labels of the dataset before creating the data loader # or else the load of files will happen on each data loader separately # decreasing performance / hitting quota on data source dataset.load_labels() # Create the pytorch dataloader dataloader = DataLoader( dataset=dataset, num_workers=num_dataloader_workers, pin_memory=pin_memory, shuffle=False, batch_size=dataset_config["BATCHSIZE_PER_REPLICA"], collate_fn=collate_function, sampler=data_sampler, drop_last=dataset_config["DROP_LAST"], worker_init_fn=worker_init_fn, ) enable_async_gpu_copy = dataset.cfg["DATA"]["ENABLE_ASYNC_GPU_COPY"] dataloader = wrap_dataloader(dataloader, enable_async_gpu_copy, device) return dataloader