def get_virtual_trajectory_from_obs(self, observation, horizon, plot=True, to_play=0): """ MuZero plays a game but uses its model instead of using the environment. We still do an MCTS at each step. """ trajectory_info = Trajectoryinfo("Virtual trajectory", self.config) root, mcts_info = MCTS(self.config).run(self.model, observation, self.config.action_space, to_play, True) trajectory_info.store_info(root, mcts_info, None, numpy.NaN) virtual_to_play = to_play for i in range(horizon): action = SelfPlay.select_action(root, 0) # Players play turn by turn if virtual_to_play + 1 < len(self.config.players): virtual_to_play = self.config.players[virtual_to_play + 1] else: virtual_to_play = self.config.players[0] # Generate new root # TODO: Test keeping the old root value, reward, policy_logits, hidden_state = self.model.recurrent_inference( root.hidden_state, torch.tensor([[action]]).to(root.hidden_state.device), ) value = models.support_to_scalar(value, self.config.support_size).item() reward = models.support_to_scalar(reward, self.config.support_size).item() root = Node(0) root.expand( self.config.action_space, virtual_to_play, reward, policy_logits, hidden_state, ) root, mcts_info = MCTS(self.config).run(self.model, None, self.config.action_space, virtual_to_play, True, root) trajectory_info.store_info(root, mcts_info, action, reward, new_prior_root_value=value) if plot: self.plot_trajectory(trajectory_info) return trajectory_info
def reanalyse(self, replay_buffer, shared_storage): while shared_storage.get_info("num_played_games") < 1: time.sleep(0.1) while shared_storage.get_info("training_step") < self.config.training_steps and not shared_storage.get_info("terminate"): self.model.set_weights(shared_storage.get_info("weights")) game_id, game_history, _ = replay_buffer.sample_game(force_uniform=True) # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) if self.config.use_last_model_value: observations = [ game_history.get_stacked_observations( i, self.config.stacked_observations ) for i in range(len(game_history.root_values)) ] observations = ( torch.tensor(observations) .float() .to(next(self.model.parameters()).device) ) values = models.support_to_scalar( self.model.initial_inference(observations)[0], self.config.support_size, ) game_history.reanalysed_predicted_root_values = ( torch.squeeze(values).detach().numpy() ) replay_buffer.update_game_history(game_id, game_history) self.num_reanalysed_games += 1 shared_storage.set_info("num_reanalysed_games", self.num_reanalysed_games)
def compute_value(self, game_history, index): # The value target is the discounted root value of the search tree td_steps into the # future, plus the discounted sum of all rewards until then. bootstrap_index = index + self.config.td_steps if bootstrap_index < len(game_history.root_values): if self.config.use_last_model_value: # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) observation = (torch.tensor( game_history.get_stacked_observations( bootstrap_index, self.config.stacked_observations)).float().unsqueeze(0) ) last_step_value = models.support_to_scalar( self.model.initial_inference(observation)[0], self.config.support_size, ).item() else: last_step_value = game_history.root_values[bootstrap_index] value = last_step_value * self.config.discount**self.config.td_steps else: value = 0 for i, reward in enumerate( game_history.reward_history[index + 1:bootstrap_index + 1]): value += (reward if game_history.to_play_history[index] == game_history.to_play_history[index + 1 + i] else -reward) * self.config.discount**i return value
def reanalyse(self, replay_buffer, shared_storage): while ray.get( shared_storage.get_info.remote())["num_played_games"] < 1: time.sleep(0.1) while (ray.get(shared_storage.get_info.remote())["training_step"] < self.config.training_steps): self.model.set_weights( copy.deepcopy(ray.get(shared_storage.get_weights.remote()))) game_id, game_history, _ = ray.get( replay_buffer.sample_game.remote(force_uniform=True)) # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) if self.config.use_last_model_value: observations = [ game_history.get_stacked_observations( i, self.config.stacked_observations) for i in range(len(game_history.root_values)) ] observations = (torch.tensor(observations).float().to( self.config.reanalyse_device)) values = models.support_to_scalar( self.model.initial_inference(observations)[0], self.config.support_size, self.config.epsilon) for i in range(len(game_history.root_values)): game_history.root_values[i] = values[i].item() replay_buffer.update_game_history.remote(game_id, game_history) self.num_reanalysed_games += 1 shared_storage.set_info.remote("num_reanalysed_games", self.num_reanalysed_games)
def update_policies(self): while True: keys = ray.get(self.replay_buffer.get_buffer_keys.remote()) for game_id in keys: remcts_count = 0 self.latest_network.set_weights( ray.get(self.shared_storage.get_network_weights.remote())) self.target_network.set_weights( ray.get(self.shared_storage.get_target_network_weights. remote())) game_history = copy.deepcopy( ray.get( self.replay_buffer.get_game_history.remote(game_id))) for pos in range(len(game_history.observation_history)): bootstrap_index = pos + self.config.td_steps if bootstrap_index < len(game_history.root_values): if self.config.use_last_model_value: # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) observation = torch.tensor( game_history.get_stacked_observations( bootstrap_index, self.config.stacked_observations)).float() value = models.support_to_scalar( self.target_network.initial_inference( observation)[0], self.config.support_size, ).item() game_history.root_values[bootstrap_index] = value if random.random( ) < self.config.policy_update_rate and pos < len( game_history.root_values): with torch.no_grad(): stacked_obs = torch.tensor( game_history.get_stacked_observations( pos, self.config.stacked_observations)).float() root, _, _ = MCTS(self.config).run( self.latest_network, stacked_obs, game_history.legal_actions[pos], game_history.to_play_history[pos], False) game_history.store_search_statistics( root, self.config.action_space, pos) remcts_count += 1 self.shared_storage.update_infos.remote( "remcts_count", remcts_count) self.shared_storage.update_infos.remote( "reanalyzed_count", len(game_history.priorities)) self.replay_buffer.update_game.remote(game_history, game_id)
def run( self, model, observation, legal_actions, to_play, add_exploration_noise, override_root_with=None, ): """ At the root of the search tree we use the representation function to obtain a hidden state given the current observation. We then run a Monte Carlo Tree Search using only action sequences and the model learned by the network. """ if override_root_with: root = override_root_with root_predicted_value = None else: root = Node(0) observation = (torch.tensor(observation).float().unsqueeze(0).to( next(model.parameters()).device)) ( root_predicted_value, reward, policy_logits, hidden_state, ) = model.initial_inference(observation) root_predicted_value = models.support_to_scalar( root_predicted_value, self.config.support_size).item() reward = models.support_to_scalar(reward, self.config.support_size).item() assert ( legal_actions ), f"Legal actions should not be an empty array. Got {legal_actions}." assert set(legal_actions).issubset( set(self.config.action_space )), "Legal actions should be a subset of the action space." root.expand( legal_actions, to_play, reward, policy_logits, hidden_state, ) if add_exploration_noise: root.add_exploration_noise( dirichlet_alpha=self.config.root_dirichlet_alpha, exploration_fraction=self.config.root_exploration_fraction, ) min_max_stats = MinMaxStats() max_tree_depth = 0 for _ in range(self.config.num_simulations): virtual_to_play = to_play node = root search_path = [node] current_tree_depth = 0 while node.expanded(): current_tree_depth += 1 action, node = self.select_child(node, min_max_stats) search_path.append(node) # Players play turn by turn if virtual_to_play + 1 < len(self.config.players): virtual_to_play = self.config.players[virtual_to_play + 1] else: virtual_to_play = self.config.players[0] # Inside the search tree we use the dynamics function to obtain the next hidden # state given an action and the previous hidden state parent = search_path[-2] value, reward, policy_logits, hidden_state = model.recurrent_inference( parent.hidden_state, torch.tensor([[action]]).to(parent.hidden_state.device), ) value = models.support_to_scalar(value, self.config.support_size).item() reward = models.support_to_scalar(reward, self.config.support_size).item() node.expand( self.config.action_space, virtual_to_play, reward, policy_logits, hidden_state, ) self.backpropagate(search_path, value, virtual_to_play, min_max_stats) max_tree_depth = max(max_tree_depth, current_tree_depth) extra_info = { "max_tree_depth": max_tree_depth, "root_predicted_value": root_predicted_value, } return root, extra_info
def update_weights(self, batch): """ Perform one training step. """ ( observation_batch, action_batch, target_value, target_reward, target_policy, weight_batch, gradient_scale_batch, ) = batch # Keep values as scalars for calculating the priorities for the prioritized replay target_value_scalar = numpy.array(target_value) priorities = numpy.zeros_like(target_value_scalar) device = next(self.model.parameters()).device weight_batch = torch.tensor(weight_batch).float().to(device) observation_batch = torch.tensor(observation_batch).float().to(device) action_batch = torch.tensor(action_batch).float().to(device).unsqueeze( -1) target_value = torch.tensor(target_value).float().to(device) target_reward = torch.tensor(target_reward).float().to(device) target_policy = torch.tensor(target_policy).float().to(device) gradient_scale_batch = torch.tensor(gradient_scale_batch).float().to( device) # observation_batch: batch, channels, height, width # action_batch: batch, num_unroll_steps+1, 1 (unsqueeze) # target_value: batch, num_unroll_steps+1 # target_reward: batch, num_unroll_steps+1 # target_policy: batch, num_unroll_steps+1, len(action_space) # gradient_scale_batch: batch, num_unroll_steps+1 target_value = models.scalar_to_support(target_value, self.config.support_size) target_reward = models.scalar_to_support(target_reward, self.config.support_size) # target_value: batch, num_unroll_steps+1, 2*support_size+1 # target_reward: batch, num_unroll_steps+1, 2*support_size+1 ## Generate predictions value, reward, policy_logits, hidden_state = self.model.initial_inference( observation_batch) predictions = [(value, reward, policy_logits)] for i in range(1, action_batch.shape[1]): value, reward, policy_logits, hidden_state = self.model.recurrent_inference( hidden_state, action_batch[:, i]) # Scale the gradient at the start of the dynamics function (See paper appendix Training) hidden_state.register_hook(lambda grad: grad * 0.5) predictions.append((value, reward, policy_logits)) # predictions: num_unroll_steps+1, 3, batch, 2*support_size+1 | 2*support_size+1 | 9 (according to the 2nd dim) ## Compute losses value_loss, reward_loss, policy_loss = (0, 0, 0) value, reward, policy_logits = predictions[0] # Ignore reward loss for the first batch step current_value_loss, _, current_policy_loss = self.loss_function( value.squeeze(-1), reward.squeeze(-1), policy_logits, target_value[:, 0], target_reward[:, 0], target_policy[:, 0], ) value_loss += current_value_loss policy_loss += current_policy_loss # Compute priorities for the prioritized replay (See paper appendix Training) pred_value_scalar = (models.support_to_scalar( value, self.config.support_size).detach().cpu().numpy().squeeze()) priorities[:, 0] = ( numpy.abs(pred_value_scalar - target_value_scalar[:, 0])**self.config.PER_alpha) for i in range(1, len(predictions)): value, reward, policy_logits = predictions[i] ( current_value_loss, current_reward_loss, current_policy_loss, ) = self.loss_function( value.squeeze(-1), reward.squeeze(-1), policy_logits, target_value[:, i], target_reward[:, i], target_policy[:, i], ) # Scale gradient by the number of unroll steps (See paper appendix Training) current_value_loss.register_hook( lambda grad: grad / gradient_scale_batch[:, i]) current_reward_loss.register_hook( lambda grad: grad / gradient_scale_batch[:, i]) current_policy_loss.register_hook( lambda grad: grad / gradient_scale_batch[:, i]) value_loss += current_value_loss reward_loss += current_reward_loss policy_loss += current_policy_loss # Compute priorities for the prioritized replay (See paper appendix Training) pred_value_scalar = (models.support_to_scalar( value, self.config.support_size).detach().cpu().numpy().squeeze()) priorities[:, i] = ( numpy.abs(pred_value_scalar - target_value_scalar[:, i])**self.config.PER_alpha) # Scale the value loss, paper recommends by 0.25 (See paper appendix Reanalyze) loss = value_loss * self.config.value_loss_weight + reward_loss + policy_loss if self.config.PER: # Correct PER bias by using importance-sampling (IS) weights loss *= weight_batch # Mean over batch dimension (pseudocode do a sum) loss = loss.mean() # Optimize self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.training_step += 1 return ( priorities, # For log purpose loss.item(), value_loss.mean().item(), reward_loss.mean().item(), policy_loss.mean().item(), )
def reanalyse(self, replay_buffer, shared_storage): while ray.get(shared_storage.get_info.remote("num_played_games")) < 1: time.sleep(0.1) while ray.get(shared_storage.get_info.remote( "training_step")) < self.config.training_steps and not ray.get( shared_storage.get_info.remote("terminate")): self.model.set_weights( ray.get(shared_storage.get_info.remote("weights"))) # update target model periodically if self.config.use_last_model_value: training_step = ray.get( shared_storage.get_info.remote("training_step")) if (training_step - self.last_update_step ) >= self.config.value_target_update_freq: self.last_update_step = training_step self.target_model.set_weights( ray.get(shared_storage.get_info.remote("weights"))) game_id, game_history = ray.get( replay_buffer.reanalyse_sample_game.remote()) # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) if self.config.use_last_model_value: # use the lagging network (representation + value) to obtain updated targets if not self.config.use_updated_mcts_value_targets: observations = [ game_history.get_stacked_observations( i, self.config.stacked_observations) for i in range(len(game_history.root_values)) ] observations = (torch.tensor(observations).float().to( next(self.model.parameters()).device)) values = models.support_to_scalar( # use lagging parameters self.target_model.initial_inference(observations)[0], self.config.support_size, ) root_values = ( torch.squeeze(values).detach().cpu().numpy()) # re-execute MCTS to update targets (child visist and root_values) l = len(game_history.root_values) game_history.child_visits = [] game_history.root_values = [] priorities = [] for i in range(l): stacked_observations = game_history.get_stacked_observations( i, self.config.stacked_observations, ) root, mcts_info = MCTS(self.config).run( # use either fresh or lagging (recent) parameters self.target_model if self.config.use_updated_mcts_value_targets else self.model, stacked_observations, self.game.legal_actions(), self.game.to_play(), True, ) game_history.store_search_statistics( root, self.config.action_space) # use mcts values targets if self.config.use_updated_mcts_value_targets: root_values = game_history.root_values # Update PER according to the initial prioritization (See paper appendix Training) if self.config.PER: priorities = [] for i, root_value in enumerate(root_values): priority = (numpy.abs(root_value - compute_target_value( game_history, i, self.config.td_steps, self.config.discount))**self.config.PER_alpha) priorities.append(priority) game_history.priorities = numpy.array(priorities, dtype="float32") game_history.game_priority = numpy.max(game_history.priorities) game_history.reanalysed_predicted_root_values = root_values replay_buffer.update_game_history.remote(game_id, game_history, shared_storage)
def make_target(self, game_history, state_index): """ Generate targets for every unroll steps. """ target_values, target_rewards, target_policies, actions = [], [], [], [] for current_index in range( state_index, state_index + self.config.num_unroll_steps + 1): # The value target is the discounted root value of the search tree td_steps into the # future, plus the discounted sum of all rewards until then. bootstrap_index = current_index + self.config.td_steps if bootstrap_index < len(game_history.root_values): if self.config.use_last_model_value: # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) observation = (torch.tensor( game_history.get_stacked_observations( bootstrap_index, self.config.stacked_observations) ).float().unsqueeze(0)) last_step_value = models.support_to_scalar( self.model.initial_inference(observation)[0], self.config.support_size, ).item() else: last_step_value = game_history.root_values[bootstrap_index] value = last_step_value * self.config.discount**self.config.td_steps else: value = 0 for i, reward in enumerate( game_history.reward_history[current_index + 1:bootstrap_index + 1]): value += (reward if game_history.to_play_history[current_index] == game_history.to_play_history[current_index + 1 + i] else -reward) * self.config.discount**i if current_index < len(game_history.root_values): target_values.append(value) target_rewards.append( game_history.reward_history[current_index]) target_policies.append( game_history.child_visits[current_index]) actions.append(game_history.action_history[current_index]) elif current_index == len(game_history.root_values): target_values.append(0) target_rewards.append( game_history.reward_history[current_index]) # Uniform policy target_policies.append([ 1 / len(game_history.child_visits[0]) for _ in range(len(game_history.child_visits[0])) ]) actions.append(game_history.action_history[current_index]) else: # States past the end of games are treated as absorbing states target_values.append(0) target_rewards.append(0) # Uniform policy target_policies.append([ 1 / len(game_history.child_visits[0]) for _ in range(len(game_history.child_visits[0])) ]) actions.append(numpy.random.choice( game_history.action_history)) return target_values, target_rewards, target_policies, actions