Ejemplo n.º 1
0
def match_random_policy(parties: Tuple[Party, Party]) -> int:
    field = Field([copy.deepcopy(parties[0]), copy.deepcopy(parties[1])])
    field.rng = GameRNGRandom()
    field.rng.set_field(field)
    field.put_record = lambda record: None

    winner = -1
    next_phase = FieldPhase.BEGIN
    while True:
        actions = []
        for p in range(2):
            legals = field.get_legal_actions(p)
            if len(legals) == 0:
                actions.append(None)
            else:
                actions.append(random.choice(legals))
        if next_phase is FieldPhase.BEGIN:
            field.actions_begin = actions
        elif next_phase is FieldPhase.FAINT_CHANGE:
            field.actions_faint_change = actions
        next_phase = field.step()
        if next_phase is FieldPhase.GAME_END:
            winner = field.winner
            break
        if field.turn_number >= 64:
            break
    return winner
Ejemplo n.º 2
0
class PokeEnv(gym.Env):
    """
    OpenAI Gym互換のポケモンバトル環境
    対戦の両プレイヤーをエージェントにして対戦させる機能も持つ。
    """
    player_party: Party
    enemy_parties: List[Party]
    next_party_idxs: List[int]
    field: Field
    done: bool
    feature_types: List[str]
    MAX_TURNS = 64

    def __init__(self, player_party: Party, enemy_parties: List[Party], feature_types: List[str]):
        super().__init__()
        self.player_party = player_party
        self.enemy_parties = enemy_parties
        self.next_party_idxs = []
        self.field = None
        self.done = True
        self.feature_types = feature_types
        self.action_space = gym.spaces.Discrete(4)  # 技4つ
        self.observation_space = gym.spaces.Box(0.0, 1.0, shape=self._get_observation_shape(), dtype=np.float32)
        self.reward_range = (-1.0, 1.0)

    def reset(self, enemy_party: Optional[Party] = None):
        """
        敵パーティを選択し、フィールドを初期状態にする。
        :return:
        """
        if enemy_party is None:
            if len(self.next_party_idxs) == 0:
                self.next_party_idxs = list(range(len(self.enemy_parties)))
                random.shuffle(self.next_party_idxs)
            enemy_party = self.enemy_parties[self.next_party_idxs.pop()]
        self.field = Field([copy.deepcopy(self.player_party), copy.deepcopy(enemy_party)])
        self.field.put_record = lambda x: None
        self.done = False
        return self._make_observation()

    def step(self, action: int):
        """
        ターンを進める。
        :param action: 選択する技(0,1,2,3)、技がN個ある場合、N以降を指定した場合は0と同等に扱われる
        :return:
        """
        assert not self.done, "call reset before step"
        assert 0 <= action <= 3
        player_possible_actions = self.field.get_legal_actions(0)
        move_idx = action
        # 指定した技が使えるならそれを選択、そうでなければ先頭の技
        # 連続技の最中は選択にかかわらず強制的に技が選ばれる
        player_action = player_possible_actions[0]
        for ppa in player_possible_actions:
            if ppa.action_type is FieldActionType.MOVE and ppa.move_idx == move_idx:
                player_action = ppa
                break
        enemy_action = random.choice(self.field.get_legal_actions(1))
        self.field.actions_begin = [player_action, enemy_action]
        phase = self.field.step()

        reward = 0.0
        if phase is FieldPhase.GAME_END:
            self.done = True
            reward = [1.0, -1.0][self.field.winner]
        else:
            if self.field.turn_number >= PokeEnv.MAX_TURNS:
                # 引き分けで打ち切り
                self.done = True
            if phase is FieldPhase.BEGIN:
                pass
            else:
                # 瀕死交代未実装
                raise NotImplementedError

        return self._make_observation(), reward, self.done, {}

    def match_agents(self, parties: List[Party], action_samplers: List[Callable[[np.ndarray], int]],
                     put_record_func=None) -> int:
        """
        エージェント同士を対戦させる。
        :param parties: 対戦するパーティのリスト。内部でdeepcopyされる。
        :param action_samplers: observation vectorを受け取りactionを返す関数のリスト
        :return: 勝ったパーティの番号(0 or 1)。引き分けなら-1。
        """
        self.field = Field([copy.deepcopy(p) for p in parties])
        if put_record_func is None:
            put_record_func = lambda x: None
        self.field.put_record = put_record_func
        while True:
            # 行動の選択(技0-3のみ)
            actions_begin = []
            for player in range(2):
                obs = self._make_observation(player)
                move_idx = action_samplers[player](obs)
                player_possible_actions = self.field.get_legal_actions(player)
                player_action = player_possible_actions[0]
                for ppa in player_possible_actions:
                    if ppa.action_type is FieldActionType.MOVE and ppa.move_idx == move_idx:
                        player_action = ppa
                        break
                actions_begin.append(player_action)
            self.field.actions_begin = actions_begin
            phase = self.field.step()

            if phase is FieldPhase.GAME_END:
                return self.field.winner
            else:
                if self.field.turn_number >= PokeEnv.MAX_TURNS:
                    # 引き分けで打ち切り
                    return -1
                if phase is FieldPhase.BEGIN:
                    pass
                else:
                    # 瀕死交代未実装
                    raise NotImplementedError

    def _get_observation_shape(self) -> Iterable[int]:
        dims = 0
        if "enemy_type" in self.feature_types:
            dims += PokeType.DRAGON.value - PokeType.NORMAL.value + 1
        if "enemy_dexno" in self.feature_types:
            dims += Dexno.MEW.value - Dexno.BULBASAUR.value + 1
        if "hp_ratio" in self.feature_types:
            dims += 1 * 2
        if "nv_condition" in self.feature_types:
            dims += 6 * 2
        if "rank" in self.feature_types:
            dims += 6 * 2
        return dims,

    def _make_observation(self, player: int = 0) -> np.ndarray:
        """
        現在の局面を表すベクトルを生成する。値域0~1。
        player: 観測側プレイヤー。通常は0。
        :return:
        """
        pokes = [self.field.parties[player].get(), self.field.parties[1 - player].get()]  # 自分、相手
        pokests = [poke.poke_static for poke in pokes]

        feats = []
        if "enemy_type" in self.feature_types:
            feats.append(self._obs_type(pokes[1]))
        if "enemy_dexno" in self.feature_types:
            feats.append(self._obs_dexno(pokests[1]))
        if "hp_ratio" in self.feature_types:
            feats.append(self._obs_hp_ratio(pokes[0]))
            feats.append(self._obs_hp_ratio(pokes[1]))
        if "nv_condition" in self.feature_types:
            feats.append(self._obs_nv_condition(pokes[0]))
            feats.append(self._obs_nv_condition(pokes[1]))
        if "rank" in self.feature_types:
            feats.append(self._obs_rank(pokes[0]))
            feats.append(self._obs_rank(pokes[1]))
        return np.concatenate(feats)

    def _obs_type(self, poke: Poke) -> np.ndarray:
        """
        タイプ(所持しているタイプの次元が1)
        :param poke:
        :return:
        """
        feat = np.zeros(PokeType.DRAGON.value - PokeType.NORMAL.value + 1, dtype=np.float32)
        for t in poke.poke_types:
            feat[t.value - PokeType.NORMAL.value] = 1
        return feat

    def _obs_dexno(self, pokest: PokeStatic) -> np.ndarray:
        """
        図鑑番号(one-hot)
        :param pokest:
        :return:
        """
        feat = np.zeros(Dexno.MEW.value - Dexno.BULBASAUR.value + 1, dtype=np.float32)
        feat[pokest.dexno.value - Dexno.BULBASAUR.value] = 1
        return feat

    def _obs_hp_ratio(self, poke: Poke) -> np.ndarray:
        """
        # 体力割合(満タンが1、ひんしが0)
        :param poke:
        :return:
        """
        feat = np.zeros(1, dtype=np.float32)
        feat[0] = poke.hp / poke.max_hp
        return feat

    def _obs_nv_condition(self, poke: Poke) -> np.ndarray:
        """
        状態異常(one-hot)
        :param poke:
        :return:
        """
        feat = np.zeros(PokeNVCondition.FREEZE.value - PokeNVCondition.EMPTY.value + 1, dtype=np.float32)
        feat[poke.nv_condition.value - PokeNVCondition.EMPTY.value] = 1
        return feat

    def _obs_rank(self, poke: Poke) -> np.ndarray:
        """
        ランク補正(-6~6を0~1に線形変換)
        :param poke:
        :return:
        """
        feat = np.zeros(6, dtype=np.float32)
        for i, rank in enumerate(
                [poke.rank_a, poke.rank_b, poke.rank_c, poke.rank_s, poke.rank_evasion, poke.rank_accuracy]):
            feat[i] = (rank.value + 6) / 12.0
        return feat