def __init__(self, args):
     self.args = args
     self.sortedDict = SortedByCountsDict(dump_dir=self.args.vocab_dump_dir)
     self.acoustic_vocab_size, self.acoustic_vocab = Util.get_acoustic_vocab_list(
     )
     self.language_vocab_size, self.language_vocab = Util.get_language_vocab_list(
     )
    def __init__(self, dataset):
        self.dataset = dataset
        self.acoustic_vocab_size, self.acoustic_vocab = Util.get_acoustic_vocab_list()
        self.language_vocab_size, self.language_vocab = Util.get_language_vocab_list()

        self.am_args, self.lm_args = args, args
        self.am_args.data_length = None

        self.lm_args.lr = args.lm_lr
        self.lm_args.is_training = False
        self.lm_args.max_len = args.lm_max_len
        self.lm_args.hidden_units = args.lm_hidden_units
        self.lm_args.feature_dim = args.lm_feature_dim
        self.lm_args.num_heads = args.lm_num_heads
        self.lm_args.num_blocks = args.lm_num_blocks
        self.lm_args.position_max_length = args.lm_position_max_length
        self.lm_args.dropout_rate = args.lm_dropout_rate
    def predict(self, **data):
        if os.path.exists(self.lm_args.PredResultFolder):
            os.mkdir(self.lm_args.PredResultFolder)
        audio_path = self.dataset.predict_data(**data)[0]

        # 声学模型
        am_model = CNNCTCModel(args=self.am_args, vocab_size=self.acoustic_vocab_size)
        am_model.load_model(os.path.join(self.am_args.AmModelFolder, self.am_args.am_ckpt))

        # 语言模型
        lm_model = TransformerModel(arg=self.lm_args, acoustic_vocab_size=self.acoustic_vocab_size,
                                    language_vocab_size=self.language_vocab_size)
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.1)
        sess = tf.Session(graph=lm_model.graph, config=tf.ConfigProto(gpu_options=gpu_options))
        with lm_model.graph.as_default():
            saver = tf.train.Saver()
        latest = tf.train.latest_checkpoint(self.lm_args.LmModelFolder)
        saver.restore(sess, latest)

        # 声学模型预测
        # signal, sample_rate = sf.read(audio_path)
        signal, sample_rate = librosa.load(audio_path, sr=48000)
        signal = librosa.resample(y=signal, orig_sr=sample_rate, target_sr=16000)
        sample_rate = 16000
        fbank = Util.compute_fbank_from_api(signal, sample_rate)
        input_data = fbank.reshape([fbank.shape[0], fbank.shape[1], 1])
        print('input_data_length:{}'.format(len(input_data)))
        data_length = input_data.shape[0] // 8 + 1

        predict, pinyin = Util.predict_pinyin(model=am_model, inputs=input_data, input_length=data_length,
                                              acoustic_vocab=self.acoustic_vocab)
        print('predict: {}'.format(predict))
        print('pinyin: {}'.format(pinyin))
        # 语言模型预测
        with sess.as_default():
            py_in = predict.reshape(1, -1)
            han_pred = sess.run(lm_model.preds, {lm_model.x: py_in})
            han = ''.join(self.language_vocab[idx] for idx in han_pred[0])

        print(han)
    def predict(self, data_input, length, batch_size=1):
        """
        返回预测结果
        :param date_input:
        :param input_len:
        :return:
        """
        x_in = np.zeros((batch_size, 1600, self.feature_length, 1), dtype=np.float32)
        for i in range(batch_size):
            if len(data_input) > 1600:
                x_in[i, 0:1600] = data_input[:1600]
            else:
                x_in[i, 0:len(data_input)] = data_input

        # 通过输入,得到预测的值,预测的值即为网络中softmax的输出
        # shape = [1, 200, 1424]
        # 还需要通过ctc_loss的网络进行解码
        pred = self.model.predict(x_in, steps=1)

        return Util.decode_ctc(pred, length)
    def speech_predict(self, am_model, lm_model, predict_data, num, sess):
        """
        预测结果
        :param am_model:
        :param lm_model:
        :param predict_data:
        :param num:
        :param sess:
        :return:
        """

        if os.path.exists(self.lm_args.PredResultFolder):
            os.mkdir(self.lm_args.PredResultFolder)

        num_data = len(predict_data.pny_lst)
        length = predict_data.data_length
        if length is None:
            length = num_data
        ran_num = random.randint(0, length - 1)
        words_num, word_error_num, han_num, han_error_num = 0, 0, 0, 0
        data = ''
        for i in range(num):
            print('\nthe ', i + 1, 'th example.')
            # 载入训练好的模型,并进行识别
            index = (ran_num + i) % num_data
            try:
                hanzi = predict_data.han_lst[index]
                hanzi_vec = [self.language_vocab.index(idx) for idx in hanzi]
                inputs, input_length, label, _ = Util.get_fbank_and_pinyin_data(index=index,
                                                                                acoustic_vocab=self.acoustic_vocab)
                pred, pinyin = Util.predict_pinyin(model=am_model, inputs=inputs, input_length=input_length,
                                                   acoustic_vocab=self.acoustic_vocab)
                y = predict_data.pny_lst[index]

                # 语言模型预测
                with sess.as_default():
                    py_in = pred.reshape(1, -1)
                    han_pred = sess.run(lm_model.preds, {lm_model.x: py_in})
                    han = ''.join(self.language_vocab[idx] for idx in han_pred[0])
            except ValueError:
                continue
            print('原文汉字结果:', ''.join(hanzi))
            print('原文拼音结果:', ''.join(y))
            print('预测拼音结果:', pinyin)
            print('预测汉字结果:', han)
            data += '原文汉字结果:' + ''.join(hanzi) + '\n'
            data += '原文拼音结果:' + ''.join(y) + '\n'
            data += '预测拼音结果:' + pinyin + '\n'
            data += '预测汉字结果:' + han + '\n'

            words_n = label.shape[0]
            # 把句子的总字数加上
            words_num += words_n
            py_edit_distance = Util.GetEditDistance(label, pred)
            # 拼音距离
            # 当编辑距离小于等于句子字数时
            if (py_edit_distance <= words_n):
                # 使用编辑距离作为错误字数
                word_error_num += py_edit_distance
                # 否则肯定是增加了一堆乱七八糟的奇奇怪怪的字
            else:
                # 就直接加句子本来的总字数就好了
                word_error_num += words_n

                # 汉字距离
            words_n = np.array(hanzi_vec).shape[0]
            # 把句子的总字数加上
            han_num += words_n
            han_edit_distance = Util.GetEditDistance(np.array(hanzi_vec), han_pred[0])
            # 当编辑距离小于等于句子字数时
            if han_edit_distance <= words_n:
                # 使用编辑距离作为错误字数
                han_error_num += han_edit_distance
            # 否则肯定是增加了一堆乱七八糟的奇奇怪怪的字
            else:
                # 就直接加句子本来的总字数就好了
                han_error_num += words_n

        data += '*[Predict Result] Speech Recognition set word accuracy ratio: ' + str(
            (1 - word_error_num / words_num) * 100) + '%'
        filename = str(datetime.datetime.now()) + '_' + str(num)
        with open(os.path.join(self.lm_args.PredResultFolder, filename), mode='w', encoding='utf-8') as f:
            f.writelines(data)
        print('*[Predict Result] Speech Recognition set 拼音 word accuracy ratio: ',
              (1 - word_error_num / words_num) * 100, '%')
        print('*[Predict Result] Speech Recognition set 汉字 word accuracy ratio: ',
              (1 - han_error_num / han_num) * 100, '%')
    def train_lm(self, train_labels, train_pinyins):
        """
        训练语言学模型
        :param train_labels:
        :param train_pinyins:
        :param dev_audio_paths:
        :param dev_labels:
        :param dev_pinyins:
        :return:
        """
        hp = self.args
        hp.batch_size = self.args.lm_batch_size
        hp.epochs = self.args.lm_epochs
        hp.data_type = 'train'
        hp.max_len = self.args.lm_max_len
        hp.hidden_units = self.args.lm_hidden_units
        hp.is_training = self.args.lm_is_training
        hp.feature_dim = self.args.lm_feature_dim
        hp.num_heads = self.args.lm_num_heads
        hp.num_blocks = self.args.lm_num_blocks
        hp.position_max_length = self.args.lm_position_max_length
        hp.lr = self.args.lm_lr
        hp.dropout_rate = self.args.lm_dropout_rate

        epochs = hp.epochs
        lm_model = TransformerModel(
            arg=hp,
            acoustic_vocab_size=self.acoustic_vocab_size,
            language_vocab_size=self.language_vocab_size)

        batch_num = len(train_pinyins) // hp.batch_size
        with lm_model.graph.as_default():
            saver = tf.train.Saver(max_to_keep=50)
            config = tf.ConfigProto()
            # 占用GPU90%的显存
            config.gpu_options.per_process_gpu_memory_fraction = 0.9
        with tf.Session(graph=lm_model.graph, config=config) as sess:
            merged = tf.summary.merge_all()
            sess.run(tf.global_variables_initializer())
            if os.path.exists(hp.LmModelFolder):
                print('loading language model...')
                latest = tf.train.latest_checkpoint(hp.LmModelFolder)
                if latest is not None:
                    saver.restore(sess, latest)
            writer = tf.summary.FileWriter(hp.LmModelTensorboard,
                                           tf.get_default_graph())
            for k in range(epochs):
                total_loss = 0
                batch = Util.get_lm_batch(args=hp,
                                          pny_lst=train_pinyins,
                                          han_lst=train_labels,
                                          acoustic_vocab=self.acoustic_vocab,
                                          language_vocab=self.language_vocab)
                for i in range(batch_num):
                    input_batch, label_batch = next(batch)
                    feed = {lm_model.x: input_batch, lm_model.y: label_batch}
                    cost, _ = sess.run([lm_model.mean_loss, lm_model.train_op],
                                       feed_dict=feed)
                    total_loss += cost
                    if i % 10 == 0:
                        print("epoch: %d step: %d/%d  train loss=6%f" %
                              (k + 1, i, batch_num, cost))
                        if i % 5000 == 0:
                            rs = sess.run(merged, feed_dict=feed)
                            writer.add_summary(rs, k * batch_num + i)
                print('epochs', k + 1, ': average loss = ',
                      total_loss / batch_num)
                saver.save(sess, hp.LmModelFolder + hp.lm_ckpt)
            writer.close()
        pass