def load_games(game_jsons) -> List[Game]:
    if not game_jsons:
        return [Game()]

    games = []
    for path in game_jsons:
        with open(path) as f:
            game_json = f.read()
        games.append(Game.from_json(game_json))
    return games
    def yield_game():
        nonlocal game_json_paths
        nonlocal seed

        while True:
            if game_json_paths is None:
                yield Game()
            else:
                rng = np.random.RandomState(seed=seed)
                p = game_json_paths[rng.choice(len(game_json_paths))]
                with open(p) as stream:
                    game_serialized = stream.read()
                yield Game.from_json(game_serialized)
def profile_model(model_path):
    late_game = load_late_game()

    model = load_diplomacy_model(model_path, map_location="cuda", eval=True)

    for game_name, game in [("new_game", Game()), ("late_game", late_game)]:
        print("\n#", game_name)
        inputs = FeatureEncoder().encode_inputs([game])
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

        for batch_size in B:
            b_inputs = {
                k: v.repeat((batch_size, ) + (1, ) * (len(v.shape) - 1))
                for k, v in inputs.items()
            }
            with torch.no_grad():
                tic = time.time()
                for _ in range(N):
                    order_idxs, order_scores, cand_scores, final_scores = model(
                        **b_inputs, temperature=1.0)
                toc = time.time() - tic

                print(
                    f"[B={batch_size}] {toc}s / {N}, latency={1000*toc/N}ms, throughput={N*batch_size/toc}/s"
                )
def load_late_game():
    with open(
            "/checkpoint/jsgray/diplomacy/slurm/cmp_mila__mila/game_TUR.2.json",
            "r") as f:
        late_game = Game.from_json(f.read())
    late_game.set_phase_data(late_game.get_phase_history()[-2])
    return late_game
Ejemplo n.º 5
0
class Env:
    def __init__(self,
                 policy_profile,
                 seed=0,
                 cf_agent=None,
                 max_year=PYDIPCC_MAX_YEAR,
                 game_obj=None):
        self.game = Game(game_obj) if game_obj is not None else Game()

        # set random seeds
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        self.policy_profile = policy_profile
        self.cf_agent = cf_agent
        assert (max_year <= PYDIPCC_MAX_YEAR
                ), f"pydipcc doesn't allow to go beyond {PYDIPCC_MAX_YEAR}"
        self.max_year = max_year

    def process_turn(self, timeout=10):
        logging.debug("Starting turn {}".format(self.game.phase))

        power_orders = self.policy_profile.get_all_power_orders(self.game)

        for power, orders in power_orders.items():
            if not self.game.get_orderable_locations().get(power):
                logging.debug(f"Skipping orders for {power}")
                continue
            logging.info("Set orders {} {} {}".format(
                self.game.current_short_phase, power, orders))
            if self.cf_agent:
                cf_orders = self.cf_agent.get_orders(self.game, power)
                logging.debug("CF  orders {} {} {}".format(
                    self.game.current_short_phase, power, cf_orders))
            self.game.set_orders(power, orders)

        self.game.process()

    def process_all_turns(self, max_turns=0):
        """Process all turns until game is over

        Returns a dict mapping power -> supply count
        """
        turn_id = 0
        while not self.game.is_game_done:
            if max_turns and turn_id >= max_turns:
                break
            _, year, _ = self.game.phase.split()
            if int(year) > self.max_year:
                break
            self.process_turn()
            turn_id += 1

        return {k: len(v) for k, v in self.game.get_state()["centers"].items()}

    def save(self, output_path):
        logging.info("Saving to {}".format(output_path))
        with open(output_path, "w") as stream:
            stream.write(self.game.to_json())
Ejemplo n.º 6
0
    def __init__(self,
                 policy_profile,
                 seed=0,
                 cf_agent=None,
                 max_year=PYDIPCC_MAX_YEAR,
                 game_obj=None):
        self.game = Game(game_obj) if game_obj is not None else Game()

        # set random seeds
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        self.policy_profile = policy_profile
        self.cf_agent = cf_agent
        assert (max_year <= PYDIPCC_MAX_YEAR
                ), f"pydipcc doesn't allow to go beyond {PYDIPCC_MAX_YEAR}"
        self.max_year = max_year
def compute_xpower_supports_from_saved(path, max_year=None, cf_agent=None):
    """Computes cross-power supports from a JSON file in fairdiplomacy saved game format."""
    with open(path) as f:
        try:
            j = json.load(f)
        except json.JSONDecodeError as e:
            raise type(e)(f"Error loading {path}: \n {e.message}")

    game = Game.from_saved_game_format(j)
    return compute_xpower_supports(game,
                                   max_year=max_year,
                                   cf_agent=cf_agent,
                                   name=os.path.basename(path))
def encode_phase(
    encoder: FeatureEncoder,
    game: Game,
    game_id: str,
    phase_idx: int,
    *,
    only_with_min_final_score: Optional[int],
    cf_agent=None,
    n_cf_agent_samples=1,
    value_decay_alpha,
    input_valid_power_idxs,
    exclude_n_holds,
):
    """
    Arguments:
    - game: Game object
    - game_id: unique id for game
    - phase_idx: int, the index of the phase to encode
    - only_with_min_final_score: if specified, only encode for powers who
      finish the game with some # of supply centers (i.e. only learn from
      winners). MILA uses 7.

    Returns: DataFields, including y_actions and y_final_score
    """
    phase = game.get_phase_history()[phase_idx]
    rolled_back_game = game.rolled_back_to_phase_start(phase.name)
    data_fields = encoder.encode_inputs([rolled_back_game])

    # encode final scores
    y_final_scores = encode_weighted_sos_scores(game, phase_idx,
                                                value_decay_alpha)

    # encode actions
    valid_power_idxs = torch.tensor(input_valid_power_idxs, dtype=torch.bool)
    # print('valid_power_idxs', valid_power_idxs)
    y_actions_lst = []
    power_orders_samples = ({
        power: [phase.orders.get(power, [])]
        for power in POWERS
    } if cf_agent is None else get_cf_agent_order_samples(
        rolled_back_game, phase.name, cf_agent, n_cf_agent_samples))
    for power_i, power in enumerate(POWERS):
        orders_samples = power_orders_samples[power]
        if len(orders_samples) == 0:
            valid_power_idxs[power_i] = False
            y_actions_lst.append(
                torch.empty(n_cf_agent_samples, MAX_SEQ_LEN,
                            dtype=torch.int32).fill_(EOS_IDX))
            continue
        encoded_power_actions_lst = []
        for orders in orders_samples:
            encoded_power_actions, valid = encode_power_actions(
                orders, data_fields["x_possible_actions"][0, power_i])
            encoded_power_actions_lst.append(encoded_power_actions)
            if 0 <= exclude_n_holds <= len(orders):
                if all(o.endswith(" H") for o in orders):
                    valid = 0
            valid_power_idxs[power_i] &= valid
        y_actions_lst.append(torch.stack(encoded_power_actions_lst,
                                         dim=0))  # [N, 17]

    y_actions = torch.stack(y_actions_lst, dim=0)  # [7, N, 17]

    # filter away powers that have no orders
    valid_power_idxs &= (y_actions != EOS_IDX).any(dim=2).all(dim=1)
    assert valid_power_idxs.ndimension() == 1

    # Maybe filter away powers that don't finish with enough SC.
    # If all players finish with fewer SC, include everybody.
    # cf. get_top_victors() in mila's state_space.py
    if only_with_min_final_score is not None:
        final_score = {
            k: len(v)
            for k, v in game.get_state()["centers"].items()
        }
        if max(final_score.values()) >= only_with_min_final_score:
            for i, power in enumerate(POWERS):
                if final_score.get(power, 0) < only_with_min_final_score:
                    valid_power_idxs[i] = 0

    data_fields["y_final_scores"] = y_final_scores.unsqueeze(0)
    data_fields["y_actions"] = y_actions.unsqueeze(0)
    data_fields["valid_power_idxs"] = valid_power_idxs.unsqueeze(0)
    data_fields["x_possible_actions"] = TensorList.from_padded(
        data_fields["x_possible_actions"].view(
            len(POWERS) * MAX_SEQ_LEN, MAX_VALID_LEN),
        padding_value=EOS_IDX,
    )

    return data_fields
def encode_game(
    game_id: Union[int, str],
    data_dir: str,
    only_with_min_final_score=7,
    *,
    cf_agent=None,
    n_cf_agent_samples=1,
    input_valid_power_idxs,
    value_decay_alpha,
    game_metadata,
    exclude_n_holds,
):
    """
    Arguments:
    - game: Game object
    - only_with_min_final_score: if specified, only encode for powers who
      finish the game with some # of supply centers (i.e. only learn from
      winners). MILA uses 7.
    - input_valid_power_idxs: bool tensor, true if power should a priori be included in
      the dataset based on e.g. player rating)
    Return: game_id, DataFields dict of tensors:
    L is game length, P is # of powers above min_final_score, N is n_cf_agent_samples
    - board_state: shape=(L, 81, 35)
    - prev_state: shape=(L, 81, 35)
    - prev_orders: shape=(L, 2, 100), dtype=long
    - power: shape=(L, 7, 7)
    - season: shape=(L, 3)
    - in_adj_phase: shape=(L, 1)
    - build_numbers: shape=(L, 7)
    - final_scores: shape=(L, 7)
    - possible_actions: TensorList shape=(L x 7, 17 x 469)
    - loc_idxs: shape=(L, 7, 81), int8
    - actions: shape=(L, 7, N, 17) int order idxs, N=n_cf_agent_samples
    - valid_power_idxs: shape=(L, 7) bool mask of valid powers at each phase
    """

    torch.set_num_threads(1)
    encoder = FeatureEncoder()

    if isinstance(game_id, str):
        game_path = game_id
    else:  # Hacky fix to handle game_ids that are paths.
        game_path = os.path.join(f"{data_dir}", f"game_{game_id}.json")

    try:
        with open(game_path) as f:
            game = Game.from_json(f.read())
    except (FileNotFoundError, json.decoder.JSONDecodeError) as e:
        print(f"Error while loading game at {game_path}: {e}")
        return None, None

    num_phases = len(game.get_phase_history())
    logging.info(f"Encoding {game.game_id} with {num_phases} phases")

    phase_encodings = [
        encode_phase(
            encoder,
            game,
            game_id,
            phase_idx,
            only_with_min_final_score=only_with_min_final_score,
            cf_agent=cf_agent,
            n_cf_agent_samples=n_cf_agent_samples,
            value_decay_alpha=value_decay_alpha,
            input_valid_power_idxs=input_valid_power_idxs,
            exclude_n_holds=exclude_n_holds,
        ) for phase_idx in range(num_phases)
    ]

    stacked_encodings = DataFields.cat(phase_encodings)

    return game_id, stacked_encodings.to_storage_fmt_()
def compute_xpower_supports(
    game: Union[Game, pydipcc.Game],
    max_year: Optional[str] = None,
    cf_agent: Optional[fairdiplomacy.agents.base_agent.BaseAgent] = None,
    name: str = "",
):
    """Computes average cross-power supports for an entire game.

        Arguments:
        - game: a Game object
        - max_year: If set, only compute cross-power supports up to this year.
        - cf_agent: If set, look at supports orders generated by `cf_agent`,  in the context of the game state
                    and game orders generated from the other powers.
        - name: a label for this game, returned in the output.

        Returns a dict of statistics for this game.
          name --> copied name from the argument
          o --> total number of orders
          s --> total number of support orders
          x --> number of support orders that are cross-power
          e --> number of cross-power support orders that "had an effect" (vs. a hold order for that unit)
        """

    num_orders, num_supports, num_xpower, num_eff = 0, 0, 0, 0
    for phase in game.get_phase_history():
        state = phase.state
        if state["name"][1:-1] == max_year:
            break
        if not state["name"].endswith("M"):
            # This is required as, e.g., in retreat phase a single location
            # could be occupied by several powers and everything is weird.
            continue
        loc_power = {
            unit.split()[1]: power
            for power, units in state["units"].items() for unit in units
        }
        # If power owns, e.g., BUL/SC, make it also own BUL. Support targets do
        # not use "/SC" so we need both.
        for loc, power in list(loc_power.items()):
            if "/" in loc:
                loc_land = loc.split("/")[0]
                assert loc_land not in loc_power, (loc_land, loc_power)
                loc_power[loc_land] = power

        if cf_agent is not None:
            assert isinstance(game, Game), "Not implemented"
            cf_orders = cf_agent.get_orders_many_powers(Game.clone_from(
                game, up_to_phase=state["name"]),
                                                        powers=POWERS)

        for power, power_orders in phase.orders.items():
            if cf_agent is not None:
                power_orders = cf_orders[power]

            for order in power_orders:
                num_orders += 1
                order_tokens = order.split()
                is_support = (len(order_tokens) >= 5 and order_tokens[2] == "S"
                              and order_tokens[3] in ("A", "F"))
                if not is_support:
                    continue
                num_supports += 1
                src = order_tokens[4]
                if src not in loc_power:
                    if not isinstance(game, Game):
                        # Only support dumping for pydicc games.
                        torch.save(
                            dict(
                                game=game.to_json(),
                                power=power,
                                pwer_orders=power_orders,
                                order=order,
                                phase=phase,
                                loc_power=loc_power,
                            ),
                            "xpower_debug.pt",
                        )
                    raise RuntimeError(f"{order}: {src} not in {loc_power}")
                if loc_power[src] == power:
                    continue
                num_xpower += 1

                cf_states = []
                for do_support in (False, True):
                    g_cf = game.rolled_back_to_phase_start(state["name"])
                    g_cf.set_orders(power, power_orders)

                    assert g_cf.get_state()["name"] == state["name"]
                    if not do_support:
                        hold_order = " ".join(order_tokens[:2] + ["H"])
                        g_cf.set_orders(power, [hold_order])

                    g_cf.process()
                    assert g_cf.get_state()["name"] != state["name"]
                    s = g_cf.get_state()
                    cf_states.append((s["name"], s["units"], s["retreats"]))

                if cf_states[0] != cf_states[1]:
                    num_eff += 1

    return {
        "name": name,
        "s": num_supports,
        "x": num_xpower,
        "e": num_eff,
        "o": num_orders
    }
        return abs(game_state["builds"].get(power, {"count": 0})["count"])
    if game_state["name"][-1] == "R":
        return len(game_state["retreats"].get(power, []))
    else:
        return len(game_state["units"].get(power, []))


if __name__ == "__main__":
    from fairdiplomacy.pydipcc import Game

    logging.basicConfig(format="%(asctime)s [%(levelname)s]: %(message)s",
                        level=logging.INFO)

    np.random.seed(0)
    torch.manual_seed(0)

    agent = FP1PAgent(
        n_rollouts=10,
        max_rollout_length=5,
        model_path,
        postman_sync_batches=False,
        rollout_temperature=0.5,
        n_rollout_procs=24 * 7,
        rollout_top_p=0.9,
        mix_square_ratio_scoring=0.1,
        n_plausible_orders=24,
        average_n_rollouts=3,
        bp_iters=3,
    )
    print(agent.get_orders(Game(), "GERMANY"))
Ejemplo n.º 12
0
def run_situation_check(meta, agent, extra_plausible_orders_str: str = ""):
    extra_plausible_orders: Optional[Dict[str, List[Tuple[str, ...]]]]
    if extra_plausible_orders_str:
        assert isinstance(agent, SearchBotAgent)
        extra_plausible_orders = _parse_extra_plausible_orders(
            extra_plausible_orders_str)
    else:
        extra_plausible_orders = None
    results = {}
    for name, config in meta.items():
        logging.info("=" * 80)
        comment = config.get("comment", "")
        logging.info(f"{name}: {comment} (phase={config.get('phase')})")
        # If path is not absolute, treat as relative to code root.
        game_path = heyhi.PROJ_ROOT / config["game_path"]
        logging.info(f"path: {game_path}")
        with open(game_path) as f:
            game = Game.from_json(f.read())
        if "phase" in config:
            game.rolled_back_to_phase_start(config["phase"])

        if hasattr(agent, "get_all_power_prob_distributions"):
            if isinstance(agent, SearchBotAgent):
                prob_distributions = agent.get_all_power_prob_distributions(
                    game, extra_plausible_orders=extra_plausible_orders)
            else:
                prob_distributions = agent.get_all_power_prob_distributions(
                    game)  # FIXME: early exit
            logging.info("CFR strategy:")
        else:
            # this is a supervised agent, sample N times to get a distribution
            NUM_ROLLOUTS = 100
            prob_distributions = {p: defaultdict(float) for p in POWERS}
            for power in POWERS:
                for N in range(NUM_ROLLOUTS):
                    orders = agent.get_orders(game, power)
                    prob_distributions[power][tuple(
                        orders)] += 1 / NUM_ROLLOUTS

        if hasattr(agent, "get_values"):
            logging.info(
                "Values: %s",
                " ".join(f"{p}={v:.3f}"
                         for p, v in zip(POWERS, agent.get_values(game))),
            )
        for power in POWERS:
            pd = prob_distributions[power]
            pdl = sorted(list(pd.items()), key=lambda x: -x[1])
            logging.info(f"   {power}")

            for order, prob in pdl:
                if prob < 0.02:
                    break
                logging.info(f"       {prob:5.2f} {order}")

        for i, (test_desc,
                test_func_str) in enumerate(config.get("tests", {}).items()):
            test_func = eval(test_func_str)
            passed = test_func(prob_distributions)
            results[f"{name}.{i}"] = int(passed)
            res_string = "PASSED" if passed else "FAILED"
            logging.info(f"Result: {res_string:8s}  {name:20s} {test_desc}")
            logging.info(f"        {test_func_str}")
    logging.info("Passed: %d/%d", sum(results.values()), len(results))
    logging.info("JSON: %s", results)
    return results
Ejemplo n.º 13
0
            #     #         total_action_utilities,
            #     #     )
            #     # )

        logging.info("cum_strats= {}".format(self.cum_sigma))
        # return best order: sample from average policy
        ps = self.avg_strategy(power, power_plausible_orders[power])
        idx = np.random.choice(range(len(ps)), p=ps)
        return list(power_plausible_orders[power][idx])

    def strategy(self, power, actions) -> List[float]:
        try:
            return [self.sigma[(power, a)] for a in actions]
        except KeyError:
            return [1.0 / len(actions) for _ in actions]

    def avg_strategy(self, power, actions) -> List[float]:
        sigmas = [self.cum_sigma[(power, a)] for a in actions]
        sum_sigmas = sum(sigmas)
        if sum_sigmas == 0:
            return [1 / len(actions) for _ in actions]
        else:
            return [s / sum_sigmas for s in sigmas]


if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)s [%(levelname)s]: %(message)s",
                        level=logging.DEBUG)

    print(CE1PAgent().get_orders(Game(), "ITALY"))