def create_classification_export_module(params: cfg.ExperimentConfig,
                                        input_type: str,
                                        batch_size: int,
                                        input_image_size: List[int],
                                        num_channels: int = 3):
    """Creats classification export module."""
    input_signature = export_utils.get_image_input_signatures(
        input_type, batch_size, input_image_size, num_channels)
    input_specs = tf.keras.layers.InputSpec(shape=[batch_size] +
                                            input_image_size + [num_channels])

    model = factory.build_classification_model(input_specs=input_specs,
                                               model_config=params.task.model,
                                               l2_regularizer=None)

    def preprocess_fn(inputs):
        image_tensor = export_utils.parse_image(inputs, input_type,
                                                input_image_size, num_channels)
        # If input_type is `tflite`, do not apply image preprocessing.
        if input_type == 'tflite':
            return image_tensor

        def preprocess_image_fn(inputs):
            return classification_input.Parser.inference_fn(
                inputs, input_image_size, num_channels)

        images = tf.map_fn(preprocess_image_fn,
                           elems=image_tensor,
                           fn_output_signature=tf.TensorSpec(
                               shape=input_image_size + [num_channels],
                               dtype=tf.float32))

        return images

    def postprocess_fn(logits):
        probs = tf.nn.softmax(logits)
        return {'logits': logits, 'probs': probs}

    export_module = export_base.ExportModule(params,
                                             model=model,
                                             input_signature=input_signature,
                                             preprocessor=preprocess_fn,
                                             postprocessor=postprocess_fn)
    return export_module
    def test_postprocessor(self):
        tmp_dir = self.get_temp_dir()
        model = TestModel()
        inputs = tf.ones([2, 4], tf.float32)

        postprocess_fn = lambda logits: {'outputs': 2 * logits['outputs']}

        module = export_base_v2.ExportModule(
            params=None,
            model=model,
            input_signature=tf.TensorSpec(shape=[2, 4]),
            postprocessor=postprocess_fn)
        expected_output = postprocess_fn(model(inputs))
        ckpt_path = tf.train.Checkpoint(model=model).save(
            os.path.join(tmp_dir, 'ckpt'))
        export_dir = export_base.export(module, ['serving_default'],
                                        export_savedmodel_dir=tmp_dir,
                                        checkpoint_path=ckpt_path,
                                        timestamped=False)
        imported = tf.saved_model.load(export_dir)
        output = imported.signatures['serving_default'](inputs)
        self.assertAllClose(output['outputs'].numpy(),
                            expected_output['outputs'].numpy())