Exemplo n.º 1
0
def main(_):
  # We want to see all the logging messages for this tutorial.
  tf.logging.set_verbosity(tf.logging.INFO)

  # Start a new TensorFlow session.
  sess = tf.InteractiveSession()

  # Begin by making sure we have the training data we need. If you already have
  # training data of your own, use `--data_url= ` on the command line to avoid
  # downloading.
  model_settings = models.prepare_model_settings(
      len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
      FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
      FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess)
  audio_processor = input_data.AudioProcessor(
      FLAGS.data_url, FLAGS.data_dir,
      FLAGS.silence_percentage, FLAGS.unknown_percentage,
      FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
      FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir)
  fingerprint_size = model_settings['fingerprint_size']
  label_count = model_settings['label_count']
  time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
  # Figure out the learning rates for each training phase. Since it's often
  # effective to have high learning rates at the start of training, followed by
  # lower levels towards the end, the number of steps and learning rates can be
  # specified as comma-separated lists to define the rate at each stage. For
  # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
  # will run 13,000 training loops in total, with a rate of 0.001 for the first
  # 10,000, and 0.0001 for the final 3,000.
  training_steps_list = list(map(int, FLAGS.how_many_training_steps.split(',')))
  learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
  if len(training_steps_list) != len(learning_rates_list):
    raise Exception(
        '--how_many_training_steps and --learning_rate must be equal length '
        'lists, but are %d and %d long instead' % (len(training_steps_list),
                                                   len(learning_rates_list)))

  input_placeholder = tf.placeholder(
      tf.float32, [None, fingerprint_size], name='fingerprint_input')
  if FLAGS.quantize:
    fingerprint_min, fingerprint_max = input_data.get_features_range(
        model_settings)
    fingerprint_input = tf.fake_quant_with_min_max_args(
        input_placeholder, fingerprint_min, fingerprint_max)
  else:
    fingerprint_input = input_placeholder

  logits, dropout_prob = models.create_model(
      fingerprint_input,
      model_settings,
      FLAGS.model_architecture,
      is_training=True)

  # Define loss and optimizer
  ground_truth_input = tf.placeholder(
      tf.int64, [None], name='groundtruth_input')

  # Optionally we can add runtime checks to spot when NaNs or other symptoms of
  # numerical errors start occurring during training.
  control_dependencies = []
  if FLAGS.check_nans:
    checks = tf.add_check_numerics_ops()
    control_dependencies = [checks]

  # Create the back propagation and training evaluation machinery in the graph.
  with tf.name_scope('cross_entropy'):
    cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
        labels=ground_truth_input, logits=logits)
  if FLAGS.quantize:
    tf.contrib.quantize.create_training_graph(quant_delay=0)
  with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
    learning_rate_input = tf.placeholder(
        tf.float32, [], name='learning_rate_input')
    train_step = tf.train.GradientDescentOptimizer(
        learning_rate_input).minimize(cross_entropy_mean)
  predicted_indices = tf.argmax(logits, 1)
  correct_prediction = tf.equal(predicted_indices, ground_truth_input)
  confusion_matrix = tf.confusion_matrix(
      ground_truth_input, predicted_indices, num_classes=label_count)
  evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  with tf.get_default_graph().name_scope('eval'):
    tf.summary.scalar('cross_entropy', cross_entropy_mean)
    tf.summary.scalar('accuracy', evaluation_step)

  global_step = tf.train.get_or_create_global_step()
  increment_global_step = tf.assign(global_step, global_step + 1)

  saver = tf.train.Saver(tf.global_variables())

  # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
  merged_summaries = tf.summary.merge_all(scope='eval')
  train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                       sess.graph)
  validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')

  tf.global_variables_initializer().run()

  start_step = 1

  if FLAGS.start_checkpoint:
    models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
    start_step = global_step.eval(session=sess)

  tf.logging.info('Training from step: %d ', start_step)

  # Save graph.pbtxt.
  tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
                       FLAGS.model_architecture + '.pbtxt')

  # Save list of words.
  with gfile.GFile(
      os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'),
      'w') as f:
    f.write('\n'.join(audio_processor.words_list))

  # Training loop.
  training_steps_max = np.sum(training_steps_list)
  for training_step in xrange(start_step, training_steps_max + 1):
    # Figure out what the current learning rate is.
    training_steps_sum = 0
    for i in range(len(training_steps_list)):
      training_steps_sum += training_steps_list[i]
      if training_step <= training_steps_sum:
        learning_rate_value = learning_rates_list[i]
        break
    # Pull the audio samples we'll use for training.
    train_fingerprints, train_ground_truth = audio_processor.get_data(
        FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
        FLAGS.background_volume, time_shift_samples, 'training', sess)
    # Run the graph with this batch of training data.
    train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
        [
            merged_summaries,
            evaluation_step,
            cross_entropy_mean,
            train_step,
            increment_global_step,
        ],
        feed_dict={
            fingerprint_input: train_fingerprints,
            ground_truth_input: train_ground_truth,
            learning_rate_input: learning_rate_value,
            dropout_prob: 0.5
        })
    train_writer.add_summary(train_summary, training_step)
    tf.logging.info('Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
                    (training_step, learning_rate_value, train_accuracy * 100,
                     cross_entropy_value))
    is_last_step = (training_step == training_steps_max)
    if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
      set_size = audio_processor.set_size('validation')
      total_accuracy = 0
      total_conf_matrix = None
      for i in xrange(0, set_size, FLAGS.batch_size):
        validation_fingerprints, validation_ground_truth = (
            audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0,
                                     0.0, 0, 'validation', sess))
        # Run a validation step and capture training summaries for TensorBoard
        # with the `merged` op.
        validation_summary, validation_accuracy, conf_matrix = sess.run(
            [merged_summaries, evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: validation_fingerprints,
                ground_truth_input: validation_ground_truth,
                dropout_prob: 1.0
            })
        validation_writer.add_summary(validation_summary, training_step)
        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (validation_accuracy * batch_size) / set_size
        if total_conf_matrix is None:
          total_conf_matrix = conf_matrix
        else:
          total_conf_matrix += conf_matrix
      tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
      tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
                      (training_step, total_accuracy * 100, set_size))

    # Save the model checkpoint periodically.
    if (training_step % FLAGS.save_step_interval == 0 or
        training_step == training_steps_max):
      checkpoint_path = os.path.join(FLAGS.train_dir,
                                     FLAGS.model_architecture + '.ckpt')
      tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step)
      saver.save(sess, checkpoint_path, global_step=training_step)

  set_size = audio_processor.set_size('testing')
  tf.logging.info('set_size=%d', set_size)
  total_accuracy = 0
  total_conf_matrix = None
  for i in xrange(0, set_size, FLAGS.batch_size):
    test_fingerprints, test_ground_truth = audio_processor.get_data(
        FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
    test_accuracy, conf_matrix = sess.run(
        [evaluation_step, confusion_matrix],
        feed_dict={
            fingerprint_input: test_fingerprints,
            ground_truth_input: test_ground_truth,
            dropout_prob: 1.0
        })
    batch_size = min(FLAGS.batch_size, set_size - i)
    total_accuracy += (test_accuracy * batch_size) / set_size
    if total_conf_matrix is None:
      total_conf_matrix = conf_matrix
    else:
      total_conf_matrix += conf_matrix
  tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
  tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % (total_accuracy * 100,
                                                           set_size))
Exemplo n.º 2
0
def main(_):
    # Set the verbosity based on flags (default is INFO, so we see all messages)
    tf.compat.v1.logging.set_verbosity(FLAGS.verbosity)

    # Start a new TensorFlow session.
    sess = tf.compat.v1.InteractiveSession()

    summaries_dir = os.path.join(FLAGS.train_dir, 'summaries')

    # Begin by making sure we have the training data we need. If you already have
    # training data of your own, use `--data_url= ` on the command line to avoid
    # downloading.
    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess)
    audio_processor = input_data.AudioProcessor(FLAGS.data_url, FLAGS.data_dir,
                                                FLAGS.silence_percentage,
                                                FLAGS.unknown_percentage,
                                                FLAGS.wanted_words.split(','),
                                                FLAGS.validation_percentage,
                                                FLAGS.testing_percentage,
                                                model_settings, summaries_dir)

    wav_file = FLAGS.wav

    fingerprint_size = model_settings['fingerprint_size']

    # Figure out the learning rates for each training phase. Since it's often
    # effective to have high learning rates at the start of training, followed by
    # lower levels towards the end, the number of steps and learning rates can be
    # specified as comma-separated lists to define the rate at each stage. For
    # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
    # will run 13,000 training loops in total, with a rate of 0.001 for the first
    # 10,000, and 0.0001 for the final 3,000.
    training_steps_list = list(
        map(int, FLAGS.how_many_training_steps.split(',')))
    learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
    if len(training_steps_list) != len(learning_rates_list):
        raise Exception(
            '--how_many_training_steps and --learning_rate must be equal length '
            'lists, but are %d and %d long instead' %
            (len(training_steps_list), len(learning_rates_list)))
    input_placeholder = tf.compat.v1.placeholder(tf.float32,
                                                 [None, fingerprint_size],
                                                 name='fingerprint_input')
    if FLAGS.quantize:
        fingerprint_min, fingerprint_max = input_data.get_features_range(
            model_settings)
        fingerprint_input = tf.quantization.fake_quant_with_min_max_args(
            input_placeholder, fingerprint_min, fingerprint_max)
    else:
        fingerprint_input = input_placeholder

    print('fingerprint input:', fingerprint_input)

    logits, dropout_prob = models.create_model(
        fingerprint_input,
        model_settings,
        FLAGS.model_architecture,
        is_training=True,
    )

    # Define loss and optimizer
    ground_truth_input = tf.compat.v1.placeholder(tf.int64, [None],
                                                  name='groundtruth_input')

    # Create the back propagation and training evaluation machinery in the graph.
    with tf.compat.v1.name_scope('cross_entropy'):
        cross_entropy_mean = tf.compat.v1.losses.sparse_softmax_cross_entropy(
            labels=ground_truth_input, logits=logits)
    if FLAGS.quantize:
        tf.contrib.quantize.create_training_graph(quant_delay=0)
    predicted_indices = tf.argmax(input=logits, axis=1)
    correct_prediction = tf.equal(predicted_indices, ground_truth_input)
    evaluation_step = tf.reduce_mean(
        input_tensor=tf.cast(correct_prediction, tf.float32))
    with tf.compat.v1.get_default_graph().name_scope('eval'):
        tf.compat.v1.summary.scalar('cross_entropy', cross_entropy_mean)
        tf.compat.v1.summary.scalar('accuracy', evaluation_step)

    global_step = tf.compat.v1.train.get_or_create_global_step()

    tf.compat.v1.global_variables_initializer().run()

    start_step = 1

    if FLAGS.checkpoint:
        models.load_variables_from_checkpoint(sess, FLAGS.checkpoint)
        start_step = global_step.eval(session=sess)
        tf.compat.v1.logging.info('Checkpoint: {}'.format(FLAGS.checkpoint))

    tf.compat.v1.logging.info(
        'Recovering checkpoint from step: {}'.format(start_step))

    input_features = audio_processor.get_features_for_wav(
        wav_file, model_settings, sess)
    print('features:', input_features)
    print('features:', len(input_features))
    input_features = input_features[0]
    print('features:', input_features.shape)
    input_features = np.expand_dims(input_features.flatten(), 0)

    y_pred = sess.run(predicted_indices,
                      feed_dict={
                          fingerprint_input: input_features,
                          dropout_prob: 1.0
                      })

    print('Predict:', y_pred)
    print('Label:', audio_processor.words_list[y_pred[0]])
def wav_to_features(sample_rate, clip_duration_ms, window_size_ms,
                    window_stride_ms, feature_bin_count, quantize, preprocess,
                    input_wav, output_c_file):
    """Converts an audio file into its corresponding feature map.

  Args:
    sample_rate: Expected sample rate of the wavs.
    clip_duration_ms: Expected duration in milliseconds of the wavs.
    window_size_ms: How long each spectrogram timeslice is.
    window_stride_ms: How far to move in time between spectogram timeslices.
    feature_bin_count: How many bins to use for the feature fingerprint.
    quantize: Whether to train the model for eight-bit deployment.
    preprocess: Spectrogram processing mode; "mfcc", "average" or "micro".
    input_wav: Path to the audio WAV file to read.
    output_c_file: Where to save the generated C source file.
  """

    # Start a new TensorFlow session.
    sess = tf.compat.v1.InteractiveSession()

    model_settings = models.prepare_model_settings(
        0, sample_rate, clip_duration_ms, window_size_ms, window_stride_ms,
        feature_bin_count, preprocess)
    audio_processor = input_data.AudioProcessor(None, None, 0, 0, '', 0, 0,
                                                model_settings, None)

    results = audio_processor.get_features_for_wav(input_wav, model_settings,
                                                   sess)
    features = results[0]

    variable_base = os.path.splitext(os.path.basename(input_wav).lower())[0]

    # Save a C source file containing the feature data as an array.
    with gfile.GFile(output_c_file, 'w') as f:
        f.write('/* File automatically created by\n')
        f.write(
            ' * tensorflow/examples/speech_commands/wav_to_features.py \\\n')
        f.write(' * --sample_rate=%d \\\n' % sample_rate)
        f.write(' * --clip_duration_ms=%d \\\n' % clip_duration_ms)
        f.write(' * --window_size_ms=%d \\\n' % window_size_ms)
        f.write(' * --window_stride_ms=%d \\\n' % window_stride_ms)
        f.write(' * --feature_bin_count=%d \\\n' % feature_bin_count)
        if quantize:
            f.write(' * --quantize=1 \\\n')
        f.write(' * --preprocess="%s" \\\n' % preprocess)
        f.write(' * --input_wav="%s" \\\n' % input_wav)
        f.write(' * --output_c_file="%s" \\\n' % output_c_file)
        f.write(' */\n\n')
        f.write('const int g_%s_width = %d;\n' %
                (variable_base, model_settings['fingerprint_width']))
        f.write('const int g_%s_height = %d;\n' %
                (variable_base, model_settings['spectrogram_length']))
        if quantize:
            features_min, features_max = input_data.get_features_range(
                model_settings)
            f.write('const unsigned char g_%s_data[] = {' % variable_base)
            i = 0
            for value in features.flatten():
                quantized_value = int(
                    round((255 * (value - features_min)) /
                          (features_max - features_min)))
                if quantized_value < 0:
                    quantized_value = 0
                if quantized_value > 255:
                    quantized_value = 255
                if i == 0:
                    f.write('\n  ')
                f.write('%d, ' % (quantized_value))
                i = (i + 1) % 10
        else:
            f.write('const float g_%s_data[] = {\n' % variable_base)
            i = 0
            for value in features.flatten():
                if i == 0:
                    f.write('\n  ')
                f.write(' ,%f' % value)
                i = (i + 1) % 10
        f.write('\n};\n')
Exemplo n.º 4
0
def wav_to_features(sample_rate, clip_duration_ms, window_size_ms,
                    window_stride_ms, feature_bin_count, quantize, preprocess,
                    input_wav, output_c_file):
  """Converts an audio file into its corresponding feature map.

  Args:
    sample_rate: Expected sample rate of the wavs.
    clip_duration_ms: Expected duration in milliseconds of the wavs.
    window_size_ms: How long each spectrogram timeslice is.
    window_stride_ms: How far to move in time between spectogram timeslices.
    feature_bin_count: How many bins to use for the feature fingerprint.
    quantize: Whether to train the model for eight-bit deployment.
    preprocess: Spectrogram processing mode; "mfcc", "average" or "micro".
    input_wav: Path to the audio WAV file to read.
    output_c_file: Where to save the generated C source file.
  """

  # Start a new TensorFlow session.
  sess = tf.InteractiveSession()

  model_settings = models.prepare_model_settings(
      0, sample_rate, clip_duration_ms, window_size_ms, window_stride_ms,
      feature_bin_count, preprocess)
  audio_processor = input_data.AudioProcessor(None, None, 0, 0, '', 0, 0,
                                              model_settings, None)

  results = audio_processor.get_features_for_wav(input_wav, model_settings,
                                                 sess)
  features = results[0]

  variable_base = os.path.splitext(os.path.basename(input_wav).lower())[0]

  # Save a C source file containing the feature data as an array.
  with gfile.GFile(output_c_file, 'w') as f:
    f.write('/* File automatically created by\n')
    f.write(' * tensorflow/examples/speech_commands/wav_to_features.py \\\n')
    f.write(' * --sample_rate=%d \\\n' % sample_rate)
    f.write(' * --clip_duration_ms=%d \\\n' % clip_duration_ms)
    f.write(' * --window_size_ms=%d \\\n' % window_size_ms)
    f.write(' * --window_stride_ms=%d \\\n' % window_stride_ms)
    f.write(' * --feature_bin_count=%d \\\n' % feature_bin_count)
    if quantize:
      f.write(' * --quantize=1 \\\n')
    f.write(' * --preprocess="%s" \\\n' % preprocess)
    f.write(' * --input_wav="%s" \\\n' % input_wav)
    f.write(' * --output_c_file="%s" \\\n' % output_c_file)
    f.write(' */\n\n')
    f.write('const int g_%s_width = %d;\n' %
            (variable_base, model_settings['fingerprint_width']))
    f.write('const int g_%s_height = %d;\n' %
            (variable_base, model_settings['spectrogram_length']))
    if quantize:
      features_min, features_max = input_data.get_features_range(model_settings)
      f.write('const unsigned char g_%s_data[] = {' % variable_base)
      i = 0
      for value in features.flatten():
        quantized_value = int(
            round(
                (255 * (value - features_min)) / (features_max - features_min)))
        if quantized_value < 0:
          quantized_value = 0
        if quantized_value > 255:
          quantized_value = 255
        if i == 0:
          f.write('\n  ')
        f.write('%d, ' % (quantized_value))
        i = (i + 1) % 10
    else:
      f.write('const float g_%s_data[] = {\n' % variable_base)
      i = 0
      for value in features.flatten():
        if i == 0:
          f.write('\n  ')
        f.write(' ,%f' % value)
        i = (i + 1) % 10
    f.write('\n};\n')
Exemplo n.º 5
0
def main(_):
  # We want to see all the logging messages for this tutorial.
  tf.logging.set_verbosity(tf.logging.INFO)

  # Start a new TensorFlow session.
  sess = tf.InteractiveSession()

  # Begin by making sure we have the training data we need. If you already have
  # training data of your own, use `--data_url= ` on the command line to avoid
  # downloading.
  model_settings = models.prepare_model_settings(
      len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
      FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
      FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess)
  audio_processor = input_data.AudioProcessor(
      FLAGS.data_url, FLAGS.data_dir,
      FLAGS.silence_percentage, FLAGS.unknown_percentage,
      FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
      FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir)
  fingerprint_size = model_settings['fingerprint_size']
  label_count = model_settings['label_count']
  time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
  # Figure out the learning rates for each training phase. Since it's often
  # effective to have high learning rates at the start of training, followed by
  # lower levels towards the end, the number of steps and learning rates can be
  # specified as comma-separated lists to define the rate at each stage. For
  # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
  # will run 13,000 training loops in total, with a rate of 0.001 for the first
  # 10,000, and 0.0001 for the final 3,000.
  training_steps_list = list(map(int, FLAGS.how_many_training_steps.split(',')))
  learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
  if len(training_steps_list) != len(learning_rates_list):
    raise Exception(
        '--how_many_training_steps and --learning_rate must be equal length '
        'lists, but are %d and %d long instead' % (len(training_steps_list),
                                                   len(learning_rates_list)))

  input_placeholder = tf.placeholder(
      tf.float32, [None, fingerprint_size], name='fingerprint_input')
  if FLAGS.quantize:
    fingerprint_min, fingerprint_max = input_data.get_features_range(
        model_settings)
    fingerprint_input = tf.fake_quant_with_min_max_args(
        input_placeholder, fingerprint_min, fingerprint_max)
  else:
    fingerprint_input = input_placeholder

  logits, dropout_prob = models.create_model(
      fingerprint_input,
      model_settings,
      FLAGS.model_architecture,
      is_training=True)

  # Define loss and optimizer
  ground_truth_input = tf.placeholder(
      tf.int64, [None], name='groundtruth_input')

  # Optionally we can add runtime checks to spot when NaNs or other symptoms of
  # numerical errors start occurring during training.
  control_dependencies = []
  if FLAGS.check_nans:
    checks = tf.add_check_numerics_ops()
    control_dependencies = [checks]

  # Create the back propagation and training evaluation machinery in the graph.
  with tf.name_scope('cross_entropy'):
    cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
        labels=ground_truth_input, logits=logits)
  if FLAGS.quantize:
    tf.contrib.quantize.create_training_graph(quant_delay=0)
  with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
    learning_rate_input = tf.placeholder(
        tf.float32, [], name='learning_rate_input')
    train_step = tf.train.GradientDescentOptimizer(
        learning_rate_input).minimize(cross_entropy_mean)
  predicted_indices = tf.argmax(logits, 1)
  correct_prediction = tf.equal(predicted_indices, ground_truth_input)
  confusion_matrix = tf.confusion_matrix(
      ground_truth_input, predicted_indices, num_classes=label_count)
  evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  with tf.get_default_graph().name_scope('eval'):
    tf.summary.scalar('cross_entropy', cross_entropy_mean)
    tf.summary.scalar('accuracy', evaluation_step)

  global_step = tf.train.get_or_create_global_step()
  increment_global_step = tf.assign(global_step, global_step + 1)

  saver = tf.train.Saver(tf.global_variables())

  # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
  merged_summaries = tf.summary.merge_all(scope='eval')
  train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                       sess.graph)
  validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')

  tf.global_variables_initializer().run()

  start_step = 1

  if FLAGS.start_checkpoint:
    models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
    start_step = global_step.eval(session=sess)

  tf.logging.info('Training from step: %d ', start_step)

  # Save graph.pbtxt.
  tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
                       FLAGS.model_architecture + '.pbtxt')

  # Save list of words.
  with gfile.GFile(
      os.path.join(FLAGS.train_dir, FLAGS.model_architecture + '_labels.txt'),
      'w') as f:
    f.write('\n'.join(audio_processor.words_list))

  # Training loop.
  training_steps_max = np.sum(training_steps_list)
  for training_step in xrange(start_step, training_steps_max + 1):
    # Figure out what the current learning rate is.
    training_steps_sum = 0
    for i in range(len(training_steps_list)):
      training_steps_sum += training_steps_list[i]
      if training_step <= training_steps_sum:
        learning_rate_value = learning_rates_list[i]
        break
    # Pull the audio samples we'll use for training.
    train_fingerprints, train_ground_truth = audio_processor.get_data(
        FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
        FLAGS.background_volume, time_shift_samples, 'training', sess)
    # Run the graph with this batch of training data.
    train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
        [
            merged_summaries,
            evaluation_step,
            cross_entropy_mean,
            train_step,
            increment_global_step,
        ],
        feed_dict={
            fingerprint_input: train_fingerprints,
            ground_truth_input: train_ground_truth,
            learning_rate_input: learning_rate_value,
            dropout_prob: 0.5
        })
    train_writer.add_summary(train_summary, training_step)
    tf.logging.info('Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
                    (training_step, learning_rate_value, train_accuracy * 100,
                     cross_entropy_value))
    is_last_step = (training_step == training_steps_max)
    if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
      set_size = audio_processor.set_size('validation')
      total_accuracy = 0
      total_conf_matrix = None
      for i in xrange(0, set_size, FLAGS.batch_size):
        validation_fingerprints, validation_ground_truth = (
            audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0,
                                     0.0, 0, 'validation', sess))
        # Run a validation step and capture training summaries for TensorBoard
        # with the `merged` op.
        validation_summary, validation_accuracy, conf_matrix = sess.run(
            [merged_summaries, evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: validation_fingerprints,
                ground_truth_input: validation_ground_truth,
                dropout_prob: 1.0
            })
        validation_writer.add_summary(validation_summary, training_step)
        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (validation_accuracy * batch_size) / set_size
        if total_conf_matrix is None:
          total_conf_matrix = conf_matrix
        else:
          total_conf_matrix += conf_matrix
      tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
      tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
                      (training_step, total_accuracy * 100, set_size))

    # Save the model checkpoint periodically.
    if (training_step % FLAGS.save_step_interval == 0 or
        training_step == training_steps_max):
      checkpoint_path = os.path.join(FLAGS.train_dir,
                                     FLAGS.model_architecture + '.ckpt')
      tf.logging.info('Saving to "%s-%d"', checkpoint_path, training_step)
      saver.save(sess, checkpoint_path, global_step=training_step)

  set_size = audio_processor.set_size('testing')
  tf.logging.info('set_size=%d', set_size)
  total_accuracy = 0
  total_conf_matrix = None
  for i in xrange(0, set_size, FLAGS.batch_size):
    test_fingerprints, test_ground_truth = audio_processor.get_data(
        FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
    test_accuracy, conf_matrix = sess.run(
        [evaluation_step, confusion_matrix],
        feed_dict={
            fingerprint_input: test_fingerprints,
            ground_truth_input: test_ground_truth,
            dropout_prob: 1.0
        })
    batch_size = min(FLAGS.batch_size, set_size - i)
    total_accuracy += (test_accuracy * batch_size) / set_size
    if total_conf_matrix is None:
      total_conf_matrix = conf_matrix
    else:
      total_conf_matrix += conf_matrix
  tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
  tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % (total_accuracy * 100,
                                                           set_size))
Exemplo n.º 6
0
def main(_):
    # Set the verbosity based on flags (default is INFO, so we see all messages)
    tf.logging.set_verbosity(FLAGS.verbosity)

    # Start a new TensorFlow session.
    sess = tf.InteractiveSession()

    # Begin by making sure we have the training data we need. If you already have
    # training data of your own, use `--data_url= ` on the command line to avoid
    # downloading.
    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess)
    audio_processor = input_data.AudioProcessor(None, FLAGS.data_dir,
                                                FLAGS.silence_percentage,
                                                FLAGS.unknown_percentage,
                                                FLAGS.wanted_words.split(','),
                                                FLAGS.validation_percentage,
                                                FLAGS.testing_percentage,
                                                model_settings, None)

    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']
    # Figure out the learning rates for each training phase. Since it's often
    # effective to have high learning rates at the start of training, followed by
    # lower levels towards the end, the number of steps and learning rates can be
    # specified as comma-separated lists to define the rate at each stage. For
    # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
    # will run 13,000 training loops in total, with a rate of 0.001 for the first
    # 10,000, and 0.0001 for the final 3,000.
    training_steps_list = list(
        map(int, FLAGS.how_many_training_steps.split(',')))
    learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
    if len(training_steps_list) != len(learning_rates_list):
        raise Exception(
            '--how_many_training_steps and --learning_rate must be equal length '
            'lists, but are %d and %d long instead' %
            (len(training_steps_list), len(learning_rates_list)))
    input_placeholder = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')
    if FLAGS.quantize:
        fingerprint_min, fingerprint_max = input_data.get_features_range(
            model_settings)
        fingerprint_input = tf.quantization.fake_quant_with_min_max_args(
            input_placeholder, fingerprint_min, fingerprint_max)
    else:
        fingerprint_input = input_placeholder

    logits = models.create_model(
        fingerprint_input,
        model_settings,
        FLAGS.model_architecture,
        is_training=False,
    )

    # Define loss and optimizer
    ground_truth_input = tf.placeholder(tf.int64, [None],
                                        name='groundtruth_input')

    # Create the back propagation and training evaluation machinery in the graph.
    with tf.name_scope('cross_entropy'):
        cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
            labels=ground_truth_input, logits=logits)
    if FLAGS.quantize:
        tf.contrib.quantize.create_training_graph(quant_delay=0)
    predicted_indices = tf.argmax(input=logits, axis=1)
    correct_prediction = tf.equal(predicted_indices, ground_truth_input)
    confusion_matrix = tf.math.confusion_matrix(labels=ground_truth_input,
                                                predictions=predicted_indices,
                                                num_classes=label_count)
    evaluation_step = tf.reduce_mean(
        input_tensor=tf.cast(correct_prediction, tf.float32))

    global_step = tf.train.get_or_create_global_step()

    tf.global_variables_initializer().run()

    start_step = 1

    if FLAGS.checkpoint:
        models.load_variables_from_checkpoint(sess, FLAGS.checkpoint)
        start_step = global_step.eval(session=sess)
        tf.logging.info('Checkpoint: {}'.format(FLAGS.checkpoint))

    tf.logging.info('Training from step: {}'.format(start_step))

    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)
    total_accuracy = 0
    total_conf_matrix = None
    for i in range(0, set_size, FLAGS.batch_size):
        test_fingerprints, test_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
        test_accuracy, conf_matrix = sess.run(
            [evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: test_fingerprints,
                ground_truth_input: test_ground_truth
            })
        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (test_accuracy * batch_size) / set_size
        if total_conf_matrix is None:
            total_conf_matrix = conf_matrix
        else:
            total_conf_matrix += conf_matrix
    tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
    tf.logging.info('Final test accuracy = %.1f%% (N=%d)' %
                    (total_accuracy * 100, set_size))
Exemplo n.º 7
0
def main(_):
  # Set the verbosity based on flags (default is INFO, so we see all messages)
  tf.compat.v1.logging.set_verbosity(FLAGS.verbosity)

  # Start a new TensorFlow session.
  sess = tf.compat.v1.InteractiveSession()

  # Begin by making sure we have the training data we need. If you already have
  # training data of your own, use `--data_url= ` on the command line to avoid
  # downloading.
  model_settings = models.prepare_model_settings(
      len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
      FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
      FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess,
      FLAGS.in_repeat
  )
  model_settings['n_hidden'] = FLAGS.n_hidden
  model_settings['n_layer'] = FLAGS.n_layer
  model_settings['dropout_prob'] = FLAGS.dropout_prob
  model_settings['n_lif_frac'] = FLAGS.n_lif_frac
  model_settings['tau'] = FLAGS.tau
  model_settings['refr'] = FLAGS.refr
  model_settings['beta'] = FLAGS.beta
  model_settings['n_thr_spikes'] = FLAGS.n_thr_spikes
  model_settings['n_delay'] = FLAGS.n_delay
  model_settings['eprop'] = FLAGS.eprop
  model_settings['random_eprop'] = FLAGS.random_eprop
  model_settings['avg_spikes'] = FLAGS.avg_spikes
  # model_settings['in_repeat'] = FLAGS.in_repeat
  audio_processor = input_data.AudioProcessor(
      FLAGS.data_url, FLAGS.data_dir,
      FLAGS.silence_percentage, FLAGS.unknown_percentage,
      FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
      FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir,
      FLAGS.n_thr_spikes, FLAGS.in_repeat
  )
  time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)
  # Figure out the learning rates for each training phase. Since it's often
  # effective to have high learning rates at the start of training, followed by
  # lower levels towards the end, the number of steps and learning rates can be
  # specified as comma-separated lists to define the rate at each stage. For
  # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
  # will run 13,000 training loops in total, with a rate of 0.001 for the first
  # 10,000, and 0.0001 for the final 3,000.
  training_steps_list = list(map(int, FLAGS.how_many_training_steps.split(',')))
  learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
  if len(training_steps_list) != len(learning_rates_list):
    raise Exception(
        '--how_many_training_steps and --learning_rate must be equal length '
        'lists, but are %d and %d long instead' % (len(training_steps_list),
                                                   len(learning_rates_list)))
  n_thr_spikes = max(1, FLAGS.n_thr_spikes)
  training_placeholder = tf.compat.v1.placeholder(tf.bool, name='is_training')
  input_placeholder = tf.compat.v1.placeholder(
    tf.float32, [None, model_settings['fingerprint_size'] * (2 * n_thr_spikes - 1) * model_settings['in_repeat']],
    name='fingerprint_input')
  if FLAGS.quantize:
    fingerprint_min, fingerprint_max = input_data.get_features_range(
        model_settings)
    fingerprint_input = tf.quantization.fake_quant_with_min_max_args(
        input_placeholder, fingerprint_min, fingerprint_max)
  else:
    fingerprint_input = input_placeholder

  model_out = models.create_model(
      fingerprint_input,
      model_settings,
      FLAGS.model_architecture,
      is_training=training_placeholder)

  if FLAGS.model_architecture == 'lsnn':
    logits, spikes, dropout_prob = model_out
    av = tf.reduce_mean(spikes, axis=(0, 1))
  else:
    logits, dropout_prob = model_out

  # Define loss and optimizer
  ground_truth_input = tf.compat.v1.placeholder(
      tf.int64, [None], name='groundtruth_input')

  # Optionally we can add runtime checks to spot when NaNs or other symptoms of
  # numerical errors start occurring during training.
  control_dependencies = []
  if FLAGS.check_nans:
    checks = tf.compat.v1.add_check_numerics_ops()
    control_dependencies = [checks]

  # Create the back propagation and training evaluation machinery in the graph.
  with tf.compat.v1.name_scope('cross_entropy'):
    cross_entropy_mean = tf.compat.v1.losses.sparse_softmax_cross_entropy(
        labels=ground_truth_input, logits=logits)

  if FLAGS.model_architecture == 'lsnn':
    regularization_f0 = 10 / 1000  # 10Hz
    loss_reg = tf.reduce_sum(tf.square(av - regularization_f0) * FLAGS.reg)
    cross_entropy_mean += loss_reg

  if FLAGS.quantize:
    try:
      tf.contrib.quantize.create_training_graph(quant_delay=0)
    except ImportError as e:
      msg = e.args[0]
      msg += ('\n\n The --quantize option still requires contrib, which is not '
              'part of TensorFlow 2.0. Please install a previous version:'
              '\n    `pip install tensorflow<=1.15`')
      e.args = (msg,)
      raise e

  with tf.compat.v1.name_scope('train'), tf.control_dependencies(
      control_dependencies):
    learning_rate_input = tf.compat.v1.placeholder(
        tf.float32, [], name='learning_rate_input')
    if FLAGS.optimizer == 'gradient_descent':
      train_step = tf.compat.v1.train.GradientDescentOptimizer(
          learning_rate_input).minimize(cross_entropy_mean)
    elif FLAGS.optimizer == 'momentum':
      train_step = tf.compat.v1.train.MomentumOptimizer(
          learning_rate_input, .9,
          use_nesterov=True).minimize(cross_entropy_mean)
    elif FLAGS.optimizer == 'adam':
      train_step = tf.compat.v1.train.AdamOptimizer(
          learning_rate_input).minimize(cross_entropy_mean)
    else:
      raise Exception('Invalid Optimizer')
  predicted_indices = tf.argmax(input=logits, axis=1)
  correct_prediction = tf.equal(predicted_indices, ground_truth_input)
  confusion_matrix = tf.math.confusion_matrix(labels=ground_truth_input,
                                              predictions=predicted_indices,
                                              num_classes=model_settings['label_count'])
  evaluation_step = tf.reduce_mean(input_tensor=tf.cast(correct_prediction,
                                                        tf.float32))
  with tf.compat.v1.get_default_graph().name_scope('eval'):
    tf.compat.v1.summary.scalar('cross_entropy', cross_entropy_mean)
    tf.compat.v1.summary.scalar('accuracy', evaluation_step)

  global_step = tf.compat.v1.train.get_or_create_global_step()
  increment_global_step = tf.compat.v1.assign(global_step, global_step + 1)

  saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())

  # Merge all the summaries and write them out to results/retrain_logs (by default)
  merged_summaries = tf.compat.v1.summary.merge_all(scope='eval')
  train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                                 sess.graph)
  validation_writer = tf.compat.v1.summary.FileWriter(
      FLAGS.summaries_dir + '/validation')

  tf.compat.v1.global_variables_initializer().run()

  start_step = 1

  if FLAGS.start_checkpoint:
    models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
    start_step = global_step.eval(session=sess)

  tf.compat.v1.logging.info('Training from step: %d ', start_step)

  # Save graph.pbtxt.
  tf.io.write_graph(sess.graph_def, FLAGS.train_dir,
                    stored_name + '.pbtxt')

  # Save list of words.
  with gfile.GFile(
      os.path.join(FLAGS.train_dir, stored_name + '_labels.txt'),
      'w') as f:
    f.write('\n'.join(audio_processor.words_list))

  # Training loop.
  performance_metrics = {'val': [], 'test': [], 'firing_rates': []}
  training_steps_max = np.sum(training_steps_list)
  for training_step in xrange(start_step, training_steps_max + 1):
    # Figure out what the current learning rate is.
    training_steps_sum = 0
    for i in range(len(training_steps_list)):
      training_steps_sum += training_steps_list[i]
      if training_step <= training_steps_sum:
        learning_rate_value = learning_rates_list[i]
        break
    # Pull the audio samples we'll use for training.
    train_fingerprints, train_ground_truth = audio_processor.get_data(
        FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
        FLAGS.background_volume, time_shift_samples, 'training', sess)
    # Run the graph with this batch of training data.
    train_nodes = [
            merged_summaries,
            evaluation_step,
            cross_entropy_mean,
            train_step,
            increment_global_step,
        ]

    train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
        train_nodes,
        feed_dict={
            fingerprint_input: train_fingerprints,
            ground_truth_input: train_ground_truth,
            learning_rate_input: learning_rate_value,
            dropout_prob: FLAGS.dropout_prob,
            training_placeholder: True,
        })
    train_writer.add_summary(train_summary, training_step)
    if training_step % FLAGS.print_every == 0:
        if FLAGS.model_architecture != 'lsnn':
            tf.compat.v1.logging.info(
                'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
                (training_step, learning_rate_value, train_accuracy * 100,
                 cross_entropy_value))
        else:

            tf.compat.v1.logging.info(
                'Step #%d: rate %.4f, accuracy %.1f%%, cross entropy %.3f' %
                (training_step, learning_rate_value, train_accuracy * 100, cross_entropy_value))
    is_last_step = (training_step == training_steps_max)
    if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
      set_size = audio_processor.set_size('validation')
      total_accuracy = 0
      total_conf_matrix = None
      for i in xrange(0, set_size, FLAGS.batch_size):
        validation_fingerprints, validation_ground_truth = (
            audio_processor.get_data(FLAGS.batch_size, i, model_settings, 0.0,
                                     0.0, 0, 'validation', sess))
        # Run a validation step and capture training summaries for TensorBoard
        # with the `merged` op.
        val_nodes = [merged_summaries, evaluation_step, confusion_matrix]
        if FLAGS.model_architecture == 'lsnn':
            val_nodes.append(spikes)

        val_nodes_results = sess.run(
            val_nodes,
            feed_dict={
                fingerprint_input: validation_fingerprints,
                ground_truth_input: validation_ground_truth,
                dropout_prob: 1.0,
                training_placeholder: False,
            })
        if FLAGS.model_architecture == 'lsnn':
            validation_summary, validation_accuracy, conf_matrix, val_spikes = val_nodes_results
        else:
            validation_summary, validation_accuracy, conf_matrix = val_nodes_results
        validation_writer.add_summary(validation_summary, training_step)
        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (validation_accuracy * batch_size) / set_size
        if total_conf_matrix is None:
          total_conf_matrix = conf_matrix
        else:
          total_conf_matrix += conf_matrix
        if FLAGS.model_architecture == 'lsnn':
            neuron_rates = np.mean(val_spikes, axis=(0, 1)) * 1000
            firing_stats = [np.mean(neuron_rates), np.min(neuron_rates), np.max(neuron_rates)]
      performance_metrics['val'].append(total_accuracy)
      tf.compat.v1.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
      tf.compat.v1.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
                                (training_step, total_accuracy * 100, set_size))
      if FLAGS.model_architecture == 'lsnn':
          tf.compat.v1.logging.info('Firing rates: avg %.1f min %.1f max %.1f' %
                                    (firing_stats[0], firing_stats[1], firing_stats[2]))
          performance_metrics['firing_rates'].append('avg %.1f min %.1f max %.1f' %
                                                     (firing_stats[0], firing_stats[1], firing_stats[2]))
      with open(os.path.join(FLAGS.summaries_dir, 'performance.json'), 'w') as f:
          json.dump({**performance_metrics, 'flags': {**vars(FLAGS)}}, f, indent=4, sort_keys=True)

    # Save the model checkpoint periodically.
    if (training_step % FLAGS.save_step_interval == 0 or
        training_step == training_steps_max):
      checkpoint_path = os.path.join(FLAGS.train_dir,
                                     stored_name + '.ckpt')
      tf.compat.v1.logging.info('Saving to "%s-%d"', checkpoint_path,
                                training_step)
      saver.save(sess, checkpoint_path, global_step=training_step)

  set_size = audio_processor.set_size('testing')
  tf.compat.v1.logging.info('set_size=%d', set_size)
  total_accuracy = 0
  total_conf_matrix = None
  for i in xrange(0, set_size, FLAGS.batch_size):
    test_fingerprints, test_ground_truth = audio_processor.get_data(
        FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
    test_accuracy, conf_matrix = sess.run(
        [evaluation_step, confusion_matrix],
        feed_dict={
            fingerprint_input: test_fingerprints,
            ground_truth_input: test_ground_truth,
            dropout_prob: 1.0,
            training_placeholder: False,
        })
    batch_size = min(FLAGS.batch_size, set_size - i)
    total_accuracy += (test_accuracy * batch_size) / set_size
    if total_conf_matrix is None:
      total_conf_matrix = conf_matrix
    else:
      total_conf_matrix += conf_matrix
  tf.compat.v1.logging.warn('Confusion Matrix:\n %s' % (total_conf_matrix))
  tf.compat.v1.logging.warn('Final test accuracy = %.1f%% (N=%d)' %
                            (total_accuracy * 100, set_size))
  performance_metrics['test'].append(total_accuracy)

  with open(os.path.join(FLAGS.summaries_dir, 'performance.json'), 'w') as f:
    json.dump({**performance_metrics, 'flags': {**vars(FLAGS)}}, f, indent=4, sort_keys=True)
Exemplo n.º 8
0
def main(_):

    ########################################################################
    #                     INITIALISATION AND SETUP
    ########################################################################
    print('Initialising...')

    # We want to see all the logging messages for this tutorial.
    tf.logging.set_verbosity(tf.logging.INFO)

    # Start a new TensorFlow session.
    sess = tf.InteractiveSession()

    # Output key parameters about audio inputs, preprocessing and analysis
    model_settings = models.prepare_model_settings(
        label_count=len(
            input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        sample_rate=FLAGS.sample_rate,
        clip_duration_ms=FLAGS.clip_duration_ms,
        window_size_ms=FLAGS.window_size_ms,
        window_stride_ms=FLAGS.window_stride_ms,
        feature_bin_count=FLAGS.feature_bin_count,
        preprocess=FLAGS.preprocess)
    time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) /
                             1000)  # Why is this not in model_settings
    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']

    # Create instance of AudioProcessor class (which handles loading, partitioning,
    # and preparing audio training data)
    print('Loading audio data...')
    audio_processor = input_data.AudioProcessor(
        FLAGS.data_url, FLAGS.data_dir,
        FLAGS.silence_percentage, FLAGS.unknown_percentage,
        FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
        FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir)

    # Figure out the learning rates for each training step
    '''Since it's often effective to have high learning rates at the start of
    training, followed by lower levels towards the end, the number of steps
    and learning rates can be specified as comma-separated lists to define the
    rate at each stage. For example --how_many_training_steps=10000,3000
    --learning_rate=0.001,0.0001 will run 13,000 training loops in total, with
    a rate of 0.001 for the first 10,000, and 0.0001 for the final 3,000'''
    training_steps_list = list(
        map(int, FLAGS.how_many_training_steps.split(',')))
    learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
    if len(training_steps_list) != len(learning_rates_list):
        raise Exception(
            '--how_many_training_steps and --learning_rate must be equal length '
            'lists, but are %d and %d long instead' %
            (len(training_steps_list), len(learning_rates_list)))

    ########################################################################
    #                             DEFINE THE MODEL
    ########################################################################
    print('Defining model...')

    # Initialise input placeholder
    input_placeholder = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')

    # If process needs to be compressed (e.g. for mobile applications)
    if FLAGS.quantize:
        fingerprint_min, fingerprint_max = input_data.get_features_range(
            model_settings)
        fingerprint_input = tf.fake_quant_with_min_max_args(
            input_placeholder, fingerprint_min, fingerprint_max)
    else:  # otherwise...
        fingerprint_input = input_placeholder

    # Initialise ground truth placeholder
    ground_truth_input = tf.placeholder(tf.int64, [None],
                                        name='groundtruth_input')

    # Get the actual model
    '''Default mode architecture is conv'''
    logits, dropout_prob = models.create_model(
        fingerprint_input=fingerprint_input,
        model_settings=model_settings,
        model_architecture=FLAGS.model_architecture,
        is_training=True)

    # Optionally we can add runtime checks to spot when NaNs or other symptoms of
    # numerical errors start occurring during training.
    '''This will add a check to make sure all numbers of valid (i.e. no NaNas or Inf'''
    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    # Define the loss function
    with tf.name_scope('cross_entropy'):
        cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
            labels=ground_truth_input, logits=logits)

    # Quantize the network (i.e. reduce the number of bits that represent a number)
    if FLAGS.quantize:
        tf.contrib.quantize.create_training_graph(quant_delay=0)

    # Define the training operation
    ''' What you specify in the argument to control_dependencies is ensured to be
    evaluated before anything you define in the with block. I image here, the
    tf.control_dependencies() is used to make sure that there are no invalid
    numbers in the system, before the training step is executed'''
    with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
        learning_rate_input = tf.placeholder(tf.float32, [],
                                             name='learning_rate_input')
        train_step = tf.train.GradientDescentOptimizer(
            learning_rate_input).minimize(cross_entropy_mean)

    # Define operation to get model predictions (i.e. output with highest activation value)
    predicted_indices = tf.argmax(logits, 1)

    # Compare predicted with true labels
    correct_prediction = tf.equal(predicted_indices, ground_truth_input)

    confusion_matrix = tf.confusion_matrix(ground_truth_input,
                                           predicted_indices,
                                           num_classes=label_count)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    with tf.get_default_graph().name_scope('eval'):
        tf.summary.scalar('cross_entropy', cross_entropy_mean)
        tf.summary.scalar('accuracy', evaluation_step)

    global_step = tf.train.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    saver = tf.train.Saver(tf.global_variables())

    # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
    merged_summaries = tf.summary.merge_all(scope='eval')
    train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                         sess.graph)
    validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir +
                                              '/validation')

    tf.global_variables_initializer().run()

    start_step = 1

    if FLAGS.start_checkpoint:
        models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
        start_step = global_step.eval(session=sess)

    tf.logging.info('Training from step: %d ', start_step)

    # Save graph.pbtxt.
    tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
                         FLAGS.model_architecture + '.pbtxt')

    # Save list of words.
    with gfile.GFile(
            os.path.join(FLAGS.train_dir,
                         FLAGS.model_architecture + '_labels.txt'), 'w') as f:
        f.write('\n'.join(audio_processor.words_list))

    # Training loop.
    print('Training...')
    training_steps_max = np.sum(training_steps_list)
    for training_step in xrange(start_step, training_steps_max + 1):
        # Figure out what the current learning rate is.
        training_steps_sum = 0
        for i in range(len(training_steps_list)):
            training_steps_sum += training_steps_list[i]
            if training_step <= training_steps_sum:
                learning_rate_value = learning_rates_list[i]
                break
        # Pull the audio samples we'll use for training.
        train_fingerprints, train_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
            FLAGS.background_volume, time_shift_samples, 'training', sess)
        # Run the graph with this batch of training data.
        train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
            [
                merged_summaries,
                evaluation_step,
                cross_entropy_mean,
                train_step,
                increment_global_step,
            ],
            feed_dict={
                fingerprint_input: train_fingerprints,
                ground_truth_input: train_ground_truth,
                learning_rate_input: learning_rate_value,
                dropout_prob: 0.5
            })
        train_writer.add_summary(train_summary, training_step)
        tf.logging.info(
            'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
            (training_step, learning_rate_value, train_accuracy * 100,
             cross_entropy_value))
        is_last_step = (training_step == training_steps_max)
        if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
            set_size = audio_processor.set_size('validation')
            total_accuracy = 0
            total_conf_matrix = None
            for i in xrange(0, set_size, FLAGS.batch_size):
                validation_fingerprints, validation_ground_truth = (
                    audio_processor.get_data(FLAGS.batch_size, i,
                                             model_settings, 0.0, 0.0, 0,
                                             'validation', sess))
                # Run a validation step and capture training summaries for TensorBoard
                # with the `merged` op.
                validation_summary, validation_accuracy, conf_matrix = sess.run(
                    [merged_summaries, evaluation_step, confusion_matrix],
                    feed_dict={
                        fingerprint_input: validation_fingerprints,
                        ground_truth_input: validation_ground_truth,
                        dropout_prob: 1.0
                    })
                validation_writer.add_summary(validation_summary,
                                              training_step)
                batch_size = min(FLAGS.batch_size, set_size - i)
                total_accuracy += (validation_accuracy * batch_size) / set_size
                if total_conf_matrix is None:
                    total_conf_matrix = conf_matrix
                else:
                    total_conf_matrix += conf_matrix
            tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
            tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
                            (training_step, total_accuracy * 100, set_size))

        # Save the model checkpoint periodically.
        if (training_step % FLAGS.save_step_interval == 0
                or training_step == training_steps_max):
            checkpoint_path = os.path.join(FLAGS.train_dir,
                                           FLAGS.model_architecture + '.ckpt')
            tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                            training_step)
            saver.save(sess, checkpoint_path, global_step=training_step)

    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)
    total_accuracy = 0
    total_conf_matrix = None
    for i in xrange(0, set_size, FLAGS.batch_size):
        test_fingerprints, test_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
        test_accuracy, conf_matrix = sess.run(
            [evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: test_fingerprints,
                ground_truth_input: test_ground_truth,
                dropout_prob: 1.0
            })
        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (test_accuracy * batch_size) / set_size
        if total_conf_matrix is None:
            total_conf_matrix = conf_matrix
        else:
            total_conf_matrix += conf_matrix
    tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
    tf.logging.info('Final test accuracy = %.1f%% (N=%d)' %
                    (total_accuracy * 100, set_size))
Exemplo n.º 9
0
def train(audio_processor, model_settings, params):

    tf.compat.v1.logging.set_verbosity(params["verbosity"])

    # Start a new TensorFlow session.
    sess = tf.compat.v1.InteractiveSession()

    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']
    time_shift_samples = int(
        (params["time_shift_ms"] * params["sample_rate"]) / 1000)
    # Figure out the learning rates for each training phase. Since it's often
    # effective to have high learning rates at the start of training, followed by
    # lower levels towards the end, the number of steps and learning rates can be
    # specified as comma-separated lists to define the rate at each stage. For
    # example --how_many_training_steps=10000,3000 --learning_rate=0.001,0.0001
    # will run 13,000 training loops in total, with a rate of 0.001 for the first
    # 10,000, and 0.0001 for the final 3,000.
    training_steps_list = list(
        map(int, params["how_many_training_steps"].split(',')))
    learning_rates_list = list(map(float, params["learning_rate"].split(',')))
    if len(training_steps_list) != len(learning_rates_list):
        raise Exception(
            '--how_many_training_steps and --learning_rate must be equal length '
            'lists, but are %d and %d long instead' %
            (len(training_steps_list), len(learning_rates_list)))

    input_placeholder = tf.compat.v1.placeholder(tf.float32,
                                                 [None, fingerprint_size],
                                                 name='fingerprint_input')
    if params["quantize"]:
        fingerprint_min, fingerprint_max = input_data.get_features_range(
            model_settings)
        fingerprint_input = tf.quantization.fake_quant_with_min_max_args(
            input_placeholder, fingerprint_min, fingerprint_max)
    else:
        fingerprint_input = input_placeholder

    logits, dropout_rate = models.create_model(fingerprint_input,
                                               model_settings,
                                               params["model_architecture"],
                                               is_training=True)

    # Define loss and optimizer
    ground_truth_input = tf.compat.v1.placeholder(tf.int64, [None],
                                                  name='groundtruth_input')

    # Optionally we can add runtime checks to spot when NaNs or other symptoms of
    # numerical errors start occurring during training.
    control_dependencies = []
    if params["check_nans"]:
        checks = tf.compat.v1.add_check_numerics_ops()
        control_dependencies = [checks]

    # Create the back propagation and training evaluation machinery in the graph.
    with tf.compat.v1.name_scope('cross_entropy'):
        cross_entropy_mean = tf.compat.v1.losses.sparse_softmax_cross_entropy(
            labels=ground_truth_input, logits=logits)

    if params["quantize"]:
        try:
            tf.contrib.quantize.create_training_graph(quant_delay=0)
        except AttributeError as e:
            msg = e.args[0]
            msg += (
                '\n\n The --quantize option still requires contrib, which is not '
                'part of TensorFlow 2.0. Please install a previous version:'
                '\n    `pip install tensorflow<=1.15`')
            e.args = (msg, )
            raise e

    with tf.compat.v1.name_scope('train'), tf.control_dependencies(
            control_dependencies):
        learning_rate_input = tf.compat.v1.placeholder(
            tf.float32, [], name='learning_rate_input')
        if params["optimizer"] == 'gradient_descent':
            train_step = tf.compat.v1.train.GradientDescentOptimizer(
                learning_rate_input).minimize(cross_entropy_mean)
        elif params["optimizer"] == 'momentum':
            train_step = tf.compat.v1.train.MomentumOptimizer(
                learning_rate_input, .9,
                use_nesterov=True).minimize(cross_entropy_mean)
        else:
            raise Exception('Invalid Optimizer')
    predicted_indices = tf.argmax(input=logits, axis=1)
    correct_prediction = tf.equal(predicted_indices, ground_truth_input)
    confusion_matrix = tf.math.confusion_matrix(labels=ground_truth_input,
                                                predictions=predicted_indices,
                                                num_classes=label_count)
    evaluation_step = tf.reduce_mean(
        input_tensor=tf.cast(correct_prediction, tf.float32))
    with tf.compat.v1.get_default_graph().name_scope('eval'):
        tf.compat.v1.summary.scalar('cross_entropy', cross_entropy_mean)
        tf.compat.v1.summary.scalar('accuracy', evaluation_step)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    increment_global_step = tf.compat.v1.assign(global_step, global_step + 1)

    saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables())

    # Merge all the summaries and write them out to /tmp/retrain_logs (by default)
    merged_summaries = tf.compat.v1.summary.merge_all(scope='eval')
    train_writer = tf.compat.v1.summary.FileWriter(
        params["summaries_dir"] + '/train', sess.graph)
    validation_writer = tf.compat.v1.summary.FileWriter(
        params["summaries_dir"] + '/validation')

    tf.compat.v1.global_variables_initializer().run()

    start_step = 1

    if params["start_checkpoint"]:
        models.load_variables_from_checkpoint(sess, params["start_checkpoint"])
        start_step = global_step.eval(session=sess)

    tf.compat.v1.logging.info('Training from step: %d ', start_step)

    # Save graph.pbtxt.
    tf.io.write_graph(sess.graph_def, params["train_dir"],
                      params["model_architecture"] + '.pbtxt')

    # Save list of words.
    with gfile.GFile(
            os.path.join(params["train_dir"],
                         params["model_architecture"] + '_labels.txt'),
            'w') as f:
        f.write('\n'.join(audio_processor.words_list))

    # Training loop.
    training_steps_max = np.sum(training_steps_list)
    for training_step in xrange(start_step, training_steps_max + 1):
        # Figure out what the current learning rate is.
        training_steps_sum = 0
        for i in range(len(training_steps_list)):
            training_steps_sum += training_steps_list[i]
            if training_step <= training_steps_sum:
                learning_rate_value = learning_rates_list[i]
                break
        # Pull the audio samples we'll use for training.
        train_fingerprints, train_ground_truth = audio_processor.get_data(
            params["batch_size"], 0, model_settings,
            params["background_frequency"], params["background_volume"],
            time_shift_samples, 'training', sess)
        # Run the graph with this batch of training data.
        train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
            [
                merged_summaries,
                evaluation_step,
                cross_entropy_mean,
                train_step,
                increment_global_step,
            ],
            feed_dict={
                fingerprint_input: train_fingerprints,
                ground_truth_input: train_ground_truth,
                learning_rate_input: learning_rate_value,
                dropout_rate: 0.5
            })
        train_writer.add_summary(train_summary, training_step)
        tf.compat.v1.logging.debug(
            'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
            (training_step, learning_rate_value, train_accuracy * 100,
             cross_entropy_value))
        is_last_step = (training_step == training_steps_max)
        if (training_step % params["eval_step_interval"]) == 0 or is_last_step:
            tf.compat.v1.logging.info(
                'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
                (training_step, learning_rate_value, train_accuracy * 100,
                 cross_entropy_value))
            set_size = audio_processor.set_size('validation')
            total_accuracy = 0
            total_conf_matrix = None
            for i in xrange(0, set_size, params["batch_size"]):
                validation_fingerprints, validation_ground_truth = (
                    audio_processor.get_data(params["batch_size"], i,
                                             model_settings, 0.0, 0.0, 0,
                                             'validation', sess))
                # Run a validation step and capture training summaries for TensorBoard
                # with the `merged` op.
                validation_summary, validation_accuracy, conf_matrix = sess.run(
                    [merged_summaries, evaluation_step, confusion_matrix],
                    feed_dict={
                        fingerprint_input: validation_fingerprints,
                        ground_truth_input: validation_ground_truth,
                        dropout_rate: 0.0
                    })
                validation_writer.add_summary(validation_summary,
                                              training_step)
                batch_size = min(params["batch_size"], set_size - i)
                total_accuracy += (validation_accuracy * batch_size) / set_size
                if total_conf_matrix is None:
                    total_conf_matrix = conf_matrix
                else:
                    total_conf_matrix += conf_matrix
            tf.compat.v1.logging.info('Confusion Matrix:\n %s' %
                                      (total_conf_matrix))
            tf.compat.v1.logging.info(
                'Step %d: Validation accuracy = %.1f%% (N=%d)' %
                (training_step, total_accuracy * 100, set_size))

        # Save the model checkpoint periodically.
        if (training_step % params["save_step_interval"] == 0
                or training_step == training_steps_max):
            checkpoint_path = os.path.join(
                params["train_dir"], params["model_architecture"] + '.ckpt')
            tf.compat.v1.logging.info('Saving to "%s-%d"', checkpoint_path,
                                      training_step)
            saver.save(sess, checkpoint_path, global_step=training_step)

    set_size = audio_processor.set_size('testing')
    tf.compat.v1.logging.info('set_size=%d', set_size)
    total_accuracy = 0
    total_conf_matrix = None
    for i in xrange(0, set_size, params["batch_size"]):
        test_fingerprints, test_ground_truth = audio_processor.get_data(
            params["batch_size"], i, model_settings, 0.0, 0.0, 0, 'testing',
            sess)
        test_accuracy, conf_matrix = sess.run(
            [evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: test_fingerprints,
                ground_truth_input: test_ground_truth,
                dropout_rate: 0.0
            })
        batch_size = min(params["batch_size"], set_size - i)
        total_accuracy += (test_accuracy * batch_size) / set_size
        if total_conf_matrix is None:
            total_conf_matrix = conf_matrix
        else:
            total_conf_matrix += conf_matrix
    tf.compat.v1.logging.warn('Confusion Matrix:\n %s' % (total_conf_matrix))
    tf.compat.v1.logging.warn('Final test accuracy = %.1f%% (N=%d)' %
                              (total_accuracy * 100, set_size))
Exemplo n.º 10
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    sess = tf.InteractiveSession()

    model_settings = models.prepare_model_settings(
        len(input_data.prepare_words_list(FLAGS.wanted_words.split(','))),
        FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.window_size_ms,
        FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.preprocess)
    audio_processor = input_data.AudioProcessor(
        FLAGS.data_url, FLAGS.data_dir,
        FLAGS.silence_percentage, FLAGS.unknown_percentage,
        FLAGS.wanted_words.split(','), FLAGS.validation_percentage,
        FLAGS.testing_percentage, model_settings, FLAGS.summaries_dir)
    fingerprint_size = model_settings['fingerprint_size']
    label_count = model_settings['label_count']
    time_shift_samples = int((FLAGS.time_shift_ms * FLAGS.sample_rate) / 1000)

    training_steps_list = list(
        map(int, FLAGS.how_many_training_steps.split(',')))
    learning_rates_list = list(map(float, FLAGS.learning_rate.split(',')))
    if len(training_steps_list) != len(learning_rates_list):
        raise Exception(
            '--how_many_training_steps and --learning_rate must be equal length '
            'lists, but are %d and %d long instead' %
            (len(training_steps_list), len(learning_rates_list)))

    input_placeholder = tf.placeholder(tf.float32, [None, fingerprint_size],
                                       name='fingerprint_input')
    if FLAGS.quantize:
        fingerprint_min, fingerprint_max = input_data.get_features_range(
            model_settings)
        fingerprint_input = tf.fake_quant_with_min_max_args(
            input_placeholder, fingerprint_min, fingerprint_max)
    else:
        fingerprint_input = input_placeholder

    logits, dropout_prob = models.create_model(fingerprint_input,
                                               model_settings,
                                               FLAGS.model_architecture,
                                               is_training=True)

    ground_truth_input = tf.placeholder(tf.int64, [None],
                                        name='groundtruth_input')

    control_dependencies = []
    if FLAGS.check_nans:
        checks = tf.add_check_numerics_ops()
        control_dependencies = [checks]

    with tf.name_scope('cross_entropy'):
        cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
            labels=ground_truth_input, logits=logits)
    if FLAGS.quantize:
        tf.contrib.quantize.create_training_graph(quant_delay=0)
    with tf.name_scope('train'), tf.control_dependencies(control_dependencies):
        learning_rate_input = tf.placeholder(tf.float32, [],
                                             name='learning_rate_input')
        train_step = tf.train.GradientDescentOptimizer(
            learning_rate_input).minimize(cross_entropy_mean)
    predicted_indices = tf.argmax(logits, 1)
    correct_prediction = tf.equal(predicted_indices, ground_truth_input)
    confusion_matrix = tf.confusion_matrix(ground_truth_input,
                                           predicted_indices,
                                           num_classes=label_count)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    with tf.get_default_graph().name_scope('eval'):
        tf.summary.scalar('cross_entropy', cross_entropy_mean)
        tf.summary.scalar('accuracy', evaluation_step)

    global_step = tf.train.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    saver = tf.train.Saver(tf.global_variables())

    merged_summaries = tf.summary.merge_all(scope='eval')
    train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
                                         sess.graph)
    validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir +
                                              '/validation')

    tf.global_variables_initializer().run()

    start_step = 1

    if FLAGS.start_checkpoint:
        models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
        start_step = global_step.eval(session=sess)

    tf.logging.info('Training from step: %d ', start_step)

    tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
                         FLAGS.model_architecture + '.pbtxt')

    with gfile.GFile(
            os.path.join(FLAGS.train_dir,
                         FLAGS.model_architecture + '_labels.txt'), 'w') as f:
        f.write('\n'.join(audio_processor.words_list))

    training_steps_max = np.sum(training_steps_list)
    for training_step in xrange(start_step, training_steps_max + 1):
        training_steps_sum = 0
        for i in range(len(training_steps_list)):
            training_steps_sum += training_steps_list[i]
            if training_step <= training_steps_sum:
                learning_rate_value = learning_rates_list[i]
                break
        train_fingerprints, train_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, 0, model_settings, FLAGS.background_frequency,
            FLAGS.background_volume, time_shift_samples, 'training', sess)
        train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
            [
                merged_summaries,
                evaluation_step,
                cross_entropy_mean,
                train_step,
                increment_global_step,
            ],
            feed_dict={
                fingerprint_input: train_fingerprints,
                ground_truth_input: train_ground_truth,
                learning_rate_input: learning_rate_value,
                dropout_prob: 0.5
            })
        train_writer.add_summary(train_summary, training_step)
        tf.logging.info(
            'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
            (training_step, learning_rate_value, train_accuracy * 100,
             cross_entropy_value))
        is_last_step = (training_step == training_steps_max)
        if (training_step % FLAGS.eval_step_interval) == 0 or is_last_step:
            set_size = audio_processor.set_size('validation')
            total_accuracy = 0
            total_conf_matrix = None
            for i in xrange(0, set_size, FLAGS.batch_size):
                validation_fingerprints, validation_ground_truth = (
                    audio_processor.get_data(FLAGS.batch_size, i,
                                             model_settings, 0.0, 0.0, 0,
                                             'validation', sess))
                validation_summary, validation_accuracy, conf_matrix = sess.run(
                    [merged_summaries, evaluation_step, confusion_matrix],
                    feed_dict={
                        fingerprint_input: validation_fingerprints,
                        ground_truth_input: validation_ground_truth,
                        dropout_prob: 1.0
                    })
                validation_writer.add_summary(validation_summary,
                                              training_step)
                batch_size = min(FLAGS.batch_size, set_size - i)
                total_accuracy += (validation_accuracy * batch_size) / set_size
                if total_conf_matrix is None:
                    total_conf_matrix = conf_matrix
                else:
                    total_conf_matrix += conf_matrix
            tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
            tf.logging.info('Step %d: Validation accuracy = %.1f%% (N=%d)' %
                            (training_step, total_accuracy * 100, set_size))

        if (training_step % FLAGS.save_step_interval == 0
                or training_step == training_steps_max):
            checkpoint_path = os.path.join(FLAGS.train_dir,
                                           FLAGS.model_architecture + '.ckpt')
            tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                            training_step)
            saver.save(sess, checkpoint_path, global_step=training_step)

    set_size = audio_processor.set_size('testing')
    tf.logging.info('set_size=%d', set_size)
    total_accuracy = 0
    total_conf_matrix = None
    for i in xrange(0, set_size, FLAGS.batch_size):
        test_fingerprints, test_ground_truth = audio_processor.get_data(
            FLAGS.batch_size, i, model_settings, 0.0, 0.0, 0, 'testing', sess)
        test_accuracy, conf_matrix = sess.run(
            [evaluation_step, confusion_matrix],
            feed_dict={
                fingerprint_input: test_fingerprints,
                ground_truth_input: test_ground_truth,
                dropout_prob: 1.0
            })
        batch_size = min(FLAGS.batch_size, set_size - i)
        total_accuracy += (test_accuracy * batch_size) / set_size
        if total_conf_matrix is None:
            total_conf_matrix = conf_matrix
        else:
            total_conf_matrix += conf_matrix
    tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
    tf.logging.info('Final test accuracy = %.1f%% (N=%d)' %
                    (total_accuracy * 100, set_size))