def __init__(self, data_path: str = "", split: str = "train", download: bool = True, seed: int = 1): if split not in ("train", "val", "test"): raise ValueError( f"Split must be train, val, or test; not {split}.") train = split in ("train", "val") datasets = [ CIFAR10(data_path=data_path, train=train, download=download), MNIST(data_path=data_path, train=train, download=download), DTD(data_path=data_path, train=train, download=download), FashionMNIST(data_path=data_path, train=train, download=download), SVHN(data_path=data_path, train=train, download=download), CIFAR10(data_path=data_path, train=train, download=download) ] if split == "train": proportions = [4000, 400, 400, 400, 400, 400] elif split == "val": proportions = [2000, 200, 200, 200, 200, 200] else: proportions = None super().__init__(datasets=datasets, proportions=proportions, class_counter=[0, 10, 20, 67, 77, 87], seed=seed, split=split)
def get_permuted_CIFAR10(path, batch_size,train): im_width = im_height = 32 rand_perm = RandomPermutation(0, 0, im_width, im_height) normalization = transforms.Normalize((0.1307,), (0.3081,)) #Todo: rethink RandomPermutation usage slows down dataloading by a factor > 6, Should try directly on batches. transfrom = transforms.Compose([ transforms.ToTensor(), rand_perm, normalization] ) if(train): set = ClassIncremental( CIFAR10(data_path="./src/data/CIFAR10", download=True, train=True), increment=2 ) else: set = ClassIncremental( CIFAR10(data_path="./src/data/CIFAR10", download=True, train=False), increment=2 ) return set
def test_background_tranformation(): """ Example code using TransformationIncremental to create a setting with 3 tasks. """ cifar = CIFAR10(DATA_PATH, train=True) mnist = MNIST(DATA_PATH, download=False, train=True) nb_task = 3 list_trsf = [] for i in range(nb_task): list_trsf.append([ torchvision.transforms.ToTensor(), BackgroundSwap(cifar, bg_label=i, input_dim=(28, 28)), torchvision.transforms.ToPILImage() ]) scenario = TransformationIncremental( mnist, base_transformations=[torchvision.transforms.ToTensor()], incremental_transformations=list_trsf) folder = "tests/samples/background_trsf/" if not os.path.exists(folder): os.makedirs(folder) for task_id, task_data in enumerate(scenario): task_data.plot(path=folder, title=f"background_{task_id}.jpg", nb_samples=100, shape=[28, 28, 3]) loader = DataLoader(task_data) _, _, _ = next(iter(loader))
def test_background_swap_numpy(): """ Test background swap on a single ndarray input. """ mnist = MNIST(DATA_PATH, download=True, train=True) cifar = CIFAR10(DATA_PATH, download=True, train=True) bg_swap = BackgroundSwap(cifar, input_dim=(28, 28)) im = mnist.get_data()[0][0] im = bg_swap(im)
def test_background_swap_torch(): """ Test background swap on a single tensor input. """ cifar = CIFAR10(DATA_PATH, download=True, train=True) mnist = torchvision.datasets.MNIST( DATA_PATH, train=True, download=True, transform=torchvision.transforms.Compose( [torchvision.transforms.ToTensor()])) bg_swap = BackgroundSwap(cifar, input_dim=(28, 28)) im = mnist[0][0] im = bg_swap(im)
def __init__(self, data_path: str = "", split: str = "train", download: bool = True, seed: int = 1): if split not in ("train", "val", "test"): raise ValueError( f"Split must be train, val, or test; not {split}.") train = split in ("train", "val") color1, color2 = np.random.RandomState(seed=seed).choice( ["red", "blue", "green"], 2) datasets = [ RainbowMNIST(data_path=data_path, train=train, download=download, color=color1), CIFAR10(data_path=data_path, train=train, download=download), DTD(data_path=data_path, train=train, download=download), FashionMNIST(data_path=data_path, train=train, download=download), SVHN(data_path=data_path, train=train, download=download), RainbowMNIST(data_path=data_path, train=train, download=download, color=color2) ] if split == "train": proportions = [4000, 400, 400, 400, 400, 50] elif split == "val": proportions = [2000, 200, 200, 200, 200, 30] else: proportions = None super().__init__(datasets=datasets, proportions=proportions, class_counter=[0, 10, 20, 67, 77, 0], seed=seed, split=split)