Пример #1
0
def demo_manual():
    """ Apply the custom method to a Setting, creating both manually in code. """
    # Create any Setting from the tree:
    from sequoia.settings import TaskIncrementalRLSetting, TaskIncrementalSetting

    # setting = TaskIncrementalSetting(dataset="mnist", nb_tasks=5)  # SL
    setting = TaskIncrementalRLSetting(  # RL
        dataset="cartpole",
        train_task_schedule={
            0: {
                "gravity": 10,
                "length": 0.5
            },
            5000: {
                "gravity": 10,
                "length": 1.0
            },
        },
        observe_state_directly=True,  # state input, rather than pixel input.
        max_steps=10_000,
    )

    ## Create the BaselineMethod:
    config = Config(debug=True)
    trainer_options = TrainerConfig(max_epochs=1)
    hparams = BaselineModel.HParams()
    base_method = BaselineMethod(hparams=hparams,
                                 config=config,
                                 trainer_options=trainer_options)

    ## Get the results of the baseline method:
    base_results = setting.apply(base_method, config=config)

    ## Create the CustomMethod:
    config = Config(debug=True)
    trainer_options = TrainerConfig(max_epochs=1)
    hparams = CustomizedBaselineModel.HParams()
    new_method = CustomMethod(hparams=hparams,
                              config=config,
                              trainer_options=trainer_options)

    ## Get the results for the 'improved' method:
    new_results = setting.apply(new_method, config=config)

    print(f"\n\nComparison: BaselineMethod vs CustomMethod")
    print("\n BaselineMethod results: ")
    print(base_results.summary())

    print("\n CustomMethod results: ")
    print(new_results.summary())
def test_multi_task_setting():
    method = BaselineMethod(no_wandb=True, max_epochs=1)
    setting = MultiTaskSLSetting(dataset="mnist")
    results = setting.apply(method)
    print(results.summary())

    assert results.final_performance_metrics[0].n_samples == 2112
    assert results.final_performance_metrics[1].n_samples == 2016
    assert results.final_performance_metrics[2].n_samples == 1888
    assert results.final_performance_metrics[3].n_samples == 1984
    assert results.final_performance_metrics[4].n_samples == 1984

    assert 0.95 <= results.final_performance_metrics[0].accuracy
    assert 0.95 <= results.final_performance_metrics[1].accuracy
    assert 0.95 <= results.final_performance_metrics[2].accuracy
    assert 0.95 <= results.final_performance_metrics[3].accuracy
    assert 0.95 <= results.final_performance_metrics[4].accuracy
def test_class_incremental_setting():
    method = BaselineMethod(no_wandb=True, max_epochs=1)
    setting = ClassIncrementalSetting()
    results = setting.apply(method)
    print(results.summary())

    assert results.final_performance_metrics[0].n_samples == 1984
    assert results.final_performance_metrics[1].n_samples == 2016
    assert results.final_performance_metrics[2].n_samples == 1984
    assert results.final_performance_metrics[3].n_samples == 2016
    assert results.final_performance_metrics[4].n_samples == 1984

    assert 0.48 <= results.final_performance_metrics[0].accuracy <= 0.55
    assert 0.48 <= results.final_performance_metrics[1].accuracy <= 0.55
    assert 0.60 <= results.final_performance_metrics[2].accuracy <= 0.95
    assert 0.75 <= results.final_performance_metrics[3].accuracy <= 0.98
    assert 0.99 <= results.final_performance_metrics[4].accuracy <= 1.00
Пример #4
0
def baseline_demo_simple():
    config = Config()
    method = BaselineMethod(config=config, max_epochs=1)

    ## Create *any* Setting from the tree, for example:
    ## Supervised Learning Setting:
    # setting = TaskIncrementalSLSetting(
    #     dataset="cifar10",
    #     nb_tasks=2,
    # )
    # Reinforcement Learning Setting:
    setting = TaskIncrementalRLSetting(
        dataset="cartpole",
        max_steps=4000,
        nb_tasks=2,
    )
    results = setting.apply(method, config=config)
    print(results.summary())
    return results