Пример #1
0
def convert(paths):
    position, in_path, out_path = paths
    assert tf.gfile.Exists(in_path)
    assert tf.gfile.Exists(os.path.dirname(out_path))

    in_size = get_size(in_path)
    if tf.gfile.Exists(out_path):
        # Make sure out_path is about the size of in_path
        size = get_size(out_path)
        error = (size - good_size) / (in_size + 1)
        # 5% smaller to 20% larger
        if -0.05 < error < 0.20:
            return out_path + " already existed"
        return "ERROR on file size ({:.1f}% diff) {}".format(
            100 * error, out_path)

    #assert abs(in_size/2**20 - 670) <= 80, in_size
    num_batches = dual_net.EXAMPLES_PER_GENERATION // FLAGS.batch_size + 1

    with tf.python_io.TFRecordWriter(out_path, OPTS) as writer:
        record_iter = tqdm(batched_reader(in_path),
                           desc=os.path.basename(in_path),
                           position=position,
                           total=num_batches)
        for record in record_iter:
            xs, rs = preprocessing.batch_parse_tf_example(len(record), record)
            # Undo cast in batch_parse_tf_example.
            xs = tf.cast(xs, tf.uint8)

            # map the rotation function.
            x_rot, r_rot = preprocessing._random_rotation(xs, rs)

            with tf.Session() as sess:
                x_rot, r_rot = sess.run([x_rot, r_rot])
            tf.reset_default_graph()

            pi_rot = r_rot['pi_tensor']
            val_rot = r_rot['value_tensor']
            for r, x, pi, val in zip(record, x_rot, pi_rot, val_rot):
                record_out = preprocessing.make_tf_example(x, pi, val)
                serialized = record_out.SerializeToString()
                writer.write(serialized)
                assert len(r) == len(serialized), (len(r), len(serialized))
Пример #2
0
def main(unused_argv):
    in_path = FLAGS.in_path
    out_path = FLAGS.out_path

    assert tf.gfile.Exists(in_path)
    # TODO(amj): Why does ensure_dir_exists skip gs paths?
    #tf.gfile.MakeDirs(os.path.dirname(out_path))
    #assert tf.gfile.Exists(os.path.dirname(out_path))

    policy_err = []
    value_err = []

    print()
    with tf.python_io.TFRecordWriter(out_path, OPTS) as writer:
        ds_iter = preprocessing.get_input_tensors(FLAGS.batch_size, [in_path],
                                                  shuffle_examples=False,
                                                  random_rotation=False,
                                                  filter_amount=1.0)

        with tf.Session() as sess:
            features, labels = ds_iter
            p_in = labels['pi_tensor']
            v_in = labels['value_tensor']

            p_out, v_out, logits = dual_net.model_inference_fn(
                features, False, FLAGS.flag_values_dict())
            tf.train.Saver().restore(sess, FLAGS.model)

            # TODO(seth): Add policy entropy.

            p_err = tf.nn.softmax_cross_entropy_with_logits_v2(
                logits=logits, labels=tf.stop_gradient(p_in))
            v_err = tf.square(v_out - v_in)

            for _ in tqdm(itertools.count(1)):
                try:
                    # Undo cast in batch_parse_tf_example.
                    x_in = tf.cast(features, tf.int8)

                    x, pi, val, pi_err, val_err = sess.run(
                        [x_in, p_out, v_out, p_err, v_err])

                    for i, (x_i, pi_i, val_i) in enumerate(zip(x, pi, val)):
                        # NOTE: The teacher's policy has much higher entropy
                        # Than the Self-play policy labels which are mostly 0
                        # expect that resulting file is 3-5x larger.

                        r = preprocessing.make_tf_example(x_i, pi_i, val_i)
                        serialized = r.SerializeToString()
                        writer.write(serialized)

                    policy_err.extend(pi_err)
                    value_err.extend(val_err)

                except tf.errors.OutOfRangeError:
                    print()
                    print("Breaking OutOfRangeError")
                    break

    print("Counts", len(policy_err), len(value_err))
    test()

    plt.subplot(121)
    n, bins, patches = plt.hist(policy_err, 40)
    plt.title('Policy Error histogram')

    plt.subplot(122)
    n, bins, patches = plt.hist(value_err, 40)
    plt.title('Value Error')

    plt.show()