예제 #1
0
def validate_pro():
    """Validate on professional data."""
    cmd = [
        'python3', 'validate.py', FLAGS.pro_dataset, '--use_tpu',
        '--tpu_name={}'.format(TPU_NAME),
        '--work_dir={}'.format(fsdb.working_dir()),
        '--flagfile=rl_loop/distributed_flags', '--validate_name=pro'
    ]
    mask_flags.run(cmd)
예제 #2
0
def validate_holdout_selfplay():
    """Validate on held-out selfplay data."""
    holdout_dirs = (
        os.path.join(fsdb.holdout_dir(), d)
        for d in reversed(gfile.ListDirectory(fsdb.holdout_dir()))
        if gfile.IsDirectory(os.path.join(fsdb.holdout_dir(), d))
        for f in gfile.ListDirectory(os.path.join(fsdb.holdout_dir(), d)))

    # This is a roundabout way of computing how many hourly directories we need
    # to read in order to encompass 20,000 holdout games.
    holdout_dirs = set(itertools.islice(holdout_dirs), 20000)
    cmd = ['python3', 'validate.py'] + list(holdout_dirs) + [
        '--use_tpu', '--tpu_name={}'.format(TPU_NAME),
        '--flagfile=rl_loop/distributed_flags', '--expand_validation_dirs'
    ]
    mask_flags.run(cmd)
예제 #3
0
def freeze(save_path, rewrite_tpu=False):
    cmd = [
        'python3', 'freeze_graph.py',
        '--work_dir={}'.format(fsdb.working_dir()),
        '--model_path={}'.format(save_path)
    ]

    if rewrite_tpu:
        cmd.extend(['--use_tpu', '--tpu_name={}'.format(TPU_NAME)])

    return mask_flags.run(cmd)
예제 #4
0
def train():
    model_num, model_name = fsdb.get_latest_model()
    print("Training on gathered game data, initializing from {}".format(
        model_name))
    new_model_num = model_num + 1
    new_model_name = shipname.generate(new_model_num)
    print("New model will be {}".format(new_model_name))
    save_file = os.path.join(fsdb.models_dir(), new_model_name)

    # TODO(jacksona): Refactor train.py to take the filepath as a flag.
    cmd = [
        'python3', 'train.py', '__unused_file__', '--use_tpu', '--use_bt',
        '--work_dir={}'.format(fsdb.working_dir()),
        '--tpu_name={}'.format(TPU_NAME),
        '--flagfile=rl_loop/distributed_flags',
        '--export_path={}'.format(save_file)
    ]

    completed_process = mask_flags.run(cmd)
    if completed_process.returncode > 0:
        print("Training failed!")
        return completed_process

    # Train.py already copies the {data,index,meta} files to $BUCKET/models
    # Persist the checkpoint two ways:
    # Freeze the .ckpt file in the work_dir for the TPU selfplayers
    # Freeze a non-tpu version of the graph for later GPU use.
    latest_checkpoint = tf.train.latest_checkpoint(fsdb.working_dir())
    p = freeze(latest_checkpoint, rewrite_tpu=True)
    if p.returncode > 0:
        print("== TPU freeze failed!")
        return p

    p = freeze(save_file, rewrite_tpu=False)
    if p.returncode > 0:
        print("== Model freeze failed!")
        return p

    return completed_process
예제 #5
0
def freeze(save_path, rewrite_tpu=False):
    cmd = ['python3', 'freeze_graph.py',
           '--work_dir={}'.format(fsdb.working_dir()),
           '--model_path={}'.format(save_path)]

    return mask_flags.run(cmd)