class PolicyNetworkBestMovePlayer(GtpInterface): def __init__(self, read_file): self.policy_network = PolicyNetwork(DEFAULT_FEATURES.planes, use_cpu=True) self.read_file = read_file super().__init__() def clear(self): super().clear() self.refresh_network() def refresh_network(self): # Ensure that the player is using the latest version of the network # so that the network can be continually trained even as it's playing. self.policy_network.initialize_variables(self.read_file) def suggest_move(self, position): if position.recent and position.n > 100 and position.recent[-1] == None: # Pass if the opponent passes return None move_probabilities = self.policy_network.run(position) for move in sorted_moves(move_probabilities): if go.is_reasonable(position, move): return move return None
tf.Variable(v, name=name.replace('PolicNetwork', 'PlayerNetwork'))) saver = tf.train.Saver(new_vars) sess.run(tf.global_variables_initializer()) saver.save(sess, os.path.join(save_dir, str(t), 'player' + str(t) + '.ckpt')) g1 = tf.Graph() with g1.as_default(): train_net = PolicyNetwork(scope="PolicNetwork") train_net.initialize_variables('model/sl/epoch_48.ckpt') pos = go.Position() train_net.run(pos) g2 = tf.Graph() with g2.as_default(): player_net = PolicyNetwork(scope="PlayerNetwork") player_net.initialize_variables('model/rl/2/player2.ckpt') pos = go.Position() player_net.run(pos) save_trained_policy(1, 'model/rl') print("===========load new model=================") g2 = tf.Graph() with g2.as_default(): player_net = PolicyNetwork(scope="PlayerNetwork") player_net.initialize_variables('model/rl/5/player5.ckpt')
class MCTS(GtpInterface): def __init__(self, read_file, seconds_per_move=5): self.seconds_per_move = seconds_per_move self.max_rollout_depth = go.N * go.N * 3 self.policy_network = PolicyNetwork(DEFAULT_FEATURES.planes, use_cpu=True) self.read_file = read_file super().__init__() def clear(self): super().clear() self.refresh_network() def refresh_network(self): # Ensure that the player is using the latest version of the network # so that the network can be continually trained even as it's playing. self.policy_network.initialize_variables(self.read_file) def suggest_move(self, position): if position.caps[0] + 50 < position.caps[1]: return gtp.RESIGN start = time.time() move_probs = self.policy_network.run(position) root = MCTSNode.root_node(position, move_probs) while time.time() - start < self.seconds_per_move: self.tree_search(root) # there's a theoretical bug here: if you refuse to pass, this AI will # eventually start filling in its own eyes. return max(root.children.keys(), key=lambda move, root=root: root.children[move].N) def tree_search(self, root): print("tree search", file=sys.stderr) # selection chosen_leaf = root.select_leaf() # expansion position = chosen_leaf.compute_position() if position is None: print("illegal move!", file=sys.stderr) # See go.Position.play_move for notes on detecting legality del chosen_leaf.parent.children[chosen_leaf.move] return print("Investigating following position:\n%s" % (chosen_leaf.position, ), file=sys.stderr) move_probs = self.policy_network.run(position) chosen_leaf.expand(move_probs) # evaluation value = self.estimate_value(chosen_leaf) # backup print("value: %s" % value, file=sys.stderr) chosen_leaf.backup_value(value) def estimate_value(self, chosen_leaf): # Estimate value of position using rollout only (for now). # (TODO: Value network; average the value estimations from rollout + value network) leaf_position = chosen_leaf.position current = leaf_position while current.n < self.max_rollout_depth: move_probs = self.policy_network.run(current) current = self.play_valid_move(current, move_probs) if len(current.recent ) > 2 and current.recent[-1] == current.recent[-2] == None: break else: print("max rollout depth exceeded!", file=sys.stderr) perspective = 1 if leaf_position.player1turn else -1 return current.score() * perspective def play_valid_move(self, position, move_probs): for move in sorted_moves(move_probs): if go.is_eyeish(position.board, move): continue candidate_pos = position.play_move(move, mutate=True) if candidate_pos is not None: return candidate_pos return position.pass_move(mutate=True)