示例#1
0
def test_init(numpy_data, seed):
    train, test = numpy_data
    dummy = InMemoryDatasetTest(*train)

    nb_tasks = 3
    if isinstance(seed, list):
        nb_tasks = len(seed) + 1

    scenario_1 = Permutations(cl_dataset=dummy, nb_tasks=nb_tasks, seed=seed)
    scenario_2 = Permutations(cl_dataset=dummy, nb_tasks=nb_tasks, seed=seed)

    previous_x = []
    if isinstance(seed, list):
        assert len(scenario_1) == len(scenario_2) == len(seed) + 1

    for task_id, (train_taskset_1, train_taskset_2) in enumerate(zip(scenario_1, scenario_2)):
        assert task_id < nb_tasks

        assert len(train_taskset_1) == len(train_taskset_2)
        indexes = list(range(len(train_taskset_1)))

        x_1, y_1, t_1 = train_taskset_1.get_samples(indexes)
        x_2, y_2, t_2 = train_taskset_2.get_samples(indexes)

        assert (x_1 == x_2).all()
        assert (y_1 == y_2).all()
        assert (t_1 == t_2).all()

        for x in previous_x:
            assert not (x == x_1).all()
        previous_x.append(x_1.clone())
示例#2
0
def test_with_dataset(dataset, shared_label_space):
    dataset = dataset(data_path=DATA_PATH, download=True, train=True)
    scenario = Permutations(cl_dataset=dataset, nb_tasks=5, seed=0, shared_label_space=shared_label_space)

    for task_id, taskset in enumerate(scenario):

        classes = taskset.get_classes()

        if shared_label_space:
            assert len(classes) == classes.max() + 1
        else:
            assert len(classes) == classes.max() + 1 - (task_id * len(classes))
示例#3
0
def test_visualization_permutations(tmpdir):
    scenario = Permutations(cl_dataset=MNIST(data_path=tmpdir, download=True, train=True),
                            nb_tasks=3,
                            seed=0)

    folder = os.path.join(tmpdir, "samples", "permutation")
    if not os.path.exists(folder):
        os.makedirs(folder)

    for task_id, taskset in enumerate(scenario):
        taskset.plot(path=folder,
                     title="MNIST_Permutations_{}.jpg".format(task_id),
                     nb_samples=100,
                     shape=[28, 28, 1])