def get_virtual_trajectory_from_obs(self,
                                        observation,
                                        horizon,
                                        plot=True,
                                        to_play=0):
        """
        MuZero plays a game but uses its model instead of using the environment.
        We still do an MCTS at each step.
        """
        trajectory_info = Trajectoryinfo("Virtual trajectory", self.config)
        root, mcts_info = MCTS(self.config).run(self.model, observation,
                                                self.config.action_space,
                                                to_play, True)
        trajectory_info.store_info(root, mcts_info, None, numpy.NaN)

        virtual_to_play = to_play
        for i in range(horizon):
            action = SelfPlay.select_action(root, 0)

            # Players play turn by turn
            if virtual_to_play + 1 < len(self.config.players):
                virtual_to_play = self.config.players[virtual_to_play + 1]
            else:
                virtual_to_play = self.config.players[0]

            # Generate new root
            # TODO: Test keeping the old root
            value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
                root.hidden_state,
                torch.tensor([[action]]).to(root.hidden_state.device),
            )
            value = models.support_to_scalar(value,
                                             self.config.support_size).item()
            reward = models.support_to_scalar(reward,
                                              self.config.support_size).item()
            root = Node(0)
            root.expand(
                self.config.action_space,
                virtual_to_play,
                reward,
                policy_logits,
                hidden_state,
            )

            root, mcts_info = MCTS(self.config).run(self.model, None,
                                                    self.config.action_space,
                                                    virtual_to_play, True,
                                                    root)
            trajectory_info.store_info(root,
                                       mcts_info,
                                       action,
                                       reward,
                                       new_prior_root_value=value)

        if plot:
            self.plot_trajectory(trajectory_info)

        return trajectory_info
示例#2
0
    def reanalyse(self, replay_buffer, shared_storage):
        while shared_storage.get_info("num_played_games") < 1:
            time.sleep(0.1)

        while shared_storage.get_info("training_step") < self.config.training_steps and not shared_storage.get_info("terminate"):
            self.model.set_weights(shared_storage.get_info("weights"))

            game_id, game_history, _ = replay_buffer.sample_game(force_uniform=True)
            

            # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
            if self.config.use_last_model_value:
                observations = [
                    game_history.get_stacked_observations(
                        i, self.config.stacked_observations
                    )
                    for i in range(len(game_history.root_values))
                ]

                observations = (
                    torch.tensor(observations)
                    .float()
                    .to(next(self.model.parameters()).device)
                )
                values = models.support_to_scalar(
                    self.model.initial_inference(observations)[0],
                    self.config.support_size,
                )
                game_history.reanalysed_predicted_root_values = (
                    torch.squeeze(values).detach().numpy()
                )

            replay_buffer.update_game_history(game_id, game_history)
            self.num_reanalysed_games += 1
            shared_storage.set_info("num_reanalysed_games", self.num_reanalysed_games)
示例#3
0
    def compute_value(self, game_history, index):
        # The value target is the discounted root value of the search tree td_steps into the
        # future, plus the discounted sum of all rewards until then.
        bootstrap_index = index + self.config.td_steps
        if bootstrap_index < len(game_history.root_values):
            if self.config.use_last_model_value:
                # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
                observation = (torch.tensor(
                    game_history.get_stacked_observations(
                        bootstrap_index,
                        self.config.stacked_observations)).float().unsqueeze(0)
                               )
                last_step_value = models.support_to_scalar(
                    self.model.initial_inference(observation)[0],
                    self.config.support_size,
                ).item()
            else:
                last_step_value = game_history.root_values[bootstrap_index]

            value = last_step_value * self.config.discount**self.config.td_steps
        else:
            value = 0

        for i, reward in enumerate(
                game_history.reward_history[index + 1:bootstrap_index + 1]):
            value += (reward if game_history.to_play_history[index]
                      == game_history.to_play_history[index + 1 + i] else
                      -reward) * self.config.discount**i

        return value
示例#4
0
    def reanalyse(self, replay_buffer, shared_storage):
        while ray.get(
                shared_storage.get_info.remote())["num_played_games"] < 1:
            time.sleep(0.1)

        while (ray.get(shared_storage.get_info.remote())["training_step"] <
               self.config.training_steps):
            self.model.set_weights(
                copy.deepcopy(ray.get(shared_storage.get_weights.remote())))

            game_id, game_history, _ = ray.get(
                replay_buffer.sample_game.remote(force_uniform=True))

            # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
            if self.config.use_last_model_value:
                observations = [
                    game_history.get_stacked_observations(
                        i, self.config.stacked_observations)
                    for i in range(len(game_history.root_values))
                ]

            observations = (torch.tensor(observations).float().to(
                self.config.reanalyse_device))
            values = models.support_to_scalar(
                self.model.initial_inference(observations)[0],
                self.config.support_size, self.config.epsilon)
            for i in range(len(game_history.root_values)):
                game_history.root_values[i] = values[i].item()

            replay_buffer.update_game_history.remote(game_id, game_history)
            self.num_reanalysed_games += 1
            shared_storage.set_info.remote("num_reanalysed_games",
                                           self.num_reanalysed_games)
示例#5
0
    def update_policies(self):
        while True:
            keys = ray.get(self.replay_buffer.get_buffer_keys.remote())
            for game_id in keys:
                remcts_count = 0
                self.latest_network.set_weights(
                    ray.get(self.shared_storage.get_network_weights.remote()))
                self.target_network.set_weights(
                    ray.get(self.shared_storage.get_target_network_weights.
                            remote()))

                game_history = copy.deepcopy(
                    ray.get(
                        self.replay_buffer.get_game_history.remote(game_id)))

                for pos in range(len(game_history.observation_history)):
                    bootstrap_index = pos + self.config.td_steps
                    if bootstrap_index < len(game_history.root_values):
                        if self.config.use_last_model_value:
                            # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
                            observation = torch.tensor(
                                game_history.get_stacked_observations(
                                    bootstrap_index,
                                    self.config.stacked_observations)).float()
                            value = models.support_to_scalar(
                                self.target_network.initial_inference(
                                    observation)[0],
                                self.config.support_size,
                            ).item()
                            game_history.root_values[bootstrap_index] = value

                    if random.random(
                    ) < self.config.policy_update_rate and pos < len(
                            game_history.root_values):
                        with torch.no_grad():
                            stacked_obs = torch.tensor(
                                game_history.get_stacked_observations(
                                    pos,
                                    self.config.stacked_observations)).float()

                            root, _, _ = MCTS(self.config).run(
                                self.latest_network, stacked_obs,
                                game_history.legal_actions[pos],
                                game_history.to_play_history[pos], False)
                            game_history.store_search_statistics(
                                root, self.config.action_space, pos)
                        remcts_count += 1

                self.shared_storage.update_infos.remote(
                    "remcts_count", remcts_count)
                self.shared_storage.update_infos.remote(
                    "reanalyzed_count", len(game_history.priorities))
                self.replay_buffer.update_game.remote(game_history, game_id)
示例#6
0
    def run(
        self,
        model,
        observation,
        legal_actions,
        to_play,
        add_exploration_noise,
        override_root_with=None,
    ):
        """
        At the root of the search tree we use the representation function to obtain a
        hidden state given the current observation.
        We then run a Monte Carlo Tree Search using only action sequences and the model
        learned by the network.
        """
        if override_root_with:
            root = override_root_with
            root_predicted_value = None
        else:
            root = Node(0)
            observation = (torch.tensor(observation).float().unsqueeze(0).to(
                next(model.parameters()).device))
            (
                root_predicted_value,
                reward,
                policy_logits,
                hidden_state,
            ) = model.initial_inference(observation)
            root_predicted_value = models.support_to_scalar(
                root_predicted_value, self.config.support_size).item()
            reward = models.support_to_scalar(reward,
                                              self.config.support_size).item()
            assert (
                legal_actions
            ), f"Legal actions should not be an empty array. Got {legal_actions}."
            assert set(legal_actions).issubset(
                set(self.config.action_space
                    )), "Legal actions should be a subset of the action space."
            root.expand(
                legal_actions,
                to_play,
                reward,
                policy_logits,
                hidden_state,
            )

        if add_exploration_noise:
            root.add_exploration_noise(
                dirichlet_alpha=self.config.root_dirichlet_alpha,
                exploration_fraction=self.config.root_exploration_fraction,
            )

        min_max_stats = MinMaxStats()

        max_tree_depth = 0
        for _ in range(self.config.num_simulations):
            virtual_to_play = to_play
            node = root
            search_path = [node]
            current_tree_depth = 0

            while node.expanded():
                current_tree_depth += 1
                action, node = self.select_child(node, min_max_stats)
                search_path.append(node)

                # Players play turn by turn
                if virtual_to_play + 1 < len(self.config.players):
                    virtual_to_play = self.config.players[virtual_to_play + 1]
                else:
                    virtual_to_play = self.config.players[0]

            # Inside the search tree we use the dynamics function to obtain the next hidden
            # state given an action and the previous hidden state
            parent = search_path[-2]
            value, reward, policy_logits, hidden_state = model.recurrent_inference(
                parent.hidden_state,
                torch.tensor([[action]]).to(parent.hidden_state.device),
            )
            value = models.support_to_scalar(value,
                                             self.config.support_size).item()
            reward = models.support_to_scalar(reward,
                                              self.config.support_size).item()
            node.expand(
                self.config.action_space,
                virtual_to_play,
                reward,
                policy_logits,
                hidden_state,
            )

            self.backpropagate(search_path, value, virtual_to_play,
                               min_max_stats)

            max_tree_depth = max(max_tree_depth, current_tree_depth)

        extra_info = {
            "max_tree_depth": max_tree_depth,
            "root_predicted_value": root_predicted_value,
        }
        return root, extra_info
示例#7
0
    def update_weights(self, batch):
        """
        Perform one training step.
        """

        (
            observation_batch,
            action_batch,
            target_value,
            target_reward,
            target_policy,
            weight_batch,
            gradient_scale_batch,
        ) = batch

        # Keep values as scalars for calculating the priorities for the prioritized replay
        target_value_scalar = numpy.array(target_value)
        priorities = numpy.zeros_like(target_value_scalar)

        device = next(self.model.parameters()).device
        weight_batch = torch.tensor(weight_batch).float().to(device)
        observation_batch = torch.tensor(observation_batch).float().to(device)
        action_batch = torch.tensor(action_batch).float().to(device).unsqueeze(
            -1)
        target_value = torch.tensor(target_value).float().to(device)
        target_reward = torch.tensor(target_reward).float().to(device)
        target_policy = torch.tensor(target_policy).float().to(device)
        gradient_scale_batch = torch.tensor(gradient_scale_batch).float().to(
            device)
        # observation_batch: batch, channels, height, width
        # action_batch: batch, num_unroll_steps+1, 1 (unsqueeze)
        # target_value: batch, num_unroll_steps+1
        # target_reward: batch, num_unroll_steps+1
        # target_policy: batch, num_unroll_steps+1, len(action_space)
        # gradient_scale_batch: batch, num_unroll_steps+1

        target_value = models.scalar_to_support(target_value,
                                                self.config.support_size)
        target_reward = models.scalar_to_support(target_reward,
                                                 self.config.support_size)
        # target_value: batch, num_unroll_steps+1, 2*support_size+1
        # target_reward: batch, num_unroll_steps+1, 2*support_size+1

        ## Generate predictions
        value, reward, policy_logits, hidden_state = self.model.initial_inference(
            observation_batch)
        predictions = [(value, reward, policy_logits)]
        for i in range(1, action_batch.shape[1]):
            value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
                hidden_state, action_batch[:, i])
            # Scale the gradient at the start of the dynamics function (See paper appendix Training)
            hidden_state.register_hook(lambda grad: grad * 0.5)
            predictions.append((value, reward, policy_logits))
        # predictions: num_unroll_steps+1, 3, batch, 2*support_size+1 | 2*support_size+1 | 9 (according to the 2nd dim)

        ## Compute losses
        value_loss, reward_loss, policy_loss = (0, 0, 0)
        value, reward, policy_logits = predictions[0]
        # Ignore reward loss for the first batch step
        current_value_loss, _, current_policy_loss = self.loss_function(
            value.squeeze(-1),
            reward.squeeze(-1),
            policy_logits,
            target_value[:, 0],
            target_reward[:, 0],
            target_policy[:, 0],
        )
        value_loss += current_value_loss
        policy_loss += current_policy_loss
        # Compute priorities for the prioritized replay (See paper appendix Training)
        pred_value_scalar = (models.support_to_scalar(
            value, self.config.support_size).detach().cpu().numpy().squeeze())
        priorities[:, 0] = (
            numpy.abs(pred_value_scalar -
                      target_value_scalar[:, 0])**self.config.PER_alpha)

        for i in range(1, len(predictions)):
            value, reward, policy_logits = predictions[i]
            (
                current_value_loss,
                current_reward_loss,
                current_policy_loss,
            ) = self.loss_function(
                value.squeeze(-1),
                reward.squeeze(-1),
                policy_logits,
                target_value[:, i],
                target_reward[:, i],
                target_policy[:, i],
            )

            # Scale gradient by the number of unroll steps (See paper appendix Training)
            current_value_loss.register_hook(
                lambda grad: grad / gradient_scale_batch[:, i])
            current_reward_loss.register_hook(
                lambda grad: grad / gradient_scale_batch[:, i])
            current_policy_loss.register_hook(
                lambda grad: grad / gradient_scale_batch[:, i])

            value_loss += current_value_loss
            reward_loss += current_reward_loss
            policy_loss += current_policy_loss

            # Compute priorities for the prioritized replay (See paper appendix Training)
            pred_value_scalar = (models.support_to_scalar(
                value,
                self.config.support_size).detach().cpu().numpy().squeeze())
            priorities[:, i] = (
                numpy.abs(pred_value_scalar -
                          target_value_scalar[:, i])**self.config.PER_alpha)

        # Scale the value loss, paper recommends by 0.25 (See paper appendix Reanalyze)
        loss = value_loss * self.config.value_loss_weight + reward_loss + policy_loss
        if self.config.PER:
            # Correct PER bias by using importance-sampling (IS) weights
            loss *= weight_batch
        # Mean over batch dimension (pseudocode do a sum)
        loss = loss.mean()

        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.training_step += 1

        return (
            priorities,
            # For log purpose
            loss.item(),
            value_loss.mean().item(),
            reward_loss.mean().item(),
            policy_loss.mean().item(),
        )
示例#8
0
    def reanalyse(self, replay_buffer, shared_storage):
        while ray.get(shared_storage.get_info.remote("num_played_games")) < 1:
            time.sleep(0.1)

        while ray.get(shared_storage.get_info.remote(
                "training_step")) < self.config.training_steps and not ray.get(
                    shared_storage.get_info.remote("terminate")):
            self.model.set_weights(
                ray.get(shared_storage.get_info.remote("weights")))

            # update target model periodically
            if self.config.use_last_model_value:
                training_step = ray.get(
                    shared_storage.get_info.remote("training_step"))
                if (training_step - self.last_update_step
                    ) >= self.config.value_target_update_freq:
                    self.last_update_step = training_step
                    self.target_model.set_weights(
                        ray.get(shared_storage.get_info.remote("weights")))

            game_id, game_history = ray.get(
                replay_buffer.reanalyse_sample_game.remote())

            # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
            if self.config.use_last_model_value:

                # use the lagging network (representation + value) to obtain updated targets
                if not self.config.use_updated_mcts_value_targets:
                    observations = [
                        game_history.get_stacked_observations(
                            i, self.config.stacked_observations)
                        for i in range(len(game_history.root_values))
                    ]

                    observations = (torch.tensor(observations).float().to(
                        next(self.model.parameters()).device))
                    values = models.support_to_scalar(
                        # use lagging parameters
                        self.target_model.initial_inference(observations)[0],
                        self.config.support_size,
                    )
                    root_values = (
                        torch.squeeze(values).detach().cpu().numpy())

                # re-execute MCTS to update targets (child visist and root_values)
                l = len(game_history.root_values)
                game_history.child_visits = []
                game_history.root_values = []
                priorities = []
                for i in range(l):
                    stacked_observations = game_history.get_stacked_observations(
                        i,
                        self.config.stacked_observations,
                    )

                    root, mcts_info = MCTS(self.config).run(
                        # use either fresh or lagging (recent) parameters
                        self.target_model
                        if self.config.use_updated_mcts_value_targets else
                        self.model,
                        stacked_observations,
                        self.game.legal_actions(),
                        self.game.to_play(),
                        True,
                    )

                    game_history.store_search_statistics(
                        root, self.config.action_space)

                # use mcts values targets
                if self.config.use_updated_mcts_value_targets:
                    root_values = game_history.root_values

            # Update PER according to the initial prioritization (See paper appendix Training)
            if self.config.PER:
                priorities = []
                for i, root_value in enumerate(root_values):
                    priority = (numpy.abs(root_value - compute_target_value(
                        game_history, i, self.config.td_steps,
                        self.config.discount))**self.config.PER_alpha)
                    priorities.append(priority)

                game_history.priorities = numpy.array(priorities,
                                                      dtype="float32")
                game_history.game_priority = numpy.max(game_history.priorities)

            game_history.reanalysed_predicted_root_values = root_values
            replay_buffer.update_game_history.remote(game_id, game_history,
                                                     shared_storage)
示例#9
0
    def make_target(self, game_history, state_index):
        """
        Generate targets for every unroll steps.
        """
        target_values, target_rewards, target_policies, actions = [], [], [], []
        for current_index in range(
                state_index, state_index + self.config.num_unroll_steps + 1):
            # The value target is the discounted root value of the search tree td_steps into the
            # future, plus the discounted sum of all rewards until then.
            bootstrap_index = current_index + self.config.td_steps
            if bootstrap_index < len(game_history.root_values):
                if self.config.use_last_model_value:
                    # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
                    observation = (torch.tensor(
                        game_history.get_stacked_observations(
                            bootstrap_index, self.config.stacked_observations)
                    ).float().unsqueeze(0))
                    last_step_value = models.support_to_scalar(
                        self.model.initial_inference(observation)[0],
                        self.config.support_size,
                    ).item()
                else:
                    last_step_value = game_history.root_values[bootstrap_index]

                value = last_step_value * self.config.discount**self.config.td_steps
            else:
                value = 0

            for i, reward in enumerate(
                    game_history.reward_history[current_index +
                                                1:bootstrap_index + 1]):
                value += (reward if game_history.to_play_history[current_index]
                          == game_history.to_play_history[current_index + 1 +
                                                          i] else
                          -reward) * self.config.discount**i

            if current_index < len(game_history.root_values):
                target_values.append(value)
                target_rewards.append(
                    game_history.reward_history[current_index])
                target_policies.append(
                    game_history.child_visits[current_index])
                actions.append(game_history.action_history[current_index])
            elif current_index == len(game_history.root_values):
                target_values.append(0)
                target_rewards.append(
                    game_history.reward_history[current_index])
                # Uniform policy
                target_policies.append([
                    1 / len(game_history.child_visits[0])
                    for _ in range(len(game_history.child_visits[0]))
                ])
                actions.append(game_history.action_history[current_index])
            else:
                # States past the end of games are treated as absorbing states
                target_values.append(0)
                target_rewards.append(0)
                # Uniform policy
                target_policies.append([
                    1 / len(game_history.child_visits[0])
                    for _ in range(len(game_history.child_visits[0]))
                ])
                actions.append(numpy.random.choice(
                    game_history.action_history))

        return target_values, target_rewards, target_policies, actions