Пример #1
0
class AlternateTraining(Trainable):
    def _setup(self, config):
        self.config = config
        self.env = config['env']
        agent_config = self.config
        adv_config = deepcopy(self.config)
        agent_config['multiagent']['policies_to_train'] = ['agent']
        adv_config['multiagent']['policies_to_train'] = ['adversary0']

        self.agent_trainer = PPOTrainer(env=self.env, config=agent_config)
        self.adv_trainer = PPOTrainer(env=self.env, config=adv_config)

    def _train(self):
        # improve the Adversary policy
        print("-- Adversary Training --")
        print(pretty_print(self.adv_trainer.train()))

        # swap weights to synchronize
        self.agent_trainer.set_weights(self.adv_trainer.get_weights(["adversary0"]))

        # improve the Agent policy
        print("-- Agent Training --")
        output = self.agent_trainer.train()
        print(pretty_print(output))

        # swap weights to synchronize
        self.adv_trainer.set_weights(self.agent_trainer.get_weights(["agent"]))
        return output

    def _save(self, tmp_checkpoint_dir):
        return self.agent_trainer._save(tmp_checkpoint_dir)
class AlternateTraining(Trainable):
    def _setup(self, config):
        self.config = config
        self.env = config['env']
        agent_config = self.config
        adv_config = deepcopy(self.config)
        agent_config['multiagent']['policies_to_train'] = ['agent']
        adv_config['multiagent']['policies_to_train'] = ['adversary']

        self.agent_trainer = PPOTrainer(env=self.env, config=agent_config)
        self.adv_trainer = PPOTrainer(env=self.env, config=adv_config)

    def _train(self):
        # improve the Adversary policy
        print("-- Adversary Training --")
        original_weight = self.adv_trainer.get_weights(
            ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0]
        print(pretty_print(self.adv_trainer.train()))
        first_weight = self.adv_trainer.get_weights(
            ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0]

        # Check that the weights are updating after training
        assert original_weight != first_weight, 'The weight hasn\'t changed after training what gives'

        # swap weights to synchronize
        self.agent_trainer.set_weights(
            self.adv_trainer.get_weights(["adversary"]))

        # improve the Agent policy
        print("-- Agent Training --")
        output = self.agent_trainer.train()

        # Assert that the weight hasn't changed but it has
        new_weight = self.agent_trainer.get_weights(
            ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0]

        # Check that the adversary is not being trained when the agent trainer is training
        assert first_weight == new_weight, 'The weight of the adversary matrix has changed but it shouldnt have been updated!'

        # swap weights to synchronize
        self.adv_trainer.set_weights(self.agent_trainer.get_weights(["agent"]))

        return output

    def _save(self, tmp_checkpoint_dir):
        return self.agent_trainer._save(tmp_checkpoint_dir)