Exemple #1
0
def get_cifar(cfg):
    """Gets training and validation data loaders for the CIFAR datasets

    Args:
        cfg: A YACS config object.
    """
    logging.info("==> Preparing to load data " + cfg.DATASET.NAME + " at " + cfg.DATASET.ROOT)
    cifar_train_transform = get_transform("cifar", augment=True)
    cifar_test_transform = get_transform("cifar", augment=False)
    # transform = {
    #     'cifar_train': transforms.Compose([
    #         # Data augmentation
    #         transforms.RandomCrop(32, padding=4),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #     ]),
    #     'cifar_test': transforms.Compose([
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #     ])
    # }
    if cfg.DATASET.NAME == "CIFAR10":
        train_set = torchvision.datasets.CIFAR10(
            cfg.DATASET.ROOT, train=True, download=True, transform=cifar_train_transform
        )
        val_set = torchvision.datasets.CIFAR10(
            cfg.DATASET.ROOT, train=False, download=True, transform=cifar_test_transform
        )
    elif cfg.DATASET.NAME == "CIFAR100":
        train_set = torchvision.datasets.CIFAR100(
            cfg.DATASET.ROOT, train=True, download=True, transform=cifar_train_transform
        )
        val_set = torchvision.datasets.CIFAR100(
            cfg.DATASET.ROOT, train=False, download=True, transform=cifar_test_transform
        )
    else:
        raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=cfg.SOLVER.TRAIN_BATCH_SIZE,
        shuffle=True,
        num_workers=cfg.DATASET.NUM_WORKERS,
        pin_memory=True,
        drop_last=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=cfg.SOLVER.TEST_BATCH_SIZE,
        shuffle=False,
        num_workers=cfg.DATASET.NUM_WORKERS,
        pin_memory=True,
    )

    return train_loader, val_loader
Exemple #2
0
def get_cifar(cfg):
    """Gets training and validation data loaders for the CIFAR datasets

    Args:
        cfg: A YACS config object.
    """
    logging.info("==> Preparing to load data " + cfg.DATASET.NAME + " at " +
                 cfg.DATASET.ROOT)
    cifar_train_transform = get_transform("cifar", augment=True)
    cifar_test_transform = get_transform("cifar", augment=False)

    if cfg.DATASET.NAME == "CIFAR10":
        train_set = torchvision.datasets.CIFAR10(
            cfg.DATASET.ROOT,
            train=True,
            download=True,
            transform=cifar_train_transform)
        val_set = torchvision.datasets.CIFAR10(cfg.DATASET.ROOT,
                                               train=False,
                                               download=True,
                                               transform=cifar_test_transform)
    elif cfg.DATASET.NAME == "CIFAR100":
        train_set = torchvision.datasets.CIFAR100(
            cfg.DATASET.ROOT,
            train=True,
            download=True,
            transform=cifar_train_transform)
        val_set = torchvision.datasets.CIFAR100(cfg.DATASET.ROOT,
                                                train=False,
                                                download=True,
                                                transform=cifar_test_transform)
    else:
        raise NotImplementedError

    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=cfg.SOLVER.TRAIN_BATCH_SIZE,
        shuffle=True,
        num_workers=cfg.DATASET.NUM_WORKERS,
        pin_memory=True,
        drop_last=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=cfg.SOLVER.TEST_BATCH_SIZE,
        shuffle=False,
        num_workers=cfg.DATASET.NUM_WORKERS,
        pin_memory=True,
    )

    return train_loader, val_loader
Exemple #3
0
 def __init__(self, data_path, transform_kind):
     super().__init__(n_classes=10)
     self._data_path = data_path
     self._transform = image_transform.get_transform(transform_kind)
Exemple #4
0
    """
    def get_train(self):
        return datasets.SVHN(self._data_path,
                             split="train",
                             transform=self._transform,
                             download=True)

    def get_test(self):
        return datasets.SVHN(self._data_path,
                             split="test",
                             transform=self._transform,
                             download=True)


OFFICE_DOMAINS = ["amazon", "caltech", "dslr", "webcam"]
office_transform = get_transform("office")


class OfficeAccess(MultiDomainImageFolder, DatasetAccess):
    """Common API for office dataset access

    Args:
        root (string): root directory of dataset
        transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
            version. Defaults to office_transform.
        download (bool, optional): Whether to allow downloading the data if not found on disk. Defaults to False.

    References:
        [1] Saenko, K., Kulis, B., Fritz, M. and Darrell, T., 2010, September. Adapting visual category models to
        new domains. In European Conference on Computer Vision (pp. 213-226). Springer, Berlin, Heidelberg.
        [2] Griffin, Gregory and Holub, Alex and Perona, Pietro, 2007. Caltech-256 Object Category Dataset.