コード例 #1
0
ファイル: strategies.py プロジェクト: zysilence/MuGo
class PolicyNetworkBestMovePlayer(GtpInterface):
    def __init__(self, read_file):
        self.policy_network = PolicyNetwork(DEFAULT_FEATURES.planes,
                                            use_cpu=True)
        self.read_file = read_file
        super().__init__()

    def clear(self):
        super().clear()
        self.refresh_network()

    def refresh_network(self):
        # Ensure that the player is using the latest version of the network
        # so that the network can be continually trained even as it's playing.
        self.policy_network.initialize_variables(self.read_file)

    def suggest_move(self, position):
        if position.recent and position.n > 100 and position.recent[-1] == None:
            # Pass if the opponent passes
            return None
        move_probabilities = self.policy_network.run(position)
        for move in sorted_moves(move_probabilities):
            if go.is_reasonable(position, move):
                return move
        return None
コード例 #2
0
                tf.Variable(v,
                            name=name.replace('PolicNetwork',
                                              'PlayerNetwork')))
        saver = tf.train.Saver(new_vars)
        sess.run(tf.global_variables_initializer())
        saver.save(sess,
                   os.path.join(save_dir, str(t), 'player' + str(t) + '.ckpt'))


g1 = tf.Graph()
with g1.as_default():
    train_net = PolicyNetwork(scope="PolicNetwork")
    train_net.initialize_variables('model/sl/epoch_48.ckpt')

pos = go.Position()
train_net.run(pos)

g2 = tf.Graph()
with g2.as_default():
    player_net = PolicyNetwork(scope="PlayerNetwork")
    player_net.initialize_variables('model/rl/2/player2.ckpt')
pos = go.Position()
player_net.run(pos)

save_trained_policy(1, 'model/rl')

print("===========load new model=================")
g2 = tf.Graph()
with g2.as_default():
    player_net = PolicyNetwork(scope="PlayerNetwork")
    player_net.initialize_variables('model/rl/5/player5.ckpt')
コード例 #3
0
ファイル: strategies.py プロジェクト: zysilence/MuGo
class MCTS(GtpInterface):
    def __init__(self, read_file, seconds_per_move=5):
        self.seconds_per_move = seconds_per_move
        self.max_rollout_depth = go.N * go.N * 3
        self.policy_network = PolicyNetwork(DEFAULT_FEATURES.planes,
                                            use_cpu=True)
        self.read_file = read_file
        super().__init__()

    def clear(self):
        super().clear()
        self.refresh_network()

    def refresh_network(self):
        # Ensure that the player is using the latest version of the network
        # so that the network can be continually trained even as it's playing.
        self.policy_network.initialize_variables(self.read_file)

    def suggest_move(self, position):
        if position.caps[0] + 50 < position.caps[1]:
            return gtp.RESIGN
        start = time.time()
        move_probs = self.policy_network.run(position)
        root = MCTSNode.root_node(position, move_probs)
        while time.time() - start < self.seconds_per_move:
            self.tree_search(root)
        # there's a theoretical bug here: if you refuse to pass, this AI will
        # eventually start filling in its own eyes.
        return max(root.children.keys(),
                   key=lambda move, root=root: root.children[move].N)

    def tree_search(self, root):
        print("tree search", file=sys.stderr)
        # selection
        chosen_leaf = root.select_leaf()
        # expansion
        position = chosen_leaf.compute_position()
        if position is None:
            print("illegal move!", file=sys.stderr)
            # See go.Position.play_move for notes on detecting legality
            del chosen_leaf.parent.children[chosen_leaf.move]
            return
        print("Investigating following position:\n%s" %
              (chosen_leaf.position, ),
              file=sys.stderr)
        move_probs = self.policy_network.run(position)
        chosen_leaf.expand(move_probs)
        # evaluation
        value = self.estimate_value(chosen_leaf)
        # backup
        print("value: %s" % value, file=sys.stderr)
        chosen_leaf.backup_value(value)

    def estimate_value(self, chosen_leaf):
        # Estimate value of position using rollout only (for now).
        # (TODO: Value network; average the value estimations from rollout + value network)
        leaf_position = chosen_leaf.position
        current = leaf_position
        while current.n < self.max_rollout_depth:
            move_probs = self.policy_network.run(current)
            current = self.play_valid_move(current, move_probs)
            if len(current.recent
                   ) > 2 and current.recent[-1] == current.recent[-2] == None:
                break
        else:
            print("max rollout depth exceeded!", file=sys.stderr)

        perspective = 1 if leaf_position.player1turn else -1
        return current.score() * perspective

    def play_valid_move(self, position, move_probs):
        for move in sorted_moves(move_probs):
            if go.is_eyeish(position.board, move):
                continue
            candidate_pos = position.play_move(move, mutate=True)
            if candidate_pos is not None:
                return candidate_pos
        return position.pass_move(mutate=True)