def compare_virtual_with_real_trajectories(self,
                                               first_obs,
                                               game,
                                               horizon,
                                               plot=True):
        """
        First, MuZero plays a game but uses its model instead of using the environment.
        Then, MuZero plays the optimal trajectory according precedent trajectory but performs it in the
        real environment until arriving at an action impossible in the real environment.
        It does an MCTS too, but doesn't take it into account.
        All information during the two trajectories are recorded and displayed.
        """
        virtual_trajectory_info = self.get_virtual_trajectory_from_obs(
            first_obs, horizon, False)
        real_trajectory_info = Trajectoryinfo("Real trajectory", self.config)
        trajectory_divergence_index = None
        real_trajectory_end_reason = "Reached horizon"

        # Illegal moves are masked at the root
        root, mcts_info = MCTS(self.config).run(
            self.model,
            first_obs,
            game.legal_actions(),
            game.to_play(),
            True,
        )
        self.plot_mcts(root, plot)
        real_trajectory_info.store_info(root, mcts_info, None, numpy.NaN)
        for i, action in enumerate(virtual_trajectory_info.action_history):
            # Follow virtual trajectory until it reaches an illegal move in the real env
            if action not in game.legal_actions():
                break  # Comment to keep playing after trajectory divergence
                action = SelfPlay.select_action(root, 0)
                if trajectory_divergence_index is None:
                    trajectory_divergence_index = i
                    real_trajectory_end_reason = f"Virtual trajectory reached an illegal move at timestep {trajectory_divergence_index}."

            observation, reward, done = game.step(action)
            root, mcts_info = MCTS(self.config).run(
                self.model,
                observation,
                game.legal_actions(),
                game.to_play(),
                True,
            )
            real_trajectory_info.store_info(root, mcts_info, action, reward)
            if done:
                real_trajectory_end_reason = "Real trajectory reached Done"
                break

        if plot:
            virtual_trajectory_info.plot_trajectory()
            real_trajectory_info.plot_trajectory()
            print(real_trajectory_end_reason)

        return (
            virtual_trajectory_info,
            real_trajectory_info,
            trajectory_divergence_index,
        )
    def get_virtual_trajectory_from_obs(self,
                                        observation,
                                        horizon,
                                        plot=True,
                                        to_play=0):
        """
        MuZero plays a game but uses its model instead of using the environment.
        We still do an MCTS at each step.
        """
        trajectory_info = Trajectoryinfo("Virtual trajectory", self.config)
        root, mcts_info = MCTS(self.config).run(self.model, observation,
                                                self.config.action_space,
                                                to_play, True)
        trajectory_info.store_info(root, mcts_info, None, numpy.NaN)

        virtual_to_play = to_play
        for i in range(horizon):
            action = SelfPlay.select_action(root, 0)

            # Players play turn by turn
            if virtual_to_play + 1 < len(self.config.players):
                virtual_to_play = self.config.players[virtual_to_play + 1]
            else:
                virtual_to_play = self.config.players[0]

            # Generate new root
            # TODO: Test keeping the old root
            value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
                root.hidden_state,
                torch.tensor([[action]]).to(root.hidden_state.device),
            )
            value = models.support_to_scalar(value,
                                             self.config.support_size).item()
            reward = models.support_to_scalar(reward,
                                              self.config.support_size).item()
            root = Node(0)
            root.expand(
                self.config.action_space,
                virtual_to_play,
                reward,
                policy_logits,
                hidden_state,
            )

            root, mcts_info = MCTS(self.config).run(self.model, None,
                                                    self.config.action_space,
                                                    virtual_to_play, True,
                                                    root)
            trajectory_info.store_info(root,
                                       mcts_info,
                                       action,
                                       reward,
                                       new_prior_root_value=value)

        if plot:
            self.plot_trajectory(trajectory_info)

        return trajectory_info
Beispiel #3
0
def play_against_algorithm(weight_file_path,
                           config_name,
                           seed,
                           algo="expert",
                           render=False):
    np.random.seed(seed)
    torch.manual_seed(seed)

    game_module = importlib.import_module("games." + config_name)
    config = game_module.MuZeroConfig()
    model = models.MuZeroNetwork(config)
    model.set_weights(torch.load(weight_file_path))
    model.eval()

    algo = globals()[algo.capitalize()](-1, 1)

    game = Game(seed)
    observation = game.reset()

    game_history = GameHistory()
    game_history.action_history.append(0)
    game_history.reward_history.append(0)
    game_history.to_play_history.append(game.to_play())
    game_history.legal_actions.append(game.legal_actions())
    game_history.observation_history.append(observation)

    done = False
    depth = 9
    reward = 0

    while not done:
        if game.to_play_real() == -1:
            action = algo(game.get_state(), depth, game.to_play_real())
        else:
            stacked_observations = game_history.get_stacked_observations(
                -1,
                config.stacked_observations,
            )

            root, priority, tree_depth = MCTS(config).run(
                model,
                stacked_observations,
                game.legal_actions(),
                game.to_play(),
                False,
            )

            action = SelfPlay.select_action(
                root,
                0,
            )

            game_history.store_search_statistics(root, config.action_space)
            game_history.priorities.append(priority)
        observation, reward, done = game.step(action)
        if render:
            game.render()
        depth -= 1

        game_history.action_history.append(action)
        game_history.observation_history.append(observation)
        game_history.reward_history.append(reward)
        game_history.to_play_history.append(game.to_play())
        game_history.legal_actions.append(game.legal_actions())

    return reward, TictactoeComp.wins(game.get_state(), 1)
Beispiel #4
0
def play_against_other(weights1,
                       config1,
                       weights2,
                       config2,
                       seed,
                       render=False):
    np.random.seed(seed)
    torch.manual_seed(seed)
    game_module = importlib.import_module("games." + config1)
    config1 = game_module.MuZeroConfig()
    model1 = models.MuZeroNetwork(config1)
    model1.set_weights(torch.load(weights1))
    model1.eval()

    game_module = importlib.import_module("games." + config2)
    config2 = game_module.MuZeroConfig()
    model2 = models.MuZeroNetwork(config2)
    model2.set_weights(torch.load(weights2))
    model2.eval()

    game = Game(seed)
    observation = game.reset()

    game_history1 = GameHistory()
    game_history1.action_history.append(0)
    game_history1.reward_history.append(0)
    game_history1.to_play_history.append(game.to_play())
    game_history1.legal_actions.append(game.legal_actions())
    observation1 = copy.deepcopy(observation)
    # observation1[0] = -observation1[1]
    # observation1[1] = -observation1[0]
    # observation1[2] = -observation1[2]
    game_history1.observation_history.append(observation1)

    game_history2 = GameHistory()
    game_history2.action_history.append(0)
    game_history2.reward_history.append(0)
    game_history2.to_play_history.append(not game.to_play())
    game_history2.legal_actions.append(game.legal_actions())
    observation2 = copy.deepcopy(observation)
    observation2[0] = -observation2[1]
    observation2[1] = -observation2[0]
    observation2[2] = -observation2[2]
    game_history2.observation_history.append(observation2)

    done = False
    reward = 0

    while not done:
        if game.to_play_real() == 1:
            config = config1
            model = model1
            game_history = game_history1
        else:
            config = config2
            model = model2
            game_history = game_history2

        stacked_observations = game_history.get_stacked_observations(
            -1,
            config.stacked_observations,
        )

        root, priority, tree_depth = MCTS(config).run(
            model,
            stacked_observations,
            game.legal_actions(),
            game.to_play(),
            False,
        )

        action = SelfPlay.select_action(
            root,
            0,
        )

        game_history1.store_search_statistics(root, config.action_space)
        game_history1.priorities.append(priority)
        game_history2.store_search_statistics(root, config.action_space)
        game_history2.priorities.append(priority)
        observation, reward, done = game.step(action)
        if render:
            game.render()

        game_history1.action_history.append(action)
        observation1 = copy.deepcopy(observation)
        # observation1[0] = -observation1[1]
        # observation1[1] = -observation1[0]
        # observation1[2] = -observation1[2]
        game_history1.observation_history.append(observation1)
        game_history1.reward_history.append(reward)
        game_history1.to_play_history.append(game.to_play())
        game_history1.legal_actions.append(game.legal_actions())

        game_history2.action_history.append(action)
        observation2 = copy.deepcopy(observation)
        observation2[0] = -observation2[1]
        observation2[1] = -observation2[0]
        observation2[2] = -observation2[2]
        game_history2.observation_history.append(observation2)
        game_history2.reward_history.append(reward)
        game_history2.to_play_history.append(not game.to_play())
        game_history2.legal_actions.append(game.legal_actions())

    return reward, TictactoeComp.wins(game.get_state(), 1)
    def add_action(self,
                   opponent: str,
                   temperature: float = 0,
                   temperature_threshold: float = 0,
                   human_action: int = None) -> (int, str, list):
        with torch.no_grad():
            if self.done:
                raise ValueError(
                    "Status is already 'done' but there still another step.")
            if len(self.game_history.action_history) > self.config.max_moves:
                raise ValueError(
                    "Number of steps are already over the max moves.")

            stacked_observations = self.game_history.get_stacked_observations(
                -1,
                self.config.stacked_observations,
            )

            # Choose the action
            action = None
            if opponent == "self":
                root, mcts_info = MCTS(self.config).run(
                    self.model,
                    stacked_observations,
                    self.game.legal_actions(),
                    self.game.to_play(),
                    True,
                )
                action = SelfPlay.select_action(
                    root,
                    temperature if not temperature_threshold
                    or len(self.game_history.action_history) <
                    temperature_threshold else 0,
                )
            elif opponent == "random":
                action, root = numpy.random.choice(
                    self.game.legal_actions()), None
            elif opponent == "expert":
                action, root = self.game.expert_agent(), None
            elif opponent == "human":
                action, root = human_action, None
            else:
                raise ValueError(
                    'Wrong argument: "opponent" argument should be "self", "human", "expert" or "random"'
                )

            # cast action variable
            action = int(action)

            if action is None or not action in self.game.legal_actions():
                if (opponent == "human"):
                    raise ValueError(
                        f"Requested action '{action}' is illegal in this game."
                    )
                else:
                    raise Exception(
                        f"Calculated action '{action}' by '{opponent}' is illegal in this game."
                    )
            observation, reward, self.done = self.game.step(action)

            self.game_history.store_search_statistics(root,
                                                      self.config.action_space)

            # Next batch
            self.game_history.action_history.append(action)
            self.game_history.observation_history.append(observation)
            self.game_history.reward_history.append(reward)
            self.game_history.to_play_history.append(self.game.to_play())

            if isinstance(observation, numpy.ndarray):
                observation = observation.tolist()
            return action, self.game.action_to_string(action), observation