def test_init_cartpole_rllib_model(): """test the init methods""" hydra_overrides = {'rllib/runner': 'dev', 'model': 'rllib'} cfg = load_hydra_config('maze.conf', 'conf_rllib', hydra_overrides) runner = Factory(base_type=MazeRLlibRunner).instantiate(cfg.runner) runner.setup(cfg) ray_config, rllib_config, tune_config = runner.ray_config, runner.rllib_config, runner.tune_config assert isinstance(runner.env_factory(), CartPoleEnv) assert isinstance(ray_config, dict) assert isinstance(rllib_config, dict) assert isinstance(tune_config, dict) assert rllib_config['env'] == 'maze_env' assert rllib_config['framework'] == 'torch' assert rllib_config['num_workers'] == 1 for k, v in rllib_config['model'].items(): if v == "DEPRECATED_VALUE": v = DEPRECATED_VALUE assert k in MODEL_DEFAULTS, f'Maze RLlib model parameter \'{k}\' not in RLlib MODEL_DEFAULTS (rllib version: ' \ f'{ray.__version__})' assert MODEL_DEFAULTS[k] == v, f'Rllib key:\'{k}\',value:\'{MODEL_DEFAULTS[k]}\' does not match with the ' \ f'maze defined config \'{v}\' with rllib version: {ray.__version__}' if 'ObservationNormalizationWrapper' in cfg.wrappers: assert os.path.exists( cfg.wrappers.ObservationNormalizationWrapper.statistics_dump) os.remove(cfg.wrappers.ObservationNormalizationWrapper.statistics_dump)
def test_init_cartpole_maze_model(): """test the init methods """ hydra_overrides = { 'rllib/runner': 'dev', 'configuration': 'test', 'env': 'gym_env', 'model': 'vector_obs', 'wrappers': 'vector_obs', 'critic': 'template_state' } cfg = load_hydra_config('maze.conf', 'conf_rllib', hydra_overrides) runner = Factory(base_type=MazeRLlibRunner).instantiate(cfg.runner) runner.setup(cfg) ray_config, rllib_config, tune_config = runner.ray_config, runner.rllib_config, runner.tune_config assert isinstance(runner.env_factory(), CartPoleEnv) assert issubclass(_global_registry.get(RLLIB_ACTION_DIST, 'maze_dist'), MazeRLlibActionDistribution) assert issubclass(_global_registry.get(RLLIB_MODEL, 'maze_model'), MazeRLlibPolicyModel) assert isinstance(ray_config, dict) assert isinstance(rllib_config, dict) assert isinstance(tune_config, dict) assert rllib_config['env'] == 'maze_env' assert rllib_config['framework'] == 'torch' assert rllib_config['num_workers'] == 1 model_config = rllib_config['model'] assert model_config['custom_action_dist'] == 'maze_dist' assert model_config['custom_model'] == 'maze_model' assert model_config['vf_share_layers'] is False assert model_config['custom_model_config'][ 'maze_model_composer_config'] == cfg.model assert model_config['custom_model_config'][ 'spaces_config_dump_file'] == cfg.runner.spaces_config_dump_file if 'ObservationNormalizationWrapper' in cfg.wrappers: assert os.path.exists( cfg.wrappers.ObservationNormalizationWrapper.statistics_dump) os.remove(cfg.wrappers.ObservationNormalizationWrapper.statistics_dump)