示例#1
0
def test_MNIST_Fellowship():
    scenario = MNISTFellowship(data_path="./tests/Datasets",
                               train=True,
                               download=True)
    scenario.get_data()
    continuum = ClassIncremental(scenario, increment=10)
    assert len(continuum) == 3
示例#2
0
def test_MNIST_Fellowship_nb_classes(tmpdir):
    dataset = MNISTFellowship(data_path=tmpdir, train=True, download=True)
    x, y, t = dataset.get_data()
    assert len(np.unique(y)) == 30
    dataset = MNISTFellowship(data_path=tmpdir, train=True, download=True, update_labels=False)
    x, y, t = dataset.get_data()
    assert len(np.unique(y)) == 10
示例#3
0
def test_visualization_MNISTFellowship(tmpdir):
    cl_dataset = MNISTFellowship(data_path=tmpdir, download=True, train=True)
    scenario = ClassIncremental(cl_dataset=cl_dataset, increment=10)

    folder = os.path.join(tmpdir, "samples", "fellowship")
    if not os.path.exists(folder):
        os.makedirs(folder)

    for task_id, taskset in enumerate(scenario):
        taskset.plot(
            path=folder,
            title="MNISTFellowship_Incremental_{}.jpg".format(task_id),
            nb_samples=100,
            shape=[28, 28, 1])
示例#4
0
def test_MNIST_Fellowship_Instance_Incremental(nb_tasks, tmpdir):
    dataset = MNISTFellowship(data_path=tmpdir, train=True, download=True)
    dataset.get_data()
    continuum = InstanceIncremental(dataset, nb_tasks=nb_tasks)
    assert len(continuum) == 3
示例#5
0
def test_MNIST_Fellowship(tmpdir):
    dataset = MNISTFellowship(data_path=tmpdir, train=True, download=True)
    dataset.get_data()
    continuum = ClassIncremental(dataset, increment=10)
    assert len(continuum) == 3