def SR_recognize(self, wavs, pre_type): hanzi = '' am_url = tf_serving_url.format('am') lm_url = tf_serving_url.format('lm') if ues_tf_serving: x, _, _ = utils.get_wav_Feature(wavsignal=wavs) try: receipt = requests.post(am_url, data='{"instances":%s}' % x.tolist()).json()['predictions'][0] receipt = np.array([receipt], dtype=np.float32) except BaseException as e: return _, '声学模型调用异常' _, pinyin = utils.decode_ctc(receipt, utils.pny_vocab) pinyin = [[ utils.pny_vocab.index(p) for p in ' '.join(pinyin).strip('\n').split(' ') ]] if pre_type == 'H': #curl -d '{"instances": [[420,58]]}' -X POST http://localhost:8501/v1/models/lm:predict try: hanzi = requests.post(lm_url, data='{"instances": %s}' % pinyin).json()['predictions'][0] except BaseException as e: return _, '语言模型调用异常' hanzi = ''.join(utils.han_vocab[idx] for idx in hanzi) else: if pre_type == 'H': pinyin, hanzi = yysb.predict(wavs) else: pinyin = yysb.predict(wavs, only_pinyin=True) return pinyin, hanzi
def test_batch(): data_args.data_type = 'test' data_args.shuffle = False data_args.batch_size = 1 test_data = get_data(data_args) am_batch = test_data.get_am_batch() word_num = 0 word_error_num = 0 for i in range(10): print('\n the ', i, 'th example.') # 载入训练好的模型,并进行识别 inputs, _ = next(am_batch) x = inputs['the_inputs'] y = test_data.pny_lst[i] result = am.model.predict(x, steps=1) # 将数字结果转化为文本结果 _, text = decode_ctc(result, train_data.am_vocab) text = ' '.join(text) print('文本结果:', text) print('原文结果:', ' '.join(y)) with sess.as_default(): text = text.strip('\n').split(' ') x = np.array([train_data.pny_vocab.index(pny) for pny in text]) x = x.reshape(1, -1) preds = sess.run(lm.preds, {lm.x: x}) label = test_data.han_lst[i] got = ''.join(train_data.han_vocab[idx] for idx in preds[0]) print('原文汉字:', label) print('识别结果:', got) word_error_num += min(len(label), GetEditDistance(label, got)) word_num += len(label) print('词错误率:', word_error_num / word_num) sess.close()
def speech(self, filename): try: inputs, _ = self.data_processer.am_data_process(filename) audio_con = inputs['the_inputs'] result = self.am.model.predict(audio_con, steps=1) _, text = decode_ctc(result, self.vocab.ampny_vocab) text = ' '.join(text) with self.sess.as_default(): text = text.strip('\n').split(' ') x = np.array([self.vocab.pny_vocab.index(pny) for pny in text if pny != '']) if len(x) == 0: x = np.array([120, 79, 1, 53]) # 为了不出现警告,临时添加的没用,加几都行 x = x.reshape(1, -1) preds = self.sess.run(self.lm.preds, {self.lm.x: x}) ultimate_result = ''.join(self.vocab.han_vocab[idx] for idx in preds[0]) return ultimate_result except KeyboardInterrupt: return
data_args.batch_size = 1 test_data = get_data(data_args) # get_data是一個class # 4. 进行测试------------------------------------------- am_batch = test_data.get_am_batch() word_num = 0 word_error_num = 0 for i in range(10): print('\n the ', i, 'th example.') # 载入训练好的模型,并进行识别 inputs, _ = next(am_batch) x = inputs['the_inputs'] y = test_data.pny_lst[i] result = am.model.predict(x, steps=1) # 将数字结果转化为文本结果 _, text = decode_ctc(result, train_data.am_vocab) text = ' '.join(text) print('文本結果:', text) print('原文結果:', ' '.join(y)) with sess.as_default(): text = text.strip('\n').split(' ') x = np.array([train_data.pny_vocab.index(pny) for pny in text]) x = x.reshape(1, -1) preds = sess.run(lm.preds, {lm.x: x}) label = test_data.han_lst[i] got = ''.join(train_data.han_vocab[idx] for idx in preds[0]) print('原文漢字:', label) print('识别結果:', got) word_error_num += min(len(label), GetEditDistance(label, got)) word_num += len(label) print('詞错誤率:', word_error_num / word_num)
am_args = am_hparams() am_args.vocab_size = len(am_vocab) am = Am(am_args) print('loading acoustic model...') am.ctc_model.load_weights(model_save_path) data_args.data_type = 'test' data_args.shuffle = False data_args.batch_size = 1 test_data = get_data(data_args) # 4. 进行测试------------------------------------------- am_batch = test_data.get_am_batch() word_num = 0 word_error_num = 0 for i in range(1): print('\n the ', i, 'th example.') # 载入训练好的模型,并进行识别 inputs, _ = next(am_batch) x = inputs['the_inputs'] y = test_data.pny_lst[i] result = am.model.predict(x, steps=1) # 将数字结果转化为文本结果 _, text = decode_ctc(result, am_vocab) text = ' '.join(text) print('文本结果:', text) print('原文结果:', ' '.join(y)) word_error_num += min(len(y), GetEditDistance(y, text.split(' '))) word_num += len(y) print('词错误率:', word_error_num / word_num)
def predict(thefile): if not thefile.endswith('wav'): print("[Error*] The file is not in .wav format!") testfile = wave.open('./tmp/' + thefile, mode='rb') print("The input has {} channel(s)".format(testfile.getnchannels())) am_batch = '' framerate = testfile.getframerate() framenum = testfile.getnframes() length = framenum / framerate print("The length of {} is {} seconds.".format(thefile, length)) max_len = 10 if length > max_len: piece_len = (max_len // 3) * 2 portion = piece_len * framerate n_pieces = length // piece_len + 1 n_pieces = int(n_pieces) print( "The file exceeds the max length of {} seconds and needs to be split into {} pieces" .format(max_len, n_pieces)) for i in range(n_pieces): apiece = testfile.readframes(framerate * max_len) testfile.setpos(testfile.tell() - portion / 2) tmp = wave.open('./tmp/tmp{:04}.wav'.format(i), mode='wb') tmp.setnchannels(1) tmp.setframerate(16000) tmp.setsampwidth(2) tmp.writeframes(apiece) tmp.close() am_batch = test_data.get_dep_batch(os.listdir('./tmp/')) for i in range(n_pieces): inputs, _ = next(am_batch) x = inputs['the_inputs'] #print(x.shape) #y = test_data.pny_lst[i] result = am.model.predict(x, steps=1) # 将数字结果转化为文本结果 _, text = decode_ctc(result, train_data.am_vocab) text = ' '.join(text) print('文本结果:', text) #print('原文结果:', ' '.join(y)) with sess.as_default(): text = text.strip('\n').split(' ') x = np.array([train_data.pny_vocab.index(pny) for pny in text]) x = x.reshape(1, -1) preds = sess.run(lm.preds, {lm.x: x}) #label = test_data.han_lst[i] got = ''.join(train_data.han_vocab[idx] for idx in preds[0]) #print('原文汉字:', label) print('识别结果:', got) #word_error_num += min(len(label), GetEditDistance(label, got)) else: thelist = [thefile] am_batch = test_data.get_dep_batch(thelist) inputs, _ = next(am_batch) x = inputs['the_inputs'] #print(x.shape) #y = test_data.pny_lst[i] result = am.model.predict(x, steps=1) # 将数字结果转化为文本结果 _, text = decode_ctc(result, train_data.am_vocab) text = ' '.join(text) print('文本结果:', text) #print('原文结果:', ' '.join(y)) with sess.as_default(): text = text.strip('\n').split(' ') x = np.array([train_data.pny_vocab.index(pny) for pny in text]) x = x.reshape(1, -1) preds = sess.run(lm.preds, {lm.x: x}) #label = test_data.han_lst[i] got = ''.join(train_data.han_vocab[idx] for idx in preds[0]) #print('原文汉字:', label) print('识别结果:', got)
def predict(thefile, conns, lock): if not thefile.endswith('wav'): print("[Error*] The file is not in .wav format!") testfile = wave.open(thefile, mode='rb') print("The input has {} channel(s)".format(testfile.getnchannels())) am_batch = '' framerate = testfile.getframerate() framenum = testfile.getnframes() length = framenum / framerate print("The length of {} is {} seconds.".format(thefile, length)) max_len = 10 if length > max_len: piece_len = max_len #(max_len // 3) * 2 portion = piece_len * framerate n_pieces = length // piece_len + 1 n_pieces = int(n_pieces) print( "The file exceeds the max length of {} seconds and needs to be split into {} pieces" .format(max_len, n_pieces)) filelist = [] for i in range(n_pieces): apiece = testfile.readframes(framerate * max_len) #testfile.setpos(testfile.tell()-portion) filename = './tmp/tmp{:04}.wav'.format(i) tmp = wave.open(filename, mode='wb') tmp.setnchannels(1) tmp.setframerate(16000) tmp.setsampwidth(2) tmp.writeframes(apiece) tmp.close() filelist.append(filename) #am_batch = test_data.get_dep_batch(os.listdir('./tmp/')) am_batch = test_data.get_dep_batch(filelist) #am_batch = test_data.get_dep_batch(os.listdir('/home/comp/15485625/data/speech/sp2chs/data_aishell/wav/test/')) for i in range(n_pieces): inputs, _ = next(am_batch) x = inputs['the_inputs'] #print(x.shape) #y = test_data.pny_lst[i] result = am.model.predict(x, steps=1) # 将数字结果转化为文本结果 _, text = decode_ctc(result, train_data.am_vocab) text = ' '.join(text) print('%s: %s' % (filelist[i], text)) #print('原文结果:', ' '.join(y)) with sess.as_default(): text = text.strip('\n').split(' ') x = np.array([train_data.pny_vocab.index(pny) for pny in text]) x = x.reshape(1, -1) preds = sess.run(lm.preds, {lm.x: x}) #label = test_data.han_lst[i] got = ''.join(train_data.han_vocab[idx] for idx in preds[0]) #print('原文汉字:', label) #print('识别结果:', got) # modified_RecvMessage = data.decode('utf-8') #print(modified_RecvMessage) lock.acquire() for conn in conns: try: conn.send((got + ':').encode('utf-8')) except Exception as e: pass lock.release() print('%s: %s' % (filelist[i], got)) #word_error_num += min(len(label), GetEditDistance(label, got)) #word_num += len(label) else: filelist = [thefile] am_batch = test_data.get_dep_batch(filelist) inputs, _ = next(am_batch) x = inputs['the_inputs'] #print(x.shape) #y = test_data.pny_lst[i] result = am.model.predict(x, steps=1) # 将数字结果转化为文本结果 _, text = decode_ctc(result, train_data.am_vocab) text = ' '.join(text) print('%s: %s' % (filelist[0], text)) #print('原文结果:', ' '.join(y)) with sess.as_default(): text = text.strip('\n').split(' ') x = np.array([train_data.pny_vocab.index(pny) for pny in text]) x = x.reshape(1, -1) preds = sess.run(lm.preds, {lm.x: x}) #label = test_data.han_lst[i] got = ''.join(train_data.han_vocab[idx] for idx in preds[0]) #print('原文汉字:', label) #print('识别结果:', got) # modified_RecvMessage = data.decode('utf-8') #print(modified_RecvMessage) lock.acquire() for conn in conns: try: conn.send((got + ':').encode('utf-8')) except Exception as e: pass lock.release() print('%s: %s' % (filelist[0], got))