def match_random_policy(parties: Tuple[Party, Party]) -> int: field = Field([copy.deepcopy(parties[0]), copy.deepcopy(parties[1])]) field.rng = GameRNGRandom() field.rng.set_field(field) field.put_record = lambda record: None winner = -1 next_phase = FieldPhase.BEGIN while True: actions = [] for p in range(2): legals = field.get_legal_actions(p) if len(legals) == 0: actions.append(None) else: actions.append(random.choice(legals)) if next_phase is FieldPhase.BEGIN: field.actions_begin = actions elif next_phase is FieldPhase.FAINT_CHANGE: field.actions_faint_change = actions next_phase = field.step() if next_phase is FieldPhase.GAME_END: winner = field.winner break if field.turn_number >= 64: break return winner
class PokeEnv(gym.Env): """ OpenAI Gym互換のポケモンバトル環境 対戦の両プレイヤーをエージェントにして対戦させる機能も持つ。 """ player_party: Party enemy_parties: List[Party] next_party_idxs: List[int] field: Field done: bool feature_types: List[str] MAX_TURNS = 64 def __init__(self, player_party: Party, enemy_parties: List[Party], feature_types: List[str]): super().__init__() self.player_party = player_party self.enemy_parties = enemy_parties self.next_party_idxs = [] self.field = None self.done = True self.feature_types = feature_types self.action_space = gym.spaces.Discrete(4) # 技4つ self.observation_space = gym.spaces.Box(0.0, 1.0, shape=self._get_observation_shape(), dtype=np.float32) self.reward_range = (-1.0, 1.0) def reset(self, enemy_party: Optional[Party] = None): """ 敵パーティを選択し、フィールドを初期状態にする。 :return: """ if enemy_party is None: if len(self.next_party_idxs) == 0: self.next_party_idxs = list(range(len(self.enemy_parties))) random.shuffle(self.next_party_idxs) enemy_party = self.enemy_parties[self.next_party_idxs.pop()] self.field = Field([copy.deepcopy(self.player_party), copy.deepcopy(enemy_party)]) self.field.put_record = lambda x: None self.done = False return self._make_observation() def step(self, action: int): """ ターンを進める。 :param action: 選択する技(0,1,2,3)、技がN個ある場合、N以降を指定した場合は0と同等に扱われる :return: """ assert not self.done, "call reset before step" assert 0 <= action <= 3 player_possible_actions = self.field.get_legal_actions(0) move_idx = action # 指定した技が使えるならそれを選択、そうでなければ先頭の技 # 連続技の最中は選択にかかわらず強制的に技が選ばれる player_action = player_possible_actions[0] for ppa in player_possible_actions: if ppa.action_type is FieldActionType.MOVE and ppa.move_idx == move_idx: player_action = ppa break enemy_action = random.choice(self.field.get_legal_actions(1)) self.field.actions_begin = [player_action, enemy_action] phase = self.field.step() reward = 0.0 if phase is FieldPhase.GAME_END: self.done = True reward = [1.0, -1.0][self.field.winner] else: if self.field.turn_number >= PokeEnv.MAX_TURNS: # 引き分けで打ち切り self.done = True if phase is FieldPhase.BEGIN: pass else: # 瀕死交代未実装 raise NotImplementedError return self._make_observation(), reward, self.done, {} def match_agents(self, parties: List[Party], action_samplers: List[Callable[[np.ndarray], int]], put_record_func=None) -> int: """ エージェント同士を対戦させる。 :param parties: 対戦するパーティのリスト。内部でdeepcopyされる。 :param action_samplers: observation vectorを受け取りactionを返す関数のリスト :return: 勝ったパーティの番号(0 or 1)。引き分けなら-1。 """ self.field = Field([copy.deepcopy(p) for p in parties]) if put_record_func is None: put_record_func = lambda x: None self.field.put_record = put_record_func while True: # 行動の選択(技0-3のみ) actions_begin = [] for player in range(2): obs = self._make_observation(player) move_idx = action_samplers[player](obs) player_possible_actions = self.field.get_legal_actions(player) player_action = player_possible_actions[0] for ppa in player_possible_actions: if ppa.action_type is FieldActionType.MOVE and ppa.move_idx == move_idx: player_action = ppa break actions_begin.append(player_action) self.field.actions_begin = actions_begin phase = self.field.step() if phase is FieldPhase.GAME_END: return self.field.winner else: if self.field.turn_number >= PokeEnv.MAX_TURNS: # 引き分けで打ち切り return -1 if phase is FieldPhase.BEGIN: pass else: # 瀕死交代未実装 raise NotImplementedError def _get_observation_shape(self) -> Iterable[int]: dims = 0 if "enemy_type" in self.feature_types: dims += PokeType.DRAGON.value - PokeType.NORMAL.value + 1 if "enemy_dexno" in self.feature_types: dims += Dexno.MEW.value - Dexno.BULBASAUR.value + 1 if "hp_ratio" in self.feature_types: dims += 1 * 2 if "nv_condition" in self.feature_types: dims += 6 * 2 if "rank" in self.feature_types: dims += 6 * 2 return dims, def _make_observation(self, player: int = 0) -> np.ndarray: """ 現在の局面を表すベクトルを生成する。値域0~1。 player: 観測側プレイヤー。通常は0。 :return: """ pokes = [self.field.parties[player].get(), self.field.parties[1 - player].get()] # 自分、相手 pokests = [poke.poke_static for poke in pokes] feats = [] if "enemy_type" in self.feature_types: feats.append(self._obs_type(pokes[1])) if "enemy_dexno" in self.feature_types: feats.append(self._obs_dexno(pokests[1])) if "hp_ratio" in self.feature_types: feats.append(self._obs_hp_ratio(pokes[0])) feats.append(self._obs_hp_ratio(pokes[1])) if "nv_condition" in self.feature_types: feats.append(self._obs_nv_condition(pokes[0])) feats.append(self._obs_nv_condition(pokes[1])) if "rank" in self.feature_types: feats.append(self._obs_rank(pokes[0])) feats.append(self._obs_rank(pokes[1])) return np.concatenate(feats) def _obs_type(self, poke: Poke) -> np.ndarray: """ タイプ(所持しているタイプの次元が1) :param poke: :return: """ feat = np.zeros(PokeType.DRAGON.value - PokeType.NORMAL.value + 1, dtype=np.float32) for t in poke.poke_types: feat[t.value - PokeType.NORMAL.value] = 1 return feat def _obs_dexno(self, pokest: PokeStatic) -> np.ndarray: """ 図鑑番号(one-hot) :param pokest: :return: """ feat = np.zeros(Dexno.MEW.value - Dexno.BULBASAUR.value + 1, dtype=np.float32) feat[pokest.dexno.value - Dexno.BULBASAUR.value] = 1 return feat def _obs_hp_ratio(self, poke: Poke) -> np.ndarray: """ # 体力割合(満タンが1、ひんしが0) :param poke: :return: """ feat = np.zeros(1, dtype=np.float32) feat[0] = poke.hp / poke.max_hp return feat def _obs_nv_condition(self, poke: Poke) -> np.ndarray: """ 状態異常(one-hot) :param poke: :return: """ feat = np.zeros(PokeNVCondition.FREEZE.value - PokeNVCondition.EMPTY.value + 1, dtype=np.float32) feat[poke.nv_condition.value - PokeNVCondition.EMPTY.value] = 1 return feat def _obs_rank(self, poke: Poke) -> np.ndarray: """ ランク補正(-6~6を0~1に線形変換) :param poke: :return: """ feat = np.zeros(6, dtype=np.float32) for i, rank in enumerate( [poke.rank_a, poke.rank_b, poke.rank_c, poke.rank_s, poke.rank_evasion, poke.rank_accuracy]): feat[i] = (rank.value + 6) / 12.0 return feat