def test_global_annotation(self): # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds. # pylint: disable=g-import-not-at-top try: from tensorflow_transform import annotations_pb2 except ImportError: return # pylint: enable=g-import-not-at-top with tf.Graph().as_default() as graph: outputs = { 'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64), 'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64), } # Annotate an arbitrary proto at the schema level (not sure what global # schema boundaries would mean, but hey I'm just a test). boundaries = tf.constant([[1.0]]) message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name sizes = tf.expand_dims([tf.size(boundaries)], axis=0) message_proto = encode_proto_op.encode_proto( sizes, [tf.cast(boundaries, tf.float32)], ['boundaries'], message_type)[0] type_url = os.path.join('type.googleapis.com', message_type) schema_inference.annotate(type_url, message_proto) with tf.compat.v1.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( outputs, graph, session) self.assertLen(schema._schema_proto.annotation.extra_metadata, 1) for annotation in schema._schema_proto.annotation.extra_metadata: # Extract the annotated message and validate its contents message = annotations_pb2.BucketBoundaries() annotation.Unpack(message) self.assertAllClose(list(message.boundaries), [1])
def test_global_annotation(self): # pylint: enable=g-import-not-at-top with tf.compat.v1.Graph().as_default() as graph: outputs = { 'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64), 'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64), } # Annotate an arbitrary proto at the schema level (not sure what global # schema boundaries would mean, but hey I'm just a test). boundaries = tf.constant([[1.0]]) message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name sizes = tf.expand_dims([tf.size(boundaries)], axis=0) message_proto = tf.raw_ops.EncodeProto( sizes=sizes, values=[tf.cast(boundaries, tf.float32)], field_names=['boundaries'], message_type=message_type)[0] type_url = os.path.join('type.googleapis.com', message_type) schema_inference.annotate(type_url, message_proto) with tf.compat.v1.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( outputs, graph, session) self.assertLen(schema.annotation.extra_metadata, 1) for annotation in schema.annotation.extra_metadata: # Extract the annotated message and validate its contents message = annotations_pb2.BucketBoundaries() annotation.Unpack(message) self.assertAllClose(list(message.boundaries), [1])
def preprocessing_fn(_): # Annotate an arbitrary proto at the schema level (not sure what global # schema boundaries would mean, but hey I'm just a test). boundaries = tf.constant([[1.0]]) message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name sizes = tf.expand_dims([tf.size(boundaries)], axis=0) message_proto = tf.raw_ops.EncodeProto( sizes=sizes, values=[tf.cast(boundaries, tf.float32)], field_names=['boundaries'], message_type=message_type)[0] type_url = os.path.join('type.googleapis.com', message_type) schema_inference.annotate(type_url, message_proto) return { 'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64), 'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64), }