示例#1
0
def rl_track_setting(tmp_path):
    # NOTE: Here instead of loading the `rl_track.yaml`, we create instantiate it
    # directly, because we want to reduce the length of the task for testing, and it
    # isn't currently possible to both pass a preset yaml file and also pass kwargs to
    # the SettingProxy.
    setting = SettingProxy(
        IncrementalRLSetting,
        dataset="monsterkong",
        train_task_schedule={
            0: {"level": 0},
            1: {"level": 1},
            2: {"level": 10},
            3: {"level": 11},
            4: {"level": 20},
            5: {"level": 21},
            6: {"level": 30},
            7: {"level": 31},
        },
        steps_per_task=2_000,  # Reduced length for testing
        test_steps_per_task=2_000,
        monitor_training_performance=True,
        task_labels_at_train_time=True,
    )
    assert setting.steps_per_phase == 2000
    assert sorted(setting.train_task_schedule.keys()) == list(range(0, 16_000, 2000))
    return setting
示例#2
0
def cartpole_state_setting():
    setting = SettingProxy(
        RLSetting,
        dataset="cartpole",
        max_steps=5_000,
        test_steps=2_000,
        monitor_training_performance=True,
    )
    return setting
示例#3
0
def sl_track_setting():
    setting = SettingProxy(
        ClassIncrementalSetting,
        "sl_track",
        # dataset="synbols",
        # nb_tasks=12,
        # class_order=class_order,
    )
    return setting
示例#4
0
def incremental_cartpole_state_setting():
    setting = SettingProxy(
        IncrementalRLSetting,
        dataset="cartpole",
        max_steps=10_000,
        nb_tasks=2,
        test_steps=2_000,
        monitor_training_performance=True,
    )
    return setting
示例#5
0
def sl_track_setting():
    setting = SettingProxy(
        ClassIncrementalSetting,
        "sl_track",
        # dataset="synbols",
        # nb_tasks=12,
        # class_order=class_order,
        # monitor_training_performance=True,
    )
    return setting
示例#6
0
def rl_track_setting():
    setting = SettingProxy(
        IncrementalRLSetting,
        # "rl_track", # TODO: Levels 0-20 work for now in MonsterKong.
        "rl_track",
        steps_per_task=2_000,  # just for testing.
        test_steps_per_task=2_000,  # just for testing.
        # dataset="synbols",
        # nb_tasks=12,
        # class_order=class_order,
    )
    return setting
示例#7
0
def run_track(method: Method, setting: Setting, yamlfile: str) -> Results:
    setting = SettingProxy(setting, yamlfile)
    results = setting.apply(method)
    print(f"Results summary:\n" f"{results.summary()}")
    print("=====================")
    print(results.to_log_dict())
示例#8
0
def mnist_setting():
    return SettingProxy(
        ClassIncrementalSetting,
        dataset="mnist",
        monitor_training_performance=True,
    )
示例#9
0
def mnist_setting():
    return SettingProxy(
        ClassIncrementalSetting,
        dataset="mnist",
    )
示例#10
0
def fashion_mnist_setting():
    return SettingProxy(
        ClassIncrementalSetting,
        dataset="fashionmnist",
    )