Beispiel #1
0
def audio_set_wavs(cfg):
    """
  audio set wavs
  """

    # plot path
    plot_path = '../docu/thesis/5_exp/figs/'

    # audio sets
    a1 = AudioDataset(cfg['datasets']['speech_commands'],
                      cfg['feature_params'],
                      root_path='../')
    a2 = AudioDataset(cfg['datasets']['my_recordings'],
                      cfg['feature_params'],
                      root_path='../')

    # feature extractor
    feature_extractor = FeatureExtractor(cfg['feature_params'])

    # get audio files
    a1.get_audiofiles()

    # random seed
    np.random.seed(1234)
    r = np.random.randint(low=0, high=150, size=len(a1.set_audio_files[1]))

    wav_grid = []

    # process wavs
    for wav in sorted([
            label_wavs[r[i]]
            for i, label_wavs in enumerate(a1.set_audio_files[1])
    ]):

        # info
        print("wav: ", wav)

        # get raw
        x, _ = a1.wav_pre_processing(wav)

        # extract feature vectors [m x l]
        _, bon_pos = feature_extractor.extract_mfcc(x,
                                                    reduce_to_best_onset=False)

        # append to wav grid
        wav_grid.append((librosa.util.normalize(x),
                         re.sub(r'[0-9]+-', '',
                                wav.split('/')[-1].split('.')[0]), bon_pos))

    # plot wav grid
    plot_wav_grid(wav_grid,
                  feature_params=a1.feature_params,
                  grid_size=(6, 5),
                  plot_path=plot_path,
                  name='wav_grid_c30',
                  show_plot=True)
Beispiel #2
0
def time_measurements(x, u, feature_params):
    """
  time measurements
  """

    # create feature extractor
    feature_extractor = FeatureExtractor(feature_params)

    # n measurements
    delta_time_list = []

    for i in range(100):

        # measure extraction time - start
        start_time = time.time()

        # time: 0.030081419944763182
        #y = calc_mfcc39(x, fs, N=400, hop=160, n_filter_bands=32, n_ceps_coeff=12, use_librosa=False)

        # time: 0.009309711456298829
        #y = calc_mfcc39(x, fs, N=400, hop=160, n_filter_bands=32, n_ceps_coeff=12, use_librosa=True)

        # time: 0.00014737367630004883
        #y = (custom_dct(np.log(u), n_filter_bands).T)

        # time: 6.929159164428711e-05
        #y = scipy.fftpack.dct(np.log(u), type=2, n=n_filter_bands, axis=1, norm=None, overwrite_x=False).T

        # time: 0.00418839693069458 *** winner
        y, _ = feature_extractor.extract_mfcc(x)

        # time: 0.015525884628295898
        #y, _ = feature_extractor.extract_mfcc39_slow(x)

        # time: 0.011266257762908936s
        #y = custom_stft(x, N=N, hop=hop, norm=True)

        # time: 0.0005800390243530274s
        #y = 2 / N * librosa.stft(x, n_fft=N, hop_length=hop, win_length=N, window='hann', center=True, dtype=None, pad_mode='reflect')

        # time: 0.00044193744659423826s
        #_, _, y = scipy.signal.stft(x, fs=1.0, window='hann', nperseg=N, noverlap=N-hop, nfft=N, detrend=False, return_onesided=True, boundary='zeros', padded=False, axis=- 1)

        # result of measured time diff
        delta_time_list.append(time.time() - start_time)

    # data shpae
    print("y: ", y.shape)

    # times
    print("delta_time: ", np.mean(delta_time_list))
Beispiel #3
0
def showcase_wavs(cfg,
                  raw_plot=True,
                  spec_plot=True,
                  mfcc_plot=True,
                  show_plot=False):
    """
  showcase wavs
  """

    # plot path
    plot_path = '../docu/thesis/3_signal/figs/'

    # change params
    feature_params = cfg['feature_params'].copy()
    feature_params['n_ceps_coeff'] = 32
    feature_params['norm_features'] = True

    # init feature extractor
    feature_extractor = FeatureExtractor(feature_params)

    # wav, anno dir
    wav_dir, anno_dir = '../ignore/my_recordings/showcase_wavs/', '../ignore/my_recordings/showcase_wavs/annotation/'

    # analyze some wavs
    for wav, anno in zip(glob(wav_dir + '*.wav'),
                         glob(anno_dir + '*.TextGrid')):

        # info
        print("\nwav: ", wav), print("anno: ", anno)

        # load file
        x, _ = librosa.load(wav, sr=feature_params['fs'])

        # raw waveform
        if raw_plot:
            plot_waveform(x,
                          feature_params['fs'],
                          anno_file=anno,
                          hop=feature_extractor.hop,
                          plot_path=plot_path,
                          name='signal_raw_' +
                          wav.split('/')[-1].split('.')[0] + '_my',
                          show_plot=show_plot)

        # spectogram
        if spec_plot:
            plot_spec_profile(x,
                              feature_extractor.calc_spectogram(x).T,
                              feature_params['fs'],
                              feature_extractor.N,
                              feature_extractor.hop,
                              anno_file=anno,
                              plot_path=plot_path,
                              title=wav.split('/')[-1].split('.')[0] + '_my',
                              name='signal_spec-lin_' +
                              wav.split('/')[-1].split('.')[0] + '_my',
                              show_plot=show_plot)
            plot_spec_profile(x,
                              feature_extractor.calc_spectogram(x).T,
                              feature_params['fs'],
                              feature_extractor.N,
                              feature_extractor.hop,
                              log_scale=True,
                              anno_file=anno,
                              plot_path=plot_path,
                              title=wav.split('/')[-1].split('.')[0] + '_my',
                              name='signal_spec-log_' +
                              wav.split('/')[-1].split('.')[0] + '_my',
                              show_plot=show_plot)

        # mfcc
        if mfcc_plot:
            mfcc, bon_pos = feature_extractor.extract_mfcc(
                x, reduce_to_best_onset=False)
            plot_mfcc_profile(x,
                              cfg['feature_params']['fs'],
                              feature_extractor.N,
                              feature_extractor.hop,
                              mfcc,
                              anno_file=anno,
                              sep_features=True,
                              bon_pos=bon_pos,
                              frame_size=cfg['feature_params']['frame_size'],
                              plot_path=plot_path,
                              name='signal_mfcc_' +
                              wav.split('/')[-1].split('.')[0] + '_my',
                              close_plot=False,
                              show_plot=show_plot)
Beispiel #4
0
    n_filter_bands = 16
    n_ceps_coeff = 12

    # --
    # test signal

    # generate test signal
    x = some_test_signal(fs, t=1, save_to_file=False)

    # stft
    x_stft = 2 / N * librosa.stft(
        x, n_fft=N, hop_length=hop, win_length=N, window='hann',
        center=False).T

    # mfcc
    mfcc, _ = feature_extractor.extract_mfcc(x)
    if len(mfcc.shape) == 3:
        mfcc = np.squeeze(mfcc.reshape(1, -1, mfcc.shape[2]))

    print("mfcc: ", mfcc.shape)

    print("e_sum: ", energy_with_sum(mfcc)), print(
        "e_m: ", energy_with_matrix(mfcc)), print("e_e: ", energy_einsum(mfcc))
    #print("pn: ", power_spec_naive(x_stft)), print("pc: ", power_spec_conj(x_stft))

    print("\ntime measures: ")
    time_measure_callable(mfcc, energy_einsum), time_measure_callable(
        mfcc, energy_with_sum), time_measure_callable(mfcc, energy_with_matrix)
    #time_measure_callable(x_stft, power_spec_naive), time_measure_callable(x_stft, power_spec_conj)

    # --
Beispiel #5
0
class Mic():
    """
  Mic class
  """
    def __init__(self,
                 classifier,
                 mic_params,
                 is_audio_record=False,
                 root_path='./'):

        # arguments
        self.classifier = classifier
        self.mic_params = mic_params
        self.is_audio_record = is_audio_record
        self.root_path = root_path

        # plot path
        self.plot_path = self.root_path + self.mic_params['plot_path']

        # create folder for plot path
        create_folder([self.plot_path])

        # shortcuts
        self.feature_params = classifier.feature_params

        # feature extractor
        self.feature_extractor = FeatureExtractor(self.feature_params)

        # windowing params
        self.N, self.hop = self.feature_extractor.N, self.feature_extractor.hop

        # queue
        self.q = queue.Queue()

        # collector
        self.collector = Collector(
            N=self.N,
            hop=self.hop,
            frame_size=self.feature_params['frame_size'],
            update_size=self.mic_params['update_size'],
            frames_post=self.mic_params['frames_post'],
            is_audio_record=self.is_audio_record)

        # device
        self.device = sd.default.device[0] if not self.mic_params[
            'select_device'] else self.mic_params['device']

        # determine downsample
        self.downsample = self.mic_params['fs_device'] // self.feature_params[
            'fs']

        # get input devices
        self.input_dev_dict = self.extract_devices()

        # show devices
        print("\ndevice list: \n", sd.query_devices())
        print("\ninput devs: ", self.input_dev_dict.keys())

        # stream
        self.stream = None

        # change device flag
        self.change_device_flag = False

    def load_user_settings(self, user_setting_file):
        """
    load user settings like device and energy threshold
    """

        # load user settings
        user_settings = yaml.safe_load(open(
            user_setting_file)) if os.path.isfile(user_setting_file) else {}

        # update mic params
        self.mic_params.update(user_settings)

        # device
        self.device = sd.default.device[0] if not self.mic_params[
            'select_device'] else self.mic_params['device']

    def init_stream(self):
        """
    init stream
    """
        self.stream = sd.InputStream(device=self.device,
                                     samplerate=self.mic_params['fs_device'],
                                     blocksize=int(self.hop * self.downsample),
                                     channels=self.mic_params['channels'],
                                     callback=self.callback_mic)
        self.change_device_flag = False

    def change_device(self, device):
        """
    change to device
    """
        self.change_device_flag = True
        self.device = device

    def extract_devices(self):
        """
    extract only input devices
    """
        return {
            i: dev
            for i, dev in enumerate(sd.query_devices())
            if dev['max_input_channels']
        }

    def callback_mic(self, indata, frames, time, status):
        """
    Input Stream Callback
    """

        # debug
        if status: print(status)

        #self.q.put(indata[:, 0].copy())

        # add to queue with primitive downsampling
        self.q.put(indata[::self.downsample, 0].copy())

    def clear_mic_queue(self):
        """
    clear the queue after classification
    """

        # process data
        for i in range(self.q.qsize()):

            # get chunk
            x = self.q.get()

            # onset and energy archiv
            e, _ = self.onset_energy_level(
                x, alpha=self.mic_params['energy_thresh'])

            # update collector
            self.collector.x_all = np.append(self.collector.x_all, x)
            self.collector.e_all = np.append(self.collector.e_all, e)

    def read_mic_data(self):
        """
    reads the input from the queue
    """

        # init
        x_collect = np.empty(shape=(0), dtype=np.float32)
        e_collect = np.empty(shape=(0), dtype=np.float32)

        # onset flag
        is_onset = False

        # process data
        if self.q.qsize():

            for i in range(self.q.qsize()):

                # get data
                x = self.q.get()

                # append chunk
                x_collect = np.append(x_collect, x.copy())

                # append energy level
                e_collect = np.append(e_collect, 1)

            # detect onset
            e_onset, is_onset = self.onset_energy_level(
                x_collect, alpha=self.mic_params['energy_thresh'])

            # collection update
            self.collector.update_collect(x_collect.copy(),
                                          e=e_collect.copy() * e_onset,
                                          on=is_onset)

        return is_onset

    def onset_energy_level(self, x, alpha=0.01):
        """
    onset detection with energy level
    x: [n x c]
    n: samples
    c: channels
    """

        e = x.T @ x / len(x)

        return e, e > alpha

    def update_read_command(self):
        """
    update mic
    """

        # read chunk
        is_onset = self.read_mic_data()

        # onset was detected
        if is_onset:

            # start collection of items
            self.collector.start_collecting()

        # collection is full
        if self.collector.is_full():

            # read out collection
            x_onset = self.collector.read_collection()

            # extract features
            mfcc_bon, bon_pos = self.feature_extractor.extract_mfcc(x_onset)

            # classify collection
            y_hat, label = self.classifier.classify(mfcc_bon)

            # plot
            plot_mfcc_profile(
                x_onset[bon_pos *
                        self.hop:(bon_pos +
                                  self.feature_params['frame_size']) *
                        self.hop],
                self.feature_params['fs'],
                self.N,
                self.hop,
                mfcc_bon,
                frame_size=self.feature_params['frame_size'],
                plot_path=self.plot_path,
                name='collect-{}_label-{}'.format(
                    self.collector.collection_counter, label),
                enable_plot=self.mic_params['enable_plot'])

            # clear read queue
            self.clear_mic_queue()

            return label

        return None

    def stop_mic_condition(self, time_duration):
        """
    stop mic if time duration is exceeded (memory issue in recording)
    """

        return (self.collector.x_all.shape[0] >=
                (time_duration *
                 self.feature_params['fs'])) and self.is_audio_record

    def save_audio_file(self):
        """
    saves collection to audio file
    """

        # has not recorded audio
        if not self.is_audio_record:
            print("***you did not set the record flag!")
            return

        import soundfile

        # save audio
        soundfile.write('{}out_audio.wav'.format(self.plot_path),
                        self.collector.x_all,
                        self.feature_params['fs'],
                        subtype=None,
                        endian=None,
                        format=None,
                        closefd=True)
Beispiel #6
0
class SpeechCommandsDataset(AudioDataset):
  """
  Speech Commands Dataset extraction and set creation
  """

  def __init__(self, dataset_cfg, feature_params, collect_wavs=False, verbose=False):

    # parent init
    super().__init__(dataset_cfg, feature_params, collect_wavs=collect_wavs, verbose=verbose)

    # feature extractor
    self.feature_extractor = FeatureExtractor(feature_params=self.feature_params)

    # short vars
    self.N = self.feature_extractor.N
    self.hop = self.feature_extractor.hop

    # create plot plaths if not already exists
    create_folder(list(self.plot_paths.values()))

    # recreate
    if self.dataset_cfg['recreate'] or not check_folders_existance(self.wav_folders, empty_check=True):

      # delete old data
      delete_files_in_path(self.wav_folders, file_ext=self.dataset_cfg['file_ext'])

      # create folder wav folders
      create_folder(self.wav_folders)

      # create sets (specific to dataset)
      self.create_sets()

    # get audio files from sets
    self.get_audiofiles()
    self.get_annotation_files()


  def create_sets(self):
    """
    copy wav files from dataset path to wav folders with splitting
    """

    # get all class directories except the ones starting with _
    class_dirs = glob(self.dataset_path + '[!_]*/')

    # run through all class directories
    for class_dir in class_dirs:

      # extract label
      label = class_dir.split('/')[-2]

      # get all .wav files
      wavs = glob(class_dir + '*' + self.dataset_cfg['file_ext'])

      # calculate split numbers in train, test, eval and split position
      n_split = (len(wavs) * np.array(self.dataset_cfg['split_percs'])).astype(int)
      n_split_pos = np.cumsum(n_split)

      # print some info
      print("label: [{}]\tn_split: [{}]\ttotal:[{}]".format(label, n_split, np.sum(n_split)))

      # actual path
      p = 0

      # shuffle
      if self.dataset_cfg['shuffle_wavs']: np.random.shuffle(wavs)

      # run through each path
      for i, wav in enumerate(wavs):

        # split in new path
        if i >= n_split_pos[p]: p += 1
        # stop if out of range (happens at rounding errors)
        if p >= len(self.wav_folders): break

        # wav name
        wav_name = wav.split('/')[-1].split('.')[0]

        # copy files to folder
        copyfile(wav, self.wav_folders[p] + label + str(i) + '--' + wav_name + self.dataset_cfg['file_ext'])


  def extract_features(self):
    """
    extract mfcc features and save them
    """

    print("\n--feature extraction:")

    # create folder structure
    create_folder(self.feature_folders)

    for i, (set_name, wavs, annos) in enumerate(zip(self.set_names, self.set_audio_files, self.set_annotation_files)):

      print("{}) extract set: {} with label num: {}".format(i, set_name, len(wavs)))

      # examples with splits
      n_examples = int(self.dataset_cfg['n_examples'] * self.dataset_cfg['split_percs'][i])

      # extract data
      x, y, t, index = self.extract_mfcc_data(wavs=wavs, annos=annos, n_examples=n_examples, set_name=set_name) if self.feature_params['use_mfcc_features'] else self.extract_raw_data(wavs=wavs, annos=annos, n_examples=n_examples, set_name=set_name)

      # add noise if requested
      if self.dataset_cfg['add_noise'] and self.feature_params['use_mfcc_features']: x, y, index = self.add_noise_to_dataset(x, y, index, n_examples)

      # print label stats
      self.label_stats(y)

      # save mfcc data file
      np.savez(self.feature_files[i], x=x, y=y, t=t, index=index, params=self.feature_params)
      print("--save data to: ", self.feature_files[i])


  def extract_mfcc_data(self, wavs, annos, n_examples, set_name=None):
    """
    extract mfcc data from wav-files
    wavs must be in a 2D-array [[wavs_class1], [wavs_class2]] so that n_examples will work properly
    """

    # mfcc_data: [n x m x l], labels and index
    mfcc_data, label_data, index_data = np.empty(shape=(0, self.channel_size, self.feature_size, self.frame_size), dtype=np.float64), [], []

    # extract class wavs
    for class_wavs, class_annos in zip(wavs, annos):

      # class annotation file names extraction
      class_annos_file_names = [l + i for f, i, l in [self.file_naming_extraction(a, file_ext='.TextGrid') for a in class_annos]]

      # number of class examples
      num_class_examples = 0

      # run through each example in class wavs
      for wav in class_wavs:
        
        # extract file namings
        file_name, file_index, label = self.file_naming_extraction(wav, file_ext=self.dataset_cfg['file_ext'])

        # get annotation if available
        anno = None
        if label + file_index in class_annos_file_names: anno = class_annos[class_annos_file_names.index(label + file_index)]

        # load and pre-process audio
        x, wav_is_useless = self.wav_pre_processing(wav)
        if wav_is_useless: continue

        # print some info
        if self.verbose: print("wav: [{}] with label: [{}], samples=[{}], time=[{}]s".format(wav, label, len(x), len(x) / self.feature_params['fs']))

        # extract feature vectors [m x l]
        mfcc, bon_pos = self.feature_extractor.extract_mfcc(x, reduce_to_best_onset=False)

        # collect wavs
        if self.collect_wavs: self.pre_wavs.append((librosa.util.normalize(x), label + str(file_index) + '_' + set_name, bon_pos))

        # plot mfcc features
        plot_mfcc_profile(x, self.feature_params['fs'], self.feature_extractor.N, self.feature_extractor.hop, mfcc, anno_file=anno, onsets=None, bon_pos=bon_pos, mient=None, minreg=None, frame_size=self.frame_size, plot_path=self.plot_paths['mfcc'], name=label + str(file_index) + '_' + set_name, enable_plot=self.dataset_cfg['enable_plot'])

        # damaged file check
        if self.dataset_cfg['filter_damaged_files']:

          # handle damaged files
          if self.detect_damaged_file(mfcc, wav): continue

        # add to mfcc_data container
        mfcc_data = np.vstack((mfcc_data, mfcc[np.newaxis, :, :, bon_pos:bon_pos+self.frame_size]))
        label_data.append(label)
        index_data.append(label + file_index)

        # update number of examples per class
        num_class_examples += 1

        # stop if desired examples are reached
        if num_class_examples >= n_examples: break


    return mfcc_data, label_data, None, index_data


  def extract_raw_data(self, wavs, annos, n_examples, set_name=None):
    """
    raw data extraction
    """

    # raw data: [n x m], labels and index
    raw_data, label_data, target_data, index_data = np.empty(shape=(0, self.channel_size, self.raw_frame_size), dtype=np.float64), [], np.empty(shape=(0, self.raw_frame_size), dtype=np.int64), []

    # extract class wavs
    for class_wavs, class_annos in zip(wavs, annos):

      # class annotation file names extraction
      class_annos_file_names = [l + i for f, i, l in [self.file_naming_extraction(a, file_ext='.TextGrid') for a in class_annos]]

      # number of class examples
      num_class_examples = 0

      # run through each example in class wavs
      for wav in class_wavs:
        
        # extract file namings
        file_name, file_index, label = self.file_naming_extraction(wav, file_ext=self.dataset_cfg['file_ext'])

        # get annotation if available
        anno = class_annos[class_annos_file_names.index(label + file_index)] if label + file_index in class_annos_file_names else None

        # load and pre-process audio
        x, wav_is_useless = self.wav_pre_processing(wav)
        if wav_is_useless: continue

        # print some info
        if self.verbose: print("wav: [{}] with label: [{}], samples=[{}], time=[{}]s".format(wav, label, len(x), len(x) / self.feature_params['fs']))

        # extract raw samples from region of energy
        raw, bon_pos = self.feature_extractor.get_best_raw_samples(x)

        # add dither and do normalization
        raw = self.wav_post_processing(raw)

        # quantize data
        t = self.feature_extractor.quantize(raw)

        # plot waveform
        if self.dataset_cfg['enable_plot']: plot_waveform(x, self.feature_params['fs'],  bon_samples=[bon_pos, bon_pos+self.raw_frame_size], title=label + file_index, plot_path=self.plot_paths['waveform'], name=label + file_index, show_plot=False, close_plot=True)

        # collect wavs
        if self.collect_wavs: self.pre_wavs.append((librosa.util.normalize(x), label + str(file_index) + '_' + set_name, bon_pos / self.hop))

        # add to mfcc_data container
        raw_data = np.vstack((raw_data, raw[np.newaxis, :]))
        target_data = np.vstack((target_data, t))
        label_data.append(label)
        index_data.append(label + file_index)

        # update number of examples per class
        num_class_examples += 1

        # stop if desired examples are reached
        if num_class_examples >= n_examples: break

    return raw_data, label_data, target_data, index_data


  def detect_damaged_file(self, mfcc, wav):
    """
    detect if file is damaged
    """

    # energy calc
    #e = np.einsum('ij,ji->j', mfcc, mfcc.T)
    #e = e / np.max(e)

    # calculate damaged score of energy deltas
    if mfcc.shape[1] == 39: z_est, z_lim = np.sum(np.abs(mfcc[0, 37:39, :])), 60
    #if mfcc.shape[0] == 39: z_est = np.sum(mfcc[37:39, :] @ mfcc[37:39, :].T)
    #else: z_est = np.sum(np.abs(np.diff(mfcc[-1, :])))
    #else: z_est = np.diff(mfcc[-1, :]) @ np.diff(mfcc[-1, :]).T
    #else: z_est = np.sum(mfcc[0, :])
    #else: z_est = np.diff(mfcc[0, :]) @ np.diff(mfcc[0, :]).T
    #else: z_est = np.abs(np.diff(mfcc[0, :])) @ mfcc[0, :-1].T
    #else: z_est = np.sum(np.diff(mfcc, axis=1) @ np.diff(mfcc, axis=1).T)
    #else: z_est = np.sum(e)
    #else: z_est = np.diff(e) @ np.diff(e).T
    else: z_est, z_lim = mfcc[0, 0, :-1] @ np.abs(np.diff(mfcc[0, 0, :])).T, 3.5

    # add score to list
    self.damaged_score_list.append(z_est)

    # damaged file
    is_damaged = z_est > z_lim

    # add to damaged file list
    if is_damaged: self.damaged_file_list.append((wav, z_est))

    # return score and damaged indicator
    return is_damaged
Beispiel #7
0
class TestBench():
    """
  test bench class for evaluating models
  """
    def __init__(self, cfg_tb, test_model_path, root_path='./'):

        # arguments
        self.cfg_tb = cfg_tb
        self.test_model_path = test_model_path
        self.root_path = root_path

        # shortcuts
        self.feature_params, self.data_size = None, None

        # paths
        self.paths = dict(
            (k, self.root_path + v) for k, v in self.cfg_tb['paths'].items())

        # test model path
        self.test_model_name = self.test_model_path.split('/')[-2]

        # determine available model files
        model_files_av = [
            f.split('/')[-1] for f in glob(self.test_model_path + '*model.pth')
        ]

        # model file
        self.model_files = [
            self.test_model_path + f for f in model_files_av
            if f in self.cfg_tb['model_file_names']
        ]

        # pick just the first one (errors should not occur)
        self.model_file = self.model_files[0]

        # param file
        self.params_file = self.test_model_path + self.cfg_tb[
            'params_file_name']

        # wavs
        self.test_wavs = [
            self.root_path + wav for wav in self.cfg_tb['test_wavs']
        ]

        # create folder
        create_folder(list(self.paths.values()))

        # parameter loading
        net_params = np.load(self.params_file, allow_pickle=True)

        # extract params
        self.nn_arch, self.train_params, self.class_dict = net_params[
            'nn_arch'][()], net_params['train_params'][(
            )], net_params['class_dict'][()]

        # legacy stuff
        #self.data_size, self.feature_params = self.legacy_adjustments_tb(net_params)

        # legacy stuff
        self.data_size, self.feature_params = legacy_adjustments_net_params(
            net_params)

        # init feature extractor
        self.feature_extractor = FeatureExtractor(self.feature_params)

        # init net handler
        self.net_handler = NetHandler(nn_arch=self.nn_arch,
                                      class_dict=self.class_dict,
                                      data_size=self.data_size,
                                      use_cpu=True)

        # load model
        self.net_handler.load_models(model_files=[self.model_file])

        # set evaluation mode
        self.net_handler.set_eval_mode()

    def test_invariances(self):
        """
    test all invariances
    """

        # init lists
        all_labels, all_corrects_shift, all_corrects_noise, all_probs_shift, all_probs_noise = [], [], [], [], []

        # test model
        print("\n--Test Bench\ntest model: [{}]".format(self.test_model_name))

        # go through each test wav
        for wav in self.test_wavs:

            # print message
            print("\ntest wav: ", wav)

            # file naming extraction
            file_name, file_index, actual_label = self.file_naming_extraction(
                wav)

            # update labels
            all_labels.append(actual_label)

            # read audio from file
            x_wav, _ = librosa.load(wav, sr=self.feature_params['fs'])

            # shift invariance
            corrects_shift, probs_shift = self.test_shift_invariance(
                x_wav, actual_label)

            # noise invariance
            corrects_noise, probs_noise = self.test_noise_invariance(
                x_wav, actual_label, mu=0)

            # collect corrects
            all_corrects_shift.append(corrects_shift)
            all_corrects_noise.append(corrects_noise)

            all_probs_shift.append(probs_shift)
            all_probs_noise.append(probs_noise)

        # some prints
        if self.cfg_tb['enable_info_prints']:
            print("\nall_corrects_shift:\n", all_corrects_shift), print(
                "\nall_corrects_noise:\n",
                all_corrects_noise), print("\nall labels: ", all_labels)

        # plots
        plot_test_bench_shift(x=all_corrects_shift,
                              y=all_labels,
                              context='bench-shift-2',
                              title='shift ' + self.test_model_name,
                              plot_path=self.test_model_path,
                              name='test_bench_shift',
                              show_plot=False)
        plot_test_bench_shift(x=all_probs_shift,
                              y=all_labels,
                              context='bench-shift',
                              title='shift ' + self.test_model_name,
                              plot_path=self.test_model_path,
                              name='test_bench_shift-prob',
                              show_plot=False)

        plot_test_bench_noise(x=all_corrects_noise,
                              y=all_labels,
                              snrs=self.cfg_tb['snrs'],
                              context='bench-noise-2',
                              title='noise ' + self.test_model_name,
                              plot_path=self.test_model_path,
                              name='test_bench_noise',
                              show_plot=False)
        plot_test_bench_noise(x=all_probs_noise,
                              y=all_labels,
                              snrs=self.cfg_tb['snrs'],
                              context='bench-noise',
                              title='noise ' + self.test_model_name,
                              plot_path=self.test_model_path,
                              name='test_bench_noise-prob',
                              show_plot=False)

    def test_noise_invariance(self, x_wav, actual_label, mu=0):
        """
    test model against noise invariance
    """

        # init lists
        pred_label_list, probs = [], []

        # origin
        if self.cfg_tb['enable_plot']:
            plot_waveform(x_wav,
                          self.feature_params['fs'],
                          title='origin actual: [{}]'.format(actual_label),
                          plot_path=self.paths['shift_wavs'],
                          name='{}_origin'.format(actual_label))

        # test model with different snr values
        for snr in self.cfg_tb['snrs']:

            # signal power
            p_x_eff = x_wav @ x_wav.T / len(x_wav)

            # calculate noise signal power
            sigma = np.sqrt(p_x_eff / (10**(snr / 10)))

            # noise generation
            n = np.random.normal(mu, sigma, len(x_wav))

            # add noise
            x_noise = x_wav + n

            # noise signal power
            p_n_eff = n @ n.T / len(n)

            # print energy info
            # print("sigma: ", sigma), print("p_x: ", p_x_eff), print("p_n: ", p_n_eff), print("db: ", 10 * np.log10(p_x_eff / p_n_eff))

            # feature extraction
            #x_mfcc, _ = self.feature_extractor.extract_mfcc(x_noise, reduce_to_best_onset=True)

            # feature extraction
            x, _ = self.feature_extractor.extract_mfcc(
                x_noise, reduce_to_best_onset=True
            ) if self.net_handler.nn_arch != 'wavenet' else self.feature_extractor.get_best_raw_samples(
                x_noise, add_channel_dim=True)

            # classify
            y_hat, o, pred_label = self.net_handler.classify_sample(x)

            # append predicted label and probs
            pred_label_list.append(pred_label)
            probs.append(float(o[0, self.class_dict[actual_label]]))

            # plot wavs
            if self.cfg_tb['enable_plot']:
                plot_waveform(x_noise,
                              self.feature_params['fs'],
                              title='snr: [{}] actual: [{}] pred: [{}]'.format(
                                  snr, actual_label, pred_label),
                              plot_path=self.paths['noise_wavs'],
                              name='{}_snr{}'.format(actual_label, snr))

        # correct list
        corrects = [int(actual_label == l) for l in pred_label_list]

        # print message
        print("test bench noise acc: ", np.sum(corrects) / len(corrects))

        return corrects, probs

    def test_shift_invariance(self, x_wav, actual_label):
        """
    test model against shift invariance
    """

        # init lists
        pred_label_list, probs = [], []

        # feature extraction
        x, _ = self.feature_extractor.extract_mfcc(
            x_wav, reduce_to_best_onset=False
        ) if self.net_handler.nn_arch != 'wavenet' else (x_wav[np.newaxis, :],
                                                         0)

        # windowed
        x_win = np.squeeze(
            view_as_windows(
                x, self.data_size, step=self.cfg_tb['shift_frame_step']),
            axis=(0,
                  1)) if self.net_handler.nn_arch != 'wavenet' else np.squeeze(
                      view_as_windows(x,
                                      self.data_size,
                                      step=self.cfg_tb['shift_frame_step'] *
                                      self.feature_extractor.hop),
                      axis=0)

        for i, x in enumerate(x_win):

            # classify
            y_hat, o, pred_label = self.net_handler.classify_sample(x)

            # append predicted label
            pred_label_list.append(pred_label)
            probs.append(float(o[0, self.class_dict[actual_label]]))

            # plot
            time_s = frames_to_sample(i * self.cfg_tb['shift_frame_step'],
                                      self.feature_params['fs'],
                                      self.feature_extractor.hop)
            time_e = frames_to_sample(
                i * self.cfg_tb['shift_frame_step'] +
                self.feature_params['frame_size'], self.feature_params['fs'],
                self.feature_extractor.hop)

            # plot waveform
            if self.cfg_tb['enable_plot']:
                plot_waveform(x_wav[time_s:time_e],
                              self.feature_params['fs'],
                              title='frame{} actual: [{}] pred: [{}]'.format(
                                  i, actual_label, pred_label),
                              plot_path=self.paths['shift_wavs'],
                              name='{}_frame{}'.format(actual_label, i))

        # correct list
        corrects = [int(actual_label == l) for l in pred_label_list]

        # print message
        print("test bench shift acc: ", np.sum(corrects) / len(corrects))

        return corrects, probs

    def file_naming_extraction(self, file, file_ext='.wav'):
        """
    extract file name ergo label
    """

        # extract filename
        file_name = re.findall(r'[\w+ 0-9]*' + re.escape(file_ext), file)[0]

        # extract file index from filename
        file_index = re.sub(r'[a-z A-Z]|(' + re.escape(file_ext) + r')', '',
                            file_name)

        # extract label from filename
        label = re.sub(r'([0-9]*' + re.escape(file_ext) + r')', '', file_name)

        return file_name, file_index, label