예제 #1
0
    for model_name, record_files in sorted(model_gamedata.items()):
        with timer("Processing %s" % model_name):
            if set(record_files) <= already_processed:
                print("%s is already fully processed" % model_name)
                continue
            for i, example_batch in enumerate(
                    tqdm(
                        preprocessing.shuffle_tf_examples(
                            examples_per_record, record_files))):
                output_record = os.path.join(
                    output_directory,
                    '{}-{}.tfrecord.zz'.format(model_name, str(i)))
                preprocessing.write_tf_examples(output_record,
                                                example_batch,
                                                serialize=False)
            already_processed.update(record_files)

    print("Processed %s new files" %
          (len(already_processed) - num_already_processed))
    with gfile.GFile(meta_file, 'w') as f:
        f.write('\n'.join(sorted(already_processed)))


parser = argparse.ArgumentParser()
argh.add_commands(parser, [gtp, bootstrap, train, selfplay, gather, evaluate])

if __name__ == '__main__':
    cloud_logging.configure()
    argh.dispatch(parser)
예제 #2
0
        already_processed = set()

    num_already_processed = len(already_processed)

    for model_name, record_files in sorted(model_gamedata.items()):
        if set(record_files) <= already_processed:
            continue
        print("Gathering files for %s:" % model_name)
        for i, example_batch in enumerate(
                tqdm(preprocessing.shuffle_tf_examples(examples_per_record, record_files))):
            output_record = os.path.join(output_directory,
                                         '{}-{}.tfrecord.zz'.format(model_name, str(i)))
            preprocessing.write_tf_examples(
                output_record, example_batch, serialize=False)
        already_processed.update(record_files)

    print("Processed %s new files" %
          (len(already_processed) - num_already_processed))
    with gfile.GFile(meta_file, 'w') as f:
        f.write('\n'.join(sorted(already_processed)))
    qmeas.stop_time('gather')


parser = argparse.ArgumentParser()
argh.add_commands(parser, [gtp, bootstrap, train,
                           selfplay, gather, evaluate, validate])

if __name__ == '__main__':
    cloud_logging.configure()
    argh.dispatch(parser)