Esempio n. 1
0
def test_trainersettings_structure():
    """
    Test structuring method for TrainerSettings
    """
    trainersettings_dict = {
        "trainer_type": "sac",
        "hyperparameters": {
            "batch_size": 1024
        },
        "max_steps": 1.0,
        "reward_signals": {
            "curiosity": {
                "encoding_size": 64
            }
        },
    }
    trainer_settings = TrainerSettings.structure(trainersettings_dict,
                                                 TrainerSettings)
    assert isinstance(trainer_settings.hyperparameters, SACSettings)
    assert trainer_settings.trainer_type == TrainerType.SAC
    assert isinstance(trainer_settings.max_steps, int)
    assert RewardSignalType.CURIOSITY in trainer_settings.reward_signals

    # Check invalid trainer type
    with pytest.raises(ValueError):
        trainersettings_dict = {
            "trainer_type": "puppo",
            "hyperparameters": {
                "batch_size": 1024
            },
            "max_steps": 1.0,
        }
        TrainerSettings.structure(trainersettings_dict, TrainerSettings)

    # Check invalid hyperparameter
    with pytest.raises(TrainerConfigError):
        trainersettings_dict = {
            "trainer_type": "ppo",
            "hyperparameters": {
                "notahyperparam": 1024
            },
            "max_steps": 1.0,
        }
        TrainerSettings.structure(trainersettings_dict, TrainerSettings)

    # Check non-dict
    with pytest.raises(TrainerConfigError):
        TrainerSettings.structure("notadict", TrainerSettings)

    # Check hyperparameters specified but trainer type left as default.
    # This shouldn't work as you could specify non-PPO hyperparameters.
    with pytest.raises(TrainerConfigError):
        trainersettings_dict = {"hyperparameters": {"batch_size": 1024}}
        TrainerSettings.structure(trainersettings_dict, TrainerSettings)
Esempio n. 2
0
def test_trainersettingsschedules_structure():
    """
    Test structuring method for Trainer Settings Schedule
    """
    trainersettings_dict = {
        "trainer_type": "ppo",
        "hyperparameters": {
            "learning_rate_schedule": "linear",
            "beta_schedule": "constant",
        },
    }
    trainer_settings = TrainerSettings.structure(trainersettings_dict, TrainerSettings)
    assert isinstance(trainer_settings.hyperparameters, PPOSettings)
    assert (
        trainer_settings.hyperparameters.learning_rate_schedule == ScheduleType.LINEAR
    )
    assert trainer_settings.hyperparameters.beta_schedule == ScheduleType.CONSTANT
    assert trainer_settings.hyperparameters.epsilon_schedule == ScheduleType.LINEAR