def test_init(numpy_data, seed): train, test = numpy_data dummy = InMemoryDatasetTest(*train) nb_tasks = 3 if isinstance(seed, list): nb_tasks = len(seed) + 1 scenario_1 = Permutations(cl_dataset=dummy, nb_tasks=nb_tasks, seed=seed) scenario_2 = Permutations(cl_dataset=dummy, nb_tasks=nb_tasks, seed=seed) previous_x = [] if isinstance(seed, list): assert len(scenario_1) == len(scenario_2) == len(seed) + 1 for task_id, (train_taskset_1, train_taskset_2) in enumerate(zip(scenario_1, scenario_2)): assert task_id < nb_tasks assert len(train_taskset_1) == len(train_taskset_2) indexes = list(range(len(train_taskset_1))) x_1, y_1, t_1 = train_taskset_1.get_samples(indexes) x_2, y_2, t_2 = train_taskset_2.get_samples(indexes) assert (x_1 == x_2).all() assert (y_1 == y_2).all() assert (t_1 == t_2).all() for x in previous_x: assert not (x == x_1).all() previous_x.append(x_1.clone())
def test_with_dataset(dataset, shared_label_space): dataset = dataset(data_path=DATA_PATH, download=True, train=True) scenario = Permutations(cl_dataset=dataset, nb_tasks=5, seed=0, shared_label_space=shared_label_space) for task_id, taskset in enumerate(scenario): classes = taskset.get_classes() if shared_label_space: assert len(classes) == classes.max() + 1 else: assert len(classes) == classes.max() + 1 - (task_id * len(classes))
def test_visualization_permutations(tmpdir): scenario = Permutations(cl_dataset=MNIST(data_path=tmpdir, download=True, train=True), nb_tasks=3, seed=0) folder = os.path.join(tmpdir, "samples", "permutation") if not os.path.exists(folder): os.makedirs(folder) for task_id, taskset in enumerate(scenario): taskset.plot(path=folder, title="MNIST_Permutations_{}.jpg".format(task_id), nb_samples=100, shape=[28, 28, 1])