def get_dataloader(batch_size_train=100, batch_size_test=200, data_dir='../data', base='CIFAR100', num_classes=100, train=True, download=False, val_ratio=0.01, num_workers=4, pin_memory=False, seed=1, **dataset_kwargs): dataset = get_dataset(data_dir=data_dir, base=base, num_classes=num_classes, train=train, download=download, **dataset_kwargs) if train: np.random.seed(seed) data_idxs = np.arange(len(dataset)) np.random.shuffle(data_idxs) val_size = int(len(dataset) * val_ratio) train_sampler, val_sampler = SubsetRandomSampler(data_idxs[val_size:]), \ SubsetRandomSampler(data_idxs[:val_size]) train_loader = DataLoader(dataset, batch_size=batch_size_train, sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory) val_loader = DataLoader(dataset, batch_size=batch_size_test, sampler=val_sampler, num_workers=num_workers, pin_memory=pin_memory) train_loader.classes = val_loader.classes = list(range(num_classes)) return train_loader, val_loader else: test_loader = DataLoader(dataset, batch_size=batch_size_test, shuffle=False, num_workers=num_workers, pin_memory=pin_memory) test_loader.classes = list(range(num_classes)) return test_loader
def get_dataloader_incr(batch_size_train=100, batch_size_test=200, data_dir='../data', base='CIFAR100', num_classes=100, train=True, download=False, val_ratio=0.01, num_workers=4, pin_memory=False, seed=1, classes_per_exposure=10, exposure_class_splits=None, scale_batch_size=False, val_idxs_path=None, train_idxs_path=None, **dataset_kwargs): train_val_split = None if val_idxs_path is not None: assert train_idxs_path is not None, 'must specify both val and train indices' train_val_split = np.load(train_idxs_path), np.load(val_idxs_path) dataset = get_dataset(data_dir=data_dir, base=base, num_classes=num_classes, train=train, download=download, train_val_split=train_val_split, **dataset_kwargs) if exposure_class_splits is None: assert num_classes % classes_per_exposure == 0, "specified classes per exposure (%d) does not evenly divide " \ "specified number of classes (%d)" % (classes_per_exposure, num_classes) exposure_class_splits = [ list(range(c, c + classes_per_exposure)) for c in range(0, num_classes, classes_per_exposure) ] if scale_batch_size: # scale down batch size by the number of total loader if we will be loading data across loaders concurrently batch_size_train = batch_size_train // len(exposure_class_splits) targets = np.array(dataset.targets) if train: train_loaders = [] val_loaders = [] np.random.seed(seed) for classes in exposure_class_splits: if val_idxs_path is not None: val_idxs = dataset.val_indices train_idxs = dataset.train_indices val_idxs_by_class = [] train_idxs_by_class = [] for c in classes: val_mask = targets[val_idxs] == c train_mask = targets[train_idxs] == c val_idxs_by_class += [val_idxs[val_mask]] train_idxs_by_class += [train_idxs[train_mask]] train_sampler = SubsetRandomSampler( np.concatenate(train_idxs_by_class)) val_sampler = SubsetRandomSampler( np.concatenate(val_idxs_by_class)) else: idxs_by_class = [] val_sizes = [] for c in classes: c_idxs = np.where(targets == c)[0] np.random.shuffle(c_idxs) idxs_by_class += [c_idxs] val_sizes += [int(len(c_idxs) * val_ratio)] train_sampler = SubsetRandomSampler( np.concatenate([ c_idxs[val_size:] for c_idxs, val_size in zip(idxs_by_class, val_sizes) ])) val_sampler = SubsetRandomSampler( np.concatenate([ c_idxs[:val_size] for c_idxs, val_size in zip(idxs_by_class, val_sizes) ])) train_loader = DataLoader(dataset, batch_size=batch_size_train, sampler=train_sampler, num_workers=num_workers, pin_memory=pin_memory) val_loader = DataLoader(dataset, batch_size=batch_size_test, sampler=val_sampler, num_workers=num_workers, pin_memory=pin_memory) train_loader.classes = val_loader.classes = classes train_loaders += [train_loader] val_loaders += [val_loader] return train_loaders, val_loaders else: test_loaders = [] for classes in exposure_class_splits: exposure_idxs = np.concatenate( [np.where(targets == c)[0] for c in classes]) exposure_loader = DataLoader( dataset, batch_size=batch_size_test, sampler=SubsetRandomSampler(exposure_idxs), num_workers=num_workers, pin_memory=pin_memory) exposure_loader.classes = classes test_loaders += [exposure_loader] return test_loaders