Esempio n. 1
0
class AlternateTraining(Trainable):
    def _setup(self, config):
        self.config = config
        self.env = config['env']
        agent_config = self.config
        adv_config = deepcopy(self.config)
        agent_config['multiagent']['policies_to_train'] = ['agent']
        adv_config['multiagent']['policies_to_train'] = ['adversary0']

        self.agent_trainer = PPOTrainer(env=self.env, config=agent_config)
        self.adv_trainer = PPOTrainer(env=self.env, config=adv_config)

    def _train(self):
        # improve the Adversary policy
        print("-- Adversary Training --")
        print(pretty_print(self.adv_trainer.train()))

        # swap weights to synchronize
        self.agent_trainer.set_weights(self.adv_trainer.get_weights(["adversary0"]))

        # improve the Agent policy
        print("-- Agent Training --")
        output = self.agent_trainer.train()
        print(pretty_print(output))

        # swap weights to synchronize
        self.adv_trainer.set_weights(self.agent_trainer.get_weights(["agent"]))
        return output

    def _save(self, tmp_checkpoint_dir):
        return self.agent_trainer._save(tmp_checkpoint_dir)
    def my_train_fn(config, reporter):
        assert args.num_learners >= 4, 'Requires 4 or more trainable agents'
        ppo_trainer = PPOTrainer(env='c4', config=config)
        while True:
            result = ppo_trainer.train()
            if 'evaluation' in result:
                train_policies = config['multiagent']['policies_to_train']
                scores = {k: v for k, v in result['evaluation']['policy_reward_mean'].items() if k in train_policies}

                scores_dist = softmax(np.array(list(scores.values())) / tau)
                new_trainables = random.choices(list(scores.keys()), scores_dist, k=len(scores))
                # new_trainables = train_policies
                # random.shuffle(new_trainables)

                weights = ppo_trainer.get_weights()
                new_weights = {old_pid: weights[new_pid] for old_pid, new_pid in zip(weights.keys(), new_trainables)}
                # new_weights = {pid: np.zeros_like(wt) for pid, wt in weights.items() if wt is not None}
                # new_weights = {pid: np.ones_like(wt)*-100 for pid, wt in weights.items() if wt is not None}
                # new_weights = {pid: np.random.rand(*wt.shape) for pid, wt in weights.items() if wt is not None}

                print('\n\n################\nSETTING WEIGHTS\n################\n\n')
                ppo_trainer.set_weights(new_weights)

                num_metrics = 4
                c = Counter(new_trainables)
                result['custom_metrics'].update(
                    {f'most_common{i:02d}': v[1] for i, v in enumerate(c.most_common(num_metrics))})
                result['custom_metrics'].update(
                    {f'scores_dist{i:02d}': v for i, v in enumerate(sorted(scores_dist, reverse=True)[:num_metrics])})
                print('scores_dist', scores_dist)
                # result['custom_metrics'].update(
                #     {f'new_agent{i:02d}': int(v[-2:]) for i, v in enumerate(new_trainables)})
            reporter(**result)
class AlternateTraining(Trainable):
    def _setup(self, config):
        self.config = config
        self.env = config['env']
        agent_config = self.config
        adv_config = deepcopy(self.config)
        agent_config['multiagent']['policies_to_train'] = ['agent']
        adv_config['multiagent']['policies_to_train'] = ['adversary']

        self.agent_trainer = PPOTrainer(env=self.env, config=agent_config)
        self.adv_trainer = PPOTrainer(env=self.env, config=adv_config)

    def _train(self):
        # improve the Adversary policy
        print("-- Adversary Training --")
        original_weight = self.adv_trainer.get_weights(
            ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0]
        print(pretty_print(self.adv_trainer.train()))
        first_weight = self.adv_trainer.get_weights(
            ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0]

        # Check that the weights are updating after training
        assert original_weight != first_weight, 'The weight hasn\'t changed after training what gives'

        # swap weights to synchronize
        self.agent_trainer.set_weights(
            self.adv_trainer.get_weights(["adversary"]))

        # improve the Agent policy
        print("-- Agent Training --")
        output = self.agent_trainer.train()

        # Assert that the weight hasn't changed but it has
        new_weight = self.agent_trainer.get_weights(
            ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0]

        # Check that the adversary is not being trained when the agent trainer is training
        assert first_weight == new_weight, 'The weight of the adversary matrix has changed but it shouldnt have been updated!'

        # swap weights to synchronize
        self.adv_trainer.set_weights(self.agent_trainer.get_weights(["agent"]))

        return output

    def _save(self, tmp_checkpoint_dir):
        return self.agent_trainer._save(tmp_checkpoint_dir)
Esempio n. 4
0
    dqn_trainer = DQNTrainer(env="multi_cartpole",
                             config={
                                 "multiagent": {
                                     "policies": policies,
                                     "policy_mapping_fn": policy_mapping_fn,
                                     "policies_to_train": ["dqn_policy"],
                                 },
                                 "gamma": 0.95,
                                 "n_step": 3,
                             })

    # You should see both the printed X and Y approach 200 as this trains:
    # info:
    #   policy_reward_mean:
    #     dqn_policy: X
    #     ppo_policy: Y
    for i in range(args.num_iters):
        print("== Iteration", i, "==")

        # improve the DQN policy
        print("-- DQN --")
        print(pretty_print(dqn_trainer.train()))

        # improve the PPO policy
        print("-- PPO --")
        print(pretty_print(ppo_trainer.train()))

        # swap weights to synchronize
        dqn_trainer.set_weights(ppo_trainer.get_weights(["ppo_policy"]))
        ppo_trainer.set_weights(dqn_trainer.get_weights(["dqn_policy"]))