def build_model(self): am_args = am_hparams() am_args.vocab_size = 230 #len(train_data.am_vocab) am = Am(am_args) return am.ctc_model
import difflib import tensorflow as tf import numpy as np from utils import decode_ctc, GetEditDistance # 0.准备解码所需字典,参数需和训练一致,也可以将字典保存到本地,直接进行读取 from utils import get_data, data_hparams data_args = data_hparams() train_data = get_data(data_args) # 1.声学模型----------------------------------- from model_speech.cnn_ctc import Am, am_hparams am_args = am_hparams() am_args.vocab_size = len(train_data.am_vocab) am = Am(am_args) print('loading acoustic model...') am.ctc_model.load_weights('logs_am/model.h5') # 2.语言模型------------------------------------------- from model_language.transformer import Lm, lm_hparams lm_args = lm_hparams() lm_args.input_vocab_size = len(train_data.pny_vocab) lm_args.label_vocab_size = len(train_data.han_vocab) lm_args.dropout_rate = 0. print('loading language model...') lm = Lm(lm_args) sess = tf.Session(graph=lm.graph) with lm.graph.as_default(): saver = tf.train.Saver()
def train_am(x=None, y=None, fit_epoch=10): from model_speech.cnn_ctc import Am, am_hparams am_args = am_hparams() am_args.vocab_size = len(utils.pny_vocab) am_args.gpu_nums = 1 am_args.lr = 0.0008 am_args.is_training = True am = Am(am_args) if os.path.exists(os.path.join(utils.cur_path, 'logs_am', 'model.h5')): print('加载声学模型...') am.ctc_model.load_weights( os.path.join(utils.cur_path, 'logs_am', 'model.h5')) checkpoint = ModelCheckpoint(os.path.join( utils.cur_path, 'checkpoint', "model_{epoch:02d}-{val_loss:.2f}.h5"), monitor='val_loss', save_best_only=True) eStop = EarlyStopping() #损失函数不再减小后patience轮停止训练 #tensorboard --logdir=/media/yangjinming/DATA/GitHub/AboutPython/AboutDL/语音识别/logs_am/tbLog/ --host=127.0.0.1 #tensbrd = TensorBoard(log_dir=os.path.join(utils.cur_path,'logs_am','tbLog')) if x is not None: #利用实时声音训练调整模型,使定制化 size = 1 if type(x) == np.ndarray: x, y = utils.real_time2data([x], [y]) else: size = len(x) x, y = utils.real_time2data(x, y) am.ctc_model.fit(x=x, y=y, batch_size=size, epochs=fit_epoch) else: #利用训练数据 batch = train_data.get_am_batch() #获取的是生成器 dev_batch = dev_data.get_am_batch() validate_step = 200 #取N个验证的平均结果 history = am.ctc_model.fit_generator(batch, steps_per_epoch=batch_num, epochs=epochs, callbacks=[eStop, checkpoint], workers=1, use_multiprocessing=False, verbose=1, validation_data=dev_batch, validation_steps=validate_step) am.ctc_model.save_weights( os.path.join(utils.cur_path, 'logs_am', 'model.h5')) #写入序列化的 PB 文件 #with keras.backend.get_session() as sess: sess = keras.backend.get_session() constant_graph = tf.compat.v1.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=['the_inputs', 'dense_2/truediv']) with tf.gfile.GFile(os.path.join(utils.cur_path, 'logs_am', 'amModel.pb'), mode='wb') as f: f.write(constant_graph.SerializeToString()) #保存TF serving用文件 builder = tf.saved_model.builder.SavedModelBuilder( os.path.join(utils.cur_path, 'logs_am', modelVersion)) model_signature = tf.saved_model.signature_def_utils.predict_signature_def( inputs={'input': am.inputs}, outputs={'output': am.outputs}) builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], { tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: model_signature }) builder.save() if x is None: sess.close()