def self_play(env, agent, return_trajectory=False, verbose=False): if return_trajectory: trajectory = [] observation = env.reset() for step in itertools.count(): board, player = observation action, prob = agent.decide(observation, return_prob=True) if verbose: print(boardgame2.strfboard(board)) logging.info('第 {} 步:玩家 {}, 动作 {}'.format(step, player, action)) observation, winner, done, _ = env.step(action) if return_trajectory: trajectory.append((player, board, prob)) if done: if verbose: print(boardgame2.strfboard(observation[0])) logging.info('赢家 {}'.format(winner)) break if return_trajectory: df_trajectory = pd.DataFrame(trajectory, columns=['player', 'board', 'prob']) df_trajectory['winner'] = winner return df_trajectory else: return winner
def run(self, episodes=1, black_first=True, verbose=False): black_win = 0 white_win = 0 for episode in range(1, episodes+1): observation = self.env.reset() done = False turn = BLACK if black_first else WHITE for step in range(1, 1000): action = self.black_agent.decide(observation) if turn == BLACK else self.white_agent.decide(observation) observation, winner, done, info = self.env.step(action) if verbose: print("第{}步, {}方落子: {}".format(step, "黑" if turn == BLACK else "白", action)) print(boardgame2.strfboard(observation[0])) turn *= -1 if done: break if winner == BLACK: black_win += 1 print("回合{}, 黑方胜利!".format(episode)) else: white_win += 1 print("回合{}, 白方胜利!".format(episode)) self.env.close() print("总计{}回合, 黑方胜率: {}, 白方胜率:{}".format(episode, black_win / episode, white_win / episode))
def search(self, board, prior_noise=False): # MCTS 搜索 s = boardgame2.strfboard(board) if s not in self.winner: self.winner[s] = self.env.get_winner((board, BLACK)) # 计算赢家 if self.winner[s] is not None: # 赢家确定的情况 return self.winner[s] if s not in self.policy: # 未计算过策略的叶子节点 pis, vs = self.net.predict(board[np.newaxis]) pi, v = pis[0], vs[0] valid = self.env.get_valid((board, BLACK)) masked_pi = pi * valid total_masked_pi = np.sum(masked_pi) if total_masked_pi <= 0: # 所有的有效动作都没有概率,偶尔可能发生 masked_pi = valid # workaround total_masked_pi = np.sum(masked_pi) self.policy[s] = masked_pi / total_masked_pi self.valid[s] = valid return v # PUCT 上界计算 count_sum = self.count[s].sum() coef = (self.c_init + np.log1p((1 + count_sum) / self.c_base)) * \ math.sqrt(count_sum) / (1. + self.count[s]) if prior_noise: # 先验噪声 alpha = 1. / self.valid[s].sum() noise = np.random.gamma(alpha, 1., board.shape) noise *= self.valid[s] noise /= noise.sum() prior = (1. - self.prior_exploration_fraction) * \ self.policy[s] + \ self.prior_exploration_fraction * noise else: prior = self.policy[s] ub = np.where(self.valid[s], self.q[s] + coef * prior, np.nan) location_index = np.nanargmax(ub) location = np.unravel_index(location_index, board.shape) (next_board, next_player), _, _, _ = self.env.next_step( (board, BLACK), np.array(location)) next_canonical_board = next_player * next_board next_v = self.search(next_canonical_board) # 递归搜索 v = next_player * next_v self.count[s][location] += 1 self.q[s][location] += (v - self.q[s][location]) / \ self.count[s][location] return v
def decide(self, observation, greedy=False, return_prob=False): # 计算策略 board, player = observation canonical_board = player * board s = boardgame2.strfboard(canonical_board) while self.count[s].sum() < self.sim_count: # 多次 MCTS 搜索 self.search(canonical_board, prior_noise=True) prob = self.count[s] / self.count[s].sum() # 采样 location_index = np.random.choice(prob.size, p=prob.reshape(-1)) location = np.unravel_index(location_index, prob.shape) if return_prob: return location, prob return location
def hash_convert(self, state): board, player = state return (strfboard(board), player)
def __contains__(self, state): board, player = state board_ = board if isinstance(board, str) else strfboard(board) return super().__contains__((board_, player))
def __setitem__(self, state, value): board, player = state board_ = board if isinstance(board, str) else strfboard(board) return super().__setitem__((board_, player), value)