Esempio n. 1
0
def launch_evals(scenario_name: str,
                 eval_dispatcher_port: int,
                 eval_dispatcher_host: str,
                 block=True,
                 ray_head_address=None):

    scenario: PSROScenario = scenario_catalog.get(scenario_name=scenario_name)

    init_ray_for_scenario(scenario=scenario, head_address=ray_head_address, logging_level=logging.INFO)

    num_workers = scenario.num_eval_workers
    evaluator_refs = [run_poker_evaluation_loop.remote(scenario_name, eval_dispatcher_port, eval_dispatcher_host)
                      for _ in range(num_workers)]
    if block:
        ray.wait(evaluator_refs, num_returns=num_workers)
Esempio n. 2
0
def train_self_play(results_dir, scenario_name, print_train_results=True):

    scenario: PSROScenario = scenario_catalog.get(scenario_name=scenario_name)

    env_class = scenario.env_class
    env_config = scenario.env_config
    trainer_class = scenario.trainer_class
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes
    single_agent_symmetric_game = scenario.single_agent_symmetric_game
    if single_agent_symmetric_game:
        raise NotImplementedError

    get_trainer_config = scenario.get_trainer_config
    should_log_result_fn = scenario.ray_should_log_result_filter

    checkpoint_every_n_iters = 500

    class PreAndPostEpisodeCallbacks(DefaultCallbacks):
        def on_train_result(self, *, trainer, result: dict, **kwargs):
            result["scenario_name"] = trainer.scenario_name
            training_iteration = result["training_iteration"]
            super().on_train_result(trainer=trainer, result=result, **kwargs)

            if training_iteration % checkpoint_every_n_iters == 0 or training_iteration == 1:
                for player in range(2):
                    checkpoint_metadata = create_metadata_with_new_checkpoint(
                        policy_id_to_save=f"best_response_{player}",
                        br_trainer=trainer,
                        policy_player=player,
                        save_dir=checkpoint_dir(trainer=trainer),
                        timesteps_training=result["timesteps_total"],
                        episodes_training=result["episodes_total"],
                        checkpoint_name=
                        f"best_response_player_{player}_iter_{training_iteration}.h5"
                    )
                    joint_pol_checkpoint_spec = StrategySpec(
                        strategy_id=
                        f"best_response_player_{player}_iter_{training_iteration}",
                        metadata=checkpoint_metadata)
                    checkpoint_path = os.path.join(
                        spec_checkpoint_dir(trainer),
                        f"best_response_player_{player}_iter_{training_iteration}.json"
                    )
                    ensure_dir(checkpoint_path)
                    with open(checkpoint_path, "+w") as checkpoint_spec_file:
                        checkpoint_spec_file.write(
                            joint_pol_checkpoint_spec.to_json())

    def select_policy(agent_id):
        if agent_id == 0:
            return "best_response_0"
        elif agent_id == 1:
            return "best_response_1"
        else:
            raise ValueError(f"Unknown agent id: {agent_id}")

    init_ray_for_scenario(scenario=scenario,
                          head_address=None,
                          logging_level=logging.INFO)

    tmp_env = env_class(env_config=env_config)

    trainer_config = {
        "callbacks": PreAndPostEpisodeCallbacks,
        "env": env_class,
        "env_config": env_config,
        "gamma": 1.0,
        "num_gpus": 0,
        "num_workers": 0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": [f"best_response_0", "best_response_1"],
            "policies": {
                f"best_response_0":
                (policy_classes["best_response"], tmp_env.observation_space,
                 tmp_env.action_space, {}),
                f"best_response_1":
                (policy_classes["best_response"], tmp_env.observation_space,
                 tmp_env.action_space, {}),
            },
            "policy_mapping_fn": select_policy,
        },
    }

    trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_env))

    trainer = trainer_class(config=trainer_config,
                            logger_creator=get_trainer_logger_creator(
                                base_dir=results_dir,
                                scenario_name=scenario_name,
                                should_log_result_fn=should_log_result_fn))

    # scenario_name logged in on_train_result_callback
    trainer.scenario_name = scenario_name

    # Perform main RL training loop.
    while True:
        train_iter_results = trainer.train(
        )  # do a step (or several) in the main RL loop

        if print_train_results:
            # Delete verbose debugging info before printing
            if "hist_stats" in train_iter_results:
                del train_iter_results["hist_stats"]
            for key in ["best_response_0", "best_response_1"]:
                if "td_error" in train_iter_results["info"]["learner"][key]:
                    del train_iter_results["info"]["learner"][key]["td_error"]
            print(pretty_dict_str(train_iter_results))
Esempio n. 3
0
def run_poker_evaluation_loop(scenario_name: str, eval_dispatcher_port: int, eval_dispatcher_host: str):
    scenario: PSROScenario = scenario_catalog.get(scenario_name=scenario_name)
    if not isinstance(scenario, PSROScenario):
        raise TypeError(f"Only instances of {PSROScenario} can be used here. {scenario.name} is a {type(scenario)}.")

    eval_dispatcher = RemoteEvalDispatcherClient(port=eval_dispatcher_port, remote_server_host=eval_dispatcher_host)

    env = scenario.env_class(env_config=scenario.env_config)
    num_players = 2

    trainer_config = scenario.get_trainer_config(env)
    trainer_config["explore"] = scenario.allow_stochastic_best_responses

    policies = [scenario.policy_classes["eval"](env.observation_space,
                                                env.action_space,
                                                with_common_config(trainer_config))
                for _ in range(num_players)]

    while True:
        policy_specs_for_each_player, required_games_to_play = eval_dispatcher.take_eval_job()

        if policy_specs_for_each_player is None:
            time.sleep(2)
        else:
            if len(policy_specs_for_each_player) != 2:
                raise NotImplementedError(f"This evaluation code only supports two player games. "
                                          f"{len(policy_specs_for_each_player)} players were requested.")

            # print(f"Got eval matchup:")
            # for spec in policy_specs_for_each_player:
            #     print(f"spec: {spec.to_json()}")

            for policy, spec in zip(policies, policy_specs_for_each_player):
                load_pure_strat(policy=policy, pure_strat_spec=spec)

            total_payoffs_per_player = np.zeros(shape=num_players, dtype=np.float64)

            # max_reward = None
            # min_reward = None
            # time_since_last_output = time.time()
            for game in range(required_games_to_play):
                # if game % 1000 == 0:
                #     now = time.time()
                #     print(f"{policy_specs_for_each_player[0].id} vs "
                #           f"{policy_specs_for_each_player[1].id}: "
                #           f"{game}/{required_games_to_play} games played, {now - time_since_last_output} seconds")
                #     time_since_last_output = now

                payoffs_per_player_this_episode = run_episode(env=env, policies_for_each_player=policies)
                total_payoffs_per_player += payoffs_per_player_this_episode

                # if max_reward is None or max(payoffs_per_player_this_episode) > max_reward:
                #     max_reward = max(payoffs_per_player_this_episode)
                # if min_reward is None or min(payoffs_per_player_this_episode) < min_reward:
                #     min_reward = min(payoffs_per_player_this_episode)

            payoffs_per_player = total_payoffs_per_player / required_games_to_play

            print(f"payoffs per player:"
                  f"{policy_specs_for_each_player[0].id} vs "
                  f"{policy_specs_for_each_player[1].id}: "
                  f"{payoffs_per_player}")

            eval_dispatcher.submit_eval_job_result(
                policy_specs_for_each_player_tuple=policy_specs_for_each_player,
                payoffs_for_each_player=payoffs_per_player,
                games_played=required_games_to_play
            )
Esempio n. 4
0
def train_off_policy_rl_nfsp(results_dir: str,
                             scenario_name: str,
                             print_train_results: bool = True):
    
    scenario: NFSPScenario = scenario_catalog.get(scenario_name=scenario_name)
    if not isinstance(scenario, NFSPScenario):
        raise TypeError(f"Only instances of {NFSPScenario} can be used here. {scenario.name} is a {type(scenario)}.")

    env_class = scenario.env_class
    env_config = scenario.env_config
    trainer_class = scenario.trainer_class
    avg_trainer_class = scenario.avg_trainer_class
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes
    anticipatory_param: float = scenario.anticipatory_param
    get_trainer_config = scenario.get_trainer_config
    get_avg_trainer_config = scenario.get_avg_trainer_config
    calculate_openspiel_metanash: bool = scenario.calculate_openspiel_metanash
    calc_metanash_every_n_iters: int = scenario.calc_metanash_every_n_iters
    checkpoint_every_n_iters: Union[int, None] = scenario.checkpoint_every_n_iters
    nfsp_get_stopping_condition = scenario.nfsp_get_stopping_condition
    should_log_result_fn = scenario.ray_should_log_result_filter

    init_ray_for_scenario(scenario=scenario, head_address=None, logging_level=logging.INFO)

    def log(message, level=logging.INFO):
        logger.log(level, message)

    def select_policy(agent_id):
        random_sample = np.random.random()
        if agent_id == 0:
            if random_sample < anticipatory_param:
                return "best_response_0"
            return "average_policy_0"
        elif agent_id == 1:
            if random_sample < anticipatory_param:
                return "best_response_1"
            return "average_policy_1"
        else:
            raise ValueError(f"unexpected agent_id: {agent_id}")

    def assert_not_called(agent_id):
        assert False, "This function should never be called."

    tmp_env = env_class(env_config=env_config)

    def _create_env():
        return env_class(env_config=env_config)

    avg_policy_model_config = get_trainer_config(tmp_env)["model"]

    avg_trainer_config = merge_dicts({
        "log_level": "DEBUG",
        "framework": "torch",
        "env": env_class,
        "env_config": env_config,
        "num_gpus": 0.0,
        "num_gpus_per_worker": 0.0,
        "num_workers": 0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": ["average_policy_0", "average_policy_1"],
            "policies": {
                "average_policy_0": (
                policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, {
                    "model": avg_policy_model_config
                }),
                "average_policy_1": (
                policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, {
                    "model": avg_policy_model_config
                }),
            },
            "policy_mapping_fn": assert_not_called,
        },

    }, get_avg_trainer_config(tmp_env))

    avg_trainer = avg_trainer_class(config=avg_trainer_config,
                                    logger_creator=get_trainer_logger_creator(
                                        base_dir=results_dir,
                                        scenario_name=f"{scenario_name}_avg_trainer",
                                        should_log_result_fn=should_log_result_fn))

    store_to_avg_policy_buffer = get_store_to_avg_policy_buffer_fn(nfsp_trainer=avg_trainer)

    class NFSPBestResponseCallbacks(DefaultCallbacks):

        def on_postprocess_trajectory(self, *, worker: "RolloutWorker", episode: MultiAgentEpisode, agent_id: AgentID,
                                      policy_id: PolicyID, policies: Dict[PolicyID, Policy],
                                      postprocessed_batch: SampleBatch,
                                      original_batches: Dict[Any, Tuple[Policy, SampleBatch]],
                                      **kwargs):
            super().on_postprocess_trajectory(worker=worker, episode=episode, agent_id=agent_id, policy_id=policy_id,
                                              policies=policies, postprocessed_batch=postprocessed_batch,
                                              original_batches=original_batches, **kwargs)

            postprocessed_batch.data["source_policy"] = [policy_id] * len(postprocessed_batch.data["rewards"])

            # All data from both policies will go into the best response's replay buffer.
            # Here we ensure policies not from the best response have the exact same preprocessing as the best response.
            for average_policy_id, br_policy_id in [("average_policy_0", "best_response_0"),
                                                    ("average_policy_1", "best_response_1")]:
                if policy_id == average_policy_id:

                    if "action_probs" in postprocessed_batch:
                        del postprocessed_batch.data["action_probs"]
                    if "behaviour_logits" in postprocessed_batch:
                        del postprocessed_batch.data["behaviour_logits"]

                    br_policy: Policy = policies[br_policy_id]

                    new_batch = br_policy.postprocess_trajectory(
                        sample_batch=postprocessed_batch,
                        other_agent_batches=original_batches,
                        episode=episode)
                    copy_attributes(src_obj=new_batch, dst_obj=postprocessed_batch)
                elif policy_id == br_policy_id:
                    if "q_values" in postprocessed_batch:
                        del postprocessed_batch.data["q_values"]
                    if "action_probs" in postprocessed_batch:
                        del postprocessed_batch.data["action_probs"]
                    del postprocessed_batch.data["action_dist_inputs"]

                if policy_id in ("average_policy_0", "best_response_0"):
                    assert agent_id == 0
                if policy_id in ("average_policy_1", "best_response_1"):
                    assert agent_id == 1

        def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy],
                           episode: MultiAgentEpisode, env_index: int, **kwargs):
            super().on_episode_end(worker=worker, base_env=base_env, policies=policies, episode=episode,
                                   env_index=env_index, **kwargs)

            episode_policies = set(episode.agent_rewards.keys())
            if episode_policies == {(0, "average_policy_0"), (1, "best_response_1")}:
                worker.avg_br_reward_deque.add.remote(episode.agent_rewards[(1, "best_response_1")])
            elif episode_policies == {(1, "average_policy_1"), (0, "best_response_0")}:
                worker.avg_br_reward_deque.add.remote(episode.agent_rewards[(0, "best_response_0")])

        def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs):
            super().on_sample_end(worker=worker, samples=samples, **kwargs)
            assert isinstance(samples, MultiAgentBatch)

            for policy_samples in samples.policy_batches.values():
                if "action_prob" in policy_samples.data:
                    del policy_samples.data["action_prob"]
                if "action_logp" in policy_samples.data:
                    del policy_samples.data["action_logp"]

            for average_policy_id, br_policy_id in [("average_policy_0", "best_response_0"),
                                                    ("average_policy_1", "best_response_1")]:
                for policy_id, policy_samples in samples.policy_batches.items():
                    if policy_id == br_policy_id:
                        store_to_avg_policy_buffer(MultiAgentBatch(policy_batches={
                            average_policy_id: policy_samples
                        }, env_steps=policy_samples.count))
                if average_policy_id in samples.policy_batches:

                    if br_policy_id in samples.policy_batches:
                        all_policies_samples = samples.policy_batches[br_policy_id].concat(
                            other=samples.policy_batches[average_policy_id])
                    else:
                        all_policies_samples = samples.policy_batches[average_policy_id]
                    del samples.policy_batches[average_policy_id]
                    samples.policy_batches[br_policy_id] = all_policies_samples

        def on_train_result(self, *, trainer, result: dict, **kwargs):
            super().on_train_result(trainer=trainer, result=result, **kwargs)
            result["scenario_name"] = trainer.scenario_name
            result["avg_br_reward_both_players"] = ray.get(trainer.avg_br_reward_deque.get_mean.remote())

            training_iteration = result["training_iteration"]
            if (calculate_openspiel_metanash and
                    (training_iteration == 1 or training_iteration % calc_metanash_every_n_iters == 0)):
                base_env = _create_env()
                open_spiel_env_config = base_env.open_spiel_env_config
                openspiel_game_version = base_env.game_version
                local_avg_policy_0 = trainer.workers.local_worker().policy_map["average_policy_0"]
                local_avg_policy_1 = trainer.workers.local_worker().policy_map["average_policy_1"]
                exploitability = nfsp_measure_exploitability_nonlstm(
                    rllib_policies=[local_avg_policy_0, local_avg_policy_1],
                    poker_game_version=openspiel_game_version,
                    open_spiel_env_config=open_spiel_env_config
                )
                result["avg_policy_exploitability"] = exploitability
                logger.info(colored(
                    f"(Graph this in a notebook) Exploitability: {exploitability} - Saving exploitability stats "
                    f"to {os.path.join(trainer.logdir, 'result.json')}", "green"))

            if checkpoint_every_n_iters and (training_iteration % checkpoint_every_n_iters == 0 or training_iteration == 1):
                for player in range(2):
                    checkpoint_metadata = create_metadata_with_new_checkpoint(
                        policy_id_to_save=f"average_policy_{player}",
                        br_trainer=br_trainer,
                        save_dir=checkpoint_dir(trainer=br_trainer),
                        timesteps_training=result["timesteps_total"],
                        episodes_training=result["episodes_total"],
                        checkpoint_name=f"average_policy_player_{player}_iter_{training_iteration}.h5"
                    )
                    avg_pol_checkpoint_spec = StrategySpec(
                        strategy_id=f"avg_pol_player_{player}_iter_{training_iteration}",
                        metadata=checkpoint_metadata)
                    checkpoint_path = os.path.join(spec_checkpoint_dir(br_trainer),
                                                   f"average_policy_player_{player}_iter_{training_iteration}.json")
                    ensure_dir(checkpoint_path)
                    with open(checkpoint_path, "+w") as checkpoint_spec_file:
                        checkpoint_spec_file.write(avg_pol_checkpoint_spec.to_json())

    br_trainer_config = {
        "log_level": "DEBUG",
        "callbacks": NFSPBestResponseCallbacks,
        "env": env_class,
        "env_config": env_config,
        "gamma": 1.0,
        "num_gpus": 0.0,
        "num_workers": 0,
        "num_gpus_per_worker": 0.0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": ["best_response_0", "best_response_1"],
            "policies": {
                "average_policy_0": (
                policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, {
                    "model": avg_policy_model_config,
                    "explore": False,
                }),
                "best_response_0": (
                policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, {}),

                "average_policy_1": (
                policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, {
                    "model": avg_policy_model_config,
                    "explore": False,
                }),
                "best_response_1": (
                policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, {}),
            },
            "policy_mapping_fn": select_policy,
        },
    }
    br_trainer_config = merge_dicts(br_trainer_config, get_trainer_config(tmp_env))

    br_trainer = trainer_class(config=br_trainer_config,
                               logger_creator=get_trainer_logger_creator(base_dir=results_dir,
                                                                         scenario_name=scenario_name,
                                                                         should_log_result_fn=should_log_result_fn))

    avg_br_reward_deque = StatDeque.remote(max_items=br_trainer_config["metrics_smoothing_episodes"])

    def _set_avg_br_rew_deque(worker: RolloutWorker):
        worker.avg_br_reward_deque = avg_br_reward_deque

    br_trainer.workers.foreach_worker(_set_avg_br_rew_deque)
    br_trainer.avg_br_reward_deque = avg_br_reward_deque

    # scenario_name logged in on_train_result_callback
    br_trainer.scenario_name = scenario_name

    br_trainer.latest_avg_trainer_result = None
    train_iter_count = 0

    for trainer in [br_trainer, avg_trainer]:
        for policy_id, policy in trainer.workers.local_worker().policy_map.items():
            policy.policy_id = policy_id

    avg_weights = avg_trainer.get_weights(["average_policy_0", "average_policy_1"])
    br_trainer.workers.foreach_worker(lambda worker: worker.set_weights(avg_weights))

    stopping_condition: StoppingCondition = nfsp_get_stopping_condition()

    print("starting")
    while True:
        print("avg train...")
        avg_train_results = avg_trainer.train()
        avg_weights = avg_trainer.get_weights(["average_policy_0", "average_policy_1"])
        br_trainer.workers.foreach_worker(lambda worker: worker.set_weights(avg_weights))
        br_trainer.latest_avg_trainer_result = copy.deepcopy(avg_train_results)
        print("br train...")
        train_iter_results = br_trainer.train()  # do a step (or several) in the main RL loop

        train_iter_count += 1
        print("printing results..")
        if print_train_results:
            # Delete verbose debugging info before printing
            if "hist_stats" in train_iter_results:
                del train_iter_results["hist_stats"]
            if "td_error" in train_iter_results["info"]["learner"]["best_response_0"]:
                del train_iter_results["info"]["learner"]["best_response_0"]["td_error"]
            if "td_error" in train_iter_results["info"]["learner"]["best_response_1"]:
                del train_iter_results["info"]["learner"]["best_response_1"]["td_error"]
            log(pretty_dict_str(train_iter_results))
        log(f"Trainer logdir is {br_trainer.logdir}")

        if stopping_condition.should_stop_this_iter(latest_trainer_result=train_iter_results):
            print("stopping condition met.")
            break
Esempio n. 5
0
    num_gpus = 0
    env_class = ThousandActionOshiZumoMultiAgentEnv

    br_player = 1
    avg_policy_player = 1 - br_player

    env_config = {
        'version': "oshi_zumo",
        'fixed_players': True,
        'append_valid_actions_mask_to_obs': True,
        'continuous_action_space': False,
        'individual_players_with_continuous_action_space': [br_player],
        'individual_players_with_orig_obs_space': [br_player],
    }

    avg_pol_scenario: NFSPScenario = scenario_catalog.get(
        scenario_name="1000_oshi_zumo_nfsp_larger_dqn_larger")

    trainer_class = PPOTrainer

    tmp_env = env_class(env_config=env_config)

    address_info = ray.init(num_cpus=num_cpus,
                            num_gpus=num_gpus,
                            object_store_memory=int(1073741824 * 1),
                            local_mode=False,
                            include_dashboard=True,
                            dashboard_host="0.0.0.0",
                            dashboard_port=find_free_port(),
                            ignore_reinit_error=True,
                            logging_level=logging.INFO,
                            log_to_driver=os.getenv("RAY_LOG_TO_DRIVER",
    num_cpus = 60
    num_gpus = 0
    env_class = TinyOshiZumoMultiAgentEnv

    br_player = 1
    avg_policy_player = 1 - br_player

    env_config = {
        'version': "oshi_zumo",
        'fixed_players': True,
        'append_valid_actions_mask_to_obs': True,
        'continuous_action_space': False,
        'illegal_actions_default_to_max_coin_value': False,
    }

    avg_pol_scenario: NFSPScenario = scenario_catalog.get(scenario_name="oshi_zumo_tiny_nfsp_dqn")

    trainer_class = DQNTrainer

    tmp_env = env_class(env_config=env_config)

    address_info = ray.init(
        num_cpus=num_cpus,
        num_gpus=num_gpus,
        object_store_memory=int(1073741824 * 1),
        local_mode=False,
        include_dashboard=True,
        dashboard_host="0.0.0.0",
        dashboard_port=find_free_port(),
        ignore_reinit_error=True,
        logging_level=logging.INFO,
if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)

    parser = argparse.ArgumentParser()
    parser.add_argument('--scenario', type=str)
    parser.add_argument('--psro_port', type=int, required=False, default=None)
    parser.add_argument('--eval_port', type=int, required=False, default=None)
    parser.add_argument('--br_method',
                        type=str,
                        required=False,
                        default='standard')
    parser.add_argument('--use_prev_brs', default=False, action='store_true')
    commandline_args = parser.parse_args()

    scenario_name = commandline_args.scenario
    scenario: PSROScenario = scenario_catalog.get(scenario_name=scenario_name)

    psro_port = commandline_args.psro_port
    if psro_port is None:
        psro_port = establish_new_server_port_for_service(
            service_name=f"seed_{GRL_SEED}_{scenario.name}")

    eval_port = commandline_args.eval_port
    if eval_port is None:
        eval_port = establish_new_server_port_for_service(
            service_name=f"seed_{GRL_SEED}_{scenario.name}_evals")

    manager = launch_manager(scenario=scenario,
                             psro_port=psro_port,
                             eval_port=eval_port,
                             block=False,
Esempio n. 8
0
    br_obs_space = tmp_br_env.observation_space
    br_act_space = tmp_br_env.action_space

    experiment_name = f"loss_game_alpha_hparam_search"
    num_cpus = 90
    num_gpus = 0
    env_class = LossGameAlphaMultiAgentEnv

    br_player = 1

    env_config = {
        "total_moves": 10,
        "alpha": 2.9,
    }

    metanash_pol_scenario: PSROScenario = scenario_catalog.get(
        scenario_name="loss_game_psro_10_moves_alpha_2.9")

    trainer_class = PPOTrainer

    tmp_env = env_class(env_config=env_config)

    address_info = ray.init(num_cpus=num_cpus,
                            num_gpus=num_gpus,
                            object_store_memory=int(1073741824 * 1),
                            local_mode=False,
                            include_dashboard=True,
                            dashboard_host="0.0.0.0",
                            dashboard_port=find_free_port(),
                            ignore_reinit_error=True,
                            logging_level=logging.INFO,
                            log_to_driver=os.getenv("RAY_LOG_TO_DRIVER",
Esempio n. 9
0
        log_dir=os.path.join(os.path.dirname(grl.__file__), "data",
                             scenario.name, f"manager_{datetime_str()}"),
        port=nxdo_port,
        manager_metadata={"ray_head_address": ray_head_address},
    )

    if block:
        manager.wait_for_server_termination()

    return manager


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)

    parser = argparse.ArgumentParser()
    parser.add_argument('--scenario', type=str)
    parser.add_argument('--nxdo_port', type=int, required=False, default=None)

    commandline_args = parser.parse_args()

    scenario_name = commandline_args.scenario
    scenario: NXDOScenario = scenario_catalog.get(scenario_name=scenario_name)

    nxdo_port = commandline_args.nxdo_port
    if nxdo_port is None:
        nxdo_port = establish_new_server_port_for_service(
            service_name=f"seed_{GRL_SEED}_{scenario.name}")

    launch_manager(scenario=scenario, nxdo_port=nxdo_port, block=True)
Esempio n. 10
0
def train_nxdo_best_response(br_player: int,
                             scenario_name: str,
                             nxdo_manager_port: int,
                             nxdo_manager_host: str,
                             print_train_results: bool = True,
                             previous_br_checkpoint_path=None):
    scenario: NXDOScenario = scenario_catalog.get(scenario_name=scenario_name)
    if not isinstance(scenario, NXDOScenario):
        raise TypeError(f"Only instances of {NXDOScenario} can be used here. {scenario.name} is a {type(scenario)}.")

    use_openspiel_restricted_game: bool = scenario.use_openspiel_restricted_game
    get_restricted_game_custom_model = scenario.get_restricted_game_custom_model

    env_class = scenario.env_class
    base_env_config = scenario.env_config
    trainer_class = scenario.trainer_class_br
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes_br
    get_trainer_config = scenario.get_trainer_config_br
    nxdo_br_get_stopping_condition = scenario.get_stopping_condition_br
    should_log_result_fn = scenario.ray_should_log_result_filter
    nxdo_metanash_method: str = scenario.xdo_metanash_method
    if nxdo_metanash_method != "nfsp":
        raise NotImplementedError("Only 'nfsp' is currently supported for the nxdo_metanash_method")

    nxdo_manager = RemoteNXDOManagerClient(n_players=2,
                                           port=nxdo_manager_port,
                                           remote_server_host=nxdo_manager_host)

    manager_metadata = nxdo_manager.get_manager_metadata()
    results_dir = nxdo_manager.get_log_dir()

    br_params = nxdo_manager.claim_new_active_policy_for_player(player=br_player)
    metanash_specs_for_players, delegate_specs_for_players, active_policy_num = br_params

    other_player = 1 - br_player
    br_learner_name = f"policy {active_policy_num} player {br_player}"

    def log(message):
        print(f"({br_learner_name}): {message}")

    def select_policy(agent_id):
        if agent_id == br_player:
            return f"best_response"
        elif agent_id == other_player:
            return f"metanash"
        else:
            raise ValueError(f"Unknown agent id: {agent_id}")

    restricted_env_config = {
        "create_env_fn": lambda: env_class(env_config=base_env_config),
        "raise_if_no_restricted_players": metanash_specs_for_players is not None,
    }
    tmp_base_env = env_class(env_config=base_env_config)

    if use_openspiel_restricted_game:
        restricted_game_class = OpenSpielRestrictedGame
    else:
        restricted_game_class = RestrictedGame
        restricted_env_config["use_delegate_policy_exploration"] = scenario.allow_stochastic_best_responses

    tmp_env = restricted_game_class(env_config=restricted_env_config)

    if metanash_specs_for_players is None or use_openspiel_restricted_game:
        other_player_restricted_action_space = tmp_env.base_action_space
        metanash_class = policy_classes["best_response"]
    else:
        other_player_restricted_action_space = Discrete(n=len(delegate_specs_for_players[other_player]))
        metanash_class = policy_classes["metanash"]
        print(
            f"metanash class: {metanash_class}, other_player_restricted_action_space: {other_player_restricted_action_space}")

    if metanash_specs_for_players is None and use_openspiel_restricted_game:
        other_player_restricted_obs_space = tmp_env.base_observation_space
    else:
        other_player_restricted_obs_space = tmp_env.observation_space

    trainer_config = {
        "env": restricted_game_class,
        "env_config": restricted_env_config,
        "gamma": 1.0,
        "num_gpus": 0,
        "num_workers": 0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": [f"best_response"],
            "policies": {
                f"metanash": (metanash_class, other_player_restricted_obs_space, other_player_restricted_action_space,
                              {"explore": False}),

                f"metanash_delegate": (policy_classes["best_response"], tmp_env.base_observation_space, tmp_env.base_action_space,
                            {"explore": scenario.allow_stochastic_best_responses}),

                f"best_response": (policy_classes["best_response"], tmp_env.base_observation_space, tmp_env.base_action_space, {}),
            },
            "policy_mapping_fn": select_policy,
        },
    }

    if metanash_specs_for_players is not None and get_restricted_game_custom_model is not None:
        trainer_config["multiagent"]["policies"]["metanash"][3]["model"] = {
            "custom_model": get_restricted_game_custom_model(tmp_env)}

    trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_base_env))

    ray_head_address = manager_metadata["ray_head_address"]
    init_ray_for_scenario(scenario=scenario, head_address=ray_head_address, logging_level=logging.INFO)

    trainer = trainer_class(config=trainer_config, logger_creator=get_trainer_logger_creator(
        base_dir=results_dir, scenario_name=scenario_name,
        should_log_result_fn=should_log_result_fn))

    # metanash is single pure strat spec
    def _set_worker_metanash(worker: RolloutWorker):
        if metanash_specs_for_players is not None:
            metanash_policy = worker.policy_map["metanash"]
            metanash_strategy_spec: StrategySpec = metanash_specs_for_players[other_player]
            load_pure_strat(policy=metanash_policy, pure_strat_spec=metanash_strategy_spec)
    trainer.workers.foreach_worker(_set_worker_metanash)

    trainer.weights_cache = {}
    if delegate_specs_for_players:
        if use_openspiel_restricted_game:
            set_restricted_game_conversions_for_all_workers_openspiel(
                trainer=trainer,
                tmp_base_env=tmp_base_env,
                delegate_policy_id="metanash_delegate",
                agent_id_to_restricted_game_specs={
                    other_player: delegate_specs_for_players[other_player]
                },
                load_policy_spec_fn=load_pure_strat)
        else:
            set_restricted_game_conversations_for_all_workers(
                trainer=trainer,
                delegate_policy_id="metanash_delegate",
                agent_id_to_restricted_game_specs={
                    other_player: delegate_specs_for_players[other_player]
                },
                load_policy_spec_fn=create_get_pure_strat_cached(cache=trainer.weights_cache))

    log(f"got policy {active_policy_num}")

    if previous_br_checkpoint_path is not None:
        def _set_br_initial_weights(worker: RolloutWorker):
            br_policy = worker.policy_map["best_response"]
            load_pure_strat(policy=br_policy, checkpoint_path=previous_br_checkpoint_path)
        trainer.workers.foreach_worker(_set_br_initial_weights)

    # Perform main RL training loop. Stop according to our StoppingCondition.
    stopping_condition: StoppingCondition = nxdo_br_get_stopping_condition()

    while True:
        train_iter_results = trainer.train()  # do a step (or several) in the main RL loop

        if print_train_results:
            train_iter_results["p2sro_active_policy_num"] = active_policy_num
            train_iter_results["best_response_player"] = br_player
            # Delete verbose debugging info before printing
            if "hist_stats" in train_iter_results:
                del train_iter_results["hist_stats"]
            if "td_error" in train_iter_results["info"]["learner"][f"best_response"]:
                del train_iter_results["info"]["learner"][f"best_response"]["td_error"]
            log(f"Trainer log dir is {trainer.logdir}")
            log(pretty_dict_str(train_iter_results))

        total_timesteps_training_br = train_iter_results["timesteps_total"]
        total_episodes_training_br = train_iter_results["episodes_total"]
        br_reward_this_iter = train_iter_results["policy_reward_mean"][f"best_response"]

        if stopping_condition.should_stop_this_iter(latest_trainer_result=train_iter_results):
            log("Stopping condition met.")
            break

    log(f"Training stopped. Setting active policy {active_policy_num} as fixed.")

    final_policy_metadata = create_metadata_with_new_checkpoint_for_current_best_response(
            trainer=trainer, player=br_player, save_dir=checkpoint_dir(trainer=trainer),
            timesteps_training_br=total_timesteps_training_br,
            episodes_training_br=total_episodes_training_br,
            active_policy_num=active_policy_num,
            average_br_reward=float(br_reward_this_iter),
        )
    nxdo_manager.submit_final_br_policy(
        player=br_player, policy_num=active_policy_num,
        metadata_dict=final_policy_metadata)

    # trainer.cleanup()
    # del trainer
    ray.shutdown()
    time.sleep(10)

    # wait for both player policies to be fixed.
    for player_to_wait_on in range(2):
        wait_count = 0
        while True:
            if nxdo_manager.is_policy_fixed(player=player_to_wait_on, policy_num=active_policy_num):
                break
            if wait_count % 10 == 0:
                log(f"Waiting for policy {active_policy_num} player {player_to_wait_on} to become fixed")
            time.sleep(2.0)
            wait_count += 1

    return final_policy_metadata["checkpoint_path"]
Esempio n. 11
0
    experiment_name = f"leduc_hyper_param_search_dqn"
    num_cpus = 32
    num_gpus = 0
    env_class = PokerMultiAgentEnv

    br_player = 1
    avg_policy_player = 1 - br_player

    env_config = {
        "version": "leduc_poker",
        "fixed_players": True,
        "append_valid_actions_mask_to_obs": True,
    }

    avg_pol_scenario: NFSPScenario = scenario_catalog.get(
        scenario_name="leduc_nfsp_dqn")

    trainer_class = DQNTrainer

    tmp_env = env_class(env_config=env_config)

    address_info = ray.init(num_cpus=num_cpus,
                            num_gpus=num_gpus,
                            object_store_memory=int(1073741824 * 1),
                            local_mode=False,
                            include_dashboard=True,
                            dashboard_host="0.0.0.0",
                            dashboard_port=find_free_port(),
                            ignore_reinit_error=True,
                            logging_level=logging.INFO,
                            log_to_driver=os.getenv("RAY_LOG_TO_DRIVER",
Esempio n. 12
0
if __name__ == "__main__":
    tmp_br_env = AttackCounterGameMultiAgentEnv(env_config={})
    br_obs_space = tmp_br_env.observation_space
    br_act_space = tmp_br_env.action_space

    experiment_name = f"attack_and_counter_game_alpha_hparam_search"
    num_cpus = 40
    num_gpus = 0
    env_class = AttackCounterGameMultiAgentEnv

    br_player = 1

    env_config = {}

    metanash_pol_scenario: PSROScenario = scenario_catalog.get(
        scenario_name="attack_and_counter_game_psro_1_moves")

    trainer_class = PPOTrainer

    tmp_env = env_class(env_config=env_config)

    address_info = ray.init(num_cpus=num_cpus,
                            num_gpus=num_gpus,
                            object_store_memory=int(1073741824 * 1),
                            local_mode=False,
                            include_dashboard=True,
                            dashboard_host="0.0.0.0",
                            dashboard_port=find_free_port(),
                            ignore_reinit_error=True,
                            logging_level=logging.INFO,
                            log_to_driver=os.getenv("RAY_LOG_TO_DRIVER",
    experiment_name = f"loss_game_hparam_search_dqn"
    num_cpus = 40
    num_gpus = 0
    env_class = LossGameAlphaMultiAgentEnv

    br_player = 1
    avg_policy_player = 1 - br_player

    env_config = {
        "total_moves": 10,
        "alpha": 2.9,
        "discrete_actions_for_players": [0, 1],
    }

    avg_pol_scenario: NFSPScenario = scenario_catalog.get(scenario_name="loss_game_nfsp_10_moves_alpha_2.9")

    trainer_class = DQNTrainer

    tmp_env = env_class(env_config=env_config)

    address_info = ray.init(
        num_cpus=num_cpus,
        num_gpus=num_gpus,
        object_store_memory=int(1073741824 * 10),
        local_mode=False,
        include_dashboard=True,
        dashboard_host="0.0.0.0",
        dashboard_port=find_free_port(),
        ignore_reinit_error=True,
        logging_level=logging.INFO,
def train_poker_approx_best_response_psro(br_player,
                                          ray_head_address,
                                          scenario_name,
                                          general_trainer_config_overrrides,
                                          br_policy_config_overrides: dict,
                                          get_stopping_condition,
                                          metanash_policy_specs,
                                          metanash_weights,
                                          results_dir,
                                          print_train_results=True):
    scenario: PSROScenario = scenario_catalog.get(scenario_name=scenario_name)

    env_class = scenario.env_class
    env_config = scenario.env_config
    trainer_class = scenario.trainer_class
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes
    p2sro = scenario.p2sro
    get_trainer_config = scenario.get_trainer_config
    psro_get_stopping_condition = scenario.psro_get_stopping_condition
    mix_metanash_with_uniform_dist_coeff = scenario.mix_metanash_with_uniform_dist_coeff

    other_player = 1 - br_player

    br_learner_name = f"new_learner_{br_player}"

    def log(message, level=logging.INFO):
        logger.log(level, f"({br_learner_name}): {message}")

    def select_policy(agent_id):
        if agent_id == br_player:
            return f"best_response"
        elif agent_id == other_player:
            return f"metanash"
        else:
            raise ValueError(f"Unknown agent id: {agent_id}")

    init_ray_for_scenario(scenario=scenario,
                          head_address=ray_head_address,
                          logging_level=logging.INFO)

    tmp_env = env_class(env_config=env_config)

    trainer_config = {
        "callbacks": P2SROPreAndPostEpisodeCallbacks,
        "env": env_class,
        "env_config": env_config,
        "gamma": 1.0,
        "num_gpus": 0,
        "num_workers": 0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": [f"best_response"],
            "policies": {
                f"metanash":
                (policy_classes["metanash"], tmp_env.observation_space,
                 tmp_env.action_space, {
                     "explore": scenario.allow_stochastic_best_responses
                 }),
                f"best_response":
                (policy_classes["best_response"], tmp_env.observation_space,
                 tmp_env.action_space, br_policy_config_overrides),
            },
            "policy_mapping_fn": select_policy,
        },
    }

    trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_env))
    trainer_config = merge_dicts(trainer_config,
                                 general_trainer_config_overrrides)

    # trainer_config["rollout_fragment_length"] = trainer_config["rollout_fragment_length"] // max(1, trainer_config["num_workers"] * trainer_config["num_envs_per_worker"] )

    trainer = trainer_class(config=trainer_config,
                            logger_creator=get_trainer_logger_creator(
                                base_dir=results_dir,
                                scenario_name="approx_br",
                                should_log_result_fn=lambda result: result[
                                    "training_iteration"] % 100 == 0))

    update_all_workers_to_latest_metanash(
        trainer=trainer,
        metanash_policy_specs=metanash_policy_specs,
        metanash_weights=metanash_weights)

    train_iter_count = 0

    stopping_condition: StoppingCondition = get_stopping_condition()
    max_reward = None
    while True:
        train_iter_results = trainer.train(
        )  # do a step (or several) in the main RL loop
        train_iter_count += 1

        if print_train_results:
            train_iter_results["best_response_player"] = br_player
            # Delete verbose debugging info before printing
            if "hist_stats" in train_iter_results:
                del train_iter_results["hist_stats"]
            if "td_error" in train_iter_results["info"]["learner"][
                    f"best_response"]:
                del train_iter_results["info"]["learner"][f"best_response"][
                    "td_error"]
            print(pretty_dict_str(train_iter_results))

        total_timesteps_training_br = train_iter_results["timesteps_total"]
        total_episodes_training_br = train_iter_results["episodes_total"]
        br_reward_this_iter = train_iter_results["policy_reward_mean"][
            f"best_response"]
        if max_reward is None or br_reward_this_iter > max_reward:
            max_reward = br_reward_this_iter
        if stopping_condition.should_stop_this_iter(
                latest_trainer_result=train_iter_results):
            break

    trainer.cleanup()
    ray.shutdown()
    time.sleep(10)

    return max_reward
if __name__ == "__main__":

    experiment_name = f"attack_and_counter_hparam_search_dqn"
    num_cpus = 20
    num_gpus = 0
    env_class = AttackCounterGameMultiAgentEnv

    br_player = 1
    avg_policy_player = 1 - br_player

    env_config = {
        "discrete_actions_for_players": [br_player],
        "discrete_action_space_is_default": True,
    }

    avg_pol_scenario: NFSPScenario = scenario_catalog.get(
        scenario_name="attack_and_counter_game_nfsp")

    trainer_class = DQNTrainer

    tmp_env = env_class(env_config=env_config)

    address_info = ray.init(num_cpus=num_cpus,
                            num_gpus=num_gpus,
                            object_store_memory=int(1073741824 * 10),
                            local_mode=False,
                            include_dashboard=True,
                            dashboard_host="0.0.0.0",
                            dashboard_port=find_free_port(),
                            ignore_reinit_error=True,
                            logging_level=logging.INFO,
                            log_to_driver=os.getenv("RAY_LOG_TO_DRIVER",
Esempio n. 16
0
def train_psro_best_response(player: int, results_dir: str, scenario_name: str, psro_manager_port: int,
                             psro_manager_host: str, print_train_results=True, previous_br_checkpoint_path=None) -> str:
    
    scenario: PSROScenario = scenario_catalog.get(scenario_name=scenario_name)
    if not isinstance(scenario, PSROScenario):
        raise TypeError(f"Only instances of {PSROScenario} can be used here. {scenario.name} is a {type(scenario)}.")

    env_class = scenario.env_class
    env_config = scenario.env_config
    trainer_class = scenario.trainer_class
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes
    single_agent_symmetric_game = scenario.single_agent_symmetric_game
    if single_agent_symmetric_game and player != 0:
        if player is None:
            player = 0
        else:
            raise ValueError(f"If treating the game as single agent symmetric, only use player 0 "
                             f"(one agent plays all sides).")

    p2sro = scenario.p2sro
    p2sro_sync_with_payoff_table_every_n_episodes = scenario.p2sro_sync_with_payoff_table_every_n_episodes
    get_trainer_config = scenario.get_trainer_config
    psro_get_stopping_condition = scenario.psro_get_stopping_condition
    mix_metanash_with_uniform_dist_coeff = scenario.mix_metanash_with_uniform_dist_coeff
    allow_stochastic_best_response = scenario.allow_stochastic_best_responses
    should_log_result_fn = scenario.ray_should_log_result_filter

    class P2SROPreAndPostEpisodeCallbacks(DefaultCallbacks):

        # def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode,
        #                     env_index: int,
        #                     **kwargs):
        #     super().on_episode_step(worker=worker, base_env=base_env, episode=episode, env_index=env_index, **kwargs)
        #
        #     # Debug render a single environment.
        #     if worker.worker_index == 1 and env_index == 0:
        #         base_env.get_unwrapped()[0].render()

        def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
                             policies: Dict[str, Policy],
                             episode: MultiAgentEpisode, env_index: int, **kwargs):

            # Sample new pure strategy policy weights from the opponent strategy distribution for the best response to
            # train against. For better runtime performance, this function can be modified to load new weights
            # only every few episodes instead.
            resample_pure_strat_every_n_episodes = 1
            metanash_policy: Policy = policies[f"metanash"]
            opponent_policy_distribution: PolicySpecDistribution = worker.opponent_policy_distribution
            time_for_resample = (not hasattr(metanash_policy, "episodes_since_resample") or
                                 metanash_policy.episodes_since_resample >= resample_pure_strat_every_n_episodes)
            if time_for_resample and opponent_policy_distribution is not None:
                new_pure_strat_spec: StrategySpec = opponent_policy_distribution.sample_policy_spec()
                # noinspection PyTypeChecker
                load_pure_strat(policy=metanash_policy, pure_strat_spec=new_pure_strat_spec)
                metanash_policy.episodes_since_resample = 1
            elif opponent_policy_distribution is not None:
                metanash_policy.episodes_since_resample += 1

        def on_train_result(self, *, trainer, result: dict, **kwargs):
            result["scenario_name"] = trainer.scenario_name
            super().on_train_result(trainer=trainer, result=result, **kwargs)

        def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
                           policies: Dict[str, Policy], episode: MultiAgentEpisode,
                           env_index: int, **kwargs):

            # If using P2SRO, report payoff results of the actively training BR to the payoff table.
            if not p2sro:
                return

            if not hasattr(worker, "p2sro_manager"):
                worker.p2sro_manager = RemoteP2SROManagerClient(n_players=2,
                                                                port=psro_manager_port,
                                                                remote_server_host=psro_manager_host)

            br_policy_spec: StrategySpec = worker.policy_map["best_response"].policy_spec
            if br_policy_spec.pure_strat_index_for_player(player=worker.br_player) == 0:
                # We're training policy 0 if True (first iteration of PSRO).
                # The PSRO subgame should be empty, and instead the metanash is a random neural network.
                # No need to report payoff results for this.
                return

            # Report payoff results for individual episodes to the p2sro manager to keep a real-time approximation
            # of the payoff matrix entries for (learning) active policies.
            policy_specs_for_each_player: List[StrategySpec] = [None, None]
            payoffs_for_each_player: List[float] = [None, None]
            for (player, policy_name), reward in episode.agent_rewards.items():
                assert policy_name in ["best_response", "metanash"]
                policy: Policy = worker.policy_map[policy_name]
                assert policy.policy_spec is not None
                policy_specs_for_each_player[player] = policy.policy_spec
                payoffs_for_each_player[player] = reward
            assert all(payoff is not None for payoff in payoffs_for_each_player)

            # Send payoff result to the manager for inclusion in the payoff table.
            worker.p2sro_manager.submit_empirical_payoff_result(
                policy_specs_for_each_player=tuple(policy_specs_for_each_player),
                payoffs_for_each_player=tuple(payoffs_for_each_player),
                games_played=1,
                override_all_previous_results=False)

    other_player = 1 - player
    br_learner_name = f"new_learner_{player}"

    def log(message):
        print(f"({br_learner_name}): {message}")

    def select_policy(agent_id):
        if agent_id == player:
            return "best_response"
        elif agent_id == other_player:
            return "metanash"
        else:
            raise ValueError(f"Unknown agent id: {agent_id}")

    p2sro_manager = RemoteP2SROManagerClient(n_players=2, port=psro_manager_port,
                                             remote_server_host=psro_manager_host)
    manager_metadata = p2sro_manager.get_manager_metadata()
    ray_head_address = manager_metadata["ray_head_address"]
    init_ray_for_scenario(scenario=scenario, head_address=ray_head_address, logging_level=logging.INFO)

    tmp_env = env_class(env_config=env_config)

    trainer_config = {
        "callbacks": P2SROPreAndPostEpisodeCallbacks,
        "env": env_class,
        "env_config": env_config,
        "gamma": 1.0,
        "num_gpus": 0,
        "num_workers": 0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": [f"best_response"],
            "policies": {
                f"metanash": (
                policy_classes["metanash"], tmp_env.observation_space, tmp_env.action_space, {"explore": allow_stochastic_best_response}),
                f"best_response": (
                policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, {}),
            },
            "policy_mapping_fn": select_policy,
        },
    }

    trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_env))

    trainer = trainer_class(config=trainer_config,
                            logger_creator=get_trainer_logger_creator(
                                base_dir=results_dir, scenario_name=scenario_name,
                                should_log_result_fn=should_log_result_fn))

    # scenario_name logged in on_train_result_callback
    trainer.scenario_name = scenario_name

    if previous_br_checkpoint_path is not None:
        def _set_br_initial_weights(worker: RolloutWorker):
            br_policy = worker.policy_map["best_response"]
            load_pure_strat(policy=br_policy, checkpoint_path=previous_br_checkpoint_path)
        trainer.workers.foreach_worker(_set_br_initial_weights)

    active_policy_spec: StrategySpec = p2sro_manager.claim_new_active_policy_for_player(
        player=player, new_policy_metadata_dict=create_metadata_with_new_checkpoint_for_current_best_response(
            trainer=trainer, player=player, save_dir=checkpoint_dir(trainer), timesteps_training_br=0,
            episodes_training_br=0,
            active_policy_num=None
        ))

    active_policy_num = active_policy_spec.pure_strat_index_for_player(player=player)
    br_learner_name = f"policy {active_policy_num} player {player}"

    log(f"got policy {active_policy_num}")

    set_best_response_active_policy_spec_and_player_for_all_workers(trainer=trainer, player=player,
                                                                    active_policy_spec=active_policy_spec)

    sync_active_policy_br_and_metanash_with_p2sro_manager(trainer=trainer,
                                                          player=player,
                                                          metanash_player=other_player,
                                                          one_agent_plays_all_sides=single_agent_symmetric_game,
                                                          p2sro_manager=p2sro_manager,
                                                          mix_metanash_with_uniform_dist_coeff=mix_metanash_with_uniform_dist_coeff,
                                                          active_policy_num=active_policy_num,
                                                          timesteps_training_br=0,
                                                          episodes_training_br=0)

    # Perform main RL training loop. Stop training according to our StoppingCondition.
    train_iter_count = 0
    episodes_since_last_sync_with_manager = 0
    stopping_condition: StoppingCondition = psro_get_stopping_condition()

    while True:
        train_iter_results = trainer.train()  # do a step (or several) in the main RL loop
        train_iter_count += 1

        if print_train_results:
            train_iter_results["p2sro_active_policy_num"] = active_policy_num
            train_iter_results["best_response_player"] = player
            # Delete verbose debugging info before printing
            if "hist_stats" in train_iter_results:
                del train_iter_results["hist_stats"]
            if "td_error" in train_iter_results["info"]["learner"][f"best_response"]:
                del train_iter_results["info"]["learner"][f"best_response"]["td_error"]
            log(pretty_dict_str(train_iter_results))

        total_timesteps_training_br = train_iter_results["timesteps_total"]
        total_episodes_training_br = train_iter_results["episodes_total"]

        episodes_since_last_sync_with_manager += train_iter_results["episodes_this_iter"]
        if p2sro and episodes_since_last_sync_with_manager >= p2sro_sync_with_payoff_table_every_n_episodes:
            if p2sro_sync_with_payoff_table_every_n_episodes > 0:
                episodes_since_last_sync_with_manager = episodes_since_last_sync_with_manager % p2sro_sync_with_payoff_table_every_n_episodes
            else:
                episodes_since_last_sync_with_manager = 0

            sync_active_policy_br_and_metanash_with_p2sro_manager(trainer=trainer,
                                                                  player=player,
                                                                  metanash_player=other_player,
                                                                  one_agent_plays_all_sides=single_agent_symmetric_game,
                                                                  p2sro_manager=p2sro_manager,
                                                                  mix_metanash_with_uniform_dist_coeff=mix_metanash_with_uniform_dist_coeff,
                                                                  active_policy_num=active_policy_num,
                                                                  timesteps_training_br=total_timesteps_training_br,
                                                                  episodes_training_br=total_episodes_training_br)

        if stopping_condition.should_stop_this_iter(latest_trainer_result=train_iter_results):
            if p2sro_manager.can_active_policy_be_set_as_fixed_now(player=player, policy_num=active_policy_num):
                break
            else:
                log(f"Forcing training to continue since lower policies are still active.")

    log(f"Training stopped. Setting active policy {active_policy_num} as fixed.")

    final_policy_metadata = create_metadata_with_new_checkpoint_for_current_best_response(
            trainer=trainer, player=player, save_dir=checkpoint_dir(trainer=trainer),
            timesteps_training_br=total_timesteps_training_br,
            episodes_training_br=total_episodes_training_br,
            active_policy_num=active_policy_num,
            average_br_reward=train_iter_results["policy_reward_mean"]["best_response"])

    p2sro_manager.set_active_policy_as_fixed(
        player=player, policy_num=active_policy_num,
        final_metadata_dict=final_policy_metadata)

    trainer.cleanup()
    ray.shutdown()
    time.sleep(10)

    if not p2sro:
        # wait for both player policies to be fixed.
        for player_to_wait_on in range(2):
            wait_count = 0
            while True:
                if p2sro_manager.is_policy_fixed(player=player_to_wait_on, policy_num=active_policy_num):
                    break
                if wait_count % 10 == 0:
                    log(f"Waiting for policy {active_policy_num} player {player_to_wait_on} to become fixed")
                time.sleep(2.0)
                wait_count += 1

    return final_policy_metadata["checkpoint_path"]