def train_self_play(results_dir, scenario_name, print_train_results=True): scenario: PSROScenario = scenario_catalog.get(scenario_name=scenario_name) env_class = scenario.env_class env_config = scenario.env_config trainer_class = scenario.trainer_class policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes single_agent_symmetric_game = scenario.single_agent_symmetric_game if single_agent_symmetric_game: raise NotImplementedError get_trainer_config = scenario.get_trainer_config should_log_result_fn = scenario.ray_should_log_result_filter checkpoint_every_n_iters = 500 class PreAndPostEpisodeCallbacks(DefaultCallbacks): def on_train_result(self, *, trainer, result: dict, **kwargs): result["scenario_name"] = trainer.scenario_name training_iteration = result["training_iteration"] super().on_train_result(trainer=trainer, result=result, **kwargs) if training_iteration % checkpoint_every_n_iters == 0 or training_iteration == 1: for player in range(2): checkpoint_metadata = create_metadata_with_new_checkpoint( policy_id_to_save=f"best_response_{player}", br_trainer=trainer, policy_player=player, save_dir=checkpoint_dir(trainer=trainer), timesteps_training=result["timesteps_total"], episodes_training=result["episodes_total"], checkpoint_name= f"best_response_player_{player}_iter_{training_iteration}.h5" ) joint_pol_checkpoint_spec = StrategySpec( strategy_id= f"best_response_player_{player}_iter_{training_iteration}", metadata=checkpoint_metadata) checkpoint_path = os.path.join( spec_checkpoint_dir(trainer), f"best_response_player_{player}_iter_{training_iteration}.json" ) ensure_dir(checkpoint_path) with open(checkpoint_path, "+w") as checkpoint_spec_file: checkpoint_spec_file.write( joint_pol_checkpoint_spec.to_json()) def select_policy(agent_id): if agent_id == 0: return "best_response_0" elif agent_id == 1: return "best_response_1" else: raise ValueError(f"Unknown agent id: {agent_id}") init_ray_for_scenario(scenario=scenario, head_address=None, logging_level=logging.INFO) tmp_env = env_class(env_config=env_config) trainer_config = { "callbacks": PreAndPostEpisodeCallbacks, "env": env_class, "env_config": env_config, "gamma": 1.0, "num_gpus": 0, "num_workers": 0, "num_envs_per_worker": 1, "multiagent": { "policies_to_train": [f"best_response_0", "best_response_1"], "policies": { f"best_response_0": (policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, {}), f"best_response_1": (policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, {}), }, "policy_mapping_fn": select_policy, }, } trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_env)) trainer = trainer_class(config=trainer_config, logger_creator=get_trainer_logger_creator( base_dir=results_dir, scenario_name=scenario_name, should_log_result_fn=should_log_result_fn)) # scenario_name logged in on_train_result_callback trainer.scenario_name = scenario_name # Perform main RL training loop. while True: train_iter_results = trainer.train( ) # do a step (or several) in the main RL loop if print_train_results: # Delete verbose debugging info before printing if "hist_stats" in train_iter_results: del train_iter_results["hist_stats"] for key in ["best_response_0", "best_response_1"]: if "td_error" in train_iter_results["info"]["learner"][key]: del train_iter_results["info"]["learner"][key]["td_error"] print(pretty_dict_str(train_iter_results))
def train_poker_approx_best_response_nfsp( br_player, ray_head_address, scenario: NFSPScenario, general_trainer_config_overrrides, br_policy_config_overrides, get_stopping_condition, avg_policy_specs_for_players: Dict[int, StrategySpec], results_dir: str, print_train_results: bool = True): env_class = scenario.env_class env_config = scenario.env_config trainer_class = scenario.trainer_class policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes get_trainer_config = scenario.get_trainer_config should_log_result_fn = scenario.ray_should_log_result_filter init_ray_for_scenario(scenario=scenario, head_address=ray_head_address, logging_level=logging.INFO) def log(message, level=logging.INFO): logger.log(level, message) def select_policy(agent_id): if agent_id == br_player: return "best_response" else: return f"average_policy" tmp_env = env_class(env_config=env_config) avg_policy_model_config = get_trainer_config(tmp_env)["model"] br_trainer_config = { "log_level": "DEBUG", # "callbacks": None, "env": env_class, "env_config": env_config, "gamma": 1.0, # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. # "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "num_gpus": 0.0, "num_workers": 0, "num_gpus_per_worker": 0.0, "num_envs_per_worker": 1, "multiagent": { "policies_to_train": ["best_response"], "policies": { "average_policy": (policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, { "model": avg_policy_model_config, "explore": False, }), "best_response": (policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, br_policy_config_overrides), }, "policy_mapping_fn": select_policy, }, } br_trainer_config = merge_dicts(br_trainer_config, get_trainer_config(tmp_env)) br_trainer_config = merge_dicts(br_trainer_config, general_trainer_config_overrrides) br_trainer = trainer_class(config=br_trainer_config, logger_creator=get_trainer_logger_creator( base_dir=results_dir, scenario_name="approx_br", should_log_result_fn=should_log_result_fn)) def _set_avg_policy(worker: RolloutWorker): avg_policy = worker.policy_map["average_policy"] load_pure_strat( policy=avg_policy, pure_strat_spec=avg_policy_specs_for_players[1 - br_player]) br_trainer.workers.foreach_worker(_set_avg_policy) br_trainer.latest_avg_trainer_result = None train_iter_count = 0 stopping_condition: StoppingCondition = get_stopping_condition() max_reward = None while True: train_iter_results = br_trainer.train( ) # do a step (or several) in the main RL loop br_reward_this_iter = train_iter_results["policy_reward_mean"][ f"best_response"] if max_reward is None or br_reward_this_iter > max_reward: max_reward = br_reward_this_iter train_iter_count += 1 if print_train_results: # Delete verbose debugging info before printing if "hist_stats" in train_iter_results: del train_iter_results["hist_stats"] if "td_error" in train_iter_results["info"]["learner"][ "best_response"]: del train_iter_results["info"]["learner"]["best_response"][ "td_error"] print(pretty_dict_str(train_iter_results)) log(f"Trainer logdir is {br_trainer.logdir}") if stopping_condition.should_stop_this_iter( latest_trainer_result=train_iter_results): print("stopping condition met.") break return max_reward
def train_off_policy_rl_nfsp(results_dir: str, scenario_name: str, print_train_results: bool = True): scenario: NFSPScenario = scenario_catalog.get(scenario_name=scenario_name) if not isinstance(scenario, NFSPScenario): raise TypeError(f"Only instances of {NFSPScenario} can be used here. {scenario.name} is a {type(scenario)}.") env_class = scenario.env_class env_config = scenario.env_config trainer_class = scenario.trainer_class avg_trainer_class = scenario.avg_trainer_class policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes anticipatory_param: float = scenario.anticipatory_param get_trainer_config = scenario.get_trainer_config get_avg_trainer_config = scenario.get_avg_trainer_config calculate_openspiel_metanash: bool = scenario.calculate_openspiel_metanash calc_metanash_every_n_iters: int = scenario.calc_metanash_every_n_iters checkpoint_every_n_iters: Union[int, None] = scenario.checkpoint_every_n_iters nfsp_get_stopping_condition = scenario.nfsp_get_stopping_condition should_log_result_fn = scenario.ray_should_log_result_filter init_ray_for_scenario(scenario=scenario, head_address=None, logging_level=logging.INFO) def log(message, level=logging.INFO): logger.log(level, message) def select_policy(agent_id): random_sample = np.random.random() if agent_id == 0: if random_sample < anticipatory_param: return "best_response_0" return "average_policy_0" elif agent_id == 1: if random_sample < anticipatory_param: return "best_response_1" return "average_policy_1" else: raise ValueError(f"unexpected agent_id: {agent_id}") def assert_not_called(agent_id): assert False, "This function should never be called." tmp_env = env_class(env_config=env_config) def _create_env(): return env_class(env_config=env_config) avg_policy_model_config = get_trainer_config(tmp_env)["model"] avg_trainer_config = merge_dicts({ "log_level": "DEBUG", "framework": "torch", "env": env_class, "env_config": env_config, "num_gpus": 0.0, "num_gpus_per_worker": 0.0, "num_workers": 0, "num_envs_per_worker": 1, "multiagent": { "policies_to_train": ["average_policy_0", "average_policy_1"], "policies": { "average_policy_0": ( policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, { "model": avg_policy_model_config }), "average_policy_1": ( policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, { "model": avg_policy_model_config }), }, "policy_mapping_fn": assert_not_called, }, }, get_avg_trainer_config(tmp_env)) avg_trainer = avg_trainer_class(config=avg_trainer_config, logger_creator=get_trainer_logger_creator( base_dir=results_dir, scenario_name=f"{scenario_name}_avg_trainer", should_log_result_fn=should_log_result_fn)) store_to_avg_policy_buffer = get_store_to_avg_policy_buffer_fn(nfsp_trainer=avg_trainer) class NFSPBestResponseCallbacks(DefaultCallbacks): def on_postprocess_trajectory(self, *, worker: "RolloutWorker", episode: MultiAgentEpisode, agent_id: AgentID, policy_id: PolicyID, policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch, original_batches: Dict[Any, Tuple[Policy, SampleBatch]], **kwargs): super().on_postprocess_trajectory(worker=worker, episode=episode, agent_id=agent_id, policy_id=policy_id, policies=policies, postprocessed_batch=postprocessed_batch, original_batches=original_batches, **kwargs) postprocessed_batch.data["source_policy"] = [policy_id] * len(postprocessed_batch.data["rewards"]) # All data from both policies will go into the best response's replay buffer. # Here we ensure policies not from the best response have the exact same preprocessing as the best response. for average_policy_id, br_policy_id in [("average_policy_0", "best_response_0"), ("average_policy_1", "best_response_1")]: if policy_id == average_policy_id: if "action_probs" in postprocessed_batch: del postprocessed_batch.data["action_probs"] if "behaviour_logits" in postprocessed_batch: del postprocessed_batch.data["behaviour_logits"] br_policy: Policy = policies[br_policy_id] new_batch = br_policy.postprocess_trajectory( sample_batch=postprocessed_batch, other_agent_batches=original_batches, episode=episode) copy_attributes(src_obj=new_batch, dst_obj=postprocessed_batch) elif policy_id == br_policy_id: if "q_values" in postprocessed_batch: del postprocessed_batch.data["q_values"] if "action_probs" in postprocessed_batch: del postprocessed_batch.data["action_probs"] del postprocessed_batch.data["action_dist_inputs"] if policy_id in ("average_policy_0", "best_response_0"): assert agent_id == 0 if policy_id in ("average_policy_1", "best_response_1"): assert agent_id == 1 def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, env_index: int, **kwargs): super().on_episode_end(worker=worker, base_env=base_env, policies=policies, episode=episode, env_index=env_index, **kwargs) episode_policies = set(episode.agent_rewards.keys()) if episode_policies == {(0, "average_policy_0"), (1, "best_response_1")}: worker.avg_br_reward_deque.add.remote(episode.agent_rewards[(1, "best_response_1")]) elif episode_policies == {(1, "average_policy_1"), (0, "best_response_0")}: worker.avg_br_reward_deque.add.remote(episode.agent_rewards[(0, "best_response_0")]) def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs): super().on_sample_end(worker=worker, samples=samples, **kwargs) assert isinstance(samples, MultiAgentBatch) for policy_samples in samples.policy_batches.values(): if "action_prob" in policy_samples.data: del policy_samples.data["action_prob"] if "action_logp" in policy_samples.data: del policy_samples.data["action_logp"] for average_policy_id, br_policy_id in [("average_policy_0", "best_response_0"), ("average_policy_1", "best_response_1")]: for policy_id, policy_samples in samples.policy_batches.items(): if policy_id == br_policy_id: store_to_avg_policy_buffer(MultiAgentBatch(policy_batches={ average_policy_id: policy_samples }, env_steps=policy_samples.count)) if average_policy_id in samples.policy_batches: if br_policy_id in samples.policy_batches: all_policies_samples = samples.policy_batches[br_policy_id].concat( other=samples.policy_batches[average_policy_id]) else: all_policies_samples = samples.policy_batches[average_policy_id] del samples.policy_batches[average_policy_id] samples.policy_batches[br_policy_id] = all_policies_samples def on_train_result(self, *, trainer, result: dict, **kwargs): super().on_train_result(trainer=trainer, result=result, **kwargs) result["scenario_name"] = trainer.scenario_name result["avg_br_reward_both_players"] = ray.get(trainer.avg_br_reward_deque.get_mean.remote()) training_iteration = result["training_iteration"] if (calculate_openspiel_metanash and (training_iteration == 1 or training_iteration % calc_metanash_every_n_iters == 0)): base_env = _create_env() open_spiel_env_config = base_env.open_spiel_env_config openspiel_game_version = base_env.game_version local_avg_policy_0 = trainer.workers.local_worker().policy_map["average_policy_0"] local_avg_policy_1 = trainer.workers.local_worker().policy_map["average_policy_1"] exploitability = nfsp_measure_exploitability_nonlstm( rllib_policies=[local_avg_policy_0, local_avg_policy_1], poker_game_version=openspiel_game_version, open_spiel_env_config=open_spiel_env_config ) result["avg_policy_exploitability"] = exploitability logger.info(colored( f"(Graph this in a notebook) Exploitability: {exploitability} - Saving exploitability stats " f"to {os.path.join(trainer.logdir, 'result.json')}", "green")) if checkpoint_every_n_iters and (training_iteration % checkpoint_every_n_iters == 0 or training_iteration == 1): for player in range(2): checkpoint_metadata = create_metadata_with_new_checkpoint( policy_id_to_save=f"average_policy_{player}", br_trainer=br_trainer, save_dir=checkpoint_dir(trainer=br_trainer), timesteps_training=result["timesteps_total"], episodes_training=result["episodes_total"], checkpoint_name=f"average_policy_player_{player}_iter_{training_iteration}.h5" ) avg_pol_checkpoint_spec = StrategySpec( strategy_id=f"avg_pol_player_{player}_iter_{training_iteration}", metadata=checkpoint_metadata) checkpoint_path = os.path.join(spec_checkpoint_dir(br_trainer), f"average_policy_player_{player}_iter_{training_iteration}.json") ensure_dir(checkpoint_path) with open(checkpoint_path, "+w") as checkpoint_spec_file: checkpoint_spec_file.write(avg_pol_checkpoint_spec.to_json()) br_trainer_config = { "log_level": "DEBUG", "callbacks": NFSPBestResponseCallbacks, "env": env_class, "env_config": env_config, "gamma": 1.0, "num_gpus": 0.0, "num_workers": 0, "num_gpus_per_worker": 0.0, "num_envs_per_worker": 1, "multiagent": { "policies_to_train": ["best_response_0", "best_response_1"], "policies": { "average_policy_0": ( policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, { "model": avg_policy_model_config, "explore": False, }), "best_response_0": ( policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, {}), "average_policy_1": ( policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, { "model": avg_policy_model_config, "explore": False, }), "best_response_1": ( policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, {}), }, "policy_mapping_fn": select_policy, }, } br_trainer_config = merge_dicts(br_trainer_config, get_trainer_config(tmp_env)) br_trainer = trainer_class(config=br_trainer_config, logger_creator=get_trainer_logger_creator(base_dir=results_dir, scenario_name=scenario_name, should_log_result_fn=should_log_result_fn)) avg_br_reward_deque = StatDeque.remote(max_items=br_trainer_config["metrics_smoothing_episodes"]) def _set_avg_br_rew_deque(worker: RolloutWorker): worker.avg_br_reward_deque = avg_br_reward_deque br_trainer.workers.foreach_worker(_set_avg_br_rew_deque) br_trainer.avg_br_reward_deque = avg_br_reward_deque # scenario_name logged in on_train_result_callback br_trainer.scenario_name = scenario_name br_trainer.latest_avg_trainer_result = None train_iter_count = 0 for trainer in [br_trainer, avg_trainer]: for policy_id, policy in trainer.workers.local_worker().policy_map.items(): policy.policy_id = policy_id avg_weights = avg_trainer.get_weights(["average_policy_0", "average_policy_1"]) br_trainer.workers.foreach_worker(lambda worker: worker.set_weights(avg_weights)) stopping_condition: StoppingCondition = nfsp_get_stopping_condition() print("starting") while True: print("avg train...") avg_train_results = avg_trainer.train() avg_weights = avg_trainer.get_weights(["average_policy_0", "average_policy_1"]) br_trainer.workers.foreach_worker(lambda worker: worker.set_weights(avg_weights)) br_trainer.latest_avg_trainer_result = copy.deepcopy(avg_train_results) print("br train...") train_iter_results = br_trainer.train() # do a step (or several) in the main RL loop train_iter_count += 1 print("printing results..") if print_train_results: # Delete verbose debugging info before printing if "hist_stats" in train_iter_results: del train_iter_results["hist_stats"] if "td_error" in train_iter_results["info"]["learner"]["best_response_0"]: del train_iter_results["info"]["learner"]["best_response_0"]["td_error"] if "td_error" in train_iter_results["info"]["learner"]["best_response_1"]: del train_iter_results["info"]["learner"]["best_response_1"]["td_error"] log(pretty_dict_str(train_iter_results)) log(f"Trainer logdir is {br_trainer.logdir}") if stopping_condition.should_stop_this_iter(latest_trainer_result=train_iter_results): print("stopping condition met.") break
def train_nxdo_best_response(br_player: int, scenario_name: str, nxdo_manager_port: int, nxdo_manager_host: str, print_train_results: bool = True, previous_br_checkpoint_path=None): scenario: NXDOScenario = scenario_catalog.get(scenario_name=scenario_name) if not isinstance(scenario, NXDOScenario): raise TypeError(f"Only instances of {NXDOScenario} can be used here. {scenario.name} is a {type(scenario)}.") use_openspiel_restricted_game: bool = scenario.use_openspiel_restricted_game get_restricted_game_custom_model = scenario.get_restricted_game_custom_model env_class = scenario.env_class base_env_config = scenario.env_config trainer_class = scenario.trainer_class_br policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes_br get_trainer_config = scenario.get_trainer_config_br nxdo_br_get_stopping_condition = scenario.get_stopping_condition_br should_log_result_fn = scenario.ray_should_log_result_filter nxdo_metanash_method: str = scenario.xdo_metanash_method if nxdo_metanash_method != "nfsp": raise NotImplementedError("Only 'nfsp' is currently supported for the nxdo_metanash_method") nxdo_manager = RemoteNXDOManagerClient(n_players=2, port=nxdo_manager_port, remote_server_host=nxdo_manager_host) manager_metadata = nxdo_manager.get_manager_metadata() results_dir = nxdo_manager.get_log_dir() br_params = nxdo_manager.claim_new_active_policy_for_player(player=br_player) metanash_specs_for_players, delegate_specs_for_players, active_policy_num = br_params other_player = 1 - br_player br_learner_name = f"policy {active_policy_num} player {br_player}" def log(message): print(f"({br_learner_name}): {message}") def select_policy(agent_id): if agent_id == br_player: return f"best_response" elif agent_id == other_player: return f"metanash" else: raise ValueError(f"Unknown agent id: {agent_id}") restricted_env_config = { "create_env_fn": lambda: env_class(env_config=base_env_config), "raise_if_no_restricted_players": metanash_specs_for_players is not None, } tmp_base_env = env_class(env_config=base_env_config) if use_openspiel_restricted_game: restricted_game_class = OpenSpielRestrictedGame else: restricted_game_class = RestrictedGame restricted_env_config["use_delegate_policy_exploration"] = scenario.allow_stochastic_best_responses tmp_env = restricted_game_class(env_config=restricted_env_config) if metanash_specs_for_players is None or use_openspiel_restricted_game: other_player_restricted_action_space = tmp_env.base_action_space metanash_class = policy_classes["best_response"] else: other_player_restricted_action_space = Discrete(n=len(delegate_specs_for_players[other_player])) metanash_class = policy_classes["metanash"] print( f"metanash class: {metanash_class}, other_player_restricted_action_space: {other_player_restricted_action_space}") if metanash_specs_for_players is None and use_openspiel_restricted_game: other_player_restricted_obs_space = tmp_env.base_observation_space else: other_player_restricted_obs_space = tmp_env.observation_space trainer_config = { "env": restricted_game_class, "env_config": restricted_env_config, "gamma": 1.0, "num_gpus": 0, "num_workers": 0, "num_envs_per_worker": 1, "multiagent": { "policies_to_train": [f"best_response"], "policies": { f"metanash": (metanash_class, other_player_restricted_obs_space, other_player_restricted_action_space, {"explore": False}), f"metanash_delegate": (policy_classes["best_response"], tmp_env.base_observation_space, tmp_env.base_action_space, {"explore": scenario.allow_stochastic_best_responses}), f"best_response": (policy_classes["best_response"], tmp_env.base_observation_space, tmp_env.base_action_space, {}), }, "policy_mapping_fn": select_policy, }, } if metanash_specs_for_players is not None and get_restricted_game_custom_model is not None: trainer_config["multiagent"]["policies"]["metanash"][3]["model"] = { "custom_model": get_restricted_game_custom_model(tmp_env)} trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_base_env)) ray_head_address = manager_metadata["ray_head_address"] init_ray_for_scenario(scenario=scenario, head_address=ray_head_address, logging_level=logging.INFO) trainer = trainer_class(config=trainer_config, logger_creator=get_trainer_logger_creator( base_dir=results_dir, scenario_name=scenario_name, should_log_result_fn=should_log_result_fn)) # metanash is single pure strat spec def _set_worker_metanash(worker: RolloutWorker): if metanash_specs_for_players is not None: metanash_policy = worker.policy_map["metanash"] metanash_strategy_spec: StrategySpec = metanash_specs_for_players[other_player] load_pure_strat(policy=metanash_policy, pure_strat_spec=metanash_strategy_spec) trainer.workers.foreach_worker(_set_worker_metanash) trainer.weights_cache = {} if delegate_specs_for_players: if use_openspiel_restricted_game: set_restricted_game_conversions_for_all_workers_openspiel( trainer=trainer, tmp_base_env=tmp_base_env, delegate_policy_id="metanash_delegate", agent_id_to_restricted_game_specs={ other_player: delegate_specs_for_players[other_player] }, load_policy_spec_fn=load_pure_strat) else: set_restricted_game_conversations_for_all_workers( trainer=trainer, delegate_policy_id="metanash_delegate", agent_id_to_restricted_game_specs={ other_player: delegate_specs_for_players[other_player] }, load_policy_spec_fn=create_get_pure_strat_cached(cache=trainer.weights_cache)) log(f"got policy {active_policy_num}") if previous_br_checkpoint_path is not None: def _set_br_initial_weights(worker: RolloutWorker): br_policy = worker.policy_map["best_response"] load_pure_strat(policy=br_policy, checkpoint_path=previous_br_checkpoint_path) trainer.workers.foreach_worker(_set_br_initial_weights) # Perform main RL training loop. Stop according to our StoppingCondition. stopping_condition: StoppingCondition = nxdo_br_get_stopping_condition() while True: train_iter_results = trainer.train() # do a step (or several) in the main RL loop if print_train_results: train_iter_results["p2sro_active_policy_num"] = active_policy_num train_iter_results["best_response_player"] = br_player # Delete verbose debugging info before printing if "hist_stats" in train_iter_results: del train_iter_results["hist_stats"] if "td_error" in train_iter_results["info"]["learner"][f"best_response"]: del train_iter_results["info"]["learner"][f"best_response"]["td_error"] log(f"Trainer log dir is {trainer.logdir}") log(pretty_dict_str(train_iter_results)) total_timesteps_training_br = train_iter_results["timesteps_total"] total_episodes_training_br = train_iter_results["episodes_total"] br_reward_this_iter = train_iter_results["policy_reward_mean"][f"best_response"] if stopping_condition.should_stop_this_iter(latest_trainer_result=train_iter_results): log("Stopping condition met.") break log(f"Training stopped. Setting active policy {active_policy_num} as fixed.") final_policy_metadata = create_metadata_with_new_checkpoint_for_current_best_response( trainer=trainer, player=br_player, save_dir=checkpoint_dir(trainer=trainer), timesteps_training_br=total_timesteps_training_br, episodes_training_br=total_episodes_training_br, active_policy_num=active_policy_num, average_br_reward=float(br_reward_this_iter), ) nxdo_manager.submit_final_br_policy( player=br_player, policy_num=active_policy_num, metadata_dict=final_policy_metadata) # trainer.cleanup() # del trainer ray.shutdown() time.sleep(10) # wait for both player policies to be fixed. for player_to_wait_on in range(2): wait_count = 0 while True: if nxdo_manager.is_policy_fixed(player=player_to_wait_on, policy_num=active_policy_num): break if wait_count % 10 == 0: log(f"Waiting for policy {active_policy_num} player {player_to_wait_on} to become fixed") time.sleep(2.0) wait_count += 1 return final_policy_metadata["checkpoint_path"]
def train_poker_approx_best_response_xdfo( br_player: int, ray_head_address, scenario: NXDOScenario, general_trainer_config_overrrides, br_policy_config_overrides: dict, get_stopping_condition: Callable[[], StoppingCondition], metanash_specs_for_players: Dict[int, StrategySpec], delegate_specs_for_players: Dict[int, List[StrategySpec]], results_dir: str, print_train_results: bool = True): use_openspiel_restricted_game: bool = scenario.use_openspiel_restricted_game get_restricted_game_custom_model = scenario.get_restricted_game_custom_model env_class = scenario.env_class base_env_config = scenario.env_config trainer_class = scenario.trainer_class_br policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes_br get_trainer_config = scenario.get_trainer_config_br should_log_result_fn = scenario.ray_should_log_result_filter nxdo_metanash_method: str = scenario.xdo_metanash_method if nxdo_metanash_method != "nfsp": raise NotImplementedError( "Only 'nfsp' is currently supported for the nxdo_metanash_method") other_player = 1 - br_player br_learner_name = f"approx br player {br_player}" def log(message, level=logging.INFO): logger.log(level, f"({br_learner_name}): {message}") def select_policy(agent_id): if agent_id == br_player: return f"best_response" elif agent_id == other_player: return f"metanash" else: raise ValueError(f"Unknown agent id: {agent_id}") restricted_env_config = { "create_env_fn": lambda: env_class(env_config=base_env_config), "raise_if_no_restricted_players": metanash_specs_for_players is not None } tmp_base_env = env_class(env_config=base_env_config) if use_openspiel_restricted_game: restricted_game_class = OpenSpielRestrictedGame else: restricted_game_class = RestrictedGame restricted_env_config[ "use_delegate_policy_exploration"] = scenario.allow_stochastic_best_responses tmp_env = restricted_game_class(env_config=restricted_env_config) if metanash_specs_for_players is None or use_openspiel_restricted_game: other_player_restricted_action_space = tmp_env.base_action_space else: other_player_restricted_action_space = Discrete( n=len(delegate_specs_for_players[other_player])) if metanash_specs_for_players is None and use_openspiel_restricted_game: other_player_restricted_obs_space = tmp_env.base_observation_space else: other_player_restricted_obs_space = tmp_env.observation_space trainer_config = { "env": restricted_game_class, "env_config": restricted_env_config, "gamma": 1.0, "num_gpus": 0, "num_workers": 0, "num_envs_per_worker": 1, "multiagent": { "policies_to_train": [f"best_response"], "policies": { f"metanash": (policy_classes["metanash"], other_player_restricted_obs_space, other_player_restricted_action_space, { "explore": False }), f"metanash_delegate": (policy_classes["best_response"], tmp_env.base_observation_space, tmp_env.base_action_space, { "explore": scenario.allow_stochastic_best_responses }), f"best_response": (policy_classes["best_response"], tmp_env.base_observation_space, tmp_env.base_action_space, br_policy_config_overrides), }, "policy_mapping_fn": select_policy, }, } if metanash_specs_for_players is not None: if get_restricted_game_custom_model is not None: restricted_game_custom_model = get_restricted_game_custom_model( tmp_base_env) else: restricted_game_custom_model = None trainer_config["multiagent"]["policies"]["metanash"][3]["model"] = { "custom_model": restricted_game_custom_model } trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_base_env)) trainer_config = merge_dicts(trainer_config, general_trainer_config_overrrides) init_ray_for_scenario(scenario=scenario, head_address=ray_head_address, logging_level=logging.INFO) trainer = trainer_class(config=trainer_config, logger_creator=get_trainer_logger_creator( base_dir=results_dir, scenario_name="approx_br", should_log_result_fn=should_log_result_fn)) # metanash is single pure strat spec def _set_worker_metanash(worker: RolloutWorker): if metanash_specs_for_players is not None: metanash_policy = worker.policy_map["metanash"] metanash_strategy_spec: StrategySpec = metanash_specs_for_players[ other_player] load_pure_strat(policy=metanash_policy, pure_strat_spec=metanash_strategy_spec) trainer.workers.foreach_worker(_set_worker_metanash) trainer.weights_cache = {} if delegate_specs_for_players: if use_openspiel_restricted_game: set_restricted_game_conversions_for_all_workers_openspiel( trainer=trainer, tmp_base_env=tmp_base_env, delegate_policy_id="metanash_delegate", agent_id_to_restricted_game_specs={ other_player: delegate_specs_for_players[other_player] }, load_policy_spec_fn=load_pure_strat) else: set_restricted_game_conversations_for_all_workers( trainer=trainer, delegate_policy_id="metanash_delegate", agent_id_to_restricted_game_specs={ other_player: delegate_specs_for_players[other_player] }, load_policy_spec_fn=create_get_pure_strat_cached( cache=trainer.weights_cache)) # Perform main RL training loop. Stop according to our StoppingCondition. train_iter_count = 0 stopping_condition: StoppingCondition = get_stopping_condition() max_reward = None while True: train_iter_results = trainer.train( ) # do a step (or several) in the main RL loop train_iter_count += 1 if print_train_results: train_iter_results["best_response_player"] = br_player # Delete verbose debugging info before printing if "hist_stats" in train_iter_results: del train_iter_results["hist_stats"] if "td_error" in train_iter_results["info"]["learner"][ f"best_response"]: del train_iter_results["info"]["learner"][f"best_response"][ "td_error"] log(f"Trainer log dir is {trainer.logdir}") print(pretty_dict_str(train_iter_results)) br_reward_this_iter = train_iter_results["policy_reward_mean"][ f"best_response"] if max_reward is None or br_reward_this_iter > max_reward: max_reward = br_reward_this_iter if stopping_condition.should_stop_this_iter( latest_trainer_result=train_iter_results): log("Stopping condition met.") break log(f"Training stopped.") # trainer.cleanup() # del trainer ray.shutdown() time.sleep(10) return max_reward
def train_off_policy_rl_nfsp_restricted_game(results_dir: str, scenario: NXDOScenario, player_to_base_game_action_specs: Dict[int, List[StrategySpec]], stopping_condition: StoppingCondition, manager_metadata: Union[dict, None], print_train_results: bool = True): use_openspiel_restricted_game: bool = scenario.use_openspiel_restricted_game get_restricted_game_custom_model = scenario.get_restricted_game_custom_model env_class = scenario.env_class base_env_config = scenario.env_config trainer_class = scenario.trainer_class_nfsp avg_trainer_class = scenario.avg_trainer_class_nfsp policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes_nfsp anticipatory_param: float = scenario.anticipatory_param_nfsp get_trainer_config = scenario.get_trainer_config_nfsp get_avg_trainer_config = scenario.get_avg_trainer_config_nfsp get_trainer_config_br = scenario.get_trainer_config_br calculate_openspiel_metanash: bool = scenario.calculate_openspiel_metanash calculate_openspiel_metanash_at_end: bool = scenario.calculate_openspiel_metanash_at_end calc_metanash_every_n_iters: int = scenario.calc_metanash_every_n_iters should_log_result_fn = scenario.ray_should_log_result_filter metrics_smoothing_episodes_override: int = scenario.metanash_metrics_smoothing_episodes_override assert scenario.xdo_metanash_method == "nfsp" ray_head_address = manager_metadata.get("ray_head_address", None) if manager_metadata is not None else None init_ray_for_scenario(scenario=scenario, head_address=ray_head_address, logging_level=logging.INFO) def select_policy(agent_id): random_sample = np.random.random() if agent_id == 0: if random_sample < anticipatory_param: return "best_response_0" return "average_policy_0" elif agent_id == 1: if random_sample < anticipatory_param: return "best_response_1" return "average_policy_1" else: raise ValueError(f"unexpected agent_id: {agent_id}") def assert_not_called(agent_id): assert False, "This function should never be called." def _create_base_env(): return env_class(env_config=base_env_config) tmp_base_env = _create_base_env() restricted_env_config = {"create_env_fn": _create_base_env} if use_openspiel_restricted_game: restricted_game_class = OpenSpielRestrictedGame tmp_env = restricted_game_class(env_config=restricted_env_config) restricted_game_action_spaces = [tmp_env.base_action_space for _ in range(2)] else: restricted_game_class = RestrictedGame restricted_env_config["use_delegate_policy_exploration"] = scenario.allow_stochastic_best_responses tmp_env = restricted_game_class(env_config=restricted_env_config) restricted_game_action_spaces = [Discrete(n=len(player_to_base_game_action_specs[p])) for p in range(2)] assert all(restricted_game_action_spaces[0] == space for space in restricted_game_action_spaces) print(f"\n\n\n\n\nRestricted game action spaces {restricted_game_action_spaces}\n\n\n\n\n\n") scenario_avg_trainer_config = get_avg_trainer_config(tmp_base_env) scenario_avg_trainer_config_exploration_config = scenario_avg_trainer_config.get("exploration_config", {}) if scenario_avg_trainer_config_exploration_config: del scenario_avg_trainer_config["exploration_config"] scenario_trainer_config = get_trainer_config(tmp_base_env) scenario_trainer_config_exploration_config = scenario_trainer_config.get("exploration_config", {}) if scenario_trainer_config_exploration_config: del scenario_trainer_config["exploration_config"] delegate_policy_config = merge_dicts(get_trainer_config_br(tmp_base_env), {"explore": scenario.allow_stochastic_best_responses}) avg_trainer_config = merge_dicts({ "log_level": "DEBUG", "framework": "torch", "env": restricted_game_class, "env_config": restricted_env_config, "num_gpus": 0.0, "num_gpus_per_worker": 0.0, "num_workers": 0, "num_envs_per_worker": 1, "multiagent": { "policies_to_train": ["average_policy_0", "average_policy_1"], "policies": { "average_policy_0": ( policy_classes["average_policy"], tmp_env.observation_space, restricted_game_action_spaces[0], {"explore": False, "exploration_config": scenario_avg_trainer_config_exploration_config}), "average_policy_1": ( policy_classes["average_policy"], tmp_env.observation_space, restricted_game_action_spaces[1], {"explore": False, "exploration_config": scenario_avg_trainer_config_exploration_config}), "delegate_policy": ( policy_classes["delegate_policy"], tmp_base_env.observation_space, tmp_env.base_action_space, delegate_policy_config), }, "policy_mapping_fn": assert_not_called, }, }, scenario_avg_trainer_config) for _policy_id in ["average_policy_0", "average_policy_1"]: if get_restricted_game_custom_model is not None: avg_trainer_config["multiagent"]["policies"][_policy_id][3]["model"] = { "custom_model": get_restricted_game_custom_model(tmp_env)} avg_trainer = avg_trainer_class(config=avg_trainer_config, logger_creator=get_trainer_logger_creator( base_dir=results_dir, scenario_name=f"nfsp_restricted_game_avg_trainer", should_log_result_fn=should_log_result_fn)) store_to_avg_policy_buffer = get_store_to_avg_policy_buffer_fn(nfsp_trainer=avg_trainer) class NFSPBestResponseCallbacks(DefaultCallbacks): def on_postprocess_trajectory(self, *, worker: "RolloutWorker", episode: MultiAgentEpisode, agent_id: AgentID, policy_id: PolicyID, policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch, original_batches: Dict[Any, Tuple[Policy, SampleBatch]], **kwargs): super().on_postprocess_trajectory(worker=worker, episode=episode, agent_id=agent_id, policy_id=policy_id, policies=policies, postprocessed_batch=postprocessed_batch, original_batches=original_batches, **kwargs) postprocessed_batch.data["source_policy"] = [policy_id] * len(postprocessed_batch.data["rewards"]) # All data from both policies will go into the best response's replay buffer. # Here we ensure policies not from the best response have the exact same preprocessing as the best response. for average_policy_id, br_policy_id in [("average_policy_0", "best_response_0"), ("average_policy_1", "best_response_1")]: if policy_id == average_policy_id: if "action_probs" in postprocessed_batch: del postprocessed_batch.data["action_probs"] if "behaviour_logits" in postprocessed_batch: del postprocessed_batch.data["behaviour_logits"] br_policy: Policy = policies[br_policy_id] new_batch = br_policy.postprocess_trajectory( sample_batch=postprocessed_batch, other_agent_batches=original_batches, episode=episode) copy_attributes(src_obj=new_batch, dst_obj=postprocessed_batch) elif policy_id == br_policy_id: if "q_values" in postprocessed_batch: del postprocessed_batch.data["q_values"] if "action_probs" in postprocessed_batch: del postprocessed_batch.data["action_probs"] del postprocessed_batch.data["action_dist_inputs"] if policy_id in ("average_policy_0", "best_response_0"): assert agent_id == 0 if policy_id in ("average_policy_1", "best_response_1"): assert agent_id == 1 def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs): super().on_sample_end(worker=worker, samples=samples, **kwargs) assert isinstance(samples, MultiAgentBatch) for policy_samples in samples.policy_batches.values(): if "action_prob" in policy_samples.data: del policy_samples.data["action_prob"] if "action_logp" in policy_samples.data: del policy_samples.data["action_logp"] for average_policy_id, br_policy_id in [("average_policy_0", "best_response_0"), ("average_policy_1", "best_response_1")]: for policy_id, policy_samples in samples.policy_batches.items(): if policy_id == br_policy_id: store_to_avg_policy_buffer(MultiAgentBatch(policy_batches={ average_policy_id: policy_samples }, env_steps=policy_samples.count)) if average_policy_id in samples.policy_batches: if br_policy_id in samples.policy_batches: all_policies_samples = samples.policy_batches[br_policy_id].concat( other=samples.policy_batches[average_policy_id]) else: all_policies_samples = samples.policy_batches[average_policy_id] del samples.policy_batches[average_policy_id] samples.policy_batches[br_policy_id] = all_policies_samples def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: MultiAgentEpisode, env_index: int, **kwargs): super().on_episode_end(worker=worker, base_env=base_env, policies=policies, episode=episode, env_index=env_index, **kwargs) episode_policies = set(episode.agent_rewards.keys()) if episode_policies == {(0, "average_policy_0"), (1, "best_response_1")}: worker.avg_br_reward_deque.add.remote(episode.agent_rewards[(1, "best_response_1")]) elif episode_policies == {(1, "average_policy_1"), (0, "best_response_0")}: worker.avg_br_reward_deque.add.remote(episode.agent_rewards[(0, "best_response_0")]) def on_train_result(self, *, trainer, result: dict, **kwargs): super().on_train_result(trainer=trainer, result=result, **kwargs) training_iteration = result["training_iteration"] result["avg_br_reward_both_players"] = ray.get(trainer.avg_br_reward_deque.get_mean.remote()) if (calculate_openspiel_metanash and (training_iteration == 1 or training_iteration % calc_metanash_every_n_iters == 0)): base_env = _create_base_env() open_spiel_env_config = base_env.open_spiel_env_config openspiel_game_version = base_env.game_version local_avg_policy_0 = trainer.workers.local_worker().policy_map["average_policy_0"] local_avg_policy_1 = trainer.workers.local_worker().policy_map["average_policy_1"] exploitability = nxdo_nfsp_measure_exploitability_nonlstm( rllib_policies=[local_avg_policy_0, local_avg_policy_1], poker_game_version=openspiel_game_version, restricted_game_convertors=trainer.get_local_converters(), open_spiel_env_config=open_spiel_env_config, use_delegate_policy_exploration=scenario.allow_stochastic_best_responses ) result["avg_policy_exploitability"] = exploitability br_trainer_config = { "log_level": "DEBUG", "callbacks": NFSPBestResponseCallbacks, "env": restricted_game_class, "env_config": restricted_env_config, "gamma": 1.0, "num_gpus": 0.0, "num_workers": 0, "num_gpus_per_worker": 0.0, "num_envs_per_worker": 1, "multiagent": { "policies_to_train": ["best_response_0", "best_response_1"], "policies": { "average_policy_0": ( policy_classes["average_policy"], tmp_env.observation_space, restricted_game_action_spaces[0], {"explore": False, "exploration_config": scenario_avg_trainer_config_exploration_config}), "best_response_0": ( policy_classes["best_response"], tmp_env.observation_space, restricted_game_action_spaces[0], {"exploration_config": scenario_trainer_config_exploration_config}), "average_policy_1": ( policy_classes["average_policy"], tmp_env.observation_space, restricted_game_action_spaces[1], {"explore": False, "exploration_config": scenario_avg_trainer_config_exploration_config}), "best_response_1": ( policy_classes["best_response"], tmp_env.observation_space, restricted_game_action_spaces[1], {"exploration_config": scenario_trainer_config_exploration_config}), "delegate_policy": ( policy_classes["delegate_policy"], tmp_base_env.observation_space, tmp_env.base_action_space, delegate_policy_config), }, "policy_mapping_fn": select_policy, }, } assert all(restricted_game_action_spaces[0] == space for space in restricted_game_action_spaces), \ "If not true, the line below with \"get_trainer_config\" may need to be changed to a better solution." br_trainer_config = merge_dicts(br_trainer_config, scenario_trainer_config) for _policy_id in ["average_policy_0", "average_policy_1", "best_response_0", "best_response_1"]: if get_restricted_game_custom_model is not None: br_trainer_config["multiagent"]["policies"][_policy_id][3]["model"] = { "custom_model": get_restricted_game_custom_model(tmp_env)} br_trainer_config["metrics_smoothing_episodes"] = metrics_smoothing_episodes_override br_trainer = trainer_class(config=br_trainer_config, logger_creator=get_trainer_logger_creator( base_dir=results_dir, scenario_name="nfsp_restricted_game_trainer", should_log_result_fn=should_log_result_fn)) avg_br_reward_deque = StatDeque.remote(max_items=br_trainer_config["metrics_smoothing_episodes"]) def _set_avg_br_rew_deque(worker: RolloutWorker): worker.avg_br_reward_deque = avg_br_reward_deque br_trainer.workers.foreach_worker(_set_avg_br_rew_deque) br_trainer.avg_br_reward_deque = avg_br_reward_deque if use_openspiel_restricted_game: local_delegate_policy = br_trainer.workers.local_worker().policy_map["delegate_policy"] player_converters = [] for p in range(2): print("Creating restricted game obs conversions...") convertor = get_restricted_game_obs_conversions(player=p, delegate_policy=local_delegate_policy, policy_specs=player_to_base_game_action_specs[p], load_policy_spec_fn=create_get_pure_strat_cached(cache={}), tmp_base_env=tmp_base_env) player_converters.append(convertor) for _trainer in [br_trainer, avg_trainer]: def _set_worker_converters(worker: RolloutWorker): worker_delegate_policy = worker.policy_map["delegate_policy"] for p in range(2): worker.foreach_env(lambda env: env.set_obs_conversion_dict(p, player_converters[p])) worker_delegate_policy.player_converters = player_converters _trainer.workers.foreach_worker(_set_worker_converters) _trainer.get_local_converters = lambda: _trainer.workers.local_worker().policy_map[ "delegate_policy"].player_converters else: weights_cache = {} for _trainer in [br_trainer, avg_trainer]: def _set_worker_converters(worker: RolloutWorker): worker_delegate_policy = worker.policy_map["delegate_policy"] player_converters = [] for p in range(2): player_converter = RestrictedToBaseGameActionSpaceConverter( delegate_policy=worker_delegate_policy, policy_specs=player_to_base_game_action_specs[p], load_policy_spec_fn=create_get_pure_strat_cached(cache=weights_cache)) player_converters.append(player_converter) worker.foreach_env(lambda env: env.set_action_conversion(p, player_converter)) worker_delegate_policy.player_converters = player_converters _trainer.workers.foreach_worker(_set_worker_converters) _trainer.get_local_converters = lambda: _trainer.workers.local_worker().policy_map[ "delegate_policy"].player_converters br_trainer.latest_avg_trainer_result = None train_iter_count = 0 for _trainer in [br_trainer, avg_trainer]: for policy_id, policy in _trainer.workers.local_worker().policy_map.items(): policy.policy_id = policy_id if len(player_to_base_game_action_specs[0]) == 1: final_train_result = {"episodes_total": 0, "timesteps_total": 0, "training_iteration": 0} tmp_callback = NFSPBestResponseCallbacks() tmp_callback.on_train_result(trainer=br_trainer, result=final_train_result) else: avg_weights = avg_trainer.get_weights(["average_policy_0", "average_policy_1"]) br_trainer.workers.foreach_worker(lambda worker: worker.set_weights(avg_weights)) while True: avg_train_results = avg_trainer.train() avg_weights = avg_trainer.get_weights(["average_policy_0", "average_policy_1"]) br_trainer.workers.foreach_worker(lambda worker: worker.set_weights(avg_weights)) br_trainer.latest_avg_trainer_result = copy.deepcopy(avg_train_results) train_iter_results = br_trainer.train() # do a step (or several) in the main RL loop train_iter_count += 1 if print_train_results: # Delete verbose debugging info before printing if "hist_stats" in train_iter_results: del train_iter_results["hist_stats"] if "td_error" in train_iter_results["info"]["learner"]["best_response_0"]: del train_iter_results["info"]["learner"]["best_response_0"]["td_error"] if "td_error" in train_iter_results["info"]["learner"]["best_response_1"]: del train_iter_results["info"]["learner"]["best_response_1"]["td_error"] print(pretty_dict_str(train_iter_results)) print(f"Trainer logdir is {br_trainer.logdir}") if stopping_condition.should_stop_this_iter(latest_trainer_result=train_iter_results): print("stopping condition met.") final_train_result = deepcopy(train_iter_results) break if calculate_openspiel_metanash_at_end: base_env = _create_base_env() open_spiel_env_config = base_env.open_spiel_env_config openspiel_game_version = base_env.game_version local_avg_policy_0 = br_trainer.workers.local_worker().policy_map["average_policy_0"] local_avg_policy_1 = br_trainer.workers.local_worker().policy_map["average_policy_1"] exploitability = nxdo_nfsp_measure_exploitability_nonlstm( rllib_policies=[local_avg_policy_0, local_avg_policy_1], poker_game_version=openspiel_game_version, restricted_game_convertors=br_trainer.get_local_converters(), open_spiel_env_config=open_spiel_env_config, use_delegate_policy_exploration=scenario.allow_stochastic_best_responses ) final_train_result["avg_policy_exploitability"] = exploitability if "avg_policy_exploitability" in final_train_result: print(f"\n\nexploitability: {final_train_result['avg_policy_exploitability']}\n\n") avg_policy_specs = [] for player in range(2): strategy_id = f"avg_policy_player_{player}_{datetime_str()}" checkpoint_path = save_nfsp_avg_policy_checkpoint(trainer=br_trainer, policy_id_to_save=f"average_policy_{player}", save_dir=checkpoint_dir(trainer=br_trainer), timesteps_training=final_train_result["timesteps_total"], episodes_training=final_train_result["episodes_total"], checkpoint_name=f"{strategy_id}.h5") avg_policy_spec = StrategySpec( strategy_id=strategy_id, metadata={"checkpoint_path": checkpoint_path, "delegate_policy_specs": [spec.to_json() for spec in player_to_base_game_action_specs[player]] }) avg_policy_specs.append(avg_policy_spec) ray.kill(avg_trainer.workers.local_worker().replay_buffer_actor) avg_trainer.cleanup() br_trainer.cleanup() del avg_trainer del br_trainer del avg_br_reward_deque time.sleep(10) assert final_train_result is not None return avg_policy_specs, final_train_result
def train_poker_approx_best_response_psro(br_player, ray_head_address, scenario_name, general_trainer_config_overrrides, br_policy_config_overrides: dict, get_stopping_condition, metanash_policy_specs, metanash_weights, results_dir, print_train_results=True): scenario: PSROScenario = scenario_catalog.get(scenario_name=scenario_name) env_class = scenario.env_class env_config = scenario.env_config trainer_class = scenario.trainer_class policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes p2sro = scenario.p2sro get_trainer_config = scenario.get_trainer_config psro_get_stopping_condition = scenario.psro_get_stopping_condition mix_metanash_with_uniform_dist_coeff = scenario.mix_metanash_with_uniform_dist_coeff other_player = 1 - br_player br_learner_name = f"new_learner_{br_player}" def log(message, level=logging.INFO): logger.log(level, f"({br_learner_name}): {message}") def select_policy(agent_id): if agent_id == br_player: return f"best_response" elif agent_id == other_player: return f"metanash" else: raise ValueError(f"Unknown agent id: {agent_id}") init_ray_for_scenario(scenario=scenario, head_address=ray_head_address, logging_level=logging.INFO) tmp_env = env_class(env_config=env_config) trainer_config = { "callbacks": P2SROPreAndPostEpisodeCallbacks, "env": env_class, "env_config": env_config, "gamma": 1.0, "num_gpus": 0, "num_workers": 0, "num_envs_per_worker": 1, "multiagent": { "policies_to_train": [f"best_response"], "policies": { f"metanash": (policy_classes["metanash"], tmp_env.observation_space, tmp_env.action_space, { "explore": scenario.allow_stochastic_best_responses }), f"best_response": (policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, br_policy_config_overrides), }, "policy_mapping_fn": select_policy, }, } trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_env)) trainer_config = merge_dicts(trainer_config, general_trainer_config_overrrides) # trainer_config["rollout_fragment_length"] = trainer_config["rollout_fragment_length"] // max(1, trainer_config["num_workers"] * trainer_config["num_envs_per_worker"] ) trainer = trainer_class(config=trainer_config, logger_creator=get_trainer_logger_creator( base_dir=results_dir, scenario_name="approx_br", should_log_result_fn=lambda result: result[ "training_iteration"] % 100 == 0)) update_all_workers_to_latest_metanash( trainer=trainer, metanash_policy_specs=metanash_policy_specs, metanash_weights=metanash_weights) train_iter_count = 0 stopping_condition: StoppingCondition = get_stopping_condition() max_reward = None while True: train_iter_results = trainer.train( ) # do a step (or several) in the main RL loop train_iter_count += 1 if print_train_results: train_iter_results["best_response_player"] = br_player # Delete verbose debugging info before printing if "hist_stats" in train_iter_results: del train_iter_results["hist_stats"] if "td_error" in train_iter_results["info"]["learner"][ f"best_response"]: del train_iter_results["info"]["learner"][f"best_response"][ "td_error"] print(pretty_dict_str(train_iter_results)) total_timesteps_training_br = train_iter_results["timesteps_total"] total_episodes_training_br = train_iter_results["episodes_total"] br_reward_this_iter = train_iter_results["policy_reward_mean"][ f"best_response"] if max_reward is None or br_reward_this_iter > max_reward: max_reward = br_reward_this_iter if stopping_condition.should_stop_this_iter( latest_trainer_result=train_iter_results): break trainer.cleanup() ray.shutdown() time.sleep(10) return max_reward
def train_psro_best_response(player: int, results_dir: str, scenario_name: str, psro_manager_port: int, psro_manager_host: str, print_train_results=True, previous_br_checkpoint_path=None) -> str: scenario: PSROScenario = scenario_catalog.get(scenario_name=scenario_name) if not isinstance(scenario, PSROScenario): raise TypeError(f"Only instances of {PSROScenario} can be used here. {scenario.name} is a {type(scenario)}.") env_class = scenario.env_class env_config = scenario.env_config trainer_class = scenario.trainer_class policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes single_agent_symmetric_game = scenario.single_agent_symmetric_game if single_agent_symmetric_game and player != 0: if player is None: player = 0 else: raise ValueError(f"If treating the game as single agent symmetric, only use player 0 " f"(one agent plays all sides).") p2sro = scenario.p2sro p2sro_sync_with_payoff_table_every_n_episodes = scenario.p2sro_sync_with_payoff_table_every_n_episodes get_trainer_config = scenario.get_trainer_config psro_get_stopping_condition = scenario.psro_get_stopping_condition mix_metanash_with_uniform_dist_coeff = scenario.mix_metanash_with_uniform_dist_coeff allow_stochastic_best_response = scenario.allow_stochastic_best_responses should_log_result_fn = scenario.ray_should_log_result_filter class P2SROPreAndPostEpisodeCallbacks(DefaultCallbacks): # def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode, # env_index: int, # **kwargs): # super().on_episode_step(worker=worker, base_env=base_env, episode=episode, env_index=env_index, **kwargs) # # # Debug render a single environment. # if worker.worker_index == 1 and env_index == 0: # base_env.get_unwrapped()[0].render() def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode, env_index: int, **kwargs): # Sample new pure strategy policy weights from the opponent strategy distribution for the best response to # train against. For better runtime performance, this function can be modified to load new weights # only every few episodes instead. resample_pure_strat_every_n_episodes = 1 metanash_policy: Policy = policies[f"metanash"] opponent_policy_distribution: PolicySpecDistribution = worker.opponent_policy_distribution time_for_resample = (not hasattr(metanash_policy, "episodes_since_resample") or metanash_policy.episodes_since_resample >= resample_pure_strat_every_n_episodes) if time_for_resample and opponent_policy_distribution is not None: new_pure_strat_spec: StrategySpec = opponent_policy_distribution.sample_policy_spec() # noinspection PyTypeChecker load_pure_strat(policy=metanash_policy, pure_strat_spec=new_pure_strat_spec) metanash_policy.episodes_since_resample = 1 elif opponent_policy_distribution is not None: metanash_policy.episodes_since_resample += 1 def on_train_result(self, *, trainer, result: dict, **kwargs): result["scenario_name"] = trainer.scenario_name super().on_train_result(trainer=trainer, result=result, **kwargs) def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode, env_index: int, **kwargs): # If using P2SRO, report payoff results of the actively training BR to the payoff table. if not p2sro: return if not hasattr(worker, "p2sro_manager"): worker.p2sro_manager = RemoteP2SROManagerClient(n_players=2, port=psro_manager_port, remote_server_host=psro_manager_host) br_policy_spec: StrategySpec = worker.policy_map["best_response"].policy_spec if br_policy_spec.pure_strat_index_for_player(player=worker.br_player) == 0: # We're training policy 0 if True (first iteration of PSRO). # The PSRO subgame should be empty, and instead the metanash is a random neural network. # No need to report payoff results for this. return # Report payoff results for individual episodes to the p2sro manager to keep a real-time approximation # of the payoff matrix entries for (learning) active policies. policy_specs_for_each_player: List[StrategySpec] = [None, None] payoffs_for_each_player: List[float] = [None, None] for (player, policy_name), reward in episode.agent_rewards.items(): assert policy_name in ["best_response", "metanash"] policy: Policy = worker.policy_map[policy_name] assert policy.policy_spec is not None policy_specs_for_each_player[player] = policy.policy_spec payoffs_for_each_player[player] = reward assert all(payoff is not None for payoff in payoffs_for_each_player) # Send payoff result to the manager for inclusion in the payoff table. worker.p2sro_manager.submit_empirical_payoff_result( policy_specs_for_each_player=tuple(policy_specs_for_each_player), payoffs_for_each_player=tuple(payoffs_for_each_player), games_played=1, override_all_previous_results=False) other_player = 1 - player br_learner_name = f"new_learner_{player}" def log(message): print(f"({br_learner_name}): {message}") def select_policy(agent_id): if agent_id == player: return "best_response" elif agent_id == other_player: return "metanash" else: raise ValueError(f"Unknown agent id: {agent_id}") p2sro_manager = RemoteP2SROManagerClient(n_players=2, port=psro_manager_port, remote_server_host=psro_manager_host) manager_metadata = p2sro_manager.get_manager_metadata() ray_head_address = manager_metadata["ray_head_address"] init_ray_for_scenario(scenario=scenario, head_address=ray_head_address, logging_level=logging.INFO) tmp_env = env_class(env_config=env_config) trainer_config = { "callbacks": P2SROPreAndPostEpisodeCallbacks, "env": env_class, "env_config": env_config, "gamma": 1.0, "num_gpus": 0, "num_workers": 0, "num_envs_per_worker": 1, "multiagent": { "policies_to_train": [f"best_response"], "policies": { f"metanash": ( policy_classes["metanash"], tmp_env.observation_space, tmp_env.action_space, {"explore": allow_stochastic_best_response}), f"best_response": ( policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, {}), }, "policy_mapping_fn": select_policy, }, } trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_env)) trainer = trainer_class(config=trainer_config, logger_creator=get_trainer_logger_creator( base_dir=results_dir, scenario_name=scenario_name, should_log_result_fn=should_log_result_fn)) # scenario_name logged in on_train_result_callback trainer.scenario_name = scenario_name if previous_br_checkpoint_path is not None: def _set_br_initial_weights(worker: RolloutWorker): br_policy = worker.policy_map["best_response"] load_pure_strat(policy=br_policy, checkpoint_path=previous_br_checkpoint_path) trainer.workers.foreach_worker(_set_br_initial_weights) active_policy_spec: StrategySpec = p2sro_manager.claim_new_active_policy_for_player( player=player, new_policy_metadata_dict=create_metadata_with_new_checkpoint_for_current_best_response( trainer=trainer, player=player, save_dir=checkpoint_dir(trainer), timesteps_training_br=0, episodes_training_br=0, active_policy_num=None )) active_policy_num = active_policy_spec.pure_strat_index_for_player(player=player) br_learner_name = f"policy {active_policy_num} player {player}" log(f"got policy {active_policy_num}") set_best_response_active_policy_spec_and_player_for_all_workers(trainer=trainer, player=player, active_policy_spec=active_policy_spec) sync_active_policy_br_and_metanash_with_p2sro_manager(trainer=trainer, player=player, metanash_player=other_player, one_agent_plays_all_sides=single_agent_symmetric_game, p2sro_manager=p2sro_manager, mix_metanash_with_uniform_dist_coeff=mix_metanash_with_uniform_dist_coeff, active_policy_num=active_policy_num, timesteps_training_br=0, episodes_training_br=0) # Perform main RL training loop. Stop training according to our StoppingCondition. train_iter_count = 0 episodes_since_last_sync_with_manager = 0 stopping_condition: StoppingCondition = psro_get_stopping_condition() while True: train_iter_results = trainer.train() # do a step (or several) in the main RL loop train_iter_count += 1 if print_train_results: train_iter_results["p2sro_active_policy_num"] = active_policy_num train_iter_results["best_response_player"] = player # Delete verbose debugging info before printing if "hist_stats" in train_iter_results: del train_iter_results["hist_stats"] if "td_error" in train_iter_results["info"]["learner"][f"best_response"]: del train_iter_results["info"]["learner"][f"best_response"]["td_error"] log(pretty_dict_str(train_iter_results)) total_timesteps_training_br = train_iter_results["timesteps_total"] total_episodes_training_br = train_iter_results["episodes_total"] episodes_since_last_sync_with_manager += train_iter_results["episodes_this_iter"] if p2sro and episodes_since_last_sync_with_manager >= p2sro_sync_with_payoff_table_every_n_episodes: if p2sro_sync_with_payoff_table_every_n_episodes > 0: episodes_since_last_sync_with_manager = episodes_since_last_sync_with_manager % p2sro_sync_with_payoff_table_every_n_episodes else: episodes_since_last_sync_with_manager = 0 sync_active_policy_br_and_metanash_with_p2sro_manager(trainer=trainer, player=player, metanash_player=other_player, one_agent_plays_all_sides=single_agent_symmetric_game, p2sro_manager=p2sro_manager, mix_metanash_with_uniform_dist_coeff=mix_metanash_with_uniform_dist_coeff, active_policy_num=active_policy_num, timesteps_training_br=total_timesteps_training_br, episodes_training_br=total_episodes_training_br) if stopping_condition.should_stop_this_iter(latest_trainer_result=train_iter_results): if p2sro_manager.can_active_policy_be_set_as_fixed_now(player=player, policy_num=active_policy_num): break else: log(f"Forcing training to continue since lower policies are still active.") log(f"Training stopped. Setting active policy {active_policy_num} as fixed.") final_policy_metadata = create_metadata_with_new_checkpoint_for_current_best_response( trainer=trainer, player=player, save_dir=checkpoint_dir(trainer=trainer), timesteps_training_br=total_timesteps_training_br, episodes_training_br=total_episodes_training_br, active_policy_num=active_policy_num, average_br_reward=train_iter_results["policy_reward_mean"]["best_response"]) p2sro_manager.set_active_policy_as_fixed( player=player, policy_num=active_policy_num, final_metadata_dict=final_policy_metadata) trainer.cleanup() ray.shutdown() time.sleep(10) if not p2sro: # wait for both player policies to be fixed. for player_to_wait_on in range(2): wait_count = 0 while True: if p2sro_manager.is_policy_fixed(player=player_to_wait_on, policy_num=active_policy_num): break if wait_count % 10 == 0: log(f"Waiting for policy {active_policy_num} player {player_to_wait_on} to become fixed") time.sleep(2.0) wait_count += 1 return final_policy_metadata["checkpoint_path"]