Esempio n. 1
0
def save_best_response_checkpoint(trainer,
                                  player: int,
                                  save_dir: str,
                                  timesteps_training_br: int,
                                  episodes_training_br: int,
                                  active_policy_num: int = None,
                                  average_br_reward: float = None):
    policy_name = active_policy_num if active_policy_num is not None else "unclaimed"
    date_time = datetime_str()
    checkpoint_name = f"policy_{policy_name}_{date_time}.h5"
    checkpoint_path = os.path.join(save_dir, checkpoint_name)
    br_weights = trainer.get_weights([f"best_response"])["best_response"]
    br_weights = {k.replace(".", "_dot_"): v for k, v in
                  br_weights.items()}  # periods cause HDF5 NaturalNaming warnings
    ensure_dir(file_path=checkpoint_path)
    num_save_attempts = 5
    for attempt in range(num_save_attempts):
        try:
            deepdish.io.save(path=checkpoint_path, data={
                "weights": br_weights,
                "player": player,
                "policy_num": active_policy_num,
                "date_time_str": date_time,
                "seconds_since_epoch": time.time(),
                "timesteps_training_br": timesteps_training_br,
                "episodes_training_br": episodes_training_br,
                "average_br_reward": average_br_reward,
            })
            break
        except HDF5ExtError:
            if attempt + 1 == num_save_attempts:
                raise
            time.sleep(1.0)
    return checkpoint_path
Esempio n. 2
0
        def on_train_result(self, *, trainer, result: dict, **kwargs):
            result["scenario_name"] = trainer.scenario_name
            training_iteration = result["training_iteration"]
            super().on_train_result(trainer=trainer, result=result, **kwargs)

            if training_iteration % checkpoint_every_n_iters == 0 or training_iteration == 1:
                for player in range(2):
                    checkpoint_metadata = create_metadata_with_new_checkpoint(
                        policy_id_to_save=f"best_response_{player}",
                        br_trainer=trainer,
                        policy_player=player,
                        save_dir=checkpoint_dir(trainer=trainer),
                        timesteps_training=result["timesteps_total"],
                        episodes_training=result["episodes_total"],
                        checkpoint_name=
                        f"best_response_player_{player}_iter_{training_iteration}.h5"
                    )
                    joint_pol_checkpoint_spec = StrategySpec(
                        strategy_id=
                        f"best_response_player_{player}_iter_{training_iteration}",
                        metadata=checkpoint_metadata)
                    checkpoint_path = os.path.join(
                        spec_checkpoint_dir(trainer),
                        f"best_response_player_{player}_iter_{training_iteration}.json"
                    )
                    ensure_dir(checkpoint_path)
                    with open(checkpoint_path, "+w") as checkpoint_spec_file:
                        checkpoint_spec_file.write(
                            joint_pol_checkpoint_spec.to_json())
Esempio n. 3
0
def save_policy_checkpoint(trainer: Trainer, player: int, save_dir: str,
                           policy_id_to_save: PolicyID, checkpoint_name: str,
                           additional_data: Dict[str, Any]):
    date_time = datetime_str()
    checkpoint_name = f"policy_{checkpoint_name}_{date_time}.h5"
    checkpoint_path = os.path.join(save_dir, checkpoint_name)
    br_weights = trainer.get_weights([policy_id_to_save])[policy_id_to_save]
    br_weights = {k.replace(".", "_dot_"): v
                  for k, v in br_weights.items()
                  }  # periods cause HDF5 NaturalNaming warnings
    ensure_dir(file_path=checkpoint_path)
    num_save_attempts = 5

    checkpoint_data = {
        "weights": br_weights,
        "player": player,
        "date_time_str": date_time,
        "seconds_since_epoch": time.time(),
    }
    checkpoint_data.update(additional_data)

    for attempt in range(num_save_attempts):
        try:
            deepdish.io.save(path=checkpoint_path, data=checkpoint_data)
            break
        except HDF5ExtError:
            if attempt + 1 == num_save_attempts:
                raise
            time.sleep(1.0)
    return checkpoint_path
Esempio n. 4
0
    def __init__(self, p2sro_manger, log_dir: str, scenario: PSROScenario):
        super(ExploitabilityP2SROManagerLogger,
              self).__init__(p2sro_manger=p2sro_manger, log_dir=log_dir)

        self._scenario = scenario
        if not issubclass(scenario.env_class,
                          (PokerMultiAgentEnv, OshiZumoMultiAgentEnv)):
            raise ValueError(
                f"ExploitabilityP2SROManagerLogger is only meant to be used with PokerMultiAgentEnv or OshiZumoMultiAgentEnv,"
                f"not {scenario.env_class}")
        if not scenario.calc_exploitability_for_openspiel_env:
            raise ValueError(
                f"Only use ExploitabilityP2SROManagerLogger if "
                f"scenario.calc_exploitability_for_openspiel_env is True.")

        self._exploitability_per_generation = []
        self._total_steps_per_generation = []
        self._total_episodes_per_generation = []
        self._num_policies_per_generation = []
        self._payoff_table_checkpoint_nums = []
        self._payoff_table_checkpoint_paths = []
        self._policy_nums_checkpoint_paths = []

        self._exploitability_stats_save_path = os.path.join(
            log_dir, "exploitability_stats.json")
        ensure_dir(self._exploitability_stats_save_path)
Esempio n. 5
0
    def __init__(self, p2sro_manger, log_dir: str, scenario: PSROScenario):
        super(ApproxExploitabilityP2SROManagerLogger, self).__init__(p2sro_manger=p2sro_manger, log_dir=log_dir)

        self._scenario = scenario

        self._exploitability_per_generation = []
        self._total_steps_per_generation = []
        self._total_episodes_per_generation = []
        self._num_policies_per_generation = []
        self._payoff_table_checkpoint_nums = []
        self._payoff_table_checkpoint_paths = []
        self._policy_nums_checkpoint_paths = []

        self._exploitability_stats_save_path = os.path.join(log_dir, "approx_exploitability_stats.json")
        ensure_dir(self._exploitability_stats_save_path)
Esempio n. 6
0
        def on_train_result(self, *, trainer, result: dict, **kwargs):
            super().on_train_result(trainer=trainer, result=result, **kwargs)
            result["scenario_name"] = trainer.scenario_name
            result["avg_br_reward_both_players"] = ray.get(trainer.avg_br_reward_deque.get_mean.remote())

            training_iteration = result["training_iteration"]
            if (calculate_openspiel_metanash and
                    (training_iteration == 1 or training_iteration % calc_metanash_every_n_iters == 0)):
                base_env = _create_env()
                open_spiel_env_config = base_env.open_spiel_env_config
                openspiel_game_version = base_env.game_version
                local_avg_policy_0 = trainer.workers.local_worker().policy_map["average_policy_0"]
                local_avg_policy_1 = trainer.workers.local_worker().policy_map["average_policy_1"]
                exploitability = nfsp_measure_exploitability_nonlstm(
                    rllib_policies=[local_avg_policy_0, local_avg_policy_1],
                    poker_game_version=openspiel_game_version,
                    open_spiel_env_config=open_spiel_env_config
                )
                result["avg_policy_exploitability"] = exploitability
                logger.info(colored(
                    f"(Graph this in a notebook) Exploitability: {exploitability} - Saving exploitability stats "
                    f"to {os.path.join(trainer.logdir, 'result.json')}", "green"))

            if checkpoint_every_n_iters and (training_iteration % checkpoint_every_n_iters == 0 or training_iteration == 1):
                for player in range(2):
                    checkpoint_metadata = create_metadata_with_new_checkpoint(
                        policy_id_to_save=f"average_policy_{player}",
                        br_trainer=br_trainer,
                        save_dir=checkpoint_dir(trainer=br_trainer),
                        timesteps_training=result["timesteps_total"],
                        episodes_training=result["episodes_total"],
                        checkpoint_name=f"average_policy_player_{player}_iter_{training_iteration}.h5"
                    )
                    avg_pol_checkpoint_spec = StrategySpec(
                        strategy_id=f"avg_pol_player_{player}_iter_{training_iteration}",
                        metadata=checkpoint_metadata)
                    checkpoint_path = os.path.join(spec_checkpoint_dir(br_trainer),
                                                   f"average_policy_player_{player}_iter_{training_iteration}.json")
                    ensure_dir(checkpoint_path)
                    with open(checkpoint_path, "+w") as checkpoint_spec_file:
                        checkpoint_spec_file.write(avg_pol_checkpoint_spec.to_json())
Esempio n. 7
0
def save_nfsp_average_policy_checkpoint(trainer: Trainer,
                                        policy_id_to_save: str,
                                        save_dir: str,
                                        timesteps_training: int,
                                        episodes_training: int,
                                        checkpoint_name=None):
    policy_name = policy_id_to_save
    date_time = datetime_str()
    if checkpoint_name is None:
        checkpoint_name = f"policy_{policy_name}_{date_time}.h5"
    checkpoint_path = os.path.join(save_dir, checkpoint_name)
    br_weights = trainer.get_weights([policy_id_to_save])[policy_id_to_save]
    br_weights = {k.replace(".", "_dot_"): v for k, v in
                  br_weights.items()}  # periods cause HDF5 NaturalNaming warnings
    ensure_dir(file_path=checkpoint_path)
    deepdish.io.save(path=checkpoint_path, data={
        "weights": br_weights,
        "date_time_str": date_time,
        "seconds_since_epoch": time.time(),
        "timesteps_training": timesteps_training,
        "episodes_training": episodes_training
    }, )
    return checkpoint_path
Esempio n. 8
0
def launch_manager(scenario: NXDOScenario,
                   nxdo_port: int,
                   block: bool = True) -> NXDOManagerWithServer:
    if not isinstance(scenario, NXDOScenario):
        raise TypeError(
            f"Only instances of {NXDOScenario} can be used here. {scenario.name} is a {type(scenario)}."
        )

    ray_head_address = init_ray_for_scenario(scenario=scenario,
                                             head_address=None,
                                             logging_level=logging.INFO)

    solve_restricted_game: SolveRestrictedGame = scenario.get_restricted_game_solver(
        scenario)

    log_dir = os.path.join(os.path.dirname(grl.__file__), "data",
                           scenario.name, f"manager_{datetime_str()}")

    name_file_path = os.path.join(log_dir, "scenario_name.txt")
    ensure_dir(name_file_path)
    with open(name_file_path, "w+") as name_file:
        name_file.write(scenario.name)

    manager = NXDOManagerWithServer(
        solve_restricted_game=solve_restricted_game,
        n_players=2,
        log_dir=os.path.join(os.path.dirname(grl.__file__), "data",
                             scenario.name, f"manager_{datetime_str()}"),
        port=nxdo_port,
        manager_metadata={"ray_head_address": ray_head_address},
    )

    if block:
        manager.wait_for_server_termination()

    return manager
Esempio n. 9
0
    def on_active_policy_moved_to_fixed(self, player: int, policy_num: int, fixed_policy_spec: StrategySpec):
        logger.info(f"Player {player} policy {policy_num} moved to fixed.")

        # save a checkpoint of the payoff table
        data = self._manager.get_copy_of_latest_data()
        latest_payoff_table, active_policy_nums_per_player, fixed_policy_nums_per_player = data

        self._latest_numbered_payoff_table_checkpoint_path = os.path.join(self._payoff_table_checkpoint_dir,
                                                   f"payoff_table_checkpoint_{self._payoff_table_checkpoint_count}.json")
        self._latest_numbered_policy_nums_path = os.path.join(self._payoff_table_checkpoint_dir,
                                                 f"policy_nums_checkpoint_{self._payoff_table_checkpoint_count}.json")

        pt_checkpoint_paths = [os.path.join(self._payoff_table_checkpoint_dir, f"payoff_table_checkpoint_latest.json"),
                               self._latest_numbered_payoff_table_checkpoint_path]
        policy_nums_paths = [os.path.join(self._payoff_table_checkpoint_dir, f"policy_nums_checkpoint_latest.json"),
                             self._latest_numbered_policy_nums_path]

        for pt_checkpoint_path, policy_nums_path in zip(pt_checkpoint_paths, policy_nums_paths):
            ensure_dir(file_path=pt_checkpoint_path)
            ensure_dir(file_path=policy_nums_path)

            latest_payoff_table.to_json_file(file_path=pt_checkpoint_path)
            print(f"\n\n\nSaved payoff table checkpoint to {pt_checkpoint_path}")

            player_policy_nums = {}
            for player_i, (active_policy_nums, fixed_policy_nums) in enumerate(
                    zip(active_policy_nums_per_player, fixed_policy_nums_per_player)):
                player_policy_nums[player_i] = {
                    "active_policies": active_policy_nums,
                    "fixed_policies": fixed_policy_nums
                }

            with open(policy_nums_path, "w+") as policy_nums_file:
                json.dump(obj=player_policy_nums, fp=policy_nums_file)
            print(f"Saved policy nums checkpoint to {policy_nums_path}\n\n\n")

        # append checkpoints metadata to checkpoints_manifest.txt
        checkpoints_manifest_path = os.path.join(self._payoff_table_checkpoint_dir, "checkpoints_manifest.json")
        ensure_dir(file_path=checkpoints_manifest_path)
        with open(checkpoints_manifest_path, "a+") as manifest_file:
            if all(len(fixed_policy_nums) > 0 for fixed_policy_nums in fixed_policy_nums_per_player):
                highest_fixed_policies_for_all_players = min(
                    max(fixed_policy_nums) for fixed_policy_nums in fixed_policy_nums_per_player)
            else:
                highest_fixed_policies_for_all_players = None
            manifest_json_line = json.dumps({"payoff_table_checkpoint_num": self._payoff_table_checkpoint_count,
                                             "highest_fixed_policies_for_all_players": highest_fixed_policies_for_all_players,
                                             "payoff_table_json_path": self._latest_numbered_payoff_table_checkpoint_path,
                                             "policy_nums_json_path": self._latest_numbered_policy_nums_path})
            manifest_file.write(f"{manifest_json_line}\n")

        self._payoff_table_checkpoint_count += 1
Esempio n. 10
0
    def submit_final_br_policy(self, player, policy_num, metadata_dict):
        with self.modification_lock:
            if player < 0 or player >= self._n_players:
                raise ValueError(
                    f"player {player} is out of range. Must be in [0, n_players)."
                )
            if policy_num != self._current_double_oracle_iteration:
                raise ValueError(
                    f"Policy {policy_num} isn't the same as the current double oracle iteration "
                    f"{self._current_double_oracle_iteration}.")

            br_policy_spec: StrategySpec = StrategySpec(
                strategy_id=self._strat_id(player=player,
                                           policy_num=policy_num),
                metadata=metadata_dict,
                pure_strategy_indexes={player: policy_num})

            self._br_episodes_this_iter += metadata_dict[
                "episodes_training_br"]
            self._br_timesteps_this_iter += metadata_dict[
                "timesteps_training_br"]

            self._next_iter_br_spec_lists_for_each_player[player].append(
                br_policy_spec)
            self._player_brs_are_finished_this_iter[player] = True

            all_players_finished_brs_this_ter = all(
                self._player_brs_are_finished_this_iter.values())
            if all_players_finished_brs_this_ter:
                print("Solving restricted game")
                restricted_game_solve_result = self._solve_restricted_game(
                    log_dir=self.log_dir,
                    br_spec_lists_for_each_player=self.
                    _next_iter_br_spec_lists_for_each_player,
                    manager_metadata=self.get_manager_metadata())
                self._latest_metanash_spec_for_each_player = restricted_game_solve_result.latest_metanash_spec_for_each_player

                self._restricted_game_episodes_this_iter += restricted_game_solve_result.episodes_spent_in_solve
                self._restricted_game_timesteps_this_iter += restricted_game_solve_result.timesteps_spent_in_solve

                self._episodes_total += (
                    self._br_episodes_this_iter +
                    self._restricted_game_episodes_this_iter)
                self._timesteps_total += (
                    self._br_timesteps_this_iter +
                    self._restricted_game_timesteps_this_iter)

                br_specs_added_this_iter = {
                    player: player_br_spec_list[-1]
                    for player, player_br_spec_list in
                    self._next_iter_br_spec_lists_for_each_player.items()
                }

                data_to_log = {
                    "episodes_total":
                    self._episodes_total,
                    "timesteps_total":
                    self._timesteps_total,
                    "br_episodes_this_iter":
                    self._br_episodes_this_iter,
                    "br_timesteps_this_iter":
                    self._br_timesteps_this_iter,
                    "restricted_game_episodes_this_iter":
                    self._restricted_game_episodes_this_iter,
                    "restricted_game_timesteps_this_iter":
                    self._restricted_game_timesteps_this_iter,
                    "br_specs_added_this_iter": {
                        player: spec.to_json()
                        for player, spec in br_specs_added_this_iter.items()
                    },
                    "metanash_specs": [
                        spec.to_json()
                        for spec in self._latest_metanash_spec_for_each_player
                    ],
                }
                if all("average_br_reward" in br_spec.metadata
                       for br_spec in br_specs_added_this_iter.values()):
                    data_to_log["player_br_rewards_vs_previous_metanash"] = {
                        player: br_spec.metadata["average_br_reward"]
                        for player, br_spec in
                        br_specs_added_this_iter.items()
                    }

                assert "episodes_total" not in restricted_game_solve_result.extra_data_to_log
                assert "timesteps_total" not in restricted_game_solve_result.extra_data_to_log
                data_to_log.update(
                    restricted_game_solve_result.extra_data_to_log)

                with open(self._json_log_path, "+a") as json_file:
                    json_file.writelines([json.dumps(data_to_log) + '\n'])
                print(
                    colored(
                        f"(Graph this in a notebook) Saved manager stats (including exploitability if applicable) "
                        f"to {self._json_log_path}", "green"))

                for checkpoint_player, player_metanash_spec in enumerate(
                        restricted_game_solve_result.
                        latest_metanash_spec_for_each_player):
                    checkpoint_path = os.path.join(
                        self.log_dir, "xfdo_metanash_specs",
                        f"{checkpoint_player}_metanash_{self._current_double_oracle_iteration}.json"
                    )
                    ensure_dir(checkpoint_path)
                    with open(checkpoint_path, "+w") as checkpoint_spec_file:
                        checkpoint_spec_file.write(
                            player_metanash_spec.to_json())

                # Start the next double oracle iteration here.
                # A double oracle iteration is considered to be training BRs
                # followed by solving the new restricted game.
                self._current_double_oracle_iteration += 1
                self._br_episodes_this_iter = 0
                self._br_timesteps_this_iter = 0
                self._restricted_game_episodes_this_iter = 0
                self._restricted_game_timesteps_this_iter = 0
                self._player_brs_are_finished_this_iter = {
                    p: False
                    for p in range(self._n_players)
                }
                self._br_spec_lists_for_each_player = deepcopy(
                    self._next_iter_br_spec_lists_for_each_player)
Esempio n. 11
0
def launch_manager(scenario: PSROScenario,
                   psro_port: int,
                   eval_port: int,
                   block: bool = True,
                   include_evals: bool = True) -> P2SROManagerWithServer:
    if not isinstance(scenario, PSROScenario):
        raise TypeError(
            f"Only instances of {PSROScenario} can be used here. {scenario.name} is a {type(scenario)}."
        )

    ray_head_address = init_ray_for_scenario(scenario=scenario,
                                             head_address=None,
                                             logging_level=logging.INFO)

    log_dir = os.path.join(os.path.dirname(grl.__file__), "data",
                           scenario.name, f"manager_{datetime_str()}")

    name_file_path = os.path.join(log_dir, "scenario.name.txt")
    ensure_dir(name_file_path)
    with open(name_file_path, "w+") as name_file:
        name_file.write(scenario.name)

    if scenario.calc_exploitability_for_openspiel_env:

        def get_manager_logger(_manager: P2SROManager):
            return ExploitabilityP2SROManagerLogger(p2sro_manger=_manager,
                                                    log_dir=_manager.log_dir,
                                                    scenario=scenario)
    else:

        def get_manager_logger(_manager: P2SROManager):
            return ApproxExploitabilityP2SROManagerLogger(
                p2sro_manger=_manager,
                log_dir=_manager.log_dir,
                scenario=scenario)

    manager = P2SROManagerWithServer(
        port=psro_port,
        eval_dispatcher_port=eval_port,
        n_players=2,
        is_two_player_symmetric_zero_sum=scenario.single_agent_symmetric_game,
        do_external_payoff_evals_for_new_fixed_policies=True,
        games_per_external_payoff_eval=scenario.games_per_payoff_eval,
        payoff_table_exponential_average_coeff=scenario.
        p2sro_payoff_table_exponential_avg_coeff,
        log_dir=log_dir,
        manager_metadata={"ray_head_address": ray_head_address},
        get_manager_logger=get_manager_logger)
    print(f"Launched P2SRO Manager with server.")

    if include_evals:
        launch_evals(scenario_name=scenario.name,
                     block=False,
                     ray_head_address=ray_head_address,
                     eval_dispatcher_port=eval_port,
                     eval_dispatcher_host='localhost')
        print(f"Launched evals")

    if block:
        manager.wait_for_server_termination()

    return manager