Exemple #1
0
def main(unused_argv):
  if not tf.io.gfile.exists(FLAGS.output_dir):
    tf.io.gfile.makedirs(FLAGS.output_dir)

  model = model_export_utils.get_model(
      FLAGS.logdir,
      params={
          'bd': FLAGS.bottleneck_dimension,
          'al': FLAGS.alpha,
          'ms': FLAGS.mobilenet_size,
          'ap': FLAGS.avg_pool,
          'cop': FLAGS.compressor or None,
          'qat': FLAGS.qat,
      },
      tflite_friendly=False,
      checkpoint_number=FLAGS.checkpoint_number,
      include_frontend=FLAGS.frontend)
  tf.keras.models.save_model(model, FLAGS.output_dir)
  assert tf.io.gfile.exists(FLAGS.output_dir)
  logging.info('Successfully wrote to: %s', FLAGS.output_dir)

  # Sanity check the resulting model.
  logging.info('Sanity checking...')
  model_export_utils.sanity_check(
      FLAGS.include_frontend,
      FLAGS.output_dir,
      embedding_dim=FLAGS.bottleneck_dimension,
      tflite=False)
def main(_):
  if tf.io.gfile.glob(os.path.join(FLAGS.output_dir, 'model_*.tflite')):
    existing_files = tf.io.gfile.glob(os.path.join(
        FLAGS.output_dir, 'model_*.tflite'))
    raise ValueError(f'Models cant already exist: {existing_files}')
  else:
    tf.io.gfile.makedirs(FLAGS.output_dir)

  # Get experiment dirs names, params, and output location.
  metadata = []
  exp_names = model_export_utils.get_experiment_dirs(FLAGS.experiment_dir)
  if not exp_names:
    raise ValueError(f'No experiments found: {FLAGS.experiment_dir}')
  for i, exp_name in enumerate(exp_names):
    cur_metadata = Metadata(
        exp_name,
        model_export_utils.get_params(exp_name),
        os.path.join(FLAGS.experiment_dir, exp_name),
        os.path.join(FLAGS.output_dir, f'model_{i}.tflite'))
    metadata.append(cur_metadata)
  logging.info('Number of metadata: %i', len(metadata))

  for m in metadata:
    logging.info('Working on experiment dir: %s', m.param_str)

    # Export SavedModel & convert to TFLite
    # Note that we keep over-writing the SavedModel while converting experiments
    # to TFLite, since we only care about the final flatbuffer models.
    static_model = model_export_utils.get_model(
        checkpoint_folder_path=m.experiment_dir,
        params=m.params,
        tflite_friendly=True,
        checkpoint_number=FLAGS.checkpoint_number,
        include_frontend=FLAGS.include_frontend)

    model_export_utils.convert_tflite_model(
        static_model, quantize=m.params['qat'], model_path=m.output_filename)

    if FLAGS.sanity_checks:
      logging.info('Sanity checking...')
      model_export_utils.sanity_check(
          FLAGS.include_frontend,
          m.output_filename,
          embedding_dim=m.params['bd'],
          tflite=True)

  logging.info('Total TFLite models generated: %i', len(metadata))
Exemple #3
0
def convert_and_write_model(m, include_frontend, sanity_check):
    """Convert model and write to disk for data prep."""
    logging.info('Working on experiment dir: %s', m.experiment_dir)

    tflite_friendly = m.conversion_type == TFLITE_

    model = model_export_utils.get_model(
        checkpoint_folder_path=m.experiment_dir,
        params=m.params,
        tflite_friendly=tflite_friendly,
        checkpoint_number=None,
        include_frontend=include_frontend)
    if not tf.io.gfile.exists(os.path.dirname(m.output_filename)):
        raise ValueError(
            f'Existing dir didn\'t exist: {os.path.dirname(m.output_filename)}'
        )
    if tflite_friendly:
        model_export_utils.convert_tflite_model(model,
                                                quantize=m.params['qat'],
                                                model_path=m.output_filename)
    else:
        assert m.conversion_type == SAVEDMODEL_
        tf.keras.models.save_model(model, m.output_filename)
    if not tf.io.gfile.exists(m.output_filename):
        raise ValueError(f'Not written: {m.output_filename}')

    if sanity_check:
        logging.info('Sanity checking...')

        def _p_or_flag(k):
            return m.params[k] if k in m.params else getattr(flags.FLAGS, k)

        model_export_utils.sanity_check(
            include_frontend=include_frontend,
            model_path=m.output_filename,
            embedding_dim=m.params['bd'],
            tflite=m.conversion_type == TFLITE_,
            n_required=_p_or_flag('n_required'),
            frame_width=_p_or_flag('frame_width'),
            num_mel_bins=_p_or_flag('num_mel_bins'))