def test_avail_actions_qmix(self): grouping = { "group_1": ["agent_1", "agent_2"], } obs_space = Tuple([ AvailActionsTestEnv.observation_space, AvailActionsTestEnv.observation_space ]) act_space = Tuple([ AvailActionsTestEnv.action_space, AvailActionsTestEnv.action_space ]) register_env( "action_mask_test", lambda config: AvailActionsTestEnv(config).with_agent_groups( grouping, obs_space=obs_space, act_space=act_space)) trainer = QMixTrainer( env="action_mask_test", config={ "num_envs_per_worker": 5, # test with vectorization on "env_config": { "avail_actions": [3, 4, 8], }, "framework": "torch", }) for _ in range(4): trainer.train() # OK if it doesn't trip the action assertion error assert trainer.train()["episode_reward_mean"] == 30.0 trainer.stop() ray.shutdown()
def test_avail_actions_qmix(self): grouping = { "group_1": ["agent_1"], # trivial grouping for testing } obs_space = Tuple([AvailActionsTestEnv.observation_space]) act_space = Tuple([AvailActionsTestEnv.action_space]) register_env( "action_mask_test", lambda config: AvailActionsTestEnv(config).with_agent_groups( grouping, obs_space=obs_space, act_space=act_space)) ray.init() agent = QMixTrainer( env="action_mask_test", config={ "num_envs_per_worker": 5, # test with vectorization on "env_config": { "avail_action": 3, }, }) for _ in range(5): agent.train() # OK if it doesn't trip the action assertion error assert agent.train()["episode_reward_mean"] == 21.0
"Failed to obey available actions mask!" self.state += 1 rewards = {"agent_1": 1} obs = {"agent_1": {"obs": 0, "action_mask": self.action_mask}} dones = {"__all__": self.state > 20} return obs, rewards, dones, {} if __name__ == "__main__": grouping = { "group_1": ["agent_1"], # trivial grouping for testing } obs_space = Tuple([AvailActionsTestEnv.observation_space]) act_space = Tuple([AvailActionsTestEnv.action_space]) register_env( "action_mask_test", lambda config: AvailActionsTestEnv(config). with_agent_groups(grouping, obs_space=obs_space, act_space=act_space)) ray.init() agent = QMixTrainer( env="action_mask_test", config={ "num_envs_per_worker": 5, # test with vectorization on "env_config": { "avail_action": 3, }, }) for _ in range(5): agent.train() # OK if it doesn't trip the action assertion error assert agent.train()["episode_reward_mean"] == 21.0
def main(args): logging.getLogger().setLevel(logging.INFO) date = datetime.now().strftime('%Y%m%d_%H%M%S') config_env = env_config(args) global_config = json.load(open('/home/skylark/PycharmRemote/Gamma-Reward-Perfect/config/global_config.json')) roadnet = global_config['cityflow_config_file'] roadnet_list = roadnet.split('_') num_row = int(roadnet_list[1]) num_col = int(roadnet_list[2].split('.')[0]) print('\033[1;35mThe scale of the current roadnet is {} x {}\033[0m'.format(num_row, num_col)) num_agents = num_row * num_col if args.mod == 'DQN': obs_space = CityFlowEnvRay.observation_space act_space = CityFlowEnvRay.action_space ModelCatalog.register_custom_model("MyKerasQModel", MyKerasQModel) elif args.mod == 'QMIX': grouping = { "group_1": [id_ for id_ in CityFlowEnvRay.intersection_id] } obs_space = Tuple([ CityFlowEnvRay.observation_space for _ in range(num_agents) ]) act_space = Tuple([ CityFlowEnvRay.action_space for _ in range(num_agents) ]) register_env( "cityflow_multi", lambda config_: CityFlowEnvRay(config_).with_agent_groups( grouping, obs_space=obs_space, act_space=act_space)) config_agent = agent_config(args, num_agents, obs_space, act_space, config_env, num_row) ray.init(local_mode=False, redis_max_memory=1024 * 1024 * 40, temp_dir='/home/skylark/log/') if args.tune: # False tune.run( "DQN", stop={ "timesteps_total": 400000, }, config=config_agent, checkpoint_freq=2, checkpoint_at_end=True, ) else: if args.mod == 'DQN': # True trainer = DQNTrainer(config=config_agent, env=CityFlowEnvRay) for i in range(args.n_epoch): # Perform one iteration of training the policy with DQN result = trainer.train() print(pretty_print(result)) elif args.mod == 'QMIX': # False trainer = QMixTrainer(config=config_agent, env="cityflow_multi") for i in range(args.n_epoch): # Perform one iteration of training the policy with DQN result = trainer.train() print(pretty_print(result))