class Node: def __init__(self): self.parents = BoardCache() self.visits = 0 self.wins = 0 self.losses = 0 self.draws = 0 def add_parent_node(self, node_cache, parent_board): result, found = self.parents.get_for_position(parent_board) if found is False: parent_node = find_or_create_node(node_cache, parent_board) self.parents.set_for_position(parent_board, parent_node) def get_total_visits_for_parent_nodes(self): return sum([parent_node.visits for parent_node in self.parents.cache.values()]) def value(self): if self.visits == 0: return 0 success_percentage = (self.wins + self.draws) / self.visits return success_percentage
class QTable: def __init__(self): self.qtable = BoardCache() def get_q_values(self, board): move_indexes = board.get_valid_move_indexes() qvalues = [ self.get_q_value(board, mi) for mi in board.get_valid_move_indexes() ] return dict(zip(move_indexes, qvalues)) def get_q_value(self, board, move_index): new_position = board.play_move(move_index) result, found = self.qtable.get_for_position(new_position) if found is True: qvalue, _ = result return qvalue return get_initial_q_value(new_position) def update_q_value(self, board, move_index, qvalue): new_position = board.play_move(move_index) result, found = self.qtable.get_for_position(new_position) if found is True: _, t = result new_position_transformed = Board( t.transform(new_position.board_2d).flatten()) self.qtable.set_for_position(new_position_transformed, qvalue) return self.qtable.set_for_position(new_position, qvalue) def get_move_index_and_max_q_value(self, board): q_values = self.get_q_values(board) return max(q_values.items(), key=operator.itemgetter(1)) def print_q_values(self): print(f"num q_values = {len(self.qtable.cache)}") for k, v in self.qtable.cache.items(): b = np.frombuffer(k, dtype=int) board = Board(b) board.print_board() print(f"qvalue = {v}")
class QTable: def __init__(self): self.qtable = BoardCache() def get_q_values(self, board): result, found = self.qtable.get_for_position(board) if found: qvalues, transform = result return reverse_transform_qvalues(qvalues, transform) valid_move_indexes = board.get_valid_move_indexes() initial_q_value = get_initial_q_value(board) initial_q_values = [initial_q_value for _ in valid_move_indexes] qvalues = dict(zip(valid_move_indexes, initial_q_values)) self.qtable.set_for_position(board, qvalues) return qvalues def get_q_value(self, board, move_index): return self.get_q_values(board)[move_index] def update_q_value(self, board, move_index, qvalue): qvalues = self.get_q_values(board) qvalues[move_index] = qvalue result, found = self.qtable.get_for_position(board) assert found is True, "position must be cached at this point" _, transform = result transformed_board, transformed_qvalues = transform_board_and_qvalues( board, qvalues, transform) self.qtable.set_for_position(transformed_board, transformed_qvalues) def get_move_index_and_max_q_value(self, board): q_values = self.get_q_values(board) return max(q_values.items(), key=operator.itemgetter(1))
def test_play_mcts_move(): b_2d = np.array([[1, 1, 0], [1, -1, 0], [-1, 1, -1]]) b = b_2d.flatten() board = Board(b) nc = BoardCache() parent_node = find_or_create_node(nc, board) actual_stats = (parent_node.visits, parent_node.wins, parent_node.draws, parent_node.losses) assert actual_stats == (0, 0, 0, 0) values = calculate_values(nc, board) expected_values = [(2, math.inf), (5, math.inf)] assert list(values) == expected_values perform_game_playout(nc, board) actual_stats = (parent_node.visits, parent_node.wins, parent_node.draws, parent_node.losses) assert actual_stats == (1, 0, 0, 1) child_node_2 = find_or_create_node(nc, board.play_move(2)) actual_stats = (child_node_2.visits, child_node_2.wins, child_node_2.draws, child_node_2.losses) assert actual_stats == (1, 1, 0, 0) child_node_5 = find_or_create_node(nc, board.play_move(5)) actual_stats = (child_node_5.visits, child_node_5.wins, child_node_5.draws, child_node_5.losses) assert actual_stats == (0, 0, 0, 0) values = calculate_values(nc, board) expected_values = [(2, 1.0), (5, math.inf)] assert list(values) == expected_values perform_game_playout(nc, board) actual_stats = (parent_node.visits, parent_node.wins, parent_node.draws, parent_node.losses) assert actual_stats == (2, 1, 0, 1) actual_stats = (child_node_2.visits, child_node_2.wins, child_node_2.draws, child_node_2.losses) assert actual_stats == (1, 1, 0, 0) actual_stats = (child_node_5.visits, child_node_5.wins, child_node_5.draws, child_node_5.losses) assert actual_stats == (1, 0, 0, 1) values = calculate_values(nc, board) expected_values = [(2, 2.177410022515475), (5, 1.1774100225154747)] assert list(values) == expected_values perform_training_playouts(nc, board, 100, False) actual_stats = (parent_node.visits, parent_node.wins, parent_node.draws, parent_node.losses) assert actual_stats == (102, 6, 0, 96) actual_stats = (child_node_2.visits, child_node_2.wins, child_node_2.draws, child_node_2.losses) assert actual_stats == (96, 96, 0, 0) actual_stats = (child_node_5.visits, child_node_5.wins, child_node_5.draws, child_node_5.losses) assert actual_stats == (6, 0, 0, 6) values = calculate_values(nc, board) expected_values = [(2, 1.3104087632087014), (5, 1.2416350528348057)] assert list(values) == expected_values
def __init__(self): self.qtable = BoardCache()
def __init__(self): self.parents = BoardCache() self.visits = 0 self.wins = 0 self.losses = 0 self.draws = 0
import math from tictac.board import play_game from tictac.board import (Board, BoardCache, CELL_X, CELL_O, RESULT_X_WINS, RESULT_O_WINS, is_draw) nodecache = BoardCache() class Node: def __init__(self): self.parents = BoardCache() self.visits = 0 self.wins = 0 self.losses = 0 self.draws = 0 def add_parent_node(self, node_cache, parent_board): result, found = self.parents.get_for_position(parent_board) if found is False: parent_node = find_or_create_node(node_cache, parent_board) self.parents.set_for_position(parent_board, parent_node) def get_total_visits_for_parent_nodes(self): return sum([parent_node.visits for parent_node in self.parents.cache.values()]) def value(self): if self.visits == 0: return 0
import random from tictac.board import BoardCache from tictac.board import CELL_O from tictac.board import is_empty cache = BoardCache() def create_minimax_player(randomize): def play(board): return play_minimax_move(board, randomize) return play def play_minimax_move(board, randomize=False): move_value_pairs = get_move_value_pairs(board) move = filter_best_move(board, move_value_pairs, randomize) return board.play_move(move) def get_move_value_pairs(board): valid_move_indexes = board.get_valid_move_indexes() assert not is_empty(valid_move_indexes), "never call with an end position" move_value_pairs = [(m, get_position_value(board.play_move(m))) for m in valid_move_indexes]