Esempio n. 1
0
def get_val_dataset(p,
                    transform=None,
                    to_neighbors_dataset=False,
                    to_similarity_dataset=False,
                    use_negatives=False,
                    use_simpred=False):
    # Base dataset
    if p['val_db_name'] == 'cifar-10':
        from data.cifar import CIFAR10
        dataset = CIFAR10(train=False, transform=transform, download=True)

    elif p['val_db_name'] == 'cifar-20':
        from data.cifar import CIFAR20
        dataset = CIFAR20(train=False, transform=transform, download=True)

    elif p['val_db_name'] == 'stl-10':
        from data.stl import STL10
        dataset = STL10(split='test', transform=transform, download=True)

    elif p['train_db_name'] in [
            'impact_kb', 'impact_full_balanced', 'impact_full_imbalanced',
            'hdi_balanced', 'hdi_imbalanced', 'tobacco3482', 'rvl-cdip',
            'wpi_demo'
    ]:
        from data.imagefolderwrapper import ImageFolderWrapper
        root = MyPath.db_root_dir(p['train_db_name'])
        dataset = ImageFolderWrapper(root, split="test", transform=transform)

    elif p['val_db_name'] == 'imagenet':
        from data.imagenet import ImageNet
        dataset = ImageNet(split='val', transform=transform)

    elif p['val_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']:
        from data.imagenet import ImageNetSubset
        subset_file = './data/imagenet_subsets/%s.txt' % (p['val_db_name'])
        dataset = ImageNetSubset(subset_file=subset_file,
                                 split='val',
                                 transform=transform)

    else:
        raise ValueError('Invalid validation dataset {}'.format(
            p['val_db_name']))

    # Wrap into other dataset (__getitem__ changes)
    if to_neighbors_dataset:  # Dataset returns an image and one of its nearest neighbors.
        from data.custom_dataset import NeighborsDataset
        knn_indices = np.load(p['topk_neighbors_val_path'])

        if use_negatives:
            kfn_indices = np.load(p['topk_furthest_val_path'])
        else:
            kfn_indices = None

        dataset = NeighborsDataset(dataset, knn_indices, kfn_indices,
                                   use_simpred, 5, 5)  # Only use 5
    elif to_similarity_dataset:  # Dataset returns an image and another random image.
        from data.custom_dataset import SimilarityDataset
        dataset = SimilarityDataset(dataset)

    return dataset
Esempio n. 2
0
def get_val_dataset(p, transform=None, to_neighbors_dataset=False):
    # Base dataset
    if p['val_db_name'] == 'cifar-10':
        from data.cifar import CIFAR10
        dataset = CIFAR10(train=False, transform=transform, download=True)
    
    elif p['val_db_name'] == 'cifar-20':
        from data.cifar import CIFAR20
        dataset = CIFAR20(train=False, transform=transform, download=True)

    elif p['val_db_name'] == 'stl-10':
        from data.stl import STL10
        dataset = STL10(split='test', transform=transform, download=True)
    
    elif p['val_db_name'] == 'imagenet':
        from data.imagenet import ImageNet
        dataset = ImageNet(split='val', transform=transform)
    
    elif p['val_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']:
        from data.imagenet import ImageNetSubset
        subset_file = './data/imagenet_subsets/%s.txt' %(p['val_db_name'])
        dataset = ImageNetSubset(subset_file=subset_file, split='val', transform=transform)
    
    else:
        raise ValueError('Invalid validation dataset {}'.format(p['val_db_name']))
    
    # Wrap into other dataset (__getitem__ changes) 
    if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors.
        from data.custom_dataset import NeighborsDataset
        indices = np.load(p['topk_neighbors_val_path'])
        dataset = NeighborsDataset(dataset, indices, 5) # Only use 5

    return dataset
Esempio n. 3
0
def get_train_dataset(p,
                      transform,
                      to_augmented_dataset=False,
                      to_neighbors_dataset=False,
                      split=None):
    # Base dataset
    if p['train_db_name'] == 'cifar-10':
        from data.cifar import CIFAR10
        dataset = CIFAR10(train=True, transform=transform, download=True)

    elif p['train_db_name'] == 'cifar-20':
        from data.cifar import CIFAR20
        dataset = CIFAR20(train=True, transform=transform, download=True)

    elif p['train_db_name'] == 'stl-10':
        from data.stl import STL10
        dataset = STL10(split=split, transform=transform, download=True)

    elif p['train_db_name'] == 'imagenet':
        from data.imagenet import ImageNet
        dataset = ImageNet(split='train', transform=transform)

    elif p['train_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']:
        from data.imagenet import ImageNetSubset
        subset_file = './data/imagenet_subsets/%s.txt' % (p['train_db_name'])
        dataset = ImageNetSubset(subset_file=subset_file,
                                 split='train',
                                 transform=transform)

    #Added by Johan
    elif p['train_db_name'] == 'tabledb':
        from data.tabledb import TableDB
        dataset = TableDB(split='train', transform=transform)

    #Added by Johan
    elif p['train_db_name'] == 'tablestrdb':
        from data.tablestrdb import TableStrDB
        dataset = TableStrDB(split='train', transform=transform)

    else:
        raise ValueError('Invalid train dataset {}'.format(p['train_db_name']))

    # Wrap into other dataset (__getitem__ changes)
    if to_augmented_dataset:  # Dataset returns an image and an augmentation of that image.
        from data.custom_dataset import AugmentedDataset
        dataset = AugmentedDataset(dataset)

    if to_neighbors_dataset:  # Dataset returns an image and one of its nearest neighbors.
        from data.custom_dataset import NeighborsDataset
        indices = np.load(p['topk_neighbors_train_path'])
        dataset = NeighborsDataset(dataset, indices, p['num_neighbors'])

    return dataset
def get_train_dataset(p, transform, to_augmented_dataset=False,
                        to_neighbors_dataset=False, split=None):
    # Base dataset
    if p['train_db_name'] == 'cifar-10':
        from data.cifar import CIFAR10
        dataset = CIFAR10(train=True, transform=transform, download=True)

    elif p['train_db_name'] == 'cifar-20':
        from data.cifar import CIFAR20
        dataset = CIFAR20(train=True, transform=transform, download=True)

    elif p['train_db_name'] == 'stl-10':
        from data.stl import STL10
        dataset = STL10(split=split, transform=transform, download=True)

    elif p['train_db_name'] == 'imagenet':
        from data.imagenet import ImageNet
        dataset = ImageNet(split='train', transform=transform)

    elif p['train_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']:
        from data.imagenet import ImageNetSubset
        subset_file = './data/imagenet_subsets/%s.txt' %(p['train_db_name'])
        dataset = ImageNetSubset(subset_file=subset_file, split='train', transform=transform)

    elif p['train_db_name'] == 'celeb-a':
        import torchvision
        from data.celeba import CelebADataset
        # dataset = torchvision.datasets.CelebA(r'E:\datasets\celeb-a', 'train')
        dataset = CelebADataset('train', target_type=p['db_targets'], attr_index=p['attr_index'], transform=transform)

    elif p['train_db_name'] == 'birds-200-2011':
        from data.birds200 import Birds200_2011
        dataset = Birds200_2011(is_train=True, targets_type=p['db_targets'], transform=transform)

    else:
        raise ValueError('Invalid train dataset {}'.format(p['train_db_name']))
    
    # Wrap into other dataset (__getitem__ changes)
    if to_augmented_dataset: # Dataset returns an image and an augmentation of that image.
        from data.custom_dataset import AugmentedDataset
        dataset = AugmentedDataset(dataset)

    if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors.
        from data.custom_dataset import NeighborsDataset
        indices = np.load(p['topk_neighbors_train_path'])
        dataset = NeighborsDataset(dataset, indices, p['num_neighbors'])
    
    return dataset
def get_val_dataset(p,
                    transform=None,
                    to_neighbors_dataset=False,
                    to_neighbors_strangers_dataset=False,
                    to_teachers_dataset=False):
    # Base dataset
    if p['val_db_name'] in ['cifar-10', 'cifar-10-d', 'cifar-10-f']:
        from data.cifar import CIFAR10
        dataset = CIFAR10(train=False, transform=transform, download=True)

    elif p['val_db_name'] in ['cifar-20', 'cifar-20-d', 'cifar-20-f']:
        from data.cifar import CIFAR20
        dataset = CIFAR20(train=False, transform=transform, download=False)

    elif p['val_db_name'] in ['stl-10', 'stl-10-d', 'stl-10-f']:
        from data.stl import STL10
        dataset = STL10(split='test', transform=transform, download=False)

    elif 'pascal-pretrained' in p['train_db_name'] or p[
            'train_db_name'] == 'pascal-large-batches' or p[
                'train_db_name'] == 'pascal-retrain':
        from data.pascal_voc import PASCALVOC
        dataset = PASCALVOC(transform=transform)

    elif 'cub' in p['val_db_name']:
        from data.cub import CUB
        dataset = CUB(train=False, transform=transform)

    elif 'imagenet_' in p['val_db_name']:
        from data.imagenet import ImageNetSubset

        subset_name = p['val_db_name'].replace('-d', '').replace(
            '-f', '').replace('-0', '').replace('-1', '').replace('-2', '')
        subset_file = './data/imagenet_subsets/%s.txt' % (subset_name)
        dataset = ImageNetSubset(subset_file=subset_file,
                                 split='val',
                                 transform=transform)

    elif 'imagenet' in p['val_db_name']:
        from data.imagenet import ImageNet
        dataset = ImageNet(split='train', transform=transform)

    else:
        raise ValueError('Invalid validation dataset {}'.format(
            p['val_db_name']))

    # Wrap into other dataset (__getitem__ changes)
    if to_neighbors_dataset:  # Dataset returns an image and one of its nearest neighbors.
        from data.custom_dataset import NeighborsDataset
        indices = np.load(p['topk_neighbors_val_path'])
        dataset = NeighborsDataset(dataset, indices, 5)  # Only use 5

    if to_neighbors_strangers_dataset:
        from data.custom_dataset import SCANFDataset
        neighbor_indices = np.load(p['topk_neighbors_val_path'])
        stranger_indices = np.load(p['topk_strangers_val_path'])
        dataset = SCANFDataset(dataset, neighbor_indices, stranger_indices, 5,
                               5)

    if to_teachers_dataset:
        from data.custom_dataset import TeachersDataset
        dataset = TeachersDataset(dataset, p['cluster_preds_path'])

    return dataset