Exemple #1
0
class UCT():
    def __init__(self,
                 env_params,
                 max_steps=1000,
                 max_depth=20,
                 max_width=5,
                 gamma=1.0,
                 policy="Random",
                 seed=123,
                 device=torch.device("cpu")):
        self.env_params = env_params
        self.max_steps = max_steps
        self.max_depth = max_depth
        self.max_width = max_width
        self.gamma = gamma
        self.policy = policy
        self.seed = seed
        self.device = device

        self.policy_wrapper = None

        # Environment
        self.wrapped_env = EnvWrapper(**env_params)

        # Environment properties
        self.action_n = self.wrapped_env.get_action_n()
        self.max_width = min(self.action_n, self.max_width)

        assert self.max_depth > 0 and 0 < self.max_width <= self.action_n

        # Checkpoint data manager
        self.checkpoint_data_manager = CheckpointManager()
        self.checkpoint_data_manager.hock_env("main", self.wrapped_env)

        # For MCTS tree
        self.root_node = None
        self.global_saving_idx = 0

        self.init_policy()

    def init_policy(self):
        self.policy_wrapper = PolicyWrapper(self.policy,
                                            self.env_params["env_name"],
                                            self.action_n, self.device)

    # Entrance of the P-UCT algorithm
    def simulate_trajectory(self, max_episode_length=-1):
        state = self.wrapped_env.reset()
        accu_reward = 0.0
        done = False
        step_count = 0
        rewards = []
        times = []

        game_start_time = time.time()

        while not done and (max_episode_length == -1
                            or step_count < max_episode_length):
            simulation_start_time = time.time()
            action = self.simulate_single_move(state)
            simulation_end_time = time.time()

            next_state, reward, done = self.wrapped_env.step(action)
            rewards.append(reward)
            times.append(simulation_end_time - simulation_start_time)

            print(
                "> Time step {}, take action {}, instance reward {}, cumulative reward {}, used {} seconds"
                .format(step_count, action, reward, accu_reward + reward,
                        simulation_end_time - simulation_start_time))

            accu_reward += reward
            state = next_state
            step_count += 1

        game_end_time = time.time()
        print("> game ended. total reward: {}, used time {} s".format(
            accu_reward, game_end_time - game_start_time))

        return accu_reward, np.array(rewards, dtype=np.float32), np.array(
            times, dtype=np.float32)

    def simulate_single_move(self, state):
        # Clear cache
        self.root_node = None
        self.global_saving_idx = 0
        self.checkpoint_data_manager.clear()

        gc.collect()

        # Construct root node
        self.checkpoint_data_manager.checkpoint_env("main",
                                                    self.global_saving_idx)

        self.root_node = UCTnode(action_n=self.action_n,
                                 state=state,
                                 checkpoint_idx=self.global_saving_idx,
                                 parent=None,
                                 tree=self,
                                 is_head=True)

        self.global_saving_idx += 1

        for _ in range(self.max_steps):
            self.simulate_single_step()

        best_action = self.root_node.max_utility_action()

        self.checkpoint_data_manager.load_checkpoint_env(
            "main", self.root_node.checkpoint_idx)

        return best_action

    def simulate_single_step(self):
        # Go into root node
        curr_node = self.root_node

        # Selection
        curr_depth = 1
        while True:
            if curr_node.no_child_available() or (not curr_node.all_child_visited() and
                    curr_node != self.root_node and np.random.random() < 0.5) or \
                    (not curr_node.all_child_visited() and curr_node == self.root_node):
                # If no child node has been updated, we have to perform expansion anyway.
                # Or if root node is not fully visited.
                # Or if non-root node is not fully visited and {with prob 1/2}.

                need_expansion = True
                break

            else:
                action = curr_node.select_action()

            curr_node.update_history(action, curr_node.rewards[action])

            if curr_node.dones[action] or curr_depth >= self.max_depth:
                need_expansion = False
                break

            next_node = curr_node.children[action]

            curr_depth += 1
            curr_node = next_node

        # Expansion
        if need_expansion:
            expand_action = curr_node.select_expand_action()

            self.checkpoint_data_manager.load_checkpoint_env(
                "main", curr_node.checkpoint_idx)
            next_state, reward, done = self.wrapped_env.step(expand_action)
            self.checkpoint_data_manager.checkpoint_env(
                "main", self.global_saving_idx)

            curr_node.rewards[expand_action] = reward
            curr_node.dones[expand_action] = done

            curr_node.update_history(action_taken=expand_action, reward=reward)

            curr_node.add_child(expand_action,
                                next_state,
                                self.global_saving_idx,
                                prior_prob=self.get_prior_prob(next_state))
            self.global_saving_idx += 1
        else:
            self.checkpoint_data_manager.load_checkpoint_env(
                "main", curr_node.checkpoint_idx)
            next_state, reward, done = self.wrapped_env.step(action)

            curr_node.rewards[action] = reward
            curr_node.dones[action] = done

        # Simulation
        done = False
        accu_reward = 0.0
        accu_gamma = 1.0

        while not done:
            action = self.get_action(next_state)

            next_state, reward, done = self.wrapped_env.step(action)

            accu_reward += reward * accu_gamma

            accu_gamma *= self.gamma

        # Complete Update
        self.complete_update(curr_node, self.root_node, accu_reward)

    def get_action(self, state):
        return self.policy_wrapper.get_action(state)

    def get_prior_prob(self, state):
        return self.policy_wrapper.get_prior_prob(state)

    def close(self):
        pass

    @staticmethod
    def complete_update(curr_node, curr_node_head, accu_reward):
        while curr_node != curr_node_head:
            accu_reward = curr_node.update(accu_reward)
            curr_node = curr_node.parent

        curr_node_head.update(accu_reward)
Exemple #2
0
class WU_UCT():
    def __init__(self,
                 env_params,
                 max_steps=1000,
                 max_depth=20,
                 max_width=5,
                 gamma=1.0,
                 expansion_worker_num=16,
                 simulation_worker_num=16,
                 policy="Random",
                 seed=123,
                 device="cpu",
                 record_video=False):
        self.env_params = env_params
        self.max_steps = max_steps
        self.max_depth = max_depth
        self.max_width = max_width
        self.gamma = gamma
        self.expansion_worker_num = expansion_worker_num
        self.simulation_worker_num = simulation_worker_num
        self.policy = policy
        self.device = device
        self.record_video = record_video

        # Environment
        record_path = "Records/P-UCT_" + env_params["env_name"] + ".mp4"
        self.wrapped_env = EnvWrapper(**env_params,
                                      enable_record=record_video,
                                      record_path=record_path)

        # Environment properties
        self.action_n = self.wrapped_env.get_action_n()
        self.max_width = min(self.action_n, self.max_width)

        assert self.max_depth > 0 and 0 < self.max_width <= self.action_n

        # Expansion worker pool
        self.expansion_worker_pool = PoolManager(
            worker_num=expansion_worker_num,
            env_params=env_params,
            policy=policy,
            gamma=gamma,
            seed=seed,
            device=device,
            need_policy=False)

        # Simulation worker pool
        self.simulation_worker_pool = PoolManager(
            worker_num=simulation_worker_num,
            env_params=env_params,
            policy=policy,
            gamma=gamma,
            seed=seed,
            device=device,
            need_policy=True)

        # Checkpoint data manager
        self.checkpoint_data_manager = CheckpointManager()
        self.checkpoint_data_manager.hock_env("main", self.wrapped_env)

        # For MCTS tree
        self.root_node = None
        self.global_saving_idx = 0

        # Task recorder
        self.expansion_task_recorder = dict()
        self.unscheduled_expansion_tasks = list()
        self.simulation_task_recorder = dict()
        self.unscheduled_simulation_tasks = list()

        # Simulation count
        self.simulation_count = 0

        # Logging
        logging.basicConfig(filename="Logs/P-UCT_" +
                            self.env_params["env_name"] + "_" +
                            str(self.simulation_worker_num) + ".log",
                            level=logging.INFO)

    # Entrance of the P-UCT algorithm
    # This is the outer loop of P-UCT simulation, where the P-UCT agent consecutively plan a best action and
    # interact with the environment.
    def simulate_trajectory(self, max_episode_length=-1):
        state = self.wrapped_env.reset()
        accu_reward = 0.0
        done = False
        step_count = 0
        rewards = []
        times = []

        game_start_time = time.clock()

        logging.info("Start simulation")

        while not done and (max_episode_length == -1
                            or step_count < max_episode_length):
            # Plan a best action under the current state
            simulation_start_time = time.clock()
            action = self.simulate_single_move(state)
            simulation_end_time = time.clock()

            # Interact with the environment
            next_state, reward, done = self.wrapped_env.step(action)
            rewards.append(reward)
            times.append(simulation_end_time - simulation_start_time)

            print(
                "> Time step {}, take action {}, instance reward {}, cumulative reward {}, used {} seconds"
                .format(step_count, action, reward, accu_reward + reward,
                        simulation_end_time - simulation_start_time))
            logging.info(
                "> Time step {}, take action {}, instance reward {}, cumulative reward {}, used {} seconds"
                .format(step_count, action, reward, accu_reward + reward,
                        simulation_end_time - simulation_start_time))

            # Record video
            if self.record_video:
                self.wrapped_env.capture_frame()
                self.wrapped_env.store_video_files()

            # update game status
            accu_reward += reward
            state = next_state
            step_count += 1

        game_end_time = time.clock()
        print("> game ended. total reward: {}, used time {} s".format(
            accu_reward, game_end_time - game_start_time))
        logging.info("> game ended. total reward: {}, used time {} s".format(
            accu_reward, game_end_time - game_start_time))

        return accu_reward, np.array(rewards, dtype=np.float32), np.array(
            times, dtype=np.float32)

    # This is the planning process of P-UCT. Starts from a tree with a root node only,
    # P-UCT performs selection, expansion, simulation, and backpropagation on it.
    def simulate_single_move(self, state):
        # Clear cache
        self.root_node = None
        self.global_saving_idx = 0
        self.checkpoint_data_manager.clear()

        # Clear recorders
        self.expansion_task_recorder.clear()
        self.unscheduled_expansion_tasks.clear()
        self.simulation_task_recorder.clear()
        self.unscheduled_simulation_tasks.clear()

        gc.collect()

        # Free all workers
        self.expansion_worker_pool.wait_until_all_envs_idle()
        self.simulation_worker_pool.wait_until_all_envs_idle()

        # Construct root node
        self.checkpoint_data_manager.checkpoint_env("main",
                                                    self.global_saving_idx)
        self.root_node = WU_UCTnode(action_n=self.action_n,
                                    state=state,
                                    checkpoint_idx=self.global_saving_idx,
                                    parent=None,
                                    tree=self,
                                    is_head=True)

        # An index used to retrieve game-states
        self.global_saving_idx += 1

        # t_complete in the origin paper, measures the completed number of simulations
        self.simulation_count = 0

        # Repeatedly invoke the master loop (Figure 2 of the paper)
        sim_idx = 0
        while self.simulation_count < self.max_steps:
            self.simulate_single_step(sim_idx)

            sim_idx += 1

        # Select the best root action
        best_action = self.root_node.max_utility_action()

        # Retrieve the game-state before simulation begins
        self.checkpoint_data_manager.load_checkpoint_env(
            "main", self.root_node.checkpoint_idx)

        return best_action

    def simulate_single_step(self, sim_idx):
        # Go into root node
        curr_node = self.root_node

        # Selection
        curr_depth = 1
        while True:
            if curr_node.no_child_available() or (not curr_node.all_child_visited() and
                    curr_node != self.root_node and np.random.random() < 0.5) or \
                    (not curr_node.all_child_visited() and curr_node == self.root_node):
                # If no child node has been updated, we have to perform expansion anyway.
                # Or if root node is not fully visited.
                # Or if non-root node is not fully visited and {with prob 1/2}.

                cloned_curr_node = curr_node.shallow_clone()
                checkpoint_data = self.checkpoint_data_manager.retrieve(
                    curr_node.checkpoint_idx)

                # Record the task
                self.expansion_task_recorder[sim_idx] = (checkpoint_data,
                                                         cloned_curr_node,
                                                         curr_node)
                self.unscheduled_expansion_tasks.append(sim_idx)

                need_expansion = True
                break

            else:
                action = curr_node.select_action()

            curr_node.update_history(sim_idx, action,
                                     curr_node.rewards[action])

            if curr_node.dones[action] or curr_depth >= self.max_depth:
                # Exceed maximum depth
                need_expansion = False
                break

            if curr_node.children[action] is None:
                need_expansion = False
                break

            next_node = curr_node.children[action]

            curr_depth += 1
            curr_node = next_node

        # Expansion
        if not need_expansion:
            if not curr_node.dones[action]:
                # Reach maximum depth but have not terminate.
                # Record simulation task.

                self.simulation_task_recorder[sim_idx] = (
                    action, curr_node, curr_node.checkpoint_idx, None)
                self.unscheduled_simulation_tasks.append(sim_idx)
            else:
                # Reach terminal node.
                # In this case, update directly.

                self.incomplete_update(curr_node, self.root_node, sim_idx)
                self.complete_update(curr_node, self.root_node, 0.0, sim_idx)

                self.simulation_count += 1

        else:
            # Assign tasks to idle server
            while len(self.unscheduled_expansion_tasks
                      ) > 0 and self.expansion_worker_pool.has_idle_server():
                # Get a task
                curr_idx = np.random.randint(
                    0, len(self.unscheduled_expansion_tasks))
                task_idx = self.unscheduled_expansion_tasks.pop(curr_idx)

                # Assign the task to server
                checkpoint_data, cloned_curr_node, _ = self.expansion_task_recorder[
                    task_idx]
                self.expansion_worker_pool.assign_expansion_task(
                    checkpoint_data, cloned_curr_node, self.global_saving_idx,
                    task_idx)
                self.global_saving_idx += 1

            # Wait for an expansion task to complete
            if self.expansion_worker_pool.server_occupied_rate() >= 0.99:
                expand_action, next_state, reward, done, checkpoint_data, \
                saving_idx, task_idx = self.expansion_worker_pool.get_complete_expansion_task()

                curr_node = self.expansion_task_recorder.pop(task_idx)[2]
                curr_node.update_history(task_idx, expand_action, reward)

                # Record info
                curr_node.dones[expand_action] = done
                curr_node.rewards[expand_action] = reward

                if done:
                    # If this expansion result in a terminal node, perform update directly.
                    # (simulation is not needed)

                    self.incomplete_update(curr_node, self.root_node, task_idx)
                    self.complete_update(curr_node, self.root_node, 0.0,
                                         task_idx)

                    self.simulation_count += 1

                else:
                    # Schedule the task to the simulation task buffer.

                    self.checkpoint_data_manager.store(saving_idx,
                                                       checkpoint_data)

                    self.simulation_task_recorder[task_idx] = (
                        expand_action, curr_node, saving_idx,
                        deepcopy(next_state))
                    self.unscheduled_simulation_tasks.append(task_idx)

        # Assign simulation tasks to idle environment server
        while len(self.unscheduled_simulation_tasks
                  ) > 0 and self.simulation_worker_pool.has_idle_server():
            # Get a task
            idx = np.random.randint(0, len(self.unscheduled_simulation_tasks))
            task_idx = self.unscheduled_simulation_tasks.pop(idx)

            checkpoint_data = self.checkpoint_data_manager.retrieve(
                self.simulation_task_recorder[task_idx][2])

            first_aciton = None if self.simulation_task_recorder[task_idx][3] \
                is not None else self.simulation_task_recorder[task_idx][0]

            # Assign the task to server
            self.simulation_worker_pool.assign_simulation_task(
                task_idx, checkpoint_data, first_action=first_aciton)

            # Perform incomplete update
            self.incomplete_update(
                self.simulation_task_recorder[task_idx]
                [1],  # This is the corresponding node
                self.root_node,
                task_idx)

        # Wait for a simulation task to complete
        if self.simulation_worker_pool.server_occupied_rate() >= 0.99:
            args = self.simulation_worker_pool.get_complete_simulation_task()
            if len(args) == 3:
                task_idx, accu_reward, prior_prob = args
            else:
                task_idx, accu_reward, reward, done = args
            expand_action, curr_node, saving_idx, next_state = self.simulation_task_recorder.pop(
                task_idx)

            if len(args) == 4:
                curr_node.rewards[expand_action] = reward
                curr_node.dones[expand_action] = done

            # Add node
            if next_state is not None:
                curr_node.add_child(expand_action,
                                    next_state,
                                    saving_idx,
                                    prior_prob=prior_prob)

            # Complete Update
            self.complete_update(curr_node, self.root_node, accu_reward,
                                 task_idx)

            self.simulation_count += 1

    def close(self):
        # Free sub-processes
        self.expansion_worker_pool.close_pool()
        self.simulation_worker_pool.close_pool()

    # Incomplete update allows to track unobserved samples (Algorithm 2 in the paper)
    @staticmethod
    def incomplete_update(curr_node, curr_node_head, idx):
        while curr_node != curr_node_head:
            curr_node.update_incomplete(idx)
            curr_node = curr_node.parent

        curr_node_head.update_incomplete(idx)

    # Complete update tracks the observed samples (Algorithm 3 in the paper)
    @staticmethod
    def complete_update(curr_node, curr_node_head, accu_reward, idx):
        while curr_node != curr_node_head:
            accu_reward = curr_node.update_complete(idx, accu_reward)
            curr_node = curr_node.parent

        curr_node_head.update_complete(idx, accu_reward)