Exemple #1
0
class QTable:
    def __init__(self):
        self.qtable = BoardCache()

    def get_q_values(self, board):
        move_indexes = board.get_valid_move_indexes()
        qvalues = [
            self.get_q_value(board, mi)
            for mi in board.get_valid_move_indexes()
        ]

        return dict(zip(move_indexes, qvalues))

    def get_q_value(self, board, move_index):
        new_position = board.play_move(move_index)
        result, found = self.qtable.get_for_position(new_position)
        if found is True:
            qvalue, _ = result
            return qvalue

        return get_initial_q_value(new_position)

    def update_q_value(self, board, move_index, qvalue):
        new_position = board.play_move(move_index)

        result, found = self.qtable.get_for_position(new_position)
        if found is True:
            _, t = result
            new_position_transformed = Board(
                t.transform(new_position.board_2d).flatten())
            self.qtable.set_for_position(new_position_transformed, qvalue)
            return

        self.qtable.set_for_position(new_position, qvalue)

    def get_move_index_and_max_q_value(self, board):
        q_values = self.get_q_values(board)
        return max(q_values.items(), key=operator.itemgetter(1))

    def print_q_values(self):
        print(f"num q_values = {len(self.qtable.cache)}")
        for k, v in self.qtable.cache.items():
            b = np.frombuffer(k, dtype=int)
            board = Board(b)
            board.print_board()
            print(f"qvalue = {v}")
class QTable:
    def __init__(self):
        self.qtable = BoardCache()

    def get_q_values(self, board):
        result, found = self.qtable.get_for_position(board)
        if found:
            qvalues, transform = result
            return reverse_transform_qvalues(qvalues, transform)

        valid_move_indexes = board.get_valid_move_indexes()
        initial_q_value = get_initial_q_value(board)
        initial_q_values = [initial_q_value for _ in valid_move_indexes]

        qvalues = dict(zip(valid_move_indexes, initial_q_values))

        self.qtable.set_for_position(board, qvalues)

        return qvalues

    def get_q_value(self, board, move_index):
        return self.get_q_values(board)[move_index]

    def update_q_value(self, board, move_index, qvalue):
        qvalues = self.get_q_values(board)
        qvalues[move_index] = qvalue

        result, found = self.qtable.get_for_position(board)
        assert found is True, "position must be cached at this point"
        _, transform = result

        transformed_board, transformed_qvalues = transform_board_and_qvalues(
            board, qvalues, transform)

        self.qtable.set_for_position(transformed_board, transformed_qvalues)

    def get_move_index_and_max_q_value(self, board):
        q_values = self.get_q_values(board)
        return max(q_values.items(), key=operator.itemgetter(1))
class Node:
    def __init__(self):
        self.parents = BoardCache()
        self.visits = 0
        self.wins = 0
        self.losses = 0
        self.draws = 0

    def add_parent_node(self, node_cache, parent_board):
        result, found = self.parents.get_for_position(parent_board)
        if found is False:
            parent_node = find_or_create_node(node_cache, parent_board)
            self.parents.set_for_position(parent_board, parent_node)

    def get_total_visits_for_parent_nodes(self):
        return sum([parent_node.visits for parent_node
                    in self.parents.cache.values()])

    def value(self):
        if self.visits == 0:
            return 0

        success_percentage = (self.wins + self.draws) / self.visits
        return success_percentage