Ejemplo n.º 1
0
def get_loader(
    dataset: GenericSSLDataset,
    dataset_config: dict,
    num_dataloader_workers: int,
    pin_memory: bool,
    multi_processing_method: str,
    device: torch.device,
    sampler_seed=0,
    get_sampler=get_sampler,
    worker_init_fn=set_dataloader_seeds,
):
    """
    Get the dataloader for the given satasets and data split

    Args:
        dataset (GenericSSLDataset):    the dataset object for which dataloader is constructed
        dataset_config (dict):          configuration of the dataset.
                                        should be DATA.TRAIN or DATA.TEST settings
        num_dataloader_workers (int):   number of workers per gpu (or cpu) training
        pin_memory (bool):              whether to pin memory or not
        multi_processing_method (str):  method to use. options: forkserver | fork | spawn
        sampler_seed (int):             seed for the sampler. Should be identical per process
        device (torch.device):          training on cuda or cpu
        get_sampler (get_sampler):      function that is used to get the sampler
        worker_init_fn (None):          any function that should be executed during
                                        initialization of dataloader workers

    Returns:
        Instance of Pytorch DataLoader. The dataloader is wrapped with
        DataloaderAsyncGPUWrapper or DataloaderSyncGPUWrapper depending
        on whether user wants to copy data to gpu async or not.
    """

    # pytorch dataloader requires setting the multiprocessing type.
    setup_multiprocessing_method(multi_processing_method)

    # we don't need to set the rank, replicas as the Sampler already does so in
    # it's init function
    data_sampler = get_sampler(dataset, dataset_config, sampler_seed)
    collate_function = get_collator(dataset_config["COLLATE_FUNCTION"],
                                    dataset_config["COLLATE_FUNCTION_PARAMS"])

    # Replace the worker_init_fn with a deterministic one when debugging
    if dataset_config["USE_DEBUGGING_SAMPLER"]:
        worker_init_fn = debugging_worker_init_fn

    # Create the pytorch dataloader
    dataloader = DataLoader(
        dataset=dataset,
        num_workers=num_dataloader_workers,
        pin_memory=pin_memory,
        shuffle=False,
        batch_size=dataset_config["BATCHSIZE_PER_REPLICA"],
        collate_fn=collate_function,
        sampler=data_sampler,
        drop_last=dataset_config["DROP_LAST"],
        worker_init_fn=worker_init_fn,
    )

    # If the targeted device is CUDA, set up async device copy:
    # - makes sure that samples are on device
    # - overlap the copy with the previous batch computation.
    if device.type == "cuda":
        if dataset.cfg["DATA"]["ENABLE_ASYNC_GPU_COPY"]:
            logging.info(
                "Wrapping the dataloader to async device copies")  # NOQA
            dataloader = DataloaderAsyncGPUWrapper(dataloader)
        else:
            logging.info(
                "Wrapping the dataloader to synchronous device copies")  # NOQA
            dataloader = DataloaderSyncGPUWrapper(dataloader)

    else:
        logging.warning("Selecting a CPU device")

    return dataloader
Ejemplo n.º 2
0
def build_dataloader(
    dataset: GenericSSLDataset,
    dataset_config: dict,
    num_dataloader_workers: int,
    pin_memory: bool,
    multi_processing_method: str,
    device: torch.device,
    sampler_seed=0,
    get_sampler=get_sampler,
    worker_init_fn=set_dataloader_seeds,
    **kwargs,
):
    """
    Get the dataloader for the given satasets and data split

    Args:
        dataset (GenericSSLDataset):    the dataset object for which dataloader is constructed
        dataset_config (dict):          configuration of the dataset.
                                        should be DATA.TRAIN or DATA.TEST settings
        num_dataloader_workers (int):   number of workers per gpu (or cpu) training
        pin_memory (bool):              whether to pin memory or not
        multi_processing_method (str):  method to use. options: forkserver | fork | spawn
        sampler_seed (int):             seed for the sampler. Should be identical per process
        device (torch.device):          training on cuda or cpu
        get_sampler (get_sampler):      function that is used to get the sampler
        worker_init_fn (None):          any function that should be executed during
                                        initialization of dataloader workers

    Returns:
        Instance of Pytorch DataLoader. The dataloader is wrapped with
        DataloaderAsyncGPUWrapper or DataloaderSyncGPUWrapper depending
        on whether user wants to copy data to gpu async or not.
    """

    # pytorch dataloader requires setting the multiprocessing type.
    setup_multiprocessing_method(multi_processing_method)

    # we don't need to set the rank, replicas as the Sampler already does so in
    # it's init function
    data_sampler = get_sampler(dataset, dataset_config, sampler_seed)
    collate_function = get_collator(dataset_config["COLLATE_FUNCTION"],
                                    dataset_config["COLLATE_FUNCTION_PARAMS"])

    # Replace the worker_init_fn with a deterministic one when debugging
    if dataset_config["USE_DEBUGGING_SAMPLER"]:
        worker_init_fn = debugging_worker_init_fn

    # Load the labels of the dataset before creating the data loader
    # or else the load of files will happen on each data loader separately
    # decreasing performance / hitting quota on data source
    dataset.load_labels()

    # Create the pytorch dataloader
    dataloader = DataLoader(
        dataset=dataset,
        num_workers=num_dataloader_workers,
        pin_memory=pin_memory,
        shuffle=False,
        batch_size=dataset_config["BATCHSIZE_PER_REPLICA"],
        collate_fn=collate_function,
        sampler=data_sampler,
        drop_last=dataset_config["DROP_LAST"],
        worker_init_fn=worker_init_fn,
    )
    enable_async_gpu_copy = dataset.cfg["DATA"]["ENABLE_ASYNC_GPU_COPY"]
    dataloader = wrap_dataloader(dataloader, enable_async_gpu_copy, device)

    return dataloader