Example #1
0
        # Stepsize of SGD.
        "lr": 0.0005,
        # PPO clip parameter.
        "clip_param": 0.03,
        # Clip param for the value function. Note that this is sensitive to the
        # scale of the rewards. If your expected V is large, increase this.
        "vf_clip_param": 10.0,
        # If specified, clip the global norm of gradients by this amount.
        "grad_clip": None,
        # Target value for KL divergence.
        "kl_target": 0.001,
    }

    tune.run(
        run_or_experiment=trainer_class,
        name=experiment_name,
        metric="br_reward_mean",
        config=hyperparams,
        num_samples=2,
        search_alg=None,
        mode="max",
        local_dir=data_dir(),
        stop={"timesteps_total": int(3e6)},
        loggers=[
            get_trainer_logger_creator(base_dir=data_dir(),
                                       scenario_name=experiment_name,
                                       should_log_result_fn=lambda result:
                                       result["training_iteration"] % 20 == 0)
        ],
    )
Example #2
0
def train_self_play(results_dir, scenario_name, print_train_results=True):

    scenario: PSROScenario = scenario_catalog.get(scenario_name=scenario_name)

    env_class = scenario.env_class
    env_config = scenario.env_config
    trainer_class = scenario.trainer_class
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes
    single_agent_symmetric_game = scenario.single_agent_symmetric_game
    if single_agent_symmetric_game:
        raise NotImplementedError

    get_trainer_config = scenario.get_trainer_config
    should_log_result_fn = scenario.ray_should_log_result_filter

    checkpoint_every_n_iters = 500

    class PreAndPostEpisodeCallbacks(DefaultCallbacks):
        def on_train_result(self, *, trainer, result: dict, **kwargs):
            result["scenario_name"] = trainer.scenario_name
            training_iteration = result["training_iteration"]
            super().on_train_result(trainer=trainer, result=result, **kwargs)

            if training_iteration % checkpoint_every_n_iters == 0 or training_iteration == 1:
                for player in range(2):
                    checkpoint_metadata = create_metadata_with_new_checkpoint(
                        policy_id_to_save=f"best_response_{player}",
                        br_trainer=trainer,
                        policy_player=player,
                        save_dir=checkpoint_dir(trainer=trainer),
                        timesteps_training=result["timesteps_total"],
                        episodes_training=result["episodes_total"],
                        checkpoint_name=
                        f"best_response_player_{player}_iter_{training_iteration}.h5"
                    )
                    joint_pol_checkpoint_spec = StrategySpec(
                        strategy_id=
                        f"best_response_player_{player}_iter_{training_iteration}",
                        metadata=checkpoint_metadata)
                    checkpoint_path = os.path.join(
                        spec_checkpoint_dir(trainer),
                        f"best_response_player_{player}_iter_{training_iteration}.json"
                    )
                    ensure_dir(checkpoint_path)
                    with open(checkpoint_path, "+w") as checkpoint_spec_file:
                        checkpoint_spec_file.write(
                            joint_pol_checkpoint_spec.to_json())

    def select_policy(agent_id):
        if agent_id == 0:
            return "best_response_0"
        elif agent_id == 1:
            return "best_response_1"
        else:
            raise ValueError(f"Unknown agent id: {agent_id}")

    init_ray_for_scenario(scenario=scenario,
                          head_address=None,
                          logging_level=logging.INFO)

    tmp_env = env_class(env_config=env_config)

    trainer_config = {
        "callbacks": PreAndPostEpisodeCallbacks,
        "env": env_class,
        "env_config": env_config,
        "gamma": 1.0,
        "num_gpus": 0,
        "num_workers": 0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": [f"best_response_0", "best_response_1"],
            "policies": {
                f"best_response_0":
                (policy_classes["best_response"], tmp_env.observation_space,
                 tmp_env.action_space, {}),
                f"best_response_1":
                (policy_classes["best_response"], tmp_env.observation_space,
                 tmp_env.action_space, {}),
            },
            "policy_mapping_fn": select_policy,
        },
    }

    trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_env))

    trainer = trainer_class(config=trainer_config,
                            logger_creator=get_trainer_logger_creator(
                                base_dir=results_dir,
                                scenario_name=scenario_name,
                                should_log_result_fn=should_log_result_fn))

    # scenario_name logged in on_train_result_callback
    trainer.scenario_name = scenario_name

    # Perform main RL training loop.
    while True:
        train_iter_results = trainer.train(
        )  # do a step (or several) in the main RL loop

        if print_train_results:
            # Delete verbose debugging info before printing
            if "hist_stats" in train_iter_results:
                del train_iter_results["hist_stats"]
            for key in ["best_response_0", "best_response_1"]:
                if "td_error" in train_iter_results["info"]["learner"][key]:
                    del train_iter_results["info"]["learner"][key]["td_error"]
            print(pretty_dict_str(train_iter_results))
def train_poker_approx_best_response_nfsp(
        br_player,
        ray_head_address,
        scenario: NFSPScenario,
        general_trainer_config_overrrides,
        br_policy_config_overrides,
        get_stopping_condition,
        avg_policy_specs_for_players: Dict[int, StrategySpec],
        results_dir: str,
        print_train_results: bool = True):
    env_class = scenario.env_class
    env_config = scenario.env_config
    trainer_class = scenario.trainer_class
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes
    get_trainer_config = scenario.get_trainer_config
    should_log_result_fn = scenario.ray_should_log_result_filter

    init_ray_for_scenario(scenario=scenario,
                          head_address=ray_head_address,
                          logging_level=logging.INFO)

    def log(message, level=logging.INFO):
        logger.log(level, message)

    def select_policy(agent_id):
        if agent_id == br_player:
            return "best_response"
        else:
            return f"average_policy"

    tmp_env = env_class(env_config=env_config)

    avg_policy_model_config = get_trainer_config(tmp_env)["model"]

    br_trainer_config = {
        "log_level": "DEBUG",
        # "callbacks": None,
        "env": env_class,
        "env_config": env_config,
        "gamma": 1.0,
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        # "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
        "num_gpus": 0.0,
        "num_workers": 0,
        "num_gpus_per_worker": 0.0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": ["best_response"],
            "policies": {
                "average_policy":
                (policy_classes["average_policy"], tmp_env.observation_space,
                 tmp_env.action_space, {
                     "model": avg_policy_model_config,
                     "explore": False,
                 }),
                "best_response":
                (policy_classes["best_response"], tmp_env.observation_space,
                 tmp_env.action_space, br_policy_config_overrides),
            },
            "policy_mapping_fn": select_policy,
        },
    }
    br_trainer_config = merge_dicts(br_trainer_config,
                                    get_trainer_config(tmp_env))
    br_trainer_config = merge_dicts(br_trainer_config,
                                    general_trainer_config_overrrides)

    br_trainer = trainer_class(config=br_trainer_config,
                               logger_creator=get_trainer_logger_creator(
                                   base_dir=results_dir,
                                   scenario_name="approx_br",
                                   should_log_result_fn=should_log_result_fn))

    def _set_avg_policy(worker: RolloutWorker):
        avg_policy = worker.policy_map["average_policy"]
        load_pure_strat(
            policy=avg_policy,
            pure_strat_spec=avg_policy_specs_for_players[1 - br_player])

    br_trainer.workers.foreach_worker(_set_avg_policy)

    br_trainer.latest_avg_trainer_result = None
    train_iter_count = 0

    stopping_condition: StoppingCondition = get_stopping_condition()

    max_reward = None
    while True:
        train_iter_results = br_trainer.train(
        )  # do a step (or several) in the main RL loop
        br_reward_this_iter = train_iter_results["policy_reward_mean"][
            f"best_response"]

        if max_reward is None or br_reward_this_iter > max_reward:
            max_reward = br_reward_this_iter

        train_iter_count += 1
        if print_train_results:
            # Delete verbose debugging info before printing
            if "hist_stats" in train_iter_results:
                del train_iter_results["hist_stats"]
            if "td_error" in train_iter_results["info"]["learner"][
                    "best_response"]:
                del train_iter_results["info"]["learner"]["best_response"][
                    "td_error"]
            print(pretty_dict_str(train_iter_results))
            log(f"Trainer logdir is {br_trainer.logdir}")

        if stopping_condition.should_stop_this_iter(
                latest_trainer_result=train_iter_results):
            print("stopping condition met.")
            break

    return max_reward
Example #4
0
def train_off_policy_rl_nfsp(results_dir: str,
                             scenario_name: str,
                             print_train_results: bool = True):
    
    scenario: NFSPScenario = scenario_catalog.get(scenario_name=scenario_name)
    if not isinstance(scenario, NFSPScenario):
        raise TypeError(f"Only instances of {NFSPScenario} can be used here. {scenario.name} is a {type(scenario)}.")

    env_class = scenario.env_class
    env_config = scenario.env_config
    trainer_class = scenario.trainer_class
    avg_trainer_class = scenario.avg_trainer_class
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes
    anticipatory_param: float = scenario.anticipatory_param
    get_trainer_config = scenario.get_trainer_config
    get_avg_trainer_config = scenario.get_avg_trainer_config
    calculate_openspiel_metanash: bool = scenario.calculate_openspiel_metanash
    calc_metanash_every_n_iters: int = scenario.calc_metanash_every_n_iters
    checkpoint_every_n_iters: Union[int, None] = scenario.checkpoint_every_n_iters
    nfsp_get_stopping_condition = scenario.nfsp_get_stopping_condition
    should_log_result_fn = scenario.ray_should_log_result_filter

    init_ray_for_scenario(scenario=scenario, head_address=None, logging_level=logging.INFO)

    def log(message, level=logging.INFO):
        logger.log(level, message)

    def select_policy(agent_id):
        random_sample = np.random.random()
        if agent_id == 0:
            if random_sample < anticipatory_param:
                return "best_response_0"
            return "average_policy_0"
        elif agent_id == 1:
            if random_sample < anticipatory_param:
                return "best_response_1"
            return "average_policy_1"
        else:
            raise ValueError(f"unexpected agent_id: {agent_id}")

    def assert_not_called(agent_id):
        assert False, "This function should never be called."

    tmp_env = env_class(env_config=env_config)

    def _create_env():
        return env_class(env_config=env_config)

    avg_policy_model_config = get_trainer_config(tmp_env)["model"]

    avg_trainer_config = merge_dicts({
        "log_level": "DEBUG",
        "framework": "torch",
        "env": env_class,
        "env_config": env_config,
        "num_gpus": 0.0,
        "num_gpus_per_worker": 0.0,
        "num_workers": 0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": ["average_policy_0", "average_policy_1"],
            "policies": {
                "average_policy_0": (
                policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, {
                    "model": avg_policy_model_config
                }),
                "average_policy_1": (
                policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, {
                    "model": avg_policy_model_config
                }),
            },
            "policy_mapping_fn": assert_not_called,
        },

    }, get_avg_trainer_config(tmp_env))

    avg_trainer = avg_trainer_class(config=avg_trainer_config,
                                    logger_creator=get_trainer_logger_creator(
                                        base_dir=results_dir,
                                        scenario_name=f"{scenario_name}_avg_trainer",
                                        should_log_result_fn=should_log_result_fn))

    store_to_avg_policy_buffer = get_store_to_avg_policy_buffer_fn(nfsp_trainer=avg_trainer)

    class NFSPBestResponseCallbacks(DefaultCallbacks):

        def on_postprocess_trajectory(self, *, worker: "RolloutWorker", episode: MultiAgentEpisode, agent_id: AgentID,
                                      policy_id: PolicyID, policies: Dict[PolicyID, Policy],
                                      postprocessed_batch: SampleBatch,
                                      original_batches: Dict[Any, Tuple[Policy, SampleBatch]],
                                      **kwargs):
            super().on_postprocess_trajectory(worker=worker, episode=episode, agent_id=agent_id, policy_id=policy_id,
                                              policies=policies, postprocessed_batch=postprocessed_batch,
                                              original_batches=original_batches, **kwargs)

            postprocessed_batch.data["source_policy"] = [policy_id] * len(postprocessed_batch.data["rewards"])

            # All data from both policies will go into the best response's replay buffer.
            # Here we ensure policies not from the best response have the exact same preprocessing as the best response.
            for average_policy_id, br_policy_id in [("average_policy_0", "best_response_0"),
                                                    ("average_policy_1", "best_response_1")]:
                if policy_id == average_policy_id:

                    if "action_probs" in postprocessed_batch:
                        del postprocessed_batch.data["action_probs"]
                    if "behaviour_logits" in postprocessed_batch:
                        del postprocessed_batch.data["behaviour_logits"]

                    br_policy: Policy = policies[br_policy_id]

                    new_batch = br_policy.postprocess_trajectory(
                        sample_batch=postprocessed_batch,
                        other_agent_batches=original_batches,
                        episode=episode)
                    copy_attributes(src_obj=new_batch, dst_obj=postprocessed_batch)
                elif policy_id == br_policy_id:
                    if "q_values" in postprocessed_batch:
                        del postprocessed_batch.data["q_values"]
                    if "action_probs" in postprocessed_batch:
                        del postprocessed_batch.data["action_probs"]
                    del postprocessed_batch.data["action_dist_inputs"]

                if policy_id in ("average_policy_0", "best_response_0"):
                    assert agent_id == 0
                if policy_id in ("average_policy_1", "best_response_1"):
                    assert agent_id == 1

        def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy],
                           episode: MultiAgentEpisode, env_index: int, **kwargs):
            super().on_episode_end(worker=worker, base_env=base_env, policies=policies, episode=episode,
                                   env_index=env_index, **kwargs)

            episode_policies = set(episode.agent_rewards.keys())
            if episode_policies == {(0, "average_policy_0"), (1, "best_response_1")}:
                worker.avg_br_reward_deque.add.remote(episode.agent_rewards[(1, "best_response_1")])
            elif episode_policies == {(1, "average_policy_1"), (0, "best_response_0")}:
                worker.avg_br_reward_deque.add.remote(episode.agent_rewards[(0, "best_response_0")])

        def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs):
            super().on_sample_end(worker=worker, samples=samples, **kwargs)
            assert isinstance(samples, MultiAgentBatch)

            for policy_samples in samples.policy_batches.values():
                if "action_prob" in policy_samples.data:
                    del policy_samples.data["action_prob"]
                if "action_logp" in policy_samples.data:
                    del policy_samples.data["action_logp"]

            for average_policy_id, br_policy_id in [("average_policy_0", "best_response_0"),
                                                    ("average_policy_1", "best_response_1")]:
                for policy_id, policy_samples in samples.policy_batches.items():
                    if policy_id == br_policy_id:
                        store_to_avg_policy_buffer(MultiAgentBatch(policy_batches={
                            average_policy_id: policy_samples
                        }, env_steps=policy_samples.count))
                if average_policy_id in samples.policy_batches:

                    if br_policy_id in samples.policy_batches:
                        all_policies_samples = samples.policy_batches[br_policy_id].concat(
                            other=samples.policy_batches[average_policy_id])
                    else:
                        all_policies_samples = samples.policy_batches[average_policy_id]
                    del samples.policy_batches[average_policy_id]
                    samples.policy_batches[br_policy_id] = all_policies_samples

        def on_train_result(self, *, trainer, result: dict, **kwargs):
            super().on_train_result(trainer=trainer, result=result, **kwargs)
            result["scenario_name"] = trainer.scenario_name
            result["avg_br_reward_both_players"] = ray.get(trainer.avg_br_reward_deque.get_mean.remote())

            training_iteration = result["training_iteration"]
            if (calculate_openspiel_metanash and
                    (training_iteration == 1 or training_iteration % calc_metanash_every_n_iters == 0)):
                base_env = _create_env()
                open_spiel_env_config = base_env.open_spiel_env_config
                openspiel_game_version = base_env.game_version
                local_avg_policy_0 = trainer.workers.local_worker().policy_map["average_policy_0"]
                local_avg_policy_1 = trainer.workers.local_worker().policy_map["average_policy_1"]
                exploitability = nfsp_measure_exploitability_nonlstm(
                    rllib_policies=[local_avg_policy_0, local_avg_policy_1],
                    poker_game_version=openspiel_game_version,
                    open_spiel_env_config=open_spiel_env_config
                )
                result["avg_policy_exploitability"] = exploitability
                logger.info(colored(
                    f"(Graph this in a notebook) Exploitability: {exploitability} - Saving exploitability stats "
                    f"to {os.path.join(trainer.logdir, 'result.json')}", "green"))

            if checkpoint_every_n_iters and (training_iteration % checkpoint_every_n_iters == 0 or training_iteration == 1):
                for player in range(2):
                    checkpoint_metadata = create_metadata_with_new_checkpoint(
                        policy_id_to_save=f"average_policy_{player}",
                        br_trainer=br_trainer,
                        save_dir=checkpoint_dir(trainer=br_trainer),
                        timesteps_training=result["timesteps_total"],
                        episodes_training=result["episodes_total"],
                        checkpoint_name=f"average_policy_player_{player}_iter_{training_iteration}.h5"
                    )
                    avg_pol_checkpoint_spec = StrategySpec(
                        strategy_id=f"avg_pol_player_{player}_iter_{training_iteration}",
                        metadata=checkpoint_metadata)
                    checkpoint_path = os.path.join(spec_checkpoint_dir(br_trainer),
                                                   f"average_policy_player_{player}_iter_{training_iteration}.json")
                    ensure_dir(checkpoint_path)
                    with open(checkpoint_path, "+w") as checkpoint_spec_file:
                        checkpoint_spec_file.write(avg_pol_checkpoint_spec.to_json())

    br_trainer_config = {
        "log_level": "DEBUG",
        "callbacks": NFSPBestResponseCallbacks,
        "env": env_class,
        "env_config": env_config,
        "gamma": 1.0,
        "num_gpus": 0.0,
        "num_workers": 0,
        "num_gpus_per_worker": 0.0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": ["best_response_0", "best_response_1"],
            "policies": {
                "average_policy_0": (
                policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, {
                    "model": avg_policy_model_config,
                    "explore": False,
                }),
                "best_response_0": (
                policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, {}),

                "average_policy_1": (
                policy_classes["average_policy"], tmp_env.observation_space, tmp_env.action_space, {
                    "model": avg_policy_model_config,
                    "explore": False,
                }),
                "best_response_1": (
                policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, {}),
            },
            "policy_mapping_fn": select_policy,
        },
    }
    br_trainer_config = merge_dicts(br_trainer_config, get_trainer_config(tmp_env))

    br_trainer = trainer_class(config=br_trainer_config,
                               logger_creator=get_trainer_logger_creator(base_dir=results_dir,
                                                                         scenario_name=scenario_name,
                                                                         should_log_result_fn=should_log_result_fn))

    avg_br_reward_deque = StatDeque.remote(max_items=br_trainer_config["metrics_smoothing_episodes"])

    def _set_avg_br_rew_deque(worker: RolloutWorker):
        worker.avg_br_reward_deque = avg_br_reward_deque

    br_trainer.workers.foreach_worker(_set_avg_br_rew_deque)
    br_trainer.avg_br_reward_deque = avg_br_reward_deque

    # scenario_name logged in on_train_result_callback
    br_trainer.scenario_name = scenario_name

    br_trainer.latest_avg_trainer_result = None
    train_iter_count = 0

    for trainer in [br_trainer, avg_trainer]:
        for policy_id, policy in trainer.workers.local_worker().policy_map.items():
            policy.policy_id = policy_id

    avg_weights = avg_trainer.get_weights(["average_policy_0", "average_policy_1"])
    br_trainer.workers.foreach_worker(lambda worker: worker.set_weights(avg_weights))

    stopping_condition: StoppingCondition = nfsp_get_stopping_condition()

    print("starting")
    while True:
        print("avg train...")
        avg_train_results = avg_trainer.train()
        avg_weights = avg_trainer.get_weights(["average_policy_0", "average_policy_1"])
        br_trainer.workers.foreach_worker(lambda worker: worker.set_weights(avg_weights))
        br_trainer.latest_avg_trainer_result = copy.deepcopy(avg_train_results)
        print("br train...")
        train_iter_results = br_trainer.train()  # do a step (or several) in the main RL loop

        train_iter_count += 1
        print("printing results..")
        if print_train_results:
            # Delete verbose debugging info before printing
            if "hist_stats" in train_iter_results:
                del train_iter_results["hist_stats"]
            if "td_error" in train_iter_results["info"]["learner"]["best_response_0"]:
                del train_iter_results["info"]["learner"]["best_response_0"]["td_error"]
            if "td_error" in train_iter_results["info"]["learner"]["best_response_1"]:
                del train_iter_results["info"]["learner"]["best_response_1"]["td_error"]
            log(pretty_dict_str(train_iter_results))
        log(f"Trainer logdir is {br_trainer.logdir}")

        if stopping_condition.should_stop_this_iter(latest_trainer_result=train_iter_results):
            print("stopping condition met.")
            break
Example #5
0
def train_nxdo_best_response(br_player: int,
                             scenario_name: str,
                             nxdo_manager_port: int,
                             nxdo_manager_host: str,
                             print_train_results: bool = True,
                             previous_br_checkpoint_path=None):
    scenario: NXDOScenario = scenario_catalog.get(scenario_name=scenario_name)
    if not isinstance(scenario, NXDOScenario):
        raise TypeError(f"Only instances of {NXDOScenario} can be used here. {scenario.name} is a {type(scenario)}.")

    use_openspiel_restricted_game: bool = scenario.use_openspiel_restricted_game
    get_restricted_game_custom_model = scenario.get_restricted_game_custom_model

    env_class = scenario.env_class
    base_env_config = scenario.env_config
    trainer_class = scenario.trainer_class_br
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes_br
    get_trainer_config = scenario.get_trainer_config_br
    nxdo_br_get_stopping_condition = scenario.get_stopping_condition_br
    should_log_result_fn = scenario.ray_should_log_result_filter
    nxdo_metanash_method: str = scenario.xdo_metanash_method
    if nxdo_metanash_method != "nfsp":
        raise NotImplementedError("Only 'nfsp' is currently supported for the nxdo_metanash_method")

    nxdo_manager = RemoteNXDOManagerClient(n_players=2,
                                           port=nxdo_manager_port,
                                           remote_server_host=nxdo_manager_host)

    manager_metadata = nxdo_manager.get_manager_metadata()
    results_dir = nxdo_manager.get_log_dir()

    br_params = nxdo_manager.claim_new_active_policy_for_player(player=br_player)
    metanash_specs_for_players, delegate_specs_for_players, active_policy_num = br_params

    other_player = 1 - br_player
    br_learner_name = f"policy {active_policy_num} player {br_player}"

    def log(message):
        print(f"({br_learner_name}): {message}")

    def select_policy(agent_id):
        if agent_id == br_player:
            return f"best_response"
        elif agent_id == other_player:
            return f"metanash"
        else:
            raise ValueError(f"Unknown agent id: {agent_id}")

    restricted_env_config = {
        "create_env_fn": lambda: env_class(env_config=base_env_config),
        "raise_if_no_restricted_players": metanash_specs_for_players is not None,
    }
    tmp_base_env = env_class(env_config=base_env_config)

    if use_openspiel_restricted_game:
        restricted_game_class = OpenSpielRestrictedGame
    else:
        restricted_game_class = RestrictedGame
        restricted_env_config["use_delegate_policy_exploration"] = scenario.allow_stochastic_best_responses

    tmp_env = restricted_game_class(env_config=restricted_env_config)

    if metanash_specs_for_players is None or use_openspiel_restricted_game:
        other_player_restricted_action_space = tmp_env.base_action_space
        metanash_class = policy_classes["best_response"]
    else:
        other_player_restricted_action_space = Discrete(n=len(delegate_specs_for_players[other_player]))
        metanash_class = policy_classes["metanash"]
        print(
            f"metanash class: {metanash_class}, other_player_restricted_action_space: {other_player_restricted_action_space}")

    if metanash_specs_for_players is None and use_openspiel_restricted_game:
        other_player_restricted_obs_space = tmp_env.base_observation_space
    else:
        other_player_restricted_obs_space = tmp_env.observation_space

    trainer_config = {
        "env": restricted_game_class,
        "env_config": restricted_env_config,
        "gamma": 1.0,
        "num_gpus": 0,
        "num_workers": 0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": [f"best_response"],
            "policies": {
                f"metanash": (metanash_class, other_player_restricted_obs_space, other_player_restricted_action_space,
                              {"explore": False}),

                f"metanash_delegate": (policy_classes["best_response"], tmp_env.base_observation_space, tmp_env.base_action_space,
                            {"explore": scenario.allow_stochastic_best_responses}),

                f"best_response": (policy_classes["best_response"], tmp_env.base_observation_space, tmp_env.base_action_space, {}),
            },
            "policy_mapping_fn": select_policy,
        },
    }

    if metanash_specs_for_players is not None and get_restricted_game_custom_model is not None:
        trainer_config["multiagent"]["policies"]["metanash"][3]["model"] = {
            "custom_model": get_restricted_game_custom_model(tmp_env)}

    trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_base_env))

    ray_head_address = manager_metadata["ray_head_address"]
    init_ray_for_scenario(scenario=scenario, head_address=ray_head_address, logging_level=logging.INFO)

    trainer = trainer_class(config=trainer_config, logger_creator=get_trainer_logger_creator(
        base_dir=results_dir, scenario_name=scenario_name,
        should_log_result_fn=should_log_result_fn))

    # metanash is single pure strat spec
    def _set_worker_metanash(worker: RolloutWorker):
        if metanash_specs_for_players is not None:
            metanash_policy = worker.policy_map["metanash"]
            metanash_strategy_spec: StrategySpec = metanash_specs_for_players[other_player]
            load_pure_strat(policy=metanash_policy, pure_strat_spec=metanash_strategy_spec)
    trainer.workers.foreach_worker(_set_worker_metanash)

    trainer.weights_cache = {}
    if delegate_specs_for_players:
        if use_openspiel_restricted_game:
            set_restricted_game_conversions_for_all_workers_openspiel(
                trainer=trainer,
                tmp_base_env=tmp_base_env,
                delegate_policy_id="metanash_delegate",
                agent_id_to_restricted_game_specs={
                    other_player: delegate_specs_for_players[other_player]
                },
                load_policy_spec_fn=load_pure_strat)
        else:
            set_restricted_game_conversations_for_all_workers(
                trainer=trainer,
                delegate_policy_id="metanash_delegate",
                agent_id_to_restricted_game_specs={
                    other_player: delegate_specs_for_players[other_player]
                },
                load_policy_spec_fn=create_get_pure_strat_cached(cache=trainer.weights_cache))

    log(f"got policy {active_policy_num}")

    if previous_br_checkpoint_path is not None:
        def _set_br_initial_weights(worker: RolloutWorker):
            br_policy = worker.policy_map["best_response"]
            load_pure_strat(policy=br_policy, checkpoint_path=previous_br_checkpoint_path)
        trainer.workers.foreach_worker(_set_br_initial_weights)

    # Perform main RL training loop. Stop according to our StoppingCondition.
    stopping_condition: StoppingCondition = nxdo_br_get_stopping_condition()

    while True:
        train_iter_results = trainer.train()  # do a step (or several) in the main RL loop

        if print_train_results:
            train_iter_results["p2sro_active_policy_num"] = active_policy_num
            train_iter_results["best_response_player"] = br_player
            # Delete verbose debugging info before printing
            if "hist_stats" in train_iter_results:
                del train_iter_results["hist_stats"]
            if "td_error" in train_iter_results["info"]["learner"][f"best_response"]:
                del train_iter_results["info"]["learner"][f"best_response"]["td_error"]
            log(f"Trainer log dir is {trainer.logdir}")
            log(pretty_dict_str(train_iter_results))

        total_timesteps_training_br = train_iter_results["timesteps_total"]
        total_episodes_training_br = train_iter_results["episodes_total"]
        br_reward_this_iter = train_iter_results["policy_reward_mean"][f"best_response"]

        if stopping_condition.should_stop_this_iter(latest_trainer_result=train_iter_results):
            log("Stopping condition met.")
            break

    log(f"Training stopped. Setting active policy {active_policy_num} as fixed.")

    final_policy_metadata = create_metadata_with_new_checkpoint_for_current_best_response(
            trainer=trainer, player=br_player, save_dir=checkpoint_dir(trainer=trainer),
            timesteps_training_br=total_timesteps_training_br,
            episodes_training_br=total_episodes_training_br,
            active_policy_num=active_policy_num,
            average_br_reward=float(br_reward_this_iter),
        )
    nxdo_manager.submit_final_br_policy(
        player=br_player, policy_num=active_policy_num,
        metadata_dict=final_policy_metadata)

    # trainer.cleanup()
    # del trainer
    ray.shutdown()
    time.sleep(10)

    # wait for both player policies to be fixed.
    for player_to_wait_on in range(2):
        wait_count = 0
        while True:
            if nxdo_manager.is_policy_fixed(player=player_to_wait_on, policy_num=active_policy_num):
                break
            if wait_count % 10 == 0:
                log(f"Waiting for policy {active_policy_num} player {player_to_wait_on} to become fixed")
            time.sleep(2.0)
            wait_count += 1

    return final_policy_metadata["checkpoint_path"]
Example #6
0
def train_poker_approx_best_response_xdfo(
        br_player: int,
        ray_head_address,
        scenario: NXDOScenario,
        general_trainer_config_overrrides,
        br_policy_config_overrides: dict,
        get_stopping_condition: Callable[[], StoppingCondition],
        metanash_specs_for_players: Dict[int, StrategySpec],
        delegate_specs_for_players: Dict[int, List[StrategySpec]],
        results_dir: str,
        print_train_results: bool = True):
    use_openspiel_restricted_game: bool = scenario.use_openspiel_restricted_game
    get_restricted_game_custom_model = scenario.get_restricted_game_custom_model

    env_class = scenario.env_class
    base_env_config = scenario.env_config
    trainer_class = scenario.trainer_class_br
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes_br
    get_trainer_config = scenario.get_trainer_config_br
    should_log_result_fn = scenario.ray_should_log_result_filter
    nxdo_metanash_method: str = scenario.xdo_metanash_method
    if nxdo_metanash_method != "nfsp":
        raise NotImplementedError(
            "Only 'nfsp' is currently supported for the nxdo_metanash_method")

    other_player = 1 - br_player
    br_learner_name = f"approx br player {br_player}"

    def log(message, level=logging.INFO):
        logger.log(level, f"({br_learner_name}): {message}")

    def select_policy(agent_id):
        if agent_id == br_player:
            return f"best_response"
        elif agent_id == other_player:
            return f"metanash"
        else:
            raise ValueError(f"Unknown agent id: {agent_id}")

    restricted_env_config = {
        "create_env_fn": lambda: env_class(env_config=base_env_config),
        "raise_if_no_restricted_players": metanash_specs_for_players
        is not None
    }
    tmp_base_env = env_class(env_config=base_env_config)

    if use_openspiel_restricted_game:
        restricted_game_class = OpenSpielRestrictedGame
    else:
        restricted_game_class = RestrictedGame
        restricted_env_config[
            "use_delegate_policy_exploration"] = scenario.allow_stochastic_best_responses

    tmp_env = restricted_game_class(env_config=restricted_env_config)

    if metanash_specs_for_players is None or use_openspiel_restricted_game:
        other_player_restricted_action_space = tmp_env.base_action_space
    else:
        other_player_restricted_action_space = Discrete(
            n=len(delegate_specs_for_players[other_player]))

    if metanash_specs_for_players is None and use_openspiel_restricted_game:
        other_player_restricted_obs_space = tmp_env.base_observation_space
    else:
        other_player_restricted_obs_space = tmp_env.observation_space

    trainer_config = {
        "env": restricted_game_class,
        "env_config": restricted_env_config,
        "gamma": 1.0,
        "num_gpus": 0,
        "num_workers": 0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": [f"best_response"],
            "policies": {
                f"metanash":
                (policy_classes["metanash"], other_player_restricted_obs_space,
                 other_player_restricted_action_space, {
                     "explore": False
                 }),
                f"metanash_delegate":
                (policy_classes["best_response"],
                 tmp_env.base_observation_space, tmp_env.base_action_space, {
                     "explore": scenario.allow_stochastic_best_responses
                 }),
                f"best_response":
                (policy_classes["best_response"],
                 tmp_env.base_observation_space, tmp_env.base_action_space,
                 br_policy_config_overrides),
            },
            "policy_mapping_fn": select_policy,
        },
    }

    if metanash_specs_for_players is not None:
        if get_restricted_game_custom_model is not None:
            restricted_game_custom_model = get_restricted_game_custom_model(
                tmp_base_env)
        else:
            restricted_game_custom_model = None
        trainer_config["multiagent"]["policies"]["metanash"][3]["model"] = {
            "custom_model": restricted_game_custom_model
        }

    trainer_config = merge_dicts(trainer_config,
                                 get_trainer_config(tmp_base_env))
    trainer_config = merge_dicts(trainer_config,
                                 general_trainer_config_overrrides)

    init_ray_for_scenario(scenario=scenario,
                          head_address=ray_head_address,
                          logging_level=logging.INFO)

    trainer = trainer_class(config=trainer_config,
                            logger_creator=get_trainer_logger_creator(
                                base_dir=results_dir,
                                scenario_name="approx_br",
                                should_log_result_fn=should_log_result_fn))

    # metanash is single pure strat spec
    def _set_worker_metanash(worker: RolloutWorker):
        if metanash_specs_for_players is not None:
            metanash_policy = worker.policy_map["metanash"]
            metanash_strategy_spec: StrategySpec = metanash_specs_for_players[
                other_player]
            load_pure_strat(policy=metanash_policy,
                            pure_strat_spec=metanash_strategy_spec)

    trainer.workers.foreach_worker(_set_worker_metanash)

    trainer.weights_cache = {}
    if delegate_specs_for_players:
        if use_openspiel_restricted_game:
            set_restricted_game_conversions_for_all_workers_openspiel(
                trainer=trainer,
                tmp_base_env=tmp_base_env,
                delegate_policy_id="metanash_delegate",
                agent_id_to_restricted_game_specs={
                    other_player: delegate_specs_for_players[other_player]
                },
                load_policy_spec_fn=load_pure_strat)
        else:
            set_restricted_game_conversations_for_all_workers(
                trainer=trainer,
                delegate_policy_id="metanash_delegate",
                agent_id_to_restricted_game_specs={
                    other_player: delegate_specs_for_players[other_player]
                },
                load_policy_spec_fn=create_get_pure_strat_cached(
                    cache=trainer.weights_cache))

    # Perform main RL training loop. Stop according to our StoppingCondition.
    train_iter_count = 0
    stopping_condition: StoppingCondition = get_stopping_condition()
    max_reward = None

    while True:
        train_iter_results = trainer.train(
        )  # do a step (or several) in the main RL loop
        train_iter_count += 1

        if print_train_results:
            train_iter_results["best_response_player"] = br_player
            # Delete verbose debugging info before printing
            if "hist_stats" in train_iter_results:
                del train_iter_results["hist_stats"]
            if "td_error" in train_iter_results["info"]["learner"][
                    f"best_response"]:
                del train_iter_results["info"]["learner"][f"best_response"][
                    "td_error"]
            log(f"Trainer log dir is {trainer.logdir}")
            print(pretty_dict_str(train_iter_results))

        br_reward_this_iter = train_iter_results["policy_reward_mean"][
            f"best_response"]
        if max_reward is None or br_reward_this_iter > max_reward:
            max_reward = br_reward_this_iter
        if stopping_condition.should_stop_this_iter(
                latest_trainer_result=train_iter_results):
            log("Stopping condition met.")
            break

    log(f"Training stopped.")

    # trainer.cleanup()
    # del trainer
    ray.shutdown()
    time.sleep(10)

    return max_reward
def train_off_policy_rl_nfsp_restricted_game(results_dir: str,
                                             scenario: NXDOScenario,
                                             player_to_base_game_action_specs: Dict[int, List[StrategySpec]],
                                             stopping_condition: StoppingCondition,
                                             manager_metadata: Union[dict, None],
                                             print_train_results: bool = True):

    use_openspiel_restricted_game: bool = scenario.use_openspiel_restricted_game
    get_restricted_game_custom_model = scenario.get_restricted_game_custom_model
    env_class = scenario.env_class
    base_env_config = scenario.env_config
    trainer_class = scenario.trainer_class_nfsp
    avg_trainer_class = scenario.avg_trainer_class_nfsp
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes_nfsp
    anticipatory_param: float = scenario.anticipatory_param_nfsp
    get_trainer_config = scenario.get_trainer_config_nfsp
    get_avg_trainer_config = scenario.get_avg_trainer_config_nfsp
    get_trainer_config_br = scenario.get_trainer_config_br
    calculate_openspiel_metanash: bool = scenario.calculate_openspiel_metanash
    calculate_openspiel_metanash_at_end: bool = scenario.calculate_openspiel_metanash_at_end
    calc_metanash_every_n_iters: int = scenario.calc_metanash_every_n_iters
    should_log_result_fn = scenario.ray_should_log_result_filter
    metrics_smoothing_episodes_override: int = scenario.metanash_metrics_smoothing_episodes_override

    assert scenario.xdo_metanash_method == "nfsp"

    ray_head_address = manager_metadata.get("ray_head_address", None) if manager_metadata is not None else None
    init_ray_for_scenario(scenario=scenario, head_address=ray_head_address, logging_level=logging.INFO)

    def select_policy(agent_id):
        random_sample = np.random.random()
        if agent_id == 0:
            if random_sample < anticipatory_param:
                return "best_response_0"
            return "average_policy_0"
        elif agent_id == 1:
            if random_sample < anticipatory_param:
                return "best_response_1"
            return "average_policy_1"
        else:
            raise ValueError(f"unexpected agent_id: {agent_id}")

    def assert_not_called(agent_id):
        assert False, "This function should never be called."

    def _create_base_env():
        return env_class(env_config=base_env_config)

    tmp_base_env = _create_base_env()
    restricted_env_config = {"create_env_fn": _create_base_env}

    if use_openspiel_restricted_game:
        restricted_game_class = OpenSpielRestrictedGame
        tmp_env = restricted_game_class(env_config=restricted_env_config)
        restricted_game_action_spaces = [tmp_env.base_action_space for _ in range(2)]
    else:
        restricted_game_class = RestrictedGame
        restricted_env_config["use_delegate_policy_exploration"] = scenario.allow_stochastic_best_responses
        tmp_env = restricted_game_class(env_config=restricted_env_config)
        restricted_game_action_spaces = [Discrete(n=len(player_to_base_game_action_specs[p])) for p in range(2)]

    assert all(restricted_game_action_spaces[0] == space for space in restricted_game_action_spaces)

    print(f"\n\n\n\n\nRestricted game action spaces {restricted_game_action_spaces}\n\n\n\n\n\n")

    scenario_avg_trainer_config = get_avg_trainer_config(tmp_base_env)
    scenario_avg_trainer_config_exploration_config = scenario_avg_trainer_config.get("exploration_config", {})
    if scenario_avg_trainer_config_exploration_config:
        del scenario_avg_trainer_config["exploration_config"]

    scenario_trainer_config = get_trainer_config(tmp_base_env)
    scenario_trainer_config_exploration_config = scenario_trainer_config.get("exploration_config", {})
    if scenario_trainer_config_exploration_config:
        del scenario_trainer_config["exploration_config"]

    delegate_policy_config = merge_dicts(get_trainer_config_br(tmp_base_env), {"explore": scenario.allow_stochastic_best_responses})

    avg_trainer_config = merge_dicts({
        "log_level": "DEBUG",
        "framework": "torch",
        "env": restricted_game_class,
        "env_config": restricted_env_config,
        "num_gpus": 0.0,
        "num_gpus_per_worker": 0.0,
        "num_workers": 0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": ["average_policy_0", "average_policy_1"],
            "policies": {
                "average_policy_0": (
                policy_classes["average_policy"], tmp_env.observation_space, restricted_game_action_spaces[0],
                {"explore": False, "exploration_config": scenario_avg_trainer_config_exploration_config}),

                "average_policy_1": (
                policy_classes["average_policy"], tmp_env.observation_space, restricted_game_action_spaces[1],
                {"explore": False, "exploration_config": scenario_avg_trainer_config_exploration_config}),

                "delegate_policy": (
                policy_classes["delegate_policy"], tmp_base_env.observation_space, tmp_env.base_action_space,
                delegate_policy_config),
            },
            "policy_mapping_fn": assert_not_called,
        },

    }, scenario_avg_trainer_config)
    for _policy_id in ["average_policy_0", "average_policy_1"]:
        if get_restricted_game_custom_model is not None:
            avg_trainer_config["multiagent"]["policies"][_policy_id][3]["model"] = {
                "custom_model": get_restricted_game_custom_model(tmp_env)}

    avg_trainer = avg_trainer_class(config=avg_trainer_config,
                                    logger_creator=get_trainer_logger_creator(
                                        base_dir=results_dir,
                                        scenario_name=f"nfsp_restricted_game_avg_trainer",
                                        should_log_result_fn=should_log_result_fn))

    store_to_avg_policy_buffer = get_store_to_avg_policy_buffer_fn(nfsp_trainer=avg_trainer)

    class NFSPBestResponseCallbacks(DefaultCallbacks):

        def on_postprocess_trajectory(self, *, worker: "RolloutWorker", episode: MultiAgentEpisode, agent_id: AgentID,
                                      policy_id: PolicyID, policies: Dict[PolicyID, Policy],
                                      postprocessed_batch: SampleBatch,
                                      original_batches: Dict[Any, Tuple[Policy, SampleBatch]],
                                      **kwargs):
            super().on_postprocess_trajectory(worker=worker, episode=episode, agent_id=agent_id, policy_id=policy_id,
                                              policies=policies, postprocessed_batch=postprocessed_batch,
                                              original_batches=original_batches, **kwargs)

            postprocessed_batch.data["source_policy"] = [policy_id] * len(postprocessed_batch.data["rewards"])

            # All data from both policies will go into the best response's replay buffer.
            # Here we ensure policies not from the best response have the exact same preprocessing as the best response.
            for average_policy_id, br_policy_id in [("average_policy_0", "best_response_0"),
                                                    ("average_policy_1", "best_response_1")]:
                if policy_id == average_policy_id:

                    if "action_probs" in postprocessed_batch:
                        del postprocessed_batch.data["action_probs"]
                    if "behaviour_logits" in postprocessed_batch:
                        del postprocessed_batch.data["behaviour_logits"]

                    br_policy: Policy = policies[br_policy_id]

                    new_batch = br_policy.postprocess_trajectory(
                        sample_batch=postprocessed_batch,
                        other_agent_batches=original_batches,
                        episode=episode)
                    copy_attributes(src_obj=new_batch, dst_obj=postprocessed_batch)
                elif policy_id == br_policy_id:
                    if "q_values" in postprocessed_batch:
                        del postprocessed_batch.data["q_values"]
                    if "action_probs" in postprocessed_batch:
                        del postprocessed_batch.data["action_probs"]
                    del postprocessed_batch.data["action_dist_inputs"]

                if policy_id in ("average_policy_0", "best_response_0"):
                    assert agent_id == 0
                if policy_id in ("average_policy_1", "best_response_1"):
                    assert agent_id == 1

        def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs):
            super().on_sample_end(worker=worker, samples=samples, **kwargs)
            assert isinstance(samples, MultiAgentBatch)

            for policy_samples in samples.policy_batches.values():
                if "action_prob" in policy_samples.data:
                    del policy_samples.data["action_prob"]
                if "action_logp" in policy_samples.data:
                    del policy_samples.data["action_logp"]

            for average_policy_id, br_policy_id in [("average_policy_0", "best_response_0"),
                                                    ("average_policy_1", "best_response_1")]:
                for policy_id, policy_samples in samples.policy_batches.items():
                    if policy_id == br_policy_id:
                        store_to_avg_policy_buffer(MultiAgentBatch(policy_batches={
                            average_policy_id: policy_samples
                        }, env_steps=policy_samples.count))
                if average_policy_id in samples.policy_batches:

                    if br_policy_id in samples.policy_batches:
                        all_policies_samples = samples.policy_batches[br_policy_id].concat(
                            other=samples.policy_batches[average_policy_id])
                    else:
                        all_policies_samples = samples.policy_batches[average_policy_id]
                    del samples.policy_batches[average_policy_id]
                    samples.policy_batches[br_policy_id] = all_policies_samples

        def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy],
                           episode: MultiAgentEpisode, env_index: int, **kwargs):
            super().on_episode_end(worker=worker, base_env=base_env, policies=policies, episode=episode,
                                   env_index=env_index, **kwargs)
            episode_policies = set(episode.agent_rewards.keys())
            if episode_policies == {(0, "average_policy_0"), (1, "best_response_1")}:
                worker.avg_br_reward_deque.add.remote(episode.agent_rewards[(1, "best_response_1")])
            elif episode_policies == {(1, "average_policy_1"), (0, "best_response_0")}:
                worker.avg_br_reward_deque.add.remote(episode.agent_rewards[(0, "best_response_0")])

        def on_train_result(self, *, trainer, result: dict, **kwargs):
            super().on_train_result(trainer=trainer, result=result, **kwargs)
            training_iteration = result["training_iteration"]

            result["avg_br_reward_both_players"] = ray.get(trainer.avg_br_reward_deque.get_mean.remote())

            if (calculate_openspiel_metanash and
                    (training_iteration == 1 or training_iteration % calc_metanash_every_n_iters == 0)):
                base_env = _create_base_env()
                open_spiel_env_config = base_env.open_spiel_env_config
                openspiel_game_version = base_env.game_version
                local_avg_policy_0 = trainer.workers.local_worker().policy_map["average_policy_0"]
                local_avg_policy_1 = trainer.workers.local_worker().policy_map["average_policy_1"]
                exploitability = nxdo_nfsp_measure_exploitability_nonlstm(
                    rllib_policies=[local_avg_policy_0, local_avg_policy_1],
                    poker_game_version=openspiel_game_version,
                    restricted_game_convertors=trainer.get_local_converters(),
                    open_spiel_env_config=open_spiel_env_config,
                    use_delegate_policy_exploration=scenario.allow_stochastic_best_responses
                )
                result["avg_policy_exploitability"] = exploitability

    br_trainer_config = {
        "log_level": "DEBUG",
        "callbacks": NFSPBestResponseCallbacks,
        "env": restricted_game_class,
        "env_config": restricted_env_config,
        "gamma": 1.0,
        "num_gpus": 0.0,
        "num_workers": 0,
        "num_gpus_per_worker": 0.0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": ["best_response_0", "best_response_1"],
            "policies": {
                "average_policy_0": (
                policy_classes["average_policy"], tmp_env.observation_space, restricted_game_action_spaces[0],
                {"explore": False, "exploration_config": scenario_avg_trainer_config_exploration_config}),

                "best_response_0": (
                policy_classes["best_response"], tmp_env.observation_space, restricted_game_action_spaces[0],
                {"exploration_config": scenario_trainer_config_exploration_config}),

                "average_policy_1": (
                policy_classes["average_policy"], tmp_env.observation_space, restricted_game_action_spaces[1],
                {"explore": False, "exploration_config": scenario_avg_trainer_config_exploration_config}),

                "best_response_1": (
                policy_classes["best_response"], tmp_env.observation_space, restricted_game_action_spaces[1],
                {"exploration_config": scenario_trainer_config_exploration_config}),

                "delegate_policy": (
                policy_classes["delegate_policy"], tmp_base_env.observation_space, tmp_env.base_action_space,
                delegate_policy_config),
            },
            "policy_mapping_fn": select_policy,
        },
    }
    assert all(restricted_game_action_spaces[0] == space for space in restricted_game_action_spaces), \
        "If not true, the line below with \"get_trainer_config\" may need to be changed to a better solution."
    br_trainer_config = merge_dicts(br_trainer_config, scenario_trainer_config)
    for _policy_id in ["average_policy_0", "average_policy_1", "best_response_0", "best_response_1"]:
        if get_restricted_game_custom_model is not None:
            br_trainer_config["multiagent"]["policies"][_policy_id][3]["model"] = {
                "custom_model": get_restricted_game_custom_model(tmp_env)}

    br_trainer_config["metrics_smoothing_episodes"] = metrics_smoothing_episodes_override

    br_trainer = trainer_class(config=br_trainer_config,
                               logger_creator=get_trainer_logger_creator(
                                   base_dir=results_dir,
                                   scenario_name="nfsp_restricted_game_trainer",
                                   should_log_result_fn=should_log_result_fn))

    avg_br_reward_deque = StatDeque.remote(max_items=br_trainer_config["metrics_smoothing_episodes"])

    def _set_avg_br_rew_deque(worker: RolloutWorker):
        worker.avg_br_reward_deque = avg_br_reward_deque

    br_trainer.workers.foreach_worker(_set_avg_br_rew_deque)
    br_trainer.avg_br_reward_deque = avg_br_reward_deque

    if use_openspiel_restricted_game:
        local_delegate_policy = br_trainer.workers.local_worker().policy_map["delegate_policy"]
        player_converters = []
        for p in range(2):
            print("Creating restricted game obs conversions...")
            convertor = get_restricted_game_obs_conversions(player=p, delegate_policy=local_delegate_policy,
                                                            policy_specs=player_to_base_game_action_specs[p],
                                                            load_policy_spec_fn=create_get_pure_strat_cached(cache={}),
                                                            tmp_base_env=tmp_base_env)
            player_converters.append(convertor)
        for _trainer in [br_trainer, avg_trainer]:
            def _set_worker_converters(worker: RolloutWorker):
                worker_delegate_policy = worker.policy_map["delegate_policy"]
                for p in range(2):
                    worker.foreach_env(lambda env: env.set_obs_conversion_dict(p, player_converters[p]))
                worker_delegate_policy.player_converters = player_converters

            _trainer.workers.foreach_worker(_set_worker_converters)
            _trainer.get_local_converters = lambda: _trainer.workers.local_worker().policy_map[
                "delegate_policy"].player_converters
    else:
        weights_cache = {}
        for _trainer in [br_trainer, avg_trainer]:
            def _set_worker_converters(worker: RolloutWorker):
                worker_delegate_policy = worker.policy_map["delegate_policy"]
                player_converters = []
                for p in range(2):
                    player_converter = RestrictedToBaseGameActionSpaceConverter(
                        delegate_policy=worker_delegate_policy, policy_specs=player_to_base_game_action_specs[p],
                        load_policy_spec_fn=create_get_pure_strat_cached(cache=weights_cache))
                    player_converters.append(player_converter)
                    worker.foreach_env(lambda env: env.set_action_conversion(p, player_converter))
                worker_delegate_policy.player_converters = player_converters

            _trainer.workers.foreach_worker(_set_worker_converters)
            _trainer.get_local_converters = lambda: _trainer.workers.local_worker().policy_map[
                "delegate_policy"].player_converters

    br_trainer.latest_avg_trainer_result = None
    train_iter_count = 0

    for _trainer in [br_trainer, avg_trainer]:
        for policy_id, policy in _trainer.workers.local_worker().policy_map.items():
            policy.policy_id = policy_id

    if len(player_to_base_game_action_specs[0]) == 1:
        final_train_result = {"episodes_total": 0, "timesteps_total": 0, "training_iteration": 0}
        tmp_callback = NFSPBestResponseCallbacks()
        tmp_callback.on_train_result(trainer=br_trainer, result=final_train_result)
    else:
        avg_weights = avg_trainer.get_weights(["average_policy_0", "average_policy_1"])
        br_trainer.workers.foreach_worker(lambda worker: worker.set_weights(avg_weights))
        while True:
            avg_train_results = avg_trainer.train()
            avg_weights = avg_trainer.get_weights(["average_policy_0", "average_policy_1"])
            br_trainer.workers.foreach_worker(lambda worker: worker.set_weights(avg_weights))
            br_trainer.latest_avg_trainer_result = copy.deepcopy(avg_train_results)
            train_iter_results = br_trainer.train()  # do a step (or several) in the main RL loop

            train_iter_count += 1
            if print_train_results:
                # Delete verbose debugging info before printing
                if "hist_stats" in train_iter_results:
                    del train_iter_results["hist_stats"]
                if "td_error" in train_iter_results["info"]["learner"]["best_response_0"]:
                    del train_iter_results["info"]["learner"]["best_response_0"]["td_error"]
                if "td_error" in train_iter_results["info"]["learner"]["best_response_1"]:
                    del train_iter_results["info"]["learner"]["best_response_1"]["td_error"]
                print(pretty_dict_str(train_iter_results))
                print(f"Trainer logdir is {br_trainer.logdir}")

            if stopping_condition.should_stop_this_iter(latest_trainer_result=train_iter_results):
                print("stopping condition met.")
                final_train_result = deepcopy(train_iter_results)
                break

    if calculate_openspiel_metanash_at_end:
        base_env = _create_base_env()
        open_spiel_env_config = base_env.open_spiel_env_config
        openspiel_game_version = base_env.game_version
        local_avg_policy_0 = br_trainer.workers.local_worker().policy_map["average_policy_0"]
        local_avg_policy_1 = br_trainer.workers.local_worker().policy_map["average_policy_1"]
        exploitability = nxdo_nfsp_measure_exploitability_nonlstm(
            rllib_policies=[local_avg_policy_0, local_avg_policy_1],
            poker_game_version=openspiel_game_version,
            restricted_game_convertors=br_trainer.get_local_converters(),
            open_spiel_env_config=open_spiel_env_config,
            use_delegate_policy_exploration=scenario.allow_stochastic_best_responses
        )
        final_train_result["avg_policy_exploitability"] = exploitability

    if "avg_policy_exploitability" in final_train_result:
        print(f"\n\nexploitability: {final_train_result['avg_policy_exploitability']}\n\n")

    avg_policy_specs = []
    for player in range(2):
        strategy_id = f"avg_policy_player_{player}_{datetime_str()}"

        checkpoint_path = save_nfsp_avg_policy_checkpoint(trainer=br_trainer,
                                                          policy_id_to_save=f"average_policy_{player}",
                                                          save_dir=checkpoint_dir(trainer=br_trainer),
                                                          timesteps_training=final_train_result["timesteps_total"],
                                                          episodes_training=final_train_result["episodes_total"],
                                                          checkpoint_name=f"{strategy_id}.h5")

        avg_policy_spec = StrategySpec(
            strategy_id=strategy_id,
            metadata={"checkpoint_path": checkpoint_path,
                      "delegate_policy_specs": [spec.to_json() for spec in player_to_base_game_action_specs[player]]
                      })
        avg_policy_specs.append(avg_policy_spec)

    ray.kill(avg_trainer.workers.local_worker().replay_buffer_actor)
    avg_trainer.cleanup()
    br_trainer.cleanup()
    del avg_trainer
    del br_trainer
    del avg_br_reward_deque

    time.sleep(10)

    assert final_train_result is not None
    return avg_policy_specs, final_train_result
def train_poker_approx_best_response_psro(br_player,
                                          ray_head_address,
                                          scenario_name,
                                          general_trainer_config_overrrides,
                                          br_policy_config_overrides: dict,
                                          get_stopping_condition,
                                          metanash_policy_specs,
                                          metanash_weights,
                                          results_dir,
                                          print_train_results=True):
    scenario: PSROScenario = scenario_catalog.get(scenario_name=scenario_name)

    env_class = scenario.env_class
    env_config = scenario.env_config
    trainer_class = scenario.trainer_class
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes
    p2sro = scenario.p2sro
    get_trainer_config = scenario.get_trainer_config
    psro_get_stopping_condition = scenario.psro_get_stopping_condition
    mix_metanash_with_uniform_dist_coeff = scenario.mix_metanash_with_uniform_dist_coeff

    other_player = 1 - br_player

    br_learner_name = f"new_learner_{br_player}"

    def log(message, level=logging.INFO):
        logger.log(level, f"({br_learner_name}): {message}")

    def select_policy(agent_id):
        if agent_id == br_player:
            return f"best_response"
        elif agent_id == other_player:
            return f"metanash"
        else:
            raise ValueError(f"Unknown agent id: {agent_id}")

    init_ray_for_scenario(scenario=scenario,
                          head_address=ray_head_address,
                          logging_level=logging.INFO)

    tmp_env = env_class(env_config=env_config)

    trainer_config = {
        "callbacks": P2SROPreAndPostEpisodeCallbacks,
        "env": env_class,
        "env_config": env_config,
        "gamma": 1.0,
        "num_gpus": 0,
        "num_workers": 0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": [f"best_response"],
            "policies": {
                f"metanash":
                (policy_classes["metanash"], tmp_env.observation_space,
                 tmp_env.action_space, {
                     "explore": scenario.allow_stochastic_best_responses
                 }),
                f"best_response":
                (policy_classes["best_response"], tmp_env.observation_space,
                 tmp_env.action_space, br_policy_config_overrides),
            },
            "policy_mapping_fn": select_policy,
        },
    }

    trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_env))
    trainer_config = merge_dicts(trainer_config,
                                 general_trainer_config_overrrides)

    # trainer_config["rollout_fragment_length"] = trainer_config["rollout_fragment_length"] // max(1, trainer_config["num_workers"] * trainer_config["num_envs_per_worker"] )

    trainer = trainer_class(config=trainer_config,
                            logger_creator=get_trainer_logger_creator(
                                base_dir=results_dir,
                                scenario_name="approx_br",
                                should_log_result_fn=lambda result: result[
                                    "training_iteration"] % 100 == 0))

    update_all_workers_to_latest_metanash(
        trainer=trainer,
        metanash_policy_specs=metanash_policy_specs,
        metanash_weights=metanash_weights)

    train_iter_count = 0

    stopping_condition: StoppingCondition = get_stopping_condition()
    max_reward = None
    while True:
        train_iter_results = trainer.train(
        )  # do a step (or several) in the main RL loop
        train_iter_count += 1

        if print_train_results:
            train_iter_results["best_response_player"] = br_player
            # Delete verbose debugging info before printing
            if "hist_stats" in train_iter_results:
                del train_iter_results["hist_stats"]
            if "td_error" in train_iter_results["info"]["learner"][
                    f"best_response"]:
                del train_iter_results["info"]["learner"][f"best_response"][
                    "td_error"]
            print(pretty_dict_str(train_iter_results))

        total_timesteps_training_br = train_iter_results["timesteps_total"]
        total_episodes_training_br = train_iter_results["episodes_total"]
        br_reward_this_iter = train_iter_results["policy_reward_mean"][
            f"best_response"]
        if max_reward is None or br_reward_this_iter > max_reward:
            max_reward = br_reward_this_iter
        if stopping_condition.should_stop_this_iter(
                latest_trainer_result=train_iter_results):
            break

    trainer.cleanup()
    ray.shutdown()
    time.sleep(10)

    return max_reward
Example #9
0
def train_psro_best_response(player: int, results_dir: str, scenario_name: str, psro_manager_port: int,
                             psro_manager_host: str, print_train_results=True, previous_br_checkpoint_path=None) -> str:
    
    scenario: PSROScenario = scenario_catalog.get(scenario_name=scenario_name)
    if not isinstance(scenario, PSROScenario):
        raise TypeError(f"Only instances of {PSROScenario} can be used here. {scenario.name} is a {type(scenario)}.")

    env_class = scenario.env_class
    env_config = scenario.env_config
    trainer_class = scenario.trainer_class
    policy_classes: Dict[str, Type[Policy]] = scenario.policy_classes
    single_agent_symmetric_game = scenario.single_agent_symmetric_game
    if single_agent_symmetric_game and player != 0:
        if player is None:
            player = 0
        else:
            raise ValueError(f"If treating the game as single agent symmetric, only use player 0 "
                             f"(one agent plays all sides).")

    p2sro = scenario.p2sro
    p2sro_sync_with_payoff_table_every_n_episodes = scenario.p2sro_sync_with_payoff_table_every_n_episodes
    get_trainer_config = scenario.get_trainer_config
    psro_get_stopping_condition = scenario.psro_get_stopping_condition
    mix_metanash_with_uniform_dist_coeff = scenario.mix_metanash_with_uniform_dist_coeff
    allow_stochastic_best_response = scenario.allow_stochastic_best_responses
    should_log_result_fn = scenario.ray_should_log_result_filter

    class P2SROPreAndPostEpisodeCallbacks(DefaultCallbacks):

        # def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode,
        #                     env_index: int,
        #                     **kwargs):
        #     super().on_episode_step(worker=worker, base_env=base_env, episode=episode, env_index=env_index, **kwargs)
        #
        #     # Debug render a single environment.
        #     if worker.worker_index == 1 and env_index == 0:
        #         base_env.get_unwrapped()[0].render()

        def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
                             policies: Dict[str, Policy],
                             episode: MultiAgentEpisode, env_index: int, **kwargs):

            # Sample new pure strategy policy weights from the opponent strategy distribution for the best response to
            # train against. For better runtime performance, this function can be modified to load new weights
            # only every few episodes instead.
            resample_pure_strat_every_n_episodes = 1
            metanash_policy: Policy = policies[f"metanash"]
            opponent_policy_distribution: PolicySpecDistribution = worker.opponent_policy_distribution
            time_for_resample = (not hasattr(metanash_policy, "episodes_since_resample") or
                                 metanash_policy.episodes_since_resample >= resample_pure_strat_every_n_episodes)
            if time_for_resample and opponent_policy_distribution is not None:
                new_pure_strat_spec: StrategySpec = opponent_policy_distribution.sample_policy_spec()
                # noinspection PyTypeChecker
                load_pure_strat(policy=metanash_policy, pure_strat_spec=new_pure_strat_spec)
                metanash_policy.episodes_since_resample = 1
            elif opponent_policy_distribution is not None:
                metanash_policy.episodes_since_resample += 1

        def on_train_result(self, *, trainer, result: dict, **kwargs):
            result["scenario_name"] = trainer.scenario_name
            super().on_train_result(trainer=trainer, result=result, **kwargs)

        def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
                           policies: Dict[str, Policy], episode: MultiAgentEpisode,
                           env_index: int, **kwargs):

            # If using P2SRO, report payoff results of the actively training BR to the payoff table.
            if not p2sro:
                return

            if not hasattr(worker, "p2sro_manager"):
                worker.p2sro_manager = RemoteP2SROManagerClient(n_players=2,
                                                                port=psro_manager_port,
                                                                remote_server_host=psro_manager_host)

            br_policy_spec: StrategySpec = worker.policy_map["best_response"].policy_spec
            if br_policy_spec.pure_strat_index_for_player(player=worker.br_player) == 0:
                # We're training policy 0 if True (first iteration of PSRO).
                # The PSRO subgame should be empty, and instead the metanash is a random neural network.
                # No need to report payoff results for this.
                return

            # Report payoff results for individual episodes to the p2sro manager to keep a real-time approximation
            # of the payoff matrix entries for (learning) active policies.
            policy_specs_for_each_player: List[StrategySpec] = [None, None]
            payoffs_for_each_player: List[float] = [None, None]
            for (player, policy_name), reward in episode.agent_rewards.items():
                assert policy_name in ["best_response", "metanash"]
                policy: Policy = worker.policy_map[policy_name]
                assert policy.policy_spec is not None
                policy_specs_for_each_player[player] = policy.policy_spec
                payoffs_for_each_player[player] = reward
            assert all(payoff is not None for payoff in payoffs_for_each_player)

            # Send payoff result to the manager for inclusion in the payoff table.
            worker.p2sro_manager.submit_empirical_payoff_result(
                policy_specs_for_each_player=tuple(policy_specs_for_each_player),
                payoffs_for_each_player=tuple(payoffs_for_each_player),
                games_played=1,
                override_all_previous_results=False)

    other_player = 1 - player
    br_learner_name = f"new_learner_{player}"

    def log(message):
        print(f"({br_learner_name}): {message}")

    def select_policy(agent_id):
        if agent_id == player:
            return "best_response"
        elif agent_id == other_player:
            return "metanash"
        else:
            raise ValueError(f"Unknown agent id: {agent_id}")

    p2sro_manager = RemoteP2SROManagerClient(n_players=2, port=psro_manager_port,
                                             remote_server_host=psro_manager_host)
    manager_metadata = p2sro_manager.get_manager_metadata()
    ray_head_address = manager_metadata["ray_head_address"]
    init_ray_for_scenario(scenario=scenario, head_address=ray_head_address, logging_level=logging.INFO)

    tmp_env = env_class(env_config=env_config)

    trainer_config = {
        "callbacks": P2SROPreAndPostEpisodeCallbacks,
        "env": env_class,
        "env_config": env_config,
        "gamma": 1.0,
        "num_gpus": 0,
        "num_workers": 0,
        "num_envs_per_worker": 1,
        "multiagent": {
            "policies_to_train": [f"best_response"],
            "policies": {
                f"metanash": (
                policy_classes["metanash"], tmp_env.observation_space, tmp_env.action_space, {"explore": allow_stochastic_best_response}),
                f"best_response": (
                policy_classes["best_response"], tmp_env.observation_space, tmp_env.action_space, {}),
            },
            "policy_mapping_fn": select_policy,
        },
    }

    trainer_config = merge_dicts(trainer_config, get_trainer_config(tmp_env))

    trainer = trainer_class(config=trainer_config,
                            logger_creator=get_trainer_logger_creator(
                                base_dir=results_dir, scenario_name=scenario_name,
                                should_log_result_fn=should_log_result_fn))

    # scenario_name logged in on_train_result_callback
    trainer.scenario_name = scenario_name

    if previous_br_checkpoint_path is not None:
        def _set_br_initial_weights(worker: RolloutWorker):
            br_policy = worker.policy_map["best_response"]
            load_pure_strat(policy=br_policy, checkpoint_path=previous_br_checkpoint_path)
        trainer.workers.foreach_worker(_set_br_initial_weights)

    active_policy_spec: StrategySpec = p2sro_manager.claim_new_active_policy_for_player(
        player=player, new_policy_metadata_dict=create_metadata_with_new_checkpoint_for_current_best_response(
            trainer=trainer, player=player, save_dir=checkpoint_dir(trainer), timesteps_training_br=0,
            episodes_training_br=0,
            active_policy_num=None
        ))

    active_policy_num = active_policy_spec.pure_strat_index_for_player(player=player)
    br_learner_name = f"policy {active_policy_num} player {player}"

    log(f"got policy {active_policy_num}")

    set_best_response_active_policy_spec_and_player_for_all_workers(trainer=trainer, player=player,
                                                                    active_policy_spec=active_policy_spec)

    sync_active_policy_br_and_metanash_with_p2sro_manager(trainer=trainer,
                                                          player=player,
                                                          metanash_player=other_player,
                                                          one_agent_plays_all_sides=single_agent_symmetric_game,
                                                          p2sro_manager=p2sro_manager,
                                                          mix_metanash_with_uniform_dist_coeff=mix_metanash_with_uniform_dist_coeff,
                                                          active_policy_num=active_policy_num,
                                                          timesteps_training_br=0,
                                                          episodes_training_br=0)

    # Perform main RL training loop. Stop training according to our StoppingCondition.
    train_iter_count = 0
    episodes_since_last_sync_with_manager = 0
    stopping_condition: StoppingCondition = psro_get_stopping_condition()

    while True:
        train_iter_results = trainer.train()  # do a step (or several) in the main RL loop
        train_iter_count += 1

        if print_train_results:
            train_iter_results["p2sro_active_policy_num"] = active_policy_num
            train_iter_results["best_response_player"] = player
            # Delete verbose debugging info before printing
            if "hist_stats" in train_iter_results:
                del train_iter_results["hist_stats"]
            if "td_error" in train_iter_results["info"]["learner"][f"best_response"]:
                del train_iter_results["info"]["learner"][f"best_response"]["td_error"]
            log(pretty_dict_str(train_iter_results))

        total_timesteps_training_br = train_iter_results["timesteps_total"]
        total_episodes_training_br = train_iter_results["episodes_total"]

        episodes_since_last_sync_with_manager += train_iter_results["episodes_this_iter"]
        if p2sro and episodes_since_last_sync_with_manager >= p2sro_sync_with_payoff_table_every_n_episodes:
            if p2sro_sync_with_payoff_table_every_n_episodes > 0:
                episodes_since_last_sync_with_manager = episodes_since_last_sync_with_manager % p2sro_sync_with_payoff_table_every_n_episodes
            else:
                episodes_since_last_sync_with_manager = 0

            sync_active_policy_br_and_metanash_with_p2sro_manager(trainer=trainer,
                                                                  player=player,
                                                                  metanash_player=other_player,
                                                                  one_agent_plays_all_sides=single_agent_symmetric_game,
                                                                  p2sro_manager=p2sro_manager,
                                                                  mix_metanash_with_uniform_dist_coeff=mix_metanash_with_uniform_dist_coeff,
                                                                  active_policy_num=active_policy_num,
                                                                  timesteps_training_br=total_timesteps_training_br,
                                                                  episodes_training_br=total_episodes_training_br)

        if stopping_condition.should_stop_this_iter(latest_trainer_result=train_iter_results):
            if p2sro_manager.can_active_policy_be_set_as_fixed_now(player=player, policy_num=active_policy_num):
                break
            else:
                log(f"Forcing training to continue since lower policies are still active.")

    log(f"Training stopped. Setting active policy {active_policy_num} as fixed.")

    final_policy_metadata = create_metadata_with_new_checkpoint_for_current_best_response(
            trainer=trainer, player=player, save_dir=checkpoint_dir(trainer=trainer),
            timesteps_training_br=total_timesteps_training_br,
            episodes_training_br=total_episodes_training_br,
            active_policy_num=active_policy_num,
            average_br_reward=train_iter_results["policy_reward_mean"]["best_response"])

    p2sro_manager.set_active_policy_as_fixed(
        player=player, policy_num=active_policy_num,
        final_metadata_dict=final_policy_metadata)

    trainer.cleanup()
    ray.shutdown()
    time.sleep(10)

    if not p2sro:
        # wait for both player policies to be fixed.
        for player_to_wait_on in range(2):
            wait_count = 0
            while True:
                if p2sro_manager.is_policy_fixed(player=player_to_wait_on, policy_num=active_policy_num):
                    break
                if wait_count % 10 == 0:
                    log(f"Waiting for policy {active_policy_num} player {player_to_wait_on} to become fixed")
                time.sleep(2.0)
                wait_count += 1

    return final_policy_metadata["checkpoint_path"]