def batch_parse_tf_example(batch_size, layout, example_batch, use_bf16 = True): """ Args: batch_size: batch size layout: 'nchw' or 'nhwc' example_batch: a batch of tf.Example Returns: A tuple (feature_tensor, dict of output tensors) """ planes = dual_net.get_features_planes() features = { 'x': tf.compat.v1.FixedLenFeature([], tf.string), 'pi': tf.compat.v1.FixedLenFeature([], tf.string), 'outcome': tf.compat.v1.FixedLenFeature([], tf.float32), } parsed = tf.compat.v1.parse_example(example_batch, features) x = tf.compat.v1.decode_raw(parsed['x'], tf.uint8) if use_bf16: x = tf.cast(x, tf.bfloat16) else: x = tf.cast(x, tf.float32) if layout == 'nhwc': shape = [batch_size, go.N, go.N, planes] else: shape = [batch_size, planes, go.N, go.N] x = tf.reshape(x, shape) pi = tf.compat.v1.decode_raw(parsed['pi'], tf.float32) pi = tf.reshape(pi, [batch_size, go.N * go.N + 1]) outcome = parsed['outcome'] outcome.set_shape([batch_size]) return x, {'pi_tensor': pi, 'value_tensor': outcome}
def create_random_data(self, num_examples): planes = dual_net.get_features_planes() raw_data = [] for _ in range(num_examples): feature = (256 * np.random.random([ go.N, go.N, planes])).astype(np.uint8) pi = np.random.random([go.N * go.N + 1]).astype(np.float32) value = np.random.random() raw_data.append((feature, pi, value)) return raw_data
def batch_parse_tf_example(batch_size, example_batch): """ Args: example_batch: a batch of tf.Example Returns: A tuple (feature_tensor, dict of output tensors) """ planes = dual_net.get_features_planes() features = { 'x': tf.FixedLenFeature([], tf.string), 'pi': tf.FixedLenFeature([], tf.string), 'outcome': tf.FixedLenFeature([], tf.float32), } parsed = tf.parse_example(example_batch, features) x = tf.decode_raw(parsed['x'], tf.uint8) x = tf.cast(x, tf.float32) x = tf.reshape(x, [batch_size, go.N, go.N, planes]) pi = tf.decode_raw(parsed['pi'], tf.float32) pi = tf.reshape(pi, [batch_size, go.N * go.N + 1]) outcome = parsed['outcome'] outcome.set_shape([batch_size]) return x, {'pi_tensor': pi, 'value_tensor': outcome}