def main():
    print('start...')
    a = open(meta_path, 'r').readlines()

    # NN->PPG
    # Set up network
    mfcc_pl = tf.placeholder(dtype=tf.float32,
                             shape=[None, None, MFCC_DIM],
                             name='mfcc_pl')
    classifier = CNNBLSTMCalssifier(out_dims=PPG_DIM,
                                    n_cnn=3,
                                    cnn_hidden=256,
                                    cnn_kernel=3,
                                    n_blstm=2,
                                    lstm_hidden=128)
    predicted_ppgs = tf.nn.softmax(classifier(inputs=mfcc_pl)['logits'])
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    print('Restoring model from {}'.format(ckpt_path))
    saver.restore(sess, ckpt_path)

    cnt = 0
    bad_list = []
    for wav_f in tqdm(a):
        try:
            wav_f = wav_f.strip()
            print('process:', wav_f)
            # 提取声学参数
            wav_arr = load_wav(wav_f)
            mfcc_feats = wav2unnormalized_mfcc(wav_arr)
            ppgs = sess.run(
                predicted_ppgs,
                feed_dict={mfcc_pl: np.expand_dims(mfcc_feats, axis=0)})
            ppgs = np.squeeze(ppgs)
            mel_feats = wav2normalized_db_mel(wav_arr)
            spec_feats = wav2normalized_db_spec(wav_arr)

            # /datapool/home/hujk17/ppg_decode_spec_10ms_sch_Multi/inference_findA_Multi_log_dir/2020-11-09T11-09-00/0_sample_spec.wav
            fname = wav_f.split('/')[-1].split('.')[0]
            save_mel_rec_name = fname + '_mel_rec.wav'
            save_spec_rec_name = fname + '_spec_rec.wav'
            assert ppgs.shape[0] == mfcc_feats.shape[0]
            assert mfcc_feats.shape[0] == mel_feats.shape[
                0] and mel_feats.shape[0] == spec_feats.shape[0]
            write_wav(os.path.join(rec_wav_dir, save_mel_rec_name),
                      normalized_db_mel2wav(mel_feats))
            write_wav(os.path.join(rec_wav_dir, save_spec_rec_name),
                      normalized_db_spec2wav(spec_feats))
            check_ppg(ppgs)

            # 存储ppg参数
            ppg_save_name = os.path.join(ppg_dir, fname + '.npy')
            np.save(ppg_save_name, ppgs)

            cnt += 1
        except Exception as e:
            bad_list.append(wav_f)
            print(str(e))

        # break

    print('good:', cnt)
    print('bad:', len(bad_list))
    print(bad_list)

    return
예제 #2
0
def main():
    #这一部分用于处理LJSpeech格式的数据集
    a = open(meta_path, 'r').readlines()
    a = [i.strip().split('|')[0] for i in a]

    a = PPG_get_restore(a, ppg_dir, ppg_dir, mel_dir, spec_dir)

    # NN->PPG
    # Set up network
    mfcc_pl = tf.placeholder(dtype=tf.float32, shape=[None, None, MFCC_DIM], name='mfcc_pl')
    classifier = CNNBLSTMCalssifier(out_dims=PPG_DIM, n_cnn=3, cnn_hidden=256, cnn_kernel=3, n_blstm=2, lstm_hidden=128)
    predicted_ppgs = tf.nn.softmax(classifier(inputs=mfcc_pl)['logits'])
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    print('Restoring model from {}'.format(ckpt_path))
    saver.restore(sess, ckpt_path)

    
    cnt = 0
    bad_list = []
    for fname in tqdm(a):
        try:
            # 提取声学参数
            # print('aaaaaaaaaaa111111111111111111111111111')
            wav_f = os.path.join(wav_dir, fname + '.wav')
            wav_arr = load_wav(wav_f)
            # print('0000000000000000')
            mfcc_feats = wav2unnormalized_mfcc(wav_arr)
            # print('111111111111111111111111111')
            ppgs = sess.run(predicted_ppgs, feed_dict={mfcc_pl: np.expand_dims(mfcc_feats, axis=0)})
            # print('5555555111111111111111111111111111')
            ppgs = np.squeeze(ppgs)
            # print('66666666666S111111111111111111111111111')
            mel_feats = wav2normalized_db_mel(wav_arr)
            spec_feats = wav2normalized_db_spec(wav_arr)
            # print('222222222111111111111111111111111111')
            # 验证声学参数提取的对
            save_name = fname + '.npy'
            save_mel_rec_name = fname + '_mel_rec.wav'
            save_spec_rec_name = fname + '_spec_rec.wav'
            assert ppgs.shape[0] == mfcc_feats.shape[0]
            assert mfcc_feats.shape[0] == mel_feats.shape[0] and mel_feats.shape[0] == spec_feats.shape[0]
            write_wav(os.path.join(rec_wav_dir, save_mel_rec_name), normalized_db_mel2wav(mel_feats))
            write_wav(os.path.join(rec_wav_dir, save_spec_rec_name), normalized_db_spec2wav(spec_feats))
            # print('11111111111111333333333331111111111111')
            check_ppg(ppgs)
            
            # 存储声学参数
            mfcc_save_name = os.path.join(mfcc_dir, save_name)
            ppg_save_name = os.path.join(ppg_dir, save_name)
            mel_save_name = os.path.join(mel_dir, save_name)
            spec_save_name = os.path.join(spec_dir, save_name)
            np.save(mfcc_save_name, mfcc_feats)
            np.save(ppg_save_name, ppgs)
            np.save(mel_save_name, mel_feats)
            np.save(spec_save_name, spec_feats)

            f_good_meta.write(fname + '\n')
            cnt += 1
        except Exception as e:
            bad_list.append(fname)
            print(str(e))
        
        # break

    print('good:', cnt)
    print('bad:', len(bad_list))
    print(bad_list)

    return