Esempio n. 1
0
def build_dataloader_and_sampler(
    dataset_instance: mmf_typings.DatasetType,
    training_config: mmf_typings.DictConfig
) -> mmf_typings.DataLoaderAndSampler:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (mmf_typings.DatasetType): Instance of dataset for which
            dataloader has to be created
        training_config (mmf_typings.DictConfig): Training configuration; required
            for infering params for dataloader

    Returns:
        mmf_typings.DataLoaderAndSampler: Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    num_workers = training_config.num_workers
    pin_memory = training_config.pin_memory

    other_args = {}

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance,
                                                    other_args)
    other_args['shuffle'] = training_config.shuffle
    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        pin_memory=pin_memory,
        collate_fn=BatchCollator(dataset_instance.dataset_name,
                                 dataset_instance.dataset_type),
        num_workers=num_workers,
        drop_last=False,  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if num_workers >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)
Esempio n. 2
0
File: build.py Progetto: zpppy/mmf
def build_dataloader_and_sampler(
    dataset_instance: mmf_typings.DatasetType, training_config: mmf_typings.DictConfig
) -> mmf_typings.DataLoaderAndSampler:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (mmf_typings.DatasetType): Instance of dataset for which
            dataloader has to be created
        training_config (mmf_typings.DictConfig): Training configuration; required
            for infering params for dataloader

    Returns:
        mmf_typings.DataLoaderAndSampler: Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    num_workers = training_config.num_workers
    pin_memory = training_config.pin_memory

    other_args = {}

    other_args = _add_extra_args_for_dataloader(dataset_instance, other_args)

    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        pin_memory=pin_memory,
        collate_fn=BatchCollator(
            dataset_instance.dataset_name, dataset_instance.dataset_type
        ),
        num_workers=num_workers,
        **other_args,
    )

    if num_workers >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)
Esempio n. 3
0
    def test_call(self):
        batch_collator = BatchCollator("vqa2", "train")
        sample_list = test_utils.build_random_sample_list()
        sample_list = batch_collator(sample_list)

        # Test already build sample list
        self.assertEqual(sample_list.dataset_name, "vqa2")
        self.assertEqual(sample_list.dataset_type, "train")

        sample = Sample()
        sample.a = torch.tensor([1, 2], dtype=torch.int)

        # Test list of samples
        sample_list = batch_collator([sample, sample])
        self.assertTrue(
            test_utils.compare_tensors(
                sample_list.a, torch.tensor([[1, 2], [1, 2]],
                                            dtype=torch.int)))

        # Test IterableDataset case
        sample_list = test_utils.build_random_sample_list()
        new_sample_list = batch_collator([sample_list])
        self.assertEqual(new_sample_list, sample_list)
Esempio n. 4
0
    def build_dataloader(self, config, dataset):
        training = self._global_config.training
        num_workers = training.num_workers
        pin_memory = training.pin_memory

        other_args = {}

        self._add_extra_args_for_dataloader(dataset, config, other_args)

        loader = DataLoader(dataset=dataset,
                            pin_memory=pin_memory,
                            collate_fn=BatchCollator(dataset.name,
                                                     dataset.dataset_type),
                            num_workers=num_workers,
                            **other_args)

        if num_workers >= 0:
            # Suppress leaking semaphore warning
            os.environ[
                "PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

        loader.dataset_type = self._dataset_type

        return loader, other_args.get("sampler", None)
Esempio n. 5
0
def build_dataloader_and_sampler(
    dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig
) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (torch.utils.data.Dataset): Instance of dataset for which
            dataloader has to be created
        datamodule_config (omegaconf.DictConfig): Datamodule configuration; required
            for infering params for dataloader

    Returns:
        Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
            Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    training_config = get_global_config("training")
    # Support params coming in from dataloader params
    other_args = {
        "num_workers": datamodule_config.get(
            "num_workers", training_config.get("num_workers", 4)
        ),
        "pin_memory": datamodule_config.get(
            "pin_memory", training_config.get("pin_memory", False)
        ),
        "shuffle": datamodule_config.get("shuffle", None),
        "batch_size": datamodule_config.get("batch_size", None),
    }
    if version.parse(torch.__version__) >= version.parse("1.8"):
        # only use persistent workers in PyTorch 1.8 or higher
        # (PyTorch 1.7 also has this option but doesn't support it correctly due to
        # https://github.com/pytorch/pytorch/issues/48370)
        other_args["persistent_workers"] = (
            datamodule_config.get(
                "persistent_workers", training_config.get("persistent_workers", True)
            ),
        )
        if other_args["persistent_workers"] and other_args["num_workers"] == 0:
            logger.warning(
                "persistent_workers cannot be used together with num_workers == 0; "
                "setting persistent_workers to False"
            )
            other_args["persistent_workers"] = False

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance, other_args)
    else:
        other_args.pop("shuffle")

    # Set drop_last=True when using XLA to have constant batch size.
    # In this case we also need to set drop_last=True in DistributedSampler.
    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        collate_fn=BatchCollator(
            dataset_instance.dataset_name, dataset_instance.dataset_type
        ),
        drop_last=is_xla(),  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if is_xla():
        device = xm.xla_device()
        loader = xla_pl.MpDeviceLoader(loader, device)

    if other_args["num_workers"] >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)
Esempio n. 6
0
def build_dataloader_and_sampler(
    dataset_instance: torch.utils.data.Dataset, datamodule_config: DictConfig
) -> Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (torch.utils.data.Dataset): Instance of dataset for which
            dataloader has to be created
        datamodule_config (omegaconf.DictConfig): Datamodule configuration; required
            for infering params for dataloader

    Returns:
        Tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.Sampler]]:
            Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    training_config = get_global_config("training")
    # Support params coming in from dataloader params
    other_args = {
        "num_workers":
        datamodule_config.get("num_workers",
                              training_config.get("num_workers", 4)),
        "pin_memory":
        datamodule_config.get("pin_memory",
                              training_config.get("pin_memory", False)),
        "shuffle":
        datamodule_config.get("shuffle", None),
        "batch_size":
        datamodule_config.get("batch_size", None),
    }

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance,
                                                    other_args)
    else:
        other_args.pop("shuffle")

    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        collate_fn=BatchCollator(dataset_instance.dataset_name,
                                 dataset_instance.dataset_type),
        drop_last=is_xla(),  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if is_xla():
        device = xm.xla_device()
        loader = xla_pl.MpDeviceLoader(loader, device)

    if other_args["num_workers"] >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)
Esempio n. 7
0
def build_dataloader_and_sampler(
    dataset_instance: mmf_typings.DatasetType,
    training_config: mmf_typings.DictConfig
) -> mmf_typings.DataLoaderAndSampler:
    """Builds and returns a dataloader along with its sample

    Args:
        dataset_instance (mmf_typings.DatasetType): Instance of dataset for which
            dataloader has to be created
        training_config (mmf_typings.DictConfig): Training configuration; required
            for infering params for dataloader

    Returns:
        mmf_typings.DataLoaderAndSampler: Tuple of Dataloader and Sampler instance
    """
    from mmf.common.batch_collator import BatchCollator

    num_workers = training_config.num_workers
    pin_memory = training_config.pin_memory

    other_args = {}

    # IterableDataset returns batches directly, so no need to add Sampler
    # or batch size as user is expected to control those. This is a fine
    # assumption for now to not support single item based IterableDataset
    # as it will add unnecessary complexity and config parameters
    # to the codebase
    if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
        other_args = _add_extra_args_for_dataloader(dataset_instance,
                                                    other_args)

    if str(dataset_instance.dataset_type) == 'train':
        train_transform = transforms.Compose([
            transforms.RandomRotation(30),
            transforms.ColorJitter(brightness=0.3,
                                   contrast=0.3,
                                   saturation=0.3,
                                   hue=0.3),
            transforms.RandomPerspective(distortion_scale=0.5,
                                         p=0.5,
                                         interpolation=3,
                                         fill=0),
            transforms.Grayscale(num_output_channels=3),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])

        dataset_instance.transform = train_transform

        print('With Augmented')

    else:
        dataset_instance.transform = None

        print('With out Augmentation ')

    print(str(dataset_instance.transform))

    loader = torch.utils.data.DataLoader(
        dataset=dataset_instance,
        pin_memory=pin_memory,
        collate_fn=BatchCollator(dataset_instance.dataset_name,
                                 dataset_instance.dataset_type),
        num_workers=num_workers,
        drop_last=False,  # see also MultiDatasetLoader.__len__
        **other_args,
    )

    if num_workers >= 0:
        # Suppress leaking semaphore warning
        os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"

    loader.dataset_type = dataset_instance.dataset_type

    return loader, other_args.get("sampler", None)