def test_global_annotation(self):
        # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds.
        # pylint: disable=g-import-not-at-top
        try:
            from tensorflow_transform import annotations_pb2
        except ImportError:
            return
        # pylint: enable=g-import-not-at-top
        with tf.Graph().as_default() as graph:
            outputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
                'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
            }

            # Annotate an arbitrary proto at the schema level (not sure what global
            # schema boundaries would mean, but hey I'm just a test).
            boundaries = tf.constant([[1.0]])
            message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
            sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
            message_proto = encode_proto_op.encode_proto(
                sizes, [tf.cast(boundaries, tf.float32)], ['boundaries'],
                message_type)[0]
            type_url = os.path.join('type.googleapis.com', message_type)
            schema_inference.annotate(type_url, message_proto)

            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema._schema_proto.annotation.extra_metadata,
                               1)
                for annotation in schema._schema_proto.annotation.extra_metadata:
                    # Extract the annotated message and validate its contents
                    message = annotations_pb2.BucketBoundaries()
                    annotation.Unpack(message)
                    self.assertAllClose(list(message.boundaries), [1])
    def test_global_annotation(self, use_compat_v1):
        # pylint: enable=g-import-not-at-top
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(_):
            # Annotate an arbitrary proto at the schema level (not sure what global
            # schema boundaries would mean, but hey I'm just a test).
            boundaries = tf.constant([[1.0]])
            message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
            sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
            message_proto = tf.raw_ops.EncodeProto(
                sizes=sizes,
                values=[tf.cast(boundaries, tf.float32)],
                field_names=['boundaries'],
                message_type=message_type)[0]
            type_url = os.path.join('type.googleapis.com', message_type)
            schema_inference.annotate(type_url, message_proto)
            return {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
                'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
            }

        schema = self._get_schema(preprocessing_fn,
                                  use_compat_v1,
                                  create_session=True)
        self.assertLen(schema.annotation.extra_metadata, 1)
        for annotation in schema.annotation.extra_metadata:
            # Extract the annotated message and validate its contents
            message = annotations_pb2.BucketBoundaries()
            annotation.Unpack(message)
            self.assertAllClose(list(message.boundaries), [1])
    def test_global_annotation(self):
        # pylint: enable=g-import-not-at-top
        with tf.compat.v1.Graph().as_default() as graph:
            outputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
                'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
            }

            # Annotate an arbitrary proto at the schema level (not sure what global
            # schema boundaries would mean, but hey I'm just a test).
            boundaries = tf.constant([[1.0]])
            message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
            sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
            message_proto = tf.raw_ops.EncodeProto(
                sizes=sizes,
                values=[tf.cast(boundaries, tf.float32)],
                field_names=['boundaries'],
                message_type=message_type)[0]
            type_url = os.path.join('type.googleapis.com', message_type)
            schema_inference.annotate(type_url, message_proto)

            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema.annotation.extra_metadata, 1)
                for annotation in schema.annotation.extra_metadata:
                    # Extract the annotated message and validate its contents
                    message = annotations_pb2.BucketBoundaries()
                    annotation.Unpack(message)
                    self.assertAllClose(list(message.boundaries), [1])
Пример #4
0
    def test_bucketization_annotation(self):
        # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds.
        # pylint: disable=g-import-not-at-top
        try:
            from tensorflow_transform import annotations_pb2
        except ImportError:
            return
        # pylint: enable=g-import-not-at-top
        with tf.Graph().as_default() as graph:
            inputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3]),
                'bar': tf.convert_to_tensor([0, 2, 0, 2]),
            }
            boundaries_foo = tf.expand_dims(tf.convert_to_tensor([.5, 1.5]),
                                            axis=0)
            boundaries_bar = tf.expand_dims(tf.convert_to_tensor([.1, .2]),
                                            axis=0)
            outputs = {}

            # tft.apply_buckets will annotate the feature in the output schema to
            # indicate the bucket boundaries that were applied.
            outputs['Bucketized_foo'] = mappers.apply_buckets(
                inputs['foo'], boundaries_foo)
            outputs['Bucketized_bar'] = mappers.apply_buckets(
                inputs['bar'], boundaries_bar)
            # Create a session to actually evaluate the annotations and extract the
            # the output schema with annotations applied.
            with tf.compat.v1.Session(graph=graph) as session:
                schema = schema_inference.infer_feature_schema(
                    outputs, graph, session)
                self.assertLen(schema.feature, 2)
                for feature in schema.feature:
                    self.assertLen(feature.annotation.extra_metadata, 1)
                    for annotation in feature.annotation.extra_metadata:

                        # Extract the annotated message and validate its contents
                        message = annotations_pb2.BucketBoundaries()
                        annotation.Unpack(message)
                        if feature.name == 'Bucketized_foo':
                            self.assertAllClose(list(message.boundaries),
                                                [.5, 1.5])
                        elif feature.name == 'Bucketized_bar':
                            self.assertAllClose(list(message.boundaries),
                                                [.1, .2])
                        else:
                            raise RuntimeError('Unexpected features in schema')
    def test_bucketization_annotation(self, use_compat_v1):
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(_):
            inputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3]),
                'bar': tf.convert_to_tensor([0, 2, 0, 2]),
            }
            boundaries_foo = tf.expand_dims(tf.convert_to_tensor([.5, 1.5]),
                                            axis=0)
            boundaries_bar = tf.expand_dims(tf.convert_to_tensor([.1, .2]),
                                            axis=0)
            outputs = {}
            # tft.apply_buckets will annotate the feature in the output schema to
            # indicate the bucket boundaries that were applied.
            outputs['Bucketized_foo'] = mappers.apply_buckets(
                inputs['foo'], boundaries_foo)
            outputs['Bucketized_bar'] = mappers.apply_buckets(
                inputs['bar'], boundaries_bar)
            return outputs

        schema = self._get_schema(preprocessing_fn,
                                  use_compat_v1,
                                  create_session=True)
        self.assertLen(schema.feature, 2)
        for feature in schema.feature:
            self.assertLen(feature.annotation.extra_metadata, 1)
            for annotation in feature.annotation.extra_metadata:

                # Extract the annotated message and validate its contents
                message = annotations_pb2.BucketBoundaries()
                annotation.Unpack(message)
                if feature.name == 'Bucketized_foo':
                    self.assertAllClose(list(message.boundaries), [.5, 1.5])
                elif feature.name == 'Bucketized_bar':
                    self.assertAllClose(list(message.boundaries), [.1, .2])
                else:
                    raise RuntimeError('Unexpected features in schema')