コード例 #1
0
def sync(root, force_all=False):
    last_ts = last_timestamp()
    if last_ts and not force_all:
        # Build a list of days from the day before our last timestamp to today
        num_days = (dt.datetime.utcnow() -
                    dt.datetime.utcfromtimestamp(last_ts) +
                    dt.timedelta(days=1)).days
        ds = [
            (dt.datetime.utcnow() - dt.timedelta(days=d)).strftime("%Y-%m-%d")
            for d in range(num_days + 1)
        ]
        for d in ds:
            if not os.path.isdir(os.path.join(root, d)):
                os.mkdir(os.path.join(root, d))
            cmd = [
                "gsutil", "-m", "rsync", "-r",
                os.path.join(fsdb.eval_dir(), d),
                os.path.join(root, d)
            ]
            print(" ".join(cmd))
            subprocess.call(cmd)

        ingest_dirs(root, ds)
    else:
        cmd = ["gsutil", "-m", "rsync", "-r", fsdb.eval_dir(), root]
        print(" ".join(cmd))
        subprocess.call(cmd)
        dirs = os.listdir(root)
        ingest_dirs(root, dirs)
コード例 #2
0
ファイル: eval_models.py プロジェクト: TAKAKEYA/mlpef
def main(unused_argv):
    logging.getLogger('mlperf_compliance').propagate = False

    sgf_dir = os.path.join(fsdb.eval_dir(), 'target')
    target = 'tf,' + os.path.join(fsdb.models_dir(), 'target.pb.og')
    models = load_train_times()

    timestamp_to_log = 0
    iter_evaluated = 0

    for i, (timestamp, name, path) in enumerate(models):
        minigo_print(key=constants.EVAL_START, metadata={'epoch_num': i + 1})

        iter_evaluated += 1
        winrate = wait(evaluate_model(path + '.og', target, sgf_dir, i + 1))

        minigo_print(key=constants.EVAL_ACCURACY,
                     value=winrate,
                     metadata={'epoch_num': i + 1})
        minigo_print(key=constants.EVAL_STOP, metadata={'epoch_num': i + 1})

        if winrate >= 0.50:
            timestamp_to_log = timestamp
            print('Model {} beat target after {}s'.format(name, timestamp))
            break

    minigo_print(key='eval_result',
                 metadata={
                     'iteration': iter_evaluated,
                     'timestamp': timestamp_to_log
                 })
コード例 #3
0
def main(unused_argv):
    """Run the reinforcement learning loop."""

    print('Wiping dir %s' % FLAGS.base_dir, flush=True)
    shutil.rmtree(FLAGS.base_dir, ignore_errors=True)

    utils.ensure_dir_exists(fsdb.models_dir())
    utils.ensure_dir_exists(fsdb.selfplay_dir())
    utils.ensure_dir_exists(fsdb.holdout_dir())
    utils.ensure_dir_exists(fsdb.eval_dir())
    utils.ensure_dir_exists(fsdb.golden_chunk_dir())
    utils.ensure_dir_exists(fsdb.working_dir())

    # Copy the target model to the models directory so we can find it easily.
    shutil.copy('ml_perf/target.pb', fsdb.models_dir())

    logging.getLogger().addHandler(
        logging.FileHandler(os.path.join(FLAGS.base_dir, 'reinforcement.log')))
    formatter = logging.Formatter('[%(asctime)s] %(message)s',
                                  '%Y-%m-%d %H:%M:%S')
    for handler in logging.getLogger().handlers:
        handler.setFormatter(formatter)

    with utils.logged_timer('Total time'):
        for target_win_rate in rl_loop():
            if target_win_rate > 0.5:
                return logging.info('Passed exit criteria.')
        logging.info('Failed to converge.')
コード例 #4
0
def main(unused_argv):
    """Run the reinforcement learning loop."""

    print('Wiping dir %s' % FLAGS.base_dir, flush=True)
    shutil.rmtree(FLAGS.base_dir, ignore_errors=True)

    utils.ensure_dir_exists(fsdb.models_dir())
    utils.ensure_dir_exists(fsdb.selfplay_dir())
    utils.ensure_dir_exists(fsdb.holdout_dir())
    utils.ensure_dir_exists(fsdb.eval_dir())
    utils.ensure_dir_exists(fsdb.golden_chunk_dir())
    utils.ensure_dir_exists(fsdb.working_dir())

    # Copy the flag files so there's no chance of them getting accidentally
    # overwritten while the RL loop is running.
    flags_dir = os.path.join(FLAGS.base_dir, 'flags')
    shutil.copytree(FLAGS.flags_dir, flags_dir)
    FLAGS.flags_dir = flags_dir

    # Copy the target model to the models directory so we can find it easily.
    shutil.copy('ml_perf/target.pb', fsdb.models_dir())

    logging.getLogger().addHandler(
        logging.FileHandler(os.path.join(FLAGS.base_dir, 'rl_loop.log')))
    formatter = logging.Formatter('[%(asctime)s] %(message)s',
                                  '%Y-%m-%d %H:%M:%S')
    for handler in logging.getLogger().handlers:
        handler.setFormatter(formatter)

    with utils.logged_timer('Total time'):
        try:
            rl_loop()
        finally:
            asyncio.get_event_loop().close()
コード例 #5
0
async def evaluate_trained_model_parallel(state):
    """Evaluate the most recently trained model against the current best model.

  Args:
    state: the RL loop State instance.
  """
    all_tasks = []
    loop = asyncio.get_event_loop()
    for i in range(FLAGS.num_parallel_eval):
        all_tasks.append(
            loop.create_task(
                evaluate_model_parallel(
                    state.train_model_path_eval, state.best_model_path_eval,
                    os.path.join(fsdb.eval_dir(), state.train_model_name),
                    state.seed, i, 'parallel_eval')))
    all_lines = await asyncio.gather(*all_tasks, return_exceptions=True)

    total_games = 0
    total_wins = 0
    for lines in all_lines:
        if type(lines) == RuntimeError or type(lines) == OSError:
            raise lines
        total_games = total_games + lines[0]
        total_wins = total_wins + lines[1]

    print('Iter = {} Eval of {} against best={}, games={}, wins={}'.format(
        state.iter_num, state.train_model_name, state.best_model_name,
        total_games, total_wins))

    win_rate = total_wins / total_games
    return win_rate
コード例 #6
0
def main(unused_argv):
  """Run the reinforcement learning loop."""

  print('Wiping dir %s' % FLAGS.base_dir, flush=True)
  shutil.rmtree(FLAGS.base_dir, ignore_errors=True)
  dirs = [fsdb.models_dir(), fsdb.selfplay_dir(), fsdb.holdout_dir(),
          fsdb.eval_dir(), fsdb.golden_chunk_dir(), fsdb.working_dir()]
  for d in dirs:
    ensure_dir_exists(d);

  # Copy the flag files so there's no chance of them getting accidentally
  # overwritten while the RL loop is running.
  flags_dir = os.path.join(FLAGS.base_dir, 'flags')
  shutil.copytree(FLAGS.flags_dir, flags_dir)
  FLAGS.flags_dir = flags_dir

  # Copy the target model to the models directory so we can find it easily.
  for file_name in [
        "target.pb", "target_raw.ckpt.data-00000-of-00001",
        "target_raw.ckpt.index", "target_raw.ckpt.meta"]:
    shutil.copy(FLAGS.target_path[:-len("target.pb")] + file_name,
                os.path.join(fsdb.models_dir(), file_name))

  logging.getLogger().addHandler(
      logging.FileHandler(os.path.join(FLAGS.base_dir, 'rl_loop.log')))
  formatter = logging.Formatter('[%(asctime)s] %(message)s',
                                '%Y-%m-%d %H:%M:%S')
  for handler in logging.getLogger().handlers:
    handler.setFormatter(formatter)

  with logged_timer('Total time'):
    try:
      rl_loop()
    finally:
      asyncio.get_event_loop().close()
コード例 #7
0
ファイル: eval_models.py プロジェクト: yihanpan1999/minigo
def main(unused_argv):
    sgf_dir = os.path.join(fsdb.eval_dir(), 'target')
    target = 'tf,' + os.path.join(fsdb.models_dir(), 'target.pb')
    models = load_train_times()
    for i, (timestamp, name, path) in enumerate(models):
        winrate = wait(evaluate_model(path, name, target, sgf_dir))
        if winrate >= 0.50:
            break
コード例 #8
0
ファイル: eval_models.py プロジェクト: mengxr/training-1
def main(unused_argv):
    sgf_dir = os.path.join(fsdb.eval_dir(), 'target')
    target = 'tf,' + os.path.join(fsdb.models_dir(), 'target.pb')
    models = load_train_times()
    for i, (timestamp, name, path) in enumerate(models):
        winrate = wait(evaluate_model(path, target, sgf_dir, i + 1))
        if winrate >= 0.50:
            print('Model {} beat target after {}s'.format(name, timestamp))
            break
コード例 #9
0
async def evaluate_trained_model(state):
  """Evaluate the most recently trained model against the current best model.

  Args:
    state: the RL loop State instance.
  """

  return await evaluate_model(
      state.train_model_path, state.best_model_path,
      os.path.join(fsdb.eval_dir(), state.train_model_name), state.seed)
コード例 #10
0
ファイル: reinforcement.py プロジェクト: delock/training
def evaluate(state, args, name, slice):
  sgf_dir = os.path.join(fsdb.eval_dir(), state.train_model_name)
  result = checked_run([
      'external/minigo/cc/main', '--mode=eval', '--parallel_games=100',
      '--model={}'.format(
          state.train_model_path), '--sgf_dir={}'.format(sgf_dir)
  ] + args, name)
  result = get_lines(result, slice)
  logging.info(result)
  pattern = '{}\s+\d+\s+(\d+\.\d+)%'.format(state.train_model_name)
  return float(re.search(pattern, result).group(1)) * 0.01
コード例 #11
0
def evaluate_trained_model_noasync(state):
    """Evaluate the most recently trained model against the current best model.

  Args:
    state: the RL loop State instance.
  """

    return evaluate_model_noasync(
        state.train_model_path_eval, state.best_model_path_eval,
        os.path.join(fsdb.eval_dir(), state.train_model_name), state.seed,
        'parallel_eval')
コード例 #12
0
async def eval_vs_target(state):
    # If we're bootstrapping a checkpoint, evaluate the newly trained model
    # against the target.
    # TODO(tommadams): evaluate the previously trained model against the
    # target in parallel with training the next model.
    if FLAGS.bootstrap and state.iter_num > 2:
        target_model_path = os.path.join(fsdb.models_dir(), 'target.pb')
        sgf_dir = os.path.join(fsdb.eval_dir(),
                               '{}-vs-target'.format(state.train_model_name))
        win_rate_vs_target = await evaluate_model(state.selfplay_model_path,
                                                  target_model_path, sgf_dir)
        return win_rate_vs_target
    return None
コード例 #13
0
def main(unused_argv):
    """Run the reinforcement learning loop."""
    logging.getLogger('mlperf_compliance').propagate = False

    ##-->multi-node setup
    if FLAGS.use_multinode:
        mpi_comm = MPI.COMM_WORLD
        mpi_rank = mpi_comm.Get_rank()
        mpi_size = mpi_comm.Get_size()
        print('[MPI Init] MPI rank {}, mpi size is {} host is {}'.format(
            mpi_rank, mpi_size, socket.gethostname()))
    else:
        mpi_comm = None
        mpi_rank = 0
        mpi_size = 1

    print('Wiping dir %s' % FLAGS.base_dir, flush=True)
    shutil.rmtree(FLAGS.base_dir, ignore_errors=True)
    dirs = [
        fsdb.models_dir(),
        fsdb.selfplay_dir(),
        fsdb.holdout_dir(),
        fsdb.eval_dir(),
        fsdb.golden_chunk_dir(),
        fsdb.working_dir()
    ]

    ##-->sharedFS for dataExchange. tmp solution 5/6/2019
    if FLAGS.use_multinode:
        ensure_dir_exists(FLAGS.shared_dir_exchange)
    for d in dirs:
        ensure_dir_exists(d)

    # Copy the flag files so there's no chance of them getting accidentally
    # overwritten while the RL loop is running.
    flags_dir = os.path.join(FLAGS.base_dir, 'flags')
    shutil.copytree(FLAGS.flags_dir, flags_dir)
    FLAGS.flags_dir = flags_dir

    # Copy the target model to the models directory so we can find it easily.
    shutil.copy(FLAGS.target_path, os.path.join(fsdb.models_dir(),
                                                'target.pb'))
    shutil.copy(FLAGS.target_path + '.og',
                os.path.join(fsdb.models_dir(), 'target.pb.og'))

    with logged_timer('Total time from mpi_rank={}'.format(mpi_rank)):
        try:
            rl_loop(mpi_comm, mpi_rank, mpi_size)
        finally:
            asyncio.get_event_loop().close()
コード例 #14
0
def rl_loop():
    """The main reinforcement learning (RL) loop."""

    state = State()
    prev_win_rate_vs_target = 0

    if FLAGS.bootstrap:
        wait(bootstrap_selfplay(state))
    else:
        initialize_from_checkpoint(state)

    # Start the selfplay workers. They will wait for a model to become available
    # in the training directory before starting to play.
    selfplay_processes, selfplay_logs = wait(start_selfplay())

    try:
        # Now start the full training loop.
        while state.iter_num < FLAGS.iterations:
            state.iter_num += 1

            wait_for_training_examples(state, FLAGS.min_games_per_iteration)
            tf_records = wait(sample_training_examples(state))

            wait(train(state, tf_records))

            # If we're bootstrapping a checkpoint, evaluate the newly trained model
            # against the target.
            # TODO(tommadams): evaluate the previously trained model against the
            # target in parallel with training the next model.
            if FLAGS.bootstrap and state.iter_num > 15:
                target_model_path = os.path.join(fsdb.models_dir(),
                                                 'target.pb')
                sgf_dir = os.path.join(
                    fsdb.eval_dir(),
                    '{}-vs-target'.format(state.train_model_name))
                win_rate_vs_target = wait(
                    evaluate_model(state.train_model_path, target_model_path,
                                   sgf_dir))
                if (win_rate_vs_target >= FLAGS.bootstrap_target_win_rate
                        and prev_win_rate_vs_target > 0):
                    # The tranined model won a sufficient number of games against
                    # the target. Create the checkpoint that will be used to start
                    # the real benchmark and exit.
                    create_checkpoint()
                    break
                prev_win_rate_vs_target = win_rate_vs_target
    finally:
        wait(end_selfplay(selfplay_processes, selfplay_logs))
コード例 #15
0
def main(unused_argv):
  sgf_dir = os.path.join(fsdb.eval_dir(), 'target')
  target = 'tf,' + os.path.join(fsdb.models_dir(), 'target.pb')
  models = load_train_times()
  for i, (timestamp, name, path) in enumerate(models):
    mll.eval_start(i)
    winrate = wait(evaluate_model(path, target, sgf_dir, i + 1))
    mll.eval_stop(i)
    mll.eval_accuracy(i, winrate)
    if winrate >= 0.50:
      print('Model {} beat target after {}s'.format(name, timestamp))
      mll.eval_result(i, timestamp)
      mll.run_stop('success')
      return
  mll.eval_result(i, 0)
  mll.run_stop('aborted')
コード例 #16
0
def evaluate(state, against_model):
    eval_model = state.train_model_name
    eval_model_path = os.path.join(fsdb.models_dir(), eval_model)
    against_model_path = os.path.join(fsdb.models_dir(), against_model)
    sgf_dir = os.path.join(fsdb.eval_dir(), eval_model)
    result = checked_run([
        'bazel-bin/cc/eval', '--num_readouts=100', '--parallel_games=100',
        '--model={}.pb'.format(eval_model_path),
        '--model_two={}.pb'.format(against_model_path),
        '--sgf_dir={}'.format(sgf_dir)
    ] + cc_flags(state), 'evaluation against ' + against_model)
    result = get_lines(result, make_slice[-7:])
    logging.info(result)
    pattern = '{}\s+\d+\s+(\d+\.\d+)%'.format(eval_model)
    win_rate = float(re.search(pattern, result).group(1)) * 0.01
    logging.info('Win rate %s vs %s: %.3f', eval_model, against_model,
                 win_rate)
    return win_rate
コード例 #17
0
def evaluate(state):
    eval_model = state.train_model_name
    best_model = state.best_model_name
    eval_model_path = os.path.join(fsdb.models_dir(), eval_model)
    best_model_path = os.path.join(fsdb.models_dir(), best_model)
    sgf_dir = os.path.join(fsdb.eval_dir(), eval_model)
    result = checked_run(
        'evaluation', 'bazel-bin/cc/eval',
        '--flagfile={}'.format(os.path.join(FLAGS.flags_dir, 'eval.flags')),
        '--model={}.pb'.format(eval_model_path),
        '--model_two={}.pb'.format(best_model_path),
        '--sgf_dir={}'.format(sgf_dir), '--seed={}'.format(state.seed))
    result = get_lines(result, make_slice[-7:])
    logging.info(result)
    pattern = '{}\s+\d+\s+(\d+\.\d+)%'.format(eval_model)
    win_rate = float(re.search(pattern, result).group(1)) * 0.01
    logging.info('Win rate %s vs %s: %.3f', eval_model, best_model, win_rate)
    return win_rate
コード例 #18
0
def main(unused_argv):
  """Run the reinforcement learning loop."""

  mll.init_start()
  print('Wiping dir %s' % FLAGS.base_dir, flush=True)
  shutil.rmtree(FLAGS.base_dir, ignore_errors=True)
  dirs = [fsdb.models_dir(), fsdb.selfplay_dir(), fsdb.holdout_dir(),
          fsdb.eval_dir(), fsdb.golden_chunk_dir(), fsdb.working_dir(),
          fsdb.mpi_log_dir()]
  for d in dirs:
    ensure_dir_exists(d);

  # Copy the flag files so there's no chance of them getting accidentally
  # overwritten while the RL loop is running.
  flags_dir = os.path.join(FLAGS.base_dir, 'flags')
  shutil.copytree(FLAGS.flags_dir, flags_dir)
  FLAGS.flags_dir = flags_dir

  # Copy the target model to the models directory so we can find it easily.
  shutil.copy(FLAGS.target_path, os.path.join(fsdb.models_dir(), 'target.pb'))

  logging.getLogger().addHandler(
      logging.FileHandler(os.path.join(FLAGS.base_dir, 'rl_loop.log')))
  formatter = logging.Formatter('[%(asctime)s] %(message)s',
                                '%Y-%m-%d %H:%M:%S')
  for handler in logging.getLogger().handlers:
    handler.setFormatter(formatter)

  logging.info('Selfplay nodes = {}'.format(FLAGS.selfplay_node))
  logging.info('Train nodes = {}'.format(FLAGS.train_node))
  logging.info('Eval nodes = {}'.format(FLAGS.eval_node))

  with logged_timer('Total time'):
    try:
      mll.init_stop()
      mll.run_start()
      rl_loop()
    finally:
      asyncio.get_event_loop().close()
コード例 #19
0
def main(unused_argv):

    for i in range(0, NUM_LOOP):
        if i == 0:
            src_model_name = shipname.generate(0)
            fsdb.switch_base(os.path.join(base_dir, src_model_name))
            src_model_path = os.path.join(fsdb.models_dir(), src_model_name)
            bootstrap_model_path = os.path.join(fsdb.models_dir(),
                                                src_model_name)
            mask_flags.checked_run([
                'python3', 'bootstrap.py',
                '--export_path={}'.format(bootstrap_model_path),
                '--work_dir={}'.format(fsdb.working_dir()),
                '--flagfile=rl_loop/local_flags'
            ])
            dst_model_name = shipname.generate(1)
            fsdb.switch_base(os.path.join(base_dir, dst_model_name))
        else:
            src_model_name = dst_model_name
            src_model_path = os.path.join(fsdb.models_dir(), src_model_name)
            dst_model_name = shipname.generate(i + 1)
            fsdb.switch_base(os.path.join(base_dir, dst_model_name))

        utils.ensure_dir_exists(fsdb.models_dir())
        utils.ensure_dir_exists(fsdb.selfplay_dir())
        utils.ensure_dir_exists(fsdb.holdout_dir())
        utils.ensure_dir_exists(fsdb.sgf_dir())
        utils.ensure_dir_exists(fsdb.eval_dir())
        utils.ensure_dir_exists(fsdb.golden_chunk_dir())
        utils.ensure_dir_exists(fsdb.working_dir())

        #bootstrap_name = shipname.generate(0)
        #bootstrap_model_path = os.path.join(fsdb.models_dir(), bootstrap_name)

        print(src_model_name)
        print(src_model_path)
        selfplay_cmd = [
            'python3', 'selfplay.py', '--load_file={}'.format(src_model_path),
            '--selfplay_dir={}'.format(
                os.path.join(fsdb.selfplay_dir(),
                             dst_model_name)), '--holdout_dir={}'.format(
                                 os.path.join(fsdb.holdout_dir(),
                                              dst_model_name)),
            '--sgf_dir={}'.format(fsdb.sgf_dir()), '--holdout_pct=0',
            '--flagfile=rl_loop/local_flags'
        ]

        # Selfplay twice
        mask_flags.checked_run(selfplay_cmd)
        mask_flags.checked_run(selfplay_cmd)

        # and once more to generate a held out game for validation
        # exploits flags behavior where if you pass flag twice, second one wins.
        mask_flags.checked_run(selfplay_cmd + ['--holdout_pct=100'])

        # Double check that at least one sgf has been generated.
        assert os.listdir(os.path.join(fsdb.sgf_dir(), 'full'))

        print("Making shuffled golden chunk from selfplay data...")
        # TODO(amj): refactor example_buffer so it can be called the same way
        # as everything else.
        eb.make_chunk_for(output_dir=fsdb.golden_chunk_dir(),
                          local_dir=fsdb.working_dir(),
                          game_dir=fsdb.selfplay_dir(),
                          model_num=1,
                          positions=64,
                          threads=8,
                          sampling_frac=1)

        tf_records = sorted(
            gfile.Glob(os.path.join(fsdb.golden_chunk_dir(), '*.tfrecord.zz')))

        #trained_model_name = shipname.generate(1)
        trained_model_name = dst_model_name
        trained_model_path = os.path.join(fsdb.models_dir(),
                                          trained_model_name)

        # Train on shuffled game data
        mask_flags.checked_run([
            'python3', 'train.py', *tf_records,
            '--work_dir={}'.format(fsdb.working_dir()),
            '--export_path={}'.format(trained_model_path),
            '--flagfile=rl_loop/local_flags'
        ])

    print("Finished!")
コード例 #20
0
async def evaluate_target_model(state):
    sgf_dir = os.path.join(fsdb.eval_dir(), 'target')
    target = 'tf,' + os.path.join(fsdb.models_dir(), 'target.pb')
    return await evaluate_model(state.train_model_path, target, sgf_dir,
                                state.iter_num)
コード例 #21
0
def rl_loop():
    """The main reinforcement learning (RL) loop."""

    state = State()

    if FLAGS.bootstrap:
        # Play the first round of selfplay games with a fake model that returns
        # random noise. We do this instead of playing multiple games using a
        # single model bootstrapped with random noise to avoid any initial bias.
        wait(bootstrap_selfplay(state))

        # Train a real model from the random selfplay games.
        tf_records = wait(sample_training_examples(state))
        state.iter_num += 1
        wait(train(state, tf_records))

        # Select the newly trained model as the best.
        state.best_model_name = state.train_model_name
        state.gen_num += 1

        # Run selfplay using the new model.
        wait(selfplay(state))
    else:
        # Start from a partially trained model.
        initialize_from_checkpoint(state)

    prev_win_rate_vs_target = 0

    # Now start the full training loop.
    while state.iter_num <= FLAGS.iterations:
        tf_records = wait(sample_training_examples(state))
        state.iter_num += 1

        # Run selfplay in parallel with sequential (train, eval).
        model_win_rate, _ = wait(
            [train_eval(state, tf_records),
             selfplay(state)])

        # If we're bootstrapping a checkpoint, evaluate the newly trained model
        # against the target.
        if FLAGS.bootstrap:
            target_model_path = os.path.join(fsdb.models_dir(), 'target.pb')
            sgf_dir = os.path.join(
                fsdb.eval_dir(), '{}-vs-target'.format(state.train_model_name))
            win_rate_vs_target = wait(
                evaluate_model(state.train_model_path, target_model_path,
                               sgf_dir))
            if (win_rate_vs_target >= FLAGS.bootstrap_target_win_rate
                    and prev_win_rate_vs_target > 0):
                # The tranined model won a sufficient number of games against
                # the target. Create the checkpoint that will be used to start
                # the real benchmark and exit.
                create_checkpoint()
                break
            prev_win_rate_vs_target = win_rate_vs_target

        if model_win_rate >= FLAGS.gating_win_rate:
            # Promote the trained model to the best model and increment the
            # generation number.
            state.best_model_name = state.train_model_name
            state.gen_num += 1
コード例 #22
0
def main(unused_argv):
    """Run the reinforcement learning loop."""
    utils.ensure_dir_exists(fsdb.models_dir())
    utils.ensure_dir_exists(fsdb.selfplay_dir())
    utils.ensure_dir_exists(fsdb.holdout_dir())
    utils.ensure_dir_exists(fsdb.sgf_dir())
    utils.ensure_dir_exists(fsdb.eval_dir())
    utils.ensure_dir_exists(fsdb.golden_chunk_dir())
    utils.ensure_dir_exists(fsdb.working_dir())

    bootstrap_name = shipname.generate(0)
    bootstrap_model_path = os.path.join(fsdb.models_dir(), bootstrap_name)
    mask_flags.checked_run([
        'python3', 'bootstrap.py',
        '--export_path={}'.format(bootstrap_model_path),
        '--work_dir={}'.format(fsdb.working_dir()),
        '--flagfile=rl_loop/local_flags'
    ])

    selfplay_cmd = [
        'python3', 'selfplay.py',
        '--load_file={}'.format(bootstrap_model_path),
        '--selfplay_dir={}'.format(
            os.path.join(fsdb.selfplay_dir(),
                         bootstrap_name)), '--holdout_dir={}'.format(
                             os.path.join(fsdb.holdout_dir(), bootstrap_name)),
        '--sgf_dir={}'.format(fsdb.sgf_dir()), '--holdout_pct=0',
        '--flagfile=rl_loop/local_flags'
    ]

    # Selfplay twice
    mask_flags.checked_run(selfplay_cmd)
    mask_flags.checked_run(selfplay_cmd)
    # and once more to generate a held out game for validation
    # exploits flags behavior where if you pass flag twice, second one wins.
    mask_flags.checked_run(selfplay_cmd + ['--holdout_pct=100'])

    # Double check that at least one sgf has been generated.
    assert os.listdir(os.path.join(fsdb.sgf_dir(), 'full'))

    print("Making shuffled golden chunk from selfplay data...")
    # TODO(amj): refactor example_buffer so it can be called the same way
    # as everything else.
    eb.make_chunk_for(output_dir=fsdb.golden_chunk_dir(),
                      local_dir=fsdb.working_dir(),
                      game_dir=fsdb.selfplay_dir(),
                      model_num=1,
                      positions=64,
                      threads=8,
                      sampling_frac=1)

    tf_records = sorted(
        gfile.Glob(os.path.join(fsdb.golden_chunk_dir(), '*.tfrecord.zz')))

    trained_model_name = shipname.generate(1)
    trained_model_path = os.path.join(fsdb.models_dir(), trained_model_name)

    # Train on shuffled game data
    mask_flags.checked_run([
        'python3', 'train.py', *tf_records,
        '--work_dir={}'.format(fsdb.working_dir()),
        '--export_path={}'.format(trained_model_path),
        '--flagfile=rl_loop/local_flags'
    ])

    # Validate the trained model on held out game
    mask_flags.checked_run([
        'python3', 'validate.py',
        os.path.join(fsdb.holdout_dir(), bootstrap_name),
        '--work_dir={}'.format(fsdb.working_dir()),
        '--flagfile=rl_loop/local_flags'
    ])

    # Verify that trained model works for selfplay
    # exploits flags behavior where if you pass flag twice, second one wins.
    mask_flags.checked_run(selfplay_cmd +
                           ['--load_file={}'.format(trained_model_path)])

    mask_flags.checked_run([
        'python3', 'evaluate.py', bootstrap_model_path, trained_model_path,
        '--games=1', '--eval_sgf_dir={}'.format(fsdb.eval_dir()),
        '--flagfile=rl_loop/local_flags'
    ])
    print("Completed integration test!")