Exemplo n.º 1
0
    def preprocess_fn(inputs):
        image_tensor = export_utils.parse_image(inputs, input_type,
                                                input_image_size, num_channels)
        # If input_type is `tflite`, do not apply image preprocessing.
        if input_type == 'tflite':
            return image_tensor

        def preprocess_image_fn(inputs):
            image = tf.cast(inputs, dtype=tf.float32)
            image = image / 255.
            (image, image_info) = yolo_model_fn.letterbox(
                image,
                input_image_size,
                letter_box=params.task.validation_data.parser.letter_box)
            return image, image_info

        images_spec = tf.TensorSpec(shape=input_image_size + [3],
                                    dtype=tf.float32)

        image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)

        images, image_info = tf.nest.map_structure(
            tf.identity,
            tf.map_fn(preprocess_image_fn,
                      elems=image_tensor,
                      fn_output_signature=(images_spec, image_info_spec),
                      parallel_iterations=32))

        return images, image_info
Exemplo n.º 2
0
    def preprocess_fn(inputs):
        image_tensor = export_utils.parse_image(inputs, input_type,
                                                input_image_size, num_channels)
        # If input_type is `tflite`, do not apply image preprocessing.
        if input_type == 'tflite':
            return image_tensor

        def preprocess_image_fn(inputs):
            return classification_input.Parser.inference_fn(
                inputs, input_image_size, num_channels)

        images = tf.map_fn(preprocess_image_fn,
                           elems=image_tensor,
                           fn_output_signature=tf.TensorSpec(
                               shape=input_image_size + [num_channels],
                               dtype=tf.float32))

        return images