Exemple #1
0
    x_PS_ph = tf.placeholder(tf.float32,
                             shape=[None, args.input_dim],
                             name='x_PS_ph')  # noisy speech PS placeholder.
    xi_hat_ph = tf.placeholder(
        tf.float32, shape=[None, args.input_dim],
        name='xi_hat_ph')  # a priori SNR estimate placeholder.
    G_ph = tf.placeholder(tf.float32,
                          shape=[None, args.input_dim],
                          name='G_ph')  # gain function placeholder.

    ## ANALYSIS
    x = tf.div(
        tf.cast(tf.slice(tf.squeeze(x_ph), [0], [tf.squeeze(x_len_ph)]),
                tf.float32), args.nconst)  # remove padding and normalise.
    x_DFT = feat.stft(
        x, args.Nw, args.Ns,
        args.NFFT)  # noisy speech single-sided short-time Fourier transform.
    x_MS_3D = tf.expand_dims(
        tf.abs(x_DFT),
        0)  # noisy speech single-sided magnitude spectrum (in 3D form).
    x_MS = tf.abs(x_DFT)  # noisy speech single-sided magnitude spectrum.
    x_PS = tf.angle(x_DFT)  # noisy speech single-sided phase spectrum.
    x_seq_len = feat.nframes(x_len_ph, args.Ns)  # length of each sequence.

    ## ENHANCEMENT
    if args.gain == 'ibm':
        G = tf.cast(tf.greater(xi_hat_ph, 1), tf.float32)  # IBM gain function.
    if args.gain == 'wf':
        G = tf.div(xi_hat_ph, tf.add(xi_hat_ph, 1.0))  # WF gain function.
    if args.gain == 'srwf':
        G = tf.sqrt(tf.div(xi_hat_ph, tf.add(xi_hat_ph,
Exemple #2
0
    def __init__(self, args):

        ## PLACEHOLDERS
        self.s_ph = tf.placeholder(tf.int16, shape=[None, None],
                                   name='s_ph')  # clean speech placeholder.
        self.d_ph = tf.placeholder(tf.int16, shape=[None, None],
                                   name='d_ph')  # noise placeholder.
        self.x_ph = tf.placeholder(tf.int16, shape=[None, None],
                                   name='x_ph')  # noisy speech placeholder.
        self.s_len_ph = tf.placeholder(
            tf.int32, shape=[None],
            name='s_len_ph')  # clean speech sequence length placeholder.
        self.d_len_ph = tf.placeholder(
            tf.int32, shape=[None],
            name='d_len_ph')  # noise sequence length placeholder.
        self.x_len_ph = tf.placeholder(
            tf.int32, shape=[None],
            name='x_len_ph')  # noisy speech sequence length placeholder.
        self.snr_ph = tf.placeholder(tf.float32, shape=[None],
                                     name='snr_ph')  # SNR placeholder.
        self.x_MS_ph = tf.placeholder(
            tf.float32, shape=[None, None, args.input_dim],
            name='x_MS_ph')  # noisy speech MS placeholder.
        self.x_MS_len_ph = tf.placeholder(
            tf.int32, shape=[None],
            name='x_MS_len_ph')  # noisy speech MS sequence length placeholder.
        self.target_ph = tf.placeholder(
            tf.float32, shape=[None, args.input_dim],
            name='target_phh')  # training target placeholder.
        self.keep_prob_ph = tf.placeholder(
            tf.float32, name='keep_prob_ph')  # keep probability placeholder.
        self.training_ph = tf.placeholder(
            tf.bool, name='training_ph')  # training placeholder.

        ## A PRIORI SNR IN DB STATISTICS
        self.mu = tf.constant(args.mu_mat['mu'], dtype=tf.float32)
        self.sigma = tf.constant(args.sigma_mat['sigma'], dtype=tf.float32)

        ## FEATURE GRAPH
        print('Preparing graph...')
        self.P = tf.reduce_max(self.s_len_ph)  # padded waveform length.
        self.feature = feat.xi_mapped(self.s_ph, self.d_ph, self.s_len_ph,
                                      self.d_len_ph, self.snr_ph, args.Nw,
                                      args.Ns, args.NFFT, args.fs, self.P,
                                      args.nconst, self.mu,
                                      self.sigma)  # feature graph.

        ## RESNET
        self.output = residual.Residual(self.x_MS_ph, self.x_MS_len_ph,
                                        self.keep_prob_ph, self.training_ph,
                                        args.num_outputs, args)

        ## LOSS & OPTIMIZER
        self.loss = residual.loss(self.target_ph, self.output,
                                  'sigmoid_cross_entropy')
        self.total_loss = tf.reduce_mean(self.loss, axis=0)
        self.trainer, _ = residual.optimizer(self.total_loss,
                                             optimizer='adam',
                                             grad_clip=True)

        ## SAVE VARIABLES
        self.saver = tf.train.Saver(max_to_keep=256)

        ## NUMBER OF PARAMETERS
        if args.verbose:
            print("No. of trainable parameters: %g." % (np.sum([
                np.prod(v.get_shape().as_list())
                for v in tf.trainable_variables()
            ])))

        ## INFERENCE GRAPH
        if args.infer:

            ## PLACEHOLDERS
            self.output_ph = tf.placeholder(
                tf.float32, shape=[None, args.input_dim],
                name='output_ph')  # network output placeholder.
            self.x_MS_2D_ph = tf.placeholder(
                tf.float32, shape=[None, args.input_dim],
                name='x_MS_2D_ph')  # noisy speech MS placeholder (in 2D form).
            self.x_PS_ph = tf.placeholder(
                tf.float32, shape=[None, args.input_dim],
                name='x_PS_ph')  # noisy speech PS placeholder.
            self.xi_hat_ph = tf.placeholder(
                tf.float32, shape=[None, args.input_dim],
                name='xi_hat_ph')  # a priori SNR estimate placeholder.
            self.G_ph = tf.placeholder(
                tf.float32, shape=[None, args.input_dim],
                name='G_ph')  # gain function placeholder.

            ## ANALYSIS
            self.x = tf.truediv(
                tf.cast(
                    tf.slice(tf.squeeze(self.x_ph), [0],
                             [tf.squeeze(self.x_len_ph)]), tf.float32),
                args.nconst)  # remove padding and normalise.
            self.x_DFT = feat.stft(
                self.x, args.Nw, args.Ns, args.NFFT
            )  # noisy speech single-sided short-time Fourier transform.
            self.x_MS_3D = tf.expand_dims(
                tf.abs(self.x_DFT), 0
            )  # noisy speech single-sided magnitude spectrum (in 3D form).
            self.x_MS = tf.abs(
                self.x_DFT)  # noisy speech single-sided magnitude spectrum.
            self.x_PS = tf.angle(
                self.x_DFT)  # noisy speech single-sided phase spectrum.
            self.x_seq_len = feat.nframes(self.x_len_ph,
                                          args.Ns)  # length of each sequence.

            ## MODIFICATION (SPEECH ENHANCEMENT)
            if args.gain == 'ibm':
                self.G = tf.cast(tf.greater(self.xi_hat_ph, 1),
                                 tf.float32)  # IBM gain function.
            if args.gain == 'wf':
                self.G = tf.truediv(self.xi_hat_ph,
                                    tf.add(self.xi_hat_ph,
                                           1.0))  # WF gain function.
            if args.gain == 'srwf':
                self.G = tf.sqrt(
                    tf.truediv(self.xi_hat_ph,
                               tf.add(self.xi_hat_ph,
                                      1.0)))  # SRWF gain function.
            if args.gain == 'irm':
                self.G = tf.sqrt(
                    tf.truediv(self.xi_hat_ph,
                               tf.add(self.xi_hat_ph,
                                      1.0)))  # IRM gain function.
            if args.gain == 'cwf':
                self.G = tf.sqrt(self.xi_hat_ph)
                self.G = tf.truediv(self.G, tf.add(self.G,
                                                   1.0))  # cWF gain function.
            self.s_hat_MS = tf.multiply(
                self.x_MS_2D_ph,
                self.G_ph)  # enhanced speech single-sided magnitude spectrum.

            ## SYNTHESIS GRAPH
            self.y_DFT = tf.cast(self.s_hat_MS, tf.complex64) * tf.exp(
                1j * tf.cast(self.x_PS_ph, tf.complex64)
            )  # enhanced speech single-sided short-time Fourier transform.
            self.y = tf.contrib.signal.inverse_stft(
                self.y_DFT, args.Nw, args.Ns, args.NFFT,
                tf.contrib.signal.inverse_stft_window_fn(
                    args.Ns,
                    forward_window_fn=tf.contrib.signal.hamming_window)
            )  # synthesis.