Example #1
0
    def MCTS(self,
             state,
             max_node,
             C_puct,
             tau,
             showNQ=False,
             noise=0.,
             random_flip=False):
        p = self.p(state)
        illegal = (p == 0.)
        old_p = p
        p = (1. - noise) * p + noise * np.random.rand(len(p))
        p[illegal] = 0.
        p = p / sum(p)
        #root_tree = Tree(state, p)
        root_tree = self.prev_tree
        root_tree.s = state
        if self.prev_action is not None:
            if self.prev_action in root_tree.children.keys():
                root_tree = root_tree.children[self.prev_action]
            else:
                root_tree = Tree(state, p)
        root_tree.P = p

        node_num = np.sum(root_tree.N)
        while node_num < max_node:
            # select
            nodess = []
            actionss = []
            for j in range(min(self.n_parallel, max_node)):
                _, _, nodes, actions = self.select(root_tree, C_puct)
                if nodes is None:
                    break
                nodess.append(nodes)
                actionss.append(actions)

                # virtual loss
                for node, action in zip(nodes, actions):
                    node.N[action] += self.virtual_loss_n
                    node.W[action] -= self.virtual_loss_n
                    node.Q[action] = node.W[action] / node.N[action]
            for nodes, actions in zip(nodess, actionss):
                # virtual lossを元に戻す
                for node, action in zip(nodes, actions):
                    node.N[action] -= self.virtual_loss_n
                    node.W[action] += self.virtual_loss_n
                    if node.N[action] == 0:
                        node.Q[action] = 0.
                    else:
                        node.Q[action] = node.W[action] / node.N[action]

            states = []
            for nodes, actions in zip(nodess, actionss):
                s = state_copy(nodes[-1].s)
                #print([self.actionid2str(node.s, action) for node, action in zip(nodes, actions)])
                s.accept_action_str(actionid2str(s, actions[-1]))
                states.append(s)
            node_num += len(states)

            p = self.p_array(states, random_flip=random_flip)
            v = self.v_array(states, random_flip=random_flip)

            for nodes2, actions2 in zip(nodess, actionss):
                pass
                #print([self.actionid2str(node.s, action) for node, action in zip(nodes2, actions2)])
            #print("")

            count = 0
            for s, nodes, actions in zip(states, nodess, actionss):
                if not s.terminate:
                    t = nodes[-1]
                    a = actions[-1]
                    if a not in t.children.keys():
                        t.children[a] = Tree(s, p[count])
                count += 1

            # backup
            count = 0
            for nodes, actions in zip(nodess, actionss):
                for node, action in zip(nodes, actions):
                    node.N[action] += 1
                    node.W[action] += v[count]
                    node.Q[action] = node.W[action] / node.N[action]
                count += 1
        if showNQ:
            print("p=")
            self.display_parameter(np.asarray(old_p * 1000, dtype="int32"))
            print("N=")
            self.display_parameter(np.asarray(root_tree.N, dtype="int32"))
            print("Q=")
            self.display_parameter(
                np.asarray(root_tree.Q * 1000, dtype="int32"))
            print("v={}".format(self.v(root_tree.s)))

        if tau == 0:
            action = np.argmax(root_tree.N)
        else:
            N2 = np.power(np.asarray(root_tree.N, dtype="float64"), 1. / tau)
            pi = N2 / np.sum(N2)
            action = np.random.choice(len(pi), p=pi)
        # 葉に向う行動は勝ちになる行動のみ
        if action in root_tree.children.keys():
            self.prev_tree = root_tree.children[action]
        action2 = np.argmax(root_tree.N)
        pi_ret = np.zeros((137, ))
        pi_ret[action2] = 1.
        return action, root_tree.N / np.sum(root_tree.N)
Example #2
0
    def MCTS(self,
             state,
             max_node,
             C_puct,
             tau,
             showNQ=False,
             noise=0.,
             random_flip=False):
        # 壁がお互いになく、分岐のない場合読みを入れない。ただし、それでもprev_treeとかの関係上振る舞いが変わるので保留中。
        #search_node_num = max_node
        #if state.black_walls == 0 and state.white_walls == 0:
        #    x, y = state.color_p(state.turn % 2)
        #    if int(np.sum(state.movable_array(x, y, shortest_only=True))) == 1:
        #        search_node_num = 1
        p = self.p(state)
        illegal = (p == 0.)
        old_p = p
        p = (1. - noise) * p + noise * np.random.rand(len(p))
        p[illegal] = 0.
        p = p / sum(p)
        #root_tree = Tree(state, p)
        root_tree = self.prev_tree
        root_tree.s = state
        if self.prev_action is not None:
            if self.prev_action in root_tree.children.keys():
                root_tree = root_tree.children[self.prev_action]
            else:
                root_tree = Tree(state, p)
        root_tree.P = p

        node_num = np.sum(root_tree.N)
        while node_num < max_node:
            # select
            nodess = []
            actionss = []
            for j in range(min(self.n_parallel, max_node)):
                _, _, nodes, actions = self.select(root_tree, C_puct)
                if nodes is None:
                    break
                nodess.append(nodes)
                actionss.append(actions)

                # virtual loss
                for node, action in zip(nodes, actions):
                    node.N[action] += self.virtual_loss_n
                    if self.color == node.s.turn % 2:  # 先後でQがひっくり返ることを考慮
                        node.W[action] -= self.virtual_loss_n
                    else:
                        node.W[action] += self.virtual_loss_n
                    node.Q[action] = node.W[action] / node.N[action]
            for nodes, actions in zip(nodess, actionss):
                # virtual lossを元に戻す
                for node, action in zip(nodes, actions):
                    node.N[action] -= self.virtual_loss_n
                    if self.color == node.s.turn % 2:
                        node.W[action] += self.virtual_loss_n
                    else:
                        node.W[action] -= self.virtual_loss_n
                    if node.N[action] == 0:
                        node.Q[action] = 0.
                    else:
                        node.Q[action] = node.W[action] / node.N[action]

            states = []
            for nodes, actions in zip(nodess, actionss):
                s = state_copy(nodes[-1].s)
                #print([self.actionid2str(node.s, action) for node, action in zip(nodes, actions)])
                s.accept_action_str(actionid2str(s, actions[-1]))
                states.append(s)
            node_num += len(states)

            #p = self.p_array(states, random_flip=random_flip)
            v = self.v_array(states, random_flip=random_flip)

            for nodes2, actions2 in zip(nodess, actionss):
                pass
                #print([self.actionid2str(node.s, action) for node, action in zip(nodes2, actions2)])
            #print("")

            count = 0
            for s, nodes, actions in zip(states, nodess, actionss):
                if not s.terminate:
                    t = nodes[-1]
                    a = actions[-1]
                    if a not in t.children.keys():
                        t.children[a] = Tree(s, None)
                count += 1

            # backup
            count = 0
            for nodes, actions in zip(nodess, actionss):
                for node, action in zip(nodes, actions):
                    node.N[action] += 1
                    node.W[action] += v[count]
                    node.Q[action] = node.W[action] / node.N[action]
                count += 1
        if showNQ:
            print("p=")
            self.display_parameter(np.asarray(old_p * 1000, dtype="int32"))
            print("N=")
            self.display_parameter(np.asarray(root_tree.N, dtype="int32"))
            print("Q=")
            self.display_parameter(
                np.asarray(root_tree.Q * 1000, dtype="int32"))
            print("v={}".format(self.v(root_tree.s)))

        if tau == 0:
            N2 = root_tree.N * (root_tree.N == np.max(root_tree.N))
        else:
            N2 = np.power(np.asarray(root_tree.N, dtype="float64"), 1. / tau)
        pi = N2 / np.sum(N2)
        action = np.random.choice(len(pi), p=pi)
        # 葉に向う行動は勝ちになる行動のみ
        if action in root_tree.children.keys():
            self.prev_tree = root_tree.children[action]
        action2 = np.argmax(root_tree.N)
        pi_ret = np.zeros((137, ))
        pi_ret[action2] = 1.
        self.tree_for_visualize = root_tree
        return action, root_tree.N / np.sum(root_tree.N)