def test_on_task_switch_is_called_multi_task():
    setting = MultiTaskSetting(
        dataset="mnist",
        nb_tasks=5,
        # steps_per_task=100,
        # max_steps=500,
        # test_steps_per_task=100,
        train_transforms=[],
        test_transforms=[],
        val_transforms=[],
    )
    method = DummyMethod()
    results = setting.apply(method)
    assert method.n_task_switches == 0
    assert method.received_task_ids == []
    assert method.received_while_training == []
Esempio n. 2
0
def test_fit_and_on_task_switch_calls():
    setting = ContinualRLSetting(
        dataset=DummyEnvironment,
        nb_tasks=5,
        steps_per_task=100,
        max_steps=500,
        test_steps_per_task=100,
        train_transforms=[],
        test_transforms=[],
        val_transforms=[],
    )
    method = DummyMethod()
    results = setting.apply(method)
    # == 30 task switches in total.
    assert method.n_task_switches == 0
    assert method.n_fit_calls == 1  # TODO: Add something like this.
    assert not method.received_task_ids
    assert not method.received_while_training
Esempio n. 3
0
def test_on_task_switch_is_called_incremental_rl():
    setting = IncrementalRLSetting(
        dataset=DummyEnvironment,
        nb_tasks=5,
        steps_per_task=100,
        max_steps=500,
        test_steps_per_task=100,
        train_transforms=[],
        test_transforms=[],
        val_transforms=[],
    )
    method = DummyMethod()
    _ = setting.apply(method)
    # 5 after learning task 0
    # 5 after learning task 1
    # 5 after learning task 2
    # 5 after learning task 3
    # 5 after learning task 4
    # == 30 task switches in total.
    assert method.n_task_switches == 30
    assert method.received_task_ids == [
        0,
        *[None for _ in range(5)],
        1,
        *[None for _ in range(5)],
        2,
        *[None for _ in range(5)],
        3,
        *[None for _ in range(5)],
        4,
        *[None for _ in range(5)],
    ]
    assert method.received_while_training == [
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
    ]
Esempio n. 4
0
def test_metaworld_auto_task_schedule(
        pass_env_id_instead_of_env_instance: bool):
    """ Test that when passing just an env id from metaworld and a number of tasks,
    the task schedule is created automatically.
    """
    import metaworld
    from metaworld import MetaWorldEnv

    benchmark = metaworld.ML10()  # Construct the benchmark, sampling tasks

    env_name = "reach-v1"
    env_type: Type[MetaWorldEnv] = benchmark.train_classes[env_name]
    env = env_type()

    # TODO: When not passing a nb_tasks, the number of available tasks for that env
    # is used.
    # setting = TaskIncrementalRLSetting(
    #     dataset=env_name if pass_env_id_instead_of_env_instance else env,
    #     steps_per_task=1000,
    # )
    # assert setting.nb_tasks == 50
    # assert setting.steps_per_task == 1000
    # assert sorted(setting.train_task_schedule.keys()) == list(range(0, 50_000, 1000))

    # Test passing a number of tasks:
    setting = TaskIncrementalRLSetting(
        dataset=env_name if pass_env_id_instead_of_env_instance else env,
        steps_per_task=1000,
        nb_tasks=2,
        test_steps_per_task=1000,
        transforms=[],
    )
    assert setting.nb_tasks == 2
    assert setting.steps_per_task == 1000
    assert sorted(setting.train_task_schedule.keys()) == list(
        range(0, 2000, 1000))
    from sequoia.common.metrics.rl_metrics import EpisodeMetrics

    method = DummyMethod()
    results: IncrementalRLSetting.Results[EpisodeMetrics] = setting.apply(
        method)
Esempio n. 5
0
def test_on_task_switch_is_called_task_incremental_rl():
    setting = IncrementalRLSetting(
        dataset=DummyEnvironment,
        nb_tasks=5,
        steps_per_task=100,
        test_steps_per_task=100,
        max_steps=500,
        train_transforms=[],
        test_transforms=[],
        val_transforms=[],
        task_labels_at_test_time=True,
    )
    method = DummyMethod()
    _ = setting.apply(method)
    assert method.n_task_switches == 30
    assert method.received_task_ids == [
        0,
        *list(range(5)),
        1,
        *list(range(5)),
        2,
        *list(range(5)),
        3,
        *list(range(5)),
        4,
        *list(range(5)),
    ]
    assert method.received_while_training == [
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
    ]
Esempio n. 6
0
def test_monsterkong(task_labels_at_test_time: bool, state: bool):
    """ checks that the MonsterKong env works fine with monsterkong and state input. """
    setting = IncrementalRLSetting(
        dataset="monsterkong",
        observe_state_directly=state,
        nb_tasks=5,
        steps_per_task=100,
        test_steps_per_task=100,
        train_transforms=[],
        test_transforms=[],
        val_transforms=[],
        task_labels_at_test_time=task_labels_at_test_time,
        max_episode_steps=10,
    )

    if state:
        # State-based monsterkong: We observe a flattened version of the game state
        # (20 x 20 grid + player cell and goal cell, IIRC.)
        assert setting.observation_space.x == spaces.Box(
            0, 292, (402, ), np.int16)
    else:
        assert setting.observation_space.x == Image(0, 255, (64, 64, 3),
                                                    np.uint8)

    if task_labels_at_test_time:
        assert setting.observation_space.task_labels == spaces.Discrete(5)
    else:
        assert setting.observation_space.task_labels == Sparse(
            spaces.Discrete(5), sparsity=0.0)

    assert setting.test_steps == 500
    with setting.train_dataloader() as env:
        obs = env.reset()
        assert obs in setting.observation_space

    method = DummyMethod()
    _ = setting.apply(method)

    assert method.n_task_switches == 30
    if task_labels_at_test_time:
        assert method.received_task_ids == [
            0,
            *list(range(5)),
            1,
            *list(range(5)),
            2,
            *list(range(5)),
            3,
            *list(range(5)),
            4,
            *list(range(5)),
        ]
    else:
        assert method.received_task_ids == [
            0,
            *[None for _ in range(5)],
            1,
            *[None for _ in range(5)],
            2,
            *[None for _ in range(5)],
            3,
            *[None for _ in range(5)],
            4,
            *[None for _ in range(5)],
        ]
    assert method.received_while_training == [
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
    ]
Esempio n. 7
0
def test_monsterkong_pixels(task_labels_at_test_time: bool):
    """ checks that the MonsterKong env works fine with monsterkong and state input. """
    setting = IncrementalRLSetting(
        dataset="monsterkong",
        observe_state_directly=False,
        nb_tasks=5,
        steps_per_task=100,
        test_steps_per_task=100,
        train_transforms=[],
        test_transforms=[],
        val_transforms=[],
        task_labels_at_test_time=task_labels_at_test_time,
        max_episode_steps=10,
    )
    assert setting.test_steps == 500
    assert setting.observation_space.x == Image(0, 255, (64, 64, 3), np.uint8)
    with setting.train_dataloader() as env:
        obs = env.reset()
        assert obs in setting.observation_space

    method = DummyMethod()
    results = setting.apply(method)

    assert method.n_task_switches == 30
    if task_labels_at_test_time:
        assert method.received_task_ids == [
            0,
            *list(range(5)),
            1,
            *list(range(5)),
            2,
            *list(range(5)),
            3,
            *list(range(5)),
            4,
            *list(range(5)),
        ]
    else:
        assert method.received_task_ids == [
            0,
            *[None for _ in range(5)],
            1,
            *[None for _ in range(5)],
            2,
            *[None for _ in range(5)],
            3,
            *[None for _ in range(5)],
            4,
            *[None for _ in range(5)],
        ]
    assert method.received_while_training == [
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
        True,
        *[False for _ in range(5)],
    ]