def get_virtual_trajectory_from_obs(self, observation, horizon, plot=True, to_play=0): """ MuZero plays a game but uses its model instead of using the environment. We still do an MCTS at each step. """ trajectory_info = Trajectoryinfo("Virtual trajectory", self.config) root, mcts_info = MCTS(self.config).run(self.model, observation, self.config.action_space, to_play, True) trajectory_info.store_info(root, mcts_info, None, numpy.NaN) virtual_to_play = to_play for i in range(horizon): action = SelfPlay.select_action(root, 0) # Players play turn by turn if virtual_to_play + 1 < len(self.config.players): virtual_to_play = self.config.players[virtual_to_play + 1] else: virtual_to_play = self.config.players[0] # Generate new root value, reward, policy_logits, hidden_state = self.model.recurrent_inference( root.hidden_state, torch.tensor([[action]]).to(root.hidden_state.device), ) value = models.support_to_scalar(value, self.config.support_size).item() reward = models.support_to_scalar(reward, self.config.support_size).item() root = Node(0) root.expand( self.config.action_space, virtual_to_play, reward, policy_logits, hidden_state, ) root, mcts_info = MCTS(self.config).run(self.model, None, self.config.action_space, virtual_to_play, True, root) trajectory_info.store_info(root, mcts_info, action, reward, new_prior_root_value=value) if plot: trajectory_info.plot_trajectory() return trajectory_info
def reanalyse(self, replay_buffer, shared_storage): while ray.get(shared_storage.get_info.remote("num_played_games")) < 1: time.sleep(0.1) while ray.get(shared_storage.get_info.remote( "training_step")) < self.config.training_steps and not ray.get( shared_storage.get_info.remote("terminate")): self.model.set_weights( ray.get(shared_storage.get_info.remote("weights"))) game_id, game_history, _ = ray.get( replay_buffer.sample_game.remote(force_uniform=True)) # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) if self.config.use_last_model_value: observations = [ game_history.get_stacked_observations( i, self.config.stacked_observations) for i in range(len(game_history.root_values)) ] observations = (tf.convert_to_tensor(observations, dtype=tf.float32)) values = models.support_to_scalar( self.model.initial_inference(observations)[0], self.config.support_size, ) game_history.reanalysed_predicted_root_values = ( tf.squeeze(values).numpy()) replay_buffer.update_game_history.remote(game_id, game_history) self.num_reanalysed_games += 1 shared_storage.set_info.remote("num_reanalysed_games", self.num_reanalysed_games)
def run(self, model, observation, legal_actions, to_play, add_exploration_noise, override_root_with=None): """ At the root of the search tree we use the representation function to obtain a hidden state given the current observation. We then run a Monte Carlo Tree Search using only action sequences and the model learned by the network. """ if override_root_with: root = override_root_with root_predicted_value = None else: root = Node(0) observation = (torch.tensor(observation).float().unsqueeze(0).to( next(model.parameters()).device)) ( root_predicted_value, reward, policy_logits, hidden_state, ) = model.initial_inference(observation) root_predicted_value = models.support_to_scalar( root_predicted_value, self.config.support_size).item() reward = models.support_to_scalar(reward, self.config.support_size).item() assert ( legal_actions ), f"Legal actions should not be an empty array. Got {legal_actions}." assert set(legal_actions).issubset(set(self.config.action_space)), "Legal actions should be a subset" \ " of the action space." root.expand( legal_actions, to_play, reward, policy_logits, hidden_state, ) if add_exploration_noise: root.add_exploration_noise( dirichlet_alpha=self.config.root_dirichlet_alpha, exploration_fraction=self.config.root_exploration_fraction, ) min_max_stats = MinMaxStats() max_tree_depth = 0 for _ in range(self.config.num_simulations): virtual_to_play = to_play node = root search_path = [node] current_tree_depth = 0 while node.expanded(): current_tree_depth += 1 action, node = self.select_child(node, min_max_stats) search_path.append(node) # Players play turn by turn if virtual_to_play + 1 < len(self.config.players): virtual_to_play = self.config.players[virtual_to_play + 1] else: virtual_to_play = self.config.players[0] # Inside the search tree we use the dynamics function to obtain the next hidden # state given an action and the previous hidden state parent = search_path[-2] value, reward, policy_logits, hidden_state = model.recurrent_inference( parent.hidden_state, torch.tensor([[action]]).to(parent.hidden_state.device), ) value = models.support_to_scalar(value, self.config.support_size).item() reward = models.support_to_scalar(reward, self.config.support_size).item() node.expand( self.config.action_space, virtual_to_play, reward, policy_logits, hidden_state, ) self.backpropagate(search_path, value, virtual_to_play, min_max_stats) max_tree_depth = max(max_tree_depth, current_tree_depth) extra_info = { "max_tree_depth": max_tree_depth, "root_predicted_value": root_predicted_value, } return root, extra_info
def compute_loss(): nonlocal observation_batch nonlocal action_batch nonlocal target_value nonlocal target_reward nonlocal target_policy nonlocal weight_batch nonlocal gradient_scale_batch # Keep values as scalars for calculating the priorities for the prioritized replay target_value_scalar = numpy.array(target_value, dtype="float32") priorities = numpy.zeros_like(target_value_scalar, dtype="float32") if self.config.PER: weight_batch = tf.identity( tf.cast(weight_batch, dtype=tf.float32)) observation_batch = tf.identity( tf.cast(observation_batch, dtype=tf.float32)) action_batch = tf.expand_dims(tf.identity(action_batch), axis=-1) target_value = tf.identity(tf.cast(target_value, dtype=tf.float32)) target_reward = tf.identity( tf.cast(target_reward, dtype=tf.float32)) target_policy = tf.identity( tf.cast(target_policy, dtype=tf.float32)) gradient_scale_batch = tf.identity( tf.cast(gradient_scale_batch, dtype=tf.float32)) # observation_batch: batch, channels, height, width # action_batch: batch, num_unroll_steps+1, 1 (unsqueeze) # target_value: batch, num_unroll_steps+1 # target_reward: batch, num_unroll_steps+1 # target_policy: batch, num_unroll_steps+1, len(action_space) # gradient_scale_batch: batch, num_unroll_steps+1 target_value = models.scalar_to_support(target_value, self.config.support_size) target_reward = models.scalar_to_support(target_reward, self.config.support_size) # target_value: batch, num_unroll_steps+1, 2*support_size+1 # target_reward: batch, num_unroll_steps+1, 2*support_size+1 # obs batch # B x H x W x C # 128 x 1 x 1 x 4 (cartpole) # value/reward # B x N # 128 x 21 (cartpole) # policy # B x A # 128 x 2 (cartpole) # hidden state # B x X # 128 x 8 (cartpole) ## Generate predictions value, reward, policy_logits, hidden_state = self.model.initial_inference( observation_batch, training=True) predictions = [(value, reward, policy_logits)] for i in range(1, action_batch.shape[1]): value, reward, policy_logits, hidden_state = self.model.recurrent_inference( hidden_state, action_batch[:, i], training=True) # Scale the gradient at the start of the dynamics function (See paper appendix Training) hidden_state = scale_gradient(hidden_state, 0.5) predictions.append((value, reward, policy_logits)) # predictions: num_unroll_steps+1, 3, batch, 2*support_size+1 | 2*support_size+1 | 9 (according to the 2nd dim) ## Compute losses value_loss, reward_loss, policy_loss = (0, 0, 0) value, reward, policy_logits = predictions[0] value_sq = tf.squeeze(value, axis=-1) if value.shape[-1] == 1 else value reward_sq = tf.squeeze( reward, axis=-1) if reward.shape[-1] == 1 else reward # Ignore reward loss for the first batch step current_value_loss, _, current_policy_loss = self.loss_function( value_sq, reward_sq, policy_logits, target_value[:, 0], target_reward[:, 0], target_policy[:, 0], ) value_loss += current_value_loss policy_loss += current_policy_loss # Compute priorities for the prioritized replay (See paper appendix Training) pred_value_scalar = (models.support_to_scalar( value, self.config.support_size).numpy().squeeze()) priorities[:, 0] = ( numpy.abs(pred_value_scalar - target_value_scalar[:, 0])**self.config.PER_alpha) for i in range(1, len(predictions)): value, reward, policy_logits = predictions[i] value_sq = tf.squeeze( value, axis=-1) if value.shape[-1] == 1 else value reward_sq = tf.squeeze( reward, axis=-1) if reward.shape[-1] == 1 else reward ( current_value_loss, current_reward_loss, current_policy_loss, ) = self.loss_function( value_sq, reward_sq, policy_logits, target_value[:, i], target_reward[:, i], target_policy[:, i], ) # Scale gradient by the number of unroll steps (See paper appendix Training) current_value_loss = scale_gradient(current_value_loss, gradient_scale_batch[:, i]) current_reward_loss = scale_gradient( current_reward_loss, gradient_scale_batch[:, i]) current_policy_loss = scale_gradient( current_policy_loss, gradient_scale_batch[:, i]) value_loss += current_value_loss reward_loss += current_reward_loss policy_loss += current_policy_loss # Compute priorities for the prioritized replay (See paper appendix Training) pred_value_scalar = (models.support_to_scalar( value, self.config.support_size).numpy().squeeze()) priorities[:, i] = (numpy.abs(pred_value_scalar - target_value_scalar[:, i])** self.config.PER_alpha) l2_loss = 0 for t in self.model.trainable_variables: l2_loss += self.config.weight_decay * tf.nn.l2_loss(t).numpy() # Scale the value loss, paper recommends by 0.25 (See paper appendix Reanalyze) loss = value_loss * self.config.value_loss_weight + reward_loss + policy_loss if self.config.PER: # Correct PER bias by using importance-sampling (IS) weights loss *= weight_batch # Mean over batch dimension (pseudocode do a sum) loss = tf.math.reduce_mean(loss) + l2_loss result.append(priorities) # For log purpose result.append(loss.numpy()) result.append(tf.math.reduce_mean(value_loss).numpy()) result.append(tf.math.reduce_mean(reward_loss).numpy()) result.append(tf.math.reduce_mean(policy_loss).numpy()) return loss
def update_weights(self, batch): """ Perform one training step. """ ( observation_batch, action_batch, target_value, target_reward, target_policy, weight_batch, gradient_scale_batch, ) = batch # Keep values as scalars for calculating the priorities for the prioritized replay target_value_scalar = numpy.array(target_value, dtype="float32") priorities = numpy.zeros_like(target_value_scalar) device = next(self.model.parameters()).device if self.config.PER: weight_batch = torch.tensor(weight_batch.copy()).float().to(device) observation_batch = torch.tensor(observation_batch).float().to(device) action_batch = torch.tensor(action_batch).long().to(device).unsqueeze( -1) target_value = torch.tensor(target_value).float().to(device) target_reward = torch.tensor(target_reward).float().to(device) target_policy = torch.tensor(target_policy).float().to(device) gradient_scale_batch = torch.tensor(gradient_scale_batch).float().to( device) # observation_batch: batch, channels, height, width # action_batch: batch, num_unroll_steps+1, 1 (unsqueeze) # target_value: batch, num_unroll_steps+1 # target_reward: batch, num_unroll_steps+1 # target_policy: batch, num_unroll_steps+1, len(action_space) # gradient_scale_batch: batch, num_unroll_steps+1 target_value = models.scalar_to_support(target_value, self.config.support_size) target_reward = models.scalar_to_support(target_reward, self.config.support_size) # target_value: batch, num_unroll_steps+1, 2*support_size+1 # target_reward: batch, num_unroll_steps+1, 2*support_size+1 ## Generate predictions value, reward, policy_logits, hidden_state = self.model.initial_inference( observation_batch) predictions = [(value, reward, policy_logits)] for i in range(1, action_batch.shape[1]): value, reward, policy_logits, hidden_state = self.model.recurrent_inference( hidden_state, action_batch[:, i]) # Scale the gradient at the start of the dynamics function (See paper appendix Training) hidden_state.register_hook(lambda grad: grad * 0.5) predictions.append((value, reward, policy_logits)) # predictions: num_unroll_steps+1, 3, batch, 2*support_size+1 | 2*support_size+1 | 9 (according to the 2nd dim) ## Compute losses value_loss, reward_loss, policy_loss = (0, 0, 0) value, reward, policy_logits = predictions[0] # Ignore reward loss for the first batch step current_value_loss, _, current_policy_loss = self.loss_function( value.squeeze(-1), reward.squeeze(-1), policy_logits, target_value[:, 0], target_reward[:, 0], target_policy[:, 0], ) value_loss += current_value_loss policy_loss += current_policy_loss # Compute priorities for the prioritized replay (See paper appendix Training) pred_value_scalar = (models.support_to_scalar( value, self.config.support_size).detach().cpu().numpy().squeeze()) priorities[:, 0] = ( numpy.abs(pred_value_scalar - target_value_scalar[:, 0])**self.config.PER_alpha) for i in range(1, len(predictions)): value, reward, policy_logits = predictions[i] ( current_value_loss, current_reward_loss, current_policy_loss, ) = self.loss_function( value.squeeze(-1), reward.squeeze(-1), policy_logits, target_value[:, i], target_reward[:, i], target_policy[:, i], ) # Scale gradient by the number of unroll steps (See paper appendix Training) current_value_loss.register_hook( lambda grad: grad / gradient_scale_batch[:, i]) current_reward_loss.register_hook( lambda grad: grad / gradient_scale_batch[:, i]) current_policy_loss.register_hook( lambda grad: grad / gradient_scale_batch[:, i]) value_loss += current_value_loss reward_loss += current_reward_loss policy_loss += current_policy_loss # Compute priorities for the prioritized replay (See paper appendix Training) pred_value_scalar = (models.support_to_scalar( value, self.config.support_size).detach().cpu().numpy().squeeze()) priorities[:, i] = ( numpy.abs(pred_value_scalar - target_value_scalar[:, i])**self.config.PER_alpha) # Scale the value loss, paper recommends by 0.25 (See paper appendix Reanalyze) loss = value_loss * self.config.value_loss_weight + reward_loss + policy_loss if self.config.PER: # Correct PER bias by using importance-sampling (IS) weights loss *= weight_batch # Mean over batch dimension (pseudocode do a sum) loss = loss.mean() # Optimize self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.training_step += 1 return ( priorities, # For log purpose loss.item(), value_loss.mean().item(), reward_loss.mean().item(), policy_loss.mean().item(), )