def profile_model(model_path):
    late_game = load_late_game()

    model = load_diplomacy_model(model_path, map_location="cuda", eval=True)

    for game_name, game in [("new_game", Game()), ("late_game", late_game)]:
        print("\n#", game_name)
        inputs = FeatureEncoder().encode_inputs([game])
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

        for batch_size in B:
            b_inputs = {
                k: v.repeat((batch_size, ) + (1, ) * (len(v.shape) - 1))
                for k, v in inputs.items()
            }
            with torch.no_grad():
                tic = time.time()
                for _ in range(N):
                    order_idxs, order_scores, cand_scores, final_scores = model(
                        **b_inputs, temperature=1.0)
                toc = time.time() - tic

                print(
                    f"[B={batch_size}] {toc}s / {N}, latency={1000*toc/N}ms, throughput={N*batch_size/toc}/s"
                )
    def get_orders(self, game, power, *, temperature=None, top_p=None):
        if len(game.get_orderable_locations().get(power, [])) == 0:
            return []

        temperature = temperature if temperature is not None else self.temperature
        top_p = top_p if top_p is not None else self.top_p
        inputs = FeatureEncoder().encode_inputs([game])
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            order_idxs, cand_idxs, logits, final_scores = self.model(
                **inputs, temperature=temperature, top_p=top_p)

        resample_duplicate_disbands_inplace(order_idxs, cand_idxs, logits,
                                            inputs["x_possible_actions"],
                                            inputs["x_in_adj_phase"])
        return decode_order_idxs(order_idxs[0, POWERS.index(power), :])
 def get_values(self, game) -> np.ndarray:
     batch_inputs = FeatureEncoder().encode_inputs([game])
     batch_est_final_scores = self.do_model_request(
         batch_inputs,
         self.rollout_temperature,
         self.rollout_top_p,
         values_only=True)
     return batch_est_final_scores[0]
    def do_rollout(
        cls,
        *,
        game_json,
        hostport,
        set_orders_dict={},
        temperature,
        top_p,
        max_rollout_length,
        batch_size=1,
        use_predicted_final_scores,
        mix_square_ratio_scoring=0,
        value_hostport=None,
        rollout_value_frac=0,
    ) -> Tuple[Tuple[Dict, List[Dict]], TimingCtx]:
        """Complete game, optionally setting orders for the current turn

        This method can safely be called in a subprocess

        Arguments:
        - game_json: json-formatted game string, e.g. output of to_saved_game_format(game)
        - hostport: string, "{host}:{port}" of model server
        - set_orders_dict: Dict[power, orders] to set for current turn
        - temperature: model softmax temperature for rollout policy
        - top_p: probability mass to samples from for rollout policy
        - max_rollout_length: return SC count after at most # steps
        - batch_size: rollout # of games in parallel
        - use_predicted_final_scores: if True, use model's value head for final SC predictions

        Returns a 2-tuple:
        - results, a 2-tuple:
          - set_orders_dict: Dict[power, orders]
          - list of Dict[power, final_score], len=batch_size
        - timings: a TimingCtx
        """
        timings = TimingCtx()

        with timings("postman.client"):
            client = postman.Client(hostport)
            client.connect(3)
            if value_hostport is not None:
                value_client = postman.Client(value_hostport)
                value_client.connect(3)
            else:
                value_client = client

        with timings("setup"):
            faulthandler.register(signal.SIGUSR2)
            torch.set_num_threads(1)

            games = [
                pydipcc.Game.from_json(game_json) for _ in range(batch_size)
            ]
            for i in range(len(games)):
                games[i].game_id += f"_{i}"

            est_final_scores = {}  # game id -> np.array len=7

            # set orders if specified
            for power, orders in set_orders_dict.items():
                for game in games:
                    game.set_orders(power, list(orders))

            other_powers = [p for p in POWERS if p not in set_orders_dict]

        rollout_start_phase = games[0].current_short_phase
        rollout_end_phase = n_move_phases_later(rollout_start_phase,
                                                max_rollout_length)
        while True:
            if max_rollout_length == 0:
                # Handled separately.
                break

            # exit loop if all games are done before max_rollout_length
            ongoing_game_phases = [
                game.current_short_phase for game in games
                if not game.is_game_done
            ]
            if len(ongoing_game_phases) == 0:
                break

            # step games together at the pace of the slowest game, e.g. process
            # games with retreat phases alone before moving on to the next move phase
            min_phase = min(ongoing_game_phases, key=sort_phase_key)

            batch_data = []
            for game in games:
                if not game.is_game_done and game.current_short_phase == min_phase:
                    with timings("encode.all_poss_orders"):
                        all_possible_orders = game.get_all_possible_orders()
                    with timings("encode.inputs"):
                        inputs = FeatureEncoder().encode_inputs([game])
                    batch_data.append((game, inputs))

            with timings("cat_pad"):
                xs: List[Tuple] = [b[1] for b in batch_data]
                batch_inputs = cls.cat_pad_inputs(xs)

            with timings("model"):
                if client != value_client:
                    assert (
                        rollout_value_frac == 0
                    ), "If separate value model, you can't add in value each step (slow)"

                cur_client = value_client if min_phase == rollout_end_phase else client

                batch_orders, _, batch_est_final_scores = self.do_model_request(
                    batch_inputs, temperature, top_p, client=cur_client)

            if min_phase == rollout_end_phase:
                with timings("score.accumulate"):
                    for game_idx, (game, _) in enumerate(batch_data):
                        est_final_scores[game.game_id] = np.array(
                            batch_est_final_scores[game_idx])

                # skip env step and exit loop once we've accumulated the estimated
                # scores for all games up to max_rollout_length
                break

            with timings("env"):
                assert len(batch_data) == len(batch_orders), "{} != {}".format(
                    len(batch_data), len(batch_orders))

                # set_orders and process
                assert len(batch_data) == len(batch_orders)
                for (game, _), power_orders in zip(batch_data, batch_orders):
                    if game.is_game_done:
                        continue
                    power_orders = dict(zip(POWERS, power_orders))
                    for other_power in other_powers:
                        game.set_orders(other_power,
                                        list(power_orders[other_power]))

                    assert game.current_short_phase == min_phase
                    game.process()

            for (game, _) in batch_data:
                if game.is_game_done:
                    with timings("score.gameover"):
                        final_scores = np.array(
                            get_square_scores_from_game(game))
                        est_final_scores[game.game_id] = final_scores

            other_powers = POWERS  # no set orders on subsequent turns

        # out of rollout loop

        if max_rollout_length > 0:
            assert len(est_final_scores) == len(games)
        else:
            assert (
                not other_powers
            ), "If max_rollout_length=0 it's assumed that all orders are pre-defined."
            # All orders are set. Step env. Now only need to get values.
            game.process()

            batch_data = []
            for game in games:
                if not game.is_game_done:
                    with timings("encode.inputs"):
                        inputs = FeatureEncoder().encode_inputs([game])
                    batch_data.append((game, inputs))
                else:
                    est_final_scores[
                        game.game_id] = get_square_scores_from_game(game)

            if batch_data:
                with timings("cat_pad"):
                    xs: List[Tuple] = [b[1] for b in batch_data]
                    batch_inputs = self.cat_pad_inputs(xs)

                with timings("model"):
                    _, _, batch_est_final_scores = self.do_model_request(
                        batch_inputs, temperature, top_p, client=value_client)
                    assert batch_est_final_scores.shape[0] == len(batch_data)
                    assert batch_est_final_scores.shape[1] == len(POWERS)
                    for game_idx, (game, _) in enumerate(batch_data):
                        est_final_scores[
                            game.game_id] = batch_est_final_scores[game_idx]

        with timings("final_scores"):
            # get GameScores objects for current game state
            current_game_scores = [{
                p: compute_game_scores_from_state(i, game.get_state())
                for i, p in enumerate(POWERS)
            } for game in games]

            # get estimated or current sum of squares scoring
            final_game_scores = [
                dict(zip(POWERS, est_final_scores[game.game_id]))
                for game, current_scores in zip(games, current_game_scores)
            ]

            # mix in current sum of squares ratio to encourage losing powers to try hard
            if mix_square_ratio_scoring > 0:
                for game, final_scores, current_scores in zip(
                        games, final_game_scores, current_game_scores):
                    for p in POWERS:
                        final_scores[p] = (
                            1 - mix_square_ratio_scoring) * final_scores[p] + (
                                mix_square_ratio_scoring *
                                current_scores[p].square_ratio)
            result = (set_orders_dict, final_game_scores)

        return result, timings
    def preprocess(self):
        """
        Pre-processes dataset
        :return:
        """
        assert not self.debug_only_opening_phase, "FIXME"

        logging.info(
            f"Building dataset from {len(self.game_ids)} games, "
            f"only_with_min_final_score={self.only_with_min_final_score} "
            f"value_decay_alpha={self.value_decay_alpha} cf_agent={self.cf_agent}"
        )

        torch.set_num_threads(1)
        encoder = FeatureEncoder()
        encoded_game_tuples = self.mp_encode_games()

        encoded_games = [g for (_, g) in encoded_game_tuples if g is not None
                         ]  # remove "empty" games (e.g. json didn't exist)

        logging.info(
            f"Found data for {len(encoded_games)} / {len(self.game_ids)} games"
        )

        encoded_games = [
            g for g in encoded_games if g["valid_power_idxs"][0].any()
        ]
        logging.info(
            f"{len(encoded_games)} games had data for at least one power")

        # Update game_ids
        self.game_ids = [
            g_id for (g_id, g) in encoded_game_tuples
            if g_id is not None and g["valid_power_idxs"][0].any()
        ]

        game_idxs, phase_idxs, power_idxs, x_idxs = [], [], [], []
        x_idx = 0
        for game_idx, encoded_game in enumerate(encoded_games):
            for phase_idx, valid_power_idxs in enumerate(
                    encoded_game["valid_power_idxs"]):
                assert valid_power_idxs.nelement() == len(POWERS), (
                    encoded_game["valid_power_idxs"].shape,
                    valid_power_idxs.shape,
                )
                for power_idx in valid_power_idxs.nonzero()[:, 0]:
                    game_idxs.append(game_idx)
                    phase_idxs.append(phase_idx)
                    power_idxs.append(power_idx)
                    x_idxs.append(x_idx)
                x_idx += 1

        self.game_idxs = torch.tensor(game_idxs, dtype=torch.long)
        self.phase_idxs = torch.tensor(phase_idxs, dtype=torch.long)
        self.power_idxs = torch.tensor(power_idxs, dtype=torch.long)
        self.x_idxs = torch.tensor(x_idxs, dtype=torch.long)

        # now collate the data into giant tensors!
        self.encoded_games = DataFields.cat(encoded_games)

        self.num_games = len(encoded_games)
        self.num_phases = len(
            self.encoded_games["x_board_state"]) if self.encoded_games else 0
        self.num_elements = len(self.x_idxs)

        self.validate_dataset()

        self._preprocessed = True
def encode_phase(
    encoder: FeatureEncoder,
    game: Game,
    game_id: str,
    phase_idx: int,
    *,
    only_with_min_final_score: Optional[int],
    cf_agent=None,
    n_cf_agent_samples=1,
    value_decay_alpha,
    input_valid_power_idxs,
    exclude_n_holds,
):
    """
    Arguments:
    - game: Game object
    - game_id: unique id for game
    - phase_idx: int, the index of the phase to encode
    - only_with_min_final_score: if specified, only encode for powers who
      finish the game with some # of supply centers (i.e. only learn from
      winners). MILA uses 7.

    Returns: DataFields, including y_actions and y_final_score
    """
    phase = game.get_phase_history()[phase_idx]
    rolled_back_game = game.rolled_back_to_phase_start(phase.name)
    data_fields = encoder.encode_inputs([rolled_back_game])

    # encode final scores
    y_final_scores = encode_weighted_sos_scores(game, phase_idx,
                                                value_decay_alpha)

    # encode actions
    valid_power_idxs = torch.tensor(input_valid_power_idxs, dtype=torch.bool)
    # print('valid_power_idxs', valid_power_idxs)
    y_actions_lst = []
    power_orders_samples = ({
        power: [phase.orders.get(power, [])]
        for power in POWERS
    } if cf_agent is None else get_cf_agent_order_samples(
        rolled_back_game, phase.name, cf_agent, n_cf_agent_samples))
    for power_i, power in enumerate(POWERS):
        orders_samples = power_orders_samples[power]
        if len(orders_samples) == 0:
            valid_power_idxs[power_i] = False
            y_actions_lst.append(
                torch.empty(n_cf_agent_samples, MAX_SEQ_LEN,
                            dtype=torch.int32).fill_(EOS_IDX))
            continue
        encoded_power_actions_lst = []
        for orders in orders_samples:
            encoded_power_actions, valid = encode_power_actions(
                orders, data_fields["x_possible_actions"][0, power_i])
            encoded_power_actions_lst.append(encoded_power_actions)
            if 0 <= exclude_n_holds <= len(orders):
                if all(o.endswith(" H") for o in orders):
                    valid = 0
            valid_power_idxs[power_i] &= valid
        y_actions_lst.append(torch.stack(encoded_power_actions_lst,
                                         dim=0))  # [N, 17]

    y_actions = torch.stack(y_actions_lst, dim=0)  # [7, N, 17]

    # filter away powers that have no orders
    valid_power_idxs &= (y_actions != EOS_IDX).any(dim=2).all(dim=1)
    assert valid_power_idxs.ndimension() == 1

    # Maybe filter away powers that don't finish with enough SC.
    # If all players finish with fewer SC, include everybody.
    # cf. get_top_victors() in mila's state_space.py
    if only_with_min_final_score is not None:
        final_score = {
            k: len(v)
            for k, v in game.get_state()["centers"].items()
        }
        if max(final_score.values()) >= only_with_min_final_score:
            for i, power in enumerate(POWERS):
                if final_score.get(power, 0) < only_with_min_final_score:
                    valid_power_idxs[i] = 0

    data_fields["y_final_scores"] = y_final_scores.unsqueeze(0)
    data_fields["y_actions"] = y_actions.unsqueeze(0)
    data_fields["valid_power_idxs"] = valid_power_idxs.unsqueeze(0)
    data_fields["x_possible_actions"] = TensorList.from_padded(
        data_fields["x_possible_actions"].view(
            len(POWERS) * MAX_SEQ_LEN, MAX_VALID_LEN),
        padding_value=EOS_IDX,
    )

    return data_fields
def encode_game(
    game_id: Union[int, str],
    data_dir: str,
    only_with_min_final_score=7,
    *,
    cf_agent=None,
    n_cf_agent_samples=1,
    input_valid_power_idxs,
    value_decay_alpha,
    game_metadata,
    exclude_n_holds,
):
    """
    Arguments:
    - game: Game object
    - only_with_min_final_score: if specified, only encode for powers who
      finish the game with some # of supply centers (i.e. only learn from
      winners). MILA uses 7.
    - input_valid_power_idxs: bool tensor, true if power should a priori be included in
      the dataset based on e.g. player rating)
    Return: game_id, DataFields dict of tensors:
    L is game length, P is # of powers above min_final_score, N is n_cf_agent_samples
    - board_state: shape=(L, 81, 35)
    - prev_state: shape=(L, 81, 35)
    - prev_orders: shape=(L, 2, 100), dtype=long
    - power: shape=(L, 7, 7)
    - season: shape=(L, 3)
    - in_adj_phase: shape=(L, 1)
    - build_numbers: shape=(L, 7)
    - final_scores: shape=(L, 7)
    - possible_actions: TensorList shape=(L x 7, 17 x 469)
    - loc_idxs: shape=(L, 7, 81), int8
    - actions: shape=(L, 7, N, 17) int order idxs, N=n_cf_agent_samples
    - valid_power_idxs: shape=(L, 7) bool mask of valid powers at each phase
    """

    torch.set_num_threads(1)
    encoder = FeatureEncoder()

    if isinstance(game_id, str):
        game_path = game_id
    else:  # Hacky fix to handle game_ids that are paths.
        game_path = os.path.join(f"{data_dir}", f"game_{game_id}.json")

    try:
        with open(game_path) as f:
            game = Game.from_json(f.read())
    except (FileNotFoundError, json.decoder.JSONDecodeError) as e:
        print(f"Error while loading game at {game_path}: {e}")
        return None, None

    num_phases = len(game.get_phase_history())
    logging.info(f"Encoding {game.game_id} with {num_phases} phases")

    phase_encodings = [
        encode_phase(
            encoder,
            game,
            game_id,
            phase_idx,
            only_with_min_final_score=only_with_min_final_score,
            cf_agent=cf_agent,
            n_cf_agent_samples=n_cf_agent_samples,
            value_decay_alpha=value_decay_alpha,
            input_valid_power_idxs=input_valid_power_idxs,
            exclude_n_holds=exclude_n_holds,
        ) for phase_idx in range(num_phases)
    ]

    stacked_encodings = DataFields.cat(phase_encodings)

    return game_id, stacked_encodings.to_storage_fmt_()
    def do_rollouts(self,
                    game_init,
                    set_orders_dicts,
                    average_n_rollouts=1,
                    timings=None,
                    log_timings=False):
        if timings is None:
            timings = TimingCtx()

        if self.clear_old_all_possible_orders:
            with timings("clear_old_orders"):
                game_init = pydipcc.Game(game_init)
                game_init.clear_old_all_possible_orders()
        with timings("clone"):
            games = [
                pydipcc.Game(game_init)  # clones
                for _ in set_orders_dicts for _ in range(average_n_rollouts)
            ]
        with timings("setup"):
            for i in range(len(games)):
                games[i].game_id += f"_{i}"
            est_final_scores = {}  # game id -> np.array len=7

            # set orders if specified
            for game, set_orders_dict in zip(
                    games, repeat(set_orders_dicts, average_n_rollouts)):
                for power, orders in set_orders_dict.items():
                    game.set_orders(power, list(orders))

            # for each game, a list of powers whose orders need to be generated
            # by the model on the first phase.
            missing_start_orders = {
                game.game_id:
                frozenset(p for p in POWERS if p not in set_orders_dict)
                for game, set_orders_dict in zip(
                    games, repeat(set_orders_dicts, average_n_rollouts))
            }

        if self.max_rollout_length > 0:
            rollout_end_phase_id = sort_phase_key(
                n_move_phases_later(game_init.current_short_phase,
                                    self.max_rollout_length))
            max_steps = 1000000
        else:
            # Really far ahead.
            rollout_end_phase_id = sort_phase_key(
                n_move_phases_later(game_init.current_short_phase, 10))
            max_steps = 1

        # This loop steps the games until one of the conditions is true:
        #   - all games are done
        #   - at least one game was stepped for max_steps steps
        #   - all games are either completed or reach a phase such that
        #     sort_phase_key(phase) >= rollout_end_phase_id
        for step_id in range(max_steps):
            ongoing_game_phases = [
                game.current_short_phase for game in games
                if not game.is_game_done
            ]

            if len(ongoing_game_phases) == 0:
                # all games are done
                break

            # step games together at the pace of the slowest game, e.g. process
            # games with retreat phases alone before moving on to the next move phase
            min_phase = min(ongoing_game_phases, key=sort_phase_key)

            if sort_phase_key(min_phase) >= rollout_end_phase_id:
                break

            games_to_step = [
                game for game in games if not game.is_game_done
                and game.current_short_phase == min_phase
            ]

            if step_id > 0 or any(missing_start_orders.values()):
                with timings("encoding"):
                    batch_inputs = FeatureEncoder().encode_inputs(
                        games_to_step)

                batch_orders, _, _ = self.do_model_request(
                    batch_inputs,
                    self.rollout_temperature,
                    self.rollout_top_p,
                    timings=timings)

                with timings("env.set_orders"):
                    assert len(games_to_step) == len(batch_orders)
                    for game, orders_per_power in zip(games_to_step,
                                                      batch_orders):
                        for power, orders in zip(POWERS, orders_per_power):
                            if step_id == 0 and power not in missing_start_orders[
                                    game.game_id]:
                                continue
                            game.set_orders(power, list(orders))

            with timings("env.step"):
                self.thread_pool.process_multi(
                    [game for game in games_to_step])

        # Compute SoS for done game and query the net for not-done games.
        not_done_games = [game for game in games if not game.is_game_done]
        if not_done_games:
            with timings("encoding"):
                batch_inputs = FeatureEncoder().encode_inputs(not_done_games)

            batch_est_final_scores = self.do_model_request(
                batch_inputs,
                self.rollout_temperature,
                self.rollout_top_p,
                timings=timings,
                values_only=True,
            )
            for game_idx, game in enumerate(not_done_games):
                est_final_scores[game.game_id] = np.array(
                    batch_est_final_scores[game_idx])
        for game in games:
            if game.is_game_done:
                est_final_scores[game.game_id] = np.array(
                    get_square_scores_from_game(game))

        with timings("final_scores"):
            final_game_scores = [
                dict(zip(POWERS, est_final_scores[game.game_id]))
                for game in games
            ]

            # mix in current sum of squares ratio to encourage losing powers to try hard
            # get GameScores objects for current game state
            if self.mix_square_ratio_scoring > 0:

                for game, final_scores in zip(games, final_game_scores):
                    current_scores = game.get_square_scores()
                    for pi, p in enumerate(POWERS):
                        final_scores[p] = (1 - self.mix_square_ratio_scoring
                                           ) * final_scores[p] + (
                                               self.mix_square_ratio_scoring *
                                               current_scores[pi])

            r = [(set_orders_dict, average_score_dicts(scores_dicts))
                 for set_orders_dict, scores_dicts in zip(
                     set_orders_dicts,
                     groups_of(final_game_scores, average_n_rollouts))]

        if log_timings:
            timings.pprint(logging.getLogger("timings").info)

        return r
def yield_rollouts(
    *,
    exploit_hostports: Sequence[str],
    blueprint_hostports: Optional[Sequence[str]],
    game_json_paths: Optional[Sequence[str]],
    mode: RolloutMode,
    fast_finish: bool = False,
    temperature=0.05,
    max_rollout_length=40,
    batch_size=1,
    seed=0,
) -> Generator[ExploitRollout, None, None]:
    """Do non-stop rollout for 1 (exploit) vs 6 (blueprint).

    This method can safely be called in a subprocess

    Arguments:
    - exploit_hostports: list of "{host}:{port}" of model servers for the
        agents that is training.
    - blueprint_hostports: list of "{host}:{port}" of model servers for the
        agents that is exploited. Ignored in SELFPLAY model
    - game_jsons: either None or a list of paths to json-serialized games.
    - mode: what kind of rollout to do. Defiens what will be outputed.
    - fast_finish: if True, the rollout is stopped once all non-blueprint agents lost.
    - temperature: model softmax temperature for rollout policy on the blueprint agent.
    - max_rollout_length: return SC count after at most # steps
    - batch_size: rollout # of games in parallel
    - seed: random seed.

    yields a ExploitRollout.

    """
    timings = TimingCtx()

    def create_client_selector(hostports):
        clients = []
        for hostport in hostports:
            client = postman.Client(hostport)
            client.connect(3)
            clients.append(client)
        return iter(itertools.cycle(clients))

    exploit_client_selector = create_client_selector(exploit_hostports)
    if mode != RolloutMode.SELFPLAY:
        assert blueprint_hostports is not None
        blueprint_client_selector = create_client_selector(blueprint_hostports)

    faulthandler.register(signal.SIGUSR2)
    torch.set_num_threads(1)

    exploit_power_selector = itertools.cycle(tuple(range(len(POWERS))))
    for _ in range(seed % len(POWERS)):
        next(exploit_power_selector)

    def yield_game():
        nonlocal game_json_paths
        nonlocal seed

        while True:
            if game_json_paths is None:
                yield Game()
            else:
                rng = np.random.RandomState(seed=seed)
                p = game_json_paths[rng.choice(len(game_json_paths))]
                with open(p) as stream:
                    game_serialized = stream.read()
                yield Game.from_json(game_serialized)

    game_selector = yield_game()

    for exploit_power_id in exploit_power_selector:
        if mode == RolloutMode.SELFPLAY:
            # Not used.
            del exploit_power_id
            exploit_ids = frozenset(range(len(POWERS)))
        else:
            exploit_ids = frozenset([exploit_power_id])

        with timings("setup"):
            games = [next(game_selector) for _ in range(batch_size)]
            first_phases = [len(game.get_phase_history()) for game in games]
            turn_idx = 0
            observations = {i: [] for i in range(batch_size)}
            actions = {i: [] for i in range(batch_size)}
            cand_actions = {i: [] for i in range(batch_size)}
            logprobs = {i: [] for i in range(batch_size)}

        while turn_idx < max_rollout_length:
            with timings("prep"):
                batch_data = []
                for batch_idx, game in enumerate(games):
                    if game.is_game_done:
                        continue
                    if fast_finish:
                        last_centers = game.get_state()["centers"]
                        if not any(
                            len(last_centers[p]) > 0
                            for i, p in enumerate(POWERS)
                            if i not in exploit_ids
                        ):
                            continue
                    inputs = FeatureEncoder().encode_inputs([game])
                    batch_data.append((game, inputs, batch_idx))

            if not batch_data:
                # All games are done.
                break

            with timings("cat_pad"):
                xs: List[Tuple] = [b[1] for b in batch_data]
                batch_inputs = cat_pad_inputs(xs)

            with timings("model"):
                if mode != RolloutMode.SELFPLAY:
                    (blueprint_batch_order_ids,) = do_model_request(
                        next(blueprint_client_selector), batch_inputs, temperature
                    )
                (
                    exploit_batch_order_ids,
                    exploit_cand_ids,
                    exploit_order_logprobs,
                ) = do_model_request(next(exploit_client_selector), batch_inputs)

            with timings("merging"):
                if mode == RolloutMode.SELFPLAY:
                    batch_order_idx = exploit_batch_order_ids
                else:
                    # Using all orders from the blueprint model except for ones for the epxloit power.
                    batch_order_idx = blueprint_batch_order_ids
                    batch_order_idx[:, exploit_power_id] = exploit_batch_order_ids[
                        :, exploit_power_id
                    ]

                batch_orders = strigify_orders_idxs(batch_order_idx)

            with timings("env"):
                assert len(batch_data) == len(batch_orders), "{} != {}".format(
                    len(batch_data), len(batch_orders)
                )

                # set_orders and process
                for (game, _, _), power_orders in zip(batch_data, batch_orders):
                    for power, orders in zip(POWERS, power_orders):
                        game.set_orders(power, list(orders))

                for inner_index, (game, _, i) in enumerate(batch_data):
                    if not game.is_game_done:
                        game.process()
                        if mode == RolloutMode.SELFPLAY:
                            actions[i].append(exploit_batch_order_ids[inner_index])
                            cand_actions[i].append(exploit_cand_ids[inner_index])
                            logprobs[i].append(exploit_order_logprobs[inner_index])
                        else:
                            actions[i].append(
                                exploit_batch_order_ids[inner_index, exploit_power_id]
                            )
                            cand_actions[i].append(exploit_cand_ids[inner_index, exploit_power_id])
                            logprobs[i].append(
                                exploit_order_logprobs[inner_index, exploit_power_id]
                            )
                        observations[i].append(batch_inputs.select(inner_index))

            turn_idx += 1

        logging.debug(
            f"end do_rollout pid {os.getpid()} for {batch_size} games in {turn_idx} turns. timings: "
            f"{ {k : float('{:.3}'.format(v)) for k, v in timings.items()} }."
        )
        for i in range(batch_size):
            final_game_json = json.loads(games[i].to_json())
            if mode == RolloutMode.SELFPLAY:
                for power_id in range(len(POWERS)):
                    extended_obs = DataFields.stack(observations[i])
                    extended_obs["cand_indices"] = torch.stack(
                        [x[power_id] for x in cand_actions[i]], 0
                    )
                    yield ExploitRollout(
                        power_id=power_id,
                        game_json=final_game_json,
                        actions=torch.stack([x[power_id] for x in actions[i]], 0),
                        logprobs=torch.stack([x[power_id] for x in logprobs[i]], 0),
                        observations=extended_obs,
                        first_phase=first_phases[i],
                    )
            else:
                extended_obs = DataFields.stack(observations[i])
                extended_obs["cand_indices"] = torch.stack(cand_actions[i], 0)
                yield ExploitRollout(
                    power_id=exploit_power_id,
                    game_json=final_game_json,
                    actions=torch.stack(actions[i], 0),
                    logprobs=torch.stack(logprobs[i], 0),
                    observations=extended_obs,
                    first_phase=first_phases[i],
                )
    def get_plausible_orders(
        self,
        game,
        *,
        n=1000,
        temperature=1.0,
        limit: Union[int, Sequence[int]],  # limit, or list of limits per power
        batch_size=500,
        top_p=1.0,
    ) -> Dict[str, Dict[Tuple[str], float]]:
        assert n % batch_size == 0, f"{n}, {batch_size}"

        # limits is a list of 7 limits
        limits = [limit] * 7 if type(limit) == int else limit
        assert len(limits) == 7
        del limit

        # trivial return case: all powers have at most `limit` actions
        # orderable_locs = game.get_orderable_locations()
        # if max(map(len, orderable_locs.values())) <= 1:
        #     all_orders = game.get_all_possible_orders()
        #     pow_orders = {
        #         p: all_orders[orderable_locs[p][0]] if orderable_locs[p] else [] for p in POWERS
        #     }
        #     if all(len(pow_orders[p]) <= limit for p, limit in zip(POWERS, limits)):
        #         return {p: set((x,) for x in orders) for p, orders in pow_orders.items()}

        # non-trivial return case: query model
        counters = {p: Counter() for p in POWERS}

        x = [FeatureEncoder().encode_inputs([game])] * n

        orders_to_logprobs = {}
        for x_chunk in [x[i : i + batch_size] for i in range(0, n, batch_size)]:
            batch_inputs = self.cat_pad_inputs(x_chunk)
            batch_orders, batch_order_logprobs, _ = self.do_model_request(
                batch_inputs, temperature, top_p
            )
            batch_orders = list(zip(*batch_orders))  # power -> list[orders]
            batch_order_logprobs = batch_order_logprobs.t()  # [7 x B]
            for p, power in enumerate(POWERS):
                counters[power].update(batch_orders[p])

            # slow and steady
            for power_orders, power_scores in zip(batch_orders, batch_order_logprobs):
                for order, score in zip(power_orders, power_scores):
                    if order not in orders_to_logprobs:
                        orders_to_logprobs[order] = score
                    assert (
                        abs(orders_to_logprobs[order] - score) < 1e-2
                    ), f"{order} : {orders_to_logprobs[order]} != {score}"

        logging.info(
            "get_plausible_orders(n={}, t={}) found {} unique sets, choosing top {}".format(
                n, temperature, list(map(len, counters.values())), limits
            )
        )

        # filter out badly-coordinated actions
        counters = {
            power: (
                filter_keys(counter, self.is_plausible_orders) if len(counter) > limit else counter
            )
            for (power, counter), limit in zip(counters.items(), limits)
        }

        most_common = {
            power: sorted(counter.most_common(), key=lambda o: -orders_to_logprobs[o[0]])[:limit]
            for (power, counter), limit in zip(counters.items(), limits)
        }

        # # choose most common
        # most_common = {
        #     power: counter.most_common(limit)
        #     for (power, counter), limit in zip(counters.items(), limits)
        # }

        try:
            logging.info(
                "get_plausible_orders filtered down to {} unique sets, n_0={}, n_cut={}".format(
                    list(map(len, counters.values())),
                    [safe_idx(most_common[p], 0, default=(None, None))[1] for p in POWERS],
                    [
                        safe_idx(most_common[p], limit - 1, default=(None, None))[1]
                        for (p, limit) in zip(POWERS, limits)
                    ],
                )
            )
        except:
            # TODO: remove this if not seen in production
            logging.warning("error in get_plausible_orders logging")

        logging.info("Plausible orders:")
        logging.info("        count,count_frac,prob")
        for power, orders_and_counts in most_common.items():
            logging.info(f"    {power}")
            for orders, count in orders_and_counts:
                logging.info(
                    f"        {count:5d} {count/n:10.5f} {np.exp(orders_to_logprobs[orders]):10.5f}  {orders}"
                )

        return {
            power: {orders: orders_to_logprobs[orders] for orders, _ in orders_and_counts}
            for power, orders_and_counts in most_common.items()
        }