Ejemplo n.º 1
0
def make_tf2_export(weights_path, export_dir):
    if os.path.exists(export_dir):
        log('TF2 export already exists in {}, skipping TF2 export'.format(
            export_dir))
        return

    # Create a TF2 Module wrapper around YAMNet.
    log('Building and checking TF2 Module ...')
    params = yamnet_params.Params()
    yamnet = YAMNet(weights_path, params)
    check_model(yamnet, yamnet.class_map_path(), params)
    log('Done')

    # Make TF2 SavedModel export.
    log('Making TF2 SavedModel export ...')
    tf.saved_model.save(yamnet, export_dir)
    log('Done')

    # Check export with TF-Hub in TF2.
    log('Checking TF2 SavedModel export in TF2 ...')
    model = tfhub.load(export_dir)
    check_model(model, model.class_map_path(), params)
    log('Done')

    # Check export with TF-Hub in TF1.
    log('Checking TF2 SavedModel export in TF1 ...')
    with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
        model = tfhub.load(export_dir)
        sess.run(tf.compat.v1.global_variables_initializer())

        def run_model(waveform):
            return sess.run(model(waveform))

        check_model(run_model, model.class_map_path().eval(), params)
    log('Done')
Ejemplo n.º 2
0
def make_tflite_export(weights_path, model_path, export_dir):
    if os.path.exists(export_dir):
        log('TF-Lite export already exists in {}, skipping TF-Lite export'.
            format(export_dir))
        return

    # Create a TF-Lite compatible Module wrapper around YAMNet.
    log('Building and checking TF-Lite Module ...')
    params = yamnet_params.Params(tflite_compatible=True)
    yamnet = YAMNet(weights_path, params, model_path)
    check_model(yamnet, yamnet.class_map_path(), params)
    log('Done')

    # Make TF-Lite SavedModel export.
    log('Making TF-Lite SavedModel export ...')
    saved_model_dir = os.path.join(export_dir, 'saved_model')
    os.makedirs(saved_model_dir)
    tf.saved_model.save(yamnet, saved_model_dir)
    log('Done')

    # Check that the export can be loaded and works.
    log('Checking TF-Lite SavedModel export in TF2 ...')
    model = tf.saved_model.load(saved_model_dir)
    check_model(model, model.class_map_path(), params)
    log('Done')

    # Make a TF-Lite model from the SavedModel.
    log('Making TF-Lite model ...')
    tflite_converter = tf.lite.TFLiteConverter.from_saved_model(
        saved_model_dir)
    tflite_model = tflite_converter.convert()
    tflite_model_path = os.path.join(export_dir, 'yamnet.tflite')
    with open(tflite_model_path, 'wb') as f:
        f.write(tflite_model)
    log('Done')

    # Check the TF-Lite export.
    log('Checking TF-Lite model ...')
    interpreter = tf.lite.Interpreter(tflite_model_path)
    audio_input_index = interpreter.get_input_details()[0]['index']
    scores_output_index = interpreter.get_output_details()[0]['index']
    embeddings_output_index = interpreter.get_output_details()[1]['index']

    #spectrogram_output_index = interpreter.get_output_details()[2]['index']

    def run_model(waveform):
        interpreter.resize_tensor_input(audio_input_index, [len(waveform)],
                                        strict=True)
        interpreter.allocate_tensors()
        interpreter.set_tensor(audio_input_index, waveform)
        interpreter.invoke()
        return (interpreter.get_tensor(scores_output_index),
                interpreter.get_tensor(embeddings_output_index))  #,
        #       interpreter.get_tensor(spectrogram_output_index))

    check_model(run_model, 'yamnet_class_map.csv', params)
    log('Done')

    return saved_model_dir
Ejemplo n.º 3
0
def make_tflite_export(weights_path, export_dir):
  if os.path.exists(export_dir):
    log('TF-Lite export already exists in {}, skipping TF-Lite export'.format(
        export_dir))
    return

  # Create a TF-Lite compatible Module wrapper around YAMNet.
  log('Building and checking TF-Lite Module ...')
  params = yamnet_params.Params(tflite_compatible=True)
  yamnet = YAMNet(weights_path, params)
  check_model(yamnet, yamnet.class_map_path(), params)
  log('Done')

  # Make TF-Lite SavedModel export.
  log('Making TF-Lite SavedModel export ...')
  saved_model_dir = os.path.join(export_dir, 'saved_model')
  os.makedirs(saved_model_dir)
  tf.saved_model.save(
      yamnet, saved_model_dir,
      signatures={'serving_default': yamnet.__call__.get_concrete_function()})
  log('Done')

  # Check that the export can be loaded and works.
  log('Checking TF-Lite SavedModel export in TF2 ...')
  model = tf.saved_model.load(saved_model_dir)
  check_model(model, model.class_map_path(), params)
  log('Done')

  # Make a TF-Lite model from the SavedModel.
  log('Making TF-Lite model ...')
  tflite_converter = tf.lite.TFLiteConverter.from_saved_model(
      saved_model_dir, signature_keys=['serving_default'])
  tflite_model = tflite_converter.convert()
  tflite_model_path = os.path.join(export_dir, 'yamnet.tflite')
  with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)
  log('Done')

  # Check the TF-Lite export.
  log('Checking TF-Lite model ...')
  interpreter = tf.lite.Interpreter(tflite_model_path)
  runner = interpreter.get_signature_runner('serving_default')
  check_model(runner, 'yamnet_class_map.csv', params)
  log('Done')

  return saved_model_dir