def reset(self, value: float, reward: int, policy_logits, hidden_state): self.history_trees = [] self.root = Node(reward=reward, value=value, policy_logits=policy_logits, hidden_state=hidden_state, action=0, to_play=Player(0)) self.root.expand(to_play=Player(0), legal_actions=self.legal_actions, min_max_stats=self.min_max_stats)
def __init__(self, environment: core.Env, number_players: int, discount: float, max_moves: int): """ :param environment: The gym environment to interact with :param number_players: The number of players alternating in this environment :param discount: The discount to apply to future rewards when calculating the value target """ self.environment = environment self.action_space_size = environment.action_space.n self.players = [Player(i) for i in range(number_players)] self.step = 0 self.discount = discount self.action_history = [Action(0)] self.reward_history = [0] self.root_values = [] self.probability_distributions = [] self.observation_history = [] self.environment = environment self.observation_history.append(self.environment.reset()) self.child_visits = [] self.max_moves = max_moves self.done = False if number_players not in [1, 2]: raise Exception('Game init', 'Valid number_player-values are: 1 or 2')
def rollout(self): player_counter = 0 for simulation in range(self.max_sims): to_play = Player(player_counter % self.num_players) player_counter += 1 next_to_play = Player(player_counter % self.num_players) leaf = self.root.select(to_play=to_play, exploration_weight=self.exploration_weight) # End rollout if no valid leaf can be found if leaf is None: break # Explore node (predict value and policy distribution and add child nodes as unexplored leafs) leaf.expand(to_play=next_to_play, legal_actions=self.legal_actions, min_max_stats=self.min_max_stats) # Update all the nodes above accordingly leaf.backup(to_play=to_play, min_max_stats=self.min_max_stats, discount=self.discount)
def apply(self, action: Action, to_play: Player): """ Applies a action on the environment and saves the action as well as the observed the next state and reward. :param to_play: The player who takes the action :param action: The action to execute in the environment """ if self.terminal(): raise Exception('MuZero Games', 'You cant continue to play a terminated game') if to_play != self.to_play(): raise Exception('Muzero Games', 'The player on turn has to rotate for board games') observation, reward, self.done, _ = self.environment.step(action.action_id) self.observation_history.append(observation) self.reward_history.append(reward if to_play == Player(0) else -reward) self.action_history.append(action) self.step += 1 if self.step >= self.max_moves: self.done = True
def test_eq(self): from muzero.environment.player import Player import random player_list = [Player(player_id) for player_id in range(50)] for _ in range(50): player_id_one = random.randint(0, 49) player_id_two = random.randint(0, 49) if player_id_one != player_id_two: self.assertNotEqual( player_list[player_id_one], player_list[player_id_two], 'Two players with different ids must not be equal in comparison' ) else: self.assertEqual( player_list[player_id_one], player_list[player_id_two], 'Two players with the same ids have to be equal in comparison' )
def test_apply_changes_player_on_turn(self): from muzero.environment.games import Game from muzero.environment.player import Player game_one_player = self.game game_two_players = Game(environment=self.env, discount=0.995, number_players=2, max_moves=50) to_play_one = game_one_player.to_play() to_play_two = game_two_players.to_play() game_one_player.apply(self.default_action, self.default_player) game_two_players.apply(self.default_action, self.default_player) self.assertEqual( to_play_one, game_one_player.to_play(), 'The player on turn must not change in single agent domains') self.assertNotEqual( to_play_two, game_two_players.to_play(), 'The player on turn has to rotate in two agent domains') game_two_players.apply(self.default_action, Player(1)) self.assertEqual( to_play_two, game_two_players.to_play(), 'The player on turn has to rotate in two agent domains')
def setUp(self): from muzero.environment.games import Game from muzero.environment.action import Action from muzero.environment.player import Player from muzero.mcts.node import Node import gym self.env = gym.make('CartPole-v0') self.game = Game(environment=self.env, discount=0.995, number_players=1, max_moves=50) self.default_action = Action(0) self.default_player = Player(0) self.default_root_node = Node(value=1, action=self.default_action, hidden_state=0, policy_logits=[0], to_play=self.default_player, reward=0) # Add two child nodes for both possible action leaf_one = Node(value=1, action=self.default_action, hidden_state=0, policy_logits=[0], to_play=self.default_player, reward=0) leaf_two = Node(value=1, action=self.default_action, hidden_state=0, policy_logits=[0], to_play=self.default_player, reward=0) leaf_one.visit_count += 1 leaf_two.visit_count += 1 self.default_root_node.child_nodes.append(leaf_one) self.default_root_node.child_nodes.append(leaf_two) self.default_root_node.visit_count += 3
def play_game(self, network: Network) -> Game: """ Each of the games produced by a process is played using the latest network, so that quality of training data should increase very quick at the beginning """ player_helper = 0 game = self.config.new_game() tree = Tree(action_list=game.legal_actions(), config=self.config, network=network, player_list=game.players, discount=self.config.discount) while not game.terminal() and len(game.root_values) < self.config.max_moves: self.frame_count += 1 image = game.make_image(-1) value, reward, policy_logits, hidden_state = network.initial_inference(image) tree.reset(value=value, reward=reward, policy_logits=policy_logits, hidden_state=hidden_state) action = tree.get_action(evaluation=False) game.apply(action, Player(player_helper % len(game.players))) game.store_search_statistics(tree.root) return game
def to_play(self) -> Player: """ :return: A bool indicating whether the player which is currently on turn """ return Player(self.step % len(self.players))