def build_dataloader_and_sampler( dataset_instance: mmf_typings.DatasetType, training_config: mmf_typings.DictConfig ) -> mmf_typings.DataLoaderAndSampler: """Builds and returns a dataloader along with its sample Args: dataset_instance (mmf_typings.DatasetType): Instance of dataset for which dataloader has to be created training_config (mmf_typings.DictConfig): Training configuration; required for infering params for dataloader Returns: mmf_typings.DataLoaderAndSampler: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator num_workers = training_config.num_workers pin_memory = training_config.pin_memory other_args = {} # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) other_args['shuffle'] = training_config.shuffle loader = torch.utils.data.DataLoader( dataset=dataset_instance, pin_memory=pin_memory, collate_fn=BatchCollator(dataset_instance.dataset_name, dataset_instance.dataset_type), num_workers=num_workers, drop_last=False, # see also MultiDatasetLoader.__len__ **other_args, ) if num_workers >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)
def build_dataloader_and_sampler( dataset_instance: mmf_typings.DatasetType, training_config: mmf_typings.DictConfig ) -> mmf_typings.DataLoaderAndSampler: """Builds and returns a dataloader along with its sample Args: dataset_instance (mmf_typings.DatasetType): Instance of dataset for which dataloader has to be created training_config (mmf_typings.DictConfig): Training configuration; required for infering params for dataloader Returns: mmf_typings.DataLoaderAndSampler: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator num_workers = training_config.num_workers pin_memory = training_config.pin_memory other_args = {} other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) loader = torch.utils.data.DataLoader( dataset=dataset_instance, pin_memory=pin_memory, collate_fn=BatchCollator( dataset_instance.dataset_name, dataset_instance.dataset_type ), num_workers=num_workers, **other_args, ) if num_workers >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)
def test_call(self): batch_collator = BatchCollator("vqa2", "train") sample_list = test_utils.build_random_sample_list() sample_list = batch_collator(sample_list) # Test already build sample list self.assertEqual(sample_list.dataset_name, "vqa2") self.assertEqual(sample_list.dataset_type, "train") sample = Sample() sample.a = torch.tensor([1, 2], dtype=torch.int) # Test list of samples sample_list = batch_collator([sample, sample]) self.assertTrue( test_utils.compare_tensors( sample_list.a, torch.tensor([[1, 2], [1, 2]], dtype=torch.int))) # Test IterableDataset case sample_list = test_utils.build_random_sample_list() new_sample_list = batch_collator([sample_list]) self.assertEqual(new_sample_list, sample_list)
def build_dataloader(self, config, dataset): training = self._global_config.training num_workers = training.num_workers pin_memory = training.pin_memory other_args = {} self._add_extra_args_for_dataloader(dataset, config, other_args) loader = DataLoader(dataset=dataset, pin_memory=pin_memory, collate_fn=BatchCollator(dataset.name, dataset.dataset_type), num_workers=num_workers, **other_args) if num_workers >= 0: # Suppress leaking semaphore warning os.environ[ "PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = self._dataset_type return loader, other_args.get("sampler", None)
def build_dataloader_and_sampler( dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: """Builds and returns a dataloader along with its sample Args: dataset_instance (torch.utils.data.Dataset): Instance of dataset for which dataloader has to be created datamodule_config (omegaconf.DictConfig): Datamodule configuration; required for infering params for dataloader Returns: Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator training_config = get_global_config("training") # Support params coming in from dataloader params other_args = { "num_workers": datamodule_config.get( "num_workers", training_config.get("num_workers", 4) ), "pin_memory": datamodule_config.get( "pin_memory", training_config.get("pin_memory", False) ), "shuffle": datamodule_config.get("shuffle", None), "batch_size": datamodule_config.get("batch_size", None), } if version.parse(torch.__version__) >= version.parse("1.8"): # only use persistent workers in PyTorch 1.8 or higher # (PyTorch 1.7 also has this option but doesn't support it correctly due to # https://github.com/pytorch/pytorch/issues/48370) other_args["persistent_workers"] = ( datamodule_config.get( "persistent_workers", training_config.get("persistent_workers", True) ), ) if other_args["persistent_workers"] and other_args["num_workers"] == 0: logger.warning( "persistent_workers cannot be used together with num_workers == 0; " "setting persistent_workers to False" ) other_args["persistent_workers"] = False # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) else: other_args.pop("shuffle") # Set drop_last=True when using XLA to have constant batch size. # In this case we also need to set drop_last=True in DistributedSampler. loader = torch.utils.data.DataLoader( dataset=dataset_instance, collate_fn=BatchCollator( dataset_instance.dataset_name, dataset_instance.dataset_type ), drop_last=is_xla(), # see also MultiDatasetLoader.__len__ **other_args, ) if is_xla(): device = xm.xla_device() loader = xla_pl.MpDeviceLoader(loader, device) if other_args["num_workers"] >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)
def build_dataloader_and_sampler( dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig ) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: """Builds and returns a dataloader along with its sample Args: dataset_instance (torch.utils.data.Dataset): Instance of dataset for which dataloader has to be created datamodule_config (omegaconf.DictConfig): Datamodule configuration; required for infering params for dataloader Returns: Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator training_config = get_global_config("training") # Support params coming in from dataloader params other_args = { "num_workers": datamodule_config.get("num_workers", training_config.get("num_workers", 4)), "pin_memory": datamodule_config.get("pin_memory", training_config.get("pin_memory", False)), "shuffle": datamodule_config.get("shuffle", None), "batch_size": datamodule_config.get("batch_size", None), } # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) else: other_args.pop("shuffle") loader = torch.utils.data.DataLoader( dataset=dataset_instance, collate_fn=BatchCollator(dataset_instance.dataset_name, dataset_instance.dataset_type), drop_last=is_xla(), # see also MultiDatasetLoader.__len__ **other_args, ) if is_xla(): device = xm.xla_device() loader = xla_pl.MpDeviceLoader(loader, device) if other_args["num_workers"] >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)
def build_dataloader_and_sampler( dataset_instance: mmf_typings.DatasetType, training_config: mmf_typings.DictConfig ) -> mmf_typings.DataLoaderAndSampler: """Builds and returns a dataloader along with its sample Args: dataset_instance (mmf_typings.DatasetType): Instance of dataset for which dataloader has to be created training_config (mmf_typings.DictConfig): Training configuration; required for infering params for dataloader Returns: mmf_typings.DataLoaderAndSampler: Tuple of Dataloader and Sampler instance """ from mmf.common.batch_collator import BatchCollator num_workers = training_config.num_workers pin_memory = training_config.pin_memory other_args = {} # IterableDataset returns batches directly, so no need to add Sampler # or batch size as user is expected to control those. This is a fine # assumption for now to not support single item based IterableDataset # as it will add unnecessary complexity and config parameters # to the codebase if not isinstance(dataset_instance, torch.utils.data.IterableDataset): other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) if str(dataset_instance.dataset_type) == 'train': train_transform = transforms.Compose([ transforms.RandomRotation(30), transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3), transforms.RandomPerspective(distortion_scale=0.5, p=0.5, interpolation=3, fill=0), transforms.Grayscale(num_output_channels=3), transforms.RandomHorizontalFlip(), transforms.ToTensor() ]) dataset_instance.transform = train_transform print('With Augmented') else: dataset_instance.transform = None print('With out Augmentation ') print(str(dataset_instance.transform)) loader = torch.utils.data.DataLoader( dataset=dataset_instance, pin_memory=pin_memory, collate_fn=BatchCollator(dataset_instance.dataset_name, dataset_instance.dataset_type), num_workers=num_workers, drop_last=False, # see also MultiDatasetLoader.__len__ **other_args, ) if num_workers >= 0: # Suppress leaking semaphore warning os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" loader.dataset_type = dataset_instance.dataset_type return loader, other_args.get("sampler", None)