コード例 #1
0
    def test_global_annotation(self):
        # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds.
        # pylint: disable=g-import-not-at-top
        try:
            from tensorflow_transform import annotations_pb2
        except ImportError:
            return
        # pylint: enable=g-import-not-at-top
        with tf.Graph().as_default() as graph:
            outputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
                'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
            }

            # Annotate an arbitrary proto at the schema level (not sure what global
            # schema boundaries would mean, but hey I'm just a test).
            boundaries = tf.constant([[1.0]])
            message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
            sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
            message_proto = encode_proto_op.encode_proto(
                sizes, [tf.cast(boundaries, tf.float32)], ['boundaries'],
                message_type)[0]
            type_url = os.path.join('type.googleapis.com', message_type)
            schema_inference.annotate(type_url, message_proto)

            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema._schema_proto.annotation.extra_metadata,
                               1)
                for annotation in schema._schema_proto.annotation.extra_metadata:
                    # Extract the annotated message and validate its contents
                    message = annotations_pb2.BucketBoundaries()
                    annotation.Unpack(message)
                    self.assertAllClose(list(message.boundaries), [1])
コード例 #2
0
    def test_global_annotation(self):
        # pylint: enable=g-import-not-at-top
        with tf.compat.v1.Graph().as_default() as graph:
            outputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
                'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
            }

            # Annotate an arbitrary proto at the schema level (not sure what global
            # schema boundaries would mean, but hey I'm just a test).
            boundaries = tf.constant([[1.0]])
            message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
            sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
            message_proto = tf.raw_ops.EncodeProto(
                sizes=sizes,
                values=[tf.cast(boundaries, tf.float32)],
                field_names=['boundaries'],
                message_type=message_type)[0]
            type_url = os.path.join('type.googleapis.com', message_type)
            schema_inference.annotate(type_url, message_proto)

            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema.annotation.extra_metadata, 1)
                for annotation in schema.annotation.extra_metadata:
                    # Extract the annotated message and validate its contents
                    message = annotations_pb2.BucketBoundaries()
                    annotation.Unpack(message)
                    self.assertAllClose(list(message.boundaries), [1])
コード例 #3
0
 def preprocessing_fn(_):
     # Annotate an arbitrary proto at the schema level (not sure what global
     # schema boundaries would mean, but hey I'm just a test).
     boundaries = tf.constant([[1.0]])
     message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
     sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
     message_proto = tf.raw_ops.EncodeProto(
         sizes=sizes,
         values=[tf.cast(boundaries, tf.float32)],
         field_names=['boundaries'],
         message_type=message_type)[0]
     type_url = os.path.join('type.googleapis.com', message_type)
     schema_inference.annotate(type_url, message_proto)
     return {
         'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
         'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
     }