コード例 #1
0
class Node:
    def __init__(self):
        self.parents = BoardCache()
        self.visits = 0
        self.wins = 0
        self.losses = 0
        self.draws = 0

    def add_parent_node(self, node_cache, parent_board):
        result, found = self.parents.get_for_position(parent_board)
        if found is False:
            parent_node = find_or_create_node(node_cache, parent_board)
            self.parents.set_for_position(parent_board, parent_node)

    def get_total_visits_for_parent_nodes(self):
        return sum([parent_node.visits for parent_node
                    in self.parents.cache.values()])

    def value(self):
        if self.visits == 0:
            return 0

        success_percentage = (self.wins + self.draws) / self.visits
        return success_percentage
コード例 #2
0
ファイル: qtable.py プロジェクト: sayedkamal2016/tictac
class QTable:
    def __init__(self):
        self.qtable = BoardCache()

    def get_q_values(self, board):
        move_indexes = board.get_valid_move_indexes()
        qvalues = [
            self.get_q_value(board, mi)
            for mi in board.get_valid_move_indexes()
        ]

        return dict(zip(move_indexes, qvalues))

    def get_q_value(self, board, move_index):
        new_position = board.play_move(move_index)
        result, found = self.qtable.get_for_position(new_position)
        if found is True:
            qvalue, _ = result
            return qvalue

        return get_initial_q_value(new_position)

    def update_q_value(self, board, move_index, qvalue):
        new_position = board.play_move(move_index)

        result, found = self.qtable.get_for_position(new_position)
        if found is True:
            _, t = result
            new_position_transformed = Board(
                t.transform(new_position.board_2d).flatten())
            self.qtable.set_for_position(new_position_transformed, qvalue)
            return

        self.qtable.set_for_position(new_position, qvalue)

    def get_move_index_and_max_q_value(self, board):
        q_values = self.get_q_values(board)
        return max(q_values.items(), key=operator.itemgetter(1))

    def print_q_values(self):
        print(f"num q_values = {len(self.qtable.cache)}")
        for k, v in self.qtable.cache.items():
            b = np.frombuffer(k, dtype=int)
            board = Board(b)
            board.print_board()
            print(f"qvalue = {v}")
コード例 #3
0
class QTable:
    def __init__(self):
        self.qtable = BoardCache()

    def get_q_values(self, board):
        result, found = self.qtable.get_for_position(board)
        if found:
            qvalues, transform = result
            return reverse_transform_qvalues(qvalues, transform)

        valid_move_indexes = board.get_valid_move_indexes()
        initial_q_value = get_initial_q_value(board)
        initial_q_values = [initial_q_value for _ in valid_move_indexes]

        qvalues = dict(zip(valid_move_indexes, initial_q_values))

        self.qtable.set_for_position(board, qvalues)

        return qvalues

    def get_q_value(self, board, move_index):
        return self.get_q_values(board)[move_index]

    def update_q_value(self, board, move_index, qvalue):
        qvalues = self.get_q_values(board)
        qvalues[move_index] = qvalue

        result, found = self.qtable.get_for_position(board)
        assert found is True, "position must be cached at this point"
        _, transform = result

        transformed_board, transformed_qvalues = transform_board_and_qvalues(
            board, qvalues, transform)

        self.qtable.set_for_position(transformed_board, transformed_qvalues)

    def get_move_index_and_max_q_value(self, board):
        q_values = self.get_q_values(board)
        return max(q_values.items(), key=operator.itemgetter(1))
コード例 #4
0
def test_play_mcts_move():
    b_2d = np.array([[1, 1, 0], [1, -1, 0], [-1, 1, -1]])
    b = b_2d.flatten()
    board = Board(b)
    nc = BoardCache()

    parent_node = find_or_create_node(nc, board)
    actual_stats = (parent_node.visits, parent_node.wins, parent_node.draws,
                    parent_node.losses)
    assert actual_stats == (0, 0, 0, 0)

    values = calculate_values(nc, board)
    expected_values = [(2, math.inf), (5, math.inf)]
    assert list(values) == expected_values

    perform_game_playout(nc, board)

    actual_stats = (parent_node.visits, parent_node.wins, parent_node.draws,
                    parent_node.losses)
    assert actual_stats == (1, 0, 0, 1)

    child_node_2 = find_or_create_node(nc, board.play_move(2))
    actual_stats = (child_node_2.visits, child_node_2.wins, child_node_2.draws,
                    child_node_2.losses)
    assert actual_stats == (1, 1, 0, 0)

    child_node_5 = find_or_create_node(nc, board.play_move(5))
    actual_stats = (child_node_5.visits, child_node_5.wins, child_node_5.draws,
                    child_node_5.losses)
    assert actual_stats == (0, 0, 0, 0)

    values = calculate_values(nc, board)
    expected_values = [(2, 1.0), (5, math.inf)]
    assert list(values) == expected_values

    perform_game_playout(nc, board)

    actual_stats = (parent_node.visits, parent_node.wins, parent_node.draws,
                    parent_node.losses)
    assert actual_stats == (2, 1, 0, 1)

    actual_stats = (child_node_2.visits, child_node_2.wins, child_node_2.draws,
                    child_node_2.losses)
    assert actual_stats == (1, 1, 0, 0)

    actual_stats = (child_node_5.visits, child_node_5.wins, child_node_5.draws,
                    child_node_5.losses)
    assert actual_stats == (1, 0, 0, 1)

    values = calculate_values(nc, board)
    expected_values = [(2, 2.177410022515475), (5, 1.1774100225154747)]
    assert list(values) == expected_values

    perform_training_playouts(nc, board, 100, False)

    actual_stats = (parent_node.visits, parent_node.wins, parent_node.draws,
                    parent_node.losses)
    assert actual_stats == (102, 6, 0, 96)

    actual_stats = (child_node_2.visits, child_node_2.wins, child_node_2.draws,
                    child_node_2.losses)
    assert actual_stats == (96, 96, 0, 0)

    actual_stats = (child_node_5.visits, child_node_5.wins, child_node_5.draws,
                    child_node_5.losses)
    assert actual_stats == (6, 0, 0, 6)

    values = calculate_values(nc, board)
    expected_values = [(2, 1.3104087632087014), (5, 1.2416350528348057)]
    assert list(values) == expected_values
コード例 #5
0
 def __init__(self):
     self.qtable = BoardCache()
コード例 #6
0
 def __init__(self):
     self.parents = BoardCache()
     self.visits = 0
     self.wins = 0
     self.losses = 0
     self.draws = 0
コード例 #7
0
import math

from tictac.board import play_game
from tictac.board import (Board, BoardCache, CELL_X, CELL_O, RESULT_X_WINS,
                          RESULT_O_WINS, is_draw)

nodecache = BoardCache()


class Node:
    def __init__(self):
        self.parents = BoardCache()
        self.visits = 0
        self.wins = 0
        self.losses = 0
        self.draws = 0

    def add_parent_node(self, node_cache, parent_board):
        result, found = self.parents.get_for_position(parent_board)
        if found is False:
            parent_node = find_or_create_node(node_cache, parent_board)
            self.parents.set_for_position(parent_board, parent_node)

    def get_total_visits_for_parent_nodes(self):
        return sum([parent_node.visits for parent_node
                    in self.parents.cache.values()])

    def value(self):
        if self.visits == 0:
            return 0
コード例 #8
0
import random

from tictac.board import BoardCache
from tictac.board import CELL_O
from tictac.board import is_empty

cache = BoardCache()


def create_minimax_player(randomize):
    def play(board):
        return play_minimax_move(board, randomize)

    return play


def play_minimax_move(board, randomize=False):
    move_value_pairs = get_move_value_pairs(board)
    move = filter_best_move(board, move_value_pairs, randomize)

    return board.play_move(move)


def get_move_value_pairs(board):
    valid_move_indexes = board.get_valid_move_indexes()

    assert not is_empty(valid_move_indexes), "never call with an end position"

    move_value_pairs = [(m, get_position_value(board.play_move(m)))
                        for m in valid_move_indexes]