Exemple #1
0
def get_val_dataset(p,
                    transform=None,
                    to_neighbors_dataset=False,
                    to_similarity_dataset=False,
                    use_negatives=False,
                    use_simpred=False):
    # Base dataset
    if p['val_db_name'] == 'cifar-10':
        from data.cifar import CIFAR10
        dataset = CIFAR10(train=False, transform=transform, download=True)

    elif p['val_db_name'] == 'cifar-20':
        from data.cifar import CIFAR20
        dataset = CIFAR20(train=False, transform=transform, download=True)

    elif p['val_db_name'] == 'stl-10':
        from data.stl import STL10
        dataset = STL10(split='test', transform=transform, download=True)

    elif p['train_db_name'] in [
            'impact_kb', 'impact_full_balanced', 'impact_full_imbalanced',
            'hdi_balanced', 'hdi_imbalanced', 'tobacco3482', 'rvl-cdip',
            'wpi_demo'
    ]:
        from data.imagefolderwrapper import ImageFolderWrapper
        root = MyPath.db_root_dir(p['train_db_name'])
        dataset = ImageFolderWrapper(root, split="test", transform=transform)

    elif p['val_db_name'] == 'imagenet':
        from data.imagenet import ImageNet
        dataset = ImageNet(split='val', transform=transform)

    elif p['val_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']:
        from data.imagenet import ImageNetSubset
        subset_file = './data/imagenet_subsets/%s.txt' % (p['val_db_name'])
        dataset = ImageNetSubset(subset_file=subset_file,
                                 split='val',
                                 transform=transform)

    else:
        raise ValueError('Invalid validation dataset {}'.format(
            p['val_db_name']))

    # Wrap into other dataset (__getitem__ changes)
    if to_neighbors_dataset:  # Dataset returns an image and one of its nearest neighbors.
        from data.custom_dataset import NeighborsDataset
        knn_indices = np.load(p['topk_neighbors_val_path'])

        if use_negatives:
            kfn_indices = np.load(p['topk_furthest_val_path'])
        else:
            kfn_indices = None

        dataset = NeighborsDataset(dataset, knn_indices, kfn_indices,
                                   use_simpred, 5, 5)  # Only use 5
    elif to_similarity_dataset:  # Dataset returns an image and another random image.
        from data.custom_dataset import SimilarityDataset
        dataset = SimilarityDataset(dataset)

    return dataset
def get_val_dataset(p, transform=None, to_neighbors_dataset=False):
    # Base dataset
    if p['val_db_name'] == 'cifar-10':
        from data.cifar import CIFAR10
        dataset = CIFAR10(train=False, transform=transform, download=True)
    
    elif p['val_db_name'] == 'cifar-20':
        from data.cifar import CIFAR20
        dataset = CIFAR20(train=False, transform=transform, download=True)

    elif p['val_db_name'] == 'stl-10':
        from data.stl import STL10
        dataset = STL10(split='test', transform=transform, download=True)
    
    elif p['val_db_name'] == 'imagenet':
        from data.imagenet import ImageNet
        dataset = ImageNet(split='val', transform=transform)
    
    elif p['val_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']:
        from data.imagenet import ImageNetSubset
        subset_file = './data/imagenet_subsets/%s.txt' %(p['val_db_name'])
        dataset = ImageNetSubset(subset_file=subset_file, split='val', transform=transform)
    
    else:
        raise ValueError('Invalid validation dataset {}'.format(p['val_db_name']))
    
    # Wrap into other dataset (__getitem__ changes) 
    if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors.
        from data.custom_dataset import NeighborsDataset
        indices = np.load(p['topk_neighbors_val_path'])
        dataset = NeighborsDataset(dataset, indices, 5) # Only use 5

    return dataset
Exemple #3
0
def get_train_dataset(p,
                      transform,
                      to_augmented_dataset=False,
                      to_neighbors_dataset=False,
                      split=None):
    # Base dataset
    if p['train_db_name'] == 'cifar-10':
        from data.cifar import CIFAR10
        dataset = CIFAR10(train=True, transform=transform, download=True)

    elif p['train_db_name'] == 'cifar-20':
        from data.cifar import CIFAR20
        dataset = CIFAR20(train=True, transform=transform, download=True)

    elif p['train_db_name'] == 'stl-10':
        from data.stl import STL10
        dataset = STL10(split=split, transform=transform, download=True)

    elif p['train_db_name'] == 'imagenet':
        from data.imagenet import ImageNet
        dataset = ImageNet(split='train', transform=transform)

    elif p['train_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']:
        from data.imagenet import ImageNetSubset
        subset_file = './data/imagenet_subsets/%s.txt' % (p['train_db_name'])
        dataset = ImageNetSubset(subset_file=subset_file,
                                 split='train',
                                 transform=transform)

    #Added by Johan
    elif p['train_db_name'] == 'tabledb':
        from data.tabledb import TableDB
        dataset = TableDB(split='train', transform=transform)

    #Added by Johan
    elif p['train_db_name'] == 'tablestrdb':
        from data.tablestrdb import TableStrDB
        dataset = TableStrDB(split='train', transform=transform)

    else:
        raise ValueError('Invalid train dataset {}'.format(p['train_db_name']))

    # Wrap into other dataset (__getitem__ changes)
    if to_augmented_dataset:  # Dataset returns an image and an augmentation of that image.
        from data.custom_dataset import AugmentedDataset
        dataset = AugmentedDataset(dataset)

    if to_neighbors_dataset:  # Dataset returns an image and one of its nearest neighbors.
        from data.custom_dataset import NeighborsDataset
        indices = np.load(p['topk_neighbors_train_path'])
        dataset = NeighborsDataset(dataset, indices, p['num_neighbors'])

    return dataset
def get_train_dataset(p, transform, to_augmented_dataset=False,
                        to_neighbors_dataset=False, split=None):
    # Base dataset
    if p['train_db_name'] == 'cifar-10':
        from data.cifar import CIFAR10
        dataset = CIFAR10(train=True, transform=transform, download=True)

    elif p['train_db_name'] == 'cifar-20':
        from data.cifar import CIFAR20
        dataset = CIFAR20(train=True, transform=transform, download=True)

    elif p['train_db_name'] == 'stl-10':
        from data.stl import STL10
        dataset = STL10(split=split, transform=transform, download=True)

    elif p['train_db_name'] == 'imagenet':
        from data.imagenet import ImageNet
        dataset = ImageNet(split='train', transform=transform)

    elif p['train_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']:
        from data.imagenet import ImageNetSubset
        subset_file = './data/imagenet_subsets/%s.txt' %(p['train_db_name'])
        dataset = ImageNetSubset(subset_file=subset_file, split='train', transform=transform)

    elif p['train_db_name'] == 'celeb-a':
        import torchvision
        from data.celeba import CelebADataset
        # dataset = torchvision.datasets.CelebA(r'E:\datasets\celeb-a', 'train')
        dataset = CelebADataset('train', target_type=p['db_targets'], attr_index=p['attr_index'], transform=transform)

    elif p['train_db_name'] == 'birds-200-2011':
        from data.birds200 import Birds200_2011
        dataset = Birds200_2011(is_train=True, targets_type=p['db_targets'], transform=transform)

    else:
        raise ValueError('Invalid train dataset {}'.format(p['train_db_name']))
    
    # Wrap into other dataset (__getitem__ changes)
    if to_augmented_dataset: # Dataset returns an image and an augmentation of that image.
        from data.custom_dataset import AugmentedDataset
        dataset = AugmentedDataset(dataset)

    if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors.
        from data.custom_dataset import NeighborsDataset
        indices = np.load(p['topk_neighbors_train_path'])
        dataset = NeighborsDataset(dataset, indices, p['num_neighbors'])
    
    return dataset
Exemple #5
0
def get_dataloader(opt, distributed):
    input_size = opt.input_size
    crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
    resize = int(math.ceil(input_size / crop_ratio))
    transform_test = transforms_cv.Compose([
        transforms_cv.Resize((resize, resize)),
        transforms_cv.CenterCrop(input_size),
        transforms_cv.ToTensor(),
        transforms_cv.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    val_dataset = ImageNet(opt.data_dir, train=False, transform=transform_test)

    sampler = make_data_sampler(val_dataset, False, distributed)
    batch_sampler = data.BatchSampler(sampler=sampler, batch_size=opt.batch_size, drop_last=False)
    val_loader = data.DataLoader(val_dataset, batch_sampler=batch_sampler, num_workers=opt.num_workers)
    return val_loader
def get_val_dataset(p,
                    transform=None,
                    to_neighbors_dataset=False,
                    to_neighbors_strangers_dataset=False,
                    to_teachers_dataset=False):
    # Base dataset
    if p['val_db_name'] in ['cifar-10', 'cifar-10-d', 'cifar-10-f']:
        from data.cifar import CIFAR10
        dataset = CIFAR10(train=False, transform=transform, download=True)

    elif p['val_db_name'] in ['cifar-20', 'cifar-20-d', 'cifar-20-f']:
        from data.cifar import CIFAR20
        dataset = CIFAR20(train=False, transform=transform, download=False)

    elif p['val_db_name'] in ['stl-10', 'stl-10-d', 'stl-10-f']:
        from data.stl import STL10
        dataset = STL10(split='test', transform=transform, download=False)

    elif 'pascal-pretrained' in p['train_db_name'] or p[
            'train_db_name'] == 'pascal-large-batches' or p[
                'train_db_name'] == 'pascal-retrain':
        from data.pascal_voc import PASCALVOC
        dataset = PASCALVOC(transform=transform)

    elif 'cub' in p['val_db_name']:
        from data.cub import CUB
        dataset = CUB(train=False, transform=transform)

    elif 'imagenet_' in p['val_db_name']:
        from data.imagenet import ImageNetSubset

        subset_name = p['val_db_name'].replace('-d', '').replace(
            '-f', '').replace('-0', '').replace('-1', '').replace('-2', '')
        subset_file = './data/imagenet_subsets/%s.txt' % (subset_name)
        dataset = ImageNetSubset(subset_file=subset_file,
                                 split='val',
                                 transform=transform)

    elif 'imagenet' in p['val_db_name']:
        from data.imagenet import ImageNet
        dataset = ImageNet(split='train', transform=transform)

    else:
        raise ValueError('Invalid validation dataset {}'.format(
            p['val_db_name']))

    # Wrap into other dataset (__getitem__ changes)
    if to_neighbors_dataset:  # Dataset returns an image and one of its nearest neighbors.
        from data.custom_dataset import NeighborsDataset
        indices = np.load(p['topk_neighbors_val_path'])
        dataset = NeighborsDataset(dataset, indices, 5)  # Only use 5

    if to_neighbors_strangers_dataset:
        from data.custom_dataset import SCANFDataset
        neighbor_indices = np.load(p['topk_neighbors_val_path'])
        stranger_indices = np.load(p['topk_strangers_val_path'])
        dataset = SCANFDataset(dataset, neighbor_indices, stranger_indices, 5,
                               5)

    if to_teachers_dataset:
        from data.custom_dataset import TeachersDataset
        dataset = TeachersDataset(dataset, p['cluster_preds_path'])

    return dataset
Exemple #7
0
def set_loader(opt):
    # construct data loader
    if opt.dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif opt.dataset == 'cifar100':
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    elif opt.dataset == 'imagenet':
        mean = (0.485, 0.456, 0.406),
        std = (0.229, 0.224, 0.225)
    else:
        raise ValueError('dataset not supported: {}'.format(opt.dataset))
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    if opt.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(root=opt.data_folder,
                                         transform=train_transform,
                                         download=True)
        val_dataset = datasets.CIFAR10(root=opt.data_folder,
                                       train=False,
                                       transform=val_transform)
    elif opt.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(root=opt.data_folder,
                                          transform=train_transform,
                                          download=True)
        val_dataset = datasets.CIFAR100(root=opt.data_folder,
                                        train=False,
                                        transform=val_transform)
    elif opt.dataset == 'imagenet':
        train_dataset = ImageNet(root=opt.data_folder,
                                 split='train',
                                 transform=train_transform)
        val_dataset = ImageNet(root=opt.data_folder,
                               split='val',
                               transform=val_transform)

    else:
        raise ValueError(opt.dataset)

    if opt.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=opt.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=256,
                                             shuffle=False,
                                             num_workers=8,
                                             pin_memory=True)

    return train_loader, val_loader