def save_best_response_checkpoint(trainer, player: int, save_dir: str, timesteps_training_br: int, episodes_training_br: int, active_policy_num: int = None, average_br_reward: float = None): policy_name = active_policy_num if active_policy_num is not None else "unclaimed" date_time = datetime_str() checkpoint_name = f"policy_{policy_name}_{date_time}.h5" checkpoint_path = os.path.join(save_dir, checkpoint_name) br_weights = trainer.get_weights([f"best_response"])["best_response"] br_weights = {k.replace(".", "_dot_"): v for k, v in br_weights.items()} # periods cause HDF5 NaturalNaming warnings ensure_dir(file_path=checkpoint_path) num_save_attempts = 5 for attempt in range(num_save_attempts): try: deepdish.io.save(path=checkpoint_path, data={ "weights": br_weights, "player": player, "policy_num": active_policy_num, "date_time_str": date_time, "seconds_since_epoch": time.time(), "timesteps_training_br": timesteps_training_br, "episodes_training_br": episodes_training_br, "average_br_reward": average_br_reward, }) break except HDF5ExtError: if attempt + 1 == num_save_attempts: raise time.sleep(1.0) return checkpoint_path
def on_train_result(self, *, trainer, result: dict, **kwargs): result["scenario_name"] = trainer.scenario_name training_iteration = result["training_iteration"] super().on_train_result(trainer=trainer, result=result, **kwargs) if training_iteration % checkpoint_every_n_iters == 0 or training_iteration == 1: for player in range(2): checkpoint_metadata = create_metadata_with_new_checkpoint( policy_id_to_save=f"best_response_{player}", br_trainer=trainer, policy_player=player, save_dir=checkpoint_dir(trainer=trainer), timesteps_training=result["timesteps_total"], episodes_training=result["episodes_total"], checkpoint_name= f"best_response_player_{player}_iter_{training_iteration}.h5" ) joint_pol_checkpoint_spec = StrategySpec( strategy_id= f"best_response_player_{player}_iter_{training_iteration}", metadata=checkpoint_metadata) checkpoint_path = os.path.join( spec_checkpoint_dir(trainer), f"best_response_player_{player}_iter_{training_iteration}.json" ) ensure_dir(checkpoint_path) with open(checkpoint_path, "+w") as checkpoint_spec_file: checkpoint_spec_file.write( joint_pol_checkpoint_spec.to_json())
def save_policy_checkpoint(trainer: Trainer, player: int, save_dir: str, policy_id_to_save: PolicyID, checkpoint_name: str, additional_data: Dict[str, Any]): date_time = datetime_str() checkpoint_name = f"policy_{checkpoint_name}_{date_time}.h5" checkpoint_path = os.path.join(save_dir, checkpoint_name) br_weights = trainer.get_weights([policy_id_to_save])[policy_id_to_save] br_weights = {k.replace(".", "_dot_"): v for k, v in br_weights.items() } # periods cause HDF5 NaturalNaming warnings ensure_dir(file_path=checkpoint_path) num_save_attempts = 5 checkpoint_data = { "weights": br_weights, "player": player, "date_time_str": date_time, "seconds_since_epoch": time.time(), } checkpoint_data.update(additional_data) for attempt in range(num_save_attempts): try: deepdish.io.save(path=checkpoint_path, data=checkpoint_data) break except HDF5ExtError: if attempt + 1 == num_save_attempts: raise time.sleep(1.0) return checkpoint_path
def __init__(self, p2sro_manger, log_dir: str, scenario: PSROScenario): super(ExploitabilityP2SROManagerLogger, self).__init__(p2sro_manger=p2sro_manger, log_dir=log_dir) self._scenario = scenario if not issubclass(scenario.env_class, (PokerMultiAgentEnv, OshiZumoMultiAgentEnv)): raise ValueError( f"ExploitabilityP2SROManagerLogger is only meant to be used with PokerMultiAgentEnv or OshiZumoMultiAgentEnv," f"not {scenario.env_class}") if not scenario.calc_exploitability_for_openspiel_env: raise ValueError( f"Only use ExploitabilityP2SROManagerLogger if " f"scenario.calc_exploitability_for_openspiel_env is True.") self._exploitability_per_generation = [] self._total_steps_per_generation = [] self._total_episodes_per_generation = [] self._num_policies_per_generation = [] self._payoff_table_checkpoint_nums = [] self._payoff_table_checkpoint_paths = [] self._policy_nums_checkpoint_paths = [] self._exploitability_stats_save_path = os.path.join( log_dir, "exploitability_stats.json") ensure_dir(self._exploitability_stats_save_path)
def __init__(self, p2sro_manger, log_dir: str, scenario: PSROScenario): super(ApproxExploitabilityP2SROManagerLogger, self).__init__(p2sro_manger=p2sro_manger, log_dir=log_dir) self._scenario = scenario self._exploitability_per_generation = [] self._total_steps_per_generation = [] self._total_episodes_per_generation = [] self._num_policies_per_generation = [] self._payoff_table_checkpoint_nums = [] self._payoff_table_checkpoint_paths = [] self._policy_nums_checkpoint_paths = [] self._exploitability_stats_save_path = os.path.join(log_dir, "approx_exploitability_stats.json") ensure_dir(self._exploitability_stats_save_path)
def on_train_result(self, *, trainer, result: dict, **kwargs): super().on_train_result(trainer=trainer, result=result, **kwargs) result["scenario_name"] = trainer.scenario_name result["avg_br_reward_both_players"] = ray.get(trainer.avg_br_reward_deque.get_mean.remote()) training_iteration = result["training_iteration"] if (calculate_openspiel_metanash and (training_iteration == 1 or training_iteration % calc_metanash_every_n_iters == 0)): base_env = _create_env() open_spiel_env_config = base_env.open_spiel_env_config openspiel_game_version = base_env.game_version local_avg_policy_0 = trainer.workers.local_worker().policy_map["average_policy_0"] local_avg_policy_1 = trainer.workers.local_worker().policy_map["average_policy_1"] exploitability = nfsp_measure_exploitability_nonlstm( rllib_policies=[local_avg_policy_0, local_avg_policy_1], poker_game_version=openspiel_game_version, open_spiel_env_config=open_spiel_env_config ) result["avg_policy_exploitability"] = exploitability logger.info(colored( f"(Graph this in a notebook) Exploitability: {exploitability} - Saving exploitability stats " f"to {os.path.join(trainer.logdir, 'result.json')}", "green")) if checkpoint_every_n_iters and (training_iteration % checkpoint_every_n_iters == 0 or training_iteration == 1): for player in range(2): checkpoint_metadata = create_metadata_with_new_checkpoint( policy_id_to_save=f"average_policy_{player}", br_trainer=br_trainer, save_dir=checkpoint_dir(trainer=br_trainer), timesteps_training=result["timesteps_total"], episodes_training=result["episodes_total"], checkpoint_name=f"average_policy_player_{player}_iter_{training_iteration}.h5" ) avg_pol_checkpoint_spec = StrategySpec( strategy_id=f"avg_pol_player_{player}_iter_{training_iteration}", metadata=checkpoint_metadata) checkpoint_path = os.path.join(spec_checkpoint_dir(br_trainer), f"average_policy_player_{player}_iter_{training_iteration}.json") ensure_dir(checkpoint_path) with open(checkpoint_path, "+w") as checkpoint_spec_file: checkpoint_spec_file.write(avg_pol_checkpoint_spec.to_json())
def save_nfsp_average_policy_checkpoint(trainer: Trainer, policy_id_to_save: str, save_dir: str, timesteps_training: int, episodes_training: int, checkpoint_name=None): policy_name = policy_id_to_save date_time = datetime_str() if checkpoint_name is None: checkpoint_name = f"policy_{policy_name}_{date_time}.h5" checkpoint_path = os.path.join(save_dir, checkpoint_name) br_weights = trainer.get_weights([policy_id_to_save])[policy_id_to_save] br_weights = {k.replace(".", "_dot_"): v for k, v in br_weights.items()} # periods cause HDF5 NaturalNaming warnings ensure_dir(file_path=checkpoint_path) deepdish.io.save(path=checkpoint_path, data={ "weights": br_weights, "date_time_str": date_time, "seconds_since_epoch": time.time(), "timesteps_training": timesteps_training, "episodes_training": episodes_training }, ) return checkpoint_path
def launch_manager(scenario: NXDOScenario, nxdo_port: int, block: bool = True) -> NXDOManagerWithServer: if not isinstance(scenario, NXDOScenario): raise TypeError( f"Only instances of {NXDOScenario} can be used here. {scenario.name} is a {type(scenario)}." ) ray_head_address = init_ray_for_scenario(scenario=scenario, head_address=None, logging_level=logging.INFO) solve_restricted_game: SolveRestrictedGame = scenario.get_restricted_game_solver( scenario) log_dir = os.path.join(os.path.dirname(grl.__file__), "data", scenario.name, f"manager_{datetime_str()}") name_file_path = os.path.join(log_dir, "scenario_name.txt") ensure_dir(name_file_path) with open(name_file_path, "w+") as name_file: name_file.write(scenario.name) manager = NXDOManagerWithServer( solve_restricted_game=solve_restricted_game, n_players=2, log_dir=os.path.join(os.path.dirname(grl.__file__), "data", scenario.name, f"manager_{datetime_str()}"), port=nxdo_port, manager_metadata={"ray_head_address": ray_head_address}, ) if block: manager.wait_for_server_termination() return manager
def on_active_policy_moved_to_fixed(self, player: int, policy_num: int, fixed_policy_spec: StrategySpec): logger.info(f"Player {player} policy {policy_num} moved to fixed.") # save a checkpoint of the payoff table data = self._manager.get_copy_of_latest_data() latest_payoff_table, active_policy_nums_per_player, fixed_policy_nums_per_player = data self._latest_numbered_payoff_table_checkpoint_path = os.path.join(self._payoff_table_checkpoint_dir, f"payoff_table_checkpoint_{self._payoff_table_checkpoint_count}.json") self._latest_numbered_policy_nums_path = os.path.join(self._payoff_table_checkpoint_dir, f"policy_nums_checkpoint_{self._payoff_table_checkpoint_count}.json") pt_checkpoint_paths = [os.path.join(self._payoff_table_checkpoint_dir, f"payoff_table_checkpoint_latest.json"), self._latest_numbered_payoff_table_checkpoint_path] policy_nums_paths = [os.path.join(self._payoff_table_checkpoint_dir, f"policy_nums_checkpoint_latest.json"), self._latest_numbered_policy_nums_path] for pt_checkpoint_path, policy_nums_path in zip(pt_checkpoint_paths, policy_nums_paths): ensure_dir(file_path=pt_checkpoint_path) ensure_dir(file_path=policy_nums_path) latest_payoff_table.to_json_file(file_path=pt_checkpoint_path) print(f"\n\n\nSaved payoff table checkpoint to {pt_checkpoint_path}") player_policy_nums = {} for player_i, (active_policy_nums, fixed_policy_nums) in enumerate( zip(active_policy_nums_per_player, fixed_policy_nums_per_player)): player_policy_nums[player_i] = { "active_policies": active_policy_nums, "fixed_policies": fixed_policy_nums } with open(policy_nums_path, "w+") as policy_nums_file: json.dump(obj=player_policy_nums, fp=policy_nums_file) print(f"Saved policy nums checkpoint to {policy_nums_path}\n\n\n") # append checkpoints metadata to checkpoints_manifest.txt checkpoints_manifest_path = os.path.join(self._payoff_table_checkpoint_dir, "checkpoints_manifest.json") ensure_dir(file_path=checkpoints_manifest_path) with open(checkpoints_manifest_path, "a+") as manifest_file: if all(len(fixed_policy_nums) > 0 for fixed_policy_nums in fixed_policy_nums_per_player): highest_fixed_policies_for_all_players = min( max(fixed_policy_nums) for fixed_policy_nums in fixed_policy_nums_per_player) else: highest_fixed_policies_for_all_players = None manifest_json_line = json.dumps({"payoff_table_checkpoint_num": self._payoff_table_checkpoint_count, "highest_fixed_policies_for_all_players": highest_fixed_policies_for_all_players, "payoff_table_json_path": self._latest_numbered_payoff_table_checkpoint_path, "policy_nums_json_path": self._latest_numbered_policy_nums_path}) manifest_file.write(f"{manifest_json_line}\n") self._payoff_table_checkpoint_count += 1
def submit_final_br_policy(self, player, policy_num, metadata_dict): with self.modification_lock: if player < 0 or player >= self._n_players: raise ValueError( f"player {player} is out of range. Must be in [0, n_players)." ) if policy_num != self._current_double_oracle_iteration: raise ValueError( f"Policy {policy_num} isn't the same as the current double oracle iteration " f"{self._current_double_oracle_iteration}.") br_policy_spec: StrategySpec = StrategySpec( strategy_id=self._strat_id(player=player, policy_num=policy_num), metadata=metadata_dict, pure_strategy_indexes={player: policy_num}) self._br_episodes_this_iter += metadata_dict[ "episodes_training_br"] self._br_timesteps_this_iter += metadata_dict[ "timesteps_training_br"] self._next_iter_br_spec_lists_for_each_player[player].append( br_policy_spec) self._player_brs_are_finished_this_iter[player] = True all_players_finished_brs_this_ter = all( self._player_brs_are_finished_this_iter.values()) if all_players_finished_brs_this_ter: print("Solving restricted game") restricted_game_solve_result = self._solve_restricted_game( log_dir=self.log_dir, br_spec_lists_for_each_player=self. _next_iter_br_spec_lists_for_each_player, manager_metadata=self.get_manager_metadata()) self._latest_metanash_spec_for_each_player = restricted_game_solve_result.latest_metanash_spec_for_each_player self._restricted_game_episodes_this_iter += restricted_game_solve_result.episodes_spent_in_solve self._restricted_game_timesteps_this_iter += restricted_game_solve_result.timesteps_spent_in_solve self._episodes_total += ( self._br_episodes_this_iter + self._restricted_game_episodes_this_iter) self._timesteps_total += ( self._br_timesteps_this_iter + self._restricted_game_timesteps_this_iter) br_specs_added_this_iter = { player: player_br_spec_list[-1] for player, player_br_spec_list in self._next_iter_br_spec_lists_for_each_player.items() } data_to_log = { "episodes_total": self._episodes_total, "timesteps_total": self._timesteps_total, "br_episodes_this_iter": self._br_episodes_this_iter, "br_timesteps_this_iter": self._br_timesteps_this_iter, "restricted_game_episodes_this_iter": self._restricted_game_episodes_this_iter, "restricted_game_timesteps_this_iter": self._restricted_game_timesteps_this_iter, "br_specs_added_this_iter": { player: spec.to_json() for player, spec in br_specs_added_this_iter.items() }, "metanash_specs": [ spec.to_json() for spec in self._latest_metanash_spec_for_each_player ], } if all("average_br_reward" in br_spec.metadata for br_spec in br_specs_added_this_iter.values()): data_to_log["player_br_rewards_vs_previous_metanash"] = { player: br_spec.metadata["average_br_reward"] for player, br_spec in br_specs_added_this_iter.items() } assert "episodes_total" not in restricted_game_solve_result.extra_data_to_log assert "timesteps_total" not in restricted_game_solve_result.extra_data_to_log data_to_log.update( restricted_game_solve_result.extra_data_to_log) with open(self._json_log_path, "+a") as json_file: json_file.writelines([json.dumps(data_to_log) + '\n']) print( colored( f"(Graph this in a notebook) Saved manager stats (including exploitability if applicable) " f"to {self._json_log_path}", "green")) for checkpoint_player, player_metanash_spec in enumerate( restricted_game_solve_result. latest_metanash_spec_for_each_player): checkpoint_path = os.path.join( self.log_dir, "xfdo_metanash_specs", f"{checkpoint_player}_metanash_{self._current_double_oracle_iteration}.json" ) ensure_dir(checkpoint_path) with open(checkpoint_path, "+w") as checkpoint_spec_file: checkpoint_spec_file.write( player_metanash_spec.to_json()) # Start the next double oracle iteration here. # A double oracle iteration is considered to be training BRs # followed by solving the new restricted game. self._current_double_oracle_iteration += 1 self._br_episodes_this_iter = 0 self._br_timesteps_this_iter = 0 self._restricted_game_episodes_this_iter = 0 self._restricted_game_timesteps_this_iter = 0 self._player_brs_are_finished_this_iter = { p: False for p in range(self._n_players) } self._br_spec_lists_for_each_player = deepcopy( self._next_iter_br_spec_lists_for_each_player)
def launch_manager(scenario: PSROScenario, psro_port: int, eval_port: int, block: bool = True, include_evals: bool = True) -> P2SROManagerWithServer: if not isinstance(scenario, PSROScenario): raise TypeError( f"Only instances of {PSROScenario} can be used here. {scenario.name} is a {type(scenario)}." ) ray_head_address = init_ray_for_scenario(scenario=scenario, head_address=None, logging_level=logging.INFO) log_dir = os.path.join(os.path.dirname(grl.__file__), "data", scenario.name, f"manager_{datetime_str()}") name_file_path = os.path.join(log_dir, "scenario.name.txt") ensure_dir(name_file_path) with open(name_file_path, "w+") as name_file: name_file.write(scenario.name) if scenario.calc_exploitability_for_openspiel_env: def get_manager_logger(_manager: P2SROManager): return ExploitabilityP2SROManagerLogger(p2sro_manger=_manager, log_dir=_manager.log_dir, scenario=scenario) else: def get_manager_logger(_manager: P2SROManager): return ApproxExploitabilityP2SROManagerLogger( p2sro_manger=_manager, log_dir=_manager.log_dir, scenario=scenario) manager = P2SROManagerWithServer( port=psro_port, eval_dispatcher_port=eval_port, n_players=2, is_two_player_symmetric_zero_sum=scenario.single_agent_symmetric_game, do_external_payoff_evals_for_new_fixed_policies=True, games_per_external_payoff_eval=scenario.games_per_payoff_eval, payoff_table_exponential_average_coeff=scenario. p2sro_payoff_table_exponential_avg_coeff, log_dir=log_dir, manager_metadata={"ray_head_address": ray_head_address}, get_manager_logger=get_manager_logger) print(f"Launched P2SRO Manager with server.") if include_evals: launch_evals(scenario_name=scenario.name, block=False, ray_head_address=ray_head_address, eval_dispatcher_port=eval_port, eval_dispatcher_host='localhost') print(f"Launched evals") if block: manager.wait_for_server_termination() return manager