Ejemplo n.º 1
0
    def play_game(self):
        """Play the entire game for one epoch."""

        state = self._env.initial_state()
        obs = self._env.initial_obs()
        # Get the first non-chance node as the root
        while state.is_chance():
            legal_actions, prob_list = state.chance_outcomes()
            action = np.random.choice(legal_actions, p=prob_list)
            step_record = self._env.step(state, action)
            state = step_record.next_state
            obs = step_record.obs

        # Set root node and the corresponding particle bin
        root = ObservationNode(obs, depth=0)
        for _ in range(self.n_start_states):
            possible_states, prob_list = self._env.possible_states(obs)
            particle = np.random.choice(possible_states, p=prob_list)
            root.particle_bin.append(particle)

        history = History()

        # Solve the game by step until a terminal state
        while not state.is_terminal() and root.depth < self.max_depth:
            assert not state.is_chance()
            # Get an action by planning
            action = self._solve_one_step(root)
            # Get step result
            step_record = self._env.step(state, action)

            # Show the step
            if not self.quiet:
                print_divider('small')
                console(3, module, "Step: " + str(root.depth))
                step_record.show()

            history.append(step_record)
            state = step_record.next_state

            # Get the next non-chance node
            while state.is_chance():
                legal_actions, prob_list = state.chance_outcomes()
                chance_action = np.random.choice(legal_actions, p=prob_list)
                step_record = self._env.step(state, chance_action)

            root = root.find_child(action).find_child(step_record.obs)

        return history
Ejemplo n.º 2
0
    def _apply_tree_policy(self, state, root):
        """Select nodes according to the tree policy in the search tree."""

        visit_path = []
        record_history = History()
        working_state = state
        current_node = root
        depth = root.depth

        # Select in the tree until a new node or a terminal node or reaching the max depth
        while current_node.visit_count > 0 and not working_state.is_terminal() \
                and depth <= root.depth + self.max_depth:
            # For a new node, initialize its children, then choose a child as normal
            if not current_node.children:
                legal_actions = working_state.legal_actions()
                # Reduce bias from move generation order.
                np.random.shuffle(legal_actions)
                current_node.children = [
                    ActionNode(action, depth) for action in legal_actions
                ]

            # Choose a child by maximizing uct value
            action_child = current_node.find_child_by_uct(self.uct_c)
            current_node = action_child

            # Get the next non-chance step result
            step_record = self._env.step(working_state, action_child.action)
            while step_record.next_state.is_chance():
                legal_actions, prob_list = state.chance_outcomes()
                chance_action = np.random.choice(legal_actions, p=prob_list)
                step_record = self._env.step(state, chance_action)
            depth += 1

            # Turn to the obs child node, if not exists, append a new node
            obs_child = current_node.find_child(step_record.obs)
            if not obs_child:
                obs_child = ObservationNode(step_record.obs, depth)
                current_node.children.append(obs_child)

            current_node = obs_child
            working_state = step_record.next_state

            # Add node to visit path and return it
            visit_path.append((action_child, obs_child))
            record_history.append(step_record)

        return visit_path, record_history, working_state
Ejemplo n.º 3
0
    def _rollout(self, state):
        """Rollout method to evaluate a state."""

        history = History()

        # Rollout to terminal state and return the discounted reward
        while not state.is_terminal():
            if state.is_chance():  # is chance
                legal_actions, prob_list = state.chance_outcomes()
                action = np.random.choice(legal_actions, p=prob_list)
                state = self._env.step(state, action).next_state
            else:  # is not chance
                action = self.rollout_policy(state)
                step_record = self._env.step(state, action)
                state = step_record.next_state

                history.append(step_record)

        return history.get_return(self.discount)
Ejemplo n.º 4
0
    def initial_history(self):
        """Get new initial history with the initial record."""

        if not hasattr(self, '_initial_history'):
            # Append the init state and the init obs to the init history
            self._initial_history = History(
                [StepRecord(next_state=self.initial_state(),
                            obs=self.initial_obs())], self)

        return self._initial_history