class AlternateTraining(Trainable): def _setup(self, config): self.config = config self.env = config['env'] agent_config = self.config adv_config = deepcopy(self.config) agent_config['multiagent']['policies_to_train'] = ['agent'] adv_config['multiagent']['policies_to_train'] = ['adversary0'] self.agent_trainer = PPOTrainer(env=self.env, config=agent_config) self.adv_trainer = PPOTrainer(env=self.env, config=adv_config) def _train(self): # improve the Adversary policy print("-- Adversary Training --") print(pretty_print(self.adv_trainer.train())) # swap weights to synchronize self.agent_trainer.set_weights(self.adv_trainer.get_weights(["adversary0"])) # improve the Agent policy print("-- Agent Training --") output = self.agent_trainer.train() print(pretty_print(output)) # swap weights to synchronize self.adv_trainer.set_weights(self.agent_trainer.get_weights(["agent"])) return output def _save(self, tmp_checkpoint_dir): return self.agent_trainer._save(tmp_checkpoint_dir)
def my_train_fn(config, reporter): assert args.num_learners >= 4, 'Requires 4 or more trainable agents' ppo_trainer = PPOTrainer(env='c4', config=config) while True: result = ppo_trainer.train() if 'evaluation' in result: train_policies = config['multiagent']['policies_to_train'] scores = {k: v for k, v in result['evaluation']['policy_reward_mean'].items() if k in train_policies} scores_dist = softmax(np.array(list(scores.values())) / tau) new_trainables = random.choices(list(scores.keys()), scores_dist, k=len(scores)) # new_trainables = train_policies # random.shuffle(new_trainables) weights = ppo_trainer.get_weights() new_weights = {old_pid: weights[new_pid] for old_pid, new_pid in zip(weights.keys(), new_trainables)} # new_weights = {pid: np.zeros_like(wt) for pid, wt in weights.items() if wt is not None} # new_weights = {pid: np.ones_like(wt)*-100 for pid, wt in weights.items() if wt is not None} # new_weights = {pid: np.random.rand(*wt.shape) for pid, wt in weights.items() if wt is not None} print('\n\n################\nSETTING WEIGHTS\n################\n\n') ppo_trainer.set_weights(new_weights) num_metrics = 4 c = Counter(new_trainables) result['custom_metrics'].update( {f'most_common{i:02d}': v[1] for i, v in enumerate(c.most_common(num_metrics))}) result['custom_metrics'].update( {f'scores_dist{i:02d}': v for i, v in enumerate(sorted(scores_dist, reverse=True)[:num_metrics])}) print('scores_dist', scores_dist) # result['custom_metrics'].update( # {f'new_agent{i:02d}': int(v[-2:]) for i, v in enumerate(new_trainables)}) reporter(**result)
class AlternateTraining(Trainable): def _setup(self, config): self.config = config self.env = config['env'] agent_config = self.config adv_config = deepcopy(self.config) agent_config['multiagent']['policies_to_train'] = ['agent'] adv_config['multiagent']['policies_to_train'] = ['adversary'] self.agent_trainer = PPOTrainer(env=self.env, config=agent_config) self.adv_trainer = PPOTrainer(env=self.env, config=adv_config) def _train(self): # improve the Adversary policy print("-- Adversary Training --") original_weight = self.adv_trainer.get_weights( ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0] print(pretty_print(self.adv_trainer.train())) first_weight = self.adv_trainer.get_weights( ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0] # Check that the weights are updating after training assert original_weight != first_weight, 'The weight hasn\'t changed after training what gives' # swap weights to synchronize self.agent_trainer.set_weights( self.adv_trainer.get_weights(["adversary"])) # improve the Agent policy print("-- Agent Training --") output = self.agent_trainer.train() # Assert that the weight hasn't changed but it has new_weight = self.agent_trainer.get_weights( ["adversary"])['adversary']['adversary/fc_1/kernel'][0, 0] # Check that the adversary is not being trained when the agent trainer is training assert first_weight == new_weight, 'The weight of the adversary matrix has changed but it shouldnt have been updated!' # swap weights to synchronize self.adv_trainer.set_weights(self.agent_trainer.get_weights(["agent"])) return output def _save(self, tmp_checkpoint_dir): return self.agent_trainer._save(tmp_checkpoint_dir)
dqn_trainer = DQNTrainer(env="multi_cartpole", config={ "multiagent": { "policies": policies, "policy_mapping_fn": policy_mapping_fn, "policies_to_train": ["dqn_policy"], }, "gamma": 0.95, "n_step": 3, }) # You should see both the printed X and Y approach 200 as this trains: # info: # policy_reward_mean: # dqn_policy: X # ppo_policy: Y for i in range(args.num_iters): print("== Iteration", i, "==") # improve the DQN policy print("-- DQN --") print(pretty_print(dqn_trainer.train())) # improve the PPO policy print("-- PPO --") print(pretty_print(ppo_trainer.train())) # swap weights to synchronize dqn_trainer.set_weights(ppo_trainer.get_weights(["ppo_policy"])) ppo_trainer.set_weights(dqn_trainer.get_weights(["dqn_policy"]))