class AlternateTraining(Trainable): def _setup(self, config): self.config = config self.env = config['env'] agent_config = self.config adv_config = deepcopy(self.config) agent_config['multiagent']['policies_to_train'] = ['agent'] adv_config['multiagent']['policies_to_train'] = ['adversary0'] self.agent_trainer = PPOTrainer(env=self.env, config=agent_config) self.adv_trainer = PPOTrainer(env=self.env, config=adv_config) def _train(self): # improve the Adversary policy print("-- Adversary Training --") print(pretty_print(self.adv_trainer.train())) # swap weights to synchronize self.agent_trainer.set_weights(self.adv_trainer.get_weights(["adversary0"])) # improve the Agent policy print("-- Agent Training --") output = self.agent_trainer.train() print(pretty_print(output)) # swap weights to synchronize self.adv_trainer.set_weights(self.agent_trainer.get_weights(["agent"])) return output def _save(self, tmp_checkpoint_dir): return self.agent_trainer._save(tmp_checkpoint_dir)
class AlternateTraining(Trainable): def _setup(self, config): self.config = config self.env = config['env'] agent_config = self.config adv_config = deepcopy(self.config) agent_config['multiagent']['policies_to_train'] = ['agent'] adv_config['multiagent']['policies_to_train'] = ['adversary'] self.agent_trainer = PPOTrainer(env=self.env, config=agent_config) self.adv_trainer = PPOTrainer(env=self.env, config=adv_config) def _train(self): # improve the Adversary policy print("-- Adversary Training --") original_weight = self.adv_trainer.get_weights( ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0] print(pretty_print(self.adv_trainer.train())) first_weight = self.adv_trainer.get_weights( ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0] # Check that the weights are updating after training assert original_weight != first_weight, 'The weight hasn\'t changed after training what gives' # swap weights to synchronize self.agent_trainer.set_weights( self.adv_trainer.get_weights(["adversary"])) # improve the Agent policy print("-- Agent Training --") output = self.agent_trainer.train() # Assert that the weight hasn't changed but it has new_weight = self.agent_trainer.get_weights( ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0] # Check that the adversary is not being trained when the agent trainer is training assert first_weight == new_weight, 'The weight of the adversary matrix has changed but it shouldnt have been updated!' # swap weights to synchronize self.adv_trainer.set_weights(self.agent_trainer.get_weights(["agent"])) return output def _save(self, tmp_checkpoint_dir): return self.agent_trainer._save(tmp_checkpoint_dir)