def __init__(self, conf_dict): self.nnet_conf = ProjConfig() self.nnet_conf.initial(conf_dict) self.kaldi_io_nstream = None feat_trans = FeatureTransform() feat_trans.LoadTransform(conf_dict['feature_transfile']) # init train file self.kaldi_io_nstream_train = KaldiDataReadParallel() self.input_dim = self.kaldi_io_nstream_train.Initialize( conf_dict, scp_file=conf_dict['scp_file'], label=conf_dict['label'], feature_transform=feat_trans, criterion='ce') # init cv file self.kaldi_io_nstream_cv = KaldiDataReadParallel() self.kaldi_io_nstream_cv.Initialize(conf_dict, scp_file=conf_dict['cv_scp'], label=conf_dict['cv_label'], feature_transform=feat_trans, criterion='ce') self.num_batch_total = 0 self.num_frames_total = 0 logging.info(self.nnet_conf.__repr__()) logging.info(self.kaldi_io_nstream_train.__repr__()) logging.info(self.kaldi_io_nstream_cv.__repr__()) self.print_trainable_variables = False if conf_dict.has_key('print_trainable_variables'): self.print_trainable_variables = conf_dict[ 'print_trainable_variables'] self.tf_async_model_prefix = conf_dict['checkpoint_dir'] self.num_threads = conf_dict['num_threads'] self.queue_cache = conf_dict['queue_cache'] self.input_queue = Queue.Queue(self.queue_cache) self.acc_label_error_rate = [] for i in range(self.num_threads): self.acc_label_error_rate.append(1.1) if conf_dict.has_key('use_normal'): self.use_normal = conf_dict['use_normal'] else: self.use_normal = False if conf_dict.has_key('use_sgd'): self.use_sgd = conf_dict['use_sgd'] else: self.use_sgd = True if conf_dict.has_key('restore_training'): self.restore_training = conf_dict['restore_training'] else: self.restore_training = False
) if __name__ == '__main__': path = '/search/speech/hubo/git/tf-code-acoustics/chain_source_7300/8-cegs-scp/' path = '/search/speech/hubo/git/tf-code-acoustics/chain_source_7300/' conf_dict = { 'batch_size' :64, 'skip_offset': 0, 'skip_frame':3, 'shuffle': False, 'queue_cache':2, 'io_thread_num':5} feat_trans_file = '../conf/final.feature_transform' feat_trans = FeatureTransform() feat_trans.LoadTransform(feat_trans_file) io_read = KaldiDataReadParallel() #io_read.Initialize(conf_dict, scp_file=path+'cegs.all.scp_0', io_read.Initialize(conf_dict, scp_file=path+'cegs.1.scp', feature_transform = feat_trans, criterion = 'chain') io_read.Reset(shuffle = False) def Gen(): while True: inputs = io_read.GetInput() if inputs[0] is not None: indexs, in_labels, weights, statesinfo, num_states = inputs[3] yield(inputs[0],inputs[1],inputs[2],indexs, in_labels, weights, statesinfo, num_states) else: print("-----end io----") break
if __name__ == '__main__': path = './' conf_dict = { 'batch_size': 64, 'skip_offset': 0, 'shuffle': False, 'queue_cache': 2, 'io_thread_num': 2 } feat_trans_file = '../../conf/final.feature_transform' feat_trans = FeatureTransform() feat_trans.LoadTransform(feat_trans_file) logging.basicConfig(filename='test.log') logging.getLogger().setLevel('INFO') io_read = KaldiDataReadParallel() io_read.Initialize(conf_dict, scp_file=path + 'scp', feature_transform=feat_trans, criterion='chain') start = time.time() io_read.Reset(shuffle=False) batch_num = 0 den_fst = '../../chain_source/den.fst' den_indexs, den_in_labels, den_weights, den_statesinfo, den_num_states, den_start_state, laststatesuperfinal = Fst2SparseMatrix( den_fst) leaky_hmm_coefficient = 0.1 l2_regularize = 0.00005 xent_regularize = 0.0 delete_laststatesuperfinal = True
def __init__(self, conf_dict): # configure paramtere self.conf_dict = conf_dict self.print_trainable_variables_cf = False self.use_normal_cf = False self.use_sgd_cf = True self.restore_training_cf = True self.checkpoint_dir_cf = None self.num_threads_cf = 1 self.queue_cache_cf = 100 self.task_index_cf = -1 self.grad_clip_cf = 5.0 self.feature_transfile_cf = None self.learning_rate_cf = 0.1 self.learning_rate_decay_steps_cf = 100000 self.learning_rate_decay_rate_cf = 0.96 self.batch_size_cf = 16 self.num_frames_batch_cf = 20 self.steps_per_checkpoint_cf = 1000 self.criterion_cf = 'ctc' # initial configuration parameter for attr in self.__dict__: if len(attr.split('_cf')) != 2: continue; key = attr.split('_cf')[0] if key in conf_dict.keys(): if key in strset or type(conf_dict[key]) is not str: self.__dict__[attr] = conf_dict[key] else: print('***************',key) self.__dict__[attr] = eval(conf_dict[key]) if self.feature_transfile_cf == None: logging.info('No feature_transfile,it must have.') sys.exit(1) feat_trans = FeatureTransform() feat_trans.LoadTransform(self.feature_transfile_cf) # init train file self.kaldi_io_nstream_train = KaldiDataReadParallel() self.input_dim = self.kaldi_io_nstream_train.Initialize(conf_dict, scp_file = conf_dict['tr_scp'], label = conf_dict['tr_label'], feature_transform = feat_trans, criterion = self.criterion_cf) # init cv file self.kaldi_io_nstream_cv = KaldiDataReadParallel() #self.kaldi_io_nstream_cv.Initialize(conf_dict, # scp_file = conf_dict['cv_scp'], label = conf_dict['cv_label'], # feature_transform = feat_trans, criterion = 'ctc') self.num_batch_total = 0 self.tot_lab_err_rate = 0.0 self.tot_num_batch = 0.0 logging.info(self.kaldi_io_nstream_train.__repr__()) #logging.info(self.kaldi_io_nstream_cv.__repr__()) # Initial input queue. self.input_queue = Queue.Queue(self.queue_cache_cf) self.acc_label_error_rate = [] self.all_lab_err_rate = [] self.num_save = 0 for i in range(5): self.all_lab_err_rate.append(1.1) self.num_batch = [] for i in range(self.num_threads_cf): self.acc_label_error_rate.append(1.0) self.num_batch.append(0) return
class TrainClass(object): ''' ''' def __init__(self, conf_dict): # configure paramtere self.conf_dict = conf_dict self.print_trainable_variables_cf = False self.use_normal_cf = False self.use_sgd_cf = True self.restore_training_cf = True self.checkpoint_dir_cf = None self.num_threads_cf = 1 self.queue_cache_cf = 100 self.task_index_cf = -1 self.grad_clip_cf = 5.0 self.feature_transfile_cf = None self.learning_rate_cf = 0.1 self.learning_rate_decay_steps_cf = 100000 self.learning_rate_decay_rate_cf = 0.96 self.batch_size_cf = 16 self.num_frames_batch_cf = 20 self.steps_per_checkpoint_cf = 1000 self.criterion_cf = 'ctc' # initial configuration parameter for attr in self.__dict__: if len(attr.split('_cf')) != 2: continue; key = attr.split('_cf')[0] if key in conf_dict.keys(): if key in strset or type(conf_dict[key]) is not str: self.__dict__[attr] = conf_dict[key] else: print('***************',key) self.__dict__[attr] = eval(conf_dict[key]) if self.feature_transfile_cf == None: logging.info('No feature_transfile,it must have.') sys.exit(1) feat_trans = FeatureTransform() feat_trans.LoadTransform(self.feature_transfile_cf) # init train file self.kaldi_io_nstream_train = KaldiDataReadParallel() self.input_dim = self.kaldi_io_nstream_train.Initialize(conf_dict, scp_file = conf_dict['tr_scp'], label = conf_dict['tr_label'], feature_transform = feat_trans, criterion = self.criterion_cf) # init cv file self.kaldi_io_nstream_cv = KaldiDataReadParallel() #self.kaldi_io_nstream_cv.Initialize(conf_dict, # scp_file = conf_dict['cv_scp'], label = conf_dict['cv_label'], # feature_transform = feat_trans, criterion = 'ctc') self.num_batch_total = 0 self.tot_lab_err_rate = 0.0 self.tot_num_batch = 0.0 logging.info(self.kaldi_io_nstream_train.__repr__()) #logging.info(self.kaldi_io_nstream_cv.__repr__()) # Initial input queue. self.input_queue = Queue.Queue(self.queue_cache_cf) self.acc_label_error_rate = [] self.all_lab_err_rate = [] self.num_save = 0 for i in range(5): self.all_lab_err_rate.append(1.1) self.num_batch = [] for i in range(self.num_threads_cf): self.acc_label_error_rate.append(1.0) self.num_batch.append(0) return # multi computers construct train graph def ConstructGraph(self, device, server): with tf.device(device): if 'cnn' in self.criterion_cf: self.X = tf.placeholder(tf.float32, [None, self.input_dim[0], self.input_dim[1], 1], name='feature') else: self.X = tf.placeholder(tf.float32, [None, self.batch_size_cf, self.input_dim], name='feature') if 'ctc' in self.criterion_cf: self.Y = tf.sparse_placeholder(tf.int32, name="labels") elif 'whole' in self.criterion_cf: self.Y = tf.placeholder(tf.int32, [self.batch_size_cf, None], name="labels") elif 'ce' in self.criterion_cf: self.Y = tf.placeholder(tf.int32, [self.batch_size_cf, self.num_frames_batch_cf], name="labels") self.seq_len = tf.placeholder(tf.int32,[None], name = 'seq_len') #self.learning_rate_var_tf = tf.Variable(float(self.learning_rate_cf), # trainable=False, name='learning_rate') # init global_step and learning rate decay criterion global_step=tf.train.get_or_create_global_step() exponential_decay = True if exponential_decay == True: self.learning_rate_var_tf = tf.train.exponential_decay( float(self.learning_rate_cf), global_step, self.learning_rate_decay_steps_cf, self.learning_rate_decay_rate_cf, staircase=True, name = 'learning_rate_exponential_decay') elif piecewise_constant == True: boundaries = [100000, 110000] values = [1.0, 0.5, 0.1] self.learning_rate_var_tf = tf.train.piecewise_constant( global_step, boundaries, values) elif inverse_time_decay == True: # decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step)) # decay_rate = 0.5 , decay_step = 100000 self.learning_rate_var_tf = tf.train.inverse_time_decay( float(self.learning_rate_cf), global_step, self.learning_rate_decay_steps_cf, self.learning_rate_decay_rate_cf, staircase=True, name = 'learning_rate_inverse_time_decay') if self.use_sgd_cf: optimizer = tf.train.GradientDescentOptimizer(self.learning_rate_var_tf) else: optimizer = tf.train.AdamOptimizer(learning_rate= self.learning_rate_var_tf, beta1=0.9, beta2=0.999, epsilon=1e-08) nnet_model = LstmModel(self.conf_dict) mean_loss = None loss = None rnn_state_zero_op = None rnn_keep_state_op = None if 'ctc' in self.criterion_cf: ctc_mean_loss, ctc_loss , label_error_rate, _ = nnet_model.CtcLoss(self.X, self.Y, self.seq_len) mean_loss = ctc_mean_loss loss = ctc_loss # elif 'ce' in self.criterion_cf and 'cnn' in self.criterion_cf and 'whole' in self.criterion_cf: # ce_mean_loss, ce_loss , label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.CeCnnBlstmLoss(self.X, self.Y, self.seq_len) # mean_loss = ce_mean_loss # loss = ce_loss elif 'ce' in self.criterion_cf: ce_mean_loss, ce_loss , label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.CeLoss(self.X, self.Y, self.seq_len) mean_loss = ce_mean_loss loss = ce_loss if self.use_sgd_cf and self.use_normal_cf: tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients( mean_loss, tvars), self.grad_clip_cf) train_op = optimizer.apply_gradients( zip(grads, tvars), global_step=tf.train.get_or_create_global_step()) else: train_op = optimizer.minimize(mean_loss, global_step=tf.train.get_or_create_global_step()) # set run operation self.run_ops = {'train_op':train_op, 'mean_loss':mean_loss, 'loss':loss, 'label_error_rate':label_error_rate, 'rnn_keep_state_op':rnn_keep_state_op, 'rnn_state_zero_op':rnn_state_zero_op} # set initial parameter self.init_para = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) #tmp_variables = tf.trainable_variables() #self.saver = tf.train.Saver(tmp_variables, max_to_keep=100) self.total_variables = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) logging.info('total parameters : %d' % self.total_variables) # set gpu option gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0) # session config sess_config = tf.ConfigProto(intra_op_parallelism_threads=self.num_threads_cf, inter_op_parallelism_threads=self.num_threads_cf, allow_soft_placement=True, log_device_placement=False,gpu_options=gpu_options) global_step = tf.train.get_or_create_global_step() # add saver hook self.saver = tf.train.Saver(max_to_keep=50, sharded=True, allow_empty=True) scaffold = tf.train.Scaffold(saver = self.saver) self.sess = tf.train.MonitoredTrainingSession( master = server.target, is_chief = (self.task_index_cf==0), checkpoint_dir = self.checkpoint_dir_cf, scaffold= scaffold, hooks=None, chief_only_hooks=None, save_checkpoint_secs=None, save_summaries_steps=self.steps_per_checkpoint_cf, save_summaries_secs=None, config=sess_config, stop_grace_period_secs=120, log_step_count_steps=100, max_wait_secs=7200, save_checkpoint_steps=self.steps_per_checkpoint_cf) # summary_dir = self.checkpoint_dir_cf + "_summary_dir") ''' sv = tf.train.Supervisor(is_chief=(self.task_index_cf==0), global_step=global_step, init_op = self.init_para, logdir = self.checkpoint_dir_cf, saver=self.saver, save_model_secs=600, checkpoint_basename='model.ckpt') self.sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) ''' return def SaveTextModel(self): if self.print_trainable_variables_cf == True: ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir_cf) if ckpt and ckpt.model_checkpoint_path: print_trainable_variables(self.sess, ckpt.model_checkpoint_path+'.txt') def InputFeat(self, input_lock): while True: input_lock.acquire() ''' if 'ctc' in self.criterion_cf or 'whole' in self.criterion_cf: if 'cnn' in self.criterion_cf: feat,label,length = self.kaldi_io_nstream.CnnLoadNextNstreams() else: feat,label,length = self.kaldi_io_nstream.WholeLoadNextNstreams() if length is None: break print(np.shape(feat),np.shape(label), np.shape(length)) if len(label) != self.batch_size_cf: break if 'ctc' in self.criterion_cf: sparse_label = sparse_tuple_from(label) self.input_queue.put((feat,sparse_label,length)) else: self.input_queue.put((feat,label,length)) elif 'ce' in self.criterion_cf: if 'cnn' in self.criterion_cf: feat_array, label_array, length_array = self.kaldi_io_nstream.CnnSliceLoadNextNstreams() else: feat_array, label_array, length_array = self.kaldi_io_nstream.SliceLoadNextNstreams() if length_array is None: break print(np.shape(feat_array),np.shape(label_array), np.shape(length_array)) if len(label_array[0]) != self.batch_size_cf: break self.input_queue.put((feat_array, label_array, length_array)) ''' feat,label,length = self.kaldi_io_nstream.LoadBatch() if length is None: break print(np.shape(feat),np.shape(label), np.shape(length)) sys.stdout.flush() if 'ctc' in self.criterion_cf: sparse_label = sparse_tuple_from(label) self.input_queue.put((feat,sparse_label,length)) else: self.input_queue.put((feat,label,length)) self.num_batch_total += 1 # if self.num_batch_total % 3000 == 0: # self.SaveModel() # self.AdjustLearnRate() print('total_batch_num**********',self.num_batch_total,'***********') input_lock.release() self.input_queue.put((None, None, None)) def ThreadInputFeatAndLab(self): input_thread = [] input_lock = threading.Lock() for i in range(1): input_thread.append(threading.Thread(group=None, target=self.InputFeat, args=(input_lock,),name='read_thread'+str(i))) for thr in input_thread: logging.info('ThreadInputFeatAndLab start') thr.start() return input_thread def SaveModel(self): while True: time.sleep(1.0) if self.input_queue.empty(): checkpoint_path = os.path.join(self.checkpoint_dir_cf, str(self.num_batch_total)+'_model'+'.ckpt') logging.info('save model: '+checkpoint_path+ ' --- learn_rate: ' + str(self.sess.run(self.learning_rate_var_tf))) self.saver.save(self.sess, checkpoint_path) break # if current label error rate less then previous five def AdjustLearnRate(self): curr_lab_err_rate = self.GetAverageLabelErrorRate() logging.info("current label error rate : %f" % curr_lab_err_rate) all_lab_err_rate_len = len(self.all_lab_err_rate) for i in range(all_lab_err_rate_len): if curr_lab_err_rate < self.all_lab_err_rate[i]: break if i == len(self.all_lab_err_rate)-1: self.DecayLearningRate(0.8) logging.info('learn_rate decay to '+str(self.sess.run(self.learning_rate_var_tf))) self.all_lab_err_rate[self.num_save%all_lab_err_rate_len] = curr_lab_err_rate self.num_save += 1 def DecayLearningRate(self, lr_decay_factor): learning_rate_decay_op = self.learning_rate_var_tf.assign(tf.multiply(self.learning_rate_var_tf, lr_decay_factor)) self.sess.run(learning_rate_decay_op) logging.info('learn_rate decay to '+str(self.sess.run(self.learning_rate_var_tf))) logging.info('lr_decay_factor is '+str(lr_decay_factor)) # get restore model number def GetNum(self,str): return int(str.split('/')[-1].split('_')[0]) # train_loss is a open train or cv . def TrainLogic(self, device, shuffle = False, train_loss = True, skip_offset = 0): if train_loss == True: logging.info('TrainLogic train start.') logging.info('Start global step is %d---learn_rate is %f' % (self.sess.run(tf.train.get_or_create_global_step()), self.sess.run(self.learning_rate_var_tf))) self.kaldi_io_nstream = self.kaldi_io_nstream_train # set run operation if 'ctc' in self.criterion_cf or 'whole' in self.criterion_cf: run_op = {'train_op':self.run_ops['train_op'], 'label_error_rate': self.run_ops['label_error_rate'], 'mean_loss':self.run_ops['mean_loss'], 'loss':self.run_ops['loss']} elif 'ce' in self.criterion_cf: run_op = {'train_op':self.run_ops['train_op'], 'label_error_rate': self.run_ops['label_error_rate'], 'mean_loss':self.run_ops['mean_loss'], 'loss':self.run_ops['loss'], 'rnn_keep_state_op':self.run_ops['rnn_keep_state_op'], 'rnn_state_zero_op':self.run_ops['rnn_state_zero_op']} else: assert 'No train criterion.' else: logging.info('TrainLogic cv start.') self.kaldi_io_nstream = self.kaldi_io_nstream_cv run_op = {'label_error_rate':self.run_ops['label_error_rate'], 'mean_loss':self.run_ops['mean_loss']} # reset io and start input thread self.kaldi_io_nstream.Reset(shuffle = shuffle, skip_offset = skip_offset) threadinput = self.ThreadInputFeatAndLab() time.sleep(3) with tf.device(device): if 'ctc' in self.criterion_cf or 'whole' in self.criterion_cf: self.WholeTrainFunction(0, run_op, 'train_ctc_thread_hubo') elif 'ce' in self.criterion_cf: self.SliceTrainFunction(0, run_op, 'train_ce_thread_hubo') tmp_label_error_rate = self.GetAverageLabelErrorRate() logging.info("current averagelabel error rate : %f" % tmp_label_error_rate) logging.info('learn_rate is '+str(self.sess.run(self.learning_rate_var_tf))) if train_loss == True: self.AdjustLearnRate() logging.info('TrainLogic train end.') else: logging.info('TrainLogic cv end.') # End input thread for i in range(len(threadinput)): threadinput[i].join() self.ResetAccuracy() return tmp_label_error_rate def WholeTrainFunction(self, gpu_id, run_op, thread_name): logging.info('******start WholeTrainFunction******') total_curr_error_rate = 0.0 num_batch = 0 self.acc_label_error_rate[gpu_id] = 0.0 self.num_batch[gpu_id] = 0 #print_trainable_variables(self.sess, 'save.model.txt') while True: time1=time.time() feat, label, length = self.GetFeatAndLabel() if feat is None: logging.info('train thread end : %s' % thread_name) break time2=time.time() feed_dict = {self.X : feat, self.Y : label, self.seq_len : length} time3 = time.time() calculate_return = self.sess.run(run_op, feed_dict = feed_dict) time4 = time.time() print("thread_name: ", thread_name, num_batch, " time:",time2-time1,time3-time2,time4-time3,time4-time1) print('label_error_rate:',calculate_return['label_error_rate']) print('mean_loss:',calculate_return['mean_loss']) num_batch += 1 total_curr_error_rate += calculate_return['label_error_rate'] self.acc_label_error_rate[gpu_id] += calculate_return['label_error_rate'] self.num_batch[gpu_id] += 1 if self.num_batch[gpu_id] % int(self.steps_per_checkpoint_cf/50) == 0: logging.info("Batch: %d current averagelabel error rate : %f" % (int(self.steps_per_checkpoint_cf/50), total_curr_error_rate / int(self.steps_per_checkpoint_cf/50))) total_curr_error_rate = 0.0 logging.info("Batch: %d current total averagelabel error rate : %f" % (self.num_batch[gpu_id], self.acc_label_error_rate[gpu_id] / self.num_batch[gpu_id])) logging.info('******end TrainFunction******') def SliceTrainFunction(self, gpu_id, run_op, thread_name): logging.info('******start TrainFunction******') total_acc_error_rate = 0.0 num_batch = 0 self.acc_label_error_rate[gpu_id] = 0.0 self.num_batch[gpu_id] = 0 while True: time1=time.time() feat, label, length = self.GetFeatAndLabel() if feat is None: logging.info('train thread end : %s' % thread_name) break time2=time.time() self.sess.run(run_op['rnn_state_zero_op']) for i in range(len(feat)): time3 = time.time() feed_dict = {self.X : feat[i], self.Y : label[i], self.seq_len : length[i]} time4 = time.time() run_need_op = {'train_op':run_op['train_op'], 'mean_loss':run_op['mean_loss'], 'loss':run_op['loss'], 'rnn_keep_state_op':run_op['rnn_keep_state_op'], 'label_error_rate':run_op['label_error_rate']} calculate_return = self.sess.run(run_need_op, feed_dict = feed_dict) time5 = time.time() print("thread_name: ", thread_name, num_batch, " time:",time4-time3,time5-time4) print('label_error_rate:',calculate_return['label_error_rate']) print('mean_loss:',calculate_return['mean_loss']) print("thread_name: ", thread_name, num_batch, " time:",time2-time1,time3-time2,time4-time3,time4-time1) num_batch += 1 total_acc_error_rate += calculate_return['label_error_rate'] self.acc_label_error_rate[gpu_id] += calculate_return['label_error_rate'] self.num_batch[gpu_id] += 1 if self.num_batch[gpu_id] % int(self.steps_per_checkpoint_cf/50) == 0: logging.info("Batch: %d current averagelabel error rate : %f" % (self.num_batch[gpu_id], self.acc_label_error_rate[gpu_id] / self.num_batch[gpu_id])) logging.info('******end TrainFunction******') def GetFeatAndLabel(self): return self.input_queue.get() def GetAverageLabelErrorRate(self): tot_label_error_rate = 0.0 tot_num_batch = 0 for i in range(self.num_threads_cf): tot_label_error_rate += self.acc_label_error_rate[i] tot_num_batch += self.num_batch[i] if tot_num_batch == 0: average_label_error_rate = 1.0 else: average_label_error_rate = tot_label_error_rate / tot_num_batch self.tot_lab_err_rate += tot_label_error_rate self.tot_num_batch += tot_num_batch #self.ResetAccuracy(tot_reset = False) return average_label_error_rate def GetTotLabErrRate(self): return self.tot_lab_err_rate/self.tot_num_batch def ResetAccuracy(self, tot_reset = True): for i in range(len(self.acc_label_error_rate)): self.acc_label_error_rate[i] = 0.0 self.num_batch[i] = 0 if tot_reset: self.tot_lab_err_rate = 0 self.tot_num_batch = 0 for i in range(5): self.all_lab_err_rate.append(1.1) self.num_save = 0
class TrainClass(object): ''' ''' def __init__(self, conf_dict): # configure paramtere self.conf_dict = conf_dict self.print_trainable_variables_cf = False self.use_normal_cf = False self.use_sgd_cf = True self.optimizer_cf = 'GD' self.use_sync_cf = False self.use_clip_cf = False self.restore_training_cf = True self.checkpoint_dir_cf = None self.num_threads_cf = 1 self.queue_cache_cf = 100 self.task_index_cf = -1 self.grad_clip_cf = 5.0 self.feature_transfile_cf = None self.learning_rate_cf = 0.1 self.learning_rate_decay_steps_cf = 100000 self.learning_rate_decay_rate_cf = 0.96 self.l2_scale_cf = 0.00005 self.batch_size_cf = 16 self.num_frames_batch_cf = 20 self.reset_global_step_cf = True self.steps_per_checkpoint_cf = 1000 self.criterion_cf = 'ctc' self.silence_phones = [] # initial configuration parameter for attr in self.__dict__: if len(attr.split('_cf')) != 2: continue key = attr.split('_cf')[0] if key in conf_dict.keys(): if key in strset or type(conf_dict[key]) is not str: self.__dict__[attr] = conf_dict[key] else: print('***************', key) self.__dict__[attr] = eval(conf_dict[key]) if self.feature_transfile_cf == None: logging.info('No feature_transfile,it must have.') sys.exit(1) feat_trans = FeatureTransform() feat_trans.LoadTransform(self.feature_transfile_cf) # init train file self.kaldi_io_nstream_train = KaldiDataReadParallel() self.input_dim = self.kaldi_io_nstream_train.Initialize( self.conf_dict, scp_file=conf_dict['tr_scp'], label=conf_dict['tr_label'], lat_scp_file=conf_dict['lat_scp_file'], feature_transform=feat_trans, criterion=self.criterion_cf) # init cv file self.kaldi_io_nstream_cv = KaldiDataReadParallel() #self.kaldi_io_nstream_cv.Initialize(conf_dict, # scp_file = conf_dict['cv_scp'], label = conf_dict['cv_label'], # feature_transform = feat_trans, criterion = 'ctc') self.num_batch_total = 0 self.tot_lab_err_rate = 0.0 self.tot_num_batch = 0.0 logging.info(self.kaldi_io_nstream_train.__repr__()) #logging.info(self.kaldi_io_nstream_cv.__repr__()) # Initial input queue. #self.input_queue = Queue.Queue(self.queue_cache_cf) self.acc_label_error_rate = [] self.all_lab_err_rate = [] self.num_save = 0 for i in range(5): self.all_lab_err_rate.append(1.1) self.num_batch = [] for i in range(self.num_threads_cf): self.acc_label_error_rate.append(1.0) self.num_batch.append(0) return # multi computers construct train graph def ConstructGraph(self, device, server): with tf.device(device): if 'cnn' in self.criterion_cf: self.X = tf.placeholder( tf.float32, [None, self.input_dim[0], self.input_dim[1], 1], name='feature') else: self.X = tf.placeholder( tf.float32, [None, self.batch_size_cf, self.input_dim], name='feature') if 'ctc' in self.criterion_cf: self.Y = tf.sparse_placeholder(tf.int32, name="labels") elif 'whole' in self.criterion_cf: self.Y = tf.placeholder(tf.int32, [self.batch_size_cf, None], name="labels") elif 'ce' in self.criterion_cf: self.Y = tf.placeholder( tf.int32, [self.batch_size_cf, self.num_frames_batch_cf], name="labels") elif 'chain' in self.criterion_cf: self.Y = tf.placeholder(tf.float32, [self.batch_size_cf, None], name="labels") # if 'mmi' in self.criterion_cf or 'smbr' in self.criterion_cf or 'mpfe' in self.criterion_cf: self.indexs = tf.placeholder(tf.int32, [self.batch_size_cf, None, 2], name="indexs") self.pdf_values = tf.placeholder(tf.int32, [self.batch_size_cf, None], name="pdf_values") self.lm_ws = tf.placeholder(tf.float32, [self.batch_size_cf, None], name="lm_ws") self.am_ws = tf.placeholder(tf.float32, [self.batch_size_cf, None], name="am_ws") self.statesinfo = tf.placeholder(tf.int32, [self.batch_size_cf, None, 2], name="statesinfo") self.num_states = tf.placeholder(tf.int32, [self.batch_size_cf], name="num_states") self.lattice = [ self.indexs, self.pdf_values, self.lm_ws, self.am_ws, self.statesinfo, self.num_states ] elif 'chain' in self.criterion_cf: self.indexs = tf.placeholder(tf.int32, [self.batch_size_cf, None, 2], name="indexs") self.in_labels = tf.placeholder(tf.int32, [self.batch_size_cf, None], name="in_labels") self.weights = tf.placeholder(tf.float32, [self.batch_size_cf, None], name="weights") self.statesinfo = tf.placeholder(tf.int32, [self.batch_size_cf, None, 2], name="statesinfo") self.num_states = tf.placeholder(tf.int32, [self.batch_size_cf], name="num_states") self.length = tf.placeholder(tf.int32, [None], name="length") self.fst = [ self.indexs, self.in_labels, self.weights, self.statesinfo, self.num_states ] self.seq_len = tf.placeholder(tf.int32, [None], name='seq_len') #self.learning_rate_var_tf = tf.Variable(float(self.learning_rate_cf), # trainable=False, name='learning_rate') # init global_step and learning rate decay criterion #self.global_step=tf.train.get_or_create_global_step() self.global_step = tf.Variable(0, trainable=False, name='global_step', dtype=tf.int64) exponential_decay = True piecewise_constant = False inverse_time_decay = False if exponential_decay == True: self.learning_rate_var_tf = tf.train.exponential_decay( float(self.learning_rate_cf), self.global_step, self.learning_rate_decay_steps_cf, self.learning_rate_decay_rate_cf, staircase=True, name='learning_rate_exponential_decay') elif piecewise_constant == True: boundaries = [100000, 110000] values = [1.0, 0.5, 0.1] self.learning_rate_var_tf = tf.train.piecewise_constant( self.global_step, boundaries, values) elif inverse_time_decay == True: # decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step)) # decay_rate = 0.5 , decay_step = 100000 self.learning_rate_var_tf = tf.train.inverse_time_decay( float(self.learning_rate_cf), self.global_step, self.learning_rate_decay_steps_cf, self.learning_rate_decay_rate_cf, staircase=True, name='learning_rate_inverse_time_decay') else: #self.learning_rate_var_tf = float(self.learning_rate_cf) #self.learning_rate_var_tf = tf.Variable(float(self.learning_rate_cf), trainable=False) self.learning_rate_var_tf = tf.convert_to_tensor( self.learning_rate_cf, name="learning_rate") # trainable=False, name='learning_rate') if self.optimizer_cf == 'GD': optimizer = tf.train.GradientDescentOptimizer( self.learning_rate_var_tf) elif self.optimizer_cf == 'Adam': optimizer = tf.train.AdamOptimizer( learning_rate=self.learning_rate_var_tf, beta1=0.9, beta2=0.999, epsilon=1e-08, use_locking=True) elif self.optimizer_cf == 'Adadelta': tf.train.AdadeltaOptimizer( learning_rate=self.learning_rate_var_tf, rho=0.95, epsilon=1e-08, use_locking=False, name='Adadelta') elif self.optimizer_cf == 'AdagradDA': tf.train.AdagradDAOptimizer( learning_rate=self.learning_rate_var_tf, global_step=self.global_step, initial_gradient_squared_accumulator_value=0.1, l1_regularization_strength=0.0, l2_regularization_strength=0.0, use_locking=False, name='AdagradDA') elif self.optimizer_cf == 'Adagrad': tf.train.AdagradOptimizer( learning_rate=self.learning_rate_var_tf, initial_accumulator_value=0.1, use_locking=False, name='Adagrad') else: logging.error("no this opyimizer.") sys.exit(1) # sync train if self.use_sync_cf: optimizer = tf.train.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=32, total_num_replicas=32, use_locking=True) sync_replicas_hook = [ optimizer.make_session_run_hook( is_chief=(self.task_index_cf == 0)) ] logging.info("******use synchronization train******") else: sync_replicas_hook = None logging.info("******use asynchronization train******") nnet_model = LstmModel(self.conf_dict) mean_loss = None loss = None rnn_state_zero_op = None rnn_keep_state_op = None if 'ctc' in self.criterion_cf: ctc_mean_loss, ctc_loss, label_error_rate, _ = nnet_model.CtcLoss( self.X, self.Y, self.seq_len) mean_loss = ctc_mean_loss loss = ctc_loss # elif 'ce' in self.criterion_cf and 'cnn' in self.criterion_cf and 'whole' in self.criterion_cf: # ce_mean_loss, ce_loss , label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.CeCnnBlstmLoss(self.X, self.Y, self.seq_len) # mean_loss = ce_mean_loss # loss = ce_loss elif 'ce' in self.criterion_cf: ce_mean_loss, ce_loss, label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.CeLoss( self.X, self.Y, self.seq_len) mean_loss = ce_mean_loss loss = ce_loss elif 'mmi' in self.criterion_cf: mmi_mean_loss, mmi_loss, label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.MmiLoss( self.X, self.Y, self.seq_len, self.indexs, self.pdf_values, self.lm_ws, self.am_ws, self.statesinfo, self.num_states, old_acoustic_scale=0.0, acoustic_scale=0.083, time_major=True, drop_frames=True) mean_loss = mmi_mean_loss loss = mmi_loss elif 'smbr' in self.criterion_cf or 'mpfe' in self.criterion_cf: if 'smbr' in self.criterion_cf: criterion = 'smbr' else: criterion = 'mpfe' pdf_to_phone = self.kaldi_io_nstream_train.pdf_to_phone log_priors = self.kaldi_io_nstream_train.pdf_prior mpe_mean_loss, mpe_loss, label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.MpeLoss( self.X, self.Y, self.seq_len, self.indexs, self.pdf_values, self.lm_ws, self.am_ws, self.statesinfo, self.num_states, log_priors=log_priors, silence_phones=self.silence_phones, pdf_to_phone=pdf_to_phone, one_silence_class=True, criterion=criterion, old_acoustic_scale=0.0, acoustic_scale=0.083, time_major=True) mean_loss = mpe_mean_loss loss = mpe_loss elif 'chain' in self.criterion_cf: den_indexs, den_in_labels, den_weights, den_statesinfo, den_num_states, den_start_state, laststatesuperfinal = Fst2SparseMatrix( self.conf_dict['den_fst']) label_dim = self.conf_dict['label_dim'] delete_laststatesuperfinal = True l2_regularize = 0.00005 leaky_hmm_coefficient = 0.1 #xent_regularize = 0.025 xent_regularize = 0.0 #den_indexs = tf.convert_to_tensor(den_indexs, name='den_indexs') #den_in_labels = tf.convert_to_tensor(den_in_labels, name='den_in_labels') #den_weights = tf.convert_to_tensor(den_weights, name='den_weights') #den_statesinfo = tf.convert_to_tensor(den_statesinfo, name='den_statesinfo') #den_indexs = tf.make_tensor_proto(den_indexs) #den_in_labels = tf.make_tensor_proto(den_in_labels) #den_weights = tf.make_tensor_proto(den_weights) #den_statesinfo = tf.make_tensor_proto(den_statesinfo) den_indexs = np.reshape(den_indexs, [-1]).tolist() den_in_labels = np.reshape(den_in_labels, [-1]).tolist() den_weights = np.reshape(den_weights, [-1]).tolist() den_statesinfo = np.reshape(den_statesinfo, [-1]).tolist() if 'xent' in self.criterion_cf: xent_regularize = 0.025 chain_mean_loss, chain_loss, label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.ChainXentLoss( self.X, self.Y, self.indexs, self.in_labels, self.weights, self.statesinfo, self.num_states, self.length, label_dim, den_indexs, den_in_labels, den_weights, den_statesinfo, den_num_states, den_start_state, delete_laststatesuperfinal, l2_regularize, leaky_hmm_coefficient, xent_regularize) else: chain_mean_loss, chain_loss, label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.ChainLoss( self.X, self.Y, self.indexs, self.in_labels, self.weights, self.statesinfo, self.num_states, self.length, label_dim, den_indexs, den_in_labels, den_weights, den_statesinfo, den_num_states, den_start_state, delete_laststatesuperfinal, l2_regularize, leaky_hmm_coefficient, xent_regularize) mean_loss = chain_mean_loss loss = chain_loss else: logging.info("no criterion.") sys.exit(1) if self.use_sgd_cf: tvars = tf.trainable_variables() if self.use_normal_cf: apply_l2_regu = tf.add_n([tf.nn.l2_loss(v) for v in tvars ]) * self.l2_scale_cf #l2_regu = tf.contrib.layers.l2_regularizer(0.5) #lstm_vars = [ var for var in tvars if 'lstm' in var.name ] #apply_l2_regu = tf.contrib.layers.apply_regularization(l2_regu, lstm_vars) #apply_l2_regu = tf.contrib.layers.apply_regularization(l2_regu, tvars) mean_loss = mean_loss + apply_l2_regu #mean_loss = mean_loss + apply_l2_regu #grads_var = optimizer.compute_gradients(mean_loss+apply_l2_regu, # var_list = tvars) #grads = [ g for g,_ in grads_var ] if self.use_clip_cf: grads, gradient_norms = tf.clip_by_global_norm( tf.gradients(mean_loss, tvars), self.grad_clip_cf, use_norm=None) #grads, gradient_norms = tf.clip_by_global_norm(grads, # self.grad_clip_cf, # use_norm=None) train_op = optimizer.apply_gradients( zip(grads, tvars), global_step=self.global_step) else: train_op = optimizer.minimize(mean_loss, global_step=self.global_step) else: train_op = optimizer.minimize(mean_loss, global_step=self.global_step) # set run operation self.run_ops = { 'train_op': train_op, 'mean_loss': mean_loss, 'loss': loss, 'label_error_rate': label_error_rate, 'rnn_keep_state_op': rnn_keep_state_op, 'rnn_state_zero_op': rnn_state_zero_op } # set initial parameter self.init_para = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) #tmp_variables = tf.trainable_variables() #self.saver = tf.train.Saver(tmp_variables, max_to_keep=100) self.total_variables = np.sum([ np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() ]) logging.info('total parameters : %d' % self.total_variables) # set gpu option gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5) # session config sess_config = tf.ConfigProto( intra_op_parallelism_threads=self.num_threads_cf, inter_op_parallelism_threads=self.num_threads_cf, allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_options) sess_config = tf.ConfigProto(intra_op_parallelism_threads=2, inter_op_parallelism_threads=5, allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_options) if self.reset_global_step_cf and self.task_index_cf == 0: non_train_variables = [self.global_step] + tf.local_variables() else: non_train_variables = tf.local_variables() ready_for_local_init_op = tf.variables_initializer( non_train_variables) local_init_op = tf.report_uninitialized_variables( var_list=non_train_variables) # add saver hook self.saver = tf.train.Saver(max_to_keep=50, sharded=False, allow_empty=False) scaffold = tf.train.Scaffold( init_op=None, init_feed_dict=None, init_fn=None, ready_op=None, ready_for_local_init_op=ready_for_local_init_op, local_init_op=local_init_op, summary_op=None, saver=self.saver, copy_from_scaffold=None) self.sess = tf.train.MonitoredTrainingSession( master=server.target, is_chief=(self.task_index_cf == 0), checkpoint_dir=self.checkpoint_dir_cf, scaffold=scaffold, hooks=sync_replicas_hook, chief_only_hooks=None, save_checkpoint_secs=None, save_summaries_steps=self.steps_per_checkpoint_cf, save_summaries_secs=None, config=sess_config, stop_grace_period_secs=120, log_step_count_steps=100, max_wait_secs=600, save_checkpoint_steps=self.steps_per_checkpoint_cf) # summary_dir = self.checkpoint_dir_cf + "_summary_dir") #self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess) ''' sv = tf.train.Supervisor(is_chief=(self.task_index_cf==0), global_step=global_step, init_op = self.init_para, logdir = self.checkpoint_dir_cf, saver=self.saver, save_model_secs=600, checkpoint_basename='model.ckpt') self.sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) ''' return def SaveTextModel(self): if self.print_trainable_variables_cf == True: ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir_cf) if ckpt and ckpt.model_checkpoint_path: print_trainable_variables(self.sess, ckpt.model_checkpoint_path + '.txt') # def InputFeat(self, input_lock): # while True: # input_lock.acquire() # ''' # if 'ctc' in self.criterion_cf or 'whole' in self.criterion_cf: # if 'cnn' in self.criterion_cf: # feat,label,length = self.kaldi_io_nstream.CnnLoadNextNstreams() # else: # feat,label,length = self.kaldi_io_nstream.WholeLoadNextNstreams() # if length is None: # break # print(np.shape(feat),np.shape(label), np.shape(length)) # if len(label) != self.batch_size_cf: # break # if 'ctc' in self.criterion_cf: # sparse_label = sparse_tuple_from(label) # self.input_queue.put((feat,sparse_label,length)) # else: # self.input_queue.put((feat,label,length)) # # elif 'ce' in self.criterion_cf: # if 'cnn' in self.criterion_cf: # feat_array, label_array, length_array = self.kaldi_io_nstream.CnnSliceLoadNextNstreams() # else: # feat_array, label_array, length_array = self.kaldi_io_nstream.SliceLoadNextNstreams() # if length_array is None: # break # print(np.shape(feat_array),np.shape(label_array), np.shape(length_array)) # if len(label_array[0]) != self.batch_size_cf: # break # self.input_queue.put((feat_array, label_array, length_array)) # ''' # feat,label,length = self.kaldi_io_nstream.LoadBatch() # if length is None: # break # print(np.shape(feat),np.shape(label), np.shape(length)) # sys.stdout.flush() # if 'ctc' in self.criterion_cf: # sparse_label = sparse_tuple_from(label) # self.input_queue.put((feat,sparse_label,length)) # else: # self.input_queue.put((feat,label,length)) # self.num_batch_total += 1 ## if self.num_batch_total % 3000 == 0: ## self.SaveModel() ## self.AdjustLearnRate() # print('total_batch_num**********',self.num_batch_total,'***********') # input_lock.release() # self.input_queue.put((None, None, None)) # # def ThreadInputFeatAndLab(self): # input_thread = [] # input_lock = threading.Lock() # for i in range(1): # input_thread.append(threading.Thread(group=None, target=self.InputFeat, # args=(input_lock,),name='read_thread'+str(i))) # # for thr in input_thread: # logging.info('ThreadInputFeatAndLab start') # thr.start() # # return input_thread # def SaveModel(self): # while True: # time.sleep(1.0) # if self.input_queue.empty(): # checkpoint_path = os.path.join(self.checkpoint_dir_cf, # str(self.num_batch_total)+'_model'+'.ckpt') # logging.info('save model: '+checkpoint_path+ # ' --- learn_rate: ' + # str(self.sess.run(self.learning_rate_var_tf))) # self.saver.save(self.sess, checkpoint_path) # break # if current label error rate less then previous five def AdjustLearnRate(self): curr_lab_err_rate = self.GetAverageLabelErrorRate() logging.info("current label error rate : %f" % curr_lab_err_rate) all_lab_err_rate_len = len(self.all_lab_err_rate) for i in range(all_lab_err_rate_len): if curr_lab_err_rate < self.all_lab_err_rate[i]: break if i == len(self.all_lab_err_rate) - 1: self.DecayLearningRate(0.8) logging.info('learn_rate decay to ' + str(self.sess.run(self.learning_rate_var_tf))) self.all_lab_err_rate[self.num_save % all_lab_err_rate_len] = curr_lab_err_rate self.num_save += 1 def DecayLearningRate(self, lr_decay_factor): learning_rate_decay_op = self.learning_rate_var_tf.assign( tf.multiply(self.learning_rate_var_tf, lr_decay_factor)) self.sess.run(learning_rate_decay_op) logging.info('learn_rate decay to ' + str(self.sess.run(self.learning_rate_var_tf))) logging.info('lr_decay_factor is ' + str(lr_decay_factor)) # get restore model number def GetNum(self, str): return int(str.split('/')[-1].split('_')[0]) # train_loss is a open train or cv . def TrainLogic(self, device, shuffle=False, train_loss=True, skip_offset=0): if train_loss == True: logging.info('TrainLogic train start.') logging.info('Start global step is %d---learn_rate is %f' % (self.sess.run(tf.train.get_or_create_global_step()), self.sess.run(self.learning_rate_var_tf))) self.kaldi_io_nstream = self.kaldi_io_nstream_train # set run operation if 'ctc' in self.criterion_cf or 'whole' in self.criterion_cf: run_op = { 'train_op': self.run_ops['train_op'], 'label_error_rate': self.run_ops['label_error_rate'], 'mean_loss': self.run_ops['mean_loss'], 'loss': self.run_ops['loss'] } elif 'ce' in self.criterion_cf: run_op = { 'train_op': self.run_ops['train_op'], 'label_error_rate': self.run_ops['label_error_rate'], 'mean_loss': self.run_ops['mean_loss'], 'loss': self.run_ops['loss'], 'rnn_keep_state_op': self.run_ops['rnn_keep_state_op'], 'rnn_state_zero_op': self.run_ops['rnn_state_zero_op'] } elif 'chain' in self.criterion_cf: run_op = { 'train_op': self.run_ops['train_op'], 'mean_loss': self.run_ops['mean_loss'], 'loss': self.run_ops['loss'] } else: assert 'No train criterion.' else: logging.info('TrainLogic cv start.') self.kaldi_io_nstream = self.kaldi_io_nstream_cv run_op = { 'label_error_rate': self.run_ops['label_error_rate'], 'mean_loss': self.run_ops['mean_loss'] } # reset io and start input thread self.kaldi_io_nstream.Reset(shuffle=shuffle, skip_offset=skip_offset) #threadinput = self.ThreadInputFeatAndLab() time.sleep(3) with tf.device(device): if 'ctc' in self.criterion_cf or 'whole' in self.criterion_cf or 'chain' in self.criterion_cf: self.WholeTrainFunction(0, run_op, 'train_ctc_thread_hubo') elif 'ce' in self.criterion_cf: self.SliceTrainFunction(0, run_op, 'train_ce_thread_hubo') else: logging.info('no criterion in train') sys.exit(1) tmp_label_error_rate = self.GetAverageLabelErrorRate() logging.info("current averagelabel error rate : %f" % tmp_label_error_rate) logging.info('learn_rate is ' + str(self.sess.run(self.learning_rate_var_tf))) if train_loss == True: self.AdjustLearnRate() logging.info('TrainLogic train end.') else: logging.info('TrainLogic cv end.') # End input thread self.kaldi_io_nstream.JoinInput() #for i in range(len(threadinput)): # threadinput[i].join() self.ResetAccuracy() return tmp_label_error_rate def WholeTrainFunction(self, gpu_id, run_op, thread_name): logging.info('******start WholeTrainFunction******') total_curr_error_rate = 0.0 num_batch = 0 self.acc_label_error_rate[gpu_id] = 0.0 self.num_batch[gpu_id] = 0 total_curr_mean_loss = 0.0 total_mean_loss = 0.0 #print_trainable_variables(self.sess, 'save.model.txt') while True: time1 = time.time() feat, label, length, lat_list = self.GetFeatAndLabel() if feat is None: logging.info('train thread end : %s' % thread_name) break time2 = time.time() if 'mmi' in self.criterion_cf or 'smbr' in self.criterion_cf or 'mpfe' in self.criterion_cf: feed_dict = { self.X: feat, self.Y: label, self.seq_len: length, self.indexs: lat_list[0], self.pdf_values: lat_list[1], self.lm_ws: lat_list[2], self.am_ws: lat_list[3], self.statesinfo: lat_list[4], self.num_states: lat_list[5] } elif 'chain' in self.criterion_cf: # length is valid_length is int # self.Y is deriv_weights feed_dict = { self.X: feat, self.Y: label, self.length: [length], self.indexs: lat_list[0], self.in_labels: lat_list[1], self.weights: lat_list[2], self.statesinfo: lat_list[3], self.num_states: lat_list[4] } else: feed_dict = {self.X: feat, self.Y: label, self.seq_len: length} time3 = time.time() calculate_return = self.sess.run(run_op, feed_dict=feed_dict) time4 = time.time() print("thread_name: ", thread_name, num_batch, " time:", time2 - time1, time3 - time2, time4 - time3, time4 - time1) if 'label_error_rate' in calculate_return.keys(): print('label_error_rate:', calculate_return['label_error_rate']) total_curr_error_rate += calculate_return['label_error_rate'] self.acc_label_error_rate[gpu_id] += calculate_return[ 'label_error_rate'] else: total_curr_error_rate += 0.0 self.acc_label_error_rate[gpu_id] += 0.0 print('mean_loss:', calculate_return['mean_loss']) print('loss:', calculate_return['loss']) if type(calculate_return['mean_loss']) is list: total_curr_mean_loss += calculate_return['mean_loss'][0] total_mean_loss += calculate_return['mean_loss'][0] else: total_curr_mean_loss += calculate_return['mean_loss'] total_mean_loss += calculate_return['mean_loss'] num_batch += 1 self.num_batch[gpu_id] += 1 if self.num_batch[gpu_id] % int( self.steps_per_checkpoint_cf / 50) == 0: logging.info( "Batch: %d current averagelabel error rate : %s, mean loss : %s" % (int(self.steps_per_checkpoint_cf / 50), str(total_curr_error_rate / int(self.steps_per_checkpoint_cf / 50)), str(total_curr_mean_loss / int(self.steps_per_checkpoint_cf / 50)))) total_curr_error_rate = 0.0 total_curr_mean_loss = 0.0 logging.info( "Batch: %d current total averagelabel error rate : %s, mean loss : %s" % (self.num_batch[gpu_id], str(self.acc_label_error_rate[gpu_id] / self.num_batch[gpu_id]), str(total_mean_loss / self.num_batch[gpu_id]))) logging.info('******end TrainFunction******') def SliceTrainFunction(self, gpu_id, run_op, thread_name): logging.info('******start SliceTrainFunction******') total_curr_error_rate = 0.0 num_batch = 0 self.acc_label_error_rate[gpu_id] = 0.0 self.num_batch[gpu_id] = 0 total_curr_mean_loss = 0.0 total_mean_loss = 0.0 num_sentence = 0 while True: time1 = time.time() feat, label, length, lat_list = self.GetFeatAndLabel() if feat is None: logging.info('train thread end : %s' % thread_name) break time2 = time.time() self.sess.run(run_op['rnn_state_zero_op']) for i in range(len(feat)): print('************input info**********:', np.shape(feat[i]), np.shape(label[i]), length[i], flush=True) time3 = time.time() feed_dict = { self.X: feat[i], self.Y: label[i], self.seq_len: length[i] } time4 = time.time() run_need_op = { 'train_op': run_op['train_op'], 'mean_loss': run_op['mean_loss'], 'loss': run_op['loss'], 'rnn_keep_state_op': run_op['rnn_keep_state_op'], 'label_error_rate': run_op['label_error_rate'] } calculate_return = self.sess.run(run_need_op, feed_dict=feed_dict) time5 = time.time() print("thread_name: ", thread_name, num_batch, " time:", time4 - time3, time5 - time4) print('label_error_rate:', calculate_return['label_error_rate']) print('mean_loss:', calculate_return['mean_loss']) total_curr_mean_loss += calculate_return['mean_loss'] total_mean_loss += calculate_return['mean_loss'] num_batch += 1 total_curr_error_rate += calculate_return['label_error_rate'] self.acc_label_error_rate[gpu_id] += calculate_return[ 'label_error_rate'] self.num_batch[gpu_id] += 1 if self.num_batch[gpu_id] % int( self.steps_per_checkpoint_cf / 50) == 0: logging.info( "Batch: %d current averagelabel error rate : %f, mean loss : %f" % (int(self.steps_per_checkpoint_cf / 50), total_curr_error_rate / int(self.steps_per_checkpoint_cf / 50), total_curr_mean_loss / int(self.steps_per_checkpoint_cf / 50))) total_curr_error_rate = 0.0 total_curr_mean_loss = 0.0 logging.info( "Batch: %d current total averagelabel error rate : %f, mean loss : %f" % (self.num_batch[gpu_id], self.acc_label_error_rate[gpu_id] / self.num_batch[gpu_id], total_mean_loss / self.num_batch[gpu_id])) # print batch sentence time info print("******thread_name: ", thread_name, num_sentence, "io time:", time2 - time1, "calculation time:", time5 - time1) num_sentence += 1 logging.info('******end SliceTrainFunction******') def GetFeatAndLabel(self): return self.kaldi_io_nstream.GetInput() def GetAverageLabelErrorRate(self): tot_label_error_rate = 0.0 tot_num_batch = 0 for i in range(self.num_threads_cf): tot_label_error_rate += self.acc_label_error_rate[i] tot_num_batch += self.num_batch[i] if tot_num_batch == 0: average_label_error_rate = 1.0 else: average_label_error_rate = tot_label_error_rate / tot_num_batch self.tot_lab_err_rate += tot_label_error_rate self.tot_num_batch += tot_num_batch #self.ResetAccuracy(tot_reset = False) return average_label_error_rate def GetTotLabErrRate(self): return self.tot_lab_err_rate / self.tot_num_batch def ResetAccuracy(self, tot_reset=True): for i in range(len(self.acc_label_error_rate)): self.acc_label_error_rate[i] = 0.0 self.num_batch[i] = 0 if tot_reset: self.tot_lab_err_rate = 0 self.tot_num_batch = 0 for i in range(5): self.all_lab_err_rate.append(1.1) self.num_save = 0
class TrainClass(object): ''' ''' def __init__(self, conf_dict): # configure paramtere self.conf_dict = conf_dict self.print_trainable_variables_cf = False self.use_normal_cf = False self.use_sgd_cf = True self.restore_training_cf = True self.checkpoint_dir_cf = None self.num_threads_cf = 1 self.queue_cache_cf = 100 self.task_index_cf = -1 self.grad_clip_cf = 5.0 self.feature_transfile_cf = None self.learning_rate_cf = 0.001 self.batch_size_cf = 10 # initial configuration parameter for attr in self.__dict__: if len(attr.split('_cf')) != 2: continue; key = attr.split('_cf')[0] if key in conf_dict.keys(): self.__dict__[attr] = conf_dict[key] if self.feature_transfile_cf == None: logging.info('No feature_transfile,it must have.') sys.exit(1) feat_trans = FeatureTransform() feat_trans.LoadTransform(self.feature_transfile_cf) # init train file self.kaldi_io_nstream_train = KaldiDataReadParallel() self.input_dim = self.kaldi_io_nstream_train.Initialize(conf_dict, scp_file = conf_dict['tr_scp'], label = conf_dict['tr_label'], feature_transform = feat_trans, criterion = 'ctc') # init cv file self.kaldi_io_nstream_cv = KaldiDataReadParallel() self.kaldi_io_nstream_cv.Initialize(conf_dict, scp_file = conf_dict['cv_scp'], label = conf_dict['cv_label'], feature_transform = feat_trans, criterion = 'ctc') self.num_batch_total = 0 self.tot_lab_err_rate = 0.0 self.tot_num_batch = 0.0 logging.info(self.kaldi_io_nstream_train.__repr__()) logging.info(self.kaldi_io_nstream_cv.__repr__()) # Initial input queue. self.input_queue = Queue.Queue(self.queue_cache_cf) self.acc_label_error_rate = [] self.all_lab_err_rate = [] self.num_save = 0 for i in range(5): self.all_lab_err_rate.append(1.1) self.num_batch = [] for i in range(self.num_threads_cf): self.acc_label_error_rate.append(1.0) self.num_batch.append(0) return # multi computers construct train graph def ConstructGraph(self, device, server): with tf.device(device): self.X = tf.placeholder(tf.float32, [None, None, self.input_dim], name='feature') self.Y = tf.sparse_placeholder(tf.int32, name="labels") self.seq_len = tf.placeholder(tf.int32,[None], name = 'seq_len') self.learning_rate_var_tf = tf.Variable(float(self.learning_rate_cf), trainable=False, name='learning_rate') if self.use_sgd_cf: optimizer = tf.train.GradientDescentOptimizer(self.learning_rate_var_tf) else: optimizer = tf.train.AdamOptimizer(learning_rate= self.learning_rate_var_tf, beta1=0.9, beta2=0.999, epsilon=1e-08) nnet_model = LstmModel(self.conf_dict) ctc_mean_loss, ctc_loss , label_error_rate, decoded = nnet_model.CtcLoss(self.X, self.Y, self.seq_len) if self.use_sgd_cf and self.use_normal_cf: tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients( ctc_mean_loss, tvars), self.grad_clip_cf) train_op = optimizer.apply_gradients( zip(grads, tvars), global_step=tf.contrib.framework.get_or_create_global_step()) else: train_op = optimizer.minimize(ctc_mean_loss) # set run operation self.run_ops = {'train_op':train_op, 'ctc_mean_loss':ctc_mean_loss, 'ctc_loss':ctc_loss, 'label_error_rate':label_error_rate} # set initial parameter self.init_para = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) tmp_variables=tf.trainable_variables() self.saver = tf.train.Saver(tmp_variables, max_to_keep=100) self.total_variables = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) logging.info('total parameters : %d' % self.total_variables) # set gpu option gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95) # session config sess_config = tf.ConfigProto(intra_op_parallelism_threads=self.num_threads_cf, inter_op_parallelism_threads=self.num_threads_cf, allow_soft_placement=True, log_device_placement=False,gpu_options=gpu_options) global_step = tf.contrib.framework.get_or_create_global_step() sv = tf.train.Supervisor(is_chief=(self.task_index_cf==0), global_step=global_step, init_op = self.init_para, logdir = self.checkpoint_dir_cf, saver=self.saver, save_model_secs=3600, checkpoint_basename='model.ckpt') self.sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) return def SaveTextModel(self): if self.print_trainable_variables_cf == True: ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir_cf) if ckpt and ckpt.model_checkpoint_path: print_trainable_variables(self.sess, ckpt.model_checkpoint_path+'.txt') def InputFeat(self, input_lock): while True: input_lock.acquire() feat,label,length = self.kaldi_io_nstream.LoadNextNstreams() if length is None: break if len(label) != self.batch_size_cf: break sparse_label = sparse_tuple_from(label) self.input_queue.put((feat,sparse_label,length)) self.num_batch_total += 1 if self.num_batch_total % 3000 == 0: self.SaveModel() self.AdjustLearnRate() print('total_batch_num**********',self.num_batch_total,'***********') input_lock.release() self.input_queue.put((None, None, None)) def ThreadInputFeatAndLab(self): input_thread = [] input_lock = threading.Lock() for i in range(1): input_thread.append(threading.Thread(group=None, target=self.InputFeat, args=(input_lock,),name='read_thread'+str(i))) for thr in input_thread: logging.info('ThreadInputFeatAndLab start') thr.start() return input_thread def SaveModel(self): while True: time.sleep(1.0) if self.input_queue.empty(): checkpoint_path = os.path.join(self.checkpoint_dir, str(self.num_batch_total)+'_model'+'.ckpt') logging.info('save model: '+checkpoint_path+ ' --- learn_rate: ' + str(self.sess.run(self.learning_rate_var_tf))) self.saver.save(self.sess, checkpoint_path) break # if current label error rate less then previous five def AdjustLearnRate(self): curr_lab_err_rate = self.GetAvergaeLabelErrorRate() logging.info("current label error rate : %f\n" % curr_lab_err_rate) all_lab_err_rate_len = len(self.all_lab_err_rate) for i in range(all_lab_err_rate_len): if curr_lab_err_rate < self.all_lab_err_rate[i]: break if i == len(self.all_lab_err_rate)-1: self.decay_learning_rate(0.8) self.all_lab_err_rate[self.num_save%all_lab_err_rate_len] = curr_lab_err_rate self.num_save += 1 # train_loss is a open train or cv . def TrainLogic(self, device, shuffle = False, train_loss = True, skip_offset = 0): if train_loss == True: self.kaldi_io_nstream = self.kaldi_io_nstream_train run_op = self.run_ops else: self.kaldi_io_nstream = self.kaldi_io_nstream_cv run_op = {'label_error_rate':run_op['label_error_rate'], 'ctc_mean_loss':run_op['ctc_mean_loss']} threadinput = self.ThreadInputFeatAndLab() time.sleep(3) logging.info('train start.') with tf.device(device): self.TrainFunction(0, run_op, 'train_thread_hubo') logging.info('train end.') threadinput[0].join() tmp_label_error_rate = self.GetAvergaeLabelErrorRate() self.kaldi_io_nstream.Reset(shuffle = shuffle, skip_offset = skip_offset) self.ResetAccuracy() return tmp_label_error_rate def TrainFunction(self, gpu_id, run_op, thread_name): logging.info('******start TrainFunction******') total_acc_error_rate = 0.0 num_batch = 0 self.acc_label_error_rate[gpu_id] = 0.0 self.num_batch[gpu_id] = 0 while True: time1=time.time() feat, sparse_label, length = self.GetFeatAndLabel() if feat is None: logging.info('train thread end : %s' % thread_name) break time2=time.time() feed_dict = {self.X : feat, self.Y : sparse_label, self.seq_len : length} time3 = time.time() calculate_return = self.sess.run(run_op, feed_dict = feed_dict) time4 = time.time() print("thread_name: ", thread_name, num_batch," time:",time2-time1,time3-time2,time4-time3,time4-time1) print('label_error_rate:',calculate_return['label_error_rate']) print('ctc_mean_loss:',calculate_return['ctc_mean_loss']) #print('ctc_loss:',calculate_return['ctc_loss']) num_batch += 1 total_acc_error_rate += calculate_return['label_error_rate'] self.acc_label_error_rate[gpu_id] += calculate_return['label_error_rate'] self.num_batch[gpu_id] += 1 logging.info('******end TrainFunction******') def GetFeatAndLabel(self): return self.input_queue.get() def GetAvergaeLabelErrorRate(self): tot_label_error_rate = 0.0 tot_num_batch = 0 for i in range(self.num_threads_cf): tot_label_error_rate += self.acc_label_error_rate[i] tot_num_batch += self.num_batch[i] if tot_num_batch == 0: average_label_error_rate = 1.0 else: average_label_error_rate = tot_label_error_rate / tot_num_batch self.tot_lab_err_rate += tot_label_error_rate self.tot_num_batch += tot_num_batch self.ResetAccuracy(tot_reset = False) return average_label_error_rate def GetTotLabErrRate(self): return self.tot_lab_err_rate/self.tot_num_batch def ResetAccuracy(self, tot_reset = True): for i in range(len(self.acc_label_error_rate)): self.acc_label_error_rate[i] = 0.0 self.num_batch[i] = 0 if tot_reset: self.tot_lab_err_rate = 0 self.tot_num_batch = 0 for i in range(5): self.all_lab_err_rate.append(1.1) self.num_save = 0
class train_class(object): def __init__(self, conf_dict): self.nnet_conf = ProjConfig() self.nnet_conf.initial(conf_dict) self.kaldi_io_nstream = None feat_trans = FeatureTransform() feat_trans.LoadTransform(conf_dict['feature_transfile']) # init train file self.kaldi_io_nstream_train = KaldiDataReadParallel() self.input_dim = self.kaldi_io_nstream_train.Initialize( conf_dict, scp_file=conf_dict['scp_file'], label=conf_dict['label'], feature_transform=feat_trans, criterion='ce') # init cv file self.kaldi_io_nstream_cv = KaldiDataReadParallel() self.kaldi_io_nstream_cv.Initialize(conf_dict, scp_file=conf_dict['cv_scp'], label=conf_dict['cv_label'], feature_transform=feat_trans, criterion='ce') self.num_batch_total = 0 self.num_frames_total = 0 logging.info(self.nnet_conf.__repr__()) logging.info(self.kaldi_io_nstream_train.__repr__()) logging.info(self.kaldi_io_nstream_cv.__repr__()) self.print_trainable_variables = False if conf_dict.has_key('print_trainable_variables'): self.print_trainable_variables = conf_dict[ 'print_trainable_variables'] self.tf_async_model_prefix = conf_dict['checkpoint_dir'] self.num_threads = conf_dict['num_threads'] self.queue_cache = conf_dict['queue_cache'] self.input_queue = Queue.Queue(self.queue_cache) self.acc_label_error_rate = [] for i in range(self.num_threads): self.acc_label_error_rate.append(1.1) if conf_dict.has_key('use_normal'): self.use_normal = conf_dict['use_normal'] else: self.use_normal = False if conf_dict.has_key('use_sgd'): self.use_sgd = conf_dict['use_sgd'] else: self.use_sgd = True if conf_dict.has_key('restore_training'): self.restore_training = conf_dict['restore_training'] else: self.restore_training = False def get_num(self, str): return int(str.split('/')[-1].split('_')[0]) #model_48434.ckpt.final def construct_graph(self): with tf.Graph().as_default(): self.run_ops = [] #self.X = tf.placeholder(tf.float32, [None, None, self.input_dim], name='feature') print(self.nnet_conf.num_frames_batch, self.nnet_conf.batch_size, self.input_dim) self.X = tf.placeholder(tf.float32, [ self.nnet_conf.num_frames_batch, self.nnet_conf.batch_size, self.input_dim ], name='feature') #self.Y = tf.sparse_placeholder(tf.int32, name="labels") self.Y = tf.placeholder( tf.int32, [self.nnet_conf.batch_size, self.nnet_conf.num_frames_batch], name="labels") self.seq_len = tf.placeholder(tf.int32, [None], name='seq_len') self.learning_rate_var = tf.Variable(float( self.nnet_conf.learning_rate), trainable=False, name='learning_rate') if self.use_sgd: optimizer = tf.train.GradientDescentOptimizer( self.learning_rate_var) else: optimizer = tf.train.AdamOptimizer( learning_rate=self.learning_rate_var, beta1=0.9, beta2=0.999, epsilon=1e-08) for i in range(self.num_threads): with tf.device("/gpu:%d" % i): initializer = tf.random_uniform_initializer( -self.nnet_conf.init_scale, self.nnet_conf.init_scale) model = LSTM_Model(self.nnet_conf) mean_loss, ce_loss, rnn_keep_state_op, rnn_state_zero_op, label_error_rate, softval = model.ce_train( self.X, self.Y, self.seq_len) if self.use_sgd and self.use_normal: tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm( tf.gradients(mean_loss, tvars), self.nnet_conf.grad_clip) train_op = optimizer.apply_gradients( zip(grads, tvars), global_step=tf.contrib.framework. get_or_create_global_step()) else: train_op = optimizer.minimize(mean_loss) run_op = { 'train_op': train_op, 'mean_loss': mean_loss, 'ce_loss': ce_loss, 'rnn_keep_state_op': rnn_keep_state_op, 'rnn_state_zero_op': rnn_state_zero_op, 'label_error_rate': label_error_rate, 'softval': softval } self.run_ops.append(run_op) tf.get_variable_scope().reuse_variables() gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95) self.sess = tf.Session(config=tf.ConfigProto( intra_op_parallelism_threads=self.num_threads, allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_options)) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) tmp_variables = tf.trainable_variables() self.saver = tf.train.Saver(tmp_variables, max_to_keep=100) #self.saver = tf.train.Saver(max_to_keep=100) if self.restore_training: self.sess.run(init) ckpt = tf.train.get_checkpoint_state( self.tf_async_model_prefix) if ckpt and ckpt.model_checkpoint_path: logging.info("restore training") self.saver.restore(self.sess, ckpt.model_checkpoint_path) self.num_batch_total = self.get_num( ckpt.model_checkpoint_path) if self.print_trainable_variables == True: print_trainable_variables( self.sess, ckpt.model_checkpoint_path + '.txt') sys.exit(0) logging.info('model:' + ckpt.model_checkpoint_path) logging.info('restore learn_rate:' + str(self.sess.run(self.learning_rate_var))) #print('*******************',self.num_batch_total) #time.sleep(3) #model_48434.ckpt.final #print("ckpt.model_checkpoint_path",ckpt.model_checkpoint_path) #print("self.tf_async_model_prefix",self.tf_async_model_prefix) #self.saver.restore(self.sess, self.tf_async_model_prefix) else: logging.info('No checkpoint file found') self.sess.run(init) logging.info('init learn_rate:' + str(self.sess.run(self.learning_rate_var))) else: self.sess.run(init) self.total_variables = np.sum([ np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() ]) logging.info('total parameters : %d' % self.total_variables) def train_function(self, gpu_id, run_op, thread_name): total_acc_error_rate = 0.0 num_batch = 0 num_bptt = 0 while True: time1 = time.time() feat, label, length = self.get_feat_and_label() if feat is None: logging.info('train thread ok: %s' % thread_name) break time2 = time.time() print('******time:', time2 - time1, thread_name) self.sess.run(run_op['rnn_state_zero_op']) for i in range(len(feat)): feed_dict = { self.X: feat[i], self.Y: label[i], self.seq_len: length[i] } run_need_op = { 'train_op': run_op['train_op'], 'mean_loss': run_op['mean_loss'], 'ce_loss': run_op['ce_loss'], 'rnn_keep_state_op': run_op['rnn_keep_state_op'], 'label_error_rate': run_op['label_error_rate'] } time3 = time.time() calculate_return = self.sess.run(run_need_op, feed_dict=feed_dict) print('mean_loss:', calculate_return['mean_loss']) #print('ce_loss:',calculate_return['ce_loss']) #self.sess.run(run_op['rnn_keep_state_op']) time4 = time.time() print(num_batch, " time:", time4 - time3) print('label_error_rate:', calculate_return['label_error_rate']) total_acc_error_rate += calculate_return['label_error_rate'] num_bptt += 1 time5 = time.time() print(num_batch, " time:", time2 - time1, time5 - time2) num_batch += 1 #total_acc_error_rate += calculate_return['label_error_rate'] self.acc_label_error_rate[gpu_id] = total_acc_error_rate / num_bptt self.acc_label_error_rate[gpu_id] = total_acc_error_rate / num_bptt def cv_function(self, gpu_id, run_op, thread_name): total_acc_error_rate = 0.0 num_batch = 0 num_bptt = 0 while True: feat, label, length = self.get_feat_and_label() if feat is None: logging.info('cv ok : %s\n' % thread_name) break self.sess.run(run_op['rnn_state_zero_op']) for i in range(len(feat)): feed_dict = { self.X: feat[i], self.Y: label[i], self.seq_len: length[i] } run_need_op = { 'mean_loss': run_op['mean_loss'], 'ce_loss': run_op['ce_loss'], 'rnn_keep_state_op': run_op['rnn_keep_state_op'], 'label_error_rate': run_op['label_error_rate'] } #'softval':run_op['softval']} calculate_return = self.sess.run(run_need_op, feed_dict=feed_dict) total_acc_error_rate += calculate_return['label_error_rate'] # print('feat:',feat[i]) print('label_error_rate:', calculate_return['label_error_rate']) print('mean_loss:', calculate_return['mean_loss']) print('ce_loss', calculate_return['ce_loss']) # print(i,'softval', calculate_return['softval']) # print('rnn_keep_state_op', calculate_return['rnn_keep_state_op']) num_bptt += 1 num_batch += 1 self.acc_label_error_rate[gpu_id] = total_acc_error_rate / num_bptt self.acc_label_error_rate[gpu_id] = total_acc_error_rate / num_bptt def get_feat_and_label(self): return self.input_queue.get() def input_feat_and_label(self): feat, label, length = self.kaldi_io_nstream.LoadNextNstreams() if length is None: return False if len(label) != self.nnet_conf.batch_size: return False sparse_label = sparse_tuple_from(label) self.input_queue.put((feat, sparse_label, length)) self.num_batch_total += 1 for i in length: self.num_frames_total += i print('total_batch_num**********', self.num_batch_total, '***********') return True def input_ce_feat_and_label(self): feat_array, label_array, length_array = self.kaldi_io_nstream.SliceLoadNextNstreams( ) if length_array is None: return False if len(label_array[0]) != self.nnet_conf.batch_size: return False #process feature #sparse_label_array = [] #for lab in label: # sparse_label_array.append(sparse_tuple_from(lab)) self.input_queue.put((feat_array, label_array, length_array)) self.num_batch_total += 1 for batch_len in length_array: for i in batch_len: self.num_frames_total += i print('total_batch_num**********', self.num_batch_total, '***********') return True def cv_logic(self): self.kaldi_io_nstream = self.kaldi_io_nstream_cv train_thread = [] #first start cv thread for i in range(self.num_threads): train_thread.append( threading.Thread(group=None, target=self.cv_function, args=(i, self.run_ops[i], 'thread_hubo_' + str(i)), name='thread_hubo_' + str(i))) for thr in train_thread: thr.start() logging.info('start cv thread.') while True: # input data if self.input_ce_feat_and_label(): continue break logging.info('cv read feat ok') for thr in train_thread: self.input_queue.put((None, None, None)) while True: if self.input_queue.empty(): logging.info('cv is ok') break for thr in train_thread: thr.join() logging.info('join cv thread %s' % thr.name) tmp_label_error_rate = self.get_avergae_label_error_rate() self.kaldi_io_nstream.Reset() self.reset_acc() return tmp_label_error_rate def train_logic(self): self.kaldi_io_nstream = self.kaldi_io_nstream_train train_thread = [] #first start train thread for i in range(self.num_threads): #self.acc_label_error_rate.append(1.0) train_thread.append( threading.Thread(group=None, target=self.train_function, args=(i, self.run_ops[i], 'thread_hubo_' + str(i)), name='thread_hubo_' + str(i))) for thr in train_thread: thr.start() logging.info('start train thread ok.\n') all_lab_err_rate = [] for i in range(5): all_lab_err_rate.append(1.1) while True: # save model if self.num_batch_total % 1000 == 0: while True: #print('wait save mode') time.sleep(0.5) if self.input_queue.empty(): checkpoint_path = os.path.join( self.tf_async_model_prefix, str(self.num_batch_total) + '_model' + '.ckpt') logging.info( 'save model: ' + checkpoint_path + ' --- learn_rate: ' + str(self.sess.run(self.learning_rate_var))) self.saver.save(self.sess, checkpoint_path) if self.num_batch_total == 0: break curr_lab_err_rate = self.get_avergae_label_error_rate() all_lab_err_rate.sort() for i in range(len(all_lab_err_rate)): if curr_lab_err_rate < all_lab_err_rate[i]: all_lab_err_rate[len(all_lab_err_rate) - 1] = curr_lab_err_rate break if i == len(all_lab_err_rate) - 1: train_logic.decay_learning_rate(0.5) all_lab_err_rate[len(all_lab_err_rate) - 1] = curr_lab_err_rate break # input data if self.input_ce_feat_and_label(): continue break time.sleep(1) logging.info('read feat ok') ''' end train ''' for thr in train_thread: self.input_queue.put((None, None, None)) while True: if self.input_queue.empty(): logging.info('train is end') checkpoint_path = os.path.join( self.tf_async_model_prefix, str(self.num_batch_total) + '_model' + '.ckpt') self.saver.save(self.sess, checkpoint_path + '.final') break ''' train is end ''' for thr in train_thread: thr.join() logging.info('join thread %s' % thr.name) tmp_label_error_rate = self.get_avergae_label_error_rate() self.kaldi_io_nstream.Reset() self.reset_acc() return tmp_label_error_rate def decay_learning_rate(self, lr_decay_factor): learning_rate_decay_op = self.learning_rate_var.assign( tf.multiply(self.learning_rate_var, lr_decay_factor)) self.sess.run(learning_rate_decay_op) logging.info('learn_rate decay to ' + str(self.sess.run(self.learning_rate_var))) logging.info('lr_decay_factor is ' + str(lr_decay_factor)) # return learning_rate_decay_op def get_avergae_label_error_rate(self): average_label_error_rate = 0.0 for i in range(self.num_threads): average_label_error_rate += self.acc_label_error_rate[i] average_label_error_rate /= self.num_threads logging.info("average label error rate : %f" % average_label_error_rate) return average_label_error_rate def reset_acc(self): for i in range(len(self.acc_label_error_rate)): self.acc_label_error_rate[i] = 1.1
path = '/search/speech/hubo/git/tf-code-acoustics/chain_source_7300/' #path = './' conf_dict = { 'batch_size': 64, 'skip_offset': 0, 'skip_frame': 3, 'shuffle': False, 'queue_cache': 10, 'io_thread_num': 5 } feat_trans_file = '../conf/final.feature_transform' feat_trans = FeatureTransform() feat_trans.LoadTransform(feat_trans_file) logging.basicConfig(filename='test.log') logging.getLogger().setLevel('INFO') io_read = KaldiDataReadParallel() io_read.Initialize( conf_dict, scp_file=path + 'cegs.1.scp', #io_read.Initialize(conf_dict, scp_file=path+'scp', #io_read.Initialize(conf_dict, scp_file=path+'cegs.all.scp_0', feature_transform=feat_trans, criterion='chain') batch_info = 2000 start = time.time() io_read.Reset(shuffle=False) batch_num = 0 model = CommonModel(nnet_conf) # Instantiate an optimizer.
def __init__(self, conf_dict): self.print_trainable_variables = False self.use_normal = False self.use_sgd = True self.restore_training = True self.checkpoint_dir = None self.num_threads = 1 self.queue_cache = 100 self.feature_transfile = None # initial configuration parameter for key in self.__dict__: if key in conf_dict.keys(): self.__dict__[key] = conf_dict[key] # initial nnet configuration parameter self.nnet_conf = ProjConfig() self.nnet_conf.initial(conf_dict) self.kaldi_io_nstream = None if self.feature_transfile == None: logging.info('No feature_transfile,it must have.') sys.exit(1) feat_trans = FeatureTransform() feat_trans.LoadTransform(self.feature_transfile) # init train file self.kaldi_io_nstream_train = KaldiDataReadParallel() self.input_dim = self.kaldi_io_nstream_train.Initialize( conf_dict, scp_file=conf_dict['scp_file'], label=conf_dict['label'], feature_transform=feat_trans, criterion='ctc') # init cv file self.kaldi_io_nstream_cv = KaldiDataReadParallel() self.kaldi_io_nstream_cv.Initialize(conf_dict, scp_file=conf_dict['cv_scp'], label=conf_dict['cv_label'], feature_transform=feat_trans, criterion='ctc') self.num_batch_total = 0 self.tot_lab_err_rate = 0.0 self.tot_num_batch = 0 logging.info(self.nnet_conf.__repr__()) logging.info(self.kaldi_io_nstream_train.__repr__()) logging.info(self.kaldi_io_nstream_cv.__repr__()) self.input_queue = Queue.Queue(self.queue_cache) self.acc_label_error_rate = [] # record every thread label error rate self.all_lab_err_rate = [] # for adjust learn rate self.num_save = 0 for i in range(5): self.all_lab_err_rate.append(1.1) self.num_batch = [] for i in range(self.num_threads): self.acc_label_error_rate.append(1.0) self.num_batch.append(0)
class train_class(object): def __init__(self, conf_dict): self.print_trainable_variables = False self.use_normal = False self.use_sgd = True self.restore_training = True self.checkpoint_dir = None self.num_threads = 1 self.queue_cache = 100 self.feature_transfile = None # initial configuration parameter for key in self.__dict__: if key in conf_dict.keys(): self.__dict__[key] = conf_dict[key] # initial nnet configuration parameter self.nnet_conf = ProjConfig() self.nnet_conf.initial(conf_dict) self.kaldi_io_nstream = None if self.feature_transfile == None: logging.info('No feature_transfile,it must have.') sys.exit(1) feat_trans = FeatureTransform() feat_trans.LoadTransform(self.feature_transfile) # init train file self.kaldi_io_nstream_train = KaldiDataReadParallel() self.input_dim = self.kaldi_io_nstream_train.Initialize( conf_dict, scp_file=conf_dict['scp_file'], label=conf_dict['label'], feature_transform=feat_trans, criterion='ctc') # init cv file self.kaldi_io_nstream_cv = KaldiDataReadParallel() self.kaldi_io_nstream_cv.Initialize(conf_dict, scp_file=conf_dict['cv_scp'], label=conf_dict['cv_label'], feature_transform=feat_trans, criterion='ctc') self.num_batch_total = 0 self.tot_lab_err_rate = 0.0 self.tot_num_batch = 0 logging.info(self.nnet_conf.__repr__()) logging.info(self.kaldi_io_nstream_train.__repr__()) logging.info(self.kaldi_io_nstream_cv.__repr__()) self.input_queue = Queue.Queue(self.queue_cache) self.acc_label_error_rate = [] # record every thread label error rate self.all_lab_err_rate = [] # for adjust learn rate self.num_save = 0 for i in range(5): self.all_lab_err_rate.append(1.1) self.num_batch = [] for i in range(self.num_threads): self.acc_label_error_rate.append(1.0) self.num_batch.append(0) # get restore model number def get_num(self, str): return int(str.split('/')[-1].split('_')[0]) # construct train graph def construct_graph(self): with tf.Graph().as_default(): self.run_ops = [] self.X = tf.placeholder(tf.float32, [None, None, self.input_dim], name='feature') self.Y = tf.sparse_placeholder(tf.int32, name="labels") self.seq_len = tf.placeholder(tf.int32, [None], name='seq_len') self.learning_rate_var = tf.Variable(float( self.nnet_conf.learning_rate), trainable=False, name='learning_rate') if self.use_sgd: optimizer = tf.train.GradientDescentOptimizer( self.learning_rate_var) else: optimizer = tf.train.AdamOptimizer( learning_rate=self.learning_rate_var, beta1=0.9, beta2=0.999, epsilon=1e-08) for i in range(self.num_threads): with tf.device("/gpu:%d" % i): initializer = tf.random_uniform_initializer( -self.nnet_conf.init_scale, self.nnet_conf.init_scale) model = LSTM_Model(self.nnet_conf) mean_loss, ctc_loss, label_error_rate, decoded, softval = model.loss( self.X, self.Y, self.seq_len) if self.use_sgd and self.use_normal: tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm( tf.gradients(mean_loss, tvars), self.nnet_conf.grad_clip) train_op = optimizer.apply_gradients( zip(grads, tvars), global_step=tf.contrib.framework. get_or_create_global_step()) else: train_op = optimizer.minimize(mean_loss) run_op = { 'train_op': train_op, 'mean_loss': mean_loss, 'ctc_loss': ctc_loss, 'label_error_rate': label_error_rate } # 'decoded':decoded, # 'softval':softval} self.run_ops.append(run_op) tf.get_variable_scope().reuse_variables() gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95) self.sess = tf.Session(config=tf.ConfigProto( intra_op_parallelism_threads=self.num_threads, allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_options)) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) tmp_variables = tf.trainable_variables() self.saver = tf.train.Saver(tmp_variables, max_to_keep=100) #self.saver = tf.train.Saver(max_to_keep=100, sharded = True) if self.restore_training: self.sess.run(init) ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: logging.info("restore training") self.saver.restore(self.sess, ckpt.model_checkpoint_path) self.num_batch_total = self.get_num( ckpt.model_checkpoint_path) if self.print_trainable_variables == True: print_trainable_variables( self.sess, ckpt.model_checkpoint_path + '.txt') sys.exit(0) logging.info('model:' + ckpt.model_checkpoint_path) logging.info('restore learn_rate:' + str(self.sess.run(self.learning_rate_var))) else: logging.info('No checkpoint file found') self.sess.run(init) logging.info('init learn_rate:' + str(self.sess.run(self.learning_rate_var))) else: self.sess.run(init) self.total_variables = np.sum([ np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() ]) logging.info('total parameters : %d' % self.total_variables) def SaveTextModel(self): if self.print_trainable_variables == True: ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: print_trainable_variables(self.sess, ckpt.model_checkpoint_path + '.txt') def train_function(self, gpu_id, run_op, thread_name): total_acc_error_rate = 0.0 num_batch = 0 self.acc_label_error_rate[gpu_id] = 0.0 self.num_batch[gpu_id] = 0 while True: time1 = time.time() feat, sparse_label, length = self.get_feat_and_label() if feat is None: logging.info('train thread end : %s' % thread_name) break time2 = time.time() print('******time:', time2 - time1, thread_name) feed_dict = { self.X: feat, self.Y: sparse_label, self.seq_len: length } time3 = time.time() calculate_return = self.sess.run(run_op, feed_dict=feed_dict) time4 = time.time() print("thread_name: ", thread_name, num_batch, " time:", time2 - time1, time3 - time2, time4 - time3, time4 - time1) print('label_error_rate:', calculate_return['label_error_rate']) print('mean_loss:', calculate_return['mean_loss']) print('ctc_loss:', calculate_return['ctc_loss']) num_batch += 1 total_acc_error_rate += calculate_return['label_error_rate'] self.acc_label_error_rate[gpu_id] += calculate_return[ 'label_error_rate'] self.num_batch[gpu_id] += 1 def cv_function(self, gpu_id, run_op, thread_name): total_acc_error_rate = 0.0 num_batch = 0 self.acc_label_error_rate[gpu_id] = 0.0 self.num_batch[gpu_id] = 0 while True: feat, sparse_label, length = self.get_feat_and_label() if feat is None: logging.info('cv thread end : %s' % thread_name) break feed_dict = { self.X: feat, self.Y: sparse_label, self.seq_len: length } run_need_op = {'label_error_rate': run_op['label_error_rate']} #'softval':run_op['softval']} #'mean_loss':run_op['mean_loss'], #'ctc_loss':run_op['ctc_loss'], #'label_error_rate':run_op['label_error_rate']} calculate_return = self.sess.run(run_need_op, feed_dict=feed_dict) print('label_error_rate:', calculate_return['label_error_rate']) #print('softval:',calculate_return['softval']) num_batch += 1 total_acc_error_rate += calculate_return['label_error_rate'] self.acc_label_error_rate[gpu_id] += calculate_return[ 'label_error_rate'] self.num_batch[gpu_id] += 1 def get_feat_and_label(self): return self.input_queue.get() def input_feat_and_label(self): strat_io_time = time.time() feat, label, length = self.kaldi_io_nstream.LoadNextNstreams() end_io_time = time.time() # print('*************io time**********************',end_io_time-strat_io_time) if length is None: return False if len(label) != self.nnet_conf.batch_size: return False sparse_label = sparse_tuple_from(label) self.input_queue.put((feat, sparse_label, length)) self.num_batch_total += 1 print('total_batch_num**********', self.num_batch_total, '***********') return True def InputFeat(self, input_lock): while True: feat, label, length = self.kaldi_io_nstream.LoadNextNstreams() if length is None: break if len(label) != self.nnet_conf.batch_size: break sparse_label = sparse_tuple_from(label) input_lock.acquire() self.input_queue.put((feat, sparse_label, length)) self.num_batch_total += 1 if self.num_batch_total % 3000 == 0: self.SaveModel() self.AdjustLearnRate() print('total_batch_num**********', self.num_batch_total, '***********') input_lock.release() def ThreadInputFeatAndLab(self): input_thread = [] input_lock = threading.Lock() for i in range(2): input_thread.append( threading.Thread(group=None, target=self.InputFeat, args=(input_lock), name='read_thread' + str(i))) for thr in input_thread: thr.start() for thr in input_thread: thr.join() def SaveModel(self): while True: time.sleep(1.0) if self.input_queue.empty(): checkpoint_path = os.path.join( self.checkpoint_dir, str(self.num_batch_total) + '_model' + '.ckpt') logging.info('save model: ' + checkpoint_path + ' --- learn_rate: ' + str(self.sess.run(self.learning_rate_var))) self.saver.save(self.sess, checkpoint_path) break # if current label error rate less then previous five def AdjustLearnRate(self): curr_lab_err_rate = self.get_avergae_label_error_rate() logging.info("current label error rate : %f\n" % curr_lab_err_rate) all_lab_err_rate_len = len(self.all_lab_err_rate) #self.all_lab_err_rate.sort() for i in range(all_lab_err_rate_len): if curr_lab_err_rate < self.all_lab_err_rate[i]: break if i == len(self.all_lab_err_rate) - 1: self.decay_learning_rate(0.8) self.all_lab_err_rate[self.num_save % all_lab_err_rate_len] = curr_lab_err_rate self.num_save += 1 def cv_logic(self): self.kaldi_io_nstream = self.kaldi_io_nstream_cv train_thread = [] for i in range(self.num_threads): # self.acc_label_error_rate.append(1.0) train_thread.append( threading.Thread(group=None, target=self.cv_function, args=(i, self.run_ops[i], 'thread_hubo_' + str(i)), name='thread_hubo_' + str(i))) for thr in train_thread: thr.start() logging.info('cv thread start.') while True: if self.input_feat_and_label(): continue break logging.info('cv read feat ok') for thr in train_thread: self.input_queue.put((None, None, None)) while True: if self.input_queue.empty(): break logging.info('cv is end.') for thr in train_thread: thr.join() logging.info('join cv thread %s' % thr.name) tmp_label_error_rate = self.get_avergae_label_error_rate() self.kaldi_io_nstream.Reset() self.reset_acc() return tmp_label_error_rate def train_logic(self, shuffle=False, skip_offset=0): self.kaldi_io_nstream = self.kaldi_io_nstream_train train_thread = [] for i in range(self.num_threads): # self.acc_label_error_rate.append(1.0) train_thread.append( threading.Thread(group=None, target=self.train_function, args=(i, self.run_ops[i], 'thread_hubo_' + str(i)), name='thread_hubo_' + str(i))) for thr in train_thread: thr.start() logging.info('train thread start.') all_lab_err_rate = [] for i in range(5): all_lab_err_rate.append(1.1) tot_time = 0.0 while True: if self.num_batch_total % 3000 == 0: while True: #print('wait save mode') time.sleep(0.5) if self.input_queue.empty(): checkpoint_path = os.path.join( self.checkpoint_dir, str(self.num_batch_total) + '_model' + '.ckpt') logging.info( 'save model: ' + checkpoint_path + '--- learn_rate: ' + str(self.sess.run(self.learning_rate_var))) self.saver.save(self.sess, checkpoint_path) if self.num_batch_total == 0: break self.AdjustLearnRate() # adjust learn rate break s_1 = time.time() if self.input_feat_and_label(): e_1 = time.time() tot_time += e_1 - s_1 print("***self.input_feat_and_label time*****", e_1 - s_1) continue break print("total input time:", tot_time) time.sleep(1) logging.info('read feat ok') ''' end train ''' for thr in train_thread: self.input_queue.put((None, None, None)) while True: if self.input_queue.empty(): # logging.info('train is ok') checkpoint_path = os.path.join( self.checkpoint_dir, str(self.num_batch_total) + '_model' + '.ckpt') logging.info('save model: ' + checkpoint_path + '.final --- learn_rate: ' + str(self.sess.run(self.learning_rate_var))) self.saver.save(self.sess, checkpoint_path + '.final') break ''' train is end ''' logging.info('train is end.') for thr in train_thread: thr.join() logging.info('join thread %s' % thr.name) tmp_label_error_rate = self.get_avergae_label_error_rate() self.kaldi_io_nstream.Reset(shuffle=shuffle, skip_offset=skip_offset) self.reset_acc() return tmp_label_error_rate def decay_learning_rate(self, lr_decay_factor): learning_rate_decay_op = self.learning_rate_var.assign( tf.multiply(self.learning_rate_var, lr_decay_factor)) self.sess.run(learning_rate_decay_op) logging.info('learn_rate decay to ' + str(self.sess.run(self.learning_rate_var))) logging.info('lr_decay_factor is ' + str(lr_decay_factor)) def get_avergae_label_error_rate(self): tot_label_error_rate = 0.0 tot_num_batch = 0 for i in range(self.num_threads): tot_label_error_rate += self.acc_label_error_rate[i] tot_num_batch += self.num_batch[i] if tot_num_batch == 0: average_label_error_rate = 1.0 else: average_label_error_rate = tot_label_error_rate / tot_num_batch # logging.info("average label error rate : %f\n" % average_label_error_rate) self.tot_lab_err_rate += tot_label_error_rate self.tot_num_batch += tot_num_batch self.reset_acc(tot_reset=False) return average_label_error_rate def GetTotLabErrRate(self): return self.tot_lab_err_rate / self.tot_num_batch def reset_acc(self, tot_reset=True): for i in range(len(self.acc_label_error_rate)): self.acc_label_error_rate[i] = 0.0 self.num_batch[i] = 0 if tot_reset: self.tot_lab_err_rate = 0 self.tot_num_batch = 0 for i in range(5): self.all_lab_err_rate.append(1.1) self.num_save = 0