def test_global_annotation(self): # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds. # pylint: disable=g-import-not-at-top try: from tensorflow_transform import annotations_pb2 except ImportError: return # pylint: enable=g-import-not-at-top with tf.Graph().as_default() as graph: outputs = { 'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64), 'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64), } # Annotate an arbitrary proto at the schema level (not sure what global # schema boundaries would mean, but hey I'm just a test). boundaries = tf.constant([[1.0]]) message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name sizes = tf.expand_dims([tf.size(boundaries)], axis=0) message_proto = encode_proto_op.encode_proto( sizes, [tf.cast(boundaries, tf.float32)], ['boundaries'], message_type)[0] type_url = os.path.join('type.googleapis.com', message_type) schema_inference.annotate(type_url, message_proto) with tf.compat.v1.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( outputs, graph, session) self.assertLen(schema._schema_proto.annotation.extra_metadata, 1) for annotation in schema._schema_proto.annotation.extra_metadata: # Extract the annotated message and validate its contents message = annotations_pb2.BucketBoundaries() annotation.Unpack(message) self.assertAllClose(list(message.boundaries), [1])
def testBadInputs(self): # Invalid field name with self.test_session(): with self.assertRaisesOpError('Unknown field: non_existent_field'): encode_proto_op.encode_proto( sizes=[[1]], values=[np.array([[0.0]], dtype=np.int32)], message_type= 'tensorflow.contrib.proto.RepeatedPrimitiveValue', field_names=['non_existent_field']).eval() # Incorrect types. with self.test_session(): with self.assertRaisesOpError( 'Incompatible type for field double_value.'): encode_proto_op.encode_proto( sizes=[[1]], values=[np.array([[0.0]], dtype=np.int32)], message_type= 'tensorflow.contrib.proto.RepeatedPrimitiveValue', field_names=['double_value']).eval() # Incorrect shapes of sizes. with self.test_session(): with self.assertRaisesOpError( r'sizes should be batch_size \+ \[len\(field_names\)\]'): sizes = array_ops.placeholder(dtypes.int32) values = array_ops.placeholder(dtypes.float64) encode_proto_op.encode_proto( sizes=sizes, values=[values], message_type= 'tensorflow.contrib.proto.RepeatedPrimitiveValue', field_names=['double_value']).eval(feed_dict={ sizes: [[[0, 0]]], values: [[0.0]] }) # Inconsistent shapes of values. with self.test_session(): with self.assertRaisesOpError( 'Values must match up to the last dimension'): sizes = array_ops.placeholder(dtypes.int32) values1 = array_ops.placeholder(dtypes.float64) values2 = array_ops.placeholder(dtypes.int32) (encode_proto_op.encode_proto( sizes=[[1, 1]], values=[values1, values2], message_type= 'tensorflow.contrib.proto.RepeatedPrimitiveValue', field_names=['double_value', 'int32_value']).eval(feed_dict={ values1: [[0.0]], values2: [[0], [0]] }))
def testBadInputs(self): # Invalid field name with self.test_session(): with self.assertRaisesOpError('Unknown field: non_existent_field'): encode_proto_op.encode_proto( sizes=[[1]], values=[np.array([[0.0]], dtype=np.int32)], message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', field_names=['non_existent_field']).eval() # Incorrect types. with self.test_session(): with self.assertRaisesOpError( 'Incompatible type for field double_value.'): encode_proto_op.encode_proto( sizes=[[1]], values=[np.array([[0.0]], dtype=np.int32)], message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', field_names=['double_value']).eval() # Incorrect shapes of sizes. with self.test_session(): with self.assertRaisesOpError( r'sizes should be batch_size \+ \[len\(field_names\)\]'): sizes = array_ops.placeholder(dtypes.int32) values = array_ops.placeholder(dtypes.float64) encode_proto_op.encode_proto( sizes=sizes, values=[values], message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', field_names=['double_value']).eval(feed_dict={ sizes: [[[0, 0]]], values: [[0.0]] }) # Inconsistent shapes of values. with self.test_session(): with self.assertRaisesOpError( 'Values must match up to the last dimension'): sizes = array_ops.placeholder(dtypes.int32) values1 = array_ops.placeholder(dtypes.float64) values2 = array_ops.placeholder(dtypes.int32) (encode_proto_op.encode_proto( sizes=[[1, 1]], values=[values1, values2], message_type='tensorflow.contrib.proto.RepeatedPrimitiveValue', field_names=['double_value', 'int32_value']).eval(feed_dict={ values1: [[0.0]], values2: [[0], [0]] }))
def testVecHostPortRpcUsingEncodeAndDecodeProto(self): with self.test_session() as sess: request_tensors = encode_proto_op.encode_proto( message_type='tensorflow.contrib.rpc.TestCase', field_names=['shape'], sizes=[[3]] * 20, values=[ [[i, i + 1, i + 2] for i in range(20)], ]) response_tensor_strings = self.rpc( method=self.get_method_name('IncrementTestShapes'), address=self._address, request=request_tensors) _, (response_shape,) = decode_proto_op.decode_proto( bytes=response_tensor_strings, message_type='tensorflow.contrib.rpc.TestCase', field_names=['shape'], output_types=[dtypes.int32]) response_shape_values = sess.run(response_shape) self.assertAllEqual([[i + 1, i + 2, i + 3] for i in range(20)], response_shape_values)
def _testRoundtrip(self, in_bufs, message_type, fields): field_names = [f.name for f in fields] out_types = [f.dtype for f in fields] with self.test_session() as sess: sizes, field_tensors = decode_proto_op.decode_proto( in_bufs, message_type=message_type, field_names=field_names, output_types=out_types) out_tensors = encode_proto_op.encode_proto( sizes, field_tensors, message_type=message_type, field_names=field_names) out_bufs, = sess.run([out_tensors]) # Check that the re-encoded tensor has the same shape. self.assertEqual(in_bufs.shape, out_bufs.shape) # Compare the input and output. for in_buf, out_buf in zip(in_bufs.flat, out_bufs.flat): in_obj = test_example_pb2.RepeatedPrimitiveValue() in_obj.ParseFromString(in_buf) out_obj = test_example_pb2.RepeatedPrimitiveValue() out_obj.ParseFromString(out_buf) # Check that the deserialized objects are identical. self.assertEqual(in_obj, out_obj) # Check that the input and output serialized messages are identical. # If we fail here, there is a difference in the serialized # representation but the new serialization still parses. This could # be harmless (a change in map ordering?) or it could be bad (e.g. # loss of packing in the encoding). self.assertEqual(in_buf, out_buf)
def _encode_proto(values_dict, message_type): """A wrapper around encode_proto_op.encode_proto.""" field_names = [] sizes = [] values = [] for field_name, value in sorted(values_dict.items(), key=lambda x: x[0]): if isinstance(value, tf.SparseTensor): size = tf.sparse_reduce_sum(tf.SparseTensor( value.indices, tf.ones_like(value.values, dtype=tf.int32), value.dense_shape), axis=1) value = tf.sparse_tensor_to_dense( value, _DEFAULT_VALUE_BY_DTYPE[value.dtype]) else: value = tf.reshape(value, [tf.shape(value)[0], -1]) size = tf.fill((tf.shape(value)[0], ), tf.shape(value)[1]) field_names.append(field_name) values.append(value) sizes.append(size) sizes = tf.stack(sizes, axis=1) return encode_proto_op.encode_proto(sizes, values, field_names, message_type)
def loop_body(a, unused_b): """Loop body for the tf.while_loop op. Args: a: a constant 0 unused_b: a string placeholder (to satisfy the requirement that a while_loop's condition and body accept the same args as the loop returns). Returns: A TensorFlow subgraph. """ # Request features features. raw_response = tf.contrib.rpc.rpc( address=config.address, method=config.get_features_method, request="", protocol="grpc", fail_fast=True, timeout_in_ms=0, name="get_features") # Decode features from a proto to a flat tensor. _, (batch_id, flat_features) = decode_proto_op.decode_proto( bytes=raw_response, message_type='minigo.GetFeaturesResponse', field_names=['batch_id', 'features'], output_types=[dtypes.int32, dtypes.float32], descriptor_source=config.descriptor_path, name="decode_raw_features") # Reshape flat features. features = tf.reshape( flat_features, [-1, go.N, go.N, features_lib.NEW_FEATURES_PLANES], name="unflatten_features") # Run inference. policy_output, value_output, _ = dual_net.model_inference_fn( features, False) # Flatten model outputs. flat_policy = tf.reshape(policy_output, [-1], name="flatten_policy") flat_value = value_output # value_output is already flat. # Encode outputs from flat tensors to a proto. request_tensors = encode_proto_op.encode_proto( message_type='minigo.PutOutputsRequest', field_names=['batch_id', 'policy', 'value'], sizes=[[1, policy_output_size, value_output_size]], values=[[batch_id], [flat_policy], [flat_value]], descriptor_source=config.descriptor_path, name="encode_outputs") # Send outputs. response = tf.contrib.rpc.rpc( address=config.address, method=config.put_outputs_method, request=request_tensors, protocol="grpc", fail_fast=True, timeout_in_ms=0, name="put_outputs") return a, response[0]