コード例 #1
0
def tree_str(node: TreeNode,
             edge: Optional[TreeNodeChildEdge] = None,
             level=0,
             highlight_nodes: Union[None, TreeNode, List[TreeNode]] = None,
             print_only_highlighted: bool = False):
    indent_len = 25
    edge_str = edge.__str__() if edge is not None else ""
    highlight_this_node = highlight_nodes is not None and (
        (isinstance(highlight_nodes, TreeNode) and highlight_nodes == node) or
        (isinstance(highlight_nodes, list) and node in highlight_nodes))
    s = ""
    if print_only_highlighted and not highlight_this_node:
        return s

    s += "\n" + repeat_str((level-1), repeat_str(indent_len, " ")) \
        + ("|" + fixed_size_str_center(indent_len-2, edge_str, "_") + " " if level != 0 else "") \
        + (node.__str__() if not highlight_this_node else f"{Bcolors.YELLOW}{node.__str__()}{Bcolors.ENDC}")

    for child, edge in zip(node.get_children(), node.get_children_edges()):
        s += tree_str(child,
                      edge=edge,
                      level=level + 1,
                      highlight_nodes=highlight_nodes,
                      print_only_highlighted=print_only_highlighted)

    return s
コード例 #2
0
def follow_most_traversed_child_edge(node: TreeNode) -> TreeNode:
    if node.next_player is None or not (0 <= node.next_player <= 1):
        raise ValueError("nodes next player is not assigned")

    children = node.get_children()
    probabilities = [edge.traversals for edge in node.get_children_edges()]
    node = max_with_probabilities(children, probabilities)
    return node
コード例 #3
0
def expand_node(state_manager: StateManager, node: TreeNode) -> bool:
    """
    Finds all child nodes and adds them to the given node, if the node is not a final state
    Returns: true if children were added else false
    """
    if len(node.get_children()) != 0:
        raise ValueError("Should not expand node that already has children")
    if state_manager.is_terminal_state(node.game_state):
        return False
    else:
        next_states = state_manager.get_successor_states(node.game_state)
        children = [TreeNode(state) for state in next_states]
        node.add_children(children)
        return True
コード例 #4
0
    def test_tree_search_random():
        # build tree
        root = TreeNode(None)
        existing_nodes = [root]
        for i in range(10):
            node = TreeNode(None)
            add_to = random.choice(existing_nodes)
            add_to.add_child(node)
            existing_nodes.append(node)

        print_tree(root)

        expand_node = tree_search(root)

        print_tree(root, highlight_nodes=expand_node)
コード例 #5
0
    def follow_policy(self, node: TreeNode) -> TreeNode:
        if node.next_player is None or not (0 <= node.next_player <= 1):
            raise ValueError("nodes next player is not assigned")

        next_player = node.next_player
        uct_sign = 1 if next_player == 0 else -1
        children = node.get_children()
        probabilities = [
            q_value + uct_sign * uct(self.uct_c, node.visits, edge_traversals)
            for q_value, edge_traversals in [(
                edge.q_value, edge.traversals) for (
                    child, edge) in zip(children, node.get_children_edges())]
        ]
        pick_child_with_prob_func = max_with_probabilities if next_player == 0 else min_with_probabilities
        node = pick_child_with_prob_func(children, probabilities)
        return node
コード例 #6
0
def perform_episode(
        config: GameSimulatorConfig) -> Tuple[List[GameState], List[TreeNode]]:
    simulations_per_move = config.simulations_per_move
    verbose = config.verbose
    do_print_tree = config.print_tree_every_move
    starting_player = config.starting_player if (
        0 <= config.starting_player <= 1) else random.randint(0, 1)

    state_manager = config.game_state_manager  # _create_state_manager(config, override_starting_player=starting_player) if state_manager is None else state_manager
    tree_policy = UctTreePolicy(uct_c=1)
    default_policy = RandomDefaultPolicy(state_manager=state_manager)

    absolute_root_node = TreeNode(state_manager.get_initial_state(),
                                  next_player=starting_player)
    curr_root_node = absolute_root_node
    state_history = []  # the state history of the actual game played
    root_history = []

    while True:
        state_history.append(curr_root_node.game_state)
        root_history.append(curr_root_node)

        if verbose:
            prev_state = state_history[-2] if len(state_history) >= 2 else None
            action_str = state_manager.action_str(state_history[-1],
                                                  prev_state)
            print(action_str)

        if state_manager.is_terminal_state(curr_root_node.game_state):
            break

        for i in range(simulations_per_move):
            perform_simulation(state_manager, curr_root_node, tree_policy,
                               default_policy)

        # choose next root node
        # corresponding to making an actual move
        next_root_node_in_tree = follow_most_traversed_child_edge(
            curr_root_node)
        next_root_node = next_root_node_in_tree.copy_and_remove_tree()

        if do_print_tree:
            print_tree(curr_root_node,
                       highlight_nodes=[next_root_node_in_tree])

        curr_root_node = next_root_node

    return state_history, root_history
コード例 #7
0
 def build_tree():
     # build tree
     root = TreeNode(None)
     r1 = TreeNode(None)
     r2 = TreeNode(None)
     r1n1 = TreeNode(None)
     r1n2 = TreeNode(None)
     r2n1 = TreeNode(None)
     r2n2 = TreeNode(None)
     root.add_children([r1, r2])
     r1.add_children([r1n1, r1n2])
     r2.add_children([r2n1, r2n2])
     return root