Exemple #1
0
def tabular_policies_from_weighted_policies(game: OpenSpielGame,
                                            policy_iterable,
                                            weights: List[Tuple[float,
                                                                float]]):
    """Converts multiple Policy instances into an weighted averaged TabularPolicy.

    Args:
      game: The game for which we want a TabularPolicy.
      policy_iterable: for each player, an iterable that returns tuples of Openspiel policies
      weights: for each player, probabilities of selecting each policy

    Returns:
      A averaged OpenSpiel Policy over the policy_iterable.
    """
    num_players = game.num_players()
    # A set of callables that take in a state and return a list of
    # (action, probability) tuples.
    avg_policies = [None, None]
    total_weights_added = np.zeros(num_players)
    for index, (best_responses,
                weights_for_each_br) in enumerate(zip(policy_iterable,
                                                      weights)):
        weights_for_each_br = np.asarray(weights_for_each_br, dtype=np.float64)
        total_weights_added += weights_for_each_br
        if index == 0:
            for i in range(num_players):
                avg_policies[i] = tabular_policy_from_callable(
                    game=game, callable_policy=best_responses[i])
        else:
            br_reach_probs = np.ones(num_players)
            avg_reach_probs = np.ones(num_players)
            average_policy_tables = [{} for _ in range(num_players)]
            _recursively_update_average_policies(
                state=game.new_initial_state(),
                avg_reach_probs=avg_reach_probs,
                br_reach_probs=br_reach_probs,
                avg_policies=avg_policies,
                best_responses=best_responses,
                alpha=weights_for_each_br / total_weights_added,
                avg_policy_tables=average_policy_tables)
            for i in range(num_players):
                avg_policies[i] = _callable_tabular_policy(
                    average_policy_tables[i])

    for i in range(num_players):
        avg_policies[i] = tabular_policy_from_callable(
            game=game, callable_policy=avg_policies[i], players=[i])

    return avg_policies
Exemple #2
0
 def save_deepcfr():  # and print some info i guess?
     print("---------iteration " + str(it) + "----------")
     for player, losses in six.iteritems(advantage_losses):
         print("Advantage for player ", player, losses)
         print("Advantage Buffer Size for player", player,
               len(deep_cfr_solver.advantage_buffers[player]))
     print("Strategy Buffer Size: ",
           len(deep_cfr_solver.strategy_buffer))
     print("policy loss: ", policy_loss)
     callable_policy = tabular_policy_from_callable(game, deep_cfr_solver.action_probabilities)
     tabular_policy = tabular_policy_from_callable(game, callable_policy)
     policy = dict(zip(tabular_policy.state_lookup, tabular_policy.action_probability_array))
     # save under map (save_prefix)_(num_travers)
     return policy_handler.save_to_tabular_policy(game, policy, "policies/DEEPCFR/{}/{}".format(
         save_prefix + "_" + str(num_travers), it))
Exemple #3
0
def main(unused_argv):
  logging.info("Loading %s", FLAGS.game_name)
  game = pyspiel.load_game(FLAGS.game_name)
  with tf.Session() as sess:
    deep_cfr_solver = deep_cfr.DeepCFRSolver(
        sess,
        game,
        policy_network_layers=(32, 32),
        advantage_network_layers=(16, 16),
        num_iterations=FLAGS.num_iterations,
        num_traversals=FLAGS.num_traversals,
        learning_rate=1e-3,
        batch_size_advantage=None,
        batch_size_strategy=None,
        memory_capacity=1e7)
    sess.run(tf.global_variables_initializer())
    _, advantage_losses, policy_loss = deep_cfr_solver.solve()
    for player, losses in six.iteritems(advantage_losses):
      logging.info("Advantage for player %d: %s", player,
                   losses[:2] + ["..."] + losses[-2:])
      logging.info("Advantage Buffer Size for player %s: '%s'", player,
                   len(deep_cfr_solver.advantage_buffers[player]))
    logging.info("Strategy Buffer Size: '%s'",
                 len(deep_cfr_solver.strategy_buffer))
    logging.info("Final policy loss: '%s'", policy_loss)
    conv = exploitability.nash_conv(
        game,
        policy.tabular_policy_from_callable(
            game, deep_cfr_solver.action_probabilities))
    logging.info("Deep CFR in '%s' - NashConv: %s", FLAGS.game_name, conv)
Exemple #4
0
def print_algorithm_results(game, callable_policy, algorithm_name):
    print(algorithm_name.upper())
    tabular_policy = tabular_policy_from_callable(game, callable_policy)
    policy_exploitability = exploitability(game, tabular_policy)
    policy_nashconv = nash_conv(game, tabular_policy)
    print("exploitability = {}".format(policy_exploitability))
    print("nashconv = {}".format(policy_nashconv))
Exemple #5
0
def get_algo_policies(algo_path, files, game):
    print("Extracting the policies...")
    algo_policies = {}
    for file in files:
        algo_iterations = (int(file.split(algo_path)[1]))
        algo_policy = policy_handler.load_to_tabular_policy(file)
        algo_policies[algo_iterations] = policy.tabular_policy_from_callable(game, algo_policy)
    return algo_policies
Exemple #6
0
def create_random_tabular_policy(game, players=(0, 1)):
    def _random_action_callable_policy(state) -> dict:
        legal_actions_list = state.legal_actions()
        chosen_legal_action = np.random.choice(legal_actions_list)
        return {
            action: (1.0 if action == chosen_legal_action else 0.0)
            for action in legal_actions_list
        }

    return tabular_policy_from_callable(
        game=game, callable_policy=_random_action_callable_policy)
Exemple #7
0
  def _compute_best_responses(self):
    """Computes each player best-response against the pool of other players."""

    def policy_fn(state):
      key = state.information_state_string()
      return self._get_infostate_policy(key)

    current_policy = policy.tabular_policy_from_callable(self._game, policy_fn)

    for player_id in range(self._game.num_players()):
      self._best_responses[player_id] = exploitability.best_response(
          self._game, current_policy, player_id)
 def test_outcome_sampling_kuhn_2p(self):
     np.random.seed(SEED)
     game = pyspiel.load_game("kuhn_poker")
     os_solver = outcome_sampling_mccfr.OutcomeSamplingSolver(game)
     for _ in range(10000):
         os_solver.iteration()
     conv = exploitability.nash_conv(
         game,
         policy.tabular_policy_from_callable(
             game, os_solver.callable_avg_policy()))
     print("Kuhn2P, conv = {}".format(conv))
     self.assertLess(conv, 0.17)
 def test_external_sampling_kuhn_2p_full(self):
   np.random.seed(SEED)
   game = pyspiel.load_game("kuhn_poker")
   es_solver = external_sampling_mccfr.ExternalSamplingSolver(
       game, external_sampling_mccfr.AverageType.FULL)
   for _ in range(10):
     es_solver.iteration()
   conv = exploitability.nash_conv(
       game,
       policy.tabular_policy_from_callable(game,
                                           es_solver.callable_avg_policy()))
   print("Kuhn2P, conv = {}".format(conv))
   self.assertLess(conv, 1)
 def disabled_test_external_sampling_liars_dice_2p_simple(self):
   np.random.seed(SEED)
   game = pyspiel.load_game("liars_dice")
   es_solver = external_sampling_mccfr.ExternalSamplingSolver(
       game, external_sampling_mccfr.AverageType.SIMPLE)
   for _ in range(1):
     es_solver.iteration()
   conv = exploitability.nash_conv(
       game,
       policy.tabular_policy_from_callable(game,
                                           es_solver.callable_avg_policy()))
   print("Liar's dice, conv = {}".format(conv))
   self.assertLess(conv, 2)
 def test_external_sampling_kuhn_3p_simple(self):
   np.random.seed(SEED)
   game = pyspiel.load_game("kuhn_poker",
                            {"players": pyspiel.GameParameter(3)})
   es_solver = external_sampling_mccfr.ExternalSamplingSolver(
       game, external_sampling_mccfr.AverageType.SIMPLE)
   for _ in range(10):
     es_solver.iteration()
   conv = exploitability.nash_conv(
       game,
       policy.tabular_policy_from_callable(game,
                                           es_solver.callable_avg_policy()))
   print("Kuhn3P, conv = {}".format(conv))
   self.assertLess(conv, 2)
Exemple #12
0
def main(_):
  game = pyspiel.load_game(FLAGS.game,
                           {"players": pyspiel.GameParameter(FLAGS.players)})
  if FLAGS.sampling == "external":
    cfr_solver = external_mccfr.ExternalSamplingSolver(
        game, external_mccfr.AverageType.SIMPLE)
  else:
    cfr_solver = outcome_mccfr.OutcomeSamplingSolver(game)
  for i in range(FLAGS.iterations):
    cfr_solver.iteration()
    if i % FLAGS.print_freq == 0:
      conv = exploitability.nash_conv(
          game,
          policy.tabular_policy_from_callable(game,
                                              cfr_solver.callable_avg_policy()))
      print("Iteration {} exploitability {}".format(i, conv))
 def test_matching_pennies_3p(self):
     game = pyspiel.load_game_as_turn_based('matching_pennies_3p')
     deep_cfr_solver = deep_cfr.DeepCFRSolver(game,
                                              policy_network_layers=(16, 8),
                                              advantage_network_layers=(32,
                                                                        16),
                                              num_iterations=2,
                                              num_traversals=2,
                                              learning_rate=1e-3,
                                              batch_size_advantage=None,
                                              batch_size_strategy=None,
                                              memory_capacity=1e7)
     deep_cfr_solver.solve()
     conv = pyspiel.nash_conv(
         game,
         policy.python_policy_to_pyspiel_policy(
             policy.tabular_policy_from_callable(
                 game, deep_cfr_solver.action_probabilities)))
     logging.info('Deep CFR in Matching Pennies 3p. NashConv: %.2f', conv)
Exemple #14
0
 def test_matching_pennies_3p(self):
     # We don't expect Deep CFR to necessarily converge on 3-player games but
     # it's nonetheless interesting to see this result.
     game = pyspiel.load_game_as_turn_based('matching_pennies_3p')
     deep_cfr_solver = deep_cfr.DeepCFRSolver(game,
                                              policy_network_layers=(16, 8),
                                              advantage_network_layers=(32,
                                                                        16),
                                              num_iterations=2,
                                              num_traversals=2,
                                              learning_rate=1e-3,
                                              batch_size_advantage=8,
                                              batch_size_strategy=8,
                                              memory_capacity=1e7)
     deep_cfr_solver.solve()
     conv = exploitability.nash_conv(
         game,
         policy.tabular_policy_from_callable(
             game, deep_cfr_solver.action_probabilities))
     print('Deep CFR in Matching Pennies 3p. NashConv: {}'.format(conv))
def main(unused_argv):
    logging.info("Loading %s", FLAGS.game_name)
    game = pyspiel.load_game(FLAGS.game_name)
    with tf.Session() as sess:
        deep_cfr_solver = deep_cfr.DeepCFRSolver(
            sess,
            game,
            policy_network_layers=(16, ),
            advantage_network_layers=(16, ),
            num_iterations=FLAGS.num_iterations,
            num_traversals=FLAGS.num_traversals,
            learning_rate=1e-3,
            batch_size_advantage=128,
            batch_size_strategy=1024,
            memory_capacity=1e7,
            policy_network_train_steps=400,
            advantage_network_train_steps=20,
            reinitialize_advantage_networks=False)
        sess.run(tf.global_variables_initializer())
        _, advantage_losses, policy_loss = deep_cfr_solver.solve()
        for player, losses in six.iteritems(advantage_losses):
            logging.info("Advantage for player %d: %s", player,
                         losses[:2] + ["..."] + losses[-2:])
            logging.info("Advantage Buffer Size for player %s: '%s'", player,
                         len(deep_cfr_solver.advantage_buffers[player]))
        logging.info("Strategy Buffer Size: '%s'",
                     len(deep_cfr_solver.strategy_buffer))
        logging.info("Final policy loss: '%s'", policy_loss)

        average_policy = policy.tabular_policy_from_callable(
            game, deep_cfr_solver.action_probabilities)

        conv = exploitability.nash_conv(game, average_policy)
        logging.info("Deep CFR in '%s' - NashConv: %s", FLAGS.game_name, conv)

        average_policy_values = expected_game_score.policy_value(
            game.new_initial_state(), [average_policy] * 2)
        print("Computed player 0 value: {}".format(average_policy_values[0]))
        print("Expected player 0 value: {}".format(-1 / 18))
        print("Computed player 1 value: {}".format(average_policy_values[1]))
        print("Expected player 1 value: {}".format(1 / 18))
Exemple #16
0
def main(unused_argv):
    logging.info("Loading %s", FLAGS.game_name)
    game = pyspiel.load_game(FLAGS.game_name)
    deep_cfr_solver = deep_cfr_tf2.DeepCFRSolver(
        game,
        policy_network_layers=(64, 64, 64, 64),
        advantage_network_layers=(64, 64, 64, 64),
        num_iterations=FLAGS.num_iterations,
        num_traversals=FLAGS.num_traversals,
        learning_rate=1e-3,
        batch_size_advantage=2048,
        batch_size_strategy=2048,
        memory_capacity=1e6,
        policy_network_train_steps=5000,
        advantage_network_train_steps=500,
        reinitialize_advantage_networks=True,
        infer_device="cpu",
        train_device="cpu")
    _, advantage_losses, policy_loss = deep_cfr_solver.solve()
    for player, losses in six.iteritems(advantage_losses):
        logging.info("Advantage for player %d: %s", player,
                     losses[:2] + ["..."] + losses[-2:])
        logging.info("Advantage Buffer Size for player %s: '%s'", player,
                     len(deep_cfr_solver.advantage_buffers[player]))
    logging.info("Strategy Buffer Size: '%s'",
                 len(deep_cfr_solver.strategy_buffer))
    logging.info("Final policy loss: '%s'", policy_loss)

    average_policy = policy.tabular_policy_from_callable(
        game, deep_cfr_solver.action_probabilities)

    conv = exploitability.nash_conv(game, average_policy)
    logging.info("Deep CFR in '%s' - NashConv: %s", FLAGS.game_name, conv)

    average_policy_values = expected_game_score.policy_value(
        game.new_initial_state(), [average_policy] * 2)
    print("Computed player 0 value: {}".format(average_policy_values[0]))
    print("Computed player 1 value: {}".format(average_policy_values[1]))
def main(unused_argv):
    logging.info("Loading %s", FLAGS.game_name)
    game = pyspiel.load_game(FLAGS.game_name)

    deep_cfr_solver = deep_cfr.DeepCFRSolver(
        game,
        policy_network_layers=(32, 32),
        advantage_network_layers=(16, 16),
        num_iterations=FLAGS.num_iterations,
        num_traversals=FLAGS.num_traversals,
        learning_rate=1e-3,
        batch_size_advantage=None,
        batch_size_strategy=None,
        memory_capacity=int(1e7))

    _, advantage_losses, policy_loss = deep_cfr_solver.solve()
    for player, losses in six.iteritems(advantage_losses):
        logging.info("Advantage for player %d: %s", player,
                     losses[:2] + ["..."] + losses[-2:])
        logging.info("Advantage Buffer Size for player %s: '%s'", player,
                     len(deep_cfr_solver.advantage_buffers[player]))
    logging.info("Strategy Buffer Size: '%s'",
                 len(deep_cfr_solver.strategy_buffer))
    logging.info("Final policy loss: '%s'", policy_loss)

    average_policy = policy.tabular_policy_from_callable(
        game, deep_cfr_solver.action_probabilities)
    pyspiel_policy = policy.python_policy_to_pyspiel_policy(average_policy)
    conv = pyspiel.nash_conv(game, pyspiel_policy)
    logging.info("Deep CFR in '%s' - NashConv: %s", FLAGS.game_name, conv)

    average_policy_values = expected_game_score.policy_value(
        game.new_initial_state(), [average_policy] * 2)
    logging.info("Computed player 0 value: %.2f (expected: %.2f).",
                 average_policy_values[0], -1 / 18)
    logging.info("Computed player 1 value: %.2f (expected: %.2f).",
                 average_policy_values[1], 1 / 18)
Exemple #18
0
 def save_nfsp():
     tabular_policy = policy.tabular_policy_from_callable(game, expl_policies_avg)
     policy_handler.save_tabular_policy(game, tabular_policy, "policies/NFSP/{}/{}".format(save_prefix, it))
Exemple #19
0
    def __call__(self, player, player_policy, info_states):
        """Computes action values per state for the player.

    Args:
      player: The id of the player (0 <= player < game.num_players()). This
        player will play `player_policy`, while the opponent will play a best
        response.
      player_policy: A `policy.Policy` object.
      info_states: A list of info state strings.

    Returns:
      A `_CalculatorReturn` nametuple. See its docstring for the documentation.
    """
        self.player = player
        opponent = 1 - player

        def best_response_policy(state):
            infostate = state.information_state_string(opponent)
            action = best_response_actions[infostate]
            return [(action, 1.0)]

        # If the policy is a TabularPolicy, we can directly copy the infostate
        # strings & values from the class. This is significantly faster than having
        # to create the infostate strings.
        if isinstance(player_policy, policy.TabularPolicy):
            tabular_policy = {
                key: _tuples_from_policy(player_policy.policy_for_key(key))
                for key in player_policy.state_lookup
            }
        # Otherwise, we have to calculate all the infostate strings everytime. This
        # is ~2x slower.
        else:
            # We cache these as they are expensive to compute & do not change.
            if self._all_states is None:
                self._all_states = get_all_states.get_all_states(
                    self.game,
                    depth_limit=-1,
                    include_terminals=False,
                    include_chance_states=False)
                self._state_to_information_state = {
                    state: self._all_states[state].information_state_string()
                    for state in self._all_states
                }
            tabular_policy = policy_utils.policy_to_dict(
                player_policy, self.game, self._all_states,
                self._state_to_information_state)

        # When constructed, TabularBestResponse does a lot of work; we can save that
        # work by caching it.
        if self._best_responder[player] is None:
            self._best_responder[player] = pyspiel.TabularBestResponse(
                self.game, opponent, tabular_policy)
        else:
            self._best_responder[player].set_policy(tabular_policy)

        # Computing the value at the root calculates best responses everywhere.
        history = str(self.game.new_initial_state())
        best_response_value = self._best_responder[player].value(history)
        best_response_actions = self._best_responder[
            player].get_best_response_actions()

        # Compute action values
        self._action_value_calculator.compute_all_states_action_values({
            player:
            player_policy,
            opponent:
            policy.tabular_policy_from_callable(self.game,
                                                best_response_policy,
                                                [opponent]),
        })
        obj = self._action_value_calculator._get_tabular_statistics(  # pylint: disable=protected-access
            ((player, s) for s in info_states))

        # Return values
        return _CalculatorReturn(
            exploitability=best_response_value,
            values_vs_br=obj.action_values,
            counterfactual_reach_probs_vs_br=obj.counterfactual_reach_probs,
            player_reach_probs_vs_br=obj.player_reach_probs)
Exemple #20
0
def openspiel_policy_from_nonlstm_rllib_policy(openspiel_game: OpenSpielGame,
                                               game_version: str,
                                               game_parameters: dict,
                                               rllib_policy: Policy):
    if openspiel_game.get_type().short_name == "universal_poker":
        print(
            "Converting universal_poker rllib policy to tabular. This will take a while..."
        )

    def policy_callable(state: pyspiel.State):

        valid_actions_mask = state.legal_actions_mask()
        legal_actions_list = state.legal_actions()
        # assert np.array_equal(valid_actions, np.ones_like(valid_actions)) # should be always true for Kuhn Poker
        try:
            info_state_vector = state.information_state_tensor()
        except pyspiel.SpielError:
            assert openspiel_game.get_type(
            ).short_name == "turn_based_simultaneous_game"
            assert game_version == "oshi_zumo", game_version
            info_state_vector = state.observation_tensor(
                state.current_player())[4:]
            info_state_vector = get_oshi_zumo_obs(
                openspiel_observation_tensor=info_state_vector,
                starting_coins=int(str(game_parameters["coins"])))

        if game_version in [
                "leduc_poker", "oshi_zumo", "oshi_zumo_tiny", "universal_poker"
        ]:
            obs = np.concatenate(
                (np.asarray(info_state_vector, dtype=np.float32),
                 np.asarray(valid_actions_mask, dtype=np.float32)),
                axis=0)
        else:
            obs = np.asarray(info_state_vector, dtype=np.float32)

        _, _, action_info = rllib_policy.compute_single_action(obs=obs,
                                                               state=[],
                                                               explore=False)

        action_probs = None
        for key in ['policy_targets', 'action_probs']:
            if key in action_info:
                action_probs = action_info[key]
                break
        if action_probs is None:
            action_logits = action_info['behaviour_logits']
            action_probs = softmax(action_logits)

        if len(action_probs) > len(valid_actions_mask) and len(
                action_probs) % len(valid_actions_mask) == 0:
            # we may be using a dummy action variant of poker
            dummy_action_probs = action_probs.copy()
            action_probs = np.zeros_like(valid_actions_mask, dtype=np.float64)
            for i, action_prob in enumerate(dummy_action_probs):
                action_probs[i % len(valid_actions_mask)] += action_prob
            assert np.isclose(sum(action_probs), 1.0), sum(action_probs)

        assert np.isclose(sum(action_probs), 1.0)

        legal_action_probs = []
        valid_action_prob_sum = 0.0
        for idx in range(len(valid_actions_mask)):
            if valid_actions_mask[idx] == 1.0:
                legal_action_probs.append(action_probs[idx])
                valid_action_prob_sum += action_probs[idx]
        assert np.isclose(valid_action_prob_sum,
                          1.0), (action_probs, valid_actions_mask,
                                 action_info.get('behaviour_logits'))

        # avoid triggering any downstream assertions due to tiny near-zero amounts
        for i in range(len(legal_action_probs)):
            if np.isclose(legal_action_probs[i], 0.0):
                legal_action_probs[i] = 0.0

        return {
            action_name: action_prob
            for action_name, action_prob in zip(legal_actions_list,
                                                legal_action_probs)
        }

    # defensive copy to tabular policy in case the rllib policy changes after this function is called
    return tabular_policy_from_callable(game=openspiel_game,
                                        callable_policy=policy_callable)
Exemple #21
0
def openspiel_policy_from_nonlstm_rllib_nxdo_policy(
    openspiel_game: OpenSpielGame, rllib_policy: Policy,
    use_delegate_policy_exploration: bool, restricted_game_convertor: Union[
        RestrictedToBaseGameActionSpaceConverter,
        AgentRestrictedGameOpenSpielObsConversions]):
    is_openspiel_restricted_game = isinstance(
        restricted_game_convertor, AgentRestrictedGameOpenSpielObsConversions)

    def policy_callable(state: pyspiel.State):

        valid_actions_mask = state.legal_actions_mask()
        legal_actions_list = state.legal_actions()

        # assert np.array_equal(valid_actions, np.ones_like(valid_actions)) # should be always true at least for Kuhn

        info_state_vector = state.information_state_tensor()

        if openspiel_game.get_type().short_name in [
                "leduc_poker", "oshi_zumo", "oshi_zumo_tiny", "universal_poker"
        ] or is_openspiel_restricted_game:
            # Observation includes both the info_state and legal actions, but agent isn't forced to take legal actions.
            # Taking an illegal action will result in a random legal action being played.
            # Allows easy compatibility with standard RL implementations for small action-space games like this one.
            obs = np.concatenate(
                (np.asarray(info_state_vector, dtype=np.float32),
                 np.asarray(valid_actions_mask, dtype=np.float32)),
                axis=0)
        else:
            obs = np.asarray(info_state_vector, dtype=np.float32)

        if is_openspiel_restricted_game:
            os_restricted_game_convertor: AgentRestrictedGameOpenSpielObsConversions = restricted_game_convertor
            try:
                obs = os_restricted_game_convertor.orig_obs_to_restricted_game_obs[
                    tuple(obs)]
            except KeyError:
                print(
                    f"missing key: {tuple(obs)}\nexample key: {list(os_restricted_game_convertor.orig_obs_to_restricted_game_obs.keys())[0]}"
                )
                raise
        action, _, restricted_action_info = rllib_policy.compute_single_action(
            obs=obs, state=[], explore=False)
        restricted_game_action_probs = _parse_action_probs_from_action_info(
            action=action,
            action_info=restricted_action_info,
            legal_actions_list=legal_actions_list,
            total_num_discrete_actions=len(valid_actions_mask))

        if is_openspiel_restricted_game:
            action_probs = restricted_game_action_probs
        else:
            base_action_probs_for_each_restricted_game_action = []
            for restricted_game_action in range(
                    len(restricted_game_action_probs)):
                sampled_base_game_action, _, action_info = restricted_game_convertor.get_base_game_action(
                    obs=obs,
                    restricted_game_action=restricted_game_action,
                    use_delegate_policy_exploration=
                    use_delegate_policy_exploration,
                    clip_base_game_actions=False,
                    delegate_policy_state=None)

                base_game_action_probs_for_rstr_action = _parse_action_probs_from_action_info(
                    action=sampled_base_game_action,
                    action_info=action_info,
                    legal_actions_list=legal_actions_list,
                    total_num_discrete_actions=len(valid_actions_mask))
                base_action_probs_for_each_restricted_game_action.append(
                    base_game_action_probs_for_rstr_action)

            action_probs = np.zeros_like(
                base_action_probs_for_each_restricted_game_action[0])
            for base_action_probs, restricted_action_prob in zip(
                    base_action_probs_for_each_restricted_game_action,
                    restricted_game_action_probs):
                action_probs += (base_action_probs * restricted_action_prob)

        assert np.isclose(sum(action_probs), 1.0)

        if len(action_probs) > len(valid_actions_mask) and len(
                action_probs) % len(valid_actions_mask) == 0:
            # we may be using a dummy action variant of poker
            dummy_action_probs = action_probs.copy()
            action_probs = np.zeros_like(valid_actions_mask, dtype=np.float64)
            for i, action_prob in enumerate(dummy_action_probs):
                action_probs[i % len(valid_actions_mask)] += action_prob
            assert np.isclose(sum(action_probs), 1.0)

        # Since the rl env will execute a random legal action if an illegal action is chosen, redistribute probability
        # of choosing an illegal action evenly across all legal actions.
        # num_legal_actions = sum(valid_actions_mask)
        # if num_legal_actions > 0:
        #     total_legal_action_probability = sum(action_probs * valid_actions_mask)
        #     total_illegal_action_probability = 1.0 - total_legal_action_probability
        #     action_probs = (action_probs + (total_illegal_action_probability / num_legal_actions)) * valid_actions_mask

        assert np.isclose(sum(action_probs), 1.0)

        legal_action_probs = []
        valid_action_prob_sum = 0.0
        for idx in range(len(valid_actions_mask)):
            if valid_actions_mask[idx] == 1.0:
                legal_action_probs.append(action_probs[idx])
                valid_action_prob_sum += action_probs[idx]
        assert np.isclose(valid_action_prob_sum, 1.0)

        return {
            action_name: action_prob
            for action_name, action_prob in zip(legal_actions_list,
                                                legal_action_probs)
        }

    # callable_policy = PolicyFromCallable(game=openspiel_game, callable_policy=policy_callable)

    # convert to tabular policy in case the rllib policy changes after this function is called
    return tabular_policy_from_callable(game=openspiel_game,
                                        callable_policy=policy_callable)