Beispiel #1
0
            mode="max",
        )
    else:
        best_checkpoint = "/home/jippo/ray_results/YanivTrainer_2021-03-28_13-22-27/YanivTrainer_yaniv_4258f_00000_0_2021-03-28_13-22-27/checkpoint_60/checkpoint_60/checkpoint-60"

    agent = ppo.PPOTrainer(config=config, env="yaniv")
    agent.restore(best_checkpoint)

    rule_agent = YanivNoviceRuleAgent(single_step=True)

    env = YanivEnv(env_config)

    for _ in range(args.eval_num):
        episode_reward = 0
        done = {"__all__": False}
        obs = env.reset()

        agent_id = "player_0"
        rules_id = "player_1"

        steps = 0
        while not done["__all__"]:
            if env.current_player == 0:
                action = agent.compute_action(obs[agent_id])
                obs, reward, done, info = env.step({agent_id: action})
            else:
                state = env.game.get_state(1)
                extracted_state = {}
                extracted_state["raw_obs"] = state
                extracted_state["raw_legal_actions"] = [
                    a for a in state["legal_actions"]
Beispiel #2
0
class YanivTournament:
    def __init__(self, env_config, trainers=[], opponent="novice"):
        self.env_config = env_config
        self.trainers = trainers

        if opponent == "novice":
            self.rule_agent = YanivNoviceRuleAgent(
                single_step=env_config.get("single_step", True))
        elif opponent == "intermediate":
            self.rule_agent = YanivIntermediateRuleAgent(
                single_step=env_config.get("single_step", True))
        else:
            raise ValueError("opponent wrong {}".format(opponent))

        self.env = YanivEnv(env_config)

        self.players = []
        for i in range(self.env.num_players):
            if i < len(self.trainers):
                self.players.append(self.trainers[i])
            else:
                self.players.append(self.rule_agent)

        self.reset_stats()

    def run_episode(self, render=False):
        obs = self.env.reset()
        if render:
            self.env.game.render()

        done = {"__all__": False}

        states = [
            t.get_policy("policy_1").model.get_initial_state()
            for t in self.trainers
        ]

        steps = 0
        while not done["__all__"]:
            player = self.players[self.env.current_player]
            player_id = self.env.current_player_string

            if player in self.trainers:
                action, state, _ = player.compute_action(
                    obs[player_id],
                    policy_id="policy_1",
                    state=states[self.env.current_player],
                    full_fetch=True,
                )
                states[self.env.current_player] = state

                if self.env.game.round.discarding:
                    dec_action = self.env._decode_action(action)
                    if dec_action != utils.YANIV_ACTION:
                        self.player_stats[player_id]["discard_freqs"][str(
                            int(len(dec_action) / 2))] += 1
                else:
                    pickup_action = self.env._decode_action(action)
                    self.player_stats[player_id]["pickup_freqs"][
                        pickup_action] += 1

                obs, reward, done, info = self.env.step({player_id: action})
            else:
                state = self.env.game.get_state(self.env.current_player)
                extracted_state = {}
                extracted_state["raw_obs"] = state
                extracted_state["raw_legal_actions"] = [
                    a for a in state["legal_actions"]
                ]
                action = self.rule_agent.step(extracted_state)

                if self.env.game.round.discarding:
                    if action != utils.YANIV_ACTION:
                        self.player_stats[player_id]["discard_freqs"][str(
                            int(len(action) / 2))] += 1
                else:
                    self.player_stats[player_id]["pickup_freqs"][action] += 1

                obs, reward, done, info = self.env.step({player_id: action},
                                                        raw_action=True)

            steps += 1

            if render:
                self.env.game.render()

        self.game_stats["avg_roundlen"] += steps

        winner = self.env.game.round.winner
        if winner == -1:
            self.game_stats["avg_draws"] += 1
        else:
            winner_id = self.env._get_player_string(winner)
            self.player_stats[winner_id]["avg_wins"] += 1
            self.player_stats[winner_id]["winning_hands"].append(
                utils.get_hand_score(self.env.game.players[winner].hand))

        assaf = self.env.game.round.assaf
        if assaf is not None:
            self.player_stats[self.env._get_player_string(
                assaf)]["avg_assafs"] += 1

        s = self.env.game.round.scores
        if s is not None:
            for i in range(self.env.num_players):
                if s[i] > 0:
                    self.player_stats[self.env._get_player_string(
                        i)]["scores"].append(s[i])

        self.games_played += 1

    def reset_game(self):
        self.scores = [[] for _ in range(self.env.num_players)]

    def run_game(self):
        self.reset_game()
        self.run_episode()

    def run(self, eval_num):
        self.reset_stats()

        for _ in range(eval_num):
            self.run_episode()

        return self.get_average_stats()

    def reset_stats(self):
        self.games_played = 0

        self.game_stats = {
            "avg_roundlen": 0,
            "avg_draws": 0,
        }

        self.player_stats = {
            player_id: {
                "avg_wins": 0,
                "avg_assafs": 0,
                "scores": [],
                "winning_hands": [],
                "discard_freqs": {
                    "1": 0,
                    "2": 0,
                    "3": 0,
                    "4": 0,
                    "5": 0,
                },
                "pickup_freqs": {a: 0
                                 for a in utils.pickup_actions},
            }
            for player_id in self.env._get_players()
        }

    def get_average_stats(self):
        stats = {
            "game": deepcopy(self.game_stats),
            "player": deepcopy(self.player_stats),
        }

        for key in stats["game"].keys():
            if key.startswith("avg"):
                stats["game"][key] /= self.games_played

        for player_stats in stats["player"].values():
            for key in player_stats:
                if key.startswith("avg"):
                    player_stats[key] /= self.games_played

            player_stats["avg_losing_score"] = (np.mean(
                player_stats["scores"]) if len(player_stats["scores"]) > 0 else
                                                0)
            player_stats.pop("scores")

            player_stats["avg_winning_hand"] = (
                np.mean(player_stats["winning_hands"])
                if len(player_stats["winning_hands"]) > 0 else 0)
            player_stats.pop("winning_hands")

        return stats

    def print_stats(self):
        avg_stats = self.get_average_stats()
        cleaned = json.dumps(avg_stats)

        print(yaml.safe_dump(json.loads(cleaned), default_flow_style=False))