def MCTS(self, state, max_node, C_puct, tau, showNQ=False, noise=0., random_flip=False): p = self.p(state) illegal = (p == 0.) old_p = p p = (1. - noise) * p + noise * np.random.rand(len(p)) p[illegal] = 0. p = p / sum(p) #root_tree = Tree(state, p) root_tree = self.prev_tree root_tree.s = state if self.prev_action is not None: if self.prev_action in root_tree.children.keys(): root_tree = root_tree.children[self.prev_action] else: root_tree = Tree(state, p) root_tree.P = p node_num = np.sum(root_tree.N) while node_num < max_node: # select nodess = [] actionss = [] for j in range(min(self.n_parallel, max_node)): _, _, nodes, actions = self.select(root_tree, C_puct) if nodes is None: break nodess.append(nodes) actionss.append(actions) # virtual loss for node, action in zip(nodes, actions): node.N[action] += self.virtual_loss_n node.W[action] -= self.virtual_loss_n node.Q[action] = node.W[action] / node.N[action] for nodes, actions in zip(nodess, actionss): # virtual lossを元に戻す for node, action in zip(nodes, actions): node.N[action] -= self.virtual_loss_n node.W[action] += self.virtual_loss_n if node.N[action] == 0: node.Q[action] = 0. else: node.Q[action] = node.W[action] / node.N[action] states = [] for nodes, actions in zip(nodess, actionss): s = state_copy(nodes[-1].s) #print([self.actionid2str(node.s, action) for node, action in zip(nodes, actions)]) s.accept_action_str(actionid2str(s, actions[-1])) states.append(s) node_num += len(states) p = self.p_array(states, random_flip=random_flip) v = self.v_array(states, random_flip=random_flip) for nodes2, actions2 in zip(nodess, actionss): pass #print([self.actionid2str(node.s, action) for node, action in zip(nodes2, actions2)]) #print("") count = 0 for s, nodes, actions in zip(states, nodess, actionss): if not s.terminate: t = nodes[-1] a = actions[-1] if a not in t.children.keys(): t.children[a] = Tree(s, p[count]) count += 1 # backup count = 0 for nodes, actions in zip(nodess, actionss): for node, action in zip(nodes, actions): node.N[action] += 1 node.W[action] += v[count] node.Q[action] = node.W[action] / node.N[action] count += 1 if showNQ: print("p=") self.display_parameter(np.asarray(old_p * 1000, dtype="int32")) print("N=") self.display_parameter(np.asarray(root_tree.N, dtype="int32")) print("Q=") self.display_parameter( np.asarray(root_tree.Q * 1000, dtype="int32")) print("v={}".format(self.v(root_tree.s))) if tau == 0: action = np.argmax(root_tree.N) else: N2 = np.power(np.asarray(root_tree.N, dtype="float64"), 1. / tau) pi = N2 / np.sum(N2) action = np.random.choice(len(pi), p=pi) # 葉に向う行動は勝ちになる行動のみ if action in root_tree.children.keys(): self.prev_tree = root_tree.children[action] action2 = np.argmax(root_tree.N) pi_ret = np.zeros((137, )) pi_ret[action2] = 1. return action, root_tree.N / np.sum(root_tree.N)
def MCTS(self, state, max_node, C_puct, tau, showNQ=False, noise=0., random_flip=False): # 壁がお互いになく、分岐のない場合読みを入れない。ただし、それでもprev_treeとかの関係上振る舞いが変わるので保留中。 #search_node_num = max_node #if state.black_walls == 0 and state.white_walls == 0: # x, y = state.color_p(state.turn % 2) # if int(np.sum(state.movable_array(x, y, shortest_only=True))) == 1: # search_node_num = 1 p = self.p(state) illegal = (p == 0.) old_p = p p = (1. - noise) * p + noise * np.random.rand(len(p)) p[illegal] = 0. p = p / sum(p) #root_tree = Tree(state, p) root_tree = self.prev_tree root_tree.s = state if self.prev_action is not None: if self.prev_action in root_tree.children.keys(): root_tree = root_tree.children[self.prev_action] else: root_tree = Tree(state, p) root_tree.P = p node_num = np.sum(root_tree.N) while node_num < max_node: # select nodess = [] actionss = [] for j in range(min(self.n_parallel, max_node)): _, _, nodes, actions = self.select(root_tree, C_puct) if nodes is None: break nodess.append(nodes) actionss.append(actions) # virtual loss for node, action in zip(nodes, actions): node.N[action] += self.virtual_loss_n if self.color == node.s.turn % 2: # 先後でQがひっくり返ることを考慮 node.W[action] -= self.virtual_loss_n else: node.W[action] += self.virtual_loss_n node.Q[action] = node.W[action] / node.N[action] for nodes, actions in zip(nodess, actionss): # virtual lossを元に戻す for node, action in zip(nodes, actions): node.N[action] -= self.virtual_loss_n if self.color == node.s.turn % 2: node.W[action] += self.virtual_loss_n else: node.W[action] -= self.virtual_loss_n if node.N[action] == 0: node.Q[action] = 0. else: node.Q[action] = node.W[action] / node.N[action] states = [] for nodes, actions in zip(nodess, actionss): s = state_copy(nodes[-1].s) #print([self.actionid2str(node.s, action) for node, action in zip(nodes, actions)]) s.accept_action_str(actionid2str(s, actions[-1])) states.append(s) node_num += len(states) #p = self.p_array(states, random_flip=random_flip) v = self.v_array(states, random_flip=random_flip) for nodes2, actions2 in zip(nodess, actionss): pass #print([self.actionid2str(node.s, action) for node, action in zip(nodes2, actions2)]) #print("") count = 0 for s, nodes, actions in zip(states, nodess, actionss): if not s.terminate: t = nodes[-1] a = actions[-1] if a not in t.children.keys(): t.children[a] = Tree(s, None) count += 1 # backup count = 0 for nodes, actions in zip(nodess, actionss): for node, action in zip(nodes, actions): node.N[action] += 1 node.W[action] += v[count] node.Q[action] = node.W[action] / node.N[action] count += 1 if showNQ: print("p=") self.display_parameter(np.asarray(old_p * 1000, dtype="int32")) print("N=") self.display_parameter(np.asarray(root_tree.N, dtype="int32")) print("Q=") self.display_parameter( np.asarray(root_tree.Q * 1000, dtype="int32")) print("v={}".format(self.v(root_tree.s))) if tau == 0: N2 = root_tree.N * (root_tree.N == np.max(root_tree.N)) else: N2 = np.power(np.asarray(root_tree.N, dtype="float64"), 1. / tau) pi = N2 / np.sum(N2) action = np.random.choice(len(pi), p=pi) # 葉に向う行動は勝ちになる行動のみ if action in root_tree.children.keys(): self.prev_tree = root_tree.children[action] action2 = np.argmax(root_tree.N) pi_ret = np.zeros((137, )) pi_ret[action2] = 1. self.tree_for_visualize = root_tree return action, root_tree.N / np.sum(root_tree.N)