예제 #1
0
    ap = AudioProcessor(data_dirs=data_dirs,
                        wanted_words=classes,
                        silence_percentage=13.0,
                        unknown_percentage=60.0,
                        validation_percentage=10.0,
                        testing_percentage=10.0,
                        model_settings=model_settings,
                        output_representation=output_representation)
    train_gen = data_gen(ap, sess, batch_size=batch_size, mode='training')
    val_gen = data_gen(ap, sess, batch_size=batch_size, mode='validation')

    model = speech_model(
        'conv_1d_time_stacked',
        model_settings['fingerprint_size'] if output_representation != 'raw'
        else model_settings['desired_samples'],
        # noqa
        num_classes=model_settings['label_count'],
        **model_settings)

    # embed()
    checkpoints_path = os.path.join('checkpoints', 'conv_bite_new')
    if not os.path.exists(checkpoints_path):
        os.makedirs(checkpoints_path)

    callbacks = [
        #   ConfusionMatrixCallback(
        #       val_gen,
        #       ap.set_size('validation') // batch_size,
        #       wanted_words=prepare_words_list(get_classes(wanted_only=True)),
        #       all_words=prepare_words_list(classes),
예제 #2
0
    ap = AudioProcessor(data_dirs=data_dirs,
                        wanted_words=classes,
                        silence_percentage=13.0,
                        unknown_percentage=60.0,
                        validation_percentage=10.0,
                        testing_percentage=0.0,
                        model_settings=model_settings,
                        output_representation=output_representation)
    train_gen = data_gen(ap, sess, batch_size=batch_size, mode='training')
    data = next(train_gen)
    print(data[0].shape)
    val_gen = data_gen(ap, sess, batch_size=batch_size, mode='validation')

    model = speech_model(model_settings['desired_samples'],
                         num_classes=model_settings['label_count'],
                         **model_settings)

    # embed()
    checkpoint_path = 'ml/checkpoints/spectrogram_model/' + \
                      datetime.now().strftime("%Y%m%d-%H%M%S") + "/cp.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)
    callbacks = [
        ConfusionMatrixCallback(val_gen,
                                ap.set_size('validation') // batch_size,
                                wanted_words=prepare_words_list(
                                    get_classes(wanted_only=True)),
                                all_words=prepare_words_list(classes),
                                label2int=ap.word_to_index),
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_categorical_accuracy',
예제 #3
0
파일: train.py 프로젝트: indranig/examples
  ap = AudioProcessor(
      data_dirs=data_dirs,
      wanted_words=classes,
      silence_percentage=13.0,
      unknown_percentage=60.0,
      validation_percentage=10.0,
      testing_percentage=0.0,
      model_settings=model_settings,
      output_representation=output_representation)
  train_gen = data_gen(ap, sess, batch_size=batch_size, mode='training')
  val_gen = data_gen(ap, sess, batch_size=batch_size, mode='validation')

  model = speech_model(
      'conv_1d_time_stacked',
      model_settings['fingerprint_size']
      if output_representation != 'raw' else model_settings['desired_samples'],
      # noqa
      num_classes=model_settings['label_count'],
      **model_settings)

  # embed()
  checkpoints_path = os.path.join('checkpoints', 'conv_1d_time_stacked_model')
  if not os.path.exists(checkpoints_path):
    os.makedirs(checkpoints_path)

  callbacks = [
      ConfusionMatrixCallback(
          val_gen,
          ap.set_size('validation') // batch_size,
          wanted_words=prepare_words_list(get_classes(wanted_only=True)),
          all_words=prepare_words_list(classes),
import tensorflow as tf
import argparse
from model import speech_model

parser = argparse.ArgumentParser(description='set input arguments')

parser.add_argument(
    '-checkpoint_dir',
    action='store',
    type=str,
    default=16000,
    help='model weights checkpoint path')

args = parser.parse_args()

model = speech_model(16000,
                     num_classes=7)
# check the latest directory and pass it
checkpoint_path = args.checkpoint_dir + 'cp.ckpt'
model.load_weights(checkpoint_path)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tf_lite_model = converter.convert()
open("ml/models/tflite_model.tflite", 'wb').write(tf_lite_model)