def get_virtual_trajectory_from_obs(self, observation, horizon, plot=True, to_play=0): """ MuZero plays a game but uses its model instead of using the environment. We still do an MCTS at each step. """ trajectory_info = Trajectoryinfo("Virtual trajectory", self.config) root, mcts_info = MCTS(self.config).run(self.model, observation, self.config.action_space, to_play, True) trajectory_info.store_info(root, mcts_info, None, numpy.NaN) virtual_to_play = to_play for i in range(horizon): action = SelfPlay.select_action(root, 0) # Players play turn by turn if virtual_to_play + 1 < len(self.config.players): virtual_to_play = self.config.players[virtual_to_play + 1] else: virtual_to_play = self.config.players[0] # Generate new root value, reward, policy_logits, hidden_state = self.model.recurrent_inference( root.hidden_state, torch.tensor([[action]]).to(root.hidden_state.device), ) value = network.support_to_scalar(value, self.config.support_size).item() reward = network.support_to_scalar( reward, self.config.support_size).item() root = Node(0) root.expand( self.config.action_space, virtual_to_play, reward, policy_logits, hidden_state, ) root, mcts_info = MCTS(self.config).run(self.model, None, self.config.action_space, virtual_to_play, True, root) trajectory_info.store_info(root, mcts_info, action, reward, new_prior_root_value=value) if plot: trajectory_info.plot_trajectory() return trajectory_info
def reanalyse(self, replay_buffer, shared_storage): while ray.get(shared_storage.get_info.remote("num_played_games")) < 1: time.sleep(0.1) while ray.get(shared_storage.get_info.remote( "training_step")) < self.config.training_steps and not ray.get( shared_storage.get_info.remote("terminate")): self.model.set_weights( ray.get(shared_storage.get_info.remote("weights"))) game_id, game_history, _ = ray.get( replay_buffer.sample_game.remote(force_uniform=True)) # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) if self.config.use_last_model_value: observations = [ game_history.get_stacked_observations( i, self.config.stacked_observations) for i in range(len(game_history.root_values)) ] observations = (torch.tensor(observations).float().to( next(self.model.parameters()).device)) values = network.support_to_scalar( self.model.initial_inference(observations)[0], self.config.support_size, ) game_history.reanalysed_predicted_root_values = ( torch.squeeze(values).detach().cpu().numpy()) replay_buffer.update_game_history.remote(game_id, game_history) self.num_reanalysed_games += 1 shared_storage.set_info.remote("num_reanalysed_games", self.num_reanalysed_games)
def run( self, model, observation, legal_actions, to_play, add_exploration_noise, override_root_with=None, ): """At the root of the search tree we use the representation function to obtain a hidden state given the current observation. We then run a Monte Carlo Tree Search using only action sequences and the model learned by the network. """ if override_root_with: root = override_root_with root_predicted_value = None else: root = Node(0) observation = ( torch.tensor(observation) .float() .unsqueeze(0) .to(next(model.parameters()).device) ) ( root_predicted_value, reward, policy_logits, hidden_state, ) = model.initial_inference(observation) root_predicted_value = network.support_to_scalar( root_predicted_value, self.config.support_size ).item() reward = network.support_to_scalar(reward, self.config.support_size).item() assert ( legal_actions ), f"Legal actions should not be an empty array. Got {legal_actions}." assert set(legal_actions).issubset( set(self.config.action_space) ), "Legal actions should be a subset of the action space." root.expand( legal_actions, to_play, reward, policy_logits, hidden_state, ) if add_exploration_noise: root.add_exploration_noise( dirichlet_alpha=self.config.root_dirichlet_alpha, exploration_fraction=self.config.root_exploration_fraction, ) min_max_stats = MinMaxStats() max_tree_depth = 0 for _ in range(self.config.num_simulations): run_log = {} for k in range(self.runs) : virtual_to_play = to_play node = root search_path = [node] current_tree_depth = 0 while node.expanded(): current_tree_depth += 1 action, node = self.select_child(node, min_max_stats) search_path.append(node) # Players play turn by turn if virtual_to_play + 1 < len(self.config.players): virtual_to_play = self.config.players[virtual_to_play + 1] else: virtual_to_play = self.config.players[0] # Inside the search tree we use the dynamics function to obtain # the next hidden state given an action and the previous hidden # state parent = search_path[-2] value, reward, policy_logits, hidden_state = model.recurrent_inference( parent.hidden_state, torch.tensor([[action]]).to(parent.hidden_state.device), ) value = network.support_to_scalar(value, self.config.support_size).item() reward = network.support_to_scalar(reward, self.config.support_size).item() run_log[k] = [virtual_to_play, reward, policy_logits, hidden_state, search_path, value, current_tree_depth] run_log = {x:v for x,v in sorted(run_log.items(), key = lambda item : item[1][1] ) } chosen_run = list(run_log.keys())[0] data_ret = run_log[chosen_run] node.expand( self.config.action_space, data_ret[0], data_ret[1], data_ret[2], data_ret[3], ) self.backpropagate(data_ret[4], data_ret[5] , data_ret[0], min_max_stats) max_tree_depth = max(max_tree_depth, data_ret[6]) extra_info = { "max_tree_depth": max_tree_depth, "root_predicted_value": root_predicted_value, } return root, extra_info
def update_weights(self, batch): """ Perform one training step. """ ( observation_batch, action_batch, target_value, target_reward, target_policy, weight_batch, gradient_scale_batch, ) = batch # Keep values as scalars for calculating the priorities for the prioritized replay target_value_scalar = numpy.array(target_value, dtype="float32") priorities = numpy.zeros_like(target_value_scalar) device = next(self.model.parameters()).device if self.config.PER: weight_batch = torch.tensor(weight_batch.copy()).float().to(device) observation_batch = torch.tensor(observation_batch).float().to(device) action_batch = torch.tensor(action_batch).long().to(device).unsqueeze( -1) target_value = torch.tensor(target_value).float().to(device) target_reward = torch.tensor(target_reward).float().to(device) target_policy = torch.tensor(target_policy).float().to(device) gradient_scale_batch = torch.tensor(gradient_scale_batch).float().to( device) # observation_batch: batch, channels, height, width # action_batch: batch, num_unroll_steps+1, 1 (unsqueeze) # target_value: batch, num_unroll_steps+1 # target_reward: batch, num_unroll_steps+1 # target_policy: batch, num_unroll_steps+1, len(action_space) # gradient_scale_batch: batch, num_unroll_steps+1 target_value = network.scalar_to_support(target_value, self.config.support_size) target_reward = network.scalar_to_support(target_reward, self.config.support_size) # target_value: batch, num_unroll_steps+1, 2*support_size+1 # target_reward: batch, num_unroll_steps+1, 2*support_size+1 ## Generate predictions value, reward, policy_logits, hidden_state = self.model.initial_inference( observation_batch) predictions = [(value, reward, policy_logits)] for i in range(1, action_batch.shape[1]): value, reward, policy_logits, hidden_state = self.model.recurrent_inference( hidden_state, action_batch[:, i]) # Scale the gradient at the start of the dynamics function (See paper appendix Training) hidden_state.register_hook(lambda grad: grad * 0.5) predictions.append((value, reward, policy_logits)) # predictions: num_unroll_steps+1, 3, batch, 2*support_size+1 | 2*support_size+1 | 9 (according to the 2nd dim) ## Compute losses value_loss, reward_loss, policy_loss = (0, 0, 0) value, reward, policy_logits = predictions[0] # Ignore reward loss for the first batch step current_value_loss, _, current_policy_loss = self.loss_function( value.squeeze(-1), reward.squeeze(-1), policy_logits, target_value[:, 0], target_reward[:, 0], target_policy[:, 0], ) value_loss += current_value_loss policy_loss += current_policy_loss # Compute priorities for the prioritized replay (See paper appendix Training) pred_value_scalar = (network.support_to_scalar( value, self.config.support_size).detach().cpu().numpy().squeeze()) priorities[:, 0] = ( numpy.abs(pred_value_scalar - target_value_scalar[:, 0])**self.config.PER_alpha) for i in range(1, len(predictions)): value, reward, policy_logits = predictions[i] ( current_value_loss, current_reward_loss, current_policy_loss, ) = self.loss_function( value.squeeze(-1), reward.squeeze(-1), policy_logits, target_value[:, i], target_reward[:, i], target_policy[:, i], ) # Scale gradient by the number of unroll steps (See paper appendix Training) current_value_loss.register_hook( lambda grad: grad / gradient_scale_batch[:, i]) current_reward_loss.register_hook( lambda grad: grad / gradient_scale_batch[:, i]) current_policy_loss.register_hook( lambda grad: grad / gradient_scale_batch[:, i]) value_loss += current_value_loss reward_loss += current_reward_loss policy_loss += current_policy_loss # Compute priorities for the prioritized replay (See paper appendix Training) pred_value_scalar = (network.support_to_scalar( value, self.config.support_size).detach().cpu().numpy().squeeze()) priorities[:, i] = ( numpy.abs(pred_value_scalar - target_value_scalar[:, i])**self.config.PER_alpha) # Scale the value loss, paper recommends by 0.25 (See paper appendix Reanalyze) loss = value_loss * self.config.value_loss_weight + reward_loss + policy_loss if self.config.PER: # Correct PER bias by using importance-sampling (IS) weights loss *= weight_batch # Mean over batch dimension (pseudocode do a sum) loss = loss.mean() # Optimize self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.training_step += 1 return ( priorities, # For log purpose loss.item(), value_loss.mean().item(), reward_loss.mean().item(), policy_loss.mean().item(), )