Ejemplo n.º 1
0
def test_save_memory(tmpdir,
                     scenario,
                     memory_size=50,
                     method="random",
                     fixed=True):
    memory = rehearsal.RehearsalMemory(memory_size, method, fixed, 10)
    assert len(memory) == 0

    c = 0
    for task_id, taskset in enumerate(scenario):
        x, y, t = taskset.get_raw_samples()

        c += 2
        memory.add(x, y, t, x)

        seen_classes = memory.seen_classes

        memory.save(os.path.join(tmpdir, f"memory_{task_id}.npz"))

        new_memory = rehearsal.RehearsalMemory(memory_size, method, fixed, 10)
        new_memory.load(os.path.join(tmpdir, f"memory_{task_id}.npz"))

        assert memory.seen_classes == new_memory.seen_classes
        assert len(memory) == len(new_memory)
        assert (memory._x == new_memory._x).all()
        assert (memory._y == new_memory._y).all()
        assert (memory._t == new_memory._t).all()

        memory.load(os.path.join(tmpdir, f"memory_{task_id}.npz"))
Ejemplo n.º 2
0
def test_memory_slice():
    memory = rehearsal.RehearsalMemory(20, "random", True, 10)
    memory.add(np.random.randn(20, 3, 4, 4), np.arange(20), np.arange(20),
               None)

    _, sliced_y, _ = memory.slice(keep_classes=list(range(20)))
    assert (np.unique(sliced_y) == np.array(list(range(20)))).all()
Ejemplo n.º 3
0
def test_memory_add_past_data(fixed):
    memory = rehearsal.RehearsalMemory(20, "random", fixed, 10)

    for i in range(10):
        memory.add(np.random.randn(20, 3, 4, 4),
                   np.ones(20) * i, np.ones(20), None)
        assert len(memory) <= 20

        if i > 0:
            memory.add(
                np.random.randn(20, 3, 4, 4),
                np.ones(20) * (i - 1),  # past data
                np.ones(20),
                None)
            assert len(memory) <= 20
Ejemplo n.º 4
0
def test_memory(scenario, memory_size, method, fixed):
    memory = rehearsal.RehearsalMemory(memory_size, method, fixed, 10)
    assert len(memory) == 0

    c = 0
    for taskset in scenario:
        x, y, t = taskset.get_raw_samples()

        c += 2
        memory.add(x, y, t, x)
        assert memory.nb_classes == c
        if fixed:
            nb_per_class = min(memory_size, 100) // 10
            assert len(memory) == nb_per_class * c
        else:
            ideal_size = min(min(memory_size, 100), 10 * c)
            assert ideal_size >= len(memory) >= ideal_size - 5
Ejemplo n.º 5
0
def test_memory_name(name, method):
    memory = rehearsal.RehearsalMemory(20, name, True, 10)
    assert memory.herding_method == method