Exemple #1
0
 def SR_recognize(self, wavs, pre_type):
     hanzi = ''
     am_url = tf_serving_url.format('am')
     lm_url = tf_serving_url.format('lm')
     if ues_tf_serving:
         x, _, _ = utils.get_wav_Feature(wavsignal=wavs)
         try:
             receipt = requests.post(am_url,
                                     data='{"instances":%s}' %
                                     x.tolist()).json()['predictions'][0]
             receipt = np.array([receipt], dtype=np.float32)
         except BaseException as e:
             return _, '声学模型调用异常'
         _, pinyin = utils.decode_ctc(receipt, utils.pny_vocab)
         pinyin = [[
             utils.pny_vocab.index(p)
             for p in ' '.join(pinyin).strip('\n').split(' ')
         ]]
         if pre_type == 'H':
             #curl -d '{"instances": [[420,58]]}' -X POST http://localhost:8501/v1/models/lm:predict
             try:
                 hanzi = requests.post(lm_url,
                                       data='{"instances": %s}' %
                                       pinyin).json()['predictions'][0]
             except BaseException as e:
                 return _, '语言模型调用异常'
             hanzi = ''.join(utils.han_vocab[idx] for idx in hanzi)
     else:
         if pre_type == 'H':
             pinyin, hanzi = yysb.predict(wavs)
         else:
             pinyin = yysb.predict(wavs, only_pinyin=True)
     return pinyin, hanzi
Exemple #2
0
def test_batch():
    data_args.data_type = 'test'
    data_args.shuffle = False
    data_args.batch_size = 1
    test_data = get_data(data_args)
    am_batch = test_data.get_am_batch()
    word_num = 0
    word_error_num = 0
    for i in range(10):
        print('\n the ', i, 'th example.')
        # 载入训练好的模型,并进行识别
        inputs, _ = next(am_batch)
        x = inputs['the_inputs']
        y = test_data.pny_lst[i]
        result = am.model.predict(x, steps=1)
        # 将数字结果转化为文本结果
        _, text = decode_ctc(result, train_data.am_vocab)
        text = ' '.join(text)
        print('文本结果:', text)
        print('原文结果:', ' '.join(y))
        with sess.as_default():
            text = text.strip('\n').split(' ')
            x = np.array([train_data.pny_vocab.index(pny) for pny in text])
            x = x.reshape(1, -1)
            preds = sess.run(lm.preds, {lm.x: x})
            label = test_data.han_lst[i]
            got = ''.join(train_data.han_vocab[idx] for idx in preds[0])
            print('原文汉字:', label)
            print('识别结果:', got)
            word_error_num += min(len(label), GetEditDistance(label, got))
            word_num += len(label)
    print('词错误率:', word_error_num / word_num)
    sess.close()
 def speech(self, filename):
     try:
         inputs, _ = self.data_processer.am_data_process(filename)
         audio_con = inputs['the_inputs']
         result = self.am.model.predict(audio_con, steps=1)
         _, text = decode_ctc(result, self.vocab.ampny_vocab)
         text = ' '.join(text)
         with self.sess.as_default():
             text = text.strip('\n').split(' ')
             x = np.array([self.vocab.pny_vocab.index(pny) for pny in text if pny != ''])
             if len(x) == 0:
                 x = np.array([120, 79, 1, 53])  # 为了不出现警告,临时添加的没用,加几都行
             x = x.reshape(1, -1)
             preds = self.sess.run(self.lm.preds, {self.lm.x: x})
             ultimate_result = ''.join(self.vocab.han_vocab[idx] for idx in preds[0])
         return ultimate_result
     except KeyboardInterrupt:
         return
data_args.batch_size = 1
test_data = get_data(data_args) # get_data是一個class

# 4. 进行测试-------------------------------------------
am_batch = test_data.get_am_batch()
word_num = 0
word_error_num = 0
for i in range(10):
    print('\n the ', i, 'th example.')
    # 载入训练好的模型,并进行识别
    inputs, _ = next(am_batch)
    x = inputs['the_inputs']
    y = test_data.pny_lst[i]
    result = am.model.predict(x, steps=1)
    # 将数字结果转化为文本结果
    _, text = decode_ctc(result, train_data.am_vocab)
    text = ' '.join(text)
    print('文本結果:', text)
    print('原文結果:', ' '.join(y))
    with sess.as_default():
        text = text.strip('\n').split(' ')
        x = np.array([train_data.pny_vocab.index(pny) for pny in text])
        x = x.reshape(1, -1)
        preds = sess.run(lm.preds, {lm.x: x})
        label = test_data.han_lst[i]
        got = ''.join(train_data.han_vocab[idx] for idx in preds[0])
        print('原文漢字:', label)
        print('识别結果:', got)
        word_error_num += min(len(label), GetEditDistance(label, got))
        word_num += len(label)
print('詞错誤率:', word_error_num / word_num)
Exemple #5
0
am_args = am_hparams()
am_args.vocab_size = len(am_vocab)
am = Am(am_args)
print('loading acoustic model...')
am.ctc_model.load_weights(model_save_path)

data_args.data_type = 'test'
data_args.shuffle = False
data_args.batch_size = 1
test_data = get_data(data_args)

# 4. 进行测试-------------------------------------------
am_batch = test_data.get_am_batch()
word_num = 0
word_error_num = 0
for i in range(1):
    print('\n the ', i, 'th example.')
    # 载入训练好的模型,并进行识别
    inputs, _ = next(am_batch)
    x = inputs['the_inputs']
    y = test_data.pny_lst[i]
    result = am.model.predict(x, steps=1)
    # 将数字结果转化为文本结果
    _, text = decode_ctc(result, am_vocab)
    text = ' '.join(text)
    print('文本结果:', text)
    print('原文结果:', ' '.join(y))

    word_error_num += min(len(y), GetEditDistance(y, text.split(' ')))
    word_num += len(y)
print('词错误率:', word_error_num / word_num)
Exemple #6
0
def predict(thefile):
    if not thefile.endswith('wav'):
        print("[Error*] The file is not in .wav format!")
    testfile = wave.open('./tmp/' + thefile, mode='rb')
    print("The input has {} channel(s)".format(testfile.getnchannels()))
    am_batch = ''
    framerate = testfile.getframerate()
    framenum = testfile.getnframes()
    length = framenum / framerate
    print("The length of {} is {} seconds.".format(thefile, length))
    max_len = 10
    if length > max_len:
        piece_len = (max_len // 3) * 2
        portion = piece_len * framerate
        n_pieces = length // piece_len + 1
        n_pieces = int(n_pieces)
        print(
            "The file exceeds the max length of {} seconds and needs to be split into {} pieces"
            .format(max_len, n_pieces))
        for i in range(n_pieces):
            apiece = testfile.readframes(framerate * max_len)
            testfile.setpos(testfile.tell() - portion / 2)
            tmp = wave.open('./tmp/tmp{:04}.wav'.format(i), mode='wb')
            tmp.setnchannels(1)
            tmp.setframerate(16000)
            tmp.setsampwidth(2)
            tmp.writeframes(apiece)
            tmp.close()
        am_batch = test_data.get_dep_batch(os.listdir('./tmp/'))
        for i in range(n_pieces):
            inputs, _ = next(am_batch)
            x = inputs['the_inputs']
            #print(x.shape)
            #y = test_data.pny_lst[i]
            result = am.model.predict(x, steps=1)
            # 将数字结果转化为文本结果
            _, text = decode_ctc(result, train_data.am_vocab)
            text = ' '.join(text)
            print('文本结果:', text)
            #print('原文结果:', ' '.join(y))
            with sess.as_default():
                text = text.strip('\n').split(' ')
                x = np.array([train_data.pny_vocab.index(pny) for pny in text])
                x = x.reshape(1, -1)
                preds = sess.run(lm.preds, {lm.x: x})
                #label = test_data.han_lst[i]
                got = ''.join(train_data.han_vocab[idx] for idx in preds[0])
                #print('原文汉字:', label)
                print('识别结果:', got)
                #word_error_num += min(len(label), GetEditDistance(label, got))

    else:
        thelist = [thefile]
        am_batch = test_data.get_dep_batch(thelist)
        inputs, _ = next(am_batch)
        x = inputs['the_inputs']
        #print(x.shape)
        #y = test_data.pny_lst[i]
        result = am.model.predict(x, steps=1)
        # 将数字结果转化为文本结果
        _, text = decode_ctc(result, train_data.am_vocab)
        text = ' '.join(text)
        print('文本结果:', text)
        #print('原文结果:', ' '.join(y))
        with sess.as_default():
            text = text.strip('\n').split(' ')
            x = np.array([train_data.pny_vocab.index(pny) for pny in text])
            x = x.reshape(1, -1)
            preds = sess.run(lm.preds, {lm.x: x})
            #label = test_data.han_lst[i]
            got = ''.join(train_data.han_vocab[idx] for idx in preds[0])
            #print('原文汉字:', label)
            print('识别结果:', got)
def predict(thefile, conns, lock):
    if not thefile.endswith('wav'):
        print("[Error*] The file is not in .wav format!")
    testfile = wave.open(thefile, mode='rb')
    print("The input has {} channel(s)".format(testfile.getnchannels()))
    am_batch = ''
    framerate = testfile.getframerate()
    framenum = testfile.getnframes()
    length = framenum / framerate
    print("The length of {} is {} seconds.".format(thefile, length))
    max_len = 10
    if length > max_len:
        piece_len = max_len  #(max_len // 3) * 2
        portion = piece_len * framerate
        n_pieces = length // piece_len + 1
        n_pieces = int(n_pieces)
        print(
            "The file exceeds the max length of {} seconds and needs to be split into {} pieces"
            .format(max_len, n_pieces))
        filelist = []
        for i in range(n_pieces):
            apiece = testfile.readframes(framerate * max_len)
            #testfile.setpos(testfile.tell()-portion)
            filename = './tmp/tmp{:04}.wav'.format(i)
            tmp = wave.open(filename, mode='wb')
            tmp.setnchannels(1)
            tmp.setframerate(16000)
            tmp.setsampwidth(2)
            tmp.writeframes(apiece)
            tmp.close()
            filelist.append(filename)
        #am_batch = test_data.get_dep_batch(os.listdir('./tmp/'))
        am_batch = test_data.get_dep_batch(filelist)
        #am_batch = test_data.get_dep_batch(os.listdir('/home/comp/15485625/data/speech/sp2chs/data_aishell/wav/test/'))
        for i in range(n_pieces):
            inputs, _ = next(am_batch)
            x = inputs['the_inputs']
            #print(x.shape)
            #y = test_data.pny_lst[i]
            result = am.model.predict(x, steps=1)
            # 将数字结果转化为文本结果
            _, text = decode_ctc(result, train_data.am_vocab)
            text = ' '.join(text)
            print('%s: %s' % (filelist[i], text))
            #print('原文结果:', ' '.join(y))
            with sess.as_default():
                text = text.strip('\n').split(' ')
                x = np.array([train_data.pny_vocab.index(pny) for pny in text])
                x = x.reshape(1, -1)
                preds = sess.run(lm.preds, {lm.x: x})
                #label = test_data.han_lst[i]
                got = ''.join(train_data.han_vocab[idx] for idx in preds[0])
                #print('原文汉字:', label)
                #print('识别结果:', got)
                # modified_RecvMessage = data.decode('utf-8')
                #print(modified_RecvMessage)
                lock.acquire()
                for conn in conns:
                    try:
                        conn.send((got + ':').encode('utf-8'))
                    except Exception as e:
                        pass
                lock.release()
                print('%s: %s' % (filelist[i], got))
                #word_error_num += min(len(label), GetEditDistance(label, got))
                #word_num += len(label)

    else:
        filelist = [thefile]
        am_batch = test_data.get_dep_batch(filelist)
        inputs, _ = next(am_batch)
        x = inputs['the_inputs']
        #print(x.shape)
        #y = test_data.pny_lst[i]
        result = am.model.predict(x, steps=1)
        # 将数字结果转化为文本结果
        _, text = decode_ctc(result, train_data.am_vocab)
        text = ' '.join(text)
        print('%s: %s' % (filelist[0], text))
        #print('原文结果:', ' '.join(y))
        with sess.as_default():
            text = text.strip('\n').split(' ')
            x = np.array([train_data.pny_vocab.index(pny) for pny in text])
            x = x.reshape(1, -1)
            preds = sess.run(lm.preds, {lm.x: x})
            #label = test_data.han_lst[i]
            got = ''.join(train_data.han_vocab[idx] for idx in preds[0])
            #print('原文汉字:', label)
            #print('识别结果:', got)
            # modified_RecvMessage = data.decode('utf-8')
            #print(modified_RecvMessage)
            lock.acquire()
            for conn in conns:
                try:
                    conn.send((got + ':').encode('utf-8'))
                except Exception as e:
                    pass
            lock.release()
            print('%s: %s' % (filelist[0], got))