Esempio n. 1
0
def rl_loop():
    """Run the reinforcement learning loop

    This is meant to be more of an integration test than a realistic way to run
    the reinforcement learning.
    """
    # TODO(brilee): move these all into appropriate local_flags file.
    # monkeypatch the hyperparams so that we get a quickly executing network.
    flags.FLAGS.conv_width = 8
    flags.FLAGS.fc_width = 16
    flags.FLAGS.trunk_layers = 1
    flags.FLAGS.train_batch_size = 16
    flags.FLAGS.shuffle_buffer_size = 1000
    dual_net.EXAMPLES_PER_GENERATION = 64

    flags.FLAGS.num_readouts = 10

    with tempfile.TemporaryDirectory() as base_dir:
        flags.FLAGS.base_dir = base_dir
        working_dir = os.path.join(base_dir, 'models_in_training')
        flags.FLAGS.model_dir = working_dir
        model_save_path = os.path.join(base_dir, 'models', '000000-bootstrap')
        local_eb_dir = os.path.join(base_dir, 'scratch')
        next_model_save_file = os.path.join(base_dir, 'models',
                                            '000001-nextmodel')
        selfplay_dir = os.path.join(base_dir, 'data', 'selfplay')
        model_selfplay_dir = os.path.join(selfplay_dir, '000000-bootstrap')
        gather_dir = os.path.join(base_dir, 'data', 'training_chunks')
        holdout_dir = os.path.join(base_dir, 'data', 'holdout',
                                   '000000-bootstrap')
        sgf_dir = os.path.join(base_dir, 'sgf', '000000-bootstrap')
        os.makedirs(os.path.join(base_dir, 'data'), exist_ok=True)

        print("Creating random initial weights...")
        main.bootstrap(working_dir, model_save_path)
        print("Playing some games...")
        # Do two selfplay runs to test gather functionality
        main.selfplay(load_file=model_save_path,
                      output_dir=model_selfplay_dir,
                      output_sgf=sgf_dir,
                      holdout_pct=0)
        main.selfplay(load_file=model_save_path,
                      output_dir=model_selfplay_dir,
                      output_sgf=sgf_dir,
                      holdout_pct=0)
        # Do one holdout run to test validation
        main.selfplay(load_file=model_save_path,
                      holdout_dir=holdout_dir,
                      output_dir=model_selfplay_dir,
                      output_sgf=sgf_dir,
                      holdout_pct=100)

        print("See sgf files here?")
        sgf_listing = subprocess.check_output(["ls", "-l", sgf_dir + "/full"])
        print(sgf_listing.decode("utf-8"))

        print("Gathering game output...")
        eb.make_chunk_for(output_dir=gather_dir,
                          local_dir=local_eb_dir,
                          game_dir=selfplay_dir,
                          model_num=1,
                          positions=dual_net.EXAMPLES_PER_GENERATION,
                          threads=8,
                          samples_per_game=200)

        print("Training on gathered game data...")
        main.train_dir(gather_dir, next_model_save_file)
        print("Trying validate on 'holdout' game...")
        main.validate(working_dir, holdout_dir)
        print("Verifying that new checkpoint is playable...")
        main.selfplay(load_file=next_model_save_file,
                      holdout_dir=holdout_dir,
                      output_dir=model_selfplay_dir,
                      output_sgf=sgf_dir)
def rl_loop():
    """Run the reinforcement learning loop

    This is meant to be more of an integration test than a realistic way to run
    the reinforcement learning.
    """
    # TODO(brilee): move these all into appropriate local_flags file.
    # monkeypatch the hyperparams so that we get a quickly executing network.
    dual_net.get_default_hyperparams = lambda **kwargs: {
        'k': 8,
        'fc_width': 16,
        'num_shared_layers': 1,
        'l2_strength': 1e-4,
        'momentum': 0.9
    }

    dual_net.TRAIN_BATCH_SIZE = 16
    dual_net.EXAMPLES_PER_GENERATION = 64

    # monkeypatch the shuffle buffer size so we don't spin forever shuffling up positions.
    preprocessing.SHUFFLE_BUFFER_SIZE = 1000

    flags.FLAGS.num_readouts = 10

    with tempfile.TemporaryDirectory() as base_dir:
        working_dir = os.path.join(base_dir, 'models_in_training')
        model_save_path = os.path.join(base_dir, 'models', '000000-bootstrap')
        local_eb_dir = os.path.join(base_dir, 'scratch')
        next_model_save_file = os.path.join(base_dir, 'models',
                                            '000001-nextmodel')
        selfplay_dir = os.path.join(base_dir, 'data', 'selfplay')
        model_selfplay_dir = os.path.join(selfplay_dir, '000000-bootstrap')
        gather_dir = os.path.join(base_dir, 'data', 'training_chunks')
        holdout_dir = os.path.join(base_dir, 'data', 'holdout',
                                   '000000-bootstrap')
        sgf_dir = os.path.join(base_dir, 'sgf', '000000-bootstrap')
        os.makedirs(os.path.join(base_dir, 'data'), exist_ok=True)

        print("Creating random initial weights...")
        main.bootstrap(working_dir, model_save_path)
        print("Playing some games...")
        # Do two selfplay runs to test gather functionality
        main.selfplay(load_file=model_save_path,
                      output_dir=model_selfplay_dir,
                      output_sgf=sgf_dir,
                      holdout_pct=0)
        main.selfplay(load_file=model_save_path,
                      output_dir=model_selfplay_dir,
                      output_sgf=sgf_dir,
                      holdout_pct=0)
        # Do one holdout run to test validation
        main.selfplay(load_file=model_save_path,
                      holdout_dir=holdout_dir,
                      output_dir=model_selfplay_dir,
                      output_sgf=sgf_dir,
                      holdout_pct=100)

        print("See sgf files here?")
        sgf_listing = subprocess.check_output(["ls", "-l", sgf_dir + "/full"])
        print(sgf_listing.decode("utf-8"))

        print("Gathering game output...")
        eb.make_chunk_for(output_dir=gather_dir,
                          local_dir=local_eb_dir,
                          game_dir=selfplay_dir,
                          model_num=1,
                          positions=dual_net.EXAMPLES_PER_GENERATION,
                          threads=8,
                          samples_per_game=200)

        print("Training on gathered game data...")
        main.train_dir(working_dir,
                       gather_dir,
                       next_model_save_file,
                       generation_num=1)
        print("Trying validate on 'holdout' game...")
        main.validate(working_dir, holdout_dir)
        print("Verifying that new checkpoint is playable...")
        main.selfplay(load_file=next_model_save_file,
                      holdout_dir=holdout_dir,
                      output_dir=model_selfplay_dir,
                      output_sgf=sgf_dir)