Example #1
0
    def get_train_val_loaders(self):
        if self.args.dataset == 'cifar10':
            train_transform, valid_transform = utils._data_transforms_cifar10(self.args)
            train_data = dset.CIFAR10(root=self.args.data, train=True, download=True, transform=train_transform)
        elif self.args.dataset == 'cifar100':
            train_transform, valid_transform = utils._data_transforms_cifar100(self.args)
            train_data = dset.CIFAR100(root=self.args.data, train=True, download=True, transform=train_transform)
        elif self.args.dataset == 'svhn':
            train_transform, valid_transform = utils._data_transforms_svhn(self.args)
            train_data = dset.SVHN(root=self.args.data, split='train', download=True, transform=train_transform)

        num_train = len(train_data)
        indices = list(range(num_train))
        split = int(np.floor(self.args.train_portion * num_train))

        train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
            pin_memory=True, num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
            pin_memory=True, num_workers=2)

        return train_queue, valid_queue, train_transform, valid_transform
Example #2
0
    def get_train_val_loaders(self):
        if self.args.dataset == 'cifar10':
            train_transform, valid_transform = utils._data_transforms_cifar10(
                self.args)
            train_data = dset.CIFAR10(root=self.args.data,
                                      train=True,
                                      download=True,
                                      transform=train_transform)
            valid_data = dset.CIFAR10(root=self.args.data,
                                      train=False,
                                      download=True,
                                      transform=valid_transform)
        elif self.args.dataset == 'cifar100':
            train_transform, valid_transform = utils._data_transforms_cifar100(
                self.args)
            train_data = dset.CIFAR100(root=self.args.data,
                                       train=True,
                                       download=True,
                                       transform=train_transform)
            valid_data = dset.CIFAR100(root=self.args.data,
                                       train=False,
                                       download=True,
                                       transform=valid_transform)
        elif self.args.dataset == 'svhn':
            train_transform, valid_transform = utils._data_transforms_svhn(
                self.args)
            train_data = dset.SVHN(root=self.args.data,
                                   split='train',
                                   download=True,
                                   transform=train_transform)
            valid_data = dset.SVHN(root=self.args.data,
                                   split='test',
                                   download=True,
                                   transform=valid_transform)

        train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=self.args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            valid_data,
            batch_size=self.args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=2)

        return train_queue, valid_queue, train_transform, valid_transform