def test_trainersettings_structure(): """ Test structuring method for TrainerSettings """ trainersettings_dict = { "trainer_type": "sac", "hyperparameters": { "batch_size": 1024 }, "max_steps": 1.0, "reward_signals": { "curiosity": { "encoding_size": 64 } }, } trainer_settings = TrainerSettings.structure(trainersettings_dict, TrainerSettings) assert isinstance(trainer_settings.hyperparameters, SACSettings) assert trainer_settings.trainer_type == TrainerType.SAC assert isinstance(trainer_settings.max_steps, int) assert RewardSignalType.CURIOSITY in trainer_settings.reward_signals # Check invalid trainer type with pytest.raises(ValueError): trainersettings_dict = { "trainer_type": "puppo", "hyperparameters": { "batch_size": 1024 }, "max_steps": 1.0, } TrainerSettings.structure(trainersettings_dict, TrainerSettings) # Check invalid hyperparameter with pytest.raises(TrainerConfigError): trainersettings_dict = { "trainer_type": "ppo", "hyperparameters": { "notahyperparam": 1024 }, "max_steps": 1.0, } TrainerSettings.structure(trainersettings_dict, TrainerSettings) # Check non-dict with pytest.raises(TrainerConfigError): TrainerSettings.structure("notadict", TrainerSettings) # Check hyperparameters specified but trainer type left as default. # This shouldn't work as you could specify non-PPO hyperparameters. with pytest.raises(TrainerConfigError): trainersettings_dict = {"hyperparameters": {"batch_size": 1024}} TrainerSettings.structure(trainersettings_dict, TrainerSettings)
def test_trainersettingsschedules_structure(): """ Test structuring method for Trainer Settings Schedule """ trainersettings_dict = { "trainer_type": "ppo", "hyperparameters": { "learning_rate_schedule": "linear", "beta_schedule": "constant", }, } trainer_settings = TrainerSettings.structure(trainersettings_dict, TrainerSettings) assert isinstance(trainer_settings.hyperparameters, PPOSettings) assert ( trainer_settings.hyperparameters.learning_rate_schedule == ScheduleType.LINEAR ) assert trainer_settings.hyperparameters.beta_schedule == ScheduleType.CONSTANT assert trainer_settings.hyperparameters.epsilon_schedule == ScheduleType.LINEAR