コード例 #1
0
ファイル: test.py プロジェクト: StevenLai1994/AM_and_LM
def load_lm():
    # 2.语言模型-------------------------------------------
    from model_language.transformer import Lm, lm_hparams

    lm_args = lm_hparams()
    lm_args.input_vocab_size = len(train_data.pny_vocab)
    lm_args.label_vocab_size = len(train_data.han_vocab)
    lm_args.dropout_rate = 0.
    print('loading language model...')
    lm = Lm(lm_args)
    sess = tf.Session(graph=lm.graph)
    with lm.graph.as_default():
        saver = tf.train.Saver()
    with sess.as_default():
        latest = tf.train.latest_checkpoint('logs_lm')
        saver.restore(sess, latest)
    return sess, lm
コード例 #2
0
    def __init__(self,test_flag = True):
        # 0.准备解码所需字典,参数需和训练一致,也可以将字典保存到本地,直接进行读取
        self.test_flag = test_flag
        #print('加载声学模型中...')
        if K_usePB:
            self.AM_sess = tf.Session()
            with tf.gfile.GFile(os.path.join(cur_path,'logs_am','amModel.pb'), 'rb') as f:#加载模型
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                self.AM_sess.graph.as_default()
                tf.import_graph_def(graph_def, name='') #导入计算图
                self.AM_sess.run(tf.global_variables_initializer())#需要有一个初始化的过程
            self.AM_x = self.AM_sess.graph.get_tensor_by_name('the_inputs:0') #此处的x一定要和之前保存时输入的名称一致!
            self.AM_preds = self.AM_sess.graph.get_tensor_by_name('dense_2/truediv:0')
        else:
            from model_speech.cnn_ctc import Am, am_hparams
            am_args = am_hparams()
            am_args.vocab_size = len(pny_vocab)#这里有个坑,需要和训练时的长度一致,需要强烈关注!
            self.am = Am(am_args)
            self.am.ctc_model.load_weights(os.path.join(cur_path,'logs_am','model.h5'))

        #print('加载语言模型中...')
        if tf_usePB:
            self.sess = tf.Session()
            with tf.gfile.GFile(os.path.join(cur_path,'logs_lm','lmModel.pb'), 'rb') as f:#加载模型
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                self.sess.graph.as_default()
                tf.import_graph_def(graph_def, name='') # 导入计算图
                self.sess.run(tf.global_variables_initializer())# 需要有一个初始化的过程
            self.x = self.sess.graph.get_tensor_by_name('x:0') #此处的x一定要和之前保存时输入的名称一致!
            self.preds = self.sess.graph.get_tensor_by_name('preds:0')
        else:#ckpt
            from model_language.transformer import Lm, lm_hparams
            lm_args = lm_hparams()
            lm_args.input_vocab_size = len(pny_vocab)
            lm_args.label_vocab_size = len(han_vocab)
            lm_args.dropout_rate = 0.
            self.lm = Lm(lm_args)
            self.sess = tf.Session(graph=self.lm.graph)
            with self.lm.graph.as_default():
                saver =tf.train.Saver()
            with self.sess.as_default():
                lmPath = tf.train.latest_checkpoint(os.path.join(cur_path,'logs_lm'))
                saver.restore(self.sess, lmPath)
コード例 #3
0
# 1.声学模型-----------------------------------
from model_speech.cnn_ctc import Am, am_hparams

am_args = am_hparams()  # 參數初始化 EX: learning rate
# am_args.vocab_size = 230
am_args.vocab_size = len(train_data.am_vocab)   # 設定單字長度
am = Am(am_args)        # 利用設定好的參數,建造出一個model
print('loading acoustic model...')
am.ctc_model.load_weights('logs_am/model.h5')
am.ctc_model.summary()

# 2.语言模型-------------------------------------------
from model_language.transformer import Lm, lm_hparams

lm_args = lm_hparams()
lm_args.input_vocab_size = len(train_data.pny_vocab)
lm_args.label_vocab_size = len(train_data.han_vocab)
lm_args.dropout_rate = 0.
print('loading language model...')
lm = Lm(lm_args)
sess = tf.Session(graph=lm.graph)
with lm.graph.as_default():
    saver =tf.train.Saver()
with sess.as_default():
    latest = tf.train.latest_checkpoint('logs_lm')
    saver.restore(sess, latest)

# 3. 准备测试所需数据, 不必和训练数据一致,通过设置data_args.data_type测试,
#    此处应设为'test',我用了'train'因为演示模型较小,如果使用'test'看不出效果,
#    且会出现未出现的词。
コード例 #4
0
ファイル: train.py プロジェクト: StevenLai1994/AM_and_LM
def train_lm(epochs):
    # 2.语言模型训练-------------------------------------------
    from model_language.transformer import Lm, lm_hparams
    lm_args = lm_hparams()
    lm_args.num_heads = 8
    lm_args.num_blocks = 6
    lm_args.input_vocab_size = len(train_data.pny_vocab)
    lm_args.label_vocab_size = len(train_data.han_vocab)
    lm_args.max_length = 100
    lm_args.hidden_units = 512
    lm_args.dropout_rate = 0.2
    lm_args.lr = 0.0003
    lm_args.is_training = True
    lm = Lm(lm_args)

    batch_num = len(train_data.wav_lst) // train_data.batch_size
    dev_batch_num = len(dev_data.wav_lst) // dev_data.batch_size
    pre_epoch = 0
    with lm.graph.as_default():
        saver = tf.train.Saver(max_to_keep=5)
    with tf.Session(graph=lm.graph) as sess:
        add_num = -1
        if os.path.exists('logs_lm/checkpoint'):
            # 恢复变量
            import re
            model_file = tf.train.latest_checkpoint('logs_lm')
            pre_epoch = int(re.findall('\d+', model_file)[0])
            min_loss = float(re.findall('\d+\.+\d*', model_file)[0])
            print(
                "=====================restore latest save model, epoch=%d===loss=%.3f=================="
                % (pre_epoch, min_loss))
            saver.restore(sess, model_file)
        else:
            # 初始化变量
            sess.run(tf.global_variables_initializer())
        merged = tf.summary.merge_all()
        min_loss = 100
        writer = tf.summary.FileWriter('logs_lm/tensorboard',
                                       tf.get_default_graph())

        for k in range(pre_epoch, epochs):
            #训练==============================================
            total_loss = 0
            batch = train_data.get_lm_batch()
            for i in range(batch_num):
                input_batch, label_batch = next(batch)
                feed = {lm.x: input_batch, lm.y: label_batch}
                cost, _ = sess.run([lm.mean_loss, lm.train_op], feed_dict=feed)
                total_loss += cost
                if (k * batch_num + i) % 10 == 0:
                    rs = sess.run(merged, feed_dict=feed)
                    writer.add_summary(rs, k * batch_num + i)
                if (i + 1) % 500 == 0:
                    print(
                        'epoch:{:04d}, step:{:05d}, train_loss:{:0.3f}'.format(
                            k + 1, i, total_loss / i))
            epoch_loss = total_loss / batch_num
            print('epochs', k + 1, ': average loss = ', epoch_loss)

            #验证==============================================
            dev_batch = dev_data.get_lm_batch()
            for i in range(dev_batch_num):
                total_loss = 0
                input_batch, label_batch = next(dev_batch)
                feed = {lm.x: input_batch, lm.y: label_batch}
                cost, _ = sess.run([lm.mean_loss, lm.train_op], feed_dict=feed)
                total_loss += cost
                if (i + 1) % 500 == 0:
                    print("dev_step:{:05d}, dev_loss:{:0.3f}".format(
                        i, total_loss / i))
            print('dev_loss:{:.3f}'.format(total_loss / dev_batch_num))
            if total_loss / dev_batch_num < min_loss:
                min_loss = total_loss / dev_batch_num
                print("save model")
                saver.save(
                    sess, 'logs_lm/model_epoch_{:04d}_val_loss_{:0.3f}'.format(
                        k + 1, min_loss))

        writer.close()
コード例 #5
0
def train_lm(eStop = False):
    global epochs
    from model_language.transformer import Lm, lm_hparams
    import numpy as np
    lm_args = lm_hparams()
    lm_args.num_heads = 8
    lm_args.num_blocks = 6
    lm_args.input_vocab_size = len(utils.pny_vocab)
    lm_args.label_vocab_size = len(utils.han_vocab)
    lm_args.max_length = 100
    lm_args.hidden_units = 512
    lm_args.dropout_rate = 0.2
    lm_args.lr = 0.0003
    lm_args.is_training = True
    lm = Lm(lm_args)

    batch_num_list = [i for i in range(batch_num)]#为进度条显示每一个epoch中的进度用
    loss_list=[]#记录每一步平均损失的列表,实现提前终止训练功能:每次取出后N个数据的平均值和当前的平均损失值作比较
    with lm.graph.as_default():
        saver =tf.train.Saver()
    with tf.Session(graph=lm.graph) as sess:
        merged = tf.summary.merge_all()
        sess.run(tf.global_variables_initializer())
        add_num = 0
        if os.path.exists(os.path.join(utils.cur_path,'logs_lm','checkpoint')):
            print('加载语言模型中...')
            latest = tf.train.latest_checkpoint(os.path.join(utils.cur_path,'logs_lm'))
            add_num = int(latest.split('_')[-1])
            saver.restore(sess, latest)
        #tensorboard --logdir=/media/yangjinming/DATA/GitHub/AboutPython/AboutDL/语音识别/logs_lm/tensorboard --host=127.0.0.1
        #writer = tf.summary.FileWriter(os.path.join(utils.cur_path,'logs_lm','tensorboard'), tf.get_default_graph())
        for k in range(epochs):
            total_loss = 0
            batch = train_data.get_lm_batch()
            for i in tqdm(batch_num_list,ncols=90):
                input_batch, label_batch = next(batch)
                feed = {lm.x: input_batch, lm.y: label_batch}
                cost,_ = sess.run([lm.mean_loss,lm.train_op], feed_dict=feed)
                total_loss += cost
                if (k * batch_num + i) % 10 == 0:
                    rs=sess.run(merged, feed_dict=feed)
                    #writer.add_summary(rs, k * batch_num + i)
            avg_loss = total_loss/batch_num
            print('步数', k+1, ': 平均损失值 = ', avg_loss)

            loss_list.append(avg_loss)
            if eStop and len(loss_list)>1 and avg_loss>np.mean(loss_list[-5:])-0.0015:#平均每个epoch下降不到0.0005则终止
                #if input('模型性能已无法提升,是否提前结束训练? yes/no:')=='yes':
                    epochs = k+1#为后面保存模型时记录名字用
                    break
        
        saver.save(sess, os.path.join(utils.cur_path,'logs_lm','model_%d' % (epochs + add_num)))
        #writer.close()
        # 写入序列化的 PB 文件
        constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def,output_node_names=['x','y','preds'])
        with tf.gfile.GFile(os.path.join(utils.cur_path,'logs_lm','lmModel.pb'), mode='wb') as f:
            f.write(constant_graph.SerializeToString())

        #tf serving 用保存文件,目前保存了三种模型,按需选择一种即可
        model_signature = tf.saved_model.signature_def_utils.build_signature_def(
            inputs={"pinyin": tf.saved_model.utils.build_tensor_info(lm.x)},
            outputs={"hanzi": tf.saved_model.utils.build_tensor_info(lm.preds)},
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
        builder = tf.saved_model.builder.SavedModelBuilder(os.path.join(utils.cur_path,'logs_lm',modelVersion))
        builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.SERVING],
            {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: model_signature})
        builder.save()