예제 #1
0
    def test_avail_actions_qmix(self):
        grouping = {
            "group_1": ["agent_1", "agent_2"],
        }
        obs_space = Tuple([
            AvailActionsTestEnv.observation_space,
            AvailActionsTestEnv.observation_space
        ])
        act_space = Tuple([
            AvailActionsTestEnv.action_space, AvailActionsTestEnv.action_space
        ])
        register_env(
            "action_mask_test",
            lambda config: AvailActionsTestEnv(config).with_agent_groups(
                grouping, obs_space=obs_space, act_space=act_space))

        trainer = QMixTrainer(
            env="action_mask_test",
            config={
                "num_envs_per_worker": 5,  # test with vectorization on
                "env_config": {
                    "avail_actions": [3, 4, 8],
                },
                "framework": "torch",
            })
        for _ in range(4):
            trainer.train()  # OK if it doesn't trip the action assertion error
        assert trainer.train()["episode_reward_mean"] == 30.0
        trainer.stop()
        ray.shutdown()
    def test_avail_actions_qmix(self):
        grouping = {
            "group_1": ["agent_1"],  # trivial grouping for testing
        }
        obs_space = Tuple([AvailActionsTestEnv.observation_space])
        act_space = Tuple([AvailActionsTestEnv.action_space])
        register_env(
            "action_mask_test",
            lambda config: AvailActionsTestEnv(config).with_agent_groups(
                grouping, obs_space=obs_space, act_space=act_space))

        ray.init()
        agent = QMixTrainer(
            env="action_mask_test",
            config={
                "num_envs_per_worker": 5,  # test with vectorization on
                "env_config": {
                    "avail_action": 3,
                },
            })
        for _ in range(5):
            agent.train()  # OK if it doesn't trip the action assertion error
        assert agent.train()["episode_reward_mean"] == 21.0
예제 #3
0
                "Failed to obey available actions mask!"
        self.state += 1
        rewards = {"agent_1": 1}
        obs = {"agent_1": {"obs": 0, "action_mask": self.action_mask}}
        dones = {"__all__": self.state > 20}
        return obs, rewards, dones, {}


if __name__ == "__main__":
    grouping = {
        "group_1": ["agent_1"],  # trivial grouping for testing
    }
    obs_space = Tuple([AvailActionsTestEnv.observation_space])
    act_space = Tuple([AvailActionsTestEnv.action_space])
    register_env(
        "action_mask_test", lambda config: AvailActionsTestEnv(config).
        with_agent_groups(grouping, obs_space=obs_space, act_space=act_space))

    ray.init()
    agent = QMixTrainer(
        env="action_mask_test",
        config={
            "num_envs_per_worker": 5,  # test with vectorization on
            "env_config": {
                "avail_action": 3,
            },
        })
    for _ in range(5):
        agent.train()  # OK if it doesn't trip the action assertion error
    assert agent.train()["episode_reward_mean"] == 21.0
예제 #4
0
def main(args):
    logging.getLogger().setLevel(logging.INFO)
    date = datetime.now().strftime('%Y%m%d_%H%M%S')

    config_env = env_config(args)

    global_config = json.load(open('/home/skylark/PycharmRemote/Gamma-Reward-Perfect/config/global_config.json'))
    roadnet = global_config['cityflow_config_file']
    roadnet_list = roadnet.split('_')
    num_row = int(roadnet_list[1])
    num_col = int(roadnet_list[2].split('.')[0])
    print('\033[1;35mThe scale of the current roadnet is {} x {}\033[0m'.format(num_row, num_col))

    num_agents = num_row * num_col

    if args.mod == 'DQN':
        obs_space = CityFlowEnvRay.observation_space
        act_space = CityFlowEnvRay.action_space
        ModelCatalog.register_custom_model("MyKerasQModel", MyKerasQModel)
    elif args.mod == 'QMIX':
        grouping = {
            "group_1": [id_ for id_ in CityFlowEnvRay.intersection_id]
        }
        obs_space = Tuple([
            CityFlowEnvRay.observation_space for _ in range(num_agents)
        ])
        act_space = Tuple([
            CityFlowEnvRay.action_space for _ in range(num_agents)
        ])
        register_env(
            "cityflow_multi",
            lambda config_: CityFlowEnvRay(config_).with_agent_groups(
                grouping, obs_space=obs_space, act_space=act_space))
    config_agent = agent_config(args, num_agents, obs_space, act_space, config_env, num_row)

    ray.init(local_mode=False, redis_max_memory=1024 * 1024 * 40, temp_dir='/home/skylark/log/')
    if args.tune:  # False
        tune.run(
            "DQN",
            stop={
                "timesteps_total": 400000,
            },
            config=config_agent,
            checkpoint_freq=2,
            checkpoint_at_end=True,
        )
    else:
        if args.mod == 'DQN':  # True
            trainer = DQNTrainer(config=config_agent,
                                 env=CityFlowEnvRay)

            for i in range(args.n_epoch):
                # Perform one iteration of training the policy with DQN
                result = trainer.train()
                print(pretty_print(result))
        elif args.mod == 'QMIX':  # False
            trainer = QMixTrainer(config=config_agent,
                                  env="cityflow_multi")

            for i in range(args.n_epoch):
                # Perform one iteration of training the policy with DQN
                result = trainer.train()
                print(pretty_print(result))