示例#1
0
  def from_name(cls,
                model_name: str,
                model_weights_path: Optional[str] = None,
                checkpoint_format: Optional[str] = 'tf_checkpoint',
                overrides: Optional[Mapping[str, Any]] = None):
    """Constructs an MobilenetEdgeTPUV2 model from a predefined model name.

    E.g., `MobilenetEdgeTPUV2.from_name('mobilenet_edgetpu_v2_s')`.

    Args:
      model_name: the predefined model name
      model_weights_path: the path to the weights (h5 file or saved model dir)
      checkpoint_format: the model weights format. One of 'tf_checkpoint' or
        'keras_checkpoint'.
      overrides: (optional) a dict containing keys that can override config

    Returns:
      A constructed EfficientNet instance.
    """
    overrides = dict(overrides) if overrides else {}

    # One can define their own custom models if necessary
    MODEL_CONFIGS.update(overrides.pop('model_config', {}))

    model = cls(model_config_name=model_name, overrides=overrides)

    if model_weights_path:
      common_modules.load_weights(model,
                                  model_weights_path,
                                  checkpoint_format=checkpoint_format)
    return model
示例#2
0
def run_export():
    """Exports TFLite with PTQ."""
    export_config = get_export_config_from_flags()
    model = export_util.build_experiment_model(
        experiment_type=export_config.model_name)

    if export_config.ckpt_path:
        logging.info('Loading checkpoint from %s', FLAGS.ckpt_path)
        common_modules.load_weights(
            model,
            export_config.ckpt_path,
            checkpoint_format=export_config.ckpt_format)
    else:
        logging.info(
            'No checkpoint provided. Using randomly initialized weights.')

    if export_config.output_layer is not None:
        all_layer_names = {l.name for l in model.layers}
        if export_config.output_layer not in all_layer_names:
            model.summary()
            logging.info(
                'Cannot find the layer %s in the model. See the above summary to '
                'chose an output layer.', export_config.output_layer)
            return
        output_layer = model.get_layer(export_config.output_layer)
        model = tf.keras.Model(model.input, output_layer.output)

    model_input = tf.keras.Input(shape=(export_config.image_size,
                                        export_config.image_size, 3),
                                 batch_size=1)
    model_output = export_util.finalize_serving(model(model_input),
                                                export_config)
    model_for_inference = tf.keras.Model(model_input, model_output)

    # Convert to tflite. Quantize if quantization parameters are specified.
    converter = tf.lite.TFLiteConverter.from_keras_model(model_for_inference)
    export_util.configure_tflite_converter(export_config, converter)
    tflite_buffer = converter.convert()

    # Make sure the base directory exists and write tflite.
    tf.io.gfile.makedirs(os.path.dirname(export_config.output_dir))
    tflite_path = os.path.join(export_config.output_dir,
                               f'{export_config.model_name}.tflite')
    tf.io.gfile.GFile(tflite_path, 'wb').write(tflite_buffer)
    print('TfLite model exported to {}'.format(tflite_path))

    # Export saved model.
    saved_model_path = os.path.join(export_config.output_dir,
                                    export_config.model_name)
    if FLAGS.export_keras_model:
        model_for_inference.save(saved_model_path)
    else:
        tf.saved_model.save(model_for_inference, saved_model_path)
    print('SavedModel exported to {}'.format(saved_model_path))
    def from_name(cls,
                  model_name: str,
                  model_weights_path: Optional[str] = None,
                  checkpoint_format: Optional[str] = 'tf_checkpoint',
                  overrides: Optional[Dict[str, Any]] = None):
        """Construct an MobilenetEdgeTPU model from a predefined model name.

    E.g., `MobilenetEdgeTPU.from_name('mobilenet_edgetpu')`.

    Args:
      model_name: the predefined model name
      model_weights_path: the path to the weights (h5 file or saved model dir)
      checkpoint_format: the model weights format. One of 'tf_checkpoint' or
        'keras_checkpoint'.
      overrides: (optional) a dict containing keys that can override config

    Returns:
      A constructed EfficientNet instance.
    """
        model_configs = dict(MODEL_CONFIGS)
        overrides = dict(overrides) if overrides else {}

        # One can define their own custom models if necessary
        model_configs.update(overrides.pop('model_config', {}))

        if model_name not in model_configs:
            raise ValueError('Unknown model name {}'.format(model_name))

        config = model_configs[model_name]

        model = cls(config=config, overrides=overrides)

        if model_weights_path:
            common_modules.load_weights(model,
                                        model_weights_path,
                                        checkpoint_format=checkpoint_format)

        return model