Beispiel #1
0
class TreeTest(unittest.TestCase):

    def setUp(self):
        self.t = Tree(0)
        self.val = 1
        self.LEFT = 0
    
    def append_test(self):
        self.t.append(self.val)
        self.assertEquals(self.val, self.t.ptr.children[0].data)

    def forward_test(self):
        self.t.append(self.val)
        self.t.forward(self.LEFT)
        self.assertEquals(self.val, self.t.ptr.data)
    
    def back_test(self):
        self.t.append(self.val)
        self.t.forward(self.LEFT)
        self.t.back()
        self.assertEquals(self.t.root, self.t.ptr)

    def reset_test(self):
        self.t.append(self.val)
        self.t.forward(self.LEFT)
        self.t.append(self.val)
        self.t.forward(self.LEFT)
        self.t.reset()
        self.assertEquals(self.t.root, self.t.ptr)
Beispiel #2
0
 def test_cwn(self):
     fi = TFile('tree2.root','RECREATE')
     tr = Tree('test_tree', 'A test tree')
     tr.var('nvals', the_type=int)
     tr.vector('x', 'nvals', 20)
     tr.fill('nvals', 10)
     tr.vfill('x', range(10))
     tr.tree.Fill()
     tr.reset()
     tr.fill('nvals', 5)
     tr.vfill('x', range(5))
     tr.tree.Fill()        
     fi.Write()
     fi.Close()
Beispiel #3
0
 def test_cwn(self):
     fi = TFile('tree2.root', 'RECREATE')
     tr = Tree('test_tree', 'A test tree')
     tr.var('nvals', the_type=int)
     tr.vector('x', 'nvals', 20)
     tr.fill('nvals', 10)
     tr.vfill('x', range(10))
     tr.tree.Fill()
     tr.reset()
     tr.fill('nvals', 5)
     tr.vfill('x', range(5))
     tr.tree.Fill()
     fi.Write()
     fi.Close()
def main():
    tree = Tree()
    tree.reset()
    while True:
        tree.tick(datetime.date.today())
        time.sleep(3600)
Beispiel #5
0
class MCTSActor(GameActor):
    def __init__(self, game, num_expansions=300, value_network=None):
        self.game = game
        self.num_expansions = num_expansions
        self.value_network = value_network

        self.tree = Tree()

    def compute_heuristic(self, node_data):
        value_est = node_data.value_est
        if (node_data.player == GameBase.Player.PLAYER2):
            value_est = -value_est
        num_simulations = self.tree.get_node_data(0).num_visits
        heuristic = value_est + np.sqrt(num_simulations /
                                        (1 + node_data.num_visits))
        return heuristic

    def select(self, idx):
        children = self.tree.get_children(idx)
        if len(children) == 0:
            return None

        exploration_vals = []
        for child_idx in children:
            node_data = self.tree.get_node_data(child_idx)
            exp_val = self.compute_heuristic(node_data)
            exploration_vals.append(exp_val)

        new_idx = children[np.argmax(exploration_vals)]
        return new_idx

    def simulate(self, game_state, player):
        self.game.set_state(game_state)
        while (self.game.get_game_status() == GameBase.Status.IN_PROGRESS):
            actions = self.game.get_valid_actions()
            rand_idx = np.random.randint(len(actions))
            action = actions[rand_idx]
            self.game.step(action)

        status = self.game.get_game_status()
        if status == GameBase.Status.PLAYER1_WIN:
            return 1
        elif status == GameBase.Status.PLAYER2_WIN:
            return -1
        else:
            return 0

    def _check_tree(self, idx, level=0):
        data = self.tree.get_node_data(idx)
        self.game.set_state(data.game_state)
        children = self.tree.get_children(idx)
        actions = self.game.get_valid_actions()

        valid_list = []
        for child_idx in children:
            prev_action = self.tree.get_node_data(child_idx).prev_action
            if prev_action not in actions:
                print(
                    "========\nFailure:\n{}\nlevel: {}\nprev_action: {}\nidx: {}\n"
                    .format(data.game_state[:9].reshape((3, 3)), level,
                            prev_action, child_idx))
                return False
            valid_list.append(self._check_tree(child_idx, level + 1))

        valid = np.all(valid_list)
        return valid

    def get_action(self, game_state):
        self.game.set_state(game_state)
        curr_player = self.game.get_curr_player()

        if self.tree.num_nodes() == 0:
            # add in root node
            initial_node = MCTSNodeData(value_est=0,
                                        num_visits=0,
                                        game_state=game_state,
                                        prev_action=None,
                                        player=curr_player)
            self.tree.insert_node(initial_node, None)
        else:
            # do breadth first search for a matching child
            def breadth_first_search(tree, game_state):
                child_list = []
                child_list += tree.get_children(0)

                found_idx = None
                child_list_idx = 0

                while child_list_idx < len(child_list):
                    child_idx = child_list[child_list_idx]
                    data = tree.get_node_data(child_idx)
                    if np.all(data.game_state == game_state):
                        found_idx = child_idx
                        break
                    child_list += tree.get_children(child_idx)
                    child_list_idx += 1
                return found_idx

            found_idx = breadth_first_search(self.tree, game_state)
            if found_idx is None:
                # clear tree
                self.tree.reset()

                # add in root node
                initial_node = MCTSNodeData(value_est=0,
                                            num_visits=0,
                                            game_state=game_state,
                                            prev_action=None,
                                            player=curr_player)
                self.tree.insert_node(initial_node, None)
            else:
                # prev_tree = copy.deepcopy(self.tree)
                self.tree.rebase(found_idx)

        for _ in range(self.num_expansions):
            # select nodes in tree until leaf node
            curr_idx = 0
            idx_list = [curr_idx]
            curr_idx = self.select(curr_idx)
            while curr_idx is not None:
                idx_list.append(curr_idx)
                curr_idx = self.select(curr_idx)

            leaf_node_idx = idx_list[-1]

            # expand
            leaf_data = self.tree.get_node_data(leaf_node_idx)

            # this node has never been visited yet, don't expand, just simulate().
            # If it has been visited and is a leaf node, expand the node
            # and choose a new leaf node from the children actions to simulate.
            if leaf_data.num_visits != 0:
                self.game.set_state(leaf_data.game_state)
                actions = self.game.get_valid_actions()
                if len(actions) != 0:
                    for action in actions:
                        self.game.set_state(leaf_data.game_state)
                        self.game.step(action)
                        new_node_data = MCTSNodeData(
                            value_est=0,
                            num_visits=0,
                            game_state=self.game.get_state(),
                            prev_action=action,
                            player=self.game.get_curr_player())
                        self.tree.insert_node(new_node_data, leaf_node_idx)

                    leaf_node_idx = self.select(leaf_node_idx)
                    leaf_data = self.tree.get_node_data(leaf_node_idx)
                    idx_list.append(leaf_node_idx)

            # simulate and update values for every node visited
            value_est = self.simulate(leaf_data.game_state, curr_player)
            if self.value_network is not None:
                board_state = leaf_data.game_state[:9].reshape((1, 3, 3, 1))
                board_state = np.array(board_state, dtype=np.float32)
                if leaf_data.value_net_cache is None:
                    # cache value_net_cache for later
                    leaf_data.value_net_cache = self.value_network(board_state)
                    self.tree.update_node_data(leaf_node_idx, leaf_data)
                value_est = (0.5) * value_est + (
                    0.5) * leaf_data.value_net_cache

            for idx in idx_list:
                node_data = self.tree.get_node_data(idx)
                node_data.value_est = float(value_est + node_data.num_visits *
                                            node_data.value_est) / (
                                                node_data.num_visits + 1)
                node_data.num_visits += 1
                self.tree.update_node_data(idx, node_data)

        children = self.tree.get_children(0)

        value_list = []
        for child_idx in children:
            child_data = self.tree.get_node_data(child_idx)
            value_list.append(child_data.value_est)

        if curr_player == GameBase.Player.PLAYER1:
            best_idx = np.argmax(value_list)
        else:
            best_idx = np.argmin(value_list)
        best_action = self.tree.get_node_data(children[best_idx]).prev_action
        return best_action
Beispiel #6
0
class MinMaxSearchTree(object):
    """
    MinMax Search Tree
    """
    def __init__(self, game):
        """
        ctor for min max search tree.

        Arguments:
            game {GameBase} -- A game object that will be used and mutated by the minmax tree search.
                Don't the game outside of this class as it will be manipulated within here.
        """
        self.tree = Tree()
        self.game = game

    def search(self, game_state, player):
        """
        Perform MinMax search and return the search tree.
        The returned search tree will contain a tuple of data at each node.
        This tuple consists of (game_state, minmax_value, optimal_action)

        Arguments:
            game_state {np.array} -- state of the game as returned by ttt.get_state()
            player {which player to solve minmax tree for} -- PLAYER1 or PLAYER2
        """
        # clear out any previous searches
        self.tree.reset()

        # Insert the parent node
        root_idx = self.tree.insert_node(
            MinMaxNodeState(game_state, None, None), None)

        # Start expanding from parent node
        self._expand_node(root_idx, player)
        return self.tree

    def _expand_node(self, node_idx, player):
        # get possible actions
        node_data = self.tree.get_node_data(node_idx)
        self.game.set_state(node_data.game_state)
        curr_player = self.game.get_curr_player()
        actions = self.game.get_valid_actions()

        # If we have reached a leaf node, get the value and return
        # 1 for winning, -1 for losing, 0 for tie
        if len(actions) == 0:
            val = self.game.get_outcome(player)
            node_data.minmax_value = val
            self.tree.update_node_data(node_idx, node_data)
            return val

        # Recursively expand each child node
        # and collect the minmax values
        minmax_vals = []
        for action in actions:
            self.game.set_state(node_data.game_state)
            self.game.step(action)

            new_node_idx = self.tree.insert_node(
                MinMaxNodeState(self.game.get_state(), None, None), node_idx)
            val = self._expand_node(new_node_idx, player)
            minmax_vals.append(val)

        # Compute minimum or maximum of values depending on what level the
        # search is currently on
        if player == curr_player:
            val_idx = np.argmax(minmax_vals)
        else:
            val_idx = np.argmin(minmax_vals)
        val = minmax_vals[val_idx]
        opt_action = actions[val_idx]

        # update the expanded node with the value and optimal action
        node_data.minmax_value = val
        node_data.optimal_action = opt_action
        self.tree.update_node_data(node_idx, node_data)
        return val