def compare_virtual_with_real_trajectories(self,
                                               first_obs,
                                               game,
                                               horizon,
                                               plot=True):
        """
        First, MuZero plays a game but uses its model instead of using the environment.
        Then, MuZero plays the optimal trajectory according precedent trajectory but performs it in the
        real environment until arriving at an action impossible in the real environment.
        It does an MCTS too, but doesn't take it into account.
        All information during the two trajectories are recorded and displayed.
        """
        virtual_trajectory_info = self.get_virtual_trajectory_from_obs(
            first_obs, horizon, False)
        real_trajectory_info = Trajectoryinfo("Real trajectory", self.config)
        trajectory_divergence_index = None
        real_trajectory_end_reason = "Reached horizon"

        # Illegal moves are masked at the root
        root, mcts_info = MCTS(self.config).run(
            self.model,
            first_obs,
            game.legal_actions(),
            game.to_play(),
            True,
        )
        self.plot_mcts(root, plot)
        real_trajectory_info.store_info(root, mcts_info, None, numpy.NaN)
        for i, action in enumerate(virtual_trajectory_info.action_history):
            # Follow virtual trajectory until it reaches an illegal move in the real env
            if action not in game.legal_actions():
                break  # Comment to keep playing after trajectory divergence
                action = SelfPlay.select_action(root, 0)
                if trajectory_divergence_index is None:
                    trajectory_divergence_index = i
                    real_trajectory_end_reason = f"Virtual trajectory reached an illegal move at timestep {trajectory_divergence_index}."

            observation, reward, done = game.step(action)
            root, mcts_info = MCTS(self.config).run(
                self.model,
                observation,
                game.legal_actions(),
                game.to_play(),
                True,
            )
            real_trajectory_info.store_info(root, mcts_info, action, reward)
            if done:
                real_trajectory_end_reason = "Real trajectory reached Done"
                break

        if plot:
            virtual_trajectory_info.plot_trajectory()
            real_trajectory_info.plot_trajectory()
            print(real_trajectory_end_reason)

        return (
            virtual_trajectory_info,
            real_trajectory_info,
            trajectory_divergence_index,
        )
Esempio n. 2
0
    def get_virtual_trajectory_from_obs(self,
                                        observation,
                                        horizon,
                                        plot=True,
                                        to_play=0):
        """
        MuZero plays a game but uses its model instead of using the environment.
        We still do an MCTS at each step.
        """
        trajectory_info = Trajectoryinfo("Virtual trajectory", self.config)
        root, mcts_info = MCTS(self.config).run(self.model, observation,
                                                self.config.action_space,
                                                to_play, True)
        trajectory_info.store_info(root, mcts_info, None, numpy.NaN)

        virtual_to_play = to_play
        for i in range(horizon):
            action = SelfPlay.select_action(root, 0)

            # Players play turn by turn
            if virtual_to_play + 1 < len(self.config.players):
                virtual_to_play = self.config.players[virtual_to_play + 1]
            else:
                virtual_to_play = self.config.players[0]

            # Generate new root
            # TODO: Test keeping the old root
            value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
                root.hidden_state,
                torch.tensor([[action]]).to(root.hidden_state.device),
            )
            value = models.support_to_scalar(value,
                                             self.config.support_size).item()
            reward = models.support_to_scalar(reward,
                                              self.config.support_size).item()
            root = Node(0)
            root.expand(
                self.config.action_space,
                virtual_to_play,
                reward,
                policy_logits,
                hidden_state,
            )

            root, mcts_info = MCTS(self.config).run(self.model, None,
                                                    self.config.action_space,
                                                    virtual_to_play, True,
                                                    root)
            trajectory_info.store_info(root,
                                       mcts_info,
                                       action,
                                       reward,
                                       new_prior_root_value=value)

        if plot:
            self.plot_trajectory(trajectory_info)

        return trajectory_info
Esempio n. 3
0
    def update_policies(self):
        while True:
            keys = ray.get(self.replay_buffer.get_buffer_keys.remote())
            for game_id in keys:
                remcts_count = 0
                self.latest_network.set_weights(
                    ray.get(self.shared_storage.get_network_weights.remote()))
                self.target_network.set_weights(
                    ray.get(self.shared_storage.get_target_network_weights.
                            remote()))

                game_history = copy.deepcopy(
                    ray.get(
                        self.replay_buffer.get_game_history.remote(game_id)))

                for pos in range(len(game_history.observation_history)):
                    bootstrap_index = pos + self.config.td_steps
                    if bootstrap_index < len(game_history.root_values):
                        if self.config.use_last_model_value:
                            # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
                            observation = torch.tensor(
                                game_history.get_stacked_observations(
                                    bootstrap_index,
                                    self.config.stacked_observations)).float()
                            value = models.support_to_scalar(
                                self.target_network.initial_inference(
                                    observation)[0],
                                self.config.support_size,
                            ).item()
                            game_history.root_values[bootstrap_index] = value

                    if random.random(
                    ) < self.config.policy_update_rate and pos < len(
                            game_history.root_values):
                        with torch.no_grad():
                            stacked_obs = torch.tensor(
                                game_history.get_stacked_observations(
                                    pos,
                                    self.config.stacked_observations)).float()

                            root, _, _ = MCTS(self.config).run(
                                self.latest_network, stacked_obs,
                                game_history.legal_actions[pos],
                                game_history.to_play_history[pos], False)
                            game_history.store_search_statistics(
                                root, self.config.action_space, pos)
                        remcts_count += 1

                self.shared_storage.update_infos.remote(
                    "remcts_count", remcts_count)
                self.shared_storage.update_infos.remote(
                    "reanalyzed_count", len(game_history.priorities))
                self.replay_buffer.update_game.remote(game_history, game_id)
Esempio n. 4
0
    def reanalyse_policy_and_value(self, game_history):
        """

        :param self_play.GameHistory game_history:
        :return:
        """
        game_history.reanalysed_predicted_root_values = None
        game_history.reanalysed_child_visits = None
        for i in range(len(game_history.root_values)):
            observation = game_history.get_stacked_observations(
                i, self.config.stacked_observations)
            root, _ = MCTS(self.config).run(
                self.model,
                observation,
                game_history.legal_actions_history[i],
                game_history.to_play_history[i],
                True,
            )
            game_history.store_search_statistics(root,
                                                 self.config.action_space,
                                                 reanalysed=True)
Esempio n. 5
0
def play_against_algorithm(weight_file_path,
                           config_name,
                           seed,
                           algo="expert",
                           render=False):
    np.random.seed(seed)
    torch.manual_seed(seed)

    game_module = importlib.import_module("games." + config_name)
    config = game_module.MuZeroConfig()
    model = models.MuZeroNetwork(config)
    model.set_weights(torch.load(weight_file_path))
    model.eval()

    algo = globals()[algo.capitalize()](-1, 1)

    game = Game(seed)
    observation = game.reset()

    game_history = GameHistory()
    game_history.action_history.append(0)
    game_history.reward_history.append(0)
    game_history.to_play_history.append(game.to_play())
    game_history.legal_actions.append(game.legal_actions())
    game_history.observation_history.append(observation)

    done = False
    depth = 9
    reward = 0

    while not done:
        if game.to_play_real() == -1:
            action = algo(game.get_state(), depth, game.to_play_real())
        else:
            stacked_observations = game_history.get_stacked_observations(
                -1,
                config.stacked_observations,
            )

            root, priority, tree_depth = MCTS(config).run(
                model,
                stacked_observations,
                game.legal_actions(),
                game.to_play(),
                False,
            )

            action = SelfPlay.select_action(
                root,
                0,
            )

            game_history.store_search_statistics(root, config.action_space)
            game_history.priorities.append(priority)
        observation, reward, done = game.step(action)
        if render:
            game.render()
        depth -= 1

        game_history.action_history.append(action)
        game_history.observation_history.append(observation)
        game_history.reward_history.append(reward)
        game_history.to_play_history.append(game.to_play())
        game_history.legal_actions.append(game.legal_actions())

    return reward, TictactoeComp.wins(game.get_state(), 1)
Esempio n. 6
0
def play_against_other(weights1,
                       config1,
                       weights2,
                       config2,
                       seed,
                       render=False):
    np.random.seed(seed)
    torch.manual_seed(seed)
    game_module = importlib.import_module("games." + config1)
    config1 = game_module.MuZeroConfig()
    model1 = models.MuZeroNetwork(config1)
    model1.set_weights(torch.load(weights1))
    model1.eval()

    game_module = importlib.import_module("games." + config2)
    config2 = game_module.MuZeroConfig()
    model2 = models.MuZeroNetwork(config2)
    model2.set_weights(torch.load(weights2))
    model2.eval()

    game = Game(seed)
    observation = game.reset()

    game_history1 = GameHistory()
    game_history1.action_history.append(0)
    game_history1.reward_history.append(0)
    game_history1.to_play_history.append(game.to_play())
    game_history1.legal_actions.append(game.legal_actions())
    observation1 = copy.deepcopy(observation)
    # observation1[0] = -observation1[1]
    # observation1[1] = -observation1[0]
    # observation1[2] = -observation1[2]
    game_history1.observation_history.append(observation1)

    game_history2 = GameHistory()
    game_history2.action_history.append(0)
    game_history2.reward_history.append(0)
    game_history2.to_play_history.append(not game.to_play())
    game_history2.legal_actions.append(game.legal_actions())
    observation2 = copy.deepcopy(observation)
    observation2[0] = -observation2[1]
    observation2[1] = -observation2[0]
    observation2[2] = -observation2[2]
    game_history2.observation_history.append(observation2)

    done = False
    reward = 0

    while not done:
        if game.to_play_real() == 1:
            config = config1
            model = model1
            game_history = game_history1
        else:
            config = config2
            model = model2
            game_history = game_history2

        stacked_observations = game_history.get_stacked_observations(
            -1,
            config.stacked_observations,
        )

        root, priority, tree_depth = MCTS(config).run(
            model,
            stacked_observations,
            game.legal_actions(),
            game.to_play(),
            False,
        )

        action = SelfPlay.select_action(
            root,
            0,
        )

        game_history1.store_search_statistics(root, config.action_space)
        game_history1.priorities.append(priority)
        game_history2.store_search_statistics(root, config.action_space)
        game_history2.priorities.append(priority)
        observation, reward, done = game.step(action)
        if render:
            game.render()

        game_history1.action_history.append(action)
        observation1 = copy.deepcopy(observation)
        # observation1[0] = -observation1[1]
        # observation1[1] = -observation1[0]
        # observation1[2] = -observation1[2]
        game_history1.observation_history.append(observation1)
        game_history1.reward_history.append(reward)
        game_history1.to_play_history.append(game.to_play())
        game_history1.legal_actions.append(game.legal_actions())

        game_history2.action_history.append(action)
        observation2 = copy.deepcopy(observation)
        observation2[0] = -observation2[1]
        observation2[1] = -observation2[0]
        observation2[2] = -observation2[2]
        game_history2.observation_history.append(observation2)
        game_history2.reward_history.append(reward)
        game_history2.to_play_history.append(not game.to_play())
        game_history2.legal_actions.append(game.legal_actions())

    return reward, TictactoeComp.wins(game.get_state(), 1)
Esempio n. 7
0
    def make_target(self, game_history, state_index):
        """
        Generate targets for every unroll steps.
        """
        target_values, target_rewards, target_policies, actions = [], [], [], []
        for current_index in range(
                state_index, state_index + self.config.num_unroll_steps + 1):
            # The value target is the discounted root value of the search tree td_steps into the
            # future, plus the discounted sum of all rewards until then.
            bootstrap_index = current_index + self.config.td_steps
            if bootstrap_index < len(game_history.root_values):
                if self.config.use_last_model_value:
                    # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
                    observation = torch.tensor(
                        game_history.get_stacked_observations(
                            bootstrap_index,
                            self.config.stacked_observations)).float()
                    last_step_value = models.support_to_scalar(
                        self.target_network.initial_inference(observation)[0],
                        self.config.support_size,
                    ).item()
                else:
                    last_step_value = game_history.root_values[bootstrap_index]

                value = last_step_value * self.config.discount**self.config.td_steps
            else:
                value = 0

            for i, reward in enumerate(
                    game_history.reward_history[current_index +
                                                1:bootstrap_index + 1]):
                value += (reward if game_history.to_play_history[current_index]
                          == game_history.to_play_history[current_index + 1 +
                                                          i] else
                          -reward) * self.config.discount**i

            if current_index < len(game_history.root_values):
                if random.random(
                ) < self.config.policy_update_rate and current_index < len(
                        game_history.root_values):
                    with torch.no_grad():
                        stacked_obs = torch.tensor(
                            game_history.get_stacked_observations(
                                current_index,
                                self.config.stacked_observations)).float()

                        root, _, _ = MCTS(self.config).run(
                            self.latest_network, stacked_obs,
                            game_history.legal_actions[current_index],
                            game_history.to_play_history[current_index], False)
                        game_history.store_search_statistics(
                            root, self.config.action_space, current_index)
                target_values.append(value)
                target_rewards.append(
                    game_history.reward_history[current_index])
                target_policies.append(
                    game_history.child_visits[current_index])
                actions.append(game_history.action_history[current_index])
            elif current_index == len(game_history.root_values):
                target_values.append(0)
                target_rewards.append(
                    game_history.reward_history[current_index])
                # Uniform policy
                target_policies.append([
                    1 / len(game_history.child_visits[0])
                    for _ in range(len(game_history.child_visits[0]))
                ])
                actions.append(game_history.action_history[current_index])
            else:
                # States past the end of games are treated as absorbing states
                target_values.append(0)
                target_rewards.append(0)
                # Uniform policy
                target_policies.append([
                    1 / len(game_history.child_visits[0])
                    for _ in range(len(game_history.child_visits[0]))
                ])
                actions.append(np.random.choice(game_history.action_history))

        return target_values, target_rewards, target_policies, actions
    def add_action(self,
                   opponent: str,
                   temperature: float = 0,
                   temperature_threshold: float = 0,
                   human_action: int = None) -> (int, str, list):
        with torch.no_grad():
            if self.done:
                raise ValueError(
                    "Status is already 'done' but there still another step.")
            if len(self.game_history.action_history) > self.config.max_moves:
                raise ValueError(
                    "Number of steps are already over the max moves.")

            stacked_observations = self.game_history.get_stacked_observations(
                -1,
                self.config.stacked_observations,
            )

            # Choose the action
            action = None
            if opponent == "self":
                root, mcts_info = MCTS(self.config).run(
                    self.model,
                    stacked_observations,
                    self.game.legal_actions(),
                    self.game.to_play(),
                    True,
                )
                action = SelfPlay.select_action(
                    root,
                    temperature if not temperature_threshold
                    or len(self.game_history.action_history) <
                    temperature_threshold else 0,
                )
            elif opponent == "random":
                action, root = numpy.random.choice(
                    self.game.legal_actions()), None
            elif opponent == "expert":
                action, root = self.game.expert_agent(), None
            elif opponent == "human":
                action, root = human_action, None
            else:
                raise ValueError(
                    'Wrong argument: "opponent" argument should be "self", "human", "expert" or "random"'
                )

            # cast action variable
            action = int(action)

            if action is None or not action in self.game.legal_actions():
                if (opponent == "human"):
                    raise ValueError(
                        f"Requested action '{action}' is illegal in this game."
                    )
                else:
                    raise Exception(
                        f"Calculated action '{action}' by '{opponent}' is illegal in this game."
                    )
            observation, reward, self.done = self.game.step(action)

            self.game_history.store_search_statistics(root,
                                                      self.config.action_space)

            # Next batch
            self.game_history.action_history.append(action)
            self.game_history.observation_history.append(observation)
            self.game_history.reward_history.append(reward)
            self.game_history.to_play_history.append(self.game.to_play())

            if isinstance(observation, numpy.ndarray):
                observation = observation.tolist()
            return action, self.game.action_to_string(action), observation