Exemple #1
0
    def __init__(self,
                 data_path: str = "",
                 split: str = "train",
                 download: bool = True,
                 seed: int = 1):
        if split not in ("train", "val", "test"):
            raise ValueError(
                f"Split must be train, val, or test; not {split}.")
        train = split in ("train", "val")

        datasets = [
            CIFAR10(data_path=data_path, train=train, download=download),
            MNIST(data_path=data_path, train=train, download=download),
            DTD(data_path=data_path, train=train, download=download),
            FashionMNIST(data_path=data_path, train=train, download=download),
            SVHN(data_path=data_path, train=train, download=download),
            CIFAR10(data_path=data_path, train=train, download=download)
        ]

        if split == "train":
            proportions = [4000, 400, 400, 400, 400, 400]
        elif split == "val":
            proportions = [2000, 200, 200, 200, 200, 200]
        else:
            proportions = None

        super().__init__(datasets=datasets,
                         proportions=proportions,
                         class_counter=[0, 10, 20, 67, 77, 87],
                         seed=seed,
                         split=split)
def get_permuted_CIFAR10(path, batch_size,train):
    im_width = im_height = 32

    rand_perm = RandomPermutation(0, 0, im_width, im_height)
    normalization = transforms.Normalize((0.1307,), (0.3081,))

    #Todo: rethink RandomPermutation usage slows down dataloading by a factor > 6, Should try directly on batches.
    transfrom = transforms.Compose([
        transforms.ToTensor(),
        rand_perm,
        normalization]
    )

    if(train):
        set = ClassIncremental(
            CIFAR10(data_path="./src/data/CIFAR10", download=True, train=True),
         increment=2
        )

    else:
        set = ClassIncremental(
            CIFAR10(data_path="./src/data/CIFAR10", download=True, train=False),
         increment=2
        )

    return set
def test_background_tranformation():
    """
    Example code using TransformationIncremental to create a setting with 3 tasks.
    """
    cifar = CIFAR10(DATA_PATH, train=True)
    mnist = MNIST(DATA_PATH, download=False, train=True)
    nb_task = 3
    list_trsf = []
    for i in range(nb_task):
        list_trsf.append([
            torchvision.transforms.ToTensor(),
            BackgroundSwap(cifar, bg_label=i, input_dim=(28, 28)),
            torchvision.transforms.ToPILImage()
        ])
    scenario = TransformationIncremental(
        mnist,
        base_transformations=[torchvision.transforms.ToTensor()],
        incremental_transformations=list_trsf)
    folder = "tests/samples/background_trsf/"
    if not os.path.exists(folder):
        os.makedirs(folder)
    for task_id, task_data in enumerate(scenario):
        task_data.plot(path=folder,
                       title=f"background_{task_id}.jpg",
                       nb_samples=100,
                       shape=[28, 28, 3])
        loader = DataLoader(task_data)
        _, _, _ = next(iter(loader))
def test_background_swap_numpy():
    """
    Test background swap on a single ndarray input.
    """
    mnist = MNIST(DATA_PATH, download=True, train=True)
    cifar = CIFAR10(DATA_PATH, download=True, train=True)

    bg_swap = BackgroundSwap(cifar, input_dim=(28, 28))

    im = mnist.get_data()[0][0]
    im = bg_swap(im)
def test_background_swap_torch():
    """
    Test background swap on a single tensor input.
    """
    cifar = CIFAR10(DATA_PATH, download=True, train=True)

    mnist = torchvision.datasets.MNIST(
        DATA_PATH,
        train=True,
        download=True,
        transform=torchvision.transforms.Compose(
            [torchvision.transforms.ToTensor()]))

    bg_swap = BackgroundSwap(cifar, input_dim=(28, 28))
    im = mnist[0][0]

    im = bg_swap(im)
Exemple #6
0
    def __init__(self,
                 data_path: str = "",
                 split: str = "train",
                 download: bool = True,
                 seed: int = 1):
        if split not in ("train", "val", "test"):
            raise ValueError(
                f"Split must be train, val, or test; not {split}.")
        train = split in ("train", "val")

        color1, color2 = np.random.RandomState(seed=seed).choice(
            ["red", "blue", "green"], 2)
        datasets = [
            RainbowMNIST(data_path=data_path,
                         train=train,
                         download=download,
                         color=color1),
            CIFAR10(data_path=data_path, train=train, download=download),
            DTD(data_path=data_path, train=train, download=download),
            FashionMNIST(data_path=data_path, train=train, download=download),
            SVHN(data_path=data_path, train=train, download=download),
            RainbowMNIST(data_path=data_path,
                         train=train,
                         download=download,
                         color=color2)
        ]

        if split == "train":
            proportions = [4000, 400, 400, 400, 400, 50]
        elif split == "val":
            proportions = [2000, 200, 200, 200, 200, 30]
        else:
            proportions = None

        super().__init__(datasets=datasets,
                         proportions=proportions,
                         class_counter=[0, 10, 20, 67, 77, 0],
                         seed=seed,
                         split=split)