コード例 #1
0
    def _minigui_report_search_status(self, leaves):
        """Prints the current MCTS search status to stderr.

    Reports the current search path, root node's child_Q, root node's
    child_N, the most visited path in a format that can be parsed by
    one of the STDERR_HANDLERS in minigui.ts.

    Args:
      leaves: list of leaf MCTSNodes returned by tree_search().
    """

        root = self._player.get_root()

        msg = {
            "id": hex(id(root)),
            "n": int(root.N),
            "q": float(root.Q),
        }

        msg["childQ"] = [int(round(q * 1000)) for q in root.child_Q]
        msg["childN"] = [int(n) for n in root.child_N]

        ranked_children = root.rank_children()
        variations = {}
        for i in ranked_children[:15]:
            if root.child_N[i] == 0 or i not in root.children:
                break
            c = coords.to_gtp(coords.from_flat(i))
            child = root.children[i]
            nodes = child.most_visited_path_nodes()
            moves = [coords.to_gtp(coords.from_flat(m.fmove)) for m in nodes]
            variations[c] = {
                "n": int(root.child_N[i]),
                "q": float(root.child_Q[i]),
                "moves": [c] + moves,
            }

        if leaves:
            path = []
            leaf = leaves[0]
            while leaf != root:
                path.append(leaf.fmove)
                leaf = leaf.parent
            if path:
                path.reverse()
                variations["live"] = {
                    "n": int(root.child_N[path[0]]),
                    "q": float(root.child_Q[path[0]]),
                    "moves":
                    [coords.to_gtp(coords.from_flat(m)) for m in path]
                }

        if variations:
            msg["variations"] = variations

        dbg("mg-update:%s" % json.dumps(msg, sort_keys=True))
コード例 #2
0
def print_example(examples, i):
    example = examples[i]
    p = parse_board(example)
    print('\nExample %d of %d, %s to play, winner is %s' %
          (i + 1, len(examples), 'Black' if p.to_play == 1 else 'White',
           'Black' if example.value > 0 else 'White'))

    if example.n != -1:
        print(
            'N:%d  Q:%.3f  picked:%s' %
            (example.n, example.q, coords.to_gtp(coords.from_flat(example.c))))
    board_lines = str(p).split('\n')[:-2]

    mean = np.mean(example.pi[example.pi > 0])
    mx = np.max(example.pi)

    pi_lines = ['PI']
    for row in range(go.N):
        pi = []
        for col in range(go.N):
            stone = p.board[row, col]
            idx = row * go.N + col
            if example.c != -1:
                picked = example.c == row * go.N + col
            else:
                picked = False
            pi.append(format_pi(example.pi[idx], stone, mean, mx, picked))
        pi_lines.append(' '.join(pi))

    pi_lines.append(
        format_pi(example.pi[-1], go.EMPTY, mean, mx,
                  example.c == go.N * go.N))

    for b, p in zip(board_lines, pi_lines):
        print('%s  |  %s' % (b, p))
コード例 #3
0
ファイル: mcts.py プロジェクト: shjwudp/training_results_v0.7
 def mvp_gg(self):
   """Returns most visited path in go-gui VAR format e.g. 'b r3 w c17..."""
   output = []
   for node in self.most_visited_path_nodes():
     if max(node.child_N) <= 1:
       break
     output.append(coords.to_gtp(coords.from_flat(node.fmove)))
   return ' '.join(output)
コード例 #4
0
ファイル: mcts.py プロジェクト: shjwudp/training_results_v0.7
  def most_visited_path(self):
    output = []
    node = self
    for node in self.most_visited_path_nodes():
      output.append('%s (%d) ==> ' %
                    (coords.to_gtp(coords.from_flat(node.fmove)), node.N))

    output.append('Q: {:.5f}\n'.format(node.Q))
    return ''.join(output)
コード例 #5
0
ファイル: mcts.py プロジェクト: shjwudp/training_results_v0.7
 def describe(self):
   ranked_children = self.rank_children()
   soft_n = self.child_N / max(1, sum(self.child_N))
   prior = self.child_prior
   p_delta = soft_n - prior
   p_rel = np.divide(
       p_delta, prior, out=np.zeros_like(p_delta), where=prior != 0)
   # Dump out some statistics
   output = []
   output.append('{q:.4f}\n'.format(q=self.Q))
   output.append(self.most_visited_path())
   output.append(
       'move : action    Q     U     P   P-Dir    N  soft-N  p-delta  p-rel')
   for i in ranked_children[:15]:
     if self.child_N[i] == 0:
       break
     output.append(
         '\n{!s:4} : {: .3f} {: .3f} {:.3f} {:.3f} {:.3f} {:5d} {:.4f} {: .5f} {: .2f}'
         .format(
             coords.to_gtp(coords.from_flat(i)), self.child_action_score[i],
             self.child_Q[i],
             self.child_U[i], self.child_prior[i], self.original_prior[i],
             int(self.child_N[i]), soft_n[i], p_delta[i], p_rel[i]))
   return ''.join(output)
コード例 #6
0
 def _heatmap(self, sort_order, node, prop):
     return "\n".join([
         "{!s:6} {}".format(coords.to_gtp(coords.from_flat(key)),
                            node.__dict__.get(prop)[key])
         for key in sort_order if node.child_N[key] > 0
     ][:20])
コード例 #7
0
def play(network):
    """Plays out a self-play match, returning a MCTSPlayer object containing.

        - the final position
        - the n x 362 tensor of floats representing the mcts search
        probabilities
        - the n-ary tensor of floats representing the original value-net
        estimate
          where n is the number of moves in the game
  """
    readouts = FLAGS.num_readouts  # defined in strategies.py
    # Disable resign in 5% of games
    if random.random() < FLAGS.resign_disable_pct:
        resign_threshold = -1.0
    else:
        resign_threshold = None

    player = MCTSPlayer(network, resign_threshold=resign_threshold)

    player.initialize_game()

    # Must run this once at the start to expand the root node.
    first_node = player.root.select_leaf()
    prob, val = network.run(first_node.position)
    first_node.incorporate_results(prob, val, first_node)

    while True:
        start = time.time()
        player.root.inject_noise()
        current_readouts = player.root.N
        # we want to do "X additional readouts", rather than "up to X readouts".
        while player.root.N < current_readouts + readouts:
            player.tree_search()

        if FLAGS.verbose >= 3:
            print(player.root.position)
            print(player.root.describe())

        if player.should_resign():
            player.set_result(-1 * player.root.position.to_play,
                              was_resign=True)
            break
        move = player.pick_move()
        player.play_move(move)
        if player.root.is_done():
            player.set_result(player.root.position.result(), was_resign=False)
            break

        if (FLAGS.verbose >= 2) or (FLAGS.verbose >= 1
                                    and player.root.position.n % 10 == 9):
            print('Q: {:.5f}'.format(player.root.Q))
            dur = time.time() - start
            print('%d: %d readouts, %.3f s/100. (%.2f sec)' %
                  (player.root.position.n, readouts, dur / readouts * 100.0,
                   dur))
        if FLAGS.verbose >= 3:
            print('Played >>',
                  coords.to_gtp(coords.from_flat(player.root.fmove)))

    if FLAGS.verbose >= 2:
        utils.dbg('%s: %.3f' % (player.result_string, player.root.Q))
        utils.dbg(player.root.position, player.root.position.score())

    return player
コード例 #8
0
ファイル: mcts.py プロジェクト: shjwudp/training_results_v0.7
 def maybe_add_child(self, fcoord):
   """Adds child node for fcoord if it doesn't already exist, and returns it."""
   if fcoord not in self.children:
     new_position = self.position.play_move(coords.from_flat(fcoord))
     self.children[fcoord] = MCTSNode(new_position, fmove=fcoord, parent=self)
   return self.children[fcoord]