示例#1
0
def test_class_order(dataset_png, mode, class_order, error):
    """We need PNG here because JPG lose some pixels values"""
    increments = [1, 1, 1, 1]

    if error:
        with pytest.raises(ValueError):
            scenario = SegmentationClassIncremental(dataset_png,
                                                    nb_classes=4,
                                                    increment=increments,
                                                    class_order=class_order,
                                                    mode=mode)
        return
    else:
        scenario = SegmentationClassIncremental(dataset_png,
                                                nb_classes=4,
                                                increment=increments,
                                                class_order=class_order,
                                                mode=mode)

    for task_id, task_set in enumerate(scenario):
        loader = DataLoader(task_set, batch_size=200, drop_last=False)
        x, y, _ = next(iter(loader))
        pixels = torch.unique((x * 255).long())

        assert (task_id + 1) in y
        real_class = class_order[task_id]
        assert real_class in pixels, task_id
        original_y = scenario.get_original_targets(y)
        assert real_class in np.unique(original_y), task_id
示例#2
0
def test_save_indexes(tmpdir):
    dataset = create_dataset(tmpdir, "seg_tmp")
    _clean(os.path.join(tmpdir, "seg_tmp*"))

    with pytest.raises(Exception):
        scenario = SegmentationClassIncremental(
            dataset,
            nb_classes=4,
            increment=2,
            mode="overlap",
        )

    dataset = create_dataset(tmpdir, "seg_tmp")
    scenario = SegmentationClassIncremental(dataset,
                                            nb_classes=4,
                                            increment=2,
                                            mode="overlap",
                                            save_indexes=os.path.join(
                                                tmpdir, "indexes.npy"))
    _clean(os.path.join(tmpdir, "seg_tmp*"))
    scenario = SegmentationClassIncremental(dataset,
                                            nb_classes=4,
                                            increment=2,
                                            mode="overlap",
                                            save_indexes=os.path.join(
                                                tmpdir, "indexes.npy"))
示例#3
0
def test_labels_test(dataset_test, mode, all_seen_tasks):
    scenario = SegmentationClassIncremental(dataset_test,
                                            nb_classes=4,
                                            increment=1,
                                            mode=mode)
    classes = [0, 255, 1, 2, 3, 4]

    for task_id in range(len(scenario)):
        if all_seen_tasks:
            task_set = scenario[:task_id + 1]
        else:
            task_set = scenario[task_id]

        loader = DataLoader(task_set, batch_size=200, drop_last=False)
        x, y, _ = next(iter(loader))
        seen_classes = torch.unique(y)

        inv_subset_classes = classes[task_id + 3:]
        for c in inv_subset_classes:
            assert c not in seen_classes, task_id
        if all_seen_tasks:
            subset_classes = classes[:task_id + 3]
            for c in subset_classes:
                assert c in seen_classes, task_id
        else:
            assert classes[task_id + 2] in seen_classes
示例#4
0
def test_advanced_indexing_step(dataset):
    scenario = SegmentationClassIncremental(dataset,
                                            nb_classes=4,
                                            increment=1,
                                            mode="overlap")

    with pytest.raises(ValueError):
        task_set = scenario[0:4:2]
示例#5
0
def test_length_taskset(dataset, mode, lengths, increment):
    scenario = SegmentationClassIncremental(dataset,
                                            nb_classes=4,
                                            increment=increment,
                                            initial_increment=2,
                                            mode=mode)

    assert len(scenario) == len(lengths)
    for i, l in enumerate(lengths):
        assert len(scenario[i]) == l, i
示例#6
0
def test_labels(dataset, mode, increment):
    initial_increment = 2
    nb_classes = 4
    min_cls = 1

    scenario = SegmentationClassIncremental(
        dataset,
        nb_classes=nb_classes,
        increment=increment,
        initial_increment=initial_increment,
        mode=mode)

    if isinstance(increment, int) and increment == 2:
        increments = [2, 2]
    elif isinstance(increment, int) and increment == 1:
        increments = [2, 1, 1]
    else:
        increments = increment

    for task_id, task_set in enumerate(scenario):
        loader = DataLoader(task_set, batch_size=200, drop_last=False)
        x, y, t = next(iter(loader))

        assert len(x.shape) == 4
        assert len(y.shape) == 3
        assert len(t.shape) == 1
        assert x.shape[2:] == y.shape[1:]

        assert (t == task_id).all()

        seen_classes = set(torch.unique(y).numpy().tolist())

        max_cls = min_cls + increments[task_id]
        assert 0 in seen_classes, task_id
        if 4 not in seen_classes:
            assert 255 in seen_classes, task_id

        for c in list(range(min_cls, nb_classes + min_cls)):
            if mode in ("overlap", "disjoint"):
                if min_cls <= c < max_cls:
                    assert c in seen_classes, (c, task_id, min_cls, max_cls)
                else:
                    assert c not in seen_classes, (c, task_id, min_cls,
                                                   max_cls)
            elif mode == "sequential":
                if c < max_cls:
                    assert c in seen_classes, (c, task_id, min_cls, max_cls)
                else:
                    assert c not in seen_classes, (c, task_id, min_cls,
                                                   max_cls)

        min_cls += increments[task_id]
示例#7
0
def test_background_test(dataset, dataset_test, mode, test_background, train):
    scenario = SegmentationClassIncremental(dataset if train else dataset_test,
                                            nb_classes=4,
                                            increment=2,
                                            mode=mode,
                                            test_background=test_background)

    for task_set in scenario:
        loader = DataLoader(task_set, batch_size=200, drop_last=False)
        x, y, _ = next(iter(loader))

        if train or test_background:
            assert 0 in y
        else:
            assert 0 not in y
示例#8
0
def test_advanced_indexing(dataset, dataset_test, mode, start, end, classes,
                           train):
    scenario = SegmentationClassIncremental(dataset if train else dataset_test,
                                            nb_classes=4,
                                            increment=1,
                                            mode=mode)

    task_set = scenario[start:end]
    loader = DataLoader(task_set, batch_size=200, drop_last=False)
    _, y, t = next(iter(loader))

    t = torch.unique(t)
    y = torch.unique(y)

    assert len(t) == 1 and t[0] == end - 1
    assert set(y.numpy().tolist()) - set([0, 255]) == set(classes)
示例#9
0
def test_labels_overlap_dense_test(dataset_dense_test):
    scenario = SegmentationClassIncremental(dataset_dense_test,
                                            nb_classes=4,
                                            increment=1,
                                            mode="overlap")
    classes = [0, 255, 1, 2, 3, 4]

    for task_id, task_set in enumerate(scenario):
        loader = DataLoader(task_set, batch_size=200, drop_last=False)
        x, y, _ = next(iter(loader))
        seen_classes = torch.unique(y)

        subset_classes = classes[:task_id + 3]
        for c in subset_classes:
            assert c in seen_classes, task_id
        inv_subset_classes = classes[task_id + 3:]
        for c in inv_subset_classes:
            assert c not in seen_classes, task_id
        assert len(x) == 20