def TrainModel(self, datapath, epoch = 2, save_step = 1000, batch_size = 32, filename = 'model_speech/speech_model2'): ''' 训练模型 参数: datapath: 数据保存的路径 epoch: 迭代轮数 save_step: 每多少步保存一次模型 filename: 默认保存文件名,不含文件后缀名 ''' data=DataSpeech(datapath) data.LoadDataList('train') num_data = data.GetDataNum() # 获取数据的数量 for epoch in range(epoch): # 迭代轮数 print('[running] train epoch %d .' % epoch) n_step = 0 # 迭代数据数 while True: try: print('[message] epoch %d . Have train datas %d+'%(epoch, n_step*save_step)) # data_genetator是一个生成器函数 yielddatas = data.data_genetator(batch_size, self.AUDIO_LENGTH) #self._model.fit_generator(yielddatas, save_step, nb_worker=2) self._model.fit_generator(yielddatas, save_step) n_step += 1 except StopIteration: print('[error] generator error. please check data format.') break self.SaveModel(comment='_e_'+str(epoch)+'_step_'+str(n_step * save_step))
def TestModel(self, datapath, str_dataset='dev'): ''' 测试检验模型效果 ''' data = DataSpeech(datapath) data.LoadDataList(str_dataset) num_data = DataSpeech.GetDataNum() # 获取数据的数量 try: gen = data.data_genetator(num_data) for i in range(1): X, y = gen r = self._model.test_on_batch(X, y) print(r) except StopIteration: print('[Error] Model Test Error. please check data format.')