Beispiel #1
0
    def test_train_multi_agent_cartpole_multi_policy(self):
        n = 10
        register_env("multi_agent_cartpole",
                     lambda _: MultiAgentCartPole({"num_agents": n}))
        single_env = gym.make("CartPole-v0")

        def gen_policy():
            config = {
                "gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
                "n_step": random.choice([1, 2, 3, 4, 5]),
            }
            obs_space = single_env.observation_space
            act_space = single_env.action_space
            return (None, obs_space, act_space, config)

        pg = PGTrainer(env="multi_agent_cartpole",
                       config={
                           "num_workers": 0,
                           "multiagent": {
                               "policies": {
                                   "policy_1": gen_policy(),
                                   "policy_2": gen_policy(),
                               },
                               "policy_mapping_fn":
                               lambda agent_id: "policy_1",
                           },
                           "framework": "tf",
                       })

        # Just check that it runs without crashing
        for i in range(10):
            result = pg.train()
            print("Iteration {}, reward {}, timesteps {}".format(
                i, result["episode_reward_mean"], result["timesteps_total"]))
        self.assertTrue(
            pg.compute_action([0, 0, 0, 0], policy_id="policy_1") in [0, 1])
        self.assertTrue(
            pg.compute_action([0, 0, 0, 0], policy_id="policy_2") in [0, 1])
        self.assertRaises(
            KeyError,
            lambda: pg.compute_action([0, 0, 0, 0], policy_id="policy_3"))
Beispiel #2
0
class PGrl(object):
    def __init__(self, env, env_config, config):
        self.config = config
        self.config['env_config'] = env_config
        self.env = env(env_config)
        self.agent = PGTrainer(config=self.config, env=env)

    def fit(self, checkpoint=None):
        if checkpoint is None:
            checkpoint = os.path.join(os.getcwd(), 'data/checkpoint_rl.pkl')
        for idx in trange(5):
            result = self.agent.train()
            LOGGER.warning('result: ', result)
            if (idx + 1) % 5 == 0:
                LOGGER.warning('Save checkpoint at: {}'.format(idx + 1))
                state = self.agent.save_to_object()
                with open(checkpoint, 'wb') as fp:
                    pickle.dump(state, fp, protocol=pickle.HIGHEST_PROTOCOL)
        return result

    def predict(self, checkpoint=None):
        if checkpoint is not None:
            with open(checkpoint, 'rb') as fp:
                state = pickle.load(fp)
            self.agent.restore_from_object(state)
        done = False
        episode_reward = 0
        obs = self.env.reset()
        actions = []
        while not done:
            action = self.agent.compute_action(obs)
            actions.append(action)
            obs, reward, done, info = self.env.step(action)
            episode_reward += reward
        results = {'action': actions, 'reward': episode_reward}
        return results