예제 #1
0
def dataloader(dataset, batch_size, train, workers, length=None):
    # Dataset
    if dataset == 'mnist':
        mean, std = (0.1307,), (0.3081,)
        transform = get_transform(size=28, padding=0, mean=mean, std=std, preprocess=False)
        dataset = datasets.MNIST('Data', train=train, download=True, transform=transform)
    if dataset == 'cifar10':
        mean, std = (0.491, 0.482, 0.447), (0.247, 0.243, 0.262)
        transform = get_transform(size=32, padding=4, mean=mean, std=std, preprocess=train)
        dataset = datasets.CIFAR10('Data', train=train, download=True, transform=transform) 
    if dataset == 'cifar100':
        mean, std = (0.507, 0.487, 0.441), (0.267, 0.256, 0.276)
        transform = get_transform(size=32, padding=4, mean=mean, std=std, preprocess=train)
        dataset = datasets.CIFAR100('Data', train=train, download=True, transform=transform)
    if dataset == 'tiny-imagenet':
        mean, std = (0.480, 0.448, 0.397), (0.276, 0.269, 0.282)
        transform = get_transform(size=64, padding=4, mean=mean, std=std, preprocess=train)
        dataset = custom_datasets.TINYIMAGENET('Data', train=train, download=True, transform=transform)
    if dataset == 'imagenet':
        mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
        if train:
            transform = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.2,1.)),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)])
        else:
            transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)])
        folder = 'Data/imagenet_raw/{}'.format('train' if train else 'val')
        dataset = datasets.ImageFolder(folder, transform=transform)
    
    # Dataloader
    use_cuda = torch.cuda.is_available()
    kwargs = {'num_workers': workers, 'pin_memory': True} if use_cuda else {}
    shuffle = train is True
    if length is not None:
        indices = torch.randperm(len(dataset))[:length]
        dataset = torch.utils.data.Subset(dataset, indices)

    dataloader = torch.utils.data.DataLoader(dataset=dataset, 
                                             batch_size=batch_size, 
                                             shuffle=shuffle, 
                                             **kwargs)

    return dataloader
예제 #2
0
def dataloader(dataset, batch_size, train, workers, length=None, args=None):
    # Dataset
    if dataset == 'mnist':
        mean, std = (0.1307, ), (0.3081, )
        transform = get_transform(size=28,
                                  padding=0,
                                  mean=mean,
                                  std=std,
                                  preprocess=False)
        dataset = datasets.MNIST('Data',
                                 train=train,
                                 download=True,
                                 transform=transform)
    if dataset == 'cifar10':
        mean, std = (0.491, 0.482, 0.447), (0.247, 0.243, 0.262)
        transform = get_transform(size=32,
                                  padding=4,
                                  mean=mean,
                                  std=std,
                                  preprocess=train)
        dataset = datasets.CIFAR10('Data',
                                   train=train,
                                   download=True,
                                   transform=transform)
    if dataset == 'cifar100':
        mean, std = (0.507, 0.487, 0.441), (0.267, 0.256, 0.276)
        transform = get_transform(size=32,
                                  padding=4,
                                  mean=mean,
                                  std=std,
                                  preprocess=train)
        dataset = datasets.CIFAR100('Data',
                                    train=train,
                                    download=True,
                                    transform=transform)
    if dataset == 'tiny-imagenet':
        mean, std = (0.480, 0.448, 0.397), (0.276, 0.269, 0.282)
        transform = get_transform(size=64,
                                  padding=4,
                                  mean=mean,
                                  std=std,
                                  preprocess=train)
        dataset = custom_datasets.TINYIMAGENET('Data',
                                               train=train,
                                               download=True,
                                               transform=transform)
    if dataset == 'imagenet':
        mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
        if train:
            transform = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        else:
            transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        folder = 'Data/imagenet_raw/{}'.format('train' if train else 'val')
        dataset = datasets.ImageFolder(folder, transform=transform)

    # Dataloader
    use_cuda = torch.cuda.is_available()
    kwargs = {'num_workers': workers, 'pin_memory': True} if use_cuda else {}
    shuffle = train is True
    if length is not None:
        indices = torch.randperm(len(dataset))[:length]
        dataset = torch.utils.data.Subset(dataset, indices)

    # ## when world_size and rank is given then use distributed sampler
    # if world_size is not None and rank is not None:
    #     print('creating sampler to divide the data for dataloader')
    #     sampler = torch.utils.data.DistributedSampler(
    #         dataset,
    #         num_replicas=world_size,
    #         rank=rank
    #     )
    #     return torch.utils.data.DataLoader(
    #         dataset=dataset,
    #         batch_size=batch_size,
    #         sampler=sampler,
    #         shuffle=False,
    #         **kwargs
    #     )

    if args is not None and args.ddp:
        print('creating sampler to divide the data for dataloader')
        # sampler = torch.utils.data.DistributedSampler(
        #     dataset, num_replicas=args.gpu_count, rank = args.gpu_id
        # )
        sampler = torch.utils.data.DistributedSampler(dataset)
    else:
        sampler = None

    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=batch_size,
                                             shuffle=shuffle
                                             and (sampler is None),
                                             sampler=sampler,
                                             **kwargs)

    return dataloader, sampler
예제 #3
0
def dataloader(dataset,
               batch_size,
               train,
               workers,
               corrupt_prob=0.0,
               length=None,
               seed=None):
    print("Which dataset?", train)
    # Dataset
    if dataset == 'mnist':
        mean, std = (0.1307, ), (0.3081, )
        transform = get_transform(size=28,
                                  padding=0,
                                  mean=mean,
                                  std=std,
                                  preprocess=False)
        dataset = datasets.MNIST('Data',
                                 train=train,
                                 download=True,
                                 transform=transform)
    if dataset == 'cifar10':
        assert (train in ['train', 'val', 'trainval', 'test'])
        mean, std = (0.491, 0.482, 0.447), (0.247, 0.243, 0.262)
        transform = get_transform(size=32,
                                  padding=4,
                                  mean=mean,
                                  std=std,
                                  preprocess=(train in ['train', 'trainval']))
        dataset = cifar.get_cifar_dataset(train, transform, seed=seed)
    if dataset == 'cifar100':
        mean, std = (0.507, 0.487, 0.441), (0.267, 0.256, 0.276)
        transform = get_transform(size=32,
                                  padding=4,
                                  mean=mean,
                                  std=std,
                                  preprocess=train)
        dataset = datasets.CIFAR100('Data',
                                    train=train,
                                    download=True,
                                    transform=transform)
    if dataset == 'tiny-imagenet':
        mean, std = (0.480, 0.448, 0.397), (0.276, 0.269, 0.282)
        transform = get_transform(size=64,
                                  padding=4,
                                  mean=mean,
                                  std=std,
                                  preprocess=train)
        dataset = custom_datasets.TINYIMAGENET('Data',
                                               train=train,
                                               download=True,
                                               transform=transform)
    if dataset == 'imagenet':
        mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
        if train:
            transform = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        else:
            transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        folder = 'Data/imagenet_raw/{}'.format('train' if train else 'val')
        dataset = datasets.ImageFolder(folder, transform=transform)

    # corrupt labels as needed
    # only supports 10-class datasets for now
    # old_seed = torch.seed()
    SEED = 2
    if corrupt_prob > 0.0:
        full_loader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=200000,
                                                  shuffle=False)
        for batch in full_loader:
            images, labels = batch
        mask = torch.empty(labels.shape).fill_(corrupt_prob)
        torch.manual_seed(SEED)
        mask = torch.bernoulli(mask).bool()

        n_mask = torch.sum(mask)
        print("Number of labels being corrupted", n_mask.item())
        labels[mask] = torch.randint(10, (n_mask, ))
        dataset = TensorDataset(images, labels)
    # torch.manual_seed(old_seed) (Breaks with Torch 1.4.0, not with 1.6.0)

    # Dataloader
    use_cuda = torch.cuda.is_available()
    kwargs = {'num_workers': workers, 'pin_memory': True} if use_cuda else {}
    shuffle = train in ['train', 'trainval']
    print("Will I be shuffling?", shuffle)
    if length is not None:
        g = torch.Generator()
        g.manual_seed(0)
        indices = torch.randperm(len(dataset), generator=g)[:length]
        dataset = torch.utils.data.Subset(dataset, indices)

    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=batch_size,
                                             shuffle=shuffle,
                                             **kwargs)

    return dataloader