def get_val_dataset(p, transform=None, to_neighbors_dataset=False, to_similarity_dataset=False, use_negatives=False, use_simpred=False): # Base dataset if p['val_db_name'] == 'cifar-10': from data.cifar import CIFAR10 dataset = CIFAR10(train=False, transform=transform, download=True) elif p['val_db_name'] == 'cifar-20': from data.cifar import CIFAR20 dataset = CIFAR20(train=False, transform=transform, download=True) elif p['val_db_name'] == 'stl-10': from data.stl import STL10 dataset = STL10(split='test', transform=transform, download=True) elif p['train_db_name'] in [ 'impact_kb', 'impact_full_balanced', 'impact_full_imbalanced', 'hdi_balanced', 'hdi_imbalanced', 'tobacco3482', 'rvl-cdip', 'wpi_demo' ]: from data.imagefolderwrapper import ImageFolderWrapper root = MyPath.db_root_dir(p['train_db_name']) dataset = ImageFolderWrapper(root, split="test", transform=transform) elif p['val_db_name'] == 'imagenet': from data.imagenet import ImageNet dataset = ImageNet(split='val', transform=transform) elif p['val_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']: from data.imagenet import ImageNetSubset subset_file = './data/imagenet_subsets/%s.txt' % (p['val_db_name']) dataset = ImageNetSubset(subset_file=subset_file, split='val', transform=transform) else: raise ValueError('Invalid validation dataset {}'.format( p['val_db_name'])) # Wrap into other dataset (__getitem__ changes) if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors. from data.custom_dataset import NeighborsDataset knn_indices = np.load(p['topk_neighbors_val_path']) if use_negatives: kfn_indices = np.load(p['topk_furthest_val_path']) else: kfn_indices = None dataset = NeighborsDataset(dataset, knn_indices, kfn_indices, use_simpred, 5, 5) # Only use 5 elif to_similarity_dataset: # Dataset returns an image and another random image. from data.custom_dataset import SimilarityDataset dataset = SimilarityDataset(dataset) return dataset
def get_val_dataset(p, transform=None, to_neighbors_dataset=False): # Base dataset if p['val_db_name'] == 'cifar-10': from data.cifar import CIFAR10 dataset = CIFAR10(train=False, transform=transform, download=True) elif p['val_db_name'] == 'cifar-20': from data.cifar import CIFAR20 dataset = CIFAR20(train=False, transform=transform, download=True) elif p['val_db_name'] == 'stl-10': from data.stl import STL10 dataset = STL10(split='test', transform=transform, download=True) elif p['val_db_name'] == 'imagenet': from data.imagenet import ImageNet dataset = ImageNet(split='val', transform=transform) elif p['val_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']: from data.imagenet import ImageNetSubset subset_file = './data/imagenet_subsets/%s.txt' %(p['val_db_name']) dataset = ImageNetSubset(subset_file=subset_file, split='val', transform=transform) else: raise ValueError('Invalid validation dataset {}'.format(p['val_db_name'])) # Wrap into other dataset (__getitem__ changes) if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors. from data.custom_dataset import NeighborsDataset indices = np.load(p['topk_neighbors_val_path']) dataset = NeighborsDataset(dataset, indices, 5) # Only use 5 return dataset
def get_train_dataset(p, transform, to_augmented_dataset=False, to_neighbors_dataset=False, split=None): # Base dataset if p['train_db_name'] == 'partnet': from data.partnet import PARTNET if p['train_type_name'] == 'chair': dataset = PARTNET(split='train', type='chair', transform=transform) elif p['train_type_name'] == 'table': dataset = PARTNET(split='train', type='table', transform=transform) elif p['train_type_name'] == 'bed': dataset = PARTNET(split='train', type='bed', transform=transform) elif p['train_type_name'] == 'bag': dataset = PARTNET(split='train', type='bag', transform=transform) else: raise ValueError('Invalid train dataset {}'.format(p['train_db_name'])) # Wrap into other dataset (__getitem__ changes) if to_augmented_dataset: # Dataset returns an image and an augmentation of that image. from data.custom_dataset import AugmentedDataset dataset = AugmentedDataset(dataset) if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors. from data.custom_dataset import NeighborsDataset indices = np.load(p['topk_neighbors_train_path']) dataset = NeighborsDataset(dataset, indices, p['num_neighbors']) return dataset
def get_train_dataset(p, transform, to_augmented_dataset=False, to_neighbors_dataset=False, split=None): # Base dataset if p['train_db_name'] == 'cifar-10': from data.cifar import CIFAR10 dataset = CIFAR10(train=True, transform=transform, download=True) elif p['train_db_name'] == 'cifar-20': from data.cifar import CIFAR20 dataset = CIFAR20(train=True, transform=transform, download=True) elif p['train_db_name'] == 'stl-10': from data.stl import STL10 dataset = STL10(split=split, transform=transform, download=True) elif p['train_db_name'] == 'imagenet': from data.imagenet import ImageNet dataset = ImageNet(split='train', transform=transform) elif p['train_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']: from data.imagenet import ImageNetSubset subset_file = './data/imagenet_subsets/%s.txt' % (p['train_db_name']) dataset = ImageNetSubset(subset_file=subset_file, split='train', transform=transform) #Added by Johan elif p['train_db_name'] == 'tabledb': from data.tabledb import TableDB dataset = TableDB(split='train', transform=transform) #Added by Johan elif p['train_db_name'] == 'tablestrdb': from data.tablestrdb import TableStrDB dataset = TableStrDB(split='train', transform=transform) else: raise ValueError('Invalid train dataset {}'.format(p['train_db_name'])) # Wrap into other dataset (__getitem__ changes) if to_augmented_dataset: # Dataset returns an image and an augmentation of that image. from data.custom_dataset import AugmentedDataset dataset = AugmentedDataset(dataset) if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors. from data.custom_dataset import NeighborsDataset indices = np.load(p['topk_neighbors_train_path']) dataset = NeighborsDataset(dataset, indices, p['num_neighbors']) return dataset
def get_train_dataset(p, transform, to_augmented_dataset=False, to_neighbors_dataset=False, split=None): # Base dataset if p['train_db_name'] == 'cifar-10': from data.cifar import CIFAR10 dataset = CIFAR10(train=True, transform=transform, download=True) elif p['train_db_name'] == 'cifar-20': from data.cifar import CIFAR20 dataset = CIFAR20(train=True, transform=transform, download=True) elif p['train_db_name'] == 'stl-10': from data.stl import STL10 dataset = STL10(split=split, transform=transform, download=True) elif p['train_db_name'] == 'imagenet': from data.imagenet import ImageNet dataset = ImageNet(split='train', transform=transform) elif p['train_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']: from data.imagenet import ImageNetSubset subset_file = './data/imagenet_subsets/%s.txt' %(p['train_db_name']) dataset = ImageNetSubset(subset_file=subset_file, split='train', transform=transform) elif p['train_db_name'] == 'celeb-a': import torchvision from data.celeba import CelebADataset # dataset = torchvision.datasets.CelebA(r'E:\datasets\celeb-a', 'train') dataset = CelebADataset('train', target_type=p['db_targets'], attr_index=p['attr_index'], transform=transform) elif p['train_db_name'] == 'birds-200-2011': from data.birds200 import Birds200_2011 dataset = Birds200_2011(is_train=True, targets_type=p['db_targets'], transform=transform) else: raise ValueError('Invalid train dataset {}'.format(p['train_db_name'])) # Wrap into other dataset (__getitem__ changes) if to_augmented_dataset: # Dataset returns an image and an augmentation of that image. from data.custom_dataset import AugmentedDataset dataset = AugmentedDataset(dataset) if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors. from data.custom_dataset import NeighborsDataset indices = np.load(p['topk_neighbors_train_path']) dataset = NeighborsDataset(dataset, indices, p['num_neighbors']) return dataset
def get_val_dataset(p, transform=None, to_neighbors_dataset=False, to_neighbors_strangers_dataset=False, to_teachers_dataset=False): # Base dataset if p['val_db_name'] in ['cifar-10', 'cifar-10-d', 'cifar-10-f']: from data.cifar import CIFAR10 dataset = CIFAR10(train=False, transform=transform, download=True) elif p['val_db_name'] in ['cifar-20', 'cifar-20-d', 'cifar-20-f']: from data.cifar import CIFAR20 dataset = CIFAR20(train=False, transform=transform, download=False) elif p['val_db_name'] in ['stl-10', 'stl-10-d', 'stl-10-f']: from data.stl import STL10 dataset = STL10(split='test', transform=transform, download=False) elif 'pascal-pretrained' in p['train_db_name'] or p[ 'train_db_name'] == 'pascal-large-batches' or p[ 'train_db_name'] == 'pascal-retrain': from data.pascal_voc import PASCALVOC dataset = PASCALVOC(transform=transform) elif 'cub' in p['val_db_name']: from data.cub import CUB dataset = CUB(train=False, transform=transform) elif 'imagenet_' in p['val_db_name']: from data.imagenet import ImageNetSubset subset_name = p['val_db_name'].replace('-d', '').replace( '-f', '').replace('-0', '').replace('-1', '').replace('-2', '') subset_file = './data/imagenet_subsets/%s.txt' % (subset_name) dataset = ImageNetSubset(subset_file=subset_file, split='val', transform=transform) elif 'imagenet' in p['val_db_name']: from data.imagenet import ImageNet dataset = ImageNet(split='train', transform=transform) else: raise ValueError('Invalid validation dataset {}'.format( p['val_db_name'])) # Wrap into other dataset (__getitem__ changes) if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors. from data.custom_dataset import NeighborsDataset indices = np.load(p['topk_neighbors_val_path']) dataset = NeighborsDataset(dataset, indices, 5) # Only use 5 if to_neighbors_strangers_dataset: from data.custom_dataset import SCANFDataset neighbor_indices = np.load(p['topk_neighbors_val_path']) stranger_indices = np.load(p['topk_strangers_val_path']) dataset = SCANFDataset(dataset, neighbor_indices, stranger_indices, 5, 5) if to_teachers_dataset: from data.custom_dataset import TeachersDataset dataset = TeachersDataset(dataset, p['cluster_preds_path']) return dataset