def test_class_order(dataset_png, mode, class_order, error): """We need PNG here because JPG lose some pixels values""" increments = [1, 1, 1, 1] if error: with pytest.raises(ValueError): scenario = SegmentationClassIncremental(dataset_png, nb_classes=4, increment=increments, class_order=class_order, mode=mode) return else: scenario = SegmentationClassIncremental(dataset_png, nb_classes=4, increment=increments, class_order=class_order, mode=mode) for task_id, task_set in enumerate(scenario): loader = DataLoader(task_set, batch_size=200, drop_last=False) x, y, _ = next(iter(loader)) pixels = torch.unique((x * 255).long()) assert (task_id + 1) in y real_class = class_order[task_id] assert real_class in pixels, task_id original_y = scenario.get_original_targets(y) assert real_class in np.unique(original_y), task_id
def test_save_indexes(tmpdir): dataset = create_dataset(tmpdir, "seg_tmp") _clean(os.path.join(tmpdir, "seg_tmp*")) with pytest.raises(Exception): scenario = SegmentationClassIncremental( dataset, nb_classes=4, increment=2, mode="overlap", ) dataset = create_dataset(tmpdir, "seg_tmp") scenario = SegmentationClassIncremental(dataset, nb_classes=4, increment=2, mode="overlap", save_indexes=os.path.join( tmpdir, "indexes.npy")) _clean(os.path.join(tmpdir, "seg_tmp*")) scenario = SegmentationClassIncremental(dataset, nb_classes=4, increment=2, mode="overlap", save_indexes=os.path.join( tmpdir, "indexes.npy"))
def test_labels_test(dataset_test, mode, all_seen_tasks): scenario = SegmentationClassIncremental(dataset_test, nb_classes=4, increment=1, mode=mode) classes = [0, 255, 1, 2, 3, 4] for task_id in range(len(scenario)): if all_seen_tasks: task_set = scenario[:task_id + 1] else: task_set = scenario[task_id] loader = DataLoader(task_set, batch_size=200, drop_last=False) x, y, _ = next(iter(loader)) seen_classes = torch.unique(y) inv_subset_classes = classes[task_id + 3:] for c in inv_subset_classes: assert c not in seen_classes, task_id if all_seen_tasks: subset_classes = classes[:task_id + 3] for c in subset_classes: assert c in seen_classes, task_id else: assert classes[task_id + 2] in seen_classes
def test_advanced_indexing_step(dataset): scenario = SegmentationClassIncremental(dataset, nb_classes=4, increment=1, mode="overlap") with pytest.raises(ValueError): task_set = scenario[0:4:2]
def test_length_taskset(dataset, mode, lengths, increment): scenario = SegmentationClassIncremental(dataset, nb_classes=4, increment=increment, initial_increment=2, mode=mode) assert len(scenario) == len(lengths) for i, l in enumerate(lengths): assert len(scenario[i]) == l, i
def test_labels(dataset, mode, increment): initial_increment = 2 nb_classes = 4 min_cls = 1 scenario = SegmentationClassIncremental( dataset, nb_classes=nb_classes, increment=increment, initial_increment=initial_increment, mode=mode) if isinstance(increment, int) and increment == 2: increments = [2, 2] elif isinstance(increment, int) and increment == 1: increments = [2, 1, 1] else: increments = increment for task_id, task_set in enumerate(scenario): loader = DataLoader(task_set, batch_size=200, drop_last=False) x, y, t = next(iter(loader)) assert len(x.shape) == 4 assert len(y.shape) == 3 assert len(t.shape) == 1 assert x.shape[2:] == y.shape[1:] assert (t == task_id).all() seen_classes = set(torch.unique(y).numpy().tolist()) max_cls = min_cls + increments[task_id] assert 0 in seen_classes, task_id if 4 not in seen_classes: assert 255 in seen_classes, task_id for c in list(range(min_cls, nb_classes + min_cls)): if mode in ("overlap", "disjoint"): if min_cls <= c < max_cls: assert c in seen_classes, (c, task_id, min_cls, max_cls) else: assert c not in seen_classes, (c, task_id, min_cls, max_cls) elif mode == "sequential": if c < max_cls: assert c in seen_classes, (c, task_id, min_cls, max_cls) else: assert c not in seen_classes, (c, task_id, min_cls, max_cls) min_cls += increments[task_id]
def test_background_test(dataset, dataset_test, mode, test_background, train): scenario = SegmentationClassIncremental(dataset if train else dataset_test, nb_classes=4, increment=2, mode=mode, test_background=test_background) for task_set in scenario: loader = DataLoader(task_set, batch_size=200, drop_last=False) x, y, _ = next(iter(loader)) if train or test_background: assert 0 in y else: assert 0 not in y
def test_advanced_indexing(dataset, dataset_test, mode, start, end, classes, train): scenario = SegmentationClassIncremental(dataset if train else dataset_test, nb_classes=4, increment=1, mode=mode) task_set = scenario[start:end] loader = DataLoader(task_set, batch_size=200, drop_last=False) _, y, t = next(iter(loader)) t = torch.unique(t) y = torch.unique(y) assert len(t) == 1 and t[0] == end - 1 assert set(y.numpy().tolist()) - set([0, 255]) == set(classes)
def test_labels_overlap_dense_test(dataset_dense_test): scenario = SegmentationClassIncremental(dataset_dense_test, nb_classes=4, increment=1, mode="overlap") classes = [0, 255, 1, 2, 3, 4] for task_id, task_set in enumerate(scenario): loader = DataLoader(task_set, batch_size=200, drop_last=False) x, y, _ = next(iter(loader)) seen_classes = torch.unique(y) subset_classes = classes[:task_id + 3] for c in subset_classes: assert c in seen_classes, task_id inv_subset_classes = classes[task_id + 3:] for c in inv_subset_classes: assert c not in seen_classes, task_id assert len(x) == 20