Exemplo n.º 1
0
    def example(self, s, d, s_len, d_len, snr):
        """
		Compute example for Deep Xi, i.e. observation (noisy-speech STMS)
		and target (gain).

		Argument/s:
			s - clean speech (dtype=tf.int32).
			d - noise (dtype=tf.int32).
			s_len - clean-speech length without padding (samples).
			d_len - noise length without padding (samples).
			snr - SNR level.

		Returns:
			x_STMS - noisy-speech short-time magnitude spectrum.
			gain - gain.
			n_frames - number of time-domain frames.
		"""
        s, d, x, n_frames = self.mix(s, d, s_len, d_len, snr)
        s_STMS, _ = self.polar_analysis(s)
        d_STMS, _ = self.polar_analysis(d)
        x_STMS, _ = self.polar_analysis(x)
        xi = self.xi(s_STMS, d_STMS)  # instantaneous a priori SNR.
        gamma = self.gamma(x_STMS, d_STMS)  # instantaneous a posteriori SNR.
        G = gfunc(xi=xi, gamma=gamma, gtype=self.gain)
        # IRM = tf.math.sqrt(tf.math.truediv(xi, tf.math.add(xi, self.one)))
        return x_STMS, G, n_frames
Exemplo n.º 2
0
    def enhanced_speech(self, x_STMS, x_STPS, xi_bar_hat, gtype):
        """
		Compute enhanced speech.

		Argument/s:
			x_STMS - noisy-speech short-time magnitude spectrum.
			x_STPS - noisy-speech short-time phase spectrum.
			xi_bar_hat - mapped a priori SNR estimate.
			gtype - gain function type.

		Returns:
			enhanced speech.
		"""
        xi_hat = self.xi_map.inverse(xi_bar_hat)
        gamma_hat = tf.math.add(xi_hat, self.one)
        y_STMS = tf.math.multiply(x_STMS, gfunc(xi_hat, gamma_hat, gtype))
        return self.polar_synthesis(y_STMS, x_STPS)
Exemplo n.º 3
0
    def enhanced_speech(self, x_STMS, x_STPS_xi_hat_mat, gamma_bar_hat, gtype):
        """
		Compute enhanced speech.

		Argument/s:
			x_STMS - noisy-speech short-time magnitude spectrum.
			x_STPS_xi_hat_mat - tuple of noisy-speech short-time phase spectrum and
				a priori SNR loaded from .mat file.
			gamma_bar_hat - mapped a priori SNR estimate.
			gtype - gain function type.

		Returns:
			enhanced speech.
		"""
        gamma_hat = self.gamma_map.inverse(gamma_bar_hat)
        x_STPS, xi_hat_mat = x_STPS_xi_hat_mat
        xi_hat = xi_hat_mat['xi_hat']
        y_STMS = tf.math.multiply(x_STMS, gfunc(xi_hat, gamma_hat, gtype))
        return self.polar_synthesis(y_STMS, x_STPS)
Exemplo n.º 4
0
    def enhanced_speech(self, x_STMS, x_STPS, xi_gamma_bar_hat, gtype):
        """
		Compute enhanced speech.

		Argument/s:
			x_STMS - noisy-speech short-time magnitude spectrum.
			x_STPS - noisy-speech short-time phase spectrum.
			xi_gamma_bar_hat - mapped a priori and a posteriorir SNR estimate.
			gtype - gain function type.

		Returns:
			enhanced speech.
		"""
        xi_bar_hat, gamma_bar_hat = tf.split(xi_gamma_bar_hat,
                                             num_or_size_splits=2,
                                             axis=-1)
        xi_hat = self.xi_map.inverse(xi_bar_hat)
        gamma_hat = self.gamma_map.inverse(gamma_bar_hat)
        y_STMS = tf.math.multiply(x_STMS, gfunc(xi_hat, gamma_hat, gtype))
        return self.polar_synthesis(y_STMS, x_STPS)
Exemplo n.º 5
0
    def enhanced_speech(self, x_STMS_STPS, dummy, xi_s_stps_bar_hat, gtype):
        """
		Compute enhanced speech.

		Argument/s:
			x_STMS_STPS - noisy-speech short-time magnitude and phase spectrum.
			dummy - dummy variable.
			xi_s_stps_bar_hat - mapped a priori SNR and clean-speech STPS estimate.
			gtype - gain function type.

		Returns:
			enhanced speech.
		"""
        x_STMS, _ = tf.split(x_STMS_STPS, num_or_size_splits=2, axis=-1)
        xi_bar_hat, s_stps_bar_hat = tf.split(xi_s_stps_bar_hat,
                                              num_or_size_splits=2,
                                              axis=-1)
        xi_hat = self.xi_map.inverse(xi_bar_hat)
        gamma_hat = tf.math.add(xi_hat, self.one)
        y_STPS = self.s_stps_map.inverse(s_stps_bar_hat)
        y_STMS = tf.math.multiply(x_STMS, gfunc(xi_hat, gamma_hat, gtype))
        return self.polar_synthesis(y_STMS, y_STPS)
Exemplo n.º 6
0
    def enhanced_speech(self, x_STDCT, dummy, xi_cd_bar_hat, gtype):
        """
		Compute enhanced speech.

		Argument/s:
			x_STDCT - noisy-speech short-time magnitude spectrum.
			dummy - dummy variable (not used).
			____ - _____________________.
			gtype - gain function type.

		Returns:
			enhanced speech.
		"""
        xi_bar_hat, cd_bar_hat = tf.split(xi_cd_bar_hat,
                                          num_or_size_splits=2,
                                          axis=-1)
        xi_hat = self.xi_map.inverse(xi_bar_hat)
        gamma_hat = tf.math.add(xi_hat, self.one)
        cd_hat = self.cd_map.inverse(cd_bar_hat)
        cdm_hat = tf.math.greater(cd_hat, 0.0)
        y_STDCT = tf.math.multiply(x_STDCT,
                                   gfunc(xi_hat, gamma_hat, gtype, cdm_hat))
        return self.stdct_synthesis(y_STDCT)
Exemplo n.º 7
0
    def test(self,
             test_x,
             test_x_len,
             test_x_base_names,
             test_s,
             test_s_len,
             test_s_base_names,
             test_epoch,
             model_path='model',
             gain='mmse-lsa',
             stats_path=None):
        """
		Deep Xi testing. Objective measures are used to evaluate the performance
		of Deep Xi.

		Argument/s:
			test_x - noisy-speech test batch.
			test_x_len - noisy-speech test batch lengths.
			test_x_base_names - noisy-speech base names.
			test_s - clean-speech test batch.
			test_s_len - clean-speech test batch lengths.
			test_s_base_names - clean-speech base names.
			test_epoch - epoch to test.
			model_path - path to model directory.
			gain - gain function (see deepxi/args.py).
			stats_path - path to the saved statistics.

		"""
        if not isinstance(test_epoch, list): test_epoch = [test_epoch]
        if not isinstance(gain, list): gain = [gain]
        for e in test_epoch:
            for g in gain:

                if e < 1:
                    raise ValueError("test_epoch must be greater than 0.")

                self.sample_stats(stats_path)
                self.model.load_weights(model_path + '/epoch-' + str(e - 1) +
                                        '/variables/variables')

                print("Processing observations...")
                x_STMS_batch, x_STPS_batch, n_frames = self.observation_batch(
                    test_x, test_x_len)
                print("Performing inference...")
                xi_bar_hat_batch = self.model.predict(x_STMS_batch,
                                                      batch_size=1,
                                                      verbose=1)

                print("Performing synthesis and objective scoring...")
                results = {}
                batch_size = len(test_x_len)
                for i in tqdm(range(batch_size)):
                    base_name = test_x_base_names[i]
                    x_STMS = x_STMS_batch[i, :n_frames[i], :]
                    x_STPS = x_STPS_batch[i, :n_frames[i], :]
                    xi_bar_hat = xi_bar_hat_batch[i, :n_frames[i], :]
                    xi_hat = self.xi_hat(xi_bar_hat)
                    y_STMS = np.multiply(x_STMS,
                                         gfunc(xi_hat, xi_hat + 1, gtype=g))
                    y = self.polar_synthesis(y_STMS, x_STPS).numpy()

                    for (j, basename) in enumerate(test_s_base_names):
                        if basename in test_x_base_names[i]: ref_idx = j

                    s = self.normalise(test_s[ref_idx,
                                              0:test_s_len[ref_idx]]).numpy()
                    y = y[0:len(s)]

                    noise_source = test_x_base_names[i].split("_")[-2]
                    snr_level = int(test_x_base_names[i].split("_")[-1][:-2])

                    results = self.add_score(
                        results, (noise_source, snr_level, 'STOI'),
                        100 * stoi(s, y, self.f_s, extended=False))
                    results = self.add_score(
                        results, (noise_source, snr_level, 'eSTOI'),
                        100 * stoi(s, y, self.f_s, extended=True))
                    results = self.add_score(results,
                                             (noise_source, snr_level, 'PESQ'),
                                             pesq(self.f_s, s, y, 'nb'))
                    results = self.add_score(
                        results, (noise_source, snr_level, 'MOS-LQO'),
                        pesq(self.f_s, s, y, 'wb'))

                noise_sources, snr_levels, metrics = set(), set(), set()
                for key, value in results.items():
                    noise_sources.add(key[0])
                    snr_levels.add(key[1])
                    metrics.add(key[2])

                if not os.path.exists("log/results"):
                    os.makedirs("log/results")

                with open(
                        "log/results/" + self.ver + "_e" + str(e) + '_' + g +
                        ".csv", "w") as f:
                    f.write("noise,snr_db")
                    for k in sorted(metrics):
                        f.write(',' + k)
                    f.write('\n')
                    for i in sorted(noise_sources):
                        for j in sorted(snr_levels):
                            f.write("{},{}".format(i, j))
                            for k in sorted(metrics):
                                if (i, j, k) in results.keys():
                                    f.write(",{:.2f}".format(
                                        np.mean(results[(i, j, k)])))
                            f.write('\n')

                avg_results = {}
                for i in sorted(noise_sources):
                    for j in sorted(snr_levels):
                        if (j >= self.min_snr) and (j <= self.max_snr):
                            for k in sorted(metrics):
                                if (i, j, k) in results.keys():
                                    avg_results = self.add_score(
                                        avg_results, k, results[(i, j, k)])

                if not os.path.exists("log/results/average.csv"):
                    with open("log/results/average.csv", "w") as f:
                        f.write("ver")
                        for i in sorted(metrics):
                            f.write("," + i)
                        f.write('\n')

                with open("log/results/average.csv", "a") as f:
                    f.write(self.ver + "_e" + str(e) + '_' + g)
                    for i in sorted(metrics):
                        if i in avg_results.keys():
                            f.write(",{:.2f}".format(np.mean(avg_results[i])))
                    f.write('\n')
Exemplo n.º 8
0
    def infer(  ## NEED TO ADD DeepMMSE
        self,
        test_x,
        test_x_len,
        test_x_base_names,
        test_epoch,
        model_path='model',
        out_type='y',
        gain='mmse-lsa',
        out_path='out',
        stats_path=None,
        n_filters=40,
    ):
        """
		Deep Xi inference. The specified 'out_type' is saved.

		Argument/s:
			test_x - noisy-speech test batch.
			test_x_len - noisy-speech test batch lengths.
			test_x_base_names - noisy-speech base names.
			test_epoch - epoch to test.
			model_path - path to model directory.
			out_type - output type (see deepxi/args.py).
			gain - gain function (see deepxi/args.py).
			out_path - path to save output files.
			stats_path - path to the saved statistics.
		"""
        if out_type == 'xi_hat': out_path = out_path + '/xi_hat'
        elif out_type == 'y': out_path = out_path + '/y/' + gain
        elif out_type == 'deepmmse': out_path = out_path + '/deepmmse'
        elif out_type == 'ibm_hat': out_path = out_path + '/ibm_hat'
        elif out_type == 'subband_ibm_hat':
            out_path = out_path + '/subband_ibm_hat'
        else:
            raise ValueError('Invalid output type.')
        if not os.path.exists(out_path): os.makedirs(out_path)

        if test_epoch < 1:
            raise ValueError("test_epoch must be greater than 0.")

        # The mel-scale filter bank is to compute an ideal binary mask (IBM)
        # estimate for log-spectral subband energies (LSSE).
        if out_type == 'subband_ibm_hat':
            mel_filter_bank = self.mel_filter_bank(n_filters)

        self.sample_stats(stats_path)
        self.model.load_weights(model_path + '/epoch-' + str(test_epoch - 1) +
                                '/variables/variables')

        print("Processing observations...")
        x_STMS_batch, x_STPS_batch, n_frames = self.observation_batch(
            test_x, test_x_len)
        print("Performing inference...")
        xi_bar_hat_batch = self.model.predict(x_STMS_batch,
                                              batch_size=1,
                                              verbose=1)

        print("Performing synthesis...")
        batch_size = len(test_x_len)
        for i in tqdm(range(batch_size)):
            base_name = test_x_base_names[i]
            x_STMS = x_STMS_batch[i, :n_frames[i], :]
            x_STPS = x_STPS_batch[i, :n_frames[i], :]
            xi_bar_hat = xi_bar_hat_batch[i, :n_frames[i], :]
            xi_hat = self.xi_hat(xi_bar_hat)
            if out_type == 'xi_hat':
                save_mat(args.out_path + '/' + base_name + '.mat', xi_hat,
                         'xi_hat')
            elif out_type == 'y':
                y_STMS = np.multiply(x_STMS,
                                     gfunc(xi_hat, xi_hat + 1, gtype=gain))
                y = self.polar_synthesis(y_STMS, x_STPS).numpy()
                save_wav(out_path + '/' + base_name + '.wav', y, self.f_s)
            elif out_type == 'deepmmse':
                d_PSD_hat = np.multiply(
                    np.square(x_STMS),
                    gfunc(xi_hat, xi_hat + 1, gtype='deepmmse'))
                save_mat(out_path + '/' + base_name + '.mat', d_PSD_hat,
                         'd_psd_hat')
            elif out_type == 'ibm_hat':
                ibm_hat = np.greater(xi_hat, 1.0).astype(bool)
                save_mat(out_path + '/' + base_name + '.mat', ibm_hat,
                         'ibm_hat')
            elif out_type == 'subband_ibm_hat':
                xi_hat_subband = np.matmul(xi_hat, mel_filter_bank.transpose())
                subband_ibm_hat = np.greater(xi_hat_subband, 1.0).astype(bool)
                save_mat(out_path + '/' + base_name + '.mat', subband_ibm_hat,
                         'subband_ibm_hat')
            else:
                raise ValueError('Invalid output type.')
Exemplo n.º 9
0
    def infer(
        self,
        test_x,
        test_x_len,
        test_x_base_names,
        test_epoch,
        model_path='model',
        out_type='y',
        gain='mmse-lsa',
        out_path='out',
        n_filters=40,
        saved_data_path=None,
    ):
        """
		Deep Xi inference. The specified 'out_type' is saved.

		Argument/s:
			test_x - noisy-speech test batch.
			test_x_len - noisy-speech test batch lengths.
			test_x_base_names - noisy-speech base names.
			test_epoch - epoch to test.
			model_path - path to model directory.
			out_type - output type (see deepxi/args.py).
			gain - gain function (see deepxi/args.py).
			out_path - path to save output files.
			saved_data_path - path to saved data necessary for enhancement.
		"""
        out_path_base = out_path
        if not isinstance(test_epoch, list): test_epoch = [test_epoch]
        if not isinstance(gain, list): gain = [gain]

        # The mel-scale filter bank is to compute an ideal binary mask (IBM)
        # estimate for log-spectral subband energies (LSSE).
        if out_type == 'subband_ibm_hat':
            mel_filter_bank = self.mel_filter_bank(n_filters)

        for e in test_epoch:
            if e < 1: raise ValueError("test_epoch must be greater than 0.")
            for g in gain:

                out_path = out_path_base + '/' + self.ver + '/' + 'e' + str(
                    e)  # output path.
                if out_type == 'xi_hat': out_path = out_path + '/xi_hat'
                elif out_type == 'gamma_hat':
                    out_path = out_path + '/gamma_hat'
                elif out_type == 's_STPS_hat':
                    out_path = out_path + '/s_STPS_hat'
                elif out_type == 'y':
                    if self.inp_tgt_type == 'MagIRM':
                        out_path = out_path + '/y'
                    else:
                        out_path = out_path + '/y/' + g
                elif out_type == 'deepmmse':
                    out_path = out_path + '/deepmmse'
                elif out_type == 'ibm_hat':
                    out_path = out_path + '/ibm_hat'
                elif out_type == 'subband_ibm_hat':
                    out_path = out_path + '/subband_ibm_hat'
                elif out_type == 'cd_hat':
                    out_path = out_path + '/cd_hat'
                else:
                    raise ValueError('Invalid output type.')
                if not os.path.exists(out_path): os.makedirs(out_path)

                self.model.load_weights(model_path + '/epoch-' + str(e - 1) +
                                        '/variables/variables')

                print("Processing observations...")
                inp_batch, supplementary_batch, n_frames = self.observation_batch(
                    test_x, test_x_len)

                print("Performing inference...")
                tgt_hat_batch = self.model.predict(inp_batch,
                                                   batch_size=1,
                                                   verbose=1)

                print("Saving outputs...")
                batch_size = len(test_x_len)
                for i in tqdm(range(batch_size)):
                    base_name = test_x_base_names[i]
                    inp = inp_batch[i, :n_frames[i], :]
                    tgt_hat = tgt_hat_batch[i, :n_frames[i], :]

                    # if tf.is_tensor(supplementary_batch):
                    supplementary = supplementary_batch[i, :n_frames[i], :]

                    if saved_data_path is not None:
                        saved_data = read_mat(saved_data_path + '/' +
                                              base_name + '.mat')
                        supplementary = (supplementary, saved_data)

                    if out_type == 'xi_hat':
                        xi_hat = self.inp_tgt.xi_hat(tgt_hat)
                        save_mat(out_path + '/' + base_name + '.mat', xi_hat,
                                 'xi_hat')
                    elif out_type == 'gamma_hat':
                        gamma_hat = self.inp_tgt.gamma_hat(tgt_hat)
                        save_mat(out_path + '/' + base_name + '.mat',
                                 gamma_hat, 'gamma_hat')
                    elif out_type == 's_STPS_hat':
                        s_STPS_hat = self.inp_tgt.s_stps_hat(tgt_hat)
                        save_mat(out_path + '/' + base_name + '.mat',
                                 s_STPS_hat, 's_STPS_hat')
                    elif out_type == 'y':
                        y = self.inp_tgt.enhanced_speech(
                            inp, supplementary, tgt_hat, g).numpy()
                        save_wav(out_path + '/' + base_name + '.wav', y,
                                 self.inp_tgt.f_s)
                    elif out_type == 'deepmmse':
                        xi_hat = self.inp_tgt.xi_hat(tgt_hat)
                        d_PSD_hat = np.multiply(
                            np.square(inp),
                            gfunc(xi_hat, xi_hat + 1.0, gtype='deepmmse'))
                        save_mat(out_path + '/' + base_name + '.mat',
                                 d_PSD_hat, 'd_psd_hat')
                    elif out_type == 'ibm_hat':
                        xi_hat = self.inp_tgt.xi_hat(tgt_hat)
                        ibm_hat = np.greater(xi_hat, 1.0).astype(bool)
                        save_mat(out_path + '/' + base_name + '.mat', ibm_hat,
                                 'ibm_hat')
                    elif out_type == 'subband_ibm_hat':
                        xi_hat = self.inp_tgt.xi_hat(tgt_hat)
                        xi_hat_subband = np.matmul(xi_hat,
                                                   mel_filter_bank.transpose())
                        subband_ibm_hat = np.greater(xi_hat_subband,
                                                     1.0).astype(bool)
                        save_mat(out_path + '/' + base_name + '.mat',
                                 subband_ibm_hat, 'subband_ibm_hat')
                    elif out_type == 'cd_hat':
                        cd_hat = self.inp_tgt.cd_hat(tgt_hat)
                        save_mat(out_path + '/' + base_name + '.mat', cd_hat,
                                 'cd_hat')
                    else:
                        raise ValueError('Invalid output type.')