Exemple #1
0
 def testNextPowerOfTwo(self):
     self.assertEqual(utils.next_power_of_two(11), 16)
Exemple #2
0
def update_flags(flags):
    """Update flags with new parameters.

  Args:
    flags: All model and data parameters

  Returns:
    Updated flags

  Raises:
    ValueError: If the preprocessing mode isn't recognized.
  """

    label_count = len(
        input_data.prepare_words_list(flags.wanted_words.split(','),
                                      flags.split_data))
    desired_samples = int(flags.sample_rate * flags.clip_duration_ms /
                          MS_PER_SECOND)
    window_size_samples = int(flags.sample_rate * flags.window_size_ms /
                              MS_PER_SECOND)
    window_stride_samples = int(flags.sample_rate * flags.window_stride_ms /
                                MS_PER_SECOND)
    length_minus_window = (desired_samples - window_size_samples)
    if length_minus_window < 0:
        spectrogram_length = 0
    else:
        spectrogram_length = 1 + int(
            length_minus_window / window_stride_samples)
    if flags.preprocess == 'raw':
        average_window_width = -1
        fingerprint_width = desired_samples
        spectrogram_length = 1
    elif flags.preprocess == 'average':
        fft_bin_count = 1 + (utils.next_power_of_two(window_size_samples) / 2)
        average_window_width = int(
            math.floor(fft_bin_count / flags.feature_bin_count))
        fingerprint_width = int(
            math.ceil(float(fft_bin_count) / average_window_width))
    elif flags.preprocess == 'mfcc':
        average_window_width = -1
        fingerprint_width = flags.feature_bin_count
    elif flags.preprocess == 'micro':
        average_window_width = -1
        fingerprint_width = flags.feature_bin_count
    else:
        raise ValueError('Unknown preprocess mode "%s" (should be "mfcc",'
                         ' "average", or "micro")' % (flags.preprocess))

    fingerprint_size = fingerprint_width * spectrogram_length

    upd_flags = flags
    upd_flags.mode = Modes.TRAINING
    upd_flags.label_count = label_count
    upd_flags.desired_samples = desired_samples
    upd_flags.window_size_samples = window_size_samples
    upd_flags.window_stride_samples = window_stride_samples
    upd_flags.spectrogram_length = spectrogram_length
    upd_flags.fingerprint_width = fingerprint_width
    upd_flags.fingerprint_size = fingerprint_size
    upd_flags.average_window_width = average_window_width

    # summary logs for TensorBoard
    upd_flags.summaries_dir = os.path.join(flags.train_dir, 'logs/')
    return upd_flags