def get_virtual_trajectory_from_obs(self,
                                        observation,
                                        horizon,
                                        plot=True,
                                        to_play=0):
        """
        MuZero plays a game but uses its model instead of using the environment.
        We still do an MCTS at each step.
        """
        trajectory_info = Trajectoryinfo("Virtual trajectory", self.config)
        root, mcts_info = MCTS(self.config).run(self.model, observation,
                                                self.config.action_space,
                                                to_play, True)
        trajectory_info.store_info(root, mcts_info, None, numpy.NaN)

        virtual_to_play = to_play
        for i in range(horizon):
            action = SelfPlay.select_action(root, 0)

            # Players play turn by turn
            if virtual_to_play + 1 < len(self.config.players):
                virtual_to_play = self.config.players[virtual_to_play + 1]
            else:
                virtual_to_play = self.config.players[0]

            # Generate new root
            value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
                root.hidden_state,
                torch.tensor([[action]]).to(root.hidden_state.device),
            )
            value = models.support_to_scalar(value,
                                             self.config.support_size).item()
            reward = models.support_to_scalar(reward,
                                              self.config.support_size).item()
            root = Node(0)
            root.expand(
                self.config.action_space,
                virtual_to_play,
                reward,
                policy_logits,
                hidden_state,
            )

            root, mcts_info = MCTS(self.config).run(self.model, None,
                                                    self.config.action_space,
                                                    virtual_to_play, True,
                                                    root)
            trajectory_info.store_info(root,
                                       mcts_info,
                                       action,
                                       reward,
                                       new_prior_root_value=value)

        if plot:
            trajectory_info.plot_trajectory()

        return trajectory_info
示例#2
0
    def reanalyse(self, replay_buffer, shared_storage):
        while ray.get(shared_storage.get_info.remote("num_played_games")) < 1:
            time.sleep(0.1)

        while ray.get(shared_storage.get_info.remote(
                "training_step")) < self.config.training_steps and not ray.get(
                    shared_storage.get_info.remote("terminate")):
            self.model.set_weights(
                ray.get(shared_storage.get_info.remote("weights")))

            game_id, game_history, _ = ray.get(
                replay_buffer.sample_game.remote(force_uniform=True))

            # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
            if self.config.use_last_model_value:
                observations = [
                    game_history.get_stacked_observations(
                        i, self.config.stacked_observations)
                    for i in range(len(game_history.root_values))
                ]

                observations = (tf.convert_to_tensor(observations,
                                                     dtype=tf.float32))
                values = models.support_to_scalar(
                    self.model.initial_inference(observations)[0],
                    self.config.support_size,
                )
                game_history.reanalysed_predicted_root_values = (
                    tf.squeeze(values).numpy())

            replay_buffer.update_game_history.remote(game_id, game_history)
            self.num_reanalysed_games += 1
            shared_storage.set_info.remote("num_reanalysed_games",
                                           self.num_reanalysed_games)
    def run(self,
            model,
            observation,
            legal_actions,
            to_play,
            add_exploration_noise,
            override_root_with=None):
        """
        At the root of the search tree we use the representation function to obtain a
        hidden state given the current observation.
        We then run a Monte Carlo Tree Search using only action sequences and the model
        learned by the network.
        """
        if override_root_with:
            root = override_root_with
            root_predicted_value = None
        else:
            root = Node(0)
            observation = (torch.tensor(observation).float().unsqueeze(0).to(
                next(model.parameters()).device))
            (
                root_predicted_value,
                reward,
                policy_logits,
                hidden_state,
            ) = model.initial_inference(observation)
            root_predicted_value = models.support_to_scalar(
                root_predicted_value, self.config.support_size).item()
            reward = models.support_to_scalar(reward,
                                              self.config.support_size).item()
            assert (
                legal_actions
            ), f"Legal actions should not be an empty array. Got {legal_actions}."
            assert set(legal_actions).issubset(set(self.config.action_space)), "Legal actions should be a subset" \
                                                                               " of the action space."
            root.expand(
                legal_actions,
                to_play,
                reward,
                policy_logits,
                hidden_state,
            )
        if add_exploration_noise:
            root.add_exploration_noise(
                dirichlet_alpha=self.config.root_dirichlet_alpha,
                exploration_fraction=self.config.root_exploration_fraction,
            )

        min_max_stats = MinMaxStats()
        max_tree_depth = 0
        for _ in range(self.config.num_simulations):
            virtual_to_play = to_play
            node = root
            search_path = [node]
            current_tree_depth = 0

            while node.expanded():
                current_tree_depth += 1

                action, node = self.select_child(node, min_max_stats)

                search_path.append(node)

                # Players play turn by turn
                if virtual_to_play + 1 < len(self.config.players):
                    virtual_to_play = self.config.players[virtual_to_play + 1]
                else:
                    virtual_to_play = self.config.players[0]

            # Inside the search tree we use the dynamics function to obtain the next hidden
            # state given an action and the previous hidden state
            parent = search_path[-2]
            value, reward, policy_logits, hidden_state = model.recurrent_inference(
                parent.hidden_state,
                torch.tensor([[action]]).to(parent.hidden_state.device),
            )
            value = models.support_to_scalar(value,
                                             self.config.support_size).item()
            reward = models.support_to_scalar(reward,
                                              self.config.support_size).item()
            node.expand(
                self.config.action_space,
                virtual_to_play,
                reward,
                policy_logits,
                hidden_state,
            )
            self.backpropagate(search_path, value, virtual_to_play,
                               min_max_stats)

            max_tree_depth = max(max_tree_depth, current_tree_depth)

        extra_info = {
            "max_tree_depth": max_tree_depth,
            "root_predicted_value": root_predicted_value,
        }
        return root, extra_info
示例#4
0
        def compute_loss():
            nonlocal observation_batch
            nonlocal action_batch
            nonlocal target_value
            nonlocal target_reward
            nonlocal target_policy
            nonlocal weight_batch
            nonlocal gradient_scale_batch

            # Keep values as scalars for calculating the priorities for the prioritized replay
            target_value_scalar = numpy.array(target_value, dtype="float32")
            priorities = numpy.zeros_like(target_value_scalar, dtype="float32")

            if self.config.PER:
                weight_batch = tf.identity(
                    tf.cast(weight_batch, dtype=tf.float32))
            observation_batch = tf.identity(
                tf.cast(observation_batch, dtype=tf.float32))
            action_batch = tf.expand_dims(tf.identity(action_batch), axis=-1)
            target_value = tf.identity(tf.cast(target_value, dtype=tf.float32))
            target_reward = tf.identity(
                tf.cast(target_reward, dtype=tf.float32))
            target_policy = tf.identity(
                tf.cast(target_policy, dtype=tf.float32))
            gradient_scale_batch = tf.identity(
                tf.cast(gradient_scale_batch, dtype=tf.float32))
            # observation_batch: batch, channels, height, width
            # action_batch: batch, num_unroll_steps+1, 1 (unsqueeze)
            # target_value: batch, num_unroll_steps+1
            # target_reward: batch, num_unroll_steps+1
            # target_policy: batch, num_unroll_steps+1, len(action_space)
            # gradient_scale_batch: batch, num_unroll_steps+1

            target_value = models.scalar_to_support(target_value,
                                                    self.config.support_size)
            target_reward = models.scalar_to_support(target_reward,
                                                     self.config.support_size)
            # target_value: batch, num_unroll_steps+1, 2*support_size+1
            # target_reward: batch, num_unroll_steps+1, 2*support_size+1

            # obs batch
            # B x H x W x C
            # 128 x 1 x 1 x 4 (cartpole)
            # value/reward
            # B x N
            # 128 x 21 (cartpole)
            # policy
            # B x A
            # 128 x 2 (cartpole)
            # hidden state
            # B x X
            # 128 x 8 (cartpole)

            ## Generate predictions
            value, reward, policy_logits, hidden_state = self.model.initial_inference(
                observation_batch, training=True)
            predictions = [(value, reward, policy_logits)]
            for i in range(1, action_batch.shape[1]):
                value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
                    hidden_state, action_batch[:, i], training=True)
                # Scale the gradient at the start of the dynamics function (See paper appendix Training)
                hidden_state = scale_gradient(hidden_state, 0.5)
                predictions.append((value, reward, policy_logits))
            # predictions: num_unroll_steps+1, 3, batch, 2*support_size+1 | 2*support_size+1 | 9 (according to the 2nd dim)

            ## Compute losses
            value_loss, reward_loss, policy_loss = (0, 0, 0)
            value, reward, policy_logits = predictions[0]
            value_sq = tf.squeeze(value,
                                  axis=-1) if value.shape[-1] == 1 else value
            reward_sq = tf.squeeze(
                reward, axis=-1) if reward.shape[-1] == 1 else reward
            # Ignore reward loss for the first batch step
            current_value_loss, _, current_policy_loss = self.loss_function(
                value_sq,
                reward_sq,
                policy_logits,
                target_value[:, 0],
                target_reward[:, 0],
                target_policy[:, 0],
            )
            value_loss += current_value_loss
            policy_loss += current_policy_loss
            # Compute priorities for the prioritized replay (See paper appendix Training)
            pred_value_scalar = (models.support_to_scalar(
                value, self.config.support_size).numpy().squeeze())
            priorities[:, 0] = (
                numpy.abs(pred_value_scalar -
                          target_value_scalar[:, 0])**self.config.PER_alpha)

            for i in range(1, len(predictions)):
                value, reward, policy_logits = predictions[i]
                value_sq = tf.squeeze(
                    value, axis=-1) if value.shape[-1] == 1 else value
                reward_sq = tf.squeeze(
                    reward, axis=-1) if reward.shape[-1] == 1 else reward
                (
                    current_value_loss,
                    current_reward_loss,
                    current_policy_loss,
                ) = self.loss_function(
                    value_sq,
                    reward_sq,
                    policy_logits,
                    target_value[:, i],
                    target_reward[:, i],
                    target_policy[:, i],
                )
                # Scale gradient by the number of unroll steps (See paper appendix Training)
                current_value_loss = scale_gradient(current_value_loss,
                                                    gradient_scale_batch[:, i])
                current_reward_loss = scale_gradient(
                    current_reward_loss, gradient_scale_batch[:, i])
                current_policy_loss = scale_gradient(
                    current_policy_loss, gradient_scale_batch[:, i])

                value_loss += current_value_loss
                reward_loss += current_reward_loss
                policy_loss += current_policy_loss

                # Compute priorities for the prioritized replay (See paper appendix Training)
                pred_value_scalar = (models.support_to_scalar(
                    value, self.config.support_size).numpy().squeeze())
                priorities[:, i] = (numpy.abs(pred_value_scalar -
                                              target_value_scalar[:, i])**
                                    self.config.PER_alpha)

            l2_loss = 0
            for t in self.model.trainable_variables:
                l2_loss += self.config.weight_decay * tf.nn.l2_loss(t).numpy()

            # Scale the value loss, paper recommends by 0.25 (See paper appendix Reanalyze)
            loss = value_loss * self.config.value_loss_weight + reward_loss + policy_loss
            if self.config.PER:
                # Correct PER bias by using importance-sampling (IS) weights
                loss *= weight_batch
            # Mean over batch dimension (pseudocode do a sum)
            loss = tf.math.reduce_mean(loss) + l2_loss

            result.append(priorities)
            # For log purpose
            result.append(loss.numpy())
            result.append(tf.math.reduce_mean(value_loss).numpy())
            result.append(tf.math.reduce_mean(reward_loss).numpy())
            result.append(tf.math.reduce_mean(policy_loss).numpy())

            return loss
    def update_weights(self, batch):
        """
        Perform one training step.
        """

        (
            observation_batch,
            action_batch,
            target_value,
            target_reward,
            target_policy,
            weight_batch,
            gradient_scale_batch,
        ) = batch

        # Keep values as scalars for calculating the priorities for the prioritized replay
        target_value_scalar = numpy.array(target_value, dtype="float32")
        priorities = numpy.zeros_like(target_value_scalar)

        device = next(self.model.parameters()).device
        if self.config.PER:
            weight_batch = torch.tensor(weight_batch.copy()).float().to(device)
        observation_batch = torch.tensor(observation_batch).float().to(device)
        action_batch = torch.tensor(action_batch).long().to(device).unsqueeze(
            -1)
        target_value = torch.tensor(target_value).float().to(device)
        target_reward = torch.tensor(target_reward).float().to(device)
        target_policy = torch.tensor(target_policy).float().to(device)
        gradient_scale_batch = torch.tensor(gradient_scale_batch).float().to(
            device)
        # observation_batch: batch, channels, height, width
        # action_batch: batch, num_unroll_steps+1, 1 (unsqueeze)
        # target_value: batch, num_unroll_steps+1
        # target_reward: batch, num_unroll_steps+1
        # target_policy: batch, num_unroll_steps+1, len(action_space)
        # gradient_scale_batch: batch, num_unroll_steps+1

        target_value = models.scalar_to_support(target_value,
                                                self.config.support_size)
        target_reward = models.scalar_to_support(target_reward,
                                                 self.config.support_size)
        # target_value: batch, num_unroll_steps+1, 2*support_size+1
        # target_reward: batch, num_unroll_steps+1, 2*support_size+1

        ## Generate predictions
        value, reward, policy_logits, hidden_state = self.model.initial_inference(
            observation_batch)
        predictions = [(value, reward, policy_logits)]
        for i in range(1, action_batch.shape[1]):
            value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
                hidden_state, action_batch[:, i])
            # Scale the gradient at the start of the dynamics function (See paper appendix Training)
            hidden_state.register_hook(lambda grad: grad * 0.5)
            predictions.append((value, reward, policy_logits))
        # predictions: num_unroll_steps+1, 3, batch, 2*support_size+1 | 2*support_size+1 | 9 (according to the 2nd dim)

        ## Compute losses
        value_loss, reward_loss, policy_loss = (0, 0, 0)
        value, reward, policy_logits = predictions[0]
        # Ignore reward loss for the first batch step
        current_value_loss, _, current_policy_loss = self.loss_function(
            value.squeeze(-1),
            reward.squeeze(-1),
            policy_logits,
            target_value[:, 0],
            target_reward[:, 0],
            target_policy[:, 0],
        )
        value_loss += current_value_loss
        policy_loss += current_policy_loss
        # Compute priorities for the prioritized replay (See paper appendix Training)
        pred_value_scalar = (models.support_to_scalar(
            value, self.config.support_size).detach().cpu().numpy().squeeze())
        priorities[:, 0] = (
            numpy.abs(pred_value_scalar -
                      target_value_scalar[:, 0])**self.config.PER_alpha)

        for i in range(1, len(predictions)):
            value, reward, policy_logits = predictions[i]
            (
                current_value_loss,
                current_reward_loss,
                current_policy_loss,
            ) = self.loss_function(
                value.squeeze(-1),
                reward.squeeze(-1),
                policy_logits,
                target_value[:, i],
                target_reward[:, i],
                target_policy[:, i],
            )

            # Scale gradient by the number of unroll steps (See paper appendix Training)
            current_value_loss.register_hook(
                lambda grad: grad / gradient_scale_batch[:, i])
            current_reward_loss.register_hook(
                lambda grad: grad / gradient_scale_batch[:, i])
            current_policy_loss.register_hook(
                lambda grad: grad / gradient_scale_batch[:, i])

            value_loss += current_value_loss
            reward_loss += current_reward_loss
            policy_loss += current_policy_loss

            # Compute priorities for the prioritized replay (See paper appendix Training)
            pred_value_scalar = (models.support_to_scalar(
                value,
                self.config.support_size).detach().cpu().numpy().squeeze())
            priorities[:, i] = (
                numpy.abs(pred_value_scalar -
                          target_value_scalar[:, i])**self.config.PER_alpha)

        # Scale the value loss, paper recommends by 0.25 (See paper appendix Reanalyze)
        loss = value_loss * self.config.value_loss_weight + reward_loss + policy_loss
        if self.config.PER:
            # Correct PER bias by using importance-sampling (IS) weights
            loss *= weight_batch
        # Mean over batch dimension (pseudocode do a sum)
        loss = loss.mean()

        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.training_step += 1

        return (
            priorities,
            # For log purpose
            loss.item(),
            value_loss.mean().item(),
            reward_loss.mean().item(),
            policy_loss.mean().item(),
        )