def compare_virtual_with_real_trajectories(self, first_obs, game, horizon, plot=True): """ First, MuZero plays a game but uses its model instead of using the environment. Then, MuZero plays the optimal trajectory according precedent trajectory but performs it in the real environment until arriving at an action impossible in the real environment. It does an MCTS too, but doesn't take it into account. All information during the two trajectories are recorded and displayed. """ virtual_trajectory_info = self.get_virtual_trajectory_from_obs( first_obs, horizon, False) real_trajectory_info = Trajectoryinfo("Real trajectory", self.config) trajectory_divergence_index = None real_trajectory_end_reason = "Reached horizon" # Illegal moves are masked at the root root, mcts_info = MCTS(self.config).run( self.model, first_obs, game.legal_actions(), game.to_play(), True, ) self.plot_mcts(root, plot) real_trajectory_info.store_info(root, mcts_info, None, numpy.NaN) for i, action in enumerate(virtual_trajectory_info.action_history): # Follow virtual trajectory until it reaches an illegal move in the real env if action not in game.legal_actions(): break # Comment to keep playing after trajectory divergence action = SelfPlay.select_action(root, 0) if trajectory_divergence_index is None: trajectory_divergence_index = i real_trajectory_end_reason = f"Virtual trajectory reached an illegal move at timestep {trajectory_divergence_index}." observation, reward, done = game.step(action) root, mcts_info = MCTS(self.config).run( self.model, observation, game.legal_actions(), game.to_play(), True, ) real_trajectory_info.store_info(root, mcts_info, action, reward) if done: real_trajectory_end_reason = "Real trajectory reached Done" break if plot: virtual_trajectory_info.plot_trajectory() real_trajectory_info.plot_trajectory() print(real_trajectory_end_reason) return ( virtual_trajectory_info, real_trajectory_info, trajectory_divergence_index, )
def get_virtual_trajectory_from_obs(self, observation, horizon, plot=True, to_play=0): """ MuZero plays a game but uses its model instead of using the environment. We still do an MCTS at each step. """ trajectory_info = Trajectoryinfo("Virtual trajectory", self.config) root, mcts_info = MCTS(self.config).run(self.model, observation, self.config.action_space, to_play, True) trajectory_info.store_info(root, mcts_info, None, numpy.NaN) virtual_to_play = to_play for i in range(horizon): action = SelfPlay.select_action(root, 0) # Players play turn by turn if virtual_to_play + 1 < len(self.config.players): virtual_to_play = self.config.players[virtual_to_play + 1] else: virtual_to_play = self.config.players[0] # Generate new root # TODO: Test keeping the old root value, reward, policy_logits, hidden_state = self.model.recurrent_inference( root.hidden_state, torch.tensor([[action]]).to(root.hidden_state.device), ) value = models.support_to_scalar(value, self.config.support_size).item() reward = models.support_to_scalar(reward, self.config.support_size).item() root = Node(0) root.expand( self.config.action_space, virtual_to_play, reward, policy_logits, hidden_state, ) root, mcts_info = MCTS(self.config).run(self.model, None, self.config.action_space, virtual_to_play, True, root) trajectory_info.store_info(root, mcts_info, action, reward, new_prior_root_value=value) if plot: self.plot_trajectory(trajectory_info) return trajectory_info
def update_policies(self): while True: keys = ray.get(self.replay_buffer.get_buffer_keys.remote()) for game_id in keys: remcts_count = 0 self.latest_network.set_weights( ray.get(self.shared_storage.get_network_weights.remote())) self.target_network.set_weights( ray.get(self.shared_storage.get_target_network_weights. remote())) game_history = copy.deepcopy( ray.get( self.replay_buffer.get_game_history.remote(game_id))) for pos in range(len(game_history.observation_history)): bootstrap_index = pos + self.config.td_steps if bootstrap_index < len(game_history.root_values): if self.config.use_last_model_value: # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) observation = torch.tensor( game_history.get_stacked_observations( bootstrap_index, self.config.stacked_observations)).float() value = models.support_to_scalar( self.target_network.initial_inference( observation)[0], self.config.support_size, ).item() game_history.root_values[bootstrap_index] = value if random.random( ) < self.config.policy_update_rate and pos < len( game_history.root_values): with torch.no_grad(): stacked_obs = torch.tensor( game_history.get_stacked_observations( pos, self.config.stacked_observations)).float() root, _, _ = MCTS(self.config).run( self.latest_network, stacked_obs, game_history.legal_actions[pos], game_history.to_play_history[pos], False) game_history.store_search_statistics( root, self.config.action_space, pos) remcts_count += 1 self.shared_storage.update_infos.remote( "remcts_count", remcts_count) self.shared_storage.update_infos.remote( "reanalyzed_count", len(game_history.priorities)) self.replay_buffer.update_game.remote(game_history, game_id)
def reanalyse_policy_and_value(self, game_history): """ :param self_play.GameHistory game_history: :return: """ game_history.reanalysed_predicted_root_values = None game_history.reanalysed_child_visits = None for i in range(len(game_history.root_values)): observation = game_history.get_stacked_observations( i, self.config.stacked_observations) root, _ = MCTS(self.config).run( self.model, observation, game_history.legal_actions_history[i], game_history.to_play_history[i], True, ) game_history.store_search_statistics(root, self.config.action_space, reanalysed=True)
def play_against_algorithm(weight_file_path, config_name, seed, algo="expert", render=False): np.random.seed(seed) torch.manual_seed(seed) game_module = importlib.import_module("games." + config_name) config = game_module.MuZeroConfig() model = models.MuZeroNetwork(config) model.set_weights(torch.load(weight_file_path)) model.eval() algo = globals()[algo.capitalize()](-1, 1) game = Game(seed) observation = game.reset() game_history = GameHistory() game_history.action_history.append(0) game_history.reward_history.append(0) game_history.to_play_history.append(game.to_play()) game_history.legal_actions.append(game.legal_actions()) game_history.observation_history.append(observation) done = False depth = 9 reward = 0 while not done: if game.to_play_real() == -1: action = algo(game.get_state(), depth, game.to_play_real()) else: stacked_observations = game_history.get_stacked_observations( -1, config.stacked_observations, ) root, priority, tree_depth = MCTS(config).run( model, stacked_observations, game.legal_actions(), game.to_play(), False, ) action = SelfPlay.select_action( root, 0, ) game_history.store_search_statistics(root, config.action_space) game_history.priorities.append(priority) observation, reward, done = game.step(action) if render: game.render() depth -= 1 game_history.action_history.append(action) game_history.observation_history.append(observation) game_history.reward_history.append(reward) game_history.to_play_history.append(game.to_play()) game_history.legal_actions.append(game.legal_actions()) return reward, TictactoeComp.wins(game.get_state(), 1)
def play_against_other(weights1, config1, weights2, config2, seed, render=False): np.random.seed(seed) torch.manual_seed(seed) game_module = importlib.import_module("games." + config1) config1 = game_module.MuZeroConfig() model1 = models.MuZeroNetwork(config1) model1.set_weights(torch.load(weights1)) model1.eval() game_module = importlib.import_module("games." + config2) config2 = game_module.MuZeroConfig() model2 = models.MuZeroNetwork(config2) model2.set_weights(torch.load(weights2)) model2.eval() game = Game(seed) observation = game.reset() game_history1 = GameHistory() game_history1.action_history.append(0) game_history1.reward_history.append(0) game_history1.to_play_history.append(game.to_play()) game_history1.legal_actions.append(game.legal_actions()) observation1 = copy.deepcopy(observation) # observation1[0] = -observation1[1] # observation1[1] = -observation1[0] # observation1[2] = -observation1[2] game_history1.observation_history.append(observation1) game_history2 = GameHistory() game_history2.action_history.append(0) game_history2.reward_history.append(0) game_history2.to_play_history.append(not game.to_play()) game_history2.legal_actions.append(game.legal_actions()) observation2 = copy.deepcopy(observation) observation2[0] = -observation2[1] observation2[1] = -observation2[0] observation2[2] = -observation2[2] game_history2.observation_history.append(observation2) done = False reward = 0 while not done: if game.to_play_real() == 1: config = config1 model = model1 game_history = game_history1 else: config = config2 model = model2 game_history = game_history2 stacked_observations = game_history.get_stacked_observations( -1, config.stacked_observations, ) root, priority, tree_depth = MCTS(config).run( model, stacked_observations, game.legal_actions(), game.to_play(), False, ) action = SelfPlay.select_action( root, 0, ) game_history1.store_search_statistics(root, config.action_space) game_history1.priorities.append(priority) game_history2.store_search_statistics(root, config.action_space) game_history2.priorities.append(priority) observation, reward, done = game.step(action) if render: game.render() game_history1.action_history.append(action) observation1 = copy.deepcopy(observation) # observation1[0] = -observation1[1] # observation1[1] = -observation1[0] # observation1[2] = -observation1[2] game_history1.observation_history.append(observation1) game_history1.reward_history.append(reward) game_history1.to_play_history.append(game.to_play()) game_history1.legal_actions.append(game.legal_actions()) game_history2.action_history.append(action) observation2 = copy.deepcopy(observation) observation2[0] = -observation2[1] observation2[1] = -observation2[0] observation2[2] = -observation2[2] game_history2.observation_history.append(observation2) game_history2.reward_history.append(reward) game_history2.to_play_history.append(not game.to_play()) game_history2.legal_actions.append(game.legal_actions()) return reward, TictactoeComp.wins(game.get_state(), 1)
def make_target(self, game_history, state_index): """ Generate targets for every unroll steps. """ target_values, target_rewards, target_policies, actions = [], [], [], [] for current_index in range( state_index, state_index + self.config.num_unroll_steps + 1): # The value target is the discounted root value of the search tree td_steps into the # future, plus the discounted sum of all rewards until then. bootstrap_index = current_index + self.config.td_steps if bootstrap_index < len(game_history.root_values): if self.config.use_last_model_value: # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) observation = torch.tensor( game_history.get_stacked_observations( bootstrap_index, self.config.stacked_observations)).float() last_step_value = models.support_to_scalar( self.target_network.initial_inference(observation)[0], self.config.support_size, ).item() else: last_step_value = game_history.root_values[bootstrap_index] value = last_step_value * self.config.discount**self.config.td_steps else: value = 0 for i, reward in enumerate( game_history.reward_history[current_index + 1:bootstrap_index + 1]): value += (reward if game_history.to_play_history[current_index] == game_history.to_play_history[current_index + 1 + i] else -reward) * self.config.discount**i if current_index < len(game_history.root_values): if random.random( ) < self.config.policy_update_rate and current_index < len( game_history.root_values): with torch.no_grad(): stacked_obs = torch.tensor( game_history.get_stacked_observations( current_index, self.config.stacked_observations)).float() root, _, _ = MCTS(self.config).run( self.latest_network, stacked_obs, game_history.legal_actions[current_index], game_history.to_play_history[current_index], False) game_history.store_search_statistics( root, self.config.action_space, current_index) target_values.append(value) target_rewards.append( game_history.reward_history[current_index]) target_policies.append( game_history.child_visits[current_index]) actions.append(game_history.action_history[current_index]) elif current_index == len(game_history.root_values): target_values.append(0) target_rewards.append( game_history.reward_history[current_index]) # Uniform policy target_policies.append([ 1 / len(game_history.child_visits[0]) for _ in range(len(game_history.child_visits[0])) ]) actions.append(game_history.action_history[current_index]) else: # States past the end of games are treated as absorbing states target_values.append(0) target_rewards.append(0) # Uniform policy target_policies.append([ 1 / len(game_history.child_visits[0]) for _ in range(len(game_history.child_visits[0])) ]) actions.append(np.random.choice(game_history.action_history)) return target_values, target_rewards, target_policies, actions
def add_action(self, opponent: str, temperature: float = 0, temperature_threshold: float = 0, human_action: int = None) -> (int, str, list): with torch.no_grad(): if self.done: raise ValueError( "Status is already 'done' but there still another step.") if len(self.game_history.action_history) > self.config.max_moves: raise ValueError( "Number of steps are already over the max moves.") stacked_observations = self.game_history.get_stacked_observations( -1, self.config.stacked_observations, ) # Choose the action action = None if opponent == "self": root, mcts_info = MCTS(self.config).run( self.model, stacked_observations, self.game.legal_actions(), self.game.to_play(), True, ) action = SelfPlay.select_action( root, temperature if not temperature_threshold or len(self.game_history.action_history) < temperature_threshold else 0, ) elif opponent == "random": action, root = numpy.random.choice( self.game.legal_actions()), None elif opponent == "expert": action, root = self.game.expert_agent(), None elif opponent == "human": action, root = human_action, None else: raise ValueError( 'Wrong argument: "opponent" argument should be "self", "human", "expert" or "random"' ) # cast action variable action = int(action) if action is None or not action in self.game.legal_actions(): if (opponent == "human"): raise ValueError( f"Requested action '{action}' is illegal in this game." ) else: raise Exception( f"Calculated action '{action}' by '{opponent}' is illegal in this game." ) observation, reward, self.done = self.game.step(action) self.game_history.store_search_statistics(root, self.config.action_space) # Next batch self.game_history.action_history.append(action) self.game_history.observation_history.append(observation) self.game_history.reward_history.append(reward) self.game_history.to_play_history.append(self.game.to_play()) if isinstance(observation, numpy.ndarray): observation = observation.tolist() return action, self.game.action_to_string(action), observation