class SpeechCommandsDataset(AudioDataset): """ Speech Commands Dataset extraction and set creation """ def __init__(self, dataset_cfg, feature_params, collect_wavs=False, verbose=False): # parent init super().__init__(dataset_cfg, feature_params, collect_wavs=collect_wavs, verbose=verbose) # feature extractor self.feature_extractor = FeatureExtractor(feature_params=self.feature_params) # short vars self.N = self.feature_extractor.N self.hop = self.feature_extractor.hop # create plot plaths if not already exists create_folder(list(self.plot_paths.values())) # recreate if self.dataset_cfg['recreate'] or not check_folders_existance(self.wav_folders, empty_check=True): # delete old data delete_files_in_path(self.wav_folders, file_ext=self.dataset_cfg['file_ext']) # create folder wav folders create_folder(self.wav_folders) # create sets (specific to dataset) self.create_sets() # get audio files from sets self.get_audiofiles() self.get_annotation_files() def create_sets(self): """ copy wav files from dataset path to wav folders with splitting """ # get all class directories except the ones starting with _ class_dirs = glob(self.dataset_path + '[!_]*/') # run through all class directories for class_dir in class_dirs: # extract label label = class_dir.split('/')[-2] # get all .wav files wavs = glob(class_dir + '*' + self.dataset_cfg['file_ext']) # calculate split numbers in train, test, eval and split position n_split = (len(wavs) * np.array(self.dataset_cfg['split_percs'])).astype(int) n_split_pos = np.cumsum(n_split) # print some info print("label: [{}]\tn_split: [{}]\ttotal:[{}]".format(label, n_split, np.sum(n_split))) # actual path p = 0 # shuffle if self.dataset_cfg['shuffle_wavs']: np.random.shuffle(wavs) # run through each path for i, wav in enumerate(wavs): # split in new path if i >= n_split_pos[p]: p += 1 # stop if out of range (happens at rounding errors) if p >= len(self.wav_folders): break # wav name wav_name = wav.split('/')[-1].split('.')[0] # copy files to folder copyfile(wav, self.wav_folders[p] + label + str(i) + '--' + wav_name + self.dataset_cfg['file_ext']) def extract_features(self): """ extract mfcc features and save them """ print("\n--feature extraction:") # create folder structure create_folder(self.feature_folders) for i, (set_name, wavs, annos) in enumerate(zip(self.set_names, self.set_audio_files, self.set_annotation_files)): print("{}) extract set: {} with label num: {}".format(i, set_name, len(wavs))) # examples with splits n_examples = int(self.dataset_cfg['n_examples'] * self.dataset_cfg['split_percs'][i]) # extract data x, y, t, index = self.extract_mfcc_data(wavs=wavs, annos=annos, n_examples=n_examples, set_name=set_name) if self.feature_params['use_mfcc_features'] else self.extract_raw_data(wavs=wavs, annos=annos, n_examples=n_examples, set_name=set_name) # add noise if requested if self.dataset_cfg['add_noise'] and self.feature_params['use_mfcc_features']: x, y, index = self.add_noise_to_dataset(x, y, index, n_examples) # print label stats self.label_stats(y) # save mfcc data file np.savez(self.feature_files[i], x=x, y=y, t=t, index=index, params=self.feature_params) print("--save data to: ", self.feature_files[i]) def extract_mfcc_data(self, wavs, annos, n_examples, set_name=None): """ extract mfcc data from wav-files wavs must be in a 2D-array [[wavs_class1], [wavs_class2]] so that n_examples will work properly """ # mfcc_data: [n x m x l], labels and index mfcc_data, label_data, index_data = np.empty(shape=(0, self.channel_size, self.feature_size, self.frame_size), dtype=np.float64), [], [] # extract class wavs for class_wavs, class_annos in zip(wavs, annos): # class annotation file names extraction class_annos_file_names = [l + i for f, i, l in [self.file_naming_extraction(a, file_ext='.TextGrid') for a in class_annos]] # number of class examples num_class_examples = 0 # run through each example in class wavs for wav in class_wavs: # extract file namings file_name, file_index, label = self.file_naming_extraction(wav, file_ext=self.dataset_cfg['file_ext']) # get annotation if available anno = None if label + file_index in class_annos_file_names: anno = class_annos[class_annos_file_names.index(label + file_index)] # load and pre-process audio x, wav_is_useless = self.wav_pre_processing(wav) if wav_is_useless: continue # print some info if self.verbose: print("wav: [{}] with label: [{}], samples=[{}], time=[{}]s".format(wav, label, len(x), len(x) / self.feature_params['fs'])) # extract feature vectors [m x l] mfcc, bon_pos = self.feature_extractor.extract_mfcc(x, reduce_to_best_onset=False) # collect wavs if self.collect_wavs: self.pre_wavs.append((librosa.util.normalize(x), label + str(file_index) + '_' + set_name, bon_pos)) # plot mfcc features plot_mfcc_profile(x, self.feature_params['fs'], self.feature_extractor.N, self.feature_extractor.hop, mfcc, anno_file=anno, onsets=None, bon_pos=bon_pos, mient=None, minreg=None, frame_size=self.frame_size, plot_path=self.plot_paths['mfcc'], name=label + str(file_index) + '_' + set_name, enable_plot=self.dataset_cfg['enable_plot']) # damaged file check if self.dataset_cfg['filter_damaged_files']: # handle damaged files if self.detect_damaged_file(mfcc, wav): continue # add to mfcc_data container mfcc_data = np.vstack((mfcc_data, mfcc[np.newaxis, :, :, bon_pos:bon_pos+self.frame_size])) label_data.append(label) index_data.append(label + file_index) # update number of examples per class num_class_examples += 1 # stop if desired examples are reached if num_class_examples >= n_examples: break return mfcc_data, label_data, None, index_data def extract_raw_data(self, wavs, annos, n_examples, set_name=None): """ raw data extraction """ # raw data: [n x m], labels and index raw_data, label_data, target_data, index_data = np.empty(shape=(0, self.channel_size, self.raw_frame_size), dtype=np.float64), [], np.empty(shape=(0, self.raw_frame_size), dtype=np.int64), [] # extract class wavs for class_wavs, class_annos in zip(wavs, annos): # class annotation file names extraction class_annos_file_names = [l + i for f, i, l in [self.file_naming_extraction(a, file_ext='.TextGrid') for a in class_annos]] # number of class examples num_class_examples = 0 # run through each example in class wavs for wav in class_wavs: # extract file namings file_name, file_index, label = self.file_naming_extraction(wav, file_ext=self.dataset_cfg['file_ext']) # get annotation if available anno = class_annos[class_annos_file_names.index(label + file_index)] if label + file_index in class_annos_file_names else None # load and pre-process audio x, wav_is_useless = self.wav_pre_processing(wav) if wav_is_useless: continue # print some info if self.verbose: print("wav: [{}] with label: [{}], samples=[{}], time=[{}]s".format(wav, label, len(x), len(x) / self.feature_params['fs'])) # extract raw samples from region of energy raw, bon_pos = self.feature_extractor.get_best_raw_samples(x) # add dither and do normalization raw = self.wav_post_processing(raw) # quantize data t = self.feature_extractor.quantize(raw) # plot waveform if self.dataset_cfg['enable_plot']: plot_waveform(x, self.feature_params['fs'], bon_samples=[bon_pos, bon_pos+self.raw_frame_size], title=label + file_index, plot_path=self.plot_paths['waveform'], name=label + file_index, show_plot=False, close_plot=True) # collect wavs if self.collect_wavs: self.pre_wavs.append((librosa.util.normalize(x), label + str(file_index) + '_' + set_name, bon_pos / self.hop)) # add to mfcc_data container raw_data = np.vstack((raw_data, raw[np.newaxis, :])) target_data = np.vstack((target_data, t)) label_data.append(label) index_data.append(label + file_index) # update number of examples per class num_class_examples += 1 # stop if desired examples are reached if num_class_examples >= n_examples: break return raw_data, label_data, target_data, index_data def detect_damaged_file(self, mfcc, wav): """ detect if file is damaged """ # energy calc #e = np.einsum('ij,ji->j', mfcc, mfcc.T) #e = e / np.max(e) # calculate damaged score of energy deltas if mfcc.shape[1] == 39: z_est, z_lim = np.sum(np.abs(mfcc[0, 37:39, :])), 60 #if mfcc.shape[0] == 39: z_est = np.sum(mfcc[37:39, :] @ mfcc[37:39, :].T) #else: z_est = np.sum(np.abs(np.diff(mfcc[-1, :]))) #else: z_est = np.diff(mfcc[-1, :]) @ np.diff(mfcc[-1, :]).T #else: z_est = np.sum(mfcc[0, :]) #else: z_est = np.diff(mfcc[0, :]) @ np.diff(mfcc[0, :]).T #else: z_est = np.abs(np.diff(mfcc[0, :])) @ mfcc[0, :-1].T #else: z_est = np.sum(np.diff(mfcc, axis=1) @ np.diff(mfcc, axis=1).T) #else: z_est = np.sum(e) #else: z_est = np.diff(e) @ np.diff(e).T else: z_est, z_lim = mfcc[0, 0, :-1] @ np.abs(np.diff(mfcc[0, 0, :])).T, 3.5 # add score to list self.damaged_score_list.append(z_est) # damaged file is_damaged = z_est > z_lim # add to damaged file list if is_damaged: self.damaged_file_list.append((wav, z_est)) # return score and damaged indicator return is_damaged
class TestBench(): """ test bench class for evaluating models """ def __init__(self, cfg_tb, test_model_path, root_path='./'): # arguments self.cfg_tb = cfg_tb self.test_model_path = test_model_path self.root_path = root_path # shortcuts self.feature_params, self.data_size = None, None # paths self.paths = dict( (k, self.root_path + v) for k, v in self.cfg_tb['paths'].items()) # test model path self.test_model_name = self.test_model_path.split('/')[-2] # determine available model files model_files_av = [ f.split('/')[-1] for f in glob(self.test_model_path + '*model.pth') ] # model file self.model_files = [ self.test_model_path + f for f in model_files_av if f in self.cfg_tb['model_file_names'] ] # pick just the first one (errors should not occur) self.model_file = self.model_files[0] # param file self.params_file = self.test_model_path + self.cfg_tb[ 'params_file_name'] # wavs self.test_wavs = [ self.root_path + wav for wav in self.cfg_tb['test_wavs'] ] # create folder create_folder(list(self.paths.values())) # parameter loading net_params = np.load(self.params_file, allow_pickle=True) # extract params self.nn_arch, self.train_params, self.class_dict = net_params[ 'nn_arch'][()], net_params['train_params'][( )], net_params['class_dict'][()] # legacy stuff #self.data_size, self.feature_params = self.legacy_adjustments_tb(net_params) # legacy stuff self.data_size, self.feature_params = legacy_adjustments_net_params( net_params) # init feature extractor self.feature_extractor = FeatureExtractor(self.feature_params) # init net handler self.net_handler = NetHandler(nn_arch=self.nn_arch, class_dict=self.class_dict, data_size=self.data_size, use_cpu=True) # load model self.net_handler.load_models(model_files=[self.model_file]) # set evaluation mode self.net_handler.set_eval_mode() def test_invariances(self): """ test all invariances """ # init lists all_labels, all_corrects_shift, all_corrects_noise, all_probs_shift, all_probs_noise = [], [], [], [], [] # test model print("\n--Test Bench\ntest model: [{}]".format(self.test_model_name)) # go through each test wav for wav in self.test_wavs: # print message print("\ntest wav: ", wav) # file naming extraction file_name, file_index, actual_label = self.file_naming_extraction( wav) # update labels all_labels.append(actual_label) # read audio from file x_wav, _ = librosa.load(wav, sr=self.feature_params['fs']) # shift invariance corrects_shift, probs_shift = self.test_shift_invariance( x_wav, actual_label) # noise invariance corrects_noise, probs_noise = self.test_noise_invariance( x_wav, actual_label, mu=0) # collect corrects all_corrects_shift.append(corrects_shift) all_corrects_noise.append(corrects_noise) all_probs_shift.append(probs_shift) all_probs_noise.append(probs_noise) # some prints if self.cfg_tb['enable_info_prints']: print("\nall_corrects_shift:\n", all_corrects_shift), print( "\nall_corrects_noise:\n", all_corrects_noise), print("\nall labels: ", all_labels) # plots plot_test_bench_shift(x=all_corrects_shift, y=all_labels, context='bench-shift-2', title='shift ' + self.test_model_name, plot_path=self.test_model_path, name='test_bench_shift', show_plot=False) plot_test_bench_shift(x=all_probs_shift, y=all_labels, context='bench-shift', title='shift ' + self.test_model_name, plot_path=self.test_model_path, name='test_bench_shift-prob', show_plot=False) plot_test_bench_noise(x=all_corrects_noise, y=all_labels, snrs=self.cfg_tb['snrs'], context='bench-noise-2', title='noise ' + self.test_model_name, plot_path=self.test_model_path, name='test_bench_noise', show_plot=False) plot_test_bench_noise(x=all_probs_noise, y=all_labels, snrs=self.cfg_tb['snrs'], context='bench-noise', title='noise ' + self.test_model_name, plot_path=self.test_model_path, name='test_bench_noise-prob', show_plot=False) def test_noise_invariance(self, x_wav, actual_label, mu=0): """ test model against noise invariance """ # init lists pred_label_list, probs = [], [] # origin if self.cfg_tb['enable_plot']: plot_waveform(x_wav, self.feature_params['fs'], title='origin actual: [{}]'.format(actual_label), plot_path=self.paths['shift_wavs'], name='{}_origin'.format(actual_label)) # test model with different snr values for snr in self.cfg_tb['snrs']: # signal power p_x_eff = x_wav @ x_wav.T / len(x_wav) # calculate noise signal power sigma = np.sqrt(p_x_eff / (10**(snr / 10))) # noise generation n = np.random.normal(mu, sigma, len(x_wav)) # add noise x_noise = x_wav + n # noise signal power p_n_eff = n @ n.T / len(n) # print energy info # print("sigma: ", sigma), print("p_x: ", p_x_eff), print("p_n: ", p_n_eff), print("db: ", 10 * np.log10(p_x_eff / p_n_eff)) # feature extraction #x_mfcc, _ = self.feature_extractor.extract_mfcc(x_noise, reduce_to_best_onset=True) # feature extraction x, _ = self.feature_extractor.extract_mfcc( x_noise, reduce_to_best_onset=True ) if self.net_handler.nn_arch != 'wavenet' else self.feature_extractor.get_best_raw_samples( x_noise, add_channel_dim=True) # classify y_hat, o, pred_label = self.net_handler.classify_sample(x) # append predicted label and probs pred_label_list.append(pred_label) probs.append(float(o[0, self.class_dict[actual_label]])) # plot wavs if self.cfg_tb['enable_plot']: plot_waveform(x_noise, self.feature_params['fs'], title='snr: [{}] actual: [{}] pred: [{}]'.format( snr, actual_label, pred_label), plot_path=self.paths['noise_wavs'], name='{}_snr{}'.format(actual_label, snr)) # correct list corrects = [int(actual_label == l) for l in pred_label_list] # print message print("test bench noise acc: ", np.sum(corrects) / len(corrects)) return corrects, probs def test_shift_invariance(self, x_wav, actual_label): """ test model against shift invariance """ # init lists pred_label_list, probs = [], [] # feature extraction x, _ = self.feature_extractor.extract_mfcc( x_wav, reduce_to_best_onset=False ) if self.net_handler.nn_arch != 'wavenet' else (x_wav[np.newaxis, :], 0) # windowed x_win = np.squeeze( view_as_windows( x, self.data_size, step=self.cfg_tb['shift_frame_step']), axis=(0, 1)) if self.net_handler.nn_arch != 'wavenet' else np.squeeze( view_as_windows(x, self.data_size, step=self.cfg_tb['shift_frame_step'] * self.feature_extractor.hop), axis=0) for i, x in enumerate(x_win): # classify y_hat, o, pred_label = self.net_handler.classify_sample(x) # append predicted label pred_label_list.append(pred_label) probs.append(float(o[0, self.class_dict[actual_label]])) # plot time_s = frames_to_sample(i * self.cfg_tb['shift_frame_step'], self.feature_params['fs'], self.feature_extractor.hop) time_e = frames_to_sample( i * self.cfg_tb['shift_frame_step'] + self.feature_params['frame_size'], self.feature_params['fs'], self.feature_extractor.hop) # plot waveform if self.cfg_tb['enable_plot']: plot_waveform(x_wav[time_s:time_e], self.feature_params['fs'], title='frame{} actual: [{}] pred: [{}]'.format( i, actual_label, pred_label), plot_path=self.paths['shift_wavs'], name='{}_frame{}'.format(actual_label, i)) # correct list corrects = [int(actual_label == l) for l in pred_label_list] # print message print("test bench shift acc: ", np.sum(corrects) / len(corrects)) return corrects, probs def file_naming_extraction(self, file, file_ext='.wav'): """ extract file name ergo label """ # extract filename file_name = re.findall(r'[\w+ 0-9]*' + re.escape(file_ext), file)[0] # extract file index from filename file_index = re.sub(r'[a-z A-Z]|(' + re.escape(file_ext) + r')', '', file_name) # extract label from filename label = re.sub(r'([0-9]*' + re.escape(file_ext) + r')', '', file_name) return file_name, file_index, label