def test_avail_actions_qmix(self): grouping = { "group_1": ["agent_1", "agent_2"], } obs_space = Tuple([ AvailActionsTestEnv.observation_space, AvailActionsTestEnv.observation_space ]) act_space = Tuple([ AvailActionsTestEnv.action_space, AvailActionsTestEnv.action_space ]) register_env( "action_mask_test", lambda config: AvailActionsTestEnv(config).with_agent_groups( grouping, obs_space=obs_space, act_space=act_space)) trainer = QMixTrainer( env="action_mask_test", config={ "num_envs_per_worker": 5, # test with vectorization on "env_config": { "avail_actions": [3, 4, 8], }, "framework": "torch", }) for _ in range(4): trainer.train() # OK if it doesn't trip the action assertion error assert trainer.train()["episode_reward_mean"] == 30.0 trainer.stop() ray.shutdown()
def test_avail_actions_qmix(self): grouping = { "group_1": ["agent_1"], # trivial grouping for testing } obs_space = Tuple([AvailActionsTestEnv.observation_space]) act_space = Tuple([AvailActionsTestEnv.action_space]) register_env( "action_mask_test", lambda config: AvailActionsTestEnv(config).with_agent_groups( grouping, obs_space=obs_space, act_space=act_space)) ray.init() agent = QMixTrainer( env="action_mask_test", config={ "num_envs_per_worker": 5, # test with vectorization on "env_config": { "avail_action": 3, }, }) for _ in range(5): agent.train() # OK if it doesn't trip the action assertion error assert agent.train()["episode_reward_mean"] == 21.0
"Failed to obey available actions mask!" self.state += 1 rewards = {"agent_1": 1} obs = {"agent_1": {"obs": 0, "action_mask": self.action_mask}} dones = {"__all__": self.state > 20} return obs, rewards, dones, {} if __name__ == "__main__": grouping = { "group_1": ["agent_1"], # trivial grouping for testing } obs_space = Tuple([AvailActionsTestEnv.observation_space]) act_space = Tuple([AvailActionsTestEnv.action_space]) register_env( "action_mask_test", lambda config: AvailActionsTestEnv(config). with_agent_groups(grouping, obs_space=obs_space, act_space=act_space)) ray.init() agent = QMixTrainer( env="action_mask_test", config={ "num_envs_per_worker": 5, # test with vectorization on "env_config": { "avail_action": 3, }, }) for _ in range(5): agent.train() # OK if it doesn't trip the action assertion error assert agent.train()["episode_reward_mean"] == 21.0