Esempio n. 1
0
def main(_):
    print('Parsed arguments: ', FLAGS.__flags)

    # make save path if it is required
    if not os.path.exists(FLAGS.save_path):
        os.makedirs(FLAGS.save_path)
    if not os.path.exists(FLAGS.synthesis_path):
        os.makedirs(FLAGS.synthesis_path)
    np.random.seed(FLAGS.seed)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    udevices = []
    for device in devices:
        if len(devices) > 1 and 'CPU' in device.name:
            # Use cpu only when we dont have gpus
            # udevices.append(device.name)
            continue
        print('Using device: ', device.name)
        udevices.append(device.name)
    # execute the session
    with tf.Session(config=config) as sess:
        if FLAGS.model == 'gan':
            print('Creating GAN model')
            se_model = SEGAN(sess, FLAGS, udevices)
        elif FLAGS.model == 'ae':
            print('Creating AE model')
            se_model = SEAE(sess, FLAGS, udevices)
        else:
            raise ValueError('{} model type not understood!'.format(
                FLAGS.model))
        if FLAGS.test_wav is None:
            se_model.train(FLAGS, udevices)
        else:
            if FLAGS.weights is None:
                raise ValueError('weights must be specified!')
            print('Loading model weights...')
            se_model.load(FLAGS.save_path, FLAGS.weights)
            fm, wav_data = wavfile.read(FLAGS.test_wav)
            print(fm)
            wavname = FLAGS.test_wav.split('/')[-1]
            if fm != 16000:
                raise ValueError('16kHz required! Test file is different')
            wave = (2. / 65535.) * (wav_data.astype(np.float32) - 32767) + 1.
            if FLAGS.preemph > 0:
                print('preemph test wave with {}'.format(FLAGS.preemph))
                x_pholder, preemph_op = pre_emph_test(FLAGS.preemph,
                                                      wave.shape[0])
                wave = sess.run(preemph_op, feed_dict={x_pholder: wave})
            print('test wave shape: ', wave.shape)
            print('test wave min:{}  max:{}'.format(np.min(wave),
                                                    np.max(wave)))
            c_wave = se_model.clean(wave)
            print('c wave min:{}  max:{}'.format(np.min(c_wave),
                                                 np.max(c_wave)))
            wavfile.write(os.path.join(FLAGS.save_clean_path, wavname), 16000,
                          c_wave)
            print('Done cleaning {} and saved '
                  'to {}'.format(FLAGS.test_wav,
                                 os.path.join(FLAGS.save_clean_path, wavname)))
Esempio n. 2
0
def main(_):
    print('Parsed arguments: ', FLAGS.__flags)

    # make save path if it is required
    if not os.path.exists(FLAGS.save_path):
        os.makedirs(FLAGS.save_path)
    if not os.path.exists(FLAGS.synthesis_path):
        os.makedirs(FLAGS.synthesis_path)
    np.random.seed(FLAGS.seed)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement=True
    udevices = []
    for device in devices:
        if len(devices) > 1 and 'cpu' in device.name:
            # Use cpu only when we dont have gpus
            continue
        print('Using device: ', device.name)
        udevices.append(device.name)
    # execute the session
    with tf.Session(config=config) as sess:
        if FLAGS.model == 'gan':
            print('Creating GAN model')
            se_model = SEGAN(sess, FLAGS, udevices)
        elif FLAGS.model == 'ae':
            print('Creating AE model')
            se_model = SEAE(sess, FLAGS, udevices)
        else:
            raise ValueError('{} model type not understood!'.format(FLAGS.model))

        if FLAGS.test_wavs is None:
            se_model.train(FLAGS, udevices)
        else:
            if FLAGS.weights is None:
                raise ValueError('weights must be specified!')
            print('Loading model weights...')
            se_model.load(FLAGS.save_path, FLAGS.weights)
            FLAGS.test_wavs = FLAGS.test_wavs.split(',')
            for i in range(len(FLAGS.test_wavs)):
                wav_data, fm = librosa.load(FLAGS.test_wavs[i], sr=16000)
                wavname = FLAGS.test_wavs[i].split('/')[-1].split('.')[0] +'_segan.wav'

                wave = (2./65535.) * (wav_data.astype(np.float32) - 32767) + 1.

                if FLAGS.preemph > 0:
                    print('preemph test wave with {}'.format(FLAGS.preemph))
                    x_pholder, preemph_op = pre_emph_test(FLAGS.preemph, wav_data.shape[0])
                    wave = sess.run(preemph_op, feed_dict={x_pholder: wav_data})

                print("【*】Clean start time ...", time.asctime(time.localtime(time.time())))
                c_wave = se_model.clean(wave)
                print("【*】Clean end time ...", time.asctime(time.localtime(time.time())))
                print('c wave min:{}  max:{}'.format(np.min(c_wave), np.max(c_wave)))
                print(c_wave)

                librosa.output.write_wav(os.path.join(FLAGS.save_clean_path, wavname), c_wave, sr=16000)

                print('Done cleaning {} and saved '
                      'to {}'.format(FLAGS.test_wavs[i], os.path.join(FLAGS.save_clean_path, wavname)))
Esempio n. 3
0
def main(_):
    print('Parsed arguments: ', FLAGS.__flags)

    # make save path if it is required
    if not os.path.exists(FLAGS.save_path):
        os.makedirs(FLAGS.save_path)
    if not os.path.exists(FLAGS.synthesis_path):
        os.makedirs(FLAGS.synthesis_path)
    np.random.seed(FLAGS.seed)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement=True
    udevices = []
    for device in devices:
        if len(devices) > 1 and 'cpu' in device.name:
            # Use cpu only when we dont have gpus
            continue
        print('Using device: ', device.name)
        udevices.append(device.name)
    # execute the session
    with tf.Session(config=config) as sess:
        if FLAGS.model == 'gan':
            print('Creating GAN model')
            se_model = SEGAN(sess, FLAGS, udevices)
        elif FLAGS.model == 'ae':
            print('Creating AE model')
            se_model = SEAE(sess, FLAGS, udevices)
        else:
            raise ValueError('{} model type not understood!'.format(FLAGS.model))
        if FLAGS.test_wav is None:
            se_model.train(FLAGS, udevices)
        else:
            if FLAGS.weights is None:
                raise ValueError('weights must be specified!')
            print('Loading model weights...')
            se_model.load(FLAGS.save_path, FLAGS.weights)
            fm, wav_data = wavfile.read(FLAGS.test_wav)
            wavname = FLAGS.test_wav.split('/')[-1]
            if fm != 16000:
                raise ValueError('16kHz required! Test file is different')
            wave = (2./65535.) * (wav_data.astype(np.float32) - 32767) + 1.
            if FLAGS.preemph  > 0:
                print('preemph test wave with {}'.format(FLAGS.preemph))
                x_pholder, preemph_op = pre_emph_test(FLAGS.preemph, wave.shape[0])
                wave = sess.run(preemph_op, feed_dict={x_pholder:wave})
            print('test wave shape: ', wave.shape)
            print('test wave min:{}  max:{}'.format(np.min(wave), np.max(wave)))
            c_wave = se_model.clean(wave)
            print('c wave min:{}  max:{}'.format(np.min(c_wave), np.max(c_wave)))
            wavfile.write(os.path.join(FLAGS.save_clean_path, wavname), 16e3, c_wave)
            print('Done cleaning {} and saved '
                  'to {}'.format(FLAGS.test_wav,
                                 os.path.join(FLAGS.save_clean_path, wavname)))
Esempio n. 4
0
def main(_):
    if FLAGS.feature_type == 'wavform':
        from model import SEGAN, SEAE
    elif FLAGS.feature_type == 'logspec':
        from spec_model import SEGAN, SEAE

    print('Parsed arguments: ', FLAGS.__flags)

    # make save path if it is required
    if not os.path.exists(FLAGS.save_path):
        os.makedirs(FLAGS.save_path)
    if not os.path.exists(FLAGS.synthesis_path):
        os.makedirs(FLAGS.synthesis_path)
    np.random.seed(FLAGS.seed)
    #gpu_options = tf.GPUOptions(allow_growth=True)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.gpu_options.allocator_type = 'BFC'
    udevices = []
    for device in devices:
        print("Device lists:{}".format(devices))
        if len(devices) > 1 and 'cpu' in device.name:
            # Use cpu only when we dont have gpus
            continue
        print('Using device: ', device.name)
        udevices.append(device.name)
    print("device:{}".format(udevices))
    # execute the session
    with tf.Session(config=config) as sess:
        if FLAGS.model == 'gan':
            print('Creating GAN model')
            se_model = SEGAN(sess, FLAGS, udevices)
        elif FLAGS.model == 'ae':
            print('Creating AE model')
            se_model = SEAE(sess, FLAGS, udevices)
        else:
            raise ValueError('{} model type not understood!'.format(
                FLAGS.model))

        if FLAGS.test_wav is None:
            mode = 'stage2'
            se_model.train(FLAGS, udevices, mode)

        else:
            if FLAGS.weights is None:
                raise ValueError('weights must be specified!')
            print('Loading model weights...')
            se_model.load(FLAGS.save_path, FLAGS.weights)

            noisy_test_filelist = []
            for (dirpath, dirnames, filenames) in os.walk(FLAGS.test_wav):
                # print('dirpath = ' + dirpath)
                for filename in filenames:
                    file_path = os.path.join(dirpath, filename)
                    noisy_test_filelist.append(file_path)
            nlist = noisy_test_filelist

            for name in nlist:
                t1 = time.time()
                fm, wav_data = wavfile.read(name)

                wavname = name.split('/')[-1]
                if fm != 16000:
                    raise ValueError('16kHz required! Test file is different')
                    #import librosa
                    #print('16kHz is required: test file is {}kHz, have to resample to the required samplerate')
                    #wav_data = librosa.resample(wav_data, fm, 16000)
                if FLAGS.feature_type == 'wavform':
                    wave = (2. / 65535.) * (wav_data.astype(np.float32) -
                                            32767) + 1.
                    if FLAGS.preemph > 0:
                        print('preemph test wave with {}'.format(
                            FLAGS.preemph))
                        x_pholder, preemph_op = pre_emph_test(
                            FLAGS.preemph, wave.shape[0])
                        wave = sess.run(preemph_op,
                                        feed_dict={x_pholder: wave})
                    print('test wave shape: ', wave.shape)
                    print('test wave min:{}  max:{}'.format(
                        np.min(wave), np.max(wave)))
                    c_wave = se_model.clean(wave)
                    print('c wave min:{}  max:{}'.format(
                        np.min(c_wave), np.max(c_wave)))
                    wavfile.write(
                        os.path.join(FLAGS.save_clean_path, wavname), 16e3,
                        np.int16(c_wave *
                                 32767))  #(0.9*c_wave/max(abs(c_wave)))
                    t2 = time.time(
                    )  #np.int16((1.0*c_wave/max(abs(c_wave)))*32767)
                    print('Done cleaning {}/{}s and saved '
                          'to {}'.format(
                              name, t2 - t1,
                              os.path.join(FLAGS.save_clean_path, wavname)))

                if FLAGS.feature_type == 'logspec':
                    if wav_data.dtype != 'float32':
                        wave = np.float32(wav_data / 32767.)
                    #wave = (2./65535.) * (wav_data.astype(np.float32) - 32767) + 1.
                    if FLAGS.preemph > 0:
                        print('preemph test wave with {}'.format(
                            FLAGS.preemph))
                        x_pholder, preemph_op = pre_emph_test(
                            FLAGS.preemph, wave.shape[0])
                        wave = sess.run(preemph_op,
                                        feed_dict={x_pholder: wave})
                    print('test wave shape: ', wave.shape)
                    print('test wave min:{}  max:{}'.format(
                        np.min(wave), np.max(wave)))
                    c_wave = se_model.clean(wave)
                    print('c wave min:{}  max:{}'.format(
                        np.min(c_wave), np.max(c_wave)))
                    wavfile.write(
                        os.path.join(FLAGS.save_clean_path, wavname), 16e3,
                        np.int16(c_wave *
                                 32767))  #(0.9*c_wave/max(abs(c_wave)))
                    t2 = time.time()
                    print('Done cleaning {}/{}s and saved '
                          'to {}'.format(
                              name, t2 - t1,
                              os.path.join(FLAGS.save_clean_path, wavname)))
    '''