Beispiel #1
0
def batch_parse_tf_example(batch_size, layout, example_batch, use_bf16 = True):
    """
    Args:
        batch_size: batch size
        layout: 'nchw' or 'nhwc'
        example_batch: a batch of tf.Example
    Returns:
        A tuple (feature_tensor, dict of output tensors)
    """
    planes = dual_net.get_features_planes()

    features = {
        'x': tf.compat.v1.FixedLenFeature([], tf.string),
        'pi': tf.compat.v1.FixedLenFeature([], tf.string),
        'outcome': tf.compat.v1.FixedLenFeature([], tf.float32),
    }
    parsed = tf.compat.v1.parse_example(example_batch, features)
    x = tf.compat.v1.decode_raw(parsed['x'], tf.uint8)
    if use_bf16:
        x = tf.cast(x, tf.bfloat16)
    else:
        x = tf.cast(x, tf.float32)

    if layout == 'nhwc':
        shape = [batch_size, go.N, go.N, planes]
    else:
        shape = [batch_size, planes, go.N, go.N]
    x = tf.reshape(x, shape)

    pi = tf.compat.v1.decode_raw(parsed['pi'], tf.float32)
    pi = tf.reshape(pi, [batch_size, go.N * go.N + 1])
    outcome = parsed['outcome']
    outcome.set_shape([batch_size])
    return x, {'pi_tensor': pi, 'value_tensor': outcome}
 def create_random_data(self, num_examples):
     planes = dual_net.get_features_planes()
     raw_data = []
     for _ in range(num_examples):
         feature = (256 * np.random.random([
             go.N, go.N, planes])).astype(np.uint8)
         pi = np.random.random([go.N * go.N + 1]).astype(np.float32)
         value = np.random.random()
         raw_data.append((feature, pi, value))
     return raw_data
Beispiel #3
0
def batch_parse_tf_example(batch_size, example_batch):
    """
    Args:
        example_batch: a batch of tf.Example
    Returns:
        A tuple (feature_tensor, dict of output tensors)
    """
    planes = dual_net.get_features_planes()

    features = {
        'x': tf.FixedLenFeature([], tf.string),
        'pi': tf.FixedLenFeature([], tf.string),
        'outcome': tf.FixedLenFeature([], tf.float32),
    }
    parsed = tf.parse_example(example_batch, features)
    x = tf.decode_raw(parsed['x'], tf.uint8)
    x = tf.cast(x, tf.float32)
    x = tf.reshape(x, [batch_size, go.N, go.N, planes])
    pi = tf.decode_raw(parsed['pi'], tf.float32)
    pi = tf.reshape(pi, [batch_size, go.N * go.N + 1])
    outcome = parsed['outcome']
    outcome.set_shape([batch_size])
    return x, {'pi_tensor': pi, 'value_tensor': outcome}