def profile_model(model_path):
    late_game = load_late_game()

    model = load_diplomacy_model(model_path, map_location="cuda", eval=True)

    for game_name, game in [("new_game", Game()), ("late_game", late_game)]:
        print("\n#", game_name)
        inputs = FeatureEncoder().encode_inputs([game])
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

        for batch_size in B:
            b_inputs = {
                k: v.repeat((batch_size, ) + (1, ) * (len(v.shape) - 1))
                for k, v in inputs.items()
            }
            with torch.no_grad():
                tic = time.time()
                for _ in range(N):
                    order_idxs, order_scores, cand_scores, final_scores = model(
                        **b_inputs, temperature=1.0)
                toc = time.time() - tic

                print(
                    f"[B={batch_size}] {toc}s / {N}, latency={1000*toc/N}ms, throughput={N*batch_size/toc}/s"
                )
    def get_orders(self, game, power, *, temperature=None, top_p=None):
        if len(game.get_orderable_locations().get(power, [])) == 0:
            return []

        temperature = temperature if temperature is not None else self.temperature
        top_p = top_p if top_p is not None else self.top_p
        inputs = FeatureEncoder().encode_inputs([game])
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            order_idxs, cand_idxs, logits, final_scores = self.model(
                **inputs, temperature=temperature, top_p=top_p)

        resample_duplicate_disbands_inplace(order_idxs, cand_idxs, logits,
                                            inputs["x_possible_actions"],
                                            inputs["x_in_adj_phase"])
        return decode_order_idxs(order_idxs[0, POWERS.index(power), :])