Ejemplo n.º 1
0
class SpeechCommandsDataset(AudioDataset):
  """
  Speech Commands Dataset extraction and set creation
  """

  def __init__(self, dataset_cfg, feature_params, collect_wavs=False, verbose=False):

    # parent init
    super().__init__(dataset_cfg, feature_params, collect_wavs=collect_wavs, verbose=verbose)

    # feature extractor
    self.feature_extractor = FeatureExtractor(feature_params=self.feature_params)

    # short vars
    self.N = self.feature_extractor.N
    self.hop = self.feature_extractor.hop

    # create plot plaths if not already exists
    create_folder(list(self.plot_paths.values()))

    # recreate
    if self.dataset_cfg['recreate'] or not check_folders_existance(self.wav_folders, empty_check=True):

      # delete old data
      delete_files_in_path(self.wav_folders, file_ext=self.dataset_cfg['file_ext'])

      # create folder wav folders
      create_folder(self.wav_folders)

      # create sets (specific to dataset)
      self.create_sets()

    # get audio files from sets
    self.get_audiofiles()
    self.get_annotation_files()


  def create_sets(self):
    """
    copy wav files from dataset path to wav folders with splitting
    """

    # get all class directories except the ones starting with _
    class_dirs = glob(self.dataset_path + '[!_]*/')

    # run through all class directories
    for class_dir in class_dirs:

      # extract label
      label = class_dir.split('/')[-2]

      # get all .wav files
      wavs = glob(class_dir + '*' + self.dataset_cfg['file_ext'])

      # calculate split numbers in train, test, eval and split position
      n_split = (len(wavs) * np.array(self.dataset_cfg['split_percs'])).astype(int)
      n_split_pos = np.cumsum(n_split)

      # print some info
      print("label: [{}]\tn_split: [{}]\ttotal:[{}]".format(label, n_split, np.sum(n_split)))

      # actual path
      p = 0

      # shuffle
      if self.dataset_cfg['shuffle_wavs']: np.random.shuffle(wavs)

      # run through each path
      for i, wav in enumerate(wavs):

        # split in new path
        if i >= n_split_pos[p]: p += 1
        # stop if out of range (happens at rounding errors)
        if p >= len(self.wav_folders): break

        # wav name
        wav_name = wav.split('/')[-1].split('.')[0]

        # copy files to folder
        copyfile(wav, self.wav_folders[p] + label + str(i) + '--' + wav_name + self.dataset_cfg['file_ext'])


  def extract_features(self):
    """
    extract mfcc features and save them
    """

    print("\n--feature extraction:")

    # create folder structure
    create_folder(self.feature_folders)

    for i, (set_name, wavs, annos) in enumerate(zip(self.set_names, self.set_audio_files, self.set_annotation_files)):

      print("{}) extract set: {} with label num: {}".format(i, set_name, len(wavs)))

      # examples with splits
      n_examples = int(self.dataset_cfg['n_examples'] * self.dataset_cfg['split_percs'][i])

      # extract data
      x, y, t, index = self.extract_mfcc_data(wavs=wavs, annos=annos, n_examples=n_examples, set_name=set_name) if self.feature_params['use_mfcc_features'] else self.extract_raw_data(wavs=wavs, annos=annos, n_examples=n_examples, set_name=set_name)

      # add noise if requested
      if self.dataset_cfg['add_noise'] and self.feature_params['use_mfcc_features']: x, y, index = self.add_noise_to_dataset(x, y, index, n_examples)

      # print label stats
      self.label_stats(y)

      # save mfcc data file
      np.savez(self.feature_files[i], x=x, y=y, t=t, index=index, params=self.feature_params)
      print("--save data to: ", self.feature_files[i])


  def extract_mfcc_data(self, wavs, annos, n_examples, set_name=None):
    """
    extract mfcc data from wav-files
    wavs must be in a 2D-array [[wavs_class1], [wavs_class2]] so that n_examples will work properly
    """

    # mfcc_data: [n x m x l], labels and index
    mfcc_data, label_data, index_data = np.empty(shape=(0, self.channel_size, self.feature_size, self.frame_size), dtype=np.float64), [], []

    # extract class wavs
    for class_wavs, class_annos in zip(wavs, annos):

      # class annotation file names extraction
      class_annos_file_names = [l + i for f, i, l in [self.file_naming_extraction(a, file_ext='.TextGrid') for a in class_annos]]

      # number of class examples
      num_class_examples = 0

      # run through each example in class wavs
      for wav in class_wavs:
        
        # extract file namings
        file_name, file_index, label = self.file_naming_extraction(wav, file_ext=self.dataset_cfg['file_ext'])

        # get annotation if available
        anno = None
        if label + file_index in class_annos_file_names: anno = class_annos[class_annos_file_names.index(label + file_index)]

        # load and pre-process audio
        x, wav_is_useless = self.wav_pre_processing(wav)
        if wav_is_useless: continue

        # print some info
        if self.verbose: print("wav: [{}] with label: [{}], samples=[{}], time=[{}]s".format(wav, label, len(x), len(x) / self.feature_params['fs']))

        # extract feature vectors [m x l]
        mfcc, bon_pos = self.feature_extractor.extract_mfcc(x, reduce_to_best_onset=False)

        # collect wavs
        if self.collect_wavs: self.pre_wavs.append((librosa.util.normalize(x), label + str(file_index) + '_' + set_name, bon_pos))

        # plot mfcc features
        plot_mfcc_profile(x, self.feature_params['fs'], self.feature_extractor.N, self.feature_extractor.hop, mfcc, anno_file=anno, onsets=None, bon_pos=bon_pos, mient=None, minreg=None, frame_size=self.frame_size, plot_path=self.plot_paths['mfcc'], name=label + str(file_index) + '_' + set_name, enable_plot=self.dataset_cfg['enable_plot'])

        # damaged file check
        if self.dataset_cfg['filter_damaged_files']:

          # handle damaged files
          if self.detect_damaged_file(mfcc, wav): continue

        # add to mfcc_data container
        mfcc_data = np.vstack((mfcc_data, mfcc[np.newaxis, :, :, bon_pos:bon_pos+self.frame_size]))
        label_data.append(label)
        index_data.append(label + file_index)

        # update number of examples per class
        num_class_examples += 1

        # stop if desired examples are reached
        if num_class_examples >= n_examples: break


    return mfcc_data, label_data, None, index_data


  def extract_raw_data(self, wavs, annos, n_examples, set_name=None):
    """
    raw data extraction
    """

    # raw data: [n x m], labels and index
    raw_data, label_data, target_data, index_data = np.empty(shape=(0, self.channel_size, self.raw_frame_size), dtype=np.float64), [], np.empty(shape=(0, self.raw_frame_size), dtype=np.int64), []

    # extract class wavs
    for class_wavs, class_annos in zip(wavs, annos):

      # class annotation file names extraction
      class_annos_file_names = [l + i for f, i, l in [self.file_naming_extraction(a, file_ext='.TextGrid') for a in class_annos]]

      # number of class examples
      num_class_examples = 0

      # run through each example in class wavs
      for wav in class_wavs:
        
        # extract file namings
        file_name, file_index, label = self.file_naming_extraction(wav, file_ext=self.dataset_cfg['file_ext'])

        # get annotation if available
        anno = class_annos[class_annos_file_names.index(label + file_index)] if label + file_index in class_annos_file_names else None

        # load and pre-process audio
        x, wav_is_useless = self.wav_pre_processing(wav)
        if wav_is_useless: continue

        # print some info
        if self.verbose: print("wav: [{}] with label: [{}], samples=[{}], time=[{}]s".format(wav, label, len(x), len(x) / self.feature_params['fs']))

        # extract raw samples from region of energy
        raw, bon_pos = self.feature_extractor.get_best_raw_samples(x)

        # add dither and do normalization
        raw = self.wav_post_processing(raw)

        # quantize data
        t = self.feature_extractor.quantize(raw)

        # plot waveform
        if self.dataset_cfg['enable_plot']: plot_waveform(x, self.feature_params['fs'],  bon_samples=[bon_pos, bon_pos+self.raw_frame_size], title=label + file_index, plot_path=self.plot_paths['waveform'], name=label + file_index, show_plot=False, close_plot=True)

        # collect wavs
        if self.collect_wavs: self.pre_wavs.append((librosa.util.normalize(x), label + str(file_index) + '_' + set_name, bon_pos / self.hop))

        # add to mfcc_data container
        raw_data = np.vstack((raw_data, raw[np.newaxis, :]))
        target_data = np.vstack((target_data, t))
        label_data.append(label)
        index_data.append(label + file_index)

        # update number of examples per class
        num_class_examples += 1

        # stop if desired examples are reached
        if num_class_examples >= n_examples: break

    return raw_data, label_data, target_data, index_data


  def detect_damaged_file(self, mfcc, wav):
    """
    detect if file is damaged
    """

    # energy calc
    #e = np.einsum('ij,ji->j', mfcc, mfcc.T)
    #e = e / np.max(e)

    # calculate damaged score of energy deltas
    if mfcc.shape[1] == 39: z_est, z_lim = np.sum(np.abs(mfcc[0, 37:39, :])), 60
    #if mfcc.shape[0] == 39: z_est = np.sum(mfcc[37:39, :] @ mfcc[37:39, :].T)
    #else: z_est = np.sum(np.abs(np.diff(mfcc[-1, :])))
    #else: z_est = np.diff(mfcc[-1, :]) @ np.diff(mfcc[-1, :]).T
    #else: z_est = np.sum(mfcc[0, :])
    #else: z_est = np.diff(mfcc[0, :]) @ np.diff(mfcc[0, :]).T
    #else: z_est = np.abs(np.diff(mfcc[0, :])) @ mfcc[0, :-1].T
    #else: z_est = np.sum(np.diff(mfcc, axis=1) @ np.diff(mfcc, axis=1).T)
    #else: z_est = np.sum(e)
    #else: z_est = np.diff(e) @ np.diff(e).T
    else: z_est, z_lim = mfcc[0, 0, :-1] @ np.abs(np.diff(mfcc[0, 0, :])).T, 3.5

    # add score to list
    self.damaged_score_list.append(z_est)

    # damaged file
    is_damaged = z_est > z_lim

    # add to damaged file list
    if is_damaged: self.damaged_file_list.append((wav, z_est))

    # return score and damaged indicator
    return is_damaged
Ejemplo n.º 2
0
class TestBench():
    """
  test bench class for evaluating models
  """
    def __init__(self, cfg_tb, test_model_path, root_path='./'):

        # arguments
        self.cfg_tb = cfg_tb
        self.test_model_path = test_model_path
        self.root_path = root_path

        # shortcuts
        self.feature_params, self.data_size = None, None

        # paths
        self.paths = dict(
            (k, self.root_path + v) for k, v in self.cfg_tb['paths'].items())

        # test model path
        self.test_model_name = self.test_model_path.split('/')[-2]

        # determine available model files
        model_files_av = [
            f.split('/')[-1] for f in glob(self.test_model_path + '*model.pth')
        ]

        # model file
        self.model_files = [
            self.test_model_path + f for f in model_files_av
            if f in self.cfg_tb['model_file_names']
        ]

        # pick just the first one (errors should not occur)
        self.model_file = self.model_files[0]

        # param file
        self.params_file = self.test_model_path + self.cfg_tb[
            'params_file_name']

        # wavs
        self.test_wavs = [
            self.root_path + wav for wav in self.cfg_tb['test_wavs']
        ]

        # create folder
        create_folder(list(self.paths.values()))

        # parameter loading
        net_params = np.load(self.params_file, allow_pickle=True)

        # extract params
        self.nn_arch, self.train_params, self.class_dict = net_params[
            'nn_arch'][()], net_params['train_params'][(
            )], net_params['class_dict'][()]

        # legacy stuff
        #self.data_size, self.feature_params = self.legacy_adjustments_tb(net_params)

        # legacy stuff
        self.data_size, self.feature_params = legacy_adjustments_net_params(
            net_params)

        # init feature extractor
        self.feature_extractor = FeatureExtractor(self.feature_params)

        # init net handler
        self.net_handler = NetHandler(nn_arch=self.nn_arch,
                                      class_dict=self.class_dict,
                                      data_size=self.data_size,
                                      use_cpu=True)

        # load model
        self.net_handler.load_models(model_files=[self.model_file])

        # set evaluation mode
        self.net_handler.set_eval_mode()

    def test_invariances(self):
        """
    test all invariances
    """

        # init lists
        all_labels, all_corrects_shift, all_corrects_noise, all_probs_shift, all_probs_noise = [], [], [], [], []

        # test model
        print("\n--Test Bench\ntest model: [{}]".format(self.test_model_name))

        # go through each test wav
        for wav in self.test_wavs:

            # print message
            print("\ntest wav: ", wav)

            # file naming extraction
            file_name, file_index, actual_label = self.file_naming_extraction(
                wav)

            # update labels
            all_labels.append(actual_label)

            # read audio from file
            x_wav, _ = librosa.load(wav, sr=self.feature_params['fs'])

            # shift invariance
            corrects_shift, probs_shift = self.test_shift_invariance(
                x_wav, actual_label)

            # noise invariance
            corrects_noise, probs_noise = self.test_noise_invariance(
                x_wav, actual_label, mu=0)

            # collect corrects
            all_corrects_shift.append(corrects_shift)
            all_corrects_noise.append(corrects_noise)

            all_probs_shift.append(probs_shift)
            all_probs_noise.append(probs_noise)

        # some prints
        if self.cfg_tb['enable_info_prints']:
            print("\nall_corrects_shift:\n", all_corrects_shift), print(
                "\nall_corrects_noise:\n",
                all_corrects_noise), print("\nall labels: ", all_labels)

        # plots
        plot_test_bench_shift(x=all_corrects_shift,
                              y=all_labels,
                              context='bench-shift-2',
                              title='shift ' + self.test_model_name,
                              plot_path=self.test_model_path,
                              name='test_bench_shift',
                              show_plot=False)
        plot_test_bench_shift(x=all_probs_shift,
                              y=all_labels,
                              context='bench-shift',
                              title='shift ' + self.test_model_name,
                              plot_path=self.test_model_path,
                              name='test_bench_shift-prob',
                              show_plot=False)

        plot_test_bench_noise(x=all_corrects_noise,
                              y=all_labels,
                              snrs=self.cfg_tb['snrs'],
                              context='bench-noise-2',
                              title='noise ' + self.test_model_name,
                              plot_path=self.test_model_path,
                              name='test_bench_noise',
                              show_plot=False)
        plot_test_bench_noise(x=all_probs_noise,
                              y=all_labels,
                              snrs=self.cfg_tb['snrs'],
                              context='bench-noise',
                              title='noise ' + self.test_model_name,
                              plot_path=self.test_model_path,
                              name='test_bench_noise-prob',
                              show_plot=False)

    def test_noise_invariance(self, x_wav, actual_label, mu=0):
        """
    test model against noise invariance
    """

        # init lists
        pred_label_list, probs = [], []

        # origin
        if self.cfg_tb['enable_plot']:
            plot_waveform(x_wav,
                          self.feature_params['fs'],
                          title='origin actual: [{}]'.format(actual_label),
                          plot_path=self.paths['shift_wavs'],
                          name='{}_origin'.format(actual_label))

        # test model with different snr values
        for snr in self.cfg_tb['snrs']:

            # signal power
            p_x_eff = x_wav @ x_wav.T / len(x_wav)

            # calculate noise signal power
            sigma = np.sqrt(p_x_eff / (10**(snr / 10)))

            # noise generation
            n = np.random.normal(mu, sigma, len(x_wav))

            # add noise
            x_noise = x_wav + n

            # noise signal power
            p_n_eff = n @ n.T / len(n)

            # print energy info
            # print("sigma: ", sigma), print("p_x: ", p_x_eff), print("p_n: ", p_n_eff), print("db: ", 10 * np.log10(p_x_eff / p_n_eff))

            # feature extraction
            #x_mfcc, _ = self.feature_extractor.extract_mfcc(x_noise, reduce_to_best_onset=True)

            # feature extraction
            x, _ = self.feature_extractor.extract_mfcc(
                x_noise, reduce_to_best_onset=True
            ) if self.net_handler.nn_arch != 'wavenet' else self.feature_extractor.get_best_raw_samples(
                x_noise, add_channel_dim=True)

            # classify
            y_hat, o, pred_label = self.net_handler.classify_sample(x)

            # append predicted label and probs
            pred_label_list.append(pred_label)
            probs.append(float(o[0, self.class_dict[actual_label]]))

            # plot wavs
            if self.cfg_tb['enable_plot']:
                plot_waveform(x_noise,
                              self.feature_params['fs'],
                              title='snr: [{}] actual: [{}] pred: [{}]'.format(
                                  snr, actual_label, pred_label),
                              plot_path=self.paths['noise_wavs'],
                              name='{}_snr{}'.format(actual_label, snr))

        # correct list
        corrects = [int(actual_label == l) for l in pred_label_list]

        # print message
        print("test bench noise acc: ", np.sum(corrects) / len(corrects))

        return corrects, probs

    def test_shift_invariance(self, x_wav, actual_label):
        """
    test model against shift invariance
    """

        # init lists
        pred_label_list, probs = [], []

        # feature extraction
        x, _ = self.feature_extractor.extract_mfcc(
            x_wav, reduce_to_best_onset=False
        ) if self.net_handler.nn_arch != 'wavenet' else (x_wav[np.newaxis, :],
                                                         0)

        # windowed
        x_win = np.squeeze(
            view_as_windows(
                x, self.data_size, step=self.cfg_tb['shift_frame_step']),
            axis=(0,
                  1)) if self.net_handler.nn_arch != 'wavenet' else np.squeeze(
                      view_as_windows(x,
                                      self.data_size,
                                      step=self.cfg_tb['shift_frame_step'] *
                                      self.feature_extractor.hop),
                      axis=0)

        for i, x in enumerate(x_win):

            # classify
            y_hat, o, pred_label = self.net_handler.classify_sample(x)

            # append predicted label
            pred_label_list.append(pred_label)
            probs.append(float(o[0, self.class_dict[actual_label]]))

            # plot
            time_s = frames_to_sample(i * self.cfg_tb['shift_frame_step'],
                                      self.feature_params['fs'],
                                      self.feature_extractor.hop)
            time_e = frames_to_sample(
                i * self.cfg_tb['shift_frame_step'] +
                self.feature_params['frame_size'], self.feature_params['fs'],
                self.feature_extractor.hop)

            # plot waveform
            if self.cfg_tb['enable_plot']:
                plot_waveform(x_wav[time_s:time_e],
                              self.feature_params['fs'],
                              title='frame{} actual: [{}] pred: [{}]'.format(
                                  i, actual_label, pred_label),
                              plot_path=self.paths['shift_wavs'],
                              name='{}_frame{}'.format(actual_label, i))

        # correct list
        corrects = [int(actual_label == l) for l in pred_label_list]

        # print message
        print("test bench shift acc: ", np.sum(corrects) / len(corrects))

        return corrects, probs

    def file_naming_extraction(self, file, file_ext='.wav'):
        """
    extract file name ergo label
    """

        # extract filename
        file_name = re.findall(r'[\w+ 0-9]*' + re.escape(file_ext), file)[0]

        # extract file index from filename
        file_index = re.sub(r'[a-z A-Z]|(' + re.escape(file_ext) + r')', '',
                            file_name)

        # extract label from filename
        label = re.sub(r'([0-9]*' + re.escape(file_ext) + r')', '', file_name)

        return file_name, file_index, label