def test_sit_multi_dataset_merge(self): split_mapping = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/", train=True, download=True) mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/", train=False, download=True) train_part1 = make_nc_transformation_subset( mnist_train, None, None, range(5)) train_part2 = make_nc_transformation_subset( mnist_train, None, None, range(5, 10)) train_part2 = AvalancheSubset( train_part2, class_mapping=split_mapping) test_part1 = make_nc_transformation_subset( mnist_test, None, None, range(5)) test_part2 = make_nc_transformation_subset( mnist_test, None, None, range(5, 10)) test_part2 = AvalancheSubset(test_part2, class_mapping=split_mapping) my_nc_scenario = nc_scenario( [train_part1, train_part2], [test_part1, test_part2], 5, task_labels=False, shuffle=True, seed=1234) self.assertEqual(5, my_nc_scenario.n_experiences) self.assertEqual(10, my_nc_scenario.n_classes) for batch_id in range(5): self.assertEqual( 2, len(my_nc_scenario.classes_in_experience[batch_id])) all_classes = set() for batch_id in range(5): all_classes.update(my_nc_scenario.classes_in_experience[batch_id]) self.assertEqual(10, len(all_classes))
def test_sit_multi_dataset_merge(self): split_mapping = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] mnist_train = MNIST( root=default_dataset_location("mnist"), train=True, download=True, ) mnist_test = MNIST( root=default_dataset_location("mnist"), train=False, download=True, ) train_part1 = make_nc_transformation_subset( mnist_train, None, None, range(5) ) train_part2 = make_nc_transformation_subset( mnist_train, None, None, range(5, 10) ) train_part2 = AvalancheSubset(train_part2, class_mapping=split_mapping) test_part1 = make_nc_transformation_subset( mnist_test, None, None, range(5) ) test_part2 = make_nc_transformation_subset( mnist_test, None, None, range(5, 10) ) test_part2 = AvalancheSubset(test_part2, class_mapping=split_mapping) my_nc_benchmark = nc_benchmark( [train_part1, train_part2], [test_part1, test_part2], 5, task_labels=False, shuffle=True, seed=1234, ) self.assertEqual(5, my_nc_benchmark.n_experiences) self.assertEqual(10, my_nc_benchmark.n_classes) for batch_id in range(5): self.assertEqual( 2, len(my_nc_benchmark.classes_in_experience["train"][batch_id]) ) all_classes = set() for batch_id in range(5): all_classes.update( my_nc_benchmark.classes_in_experience["train"][batch_id] ) self.assertEqual(10, len(all_classes))
def test_sit_multi_dataset_one_batch_per_set(self): split_mapping = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6] mnist_train = MNIST( root=expanduser("~") + "/.avalanche/data/mnist/", train=True, download=True, ) mnist_test = MNIST( root=expanduser("~") + "/.avalanche/data/mnist/", train=False, download=True, ) train_part1 = make_nc_transformation_subset(mnist_train, None, None, range(3)) train_part2 = make_nc_transformation_subset(mnist_train, None, None, range(3, 10)) train_part2 = AvalancheSubset(train_part2, class_mapping=split_mapping) test_part1 = make_nc_transformation_subset(mnist_test, None, None, range(3)) test_part2 = make_nc_transformation_subset(mnist_test, None, None, range(3, 10)) test_part2 = AvalancheSubset(test_part2, class_mapping=split_mapping) my_nc_benchmark = nc_benchmark( [train_part1, train_part2], [test_part1, test_part2], 2, task_labels=False, shuffle=True, seed=1234, one_dataset_per_exp=True, ) self.assertEqual(2, my_nc_benchmark.n_experiences) self.assertEqual(10, my_nc_benchmark.n_classes) all_classes = set() for batch_id in range(2): all_classes.update( my_nc_benchmark.classes_in_experience["train"][batch_id]) self.assertEqual(10, len(all_classes)) self.assertTrue( (my_nc_benchmark.classes_in_experience["train"][0] == {0, 1, 2} and my_nc_benchmark.classes_in_experience["train"][1] == set( range(3, 10))) or (my_nc_benchmark.classes_in_experience["train"][0] == set( range(3, 10)) and my_nc_benchmark.classes_in_experience["train"][1] == {0, 1, 2}))
def test_mt_multi_dataset_one_task_per_set(self): split_mapping = [0, 1, 2, 0, 1, 2, 3, 4, 5, 6] mnist_train = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/", train=True, download=True) mnist_test = MNIST(root=expanduser("~") + "/.avalanche/data/mnist/", train=False, download=True) train_part1 = make_nc_transformation_subset(mnist_train, None, None, range(3)) train_part2 = make_nc_transformation_subset(mnist_train, None, None, range(3, 10)) train_part2 = AvalancheSubset(train_part2, class_mapping=split_mapping) test_part1 = make_nc_transformation_subset(mnist_test, None, None, range(3)) test_part2 = make_nc_transformation_subset(mnist_test, None, None, range(3, 10)) test_part2 = AvalancheSubset(test_part2, class_mapping=split_mapping) my_nc_benchmark = nc_benchmark([train_part1, train_part2], [test_part1, test_part2], 2, task_labels=True, seed=1234, class_ids_from_zero_in_each_exp=True, one_dataset_per_exp=True) self.assertEqual(2, my_nc_benchmark.n_experiences) self.assertEqual(10, my_nc_benchmark.n_classes) self.assertEqual(2, len(my_nc_benchmark.train_stream)) self.assertEqual(2, len(my_nc_benchmark.test_stream)) exp_classes_train = [] exp_classes_test = [] all_classes_train = set() all_classes_test = set() task_info: NCExperience for task_id, task_info in enumerate(my_nc_benchmark.train_stream): self.assertLessEqual(task_id, 1) all_classes_train.update( my_nc_benchmark.classes_in_experience['train'][task_id]) exp_classes_train.append(task_info.classes_in_this_experience) self.assertEqual(7, len(all_classes_train)) for task_id, task_info in enumerate(my_nc_benchmark.test_stream): self.assertLessEqual(task_id, 1) all_classes_test.update( my_nc_benchmark.classes_in_experience['test'][task_id]) exp_classes_test.append(task_info.classes_in_this_experience) self.assertEqual(7, len(all_classes_test)) self.assertTrue( (my_nc_benchmark.classes_in_experience['train'][0] == {0, 1, 2} and my_nc_benchmark.classes_in_experience['train'][1] == set( range(0, 7))) or (my_nc_benchmark.classes_in_experience['train'][0] == set( range(0, 7)) and my_nc_benchmark.classes_in_experience['train'][1] == {0, 1, 2})) exp_classes_ref1 = [list(range(3)), list(range(7))] exp_classes_ref2 = [list(range(7)), list(range(3))] self.assertTrue(exp_classes_train == exp_classes_ref1 or exp_classes_train == exp_classes_ref2) if exp_classes_train == exp_classes_ref1: self.assertTrue(exp_classes_test == exp_classes_ref1) else: self.assertTrue(exp_classes_test == exp_classes_ref2)