def test_hash(self): from muzero.environment.action import Action import random for _ in range(100): action_id = random.randint(0, 10000) action = Action(action_id) self.assertEqual( action.__hash__(), action_id, 'The hash of an action should be equal to the id')
def __init__(self, environment: core.Env, number_players: int, discount: float, max_moves: int): """ :param environment: The gym environment to interact with :param number_players: The number of players alternating in this environment :param discount: The discount to apply to future rewards when calculating the value target """ self.environment = environment self.action_space_size = environment.action_space.n self.players = [Player(i) for i in range(number_players)] self.step = 0 self.discount = discount self.action_history = [Action(0)] self.reward_history = [0] self.root_values = [] self.probability_distributions = [] self.observation_history = [] self.environment = environment self.observation_history.append(self.environment.reset()) self.child_visits = [] self.max_moves = max_moves self.done = False if number_players not in [1, 2]: raise Exception('Game init', 'Valid number_player-values are: 1 or 2')
def setUp(self): from muzero.models.dynamics_model import DynamicsModel from muzero.environment.action import Action import tensorflow as tf self.dynamics_model = DynamicsModel() self.batch_of_hidden_states = tf.ones([4, 3, 3, 1]) self.default_action = Action(0)
def legal_actions(self) -> List[Action]: """ :return: A list of all legal actions in this environment """ action_list = [] for i in range(self.action_space_size): action_list.append(Action(i)) return action_list
def get_action(self, evaluation=False): self.rollout() if evaluation: return self.get_action_with_highest_visit_count() else: probability_distribution = self.get_probability_distribution() return Action( np.random.choice(a=len(probability_distribution), size=1, p=probability_distribution))
def test_eq(self): from muzero.environment.action import Action import random action_list = [Action(action_id) for action_id in range(100)] for _ in range(100): action_id_one = random.randint(0, 99) action_id_two = random.randint(0, 99) if action_id_one != action_id_two: self.assertNotEqual( action_list[action_id_one], action_list[action_id_two], "Two actions with different ids must not be equal in comparison" ) else: self.assertEqual( action_list[action_id_one], action_list[action_id_two], "Two actions with the same ids have to be equal in comparison" )
def setUp(self): from muzero.environment.games import Game from muzero.environment.action import Action from muzero.environment.player import Player from muzero.mcts.node import Node import gym self.env = gym.make('CartPole-v0') self.game = Game(environment=self.env, discount=0.995, number_players=1, max_moves=50) self.default_action = Action(0) self.default_player = Player(0) self.default_root_node = Node(value=1, action=self.default_action, hidden_state=0, policy_logits=[0], to_play=self.default_player, reward=0) # Add two child nodes for both possible action leaf_one = Node(value=1, action=self.default_action, hidden_state=0, policy_logits=[0], to_play=self.default_player, reward=0) leaf_two = Node(value=1, action=self.default_action, hidden_state=0, policy_logits=[0], to_play=self.default_player, reward=0) leaf_one.visit_count += 1 leaf_two.visit_count += 1 self.default_root_node.child_nodes.append(leaf_one) self.default_root_node.child_nodes.append(leaf_two) self.default_root_node.visit_count += 3