def compare(pair):
    position, in_path, out_path = pair
    num_batches = dual_net.EXAMPLES_PER_GENERATION // FLAGS.batch_size + 1
    compare_iter = tqdm(zip(batched_reader(in_path), batched_reader(out_path)),
                        desc=os.path.basename(in_path),
                        position=position,
                        total=num_batches)

    count = 0
    equal = 0
    results = {}
    for a, b in compare_iter:
        # a, b are batched records
        xa, ra = preprocessing.batch_parse_tf_example(len(a), a)
        xb, rb = preprocessing.batch_parse_tf_example(len(b), b)
        xa, xb, ra, rb = tf.Session().run([xa, xb, ra, rb])

        # NOTE: This relies on python3 deterministic dictionaries.
        values = [xa] + list(ra.values()) + [xb] + list(rb.values())
        for xa, pa, va, xb, pb, vb in zip(*values):
            count += 1
            assert va == vb
            equal += (xa == xb).all() + (pa == pb).all()
        results['equal'] = "{}/{} = {:.3f}".format(equal, count, equal / count)
        compare_iter.set_postfix(results)
Пример #2
0
def convert(paths):
    position, in_path, out_path = paths
    assert tf.gfile.Exists(in_path)
    assert tf.gfile.Exists(os.path.dirname(out_path))

    in_size = get_size(in_path)
    if tf.gfile.Exists(out_path):
        # Make sure out_path is about the size of in_path
        size = get_size(out_path)
        error = (size - good_size) / (in_size + 1)
        # 5% smaller to 20% larger
        if -0.05 < error < 0.20:
            return out_path + " already existed"
        return "ERROR on file size ({:.1f}% diff) {}".format(
            100 * error, out_path)

    #assert abs(in_size/2**20 - 670) <= 80, in_size
    num_batches = dual_net.EXAMPLES_PER_GENERATION // FLAGS.batch_size + 1

    with tf.python_io.TFRecordWriter(out_path, OPTS) as writer:
        record_iter = tqdm(batched_reader(in_path),
                           desc=os.path.basename(in_path),
                           position=position,
                           total=num_batches)
        for record in record_iter:
            xs, rs = preprocessing.batch_parse_tf_example(len(record), record)
            # Undo cast in batch_parse_tf_example.
            xs = tf.cast(xs, tf.uint8)

            # map the rotation function.
            x_rot, r_rot = preprocessing._random_rotation(xs, rs)

            with tf.Session() as sess:
                x_rot, r_rot = sess.run([x_rot, r_rot])
            tf.reset_default_graph()

            pi_rot = r_rot['pi_tensor']
            val_rot = r_rot['value_tensor']
            for r, x, pi, val in zip(record, x_rot, pi_rot, val_rot):
                record_out = preprocessing.make_tf_example(x, pi, val)
                serialized = record_out.SerializeToString()
                writer.write(serialized)
                assert len(r) == len(serialized), (len(r), len(serialized))