def test_default_values(self): params = model_params.Params() # validate default parameters to avoid regression self.assertEqual(FLAGS.lr_schedule, params.lr_schedule) self.assertEqual(FLAGS.optimizer, params.optimizer) self.assertEqual(FLAGS.background_volume, params.background_volume) self.assertEqual(FLAGS.l2_weight_decay, params.l2_weight_decay) self.assertEqual(FLAGS.background_frequency, params.background_frequency) self.assertEqual(FLAGS.split_data, params.split_data) self.assertEqual(FLAGS.silence_percentage, params.silence_percentage) self.assertEqual(FLAGS.unknown_percentage, params.unknown_percentage) self.assertEqual(FLAGS.time_shift_ms, params.time_shift_ms) self.assertEqual(FLAGS.testing_percentage, params.testing_percentage) self.assertEqual(FLAGS.validation_percentage, params.validation_percentage) self.assertEqual(FLAGS.how_many_training_steps, params.how_many_training_steps) self.assertEqual(FLAGS.eval_step_interval, params.eval_step_interval) self.assertEqual(FLAGS.learning_rate, params.learning_rate) self.assertEqual(FLAGS.batch_size, 100) self.assertEqual(FLAGS.optimizer_epsilon, params.optimizer_epsilon) self.assertEqual(FLAGS.resample, params.resample) self.assertEqual(FLAGS.sample_rate, params.sample_rate) self.assertEqual(FLAGS.volume_resample, params.volume_resample) self.assertEqual(FLAGS.clip_duration_ms, 1000) self.assertEqual(FLAGS.window_size_ms, params.window_size_ms) self.assertEqual(FLAGS.window_stride_ms, params.window_stride_ms) self.assertEqual(FLAGS.preprocess, params.preprocess) self.assertEqual(FLAGS.feature_type, params.feature_type) self.assertEqual(FLAGS.preemph, params.preemph) self.assertEqual(FLAGS.window_type, params.window_type) self.assertEqual(FLAGS.mel_lower_edge_hertz, params.mel_lower_edge_hertz) self.assertEqual(FLAGS.mel_upper_edge_hertz, params.mel_upper_edge_hertz) self.assertEqual(FLAGS.log_epsilon, params.log_epsilon) self.assertEqual(FLAGS.dct_num_features, params.dct_num_features) self.assertEqual(FLAGS.use_tf_fft, params.use_tf_fft) self.assertEqual(FLAGS.mel_non_zero_only, params.mel_non_zero_only) self.assertEqual(FLAGS.fft_magnitude_squared, params.fft_magnitude_squared) self.assertEqual(FLAGS.mel_num_bins, params.mel_num_bins) self.assertEqual(FLAGS.use_spec_augment, params.use_spec_augment) self.assertEqual(FLAGS.time_masks_number, params.time_masks_number) self.assertEqual(FLAGS.time_mask_max_size, params.time_mask_max_size) self.assertEqual(FLAGS.frequency_masks_number, params.frequency_masks_number) self.assertEqual(FLAGS.frequency_mask_max_size, params.frequency_mask_max_size) self.assertEqual(FLAGS.return_softmax, params.return_softmax) self.assertEqual(FLAGS.use_spec_cutout, params.use_spec_cutout) self.assertEqual(FLAGS.spec_cutout_masks_number, params.spec_cutout_masks_number) self.assertEqual(FLAGS.spec_cutout_time_mask_size, params.spec_cutout_time_mask_size) self.assertEqual(FLAGS.spec_cutout_frequency_mask_size, params.spec_cutout_frequency_mask_size) self.assertEqual(FLAGS.pick_deterministically, params.pick_deterministically)
def __init__(self, batch_size=512, version=1, preprocess="raw"): # Set PATH to data sets (for example to speech commands V2): # They can be downloaded from # https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz # https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz # https://docs.google.com/uc?export=download&id=1OAN3h4uffi5HS7eb7goklWeI2XPm1jCS # Files should be downloaded then extracted in the google-speech-commands directory dataset = "google-speech-commands" DATA_PATH = os.path.join("data", dataset, "data{}".format(version)) FLAGS = model_params.Params() FLAGS.data_dir = DATA_PATH FLAGS.verbosity = logging.ERROR # set wanted words for V2_35 dataset if version == 3: FLAGS.wanted_words = 'visual,wow,learn,backward,dog,two,left,happy,nine,go,up,bed,stop,one,zero,tree,seven,on,four,bird,right,eight,no,six,forward,house,marvin,sheila,five,off,three,down,cat,follow,yes' FLAGS.split_data = 0 # set speech feature extractor properties FLAGS.mel_upper_edge_hertz = 7600 FLAGS.window_size_ms = 30.0 FLAGS.window_stride_ms = 10.0 FLAGS.mel_num_bins = 80 FLAGS.dct_num_features = 40 FLAGS.feature_type = 'mfcc_tf' FLAGS.preprocess = preprocess # for numerical correctness of streaming and non streaming models set it to 1 # but for real use case streaming set it to 0 FLAGS.causal_data_frame_padding = 0 FLAGS.use_tf_fft = True FLAGS.mel_non_zero_only = not FLAGS.use_tf_fft # data augmentation parameters FLAGS.resample = 0.15 FLAGS.time_shift_ms = 100 FLAGS.use_spec_augment = 1 FLAGS.time_masks_number = 2 FLAGS.time_mask_max_size = 25 FLAGS.frequency_masks_number = 2 FLAGS.frequency_mask_max_size = 7 FLAGS.pick_deterministically = 1 self.flags = model_flags.update_flags(FLAGS) import absl absl.logging.set_verbosity(self.flags.verbosity) self.flags.batch_size = batch_size self.time_shift_samples = int((self.flags.time_shift_ms * self.flags.sample_rate) / 1000) tf1.disable_eager_execution() config = tf1.ConfigProto(device_count={'GPU': 0}) self.sess = tf1.Session(config=config) # tf1.keras.backend.set_session(self.sess) self.audio_processor = input_data.AudioProcessor(self.flags)