コード例 #1
0
ファイル: mcts_robot.py プロジェクト: DamonDeng/LambdaGo
    def reset(self):
        self.root_node = MCTSNode()
        self.root_node.is_root = True
        self.root_node.is_leaf = True

        self.training_data = []
        self.training_score = []
        self.training_move = []

        self.simulate_board = LambdaGoBoard(self.board_size)

        self.go_board = LambdaGoBoard(self.board_size)
コード例 #2
0
ファイル: mcts.py プロジェクト: polyzer/MIS_solver
    def rollout(self, root_node, stop_at_leaf=False):
        node = root_node
        v = -1
        finish = False
        while not finish:
            if node.is_end():
                break
            v = node.best_child()
            if node.children[v] is None:
                env = MISEnv() if use_dense else MISEnv_Sparse()
                env.set_graph(node.graph)
                next_graph, r, done, info = env.step(v)
                node.children[v] = MCTSNode(next_graph,
                                            self,
                                            idx=v,
                                            parent=node)
                if stop_at_leaf:
                    finish = True
            node = node.children[v]

        # backpropagate V
        V = node.state_value()
        while node is not root_node:
            V += 1
            self.update_parent(node, V)
            node = node.parent
        self.root_max = max(self.root_max, V)
        return V
コード例 #3
0
ファイル: mcts_robot.py プロジェクト: DamonDeng/LambdaGo
    def simulate_best_move(self, color, pos_filter=None):
        # create a new node as root node, we may need to reuse node from original tree in the future:
        self.root_node = MCTSNode()

        self.root_node.set_simulate_board(self.go_board)
        current_color = LambdaGoBoard.get_color_value(color)
        # root node is the node before current player play the stone,
        # so the color of root node should be color of current color
        self.root_node.player_color = current_color
        self.root_node.is_root = True

        self.expand_mcts_node(self.root_node, None)

        right_move = self.mcts_search(self.root_node, color)

        return right_move
コード例 #4
0
ファイル: mcts.py プロジェクト: drumehiron/MIS_solver
 def search(self, graph, iter_num=10):
     root_node = MCTSNode(graph, self)
     ans = []
     for i in range(iter_num):
         r = self.rollout(root_node)
         if self.performance: print(r)
         ans.append(r)
     return ans
コード例 #5
0
ファイル: mcts.py プロジェクト: polyzer/MIS_solver
 def search_for_exp(self, graph, time_limit=600, min_iter_num=100):
     now = time.time()
     root_node = MCTSNode(graph, self)
     ans = []
     cnt = 0
     while cnt < min_iter_num or time.time() - now < time_limit:
         r = self.rollout(root_node)
         ans.append(r)
         cnt += 1
     return ans
コード例 #6
0
ファイル: mcts.py プロジェクト: polyzer/MIS_solver
    def train(self, graph, TAU, batch_size=10, iter_p=2, stop_at_leaf=False):
        self.gnnhash.clear()
        mse = torch.nn.MSELoss()
        env = MISEnv() if use_dense else MISEnv_Sparse()
        env.set_graph(graph)

        graphs = []
        actions = []
        pis = []
        means = []
        stds = []
        done = False
        while not done:
            n, _ = graph.shape
            node = MCTSNode(graph, self)
            means.append(node.reward_mean)
            stds.append(node.reward_std)
            pi = self.get_improved_pi(node,
                                      TAU,
                                      iter_p=iter_p,
                                      stop_at_leaf=stop_at_leaf)
            action = np.random.choice(n, p=pi)
            graphs.append(graph)
            actions.append(action)
            pis.append(pi)
            graph, reward, done, info = env.step(action)

        T = len(graphs)
        idxs = [i for i in range(T)]
        np.random.shuffle(idxs)
        i = 0
        while i < T:
            size = min(batch_size, T - i)
            self.optimizer.zero_grad()
            loss = torch.tensor([0], dtype=torch.float32)
            for j in range(i, i + size):
                idx = idxs[j]
                Timer.start('gnn')
                p, v = self.gnn(graphs[idx], True)
                Timer.end('gnn')
                n, _ = graphs[idx].shape
                # normalize z with mean, std
                z = torch.tensor(((T - idx) - means[idx]) / stds[idx],
                                 dtype=torch.float32)
                loss += torch.tensor(mse(z, v[actions[idx]]) - \
                    (torch.tensor(pis[idx], dtype=torch.float32) * torch.log(p + EPS)).sum(), dtype=torch.float32, requires_grad = True)
            loss /= size
            loss.backward()
            self.optimizer.step()
            i += size
コード例 #7
0
ファイル: mcts.py プロジェクト: polyzer/MIS_solver
    def best_search1(self, graph, TAU=0.1, iter_p=1):
        self.gnnhash.clear()
        env = MISEnv() if use_dense else MISEnv_Sparse()
        env.set_graph(graph)

        ma = 0
        reward = 0
        done = False
        while not done:
            n, _ = graph.shape
            node = MCTSNode(graph, self)
            pi = self.get_improved_pi(node, TAU, iter_p=iter_p)
            ma = max(ma, self.root_max + reward)
            action = np.random.choice(n, p=pi)
            graph, reward, done, info = env.step(action)
        return ma, reward
コード例 #8
0
    def build_tree(self, num_simulations=5000):
        # first, for every possible next code, we perform one simulation
        for possible_code in self.possible_codes.get_all():
            # print(len(self.possible_codes))
            node = MCTSNode(code=possible_code)
            node.perform_simulation(self.possible_codes,
                                    next_code=possible_code)
            self.nodes.append(node)

        # then, we perform a simulation for remaining number of tries
        # nodes to perform simulations are chosen by roulette wheel selection
        num_simulations -= len(self.possible_codes)
        for i in range(num_simulations):
            # if i % 1000 == 0:
            #     print(i)
            node = self._get_best_node()
            node.perform_simulation(self.possible_codes,
                                    next_code=node.code)
コード例 #9
0
class MCTS(object):
    def __init__(self, side, search_size, logger):
        self.tree_manager = TreeManager()
        self.tree_manager.c_uct = 2
        self.actual_side = side
        self.side = 1 - side
        self.search_size = search_size
        self.logger = logger

    def init(self, actions, env):
        self.root = MCTSNode(self.tree_manager, self.side, descendant=actions)
        self.env = env

    def uct(self):
        self.logger.debug(f'=== UCT start === Step {self.env.game.step_count}')
        self.logger.debug(f'game map {self.env.game.map}')
        for i in range(self.search_size):
            self.logger.debug('=== Search start ===')
            node = self.root
            env = self.env.clone()
            player_move = 2
            self.logger.debug(f'game map {env.game.map}')
            self.logger.debug(f'real side {self.actual_side}')
            self.logger.debug(f'step {env.game.step_count}')

            self.logger.debug('== Select ==')
            # Select
            while node.untried_actions == [] and node.children != []:
                # node is fully expanded and non-terminal
                node = node.select()
                env.take_side_action(node.side, node.action)
                self.logger.debug(f'Select side {node.side}, '
                                  f'action {node.action}')
                player_move -= 1
                if player_move == 0:
                    env.step()
                    self.logger.debug('Select env step')
                    player_move = 2

            self.logger.debug('== Expand ==')
            # Expand
            if node.untried_actions:
                # if we can expand (i.e. state/node is non-terminal)
                action = random.choice(node.untried_actions)
                env.take_side_action(1 - node.side, action)
                self.logger.debug(f'Expand side {1 - node.side}, '
                                  f'action {action}')
                player_move -= 1
                if player_move == 0:
                    env.step()
                    self.logger.debug('Expand env step')
                    player_move = 2
                node = node.add_node(1 - node.side, action,
                                     env.get_side_action(node.side))
                self.logger.debug(f'Expand add node, side {node.side}, '
                                  f'action {node.action}')
            side = 1 - node.side

            roll_out = 0
            self.logger.debug('== Rollout ==')
            # Rollout
            while env.get_side_action(side):
                # while state is non-terminal
                actions = env.get_side_action(side)
                action = random.choice(actions)
                env.take_side_action(side, action)
                self.logger.debug(f'Rollout side {side}, action {action}')
                side = 1 - side
                player_move -= 1
                if player_move == 0:
                    env.step()
                    self.logger.debug('Rollout env step')
                    roll_out += 1
                    player_move = 2

            winner = env.get_winner()
            self.logger.debug(f'winner {winner}')
            if winner >= 0:
                if winner == node.side:
                    z = 1
                else:
                    z = -1
            else:
                z = 0

            self.logger.debug(f'z {z}')

            # Backpropagate
            while node is not None:
                node.update(z, 0)
                self.logger.debug(f'side {node.side}, '
                                  f'action {node.action}, z {z}')
                node = node.parent
                z *= -1

            self.logger.debug('=== Search end ===')

        sorted_nodes = self.root.sorted_child()

        self.logger.info(f'=== Side {self.actual_side} ===')
        _count = 0
        for node in sorted_nodes:
            if _count < 3:
                self.logger.info(node.str())
                _count += 1
            assert node.check_vls()

        selected = sorted_nodes[0]
        self.root = selected
        return selected.action
コード例 #10
0
 def init(self, actions, env):
     self.root = MCTSNode(self.tree_manager, self.side, descendant=actions)
     self.env = env
コード例 #11
0
ファイル: mcts_robot.py プロジェクト: DamonDeng/LambdaGo
class MCTSRobot(object):
    def __init__(self,
                 name='DefaultMCTSRobot',
                 layer_number=19,
                 old_model=None,
                 search_time=100,
                 board_size=19,
                 komi=7.5,
                 train_iter=2):
        self.name = name
        self.layer_number = layer_number
        self.search_time = search_time

        self.train_iter = train_iter

        self.board_size = board_size
        self.komi = komi

        self.stone_number = self.board_size * self.board_size

        self.simulate_board_list = []
        self.max_play_move = 1024
        self.ColorBlackChar = 'b'
        self.ColorWhiteChar = 'w'

        self.BlackIdentify = np.ones((self.board_size, self.board_size),
                                     dtype=int)
        self.WhiteIdentify = np.ones(
            (self.board_size, self.board_size), dtype=int) * -1

        self.PosArray = []

        for row in range(self.board_size):
            for col in range(self.board_size):
                self.PosArray.append((row, col))

        self.PosArray.append(None)

        self.reset()

        if old_model is None:
            self.model = DualHeadModel(self.name,
                                       self.board_size,
                                       layer_number=self.layer_number)
        else:
            print('Trying to load old model for continue training:' +
                  './model/' + old_model)
            self.model = DualHeadModel(self.name,
                                       self.board_size,
                                       model_path='./model/' + old_model,
                                       layer_number=self.layer_number)

    def reset(self):
        self.root_node = MCTSNode()
        self.root_node.is_root = True
        self.root_node.is_leaf = True

        self.training_data = []
        self.training_score = []
        self.training_move = []

        self.simulate_board = LambdaGoBoard(self.board_size)

        self.go_board = LambdaGoBoard(self.board_size)

    def reset_board(self):
        self.go_board.reset(self.board_size)

    def apply_move(self, color, pos):

        # print ('aplying move:' + color + ' in the position ' + str(pos))

        current_data = []

        current_board = self.go_board.get_board()

        current_data.append(current_board)

        if LambdaGoBoard.get_color_value(color) == LambdaGoBoard.ColorBlack:
            current_data.append(self.BlackIdentify)
        elif LambdaGoBoard.get_color_value(color) == LambdaGoBoard.ColorWhite:
            current_data.append(self.WhiteIdentify)
        else:
            raise Exception('Incorrect color character')

        self.training_data.append(current_data)

        if pos is None:
            current_move = np.zeros((self.board_size * self.board_size + 1),
                                    dtype=int)
            current_move[self.board_size * self.board_size] = 1
            self.training_move.append(current_move)
        else:

            (row, col) = pos
            current_move = np.zeros((self.board_size * self.board_size + 1),
                                    dtype=int)
            current_move[row * self.board_size + col] = 1
            self.training_move.append(current_move)

        is_valid = self.go_board.apply_move(color, pos)

        # if not is_valid:
        #     print ('# incorrect move:' + color + '  pos:' + str(pos) + '  Reason:' + str(reason))
        # else:
        #     print ('#   correct move:' + color + '  pos:' + str(pos) + '                            ')

        # start_time = time.time()
        # self.board.update_score_board()
        # end_time = time.time()
        # print ('total update time:' + str(end_time - start_time))
        # print(str(self.board))

    def showboard(self):
        # self.go_board.update_score_board()
        return str(self.go_board)

    def simulate_best_move(self, color, pos_filter=None):
        # create a new node as root node, we may need to reuse node from original tree in the future:
        self.root_node = MCTSNode()

        self.root_node.set_simulate_board(self.go_board)
        current_color = LambdaGoBoard.get_color_value(color)
        # root node is the node before current player play the stone,
        # so the color of root node should be color of current color
        self.root_node.player_color = current_color
        self.root_node.is_root = True

        self.expand_mcts_node(self.root_node, None)

        right_move = self.mcts_search(self.root_node, color)

        return right_move

    def mcts_search(self, root_node, color):

        right_move = (None, 0)

        for i in range(self.search_time):
            # node_visited = []
            # start_time = time.time()
            self.search_to_expand(root_node)
            # end_time = time.time()

            # print ('Search to Expand time:' + str(end_time - start_time))

            # print ('searching......, iter:' + str(i))
            # value = self.expand_mcts_node(node_visited)
            # self.node_back_up(node_visited, value)

        right_node = self.lookup_right_node(self.root_node)

        right_move = (right_node.move, right_node.get_value())

        self.display_result(color, right_move, right_node)

        # print ('Found right move:' + str(right_move[0]) + 'with value:' + str(right_move[1]))
        # print ('Visited count: ' + str(right_node.visit_count))

        # time.sleep(5)

        return right_move

    def search_to_expand(self, current_node):

        # if not current_node.is_leaf:
        #     # it is not a leaf node, search the best one and call search_to_expand again
        best_value = -10000
        best_node = None
        for child in current_node.children:

            # print ('searching...........' + str(child.move))
            # print ('current value of best node:' + str(child.current_value))
            # print ('average value of best node:' + str(child.average_value))
            # print ('total value of best node:' + str(child.total_value))

            # print ('visit count of best node:' + str(child.visit_count))

            # print ('policy value of best node:' + str(child.policy_value))
            # print ('------------------------------------')

            node_value = child.get_value()
            if node_value > best_value:
                best_value = node_value
                best_node = child

        # print ('found best node:' + str(best_node.move) + '  with value:' + str(best_value))
        # print ('current value of best node:' + str(best_node.current_value))
        # print ('average value of best node:' + str(best_node.average_value))
        # print ('total value of best node:' + str(best_node.total_value))

        # print ('visit count of best node:' + str(best_node.visit_count))

        # print ('policy value of best node:' + str(best_node.policy_value))

        if best_node is None:
            raise Exception('Node without valid child')
        else:
            if best_node.is_leaf:
                # print('trying to expand node.' + str(best_node.move))

                # start_time = time.time()

                value = self.expand_mcts_node(best_node, current_node)

                # end_time = time.time()

                # print ('    node Expand time:' + str(end_time - start_time))

            else:
                # print('searching into best child:' + str(best_node.move) + '.........................................')
                value = self.search_to_expand(best_node)

            # print('value from child is:' + str(value))
            # print('number of children of current node:' +str(len(current_node.children)))
            current_node.visit_count = current_node.visit_count + 1
            current_node.total_value = current_node.total_value + value
            current_node.average_value = (
                current_node.total_value /
                current_node.visit_count) / self.stone_number
            current_node.current_value = current_node.average_value
            return value

        # else:
        #     value = self.expand_mcts_node(current_node)
        #     return value

    def expand_mcts_node(self, current_node, parent_node):

        # if parent_node is None:
        #     # no parent node, it is root node, we can use the value in current node, which is root node that was initialized.
        #     current_color = current_node.player_color
        # else:
        #     # copy data from parent node
        #     parent_color = parent_node.player_color
        #     current_color = LambdaGoBoard.reverse_color_value(parent_color)
        #     # enemy_color = parent_node.player_color

        #     current_node.player_color = current_color
        #     current_node.simulate_board = LambdaGoBoard(parent_node.simulate_board.board_size)
        #     current_node.simulate_board.copy_from(parent_node.simulate_board)

        #     current_node.simulate_board.apply_move(LambdaGoBoard.get_color_char(current_color), current_node.move)

        current_color = current_node.player_color
        child_color = LambdaGoBoard.reverse_color_value(current_color)

        # current_data_list = []

        current_data = []

        current_board = current_node.simulate_board.get_board()

        current_data.append(current_board)
        if current_color == LambdaGoBoard.ColorBlack:
            current_data.append(self.BlackIdentify)
        elif current_color == LambdaGoBoard.ColorWhite:
            current_data.append(self.WhiteIdentify)
        else:
            raise Exception('Incorrect color character')

        # for i in range(361):
        #     current_data_list.append(current_data)

        # predict_start_time = time.time()

        result = self.model.predict(current_data)

        # predict_end_time = time.time()

        # print ('                Predict time:' + str(predict_end_time - predict_start_time))

        policy_array = result[0][0]

        current_value = result[1][0]

        move_and_policy_value = zip(self.PosArray, policy_array)

        # # need to consider whether sorting this array helps to improve the MCTS searching speed.
        # move_and_policy_value.sort(key=lambda x:x[1], reverse=True)

        # simulate_start_time = time.time()

        for single_move_and_policy in move_and_policy_value:

            move = single_move_and_policy[0]
            policy_value = single_move_and_policy[1]

            new_child = MCTSNode()

            new_child.simulate_board = LambdaGoBoard(self.board_size)
            new_child.simulate_board.copy_from(current_node.simulate_board)

            is_valid, reason = new_child.simulate_board.apply_move(
                LambdaGoBoard.get_color_char(current_color), move)

            if is_valid:
                new_child.move = move
                new_child.policy_value = policy_value
                new_child.player_color = child_color
                current_node.children.append(new_child)

        # simulate_end_time = time.time()

        # print ('              Simulating time:' + str(simulate_end_time - simulate_start_time))

        current_node.is_leaf = False
        current_node.visit_count = current_node.visit_count + 1
        current_node.total_value = current_node.total_value + current_value
        current_node.average_value = (
            current_node.total_value /
            current_node.visit_count) / self.stone_number
        current_node.current_value = current_node.average_value

        return current_value

    def lookup_right_node(self, current_node):

        most_visited_count = -1
        best_policy = -1
        most_visited_node = None

        for child in current_node.children:
            if child.visit_count > most_visited_count:
                most_visited_count = child.visit_count
                best_policy = child.policy_value
                most_visited_node = child
            elif child.visit_count == most_visited_count:
                if child.policy_value > best_policy:
                    best_policy = child.policy_value
                    most_visited_node = child

        return most_visited_node

    def select_move(self, color):

        right_move = self.simulate_best_move(color)

        self.go_board.apply_move(color, right_move[0])

        # print ('# selected move:' + str(right_move))

        return right_move[0]

    def display_result(self, color, right_move, right_node):

        display_string = "# Player: "
        if LambdaGoBoard.get_color_value(color) == LambdaGoBoard.ColorBlack:
            display_string = display_string + "Black    "
        else:
            display_string = display_string + "White    "

        move_string = ' Move:' + str(right_move[0]) + '                   '
        value_string = ' Value:' + str(right_move[1]) + '                    '
        visit_count_string = '    Count:' + str(
            right_node.visit_count) + '                     '
        node_value_string = '   NodeValue:' + str(
            right_node.average_value) + '                    '
        policy_value_string = '   Policy:' + str(
            right_node.policy_value) + '                     '

        display_string = display_string + move_string[
            0:20] + visit_count_string[0:20]
        display_string = display_string + value_string[
            0:25] + node_value_string[0:25] + policy_value_string[0:25]

        if LambdaGoBoard.get_color_value(color) == LambdaGoBoard.ColorBlack:
            print(display_string)
            print('# ')
        else:
            print('# ')
            print(display_string)

        if abs(right_node.average_value) > 1:
            print('# incorrect node:')
            print(str(right_node))
            time.sleep(20)

    def train(self, board_states, move_sequence, score_board):

        train_data_len = len(self.training_data)
        train_move_len = len(self.training_move)

        if not train_data_len == train_move_len:
            raise Exception(
                'Inconsist state, training data and training move not in same length'
            )

        self.go_board.update_score_board()

        result_score = self.get_score()

        if result_score > self.komi:
            for i in range(len(self.training_move)):
                if i % 2 == 1:
                    # black win while current move is for white, revert it
                    self.training_move[i] = self.training_move[1] * -1 + 1
        else:
            for i in range(len(self.training_move)):
                if i % 2 == 0:
                    # white win while current move is for black, revert it
                    self.training_move[i] = self.training_move[1] * -1 + 1

        result_score_board = self.go_board.score_board

        for i in range(train_data_len):
            self.training_score.append(result_score_board)

        print('# robot ' + self.name + ' is in training.........')

        self.model.train(self.training_data,
                         self.training_score,
                         self.training_move,
                         steps=self.train_iter)

    def new_game(self):
        self.reset_board()

        self.simulate_board_list = []
        for i in range(self.board_size * self.board_size):
            self.simulate_board_list.append(LambdaGoBoard(self.board_size))

    def get_current_state(self):
        return self.go_board.board

    def get_score_board(self):

        if not self.go_board.score_board_updated:
            self.go_board.update_score_board()

        return self.go_board.score_board

    def get_score(self):

        if not self.go_board.score_board_updated:
            self.go_board.update_score_board()

        return self.go_board.score_board_sum
コード例 #12
0
ファイル: mcts_robot.py プロジェクト: DamonDeng/LambdaGo
    def expand_mcts_node(self, current_node, parent_node):

        # if parent_node is None:
        #     # no parent node, it is root node, we can use the value in current node, which is root node that was initialized.
        #     current_color = current_node.player_color
        # else:
        #     # copy data from parent node
        #     parent_color = parent_node.player_color
        #     current_color = LambdaGoBoard.reverse_color_value(parent_color)
        #     # enemy_color = parent_node.player_color

        #     current_node.player_color = current_color
        #     current_node.simulate_board = LambdaGoBoard(parent_node.simulate_board.board_size)
        #     current_node.simulate_board.copy_from(parent_node.simulate_board)

        #     current_node.simulate_board.apply_move(LambdaGoBoard.get_color_char(current_color), current_node.move)

        current_color = current_node.player_color
        child_color = LambdaGoBoard.reverse_color_value(current_color)

        # current_data_list = []

        current_data = []

        current_board = current_node.simulate_board.get_board()

        current_data.append(current_board)
        if current_color == LambdaGoBoard.ColorBlack:
            current_data.append(self.BlackIdentify)
        elif current_color == LambdaGoBoard.ColorWhite:
            current_data.append(self.WhiteIdentify)
        else:
            raise Exception('Incorrect color character')

        # for i in range(361):
        #     current_data_list.append(current_data)

        # predict_start_time = time.time()

        result = self.model.predict(current_data)

        # predict_end_time = time.time()

        # print ('                Predict time:' + str(predict_end_time - predict_start_time))

        policy_array = result[0][0]

        current_value = result[1][0]

        move_and_policy_value = zip(self.PosArray, policy_array)

        # # need to consider whether sorting this array helps to improve the MCTS searching speed.
        # move_and_policy_value.sort(key=lambda x:x[1], reverse=True)

        # simulate_start_time = time.time()

        for single_move_and_policy in move_and_policy_value:

            move = single_move_and_policy[0]
            policy_value = single_move_and_policy[1]

            new_child = MCTSNode()

            new_child.simulate_board = LambdaGoBoard(self.board_size)
            new_child.simulate_board.copy_from(current_node.simulate_board)

            is_valid, reason = new_child.simulate_board.apply_move(
                LambdaGoBoard.get_color_char(current_color), move)

            if is_valid:
                new_child.move = move
                new_child.policy_value = policy_value
                new_child.player_color = child_color
                current_node.children.append(new_child)

        # simulate_end_time = time.time()

        # print ('              Simulating time:' + str(simulate_end_time - simulate_start_time))

        current_node.is_leaf = False
        current_node.visit_count = current_node.visit_count + 1
        current_node.total_value = current_node.total_value + current_value
        current_node.average_value = (
            current_node.total_value /
            current_node.visit_count) / self.stone_number
        current_node.current_value = current_node.average_value

        return current_value
コード例 #13
0
ファイル: mcts.py プロジェクト: Seraphli/TankAI
class MCTSP(object):
    def __init__(self):
        import threading

        self.pool = Pool('tankai', 0)
        self.pool.setup()
        self.pool.reg_task([random_rollout])
        self.backup_lock = threading.Lock()
        self.select_lock = threading.Lock()
        self.log_info = {}

    def setup(self, side, search_size, logger):
        self.tree_manager = TreeManager()
        self.tree_manager.c_uct = 2
        self.actual_side = side
        self.side = 1 - side
        self.search_size = search_size
        self.logger = logger

    def reset_root(self, actions):
        self.root = MCTSNode(self.tree_manager, self.side, descendant=actions)

    def select(self, node):
        self.select_lock.acquire()
        action_list = []
        # Select
        while node.untried_actions == [] and node.children:
            # node is fully expanded and non-terminal
            node = node.select()
            action_list.append((node.side, node.action))
        self.select_lock.release()
        return action_list, node

    def _backup(self, node, flag, result):
        z = result['z']

        # backpropagate from the expanded node and work back to the root node
        while node is not None:
            if (node.is_leaf() and flag) or node.is_root():
                # node is add in the front of the backup
                # or node is the root, which won't apply vl
                node.update(z, 0, {})
            else:
                node.update(z, 3, {})
            node = node.parent
            z *= -1

        self.backup_lock.acquire()
        self.jobs.pop(str(result['action_list']))
        self.pbar.update()
        self.backup_lock.release()

    def backup(self, result):
        self.select_lock.acquire()
        uid = result['uid']
        node, flag = self.backup_dict[uid]
        if flag:
            # add child and descend tree
            side, action = result['action_list'][-1]
            node = node.add_node(1 - node.side, action, result['descendants'])

        self._backup(node, flag, result)
        self.backup_dict.pop(uid, None)
        self.select_lock.release()

    def uct(self, env):
        import sys
        import pickle
        import zlib
        import random
        import time
        import uuid

        self.info = {}
        self.info['results'] = []
        self.jobs = {}
        self.backup_dict = {}
        self.pbar = tqdm(total=self.search_size, file=sys.stdout)
        _env = zlib.compress(pickle.dumps(env, -1))

        for i in range(self.search_size):
            wait = True
            while wait:
                node = self.root
                untried_flag = False
                action_list, node = self.select(node)

                if node.untried_actions:
                    untried_flag = True

                # Expand
                if untried_flag:
                    # if we can expand (i.e. state/node is non-terminal)
                    action = random.choice(node.untried_actions)
                    action_list.append((1 - node.side, action))

                if str(action_list) in self.jobs:
                    time.sleep(0.01)
                else:
                    self.backup_lock.acquire()
                    self.jobs[str(action_list)] = 1
                    tmp_node = self.root
                    if untried_flag:
                        for a in action_list[:-1]:
                            tmp_node = tmp_node.select_child_by_action(a[1])
                            tmp_node.apply_vl(3)
                    else:
                        for a in action_list:
                            tmp_node = tmp_node.select_child_by_action(a[1])
                            tmp_node.apply_vl(3)
                    self.backup_lock.release()
                    wait = False

            uid = str(uuid.uuid4())
            self.backup_dict[uid] = [node, untried_flag]
            self.pool.apply_async(random_rollout, (uid, _env, action_list, 0),
                                  self.backup)

        self.pool.join()
        self.pbar.close()

        sorted_nodes = self.root.sorted_child()
        _count = 0
        for node in sorted_nodes:
            if _count < 3:
                self.logger.info(node.str())
                _count += 1
            assert node.check_vls()

        selected = sorted_nodes[0]
        return selected
コード例 #14
0
ファイル: mcts.py プロジェクト: Seraphli/TankAI
 def reset_root(self, actions):
     self.root = MCTSNode(self.tree_manager, self.side, descendant=actions)