Example #1
0
    def play_move(self, c):
        '''
        Notable side effects:
          - finalizes the probability distribution according to
          this roots visit counts into the class' running tally, `searches_pi`
          - Makes the node associated with this move the root, for future
            `inject_noise` calls.
        '''
        if not self.two_player_mode:
            self.searches_pi.append(
                self.root.children_as_pi(
                    self.root.position.n < self.temp_threshold))
        self.comments.append(self.root.describe())
        try:
            self.root = self.root.maybe_add_child(coords.to_flat(c))
        except go.IllegalMove:
            dbg("Illegal move")
            if not self.two_player_mode:
                self.searches_pi.pop()
            self.comments.pop()
            raise

        self.position = self.root.position  # for showboard
        del self.root.parent.children
        return True  # GTP requires positive result.
Example #2
0
    def moves_from_games(self, start_game, end_game, moves, shuffle,
                         column_family, column):
        """Dataset of samples and/or shuffled moves from game range.

        Args:
          n:  an integer indicating how many past games should be sourced.
          moves:  an integer indicating how many moves should be sampled
            from those N games.
          column_family:  name of the column family containing move examples.
          column:  name of the column containing move examples.
          shuffle:  if True, shuffle the selected move examples.

        Returns:
          A dataset containing no more than `moves` examples, sampled
            randomly from the last `n` games in the table.
        """
        start_row = ROW_PREFIX.format(start_game)
        end_row = ROW_PREFIX.format(end_game)
        # NOTE:  Choose a probability high enough to guarantee at least the
        # required number of moves, by using a slightly lower estimate
        # of the total moves, then trimming the result.
        total_moves = self.count_moves_in_game_range(start_game, end_game)
        probability = moves / (total_moves * 0.99)
        utils.dbg('Row range: %s - %s; total moves: %d; probability %.3f; moves %d' % (
            start_row, end_row, total_moves, probability, moves))
        ds = self.tf_table.parallel_scan_range(start_row, end_row,
                                               probability=probability,
                                               columns=[(column_family, column)])
        if shuffle:
            utils.dbg('Doing a complete shuffle of %d moves' % moves)
            ds = ds.shuffle(moves)
        ds = ds.take(moves)
        return ds
Example #3
0
    def wait_for_fresh_games(self, poll_interval=15.0):
        """Block caller until required new games have been played.

        Args:
          poll_interval:  number of seconds to wait between checks

        If the cell `table_state=metadata:wait_for_game_number` exists,
        then block the caller, checking every `poll_interval` seconds,
        until `table_state=metadata:game_counter is at least the value
        in that cell.
        """
        wait_until_game = self.read_wait_cell()
        if not wait_until_game:
            return
        latest_game = self.latest_game_number
        last_latest = latest_game
        while latest_game < wait_until_game:
            utils.dbg('Latest game {} not yet at required game {} '
                      '(+{}, {:0.3f} games/sec)'.format(
                          latest_game, wait_until_game,
                          latest_game - last_latest,
                          (latest_game - last_latest) / poll_interval))
            time.sleep(poll_interval)
            last_latest = latest_game
            latest_game = self.latest_game_number
Example #4
0
    def trim_games_since(self, t, max_games=500000):
        """Trim off the games since the given time.

        Search back no more than max_games for this time point, locate
        the game there, and remove all games since that game,
        resetting the latest game counter.

        If `t` is a `datetime.timedelta`, then the target time will be
        found by subtracting that delta from the time of the last
        game.  Otherwise, it will be the target time.
        """
        latest = self.latest_game_number
        gbt = self.games_by_time(int(latest - max_games), latest)
        most_recent = gbt[-1]
        if isinstance(t, datetime.timedelta):
            target = most_recent[0] - t
        else:
            target = t
        i = bisect.bisect_right(gbt, (target, ))
        when, which = gbt[i]
        utils.dbg('Most recent:  %s  %s' % most_recent)
        utils.dbg('     Target:  %s  %s' % (when, which))
        which = int(which)
        self.delete_row_range(ROW_PREFIX, which, latest)
        self.delete_row_range(ROWCOUNT_PREFIX, which, latest)
        self.latest_game_number = which + 1
Example #5
0
 def tree_search(self, parallel_readouts=None):
     if parallel_readouts is None:
         parallel_readouts = FLAGS.parallel_readouts
     leaves = []
     failsafe = 0
     while len(leaves
               ) < parallel_readouts and failsafe < parallel_readouts * 2:
         failsafe += 1
         leaf = self.root.select_leaf()
         if self.verbosity >= 4:
             dbg(self.show_path_to_root(leaf))
         # if game is over, override the value estimate with the true score
         if leaf.is_done():
             value = 1 if leaf.position.score() > 0 else -1
             leaf.backup_value(value, up_to=self.root)
             continue
         leaf.add_virtual_loss(up_to=self.root)
         leaves.append(leaf)
     if leaves:
         move_probs, values = self.network.run_many(
             [leaf.position for leaf in leaves])
         for leaf, move_prob, value in zip(leaves, move_probs, values):
             leaf.revert_virtual_loss(up_to=self.root)
             leaf.incorporate_results(move_prob, value, up_to=self.root)
     return leaves
Example #6
0
def get_moves_from_games(start_game, end_game, moves, shuffle, column_family,
                         column):
    start_row, end_row = get_game_range_row_names(start_game, end_game)
    # NOTE:  Choose a probability high enough to guarantee at least the
    # required number of moves, by using a slightly lower estimate
    # of the total moves, then trimming the result.
    total_moves = count_moves_in_game_range(start_game, end_game)
    probability = moves / (total_moves * 0.99)
    utils.dbg('Row range: %s - %s; total moves: %d; probability %.3f' %
              (start_row, end_row, total_moves, probability))
    shards = 8
    ds = _tf_table.parallel_scan_range(start_row,
                                       end_row,
                                       probability=probability,
                                       num_parallel_scans=shards,
                                       columns=[(column_family, column)])
    if shuffle:
        rds = tf.data.Dataset.from_tensor_slices(
            tf.random_shuffle(tf.range(0, shards, dtype=tf.int64)))
        ds = rds.apply(
            tf.contrib.data.parallel_interleave(lambda x: ds.shard(shards, x),
                                                cycle_length=shards,
                                                block_length=1024))
        ds = ds.shuffle(shards * 1024 * 2)
    ds = ds.take(moves)
    return ds
Example #7
0
 def cmd_gamestate(self):
     position = self._player.get_position()
     root = self._player.get_root()
     msg = {}
     board = []
     for row in range(go.N):
         for col in range(go.N):
             stone = position.board[row, col]
             if stone == go.BLACK:
                 board.append("X")
             elif stone == go.WHITE:
                 board.append("O")
             else:
                 board.append(".")
     msg["board"] = "".join(board)
     msg["toPlay"] = "Black" if position.to_play == 1 else "White"
     if position.recent:
         msg["lastMove"] = coords.to_kgs(position.recent[-1].move)
     else:
         msg["lastMove"] = None
     msg["n"] = position.n
     if root.parent and root.parent.parent:
         msg["q"] = root.parent.Q
     else:
         msg["q"] = 0
     dbg("mg-gamestate:%s", json.dumps(msg, sort_keys=True))
Example #8
0
def _get_excld_packages(excld_csv):
    if not excld_csv:
        return []

    if not os.path.exists(excld_csv):
        warn("Skipping Non-Existent exclude-package File: %s" % excld_csv)
        return []

    dbg("Importing Excluded Packages from %s" % excld_csv)

    excld_pkgs = set()
    try:
        with open(excld_csv) as csv_in:
            reader = csv.reader(csv_in)
            for row in reader:
                if not len(row):
                    continue
                if row[0].startswith('#'):
                    continue

                pkg = row[0].strip().lower()
                excld_pkgs.add(pkg.replace(' ', '-'))
    except Exception as e:
        warn("exclude-packages: %s" % e)
        return []

    dbg("Requested packages to exclude: %s" % list(excld_pkgs))
    return list(excld_pkgs)
Example #9
0
    def _run_threads(self):
        """Run inference threads and optionally a thread that updates the model.

        Synchronization between the inference threads and the model update
        thread is performed using a RwLock that protects access to self.sess.
        The inference threads enter the critical section using a read lock, so
        they can both run inference concurrently. The model update thread enters
        the critical section using a write lock for exclusive access.
        """
        threads = []
        # Start the worker threads before the checkpoint thread: if the parent
        # process dies, the worker thread RPCs will fail and the thread will
        # exit. This gives us a chance below to set self._running to False,
        # telling the checkpoint thread to exit.
        for i in range(NUM_WORKER_THREADS):
            threads.append(
                threading.Thread(target=self._worker_thread, args=[i]))
        if FLAGS.checkpoint_dir:
            threads.append(threading.Thread(target=self._checkpoint_thread))

        for t in threads:
            t.start()
        for i, t in enumerate(threads):
            t.join()
            dbg("joined thread %d" % i)
            # Once the first thread has joined, tell the remaining ones to stop.
            self._running = False
Example #10
0
def _get_user_whitelist(whtlst_csv):
    if not whtlst_csv:
        return []

    if not os.path.exists(whtlst_csv):
        warn("Skipping Non-Existent CVE Whitelist File: %s" % excld_csv)
        return []

    dbg("Importing Whitelisted CVEs from %s" % whtlst_csv)

    whtlst_cves = set()
    try:
        with open(whtlst_csv) as csv_in:
            reader = csv.reader(csv_in)
            for row in reader:
                if not len(row):
                    continue
                if row[0].startswith('#'):
                    continue

                pkg = row[0].strip().upper()
                whtlst_cves.add(pkg.replace(' ', '-'))
    except Exception as e:
        warn("whitelist-cves: %s" % e)
        return []

    dbg("Requested CVEs to Ignore: %s" % list(whtlst_cves))
    return whtlst_cves
Example #11
0
 def tree_search(self, parallel_readouts=None):
     if parallel_readouts is None:
         parallel_readouts = min(strat_args.parallel_readouts,
                                 self.num_readouts)
     leaves = []
     leaves_ft = []
     failsafe = 0
     while len(leaves
               ) < parallel_readouts and failsafe < parallel_readouts * 2:
         failsafe += 1
         leaf = self.root.select_leaf()
         if self.verbosity >= 4:
             dbg(self.show_path_to_root(leaf))
         # if game is over, override the value estimate with the true score
         if leaf.is_done():
             value = 1 if leaf.position.score() > 0 else -1
             leaf.backup_value(value, up_to=self.root)
             continue
         leaf.add_virtual_loss(up_to=self.root)
         leaves.append(leaf)
         leaves_ft.append(extract_features(leaf.position))
     if leaves:
         leaves_np = np.array(leaves_ft)
         move_probs, values = self.network.policy_value_fn(
             leaves_np, device=self.device)
         for leaf, move_prob, value in zip(leaves, move_probs, values):
             leaf.revert_virtual_loss(up_to=self.root)
             leaf.incorporate_results(move_prob, value, up_to=self.root)
     return leaves
Example #12
0
def count_elements_in_dataset(ds, batch_size=1*1024, parallel_batch=8):
    """Count and return all the elements in the given dataset.

    Debugging function.  The elements in a dataset cannot be counted
    without enumerating all of them.  By counting in batch and in
    parallel, this method allows rapid traversal of the dataset.

    Args:
      ds:  The dataset whose elements should be counted.
      batch_size:  the number of elements to count a a time.
      parallel_batch:  how many batches to count in parallel.

    Returns:
      The number of elements in the dataset.
    """
    with tf.Session() as sess:
        dsc = ds.apply(tf.contrib.data.enumerate_dataset())
        dsc = dsc.apply(
            tf.contrib.data.map_and_batch(lambda c, v: c, batch_size,
                                          num_parallel_batches=parallel_batch))
        iterator = dsc.make_initializable_iterator()
        sess.run(iterator.initializer)
        get_next = iterator.get_next()
        counted = 0
        try:
            while True:
                # The numbers in the tensors are 0-based indicies,
                # so add 1 to get the number counted.
                counted = sess.run(tf.reduce_max(get_next)) + 1
                utils.dbg('Counted so far: %d' % counted)
        except tf.errors.OutOfRangeError:
            pass
        utils.dbg('Counted total: %d' % counted)
        return counted
Example #13
0
def get_unparsed_moves_from_last_n_games(n,
                                         moves=2**21,
                                         shuffle=True,
                                         column_family='tfexample',
                                         column='example'):
    """Get a dataset of serialized TFExamples from the last N games.

    Args:
      n:  an integer indicating how many past games should be sourced.
      moves:  an integer indicating how many moves should be sampled
        from those N games.
      column_family:  name of the column family containing move examples.
      column:  name of the column containing move examples.
      shuffle:  if True, shuffle the selected move examples.

    Returns:
      A dataset containing no more than `moves` examples, sampled
        randomly from the last `n` games in the table.
    """
    _games.wait_for_fresh_games()
    latest_game = int(_games.latest_game_number())
    utils.dbg('Latest game: %s' % latest_game)
    if latest_game == 0:
        raise ValueError('Cannot find a latest game in the table')

    start = int(max(0, latest_game - n))
    ds = _games.moves_from_games(start, latest_game, moves, shuffle,
                                 column_family, column)
    return ds.map(lambda row_name, s: s)
Example #14
0
    def run(self):
        """
        Run command _cmd and pass its stdout+stderr output
        to _watch object. watch object is supposed to be
        an instance of ProcessWatch object. Running process
        has handle in current_process global variable.

        \return return code of the process or -1 on watch error
        """

        dbg('|> {0}'.format(' '.join(self._cmd)), prefix='')

        # run the command and store handle into global variable
        # current_process, so that we can easily kill this process
        # on timeout or signal. This way we can run only one
        # process at the time, but we don't need more (yet?)
        self._process = Popen(self._cmd, stdout=PIPE, stderr=STDOUT)

        while True:
            line = self._process.stdout.readline()
            if line == '' and self._process.poll() is not None:
                break

            self._watch.putLine(line)
            if not self._watch.ok():
                # watch told us to kill the process for some reason
                self._process.terminate()
                self._process.kill()
                return -1

        return self._process.wait()
Example #15
0
    def _minigui_report_position(self):
        root = self._player.get_root()
        position = root.position

        board = []
        for row in range(go.N):
            for col in range(go.N):
                stone = position.board[row, col]
                if stone == go.BLACK:
                    board.append("X")
                elif stone == go.WHITE:
                    board.append("O")
                else:
                    board.append(".")

        msg = {
            "id": hex(id(root)),
            "toPlay": "B" if position.to_play == 1 else "W",
            "moveNum": position.n,
            "stones": "".join(board),
            "gameOver": position.is_game_over(),
            "caps": position.caps,
        }
        if root.parent and root.parent.parent:
            msg["parentId"] = hex(id(root.parent))
            msg["q"] = float(root.parent.Q)
        if position.recent:
            msg["move"] = coords.to_gtp(position.recent[-1].move)
        dbg("mg-position:%s" % json.dumps(msg, sort_keys=True))
Example #16
0
def play(network):
    """Plays out a self-play match, returning a MCTSPlayer object containing:
        - the final position
        - the n x 362 tensor of floats representing the mcts search probabilities
        - the n-ary tensor of floats representing the original value-net estimate
          where n is the number of moves in the game
    """
    readouts = FLAGS.num_readouts  # defined in strategies.py
    # Disable resign in 5% of games
    if random.random() < FLAGS.resign_disable_pct:
        resign_threshold = -1.0
    else:
        resign_threshold = None

    player = MCTSPlayer(network, resign_threshold=resign_threshold)

    player.initialize_game()

    # Must run this once at the start to expand the root node.
    first_node = player.root.select_leaf()
    prob, val = network.run(first_node.position)
    first_node.incorporate_results(prob, val, first_node)

    while True:
        start = time.time()
        player.root.inject_noise()
        current_readouts = player.root.N
        # we want to do "X additional readouts", rather than "up to X readouts".
        while player.root.N < current_readouts + readouts:
            player.tree_search()

        if FLAGS.verbose >= 3:
            print(player.root.position)
            print(player.root.describe())

        if player.should_resign():
            player.set_result(-1 * player.root.position.to_play,
                              was_resign=True)
            break
        move = player.pick_move()
        player.play_move(move)
        if player.root.is_done():
            player.set_result(player.root.position.result(), was_resign=False)
            break

        if (FLAGS.verbose >= 2) or (FLAGS.verbose >= 1 and player.root.position.n % 10 == 9):
            print("Q: {:.5f}".format(player.root.Q))
            dur = time.time() - start
            print("%d: %d readouts, %.3f s/100. (%.2f sec)" % (
                player.root.position.n, readouts, dur / readouts * 100.0, dur), flush=True)
        if FLAGS.verbose >= 3:
            print("Played >>",
                  coords.to_gtp(coords.from_flat(player.root.fmove)))

    if FLAGS.verbose >= 2:
        utils.dbg("%s: %.3f" % (player.result_string, player.root.Q))
        utils.dbg(player.root.position, player.root.position.score())

    return player
Example #17
0
 def _checkpoint_thread(self):
     dbg("starting model loader thread")
     while self._running:
         freshest = saver.latest_checkpoint(FLAGS.checkpoint_dir)
         if freshest:
             self.sess.maybe_load_model(freshest)
         # Wait a few seconds before checking again.
         time.sleep(5)
Example #18
0
    def _minigui_report_search_status(self, leaves):
        """Prints the current MCTS search status to stderr.

        Reports the current search path, root node's child_Q, root node's
        child_N, the most visited path in a format that can be parsed by
        one of the STDERR_HANDLERS in minigui.ts.

        Args:
          leaves: list of leaf MCTSNodes returned by tree_search().
         """

        root = self._player.get_root()
        position = root.position

        msg = {
            "id": hex(id(root)),
            "n": int(root.N),
            "q": float(root.Q),
        }

        msg["childQ"] = [int(round(q * 1000)) for q in root.child_Q]
        msg["childN"] = [int(n) for n in root.child_N]

        ranked_children = root.rank_children()
        variations = {}
        for i in ranked_children[:15]:
            if root.child_N[i] == 0 or i not in root.children:
                break
            c = coords.to_gtp(coords.from_flat(i))
            child = root.children[i]
            nodes = child.most_visited_path_nodes()
            moves = [coords.to_gtp(coords.from_flat(m.fmove)) for m in nodes]
            variations[c] = {
                "n": int(root.child_N[i]),
                "q": float(root.child_Q[i]),
                "moves": [c] + moves,
            }

        if leaves:
            path = []
            leaf = leaves[0]
            while leaf != root:
                path.append(leaf.fmove)
                leaf = leaf.parent
            if path:
                path.reverse()
                variations["live"] = {
                    "n": int(root.child_N[path[0]]),
                    "q": float(root.child_Q[path[0]]),
                    "moves":
                    [coords.to_gtp(coords.from_flat(m)) for m in path]
                }

        if variations:
            msg["variations"] = variations

        dbg("mg-update:%s" % json.dumps(msg, sort_keys=True))
Example #19
0
 def run(self):
     self._running = True
     try:
         self._run_threads()
     finally:
         self._running = False
         dbg("shutting down session")
         self.sess.shutdown()
         dbg("all done!")
Example #20
0
 def parse(self, line):
     if 'INFO' in line:
         dbg(line, domain='slicer', print_nl=False)
     elif 'ERROR' in line or 'error' in line:
         print_stderr(line)
     elif 'Statistics' in line:
         print_stdout(line, print_nl=False)
     else:
         dbg(line, 'slicer', False)
Example #21
0
    def _locked_load_model(self, path):
        if self._model_path:
            dbg("shutting down tpu")
            self._sess.run(self._tpu_shutdown)

        with self._sess.graph.as_default():
            tf.train.Saver().restore(self._sess, path)

        dbg("initializing tpu")
        self._sess.run(self._tpu_init)
Example #22
0
 def save_model(self, filename="model"):
     model_file_path = self.model_data_dir + filename
     i = 1
     while os.path.isfile(model_file_path):
         model_file_path = self.model_data_dir + filename + "-" + str(i)
         i += 1
     weights_path = model_file_path + ".hdf5"
     self.model.save(model_file_path)
     self.model.save_weights(weights_path)
     dbg(f"model saved: {model_file_path}\tweights saved: {weights_path}")
Example #23
0
def main(argv):
    """Run Minigo in GTP mode."""
    del argv
    engine = make_gtp_instance(FLAGS.load_file,
                               cgos_mode=FLAGS.cgos_mode,
                               kgs_mode=FLAGS.kgs_mode,
                               minigui_mode=FLAGS.minigui_mode)
    dbg("GTP engine ready\n")
    for msg in sys.stdin:
        if not engine.handle_msg(msg.strip()):
            break
Example #24
0
def main(argv):
    '''Run Minigo in GTP mode.'''
    del argv
    engine = make_gtp_instance(FLAGS.load_file,
                               cgos_mode=FLAGS.cgos_mode,
                               kgs_mode=FLAGS.kgs_mode,
                               verbosity=FLAGS.verbose)
    dbg("GTP engine ready\n")
    for msg in sys.stdin:
        if not engine.handle_msg(msg.strip()):
            break
Example #25
0
 def save_model(self, filename=""):
     model_file_path = self.model_data_dir + filename + self.architecture
     i = 1
     while os.path.isfile(model_file_path):
         model_file_path = self.model_data_dir + filename + self.architecture + '-' + str(
             i)
         i += 1
     weights_path = model_file_path + "-weights"
     model_file_path = model_file_path + '-model'
     self.model.save(model_file_path)
     self.model.save_weights(weights_path, save_format='tf')
     dbg(f"model saved: {model_file_path}\tweights saved: {weights_path}")
Example #26
0
def main(argv):
    '''Run Minigo in GTP mode.'''
    del argv
    engine = make_gtp_instance(FLAGS.model_path,
                               FLAGS.model_name,
                               cgos_mode=FLAGS.cgos_mode,
                               kgs_mode=FLAGS.kgs_mode,
                               minigui_mode=FLAGS.minigui_mode)
    dbg("GTP engine ready\n")
    for msg in sys.stdin:
        if not engine.handle_msg(msg.strip()):
            break
Example #27
0
    def moves_from_last_n_games(self, n, moves, shuffle,
                                column_family, column):
      self.wait_for_fresh_games()
      latest_game = int(self.latest_game_number())
      utils.dbg('Latest game in %s: %s' % (self.table_name, latest_game))
      if latest_game == 0:
          raise ValueError('Cannot find a latest game in the table')

      start = int(max(0, latest_game - n))
      ds = self.moves_from_games(start, latest_game, moves, shuffle,
                                   column_family, column)
      return ds
Example #28
0
    def maybe_load_model(self, path):
        """Loads the given model if it's different from the current one."""
        with self._mutex.read_lock():
            if path == self._model_path:
                return

        with self._mutex.write_lock():
            dbg(time.time(), "loading %s" % path)
            self._locked_load_model(path)
            self._model_path = path
            dbg(time.time(), "loaded %s" % path)
        self.model_available.set()
Example #29
0
 def cmd_clear_board(self):
     position = self._player.get_position()
     if (self._player.get_result_string() and
             position and len(position.recent) > 1):
         try:
             sgf = self._player.to_sgf()
             with open(datetime.now().strftime("%Y-%m-%d-%H:%M.sgf"), 'w') as f:
                 f.write(sgf)
         except NotImplementedError:
             pass
         except:
             dbg("Error saving sgf")
     self._player.initialize_game(go.Position(komi=self._komi))
Example #30
0
def amend_manifest(vgls, manifest):
    addl_pkgs = _get_addl_packages(vgls['addl'])
    if addl_pkgs:
        manifest.update(addl_pkgs)

    excld_pkgs = _get_excld_packages(vgls['excld'])
    _filter_excluded_packages(manifest['packages'], excld_pkgs)

    whtlst_cves = _build_whitelist(vgls, manifest)
    if whtlst_cves:
        dbg("Ignoring CVEs: %s" %
            json.dumps(whtlst_cves, indent=4, sort_keys=True))
        manifest['whitelist'] = sorted(whtlst_cves)