コード例 #1
0
 def test_default_values(self):
     params = model_params.Params()
     # validate default parameters to avoid regression
     self.assertEqual(FLAGS.lr_schedule, params.lr_schedule)
     self.assertEqual(FLAGS.optimizer, params.optimizer)
     self.assertEqual(FLAGS.background_volume, params.background_volume)
     self.assertEqual(FLAGS.l2_weight_decay, params.l2_weight_decay)
     self.assertEqual(FLAGS.background_frequency,
                      params.background_frequency)
     self.assertEqual(FLAGS.split_data, params.split_data)
     self.assertEqual(FLAGS.silence_percentage, params.silence_percentage)
     self.assertEqual(FLAGS.unknown_percentage, params.unknown_percentage)
     self.assertEqual(FLAGS.time_shift_ms, params.time_shift_ms)
     self.assertEqual(FLAGS.testing_percentage, params.testing_percentage)
     self.assertEqual(FLAGS.validation_percentage,
                      params.validation_percentage)
     self.assertEqual(FLAGS.how_many_training_steps,
                      params.how_many_training_steps)
     self.assertEqual(FLAGS.eval_step_interval, params.eval_step_interval)
     self.assertEqual(FLAGS.learning_rate, params.learning_rate)
     self.assertEqual(FLAGS.batch_size, 100)
     self.assertEqual(FLAGS.optimizer_epsilon, params.optimizer_epsilon)
     self.assertEqual(FLAGS.resample, params.resample)
     self.assertEqual(FLAGS.sample_rate, params.sample_rate)
     self.assertEqual(FLAGS.volume_resample, params.volume_resample)
     self.assertEqual(FLAGS.clip_duration_ms, 1000)
     self.assertEqual(FLAGS.window_size_ms, params.window_size_ms)
     self.assertEqual(FLAGS.window_stride_ms, params.window_stride_ms)
     self.assertEqual(FLAGS.preprocess, params.preprocess)
     self.assertEqual(FLAGS.feature_type, params.feature_type)
     self.assertEqual(FLAGS.preemph, params.preemph)
     self.assertEqual(FLAGS.window_type, params.window_type)
     self.assertEqual(FLAGS.mel_lower_edge_hertz,
                      params.mel_lower_edge_hertz)
     self.assertEqual(FLAGS.mel_upper_edge_hertz,
                      params.mel_upper_edge_hertz)
     self.assertEqual(FLAGS.log_epsilon, params.log_epsilon)
     self.assertEqual(FLAGS.dct_num_features, params.dct_num_features)
     self.assertEqual(FLAGS.use_tf_fft, params.use_tf_fft)
     self.assertEqual(FLAGS.mel_non_zero_only, params.mel_non_zero_only)
     self.assertEqual(FLAGS.fft_magnitude_squared,
                      params.fft_magnitude_squared)
     self.assertEqual(FLAGS.mel_num_bins, params.mel_num_bins)
     self.assertEqual(FLAGS.use_spec_augment, params.use_spec_augment)
     self.assertEqual(FLAGS.time_masks_number, params.time_masks_number)
     self.assertEqual(FLAGS.time_mask_max_size, params.time_mask_max_size)
     self.assertEqual(FLAGS.frequency_masks_number,
                      params.frequency_masks_number)
     self.assertEqual(FLAGS.frequency_mask_max_size,
                      params.frequency_mask_max_size)
     self.assertEqual(FLAGS.return_softmax, params.return_softmax)
     self.assertEqual(FLAGS.use_spec_cutout, params.use_spec_cutout)
     self.assertEqual(FLAGS.spec_cutout_masks_number,
                      params.spec_cutout_masks_number)
     self.assertEqual(FLAGS.spec_cutout_time_mask_size,
                      params.spec_cutout_time_mask_size)
     self.assertEqual(FLAGS.spec_cutout_frequency_mask_size,
                      params.spec_cutout_frequency_mask_size)
     self.assertEqual(FLAGS.pick_deterministically,
                      params.pick_deterministically)
コード例 #2
0
ファイル: datagen.py プロジェクト: wdjose/keyword-transformer
    def __init__(self, batch_size=512, version=1, preprocess="raw"):

        # Set PATH to data sets (for example to speech commands V2):
        # They can be downloaded from
        # https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz
        # https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz
        # https://docs.google.com/uc?export=download&id=1OAN3h4uffi5HS7eb7goklWeI2XPm1jCS
        # Files should be downloaded then extracted in the google-speech-commands directory
        dataset = "google-speech-commands"
        DATA_PATH = os.path.join("data", dataset, "data{}".format(version))

        FLAGS = model_params.Params()
        FLAGS.data_dir = DATA_PATH
        FLAGS.verbosity = logging.ERROR

        # set wanted words for V2_35 dataset
        if version == 3:
            FLAGS.wanted_words = 'visual,wow,learn,backward,dog,two,left,happy,nine,go,up,bed,stop,one,zero,tree,seven,on,four,bird,right,eight,no,six,forward,house,marvin,sheila,five,off,three,down,cat,follow,yes'
            FLAGS.split_data = 0

        # set speech feature extractor properties
        FLAGS.mel_upper_edge_hertz = 7600
        FLAGS.window_size_ms = 30.0
        FLAGS.window_stride_ms = 10.0
        FLAGS.mel_num_bins = 80
        FLAGS.dct_num_features = 40
        FLAGS.feature_type = 'mfcc_tf'
        FLAGS.preprocess = preprocess

        # for numerical correctness of streaming and non streaming models set it to 1
        # but for real use case streaming set it to 0
        FLAGS.causal_data_frame_padding = 0

        FLAGS.use_tf_fft = True
        FLAGS.mel_non_zero_only = not FLAGS.use_tf_fft

        # data augmentation parameters
        FLAGS.resample = 0.15
        FLAGS.time_shift_ms = 100
        FLAGS.use_spec_augment = 1
        FLAGS.time_masks_number = 2
        FLAGS.time_mask_max_size = 25
        FLAGS.frequency_masks_number = 2
        FLAGS.frequency_mask_max_size = 7
        FLAGS.pick_deterministically = 1

        self.flags = model_flags.update_flags(FLAGS)
        import absl
        absl.logging.set_verbosity(self.flags.verbosity)


        self.flags.batch_size = batch_size
        self.time_shift_samples = int((self.flags.time_shift_ms * self.flags.sample_rate) / 1000)


        tf1.disable_eager_execution()
        config = tf1.ConfigProto(device_count={'GPU': 0})
        self.sess = tf1.Session(config=config)
        # tf1.keras.backend.set_session(self.sess)

        self.audio_processor = input_data.AudioProcessor(self.flags)