def _dump_tflite(model, config):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    export_util.configure_tflite_converter(config, converter)
    tflite_buffer = converter.convert()
    tf.io.gfile.makedirs(os.path.dirname(config.output_dir))
    tflite_path = os.path.join(config.output_dir,
                               f'{config.model_name}.tflite')
    tf.io.gfile.GFile(tflite_path, 'wb').write(tflite_buffer)
    return tflite_path
Example #2
0
def run_export():
    """Exports TFLite with PTQ."""
    export_config = get_export_config_from_flags()
    model = export_util.build_experiment_model(
        experiment_type=export_config.model_name)

    if export_config.ckpt_path:
        logging.info('Loading checkpoint from %s', FLAGS.ckpt_path)
        common_modules.load_weights(
            model,
            export_config.ckpt_path,
            checkpoint_format=export_config.ckpt_format)
    else:
        logging.info(
            'No checkpoint provided. Using randomly initialized weights.')

    if export_config.output_layer is not None:
        all_layer_names = {l.name for l in model.layers}
        if export_config.output_layer not in all_layer_names:
            model.summary()
            logging.info(
                'Cannot find the layer %s in the model. See the above summary to '
                'chose an output layer.', export_config.output_layer)
            return
        output_layer = model.get_layer(export_config.output_layer)
        model = tf.keras.Model(model.input, output_layer.output)

    model_input = tf.keras.Input(shape=(export_config.image_size,
                                        export_config.image_size, 3),
                                 batch_size=1)
    model_output = export_util.finalize_serving(model(model_input),
                                                export_config)
    model_for_inference = tf.keras.Model(model_input, model_output)

    # Convert to tflite. Quantize if quantization parameters are specified.
    converter = tf.lite.TFLiteConverter.from_keras_model(model_for_inference)
    export_util.configure_tflite_converter(export_config, converter)
    tflite_buffer = converter.convert()

    # Make sure the base directory exists and write tflite.
    tf.io.gfile.makedirs(os.path.dirname(export_config.output_dir))
    tflite_path = os.path.join(export_config.output_dir,
                               f'{export_config.model_name}.tflite')
    tf.io.gfile.GFile(tflite_path, 'wb').write(tflite_buffer)
    print('TfLite model exported to {}'.format(tflite_path))

    # Export saved model.
    saved_model_path = os.path.join(export_config.output_dir,
                                    export_config.model_name)
    if FLAGS.export_keras_model:
        model_for_inference.save(saved_model_path)
    else:
        tf.saved_model.save(model_for_inference, saved_model_path)
    print('SavedModel exported to {}'.format(saved_model_path))