Ejemplo n.º 1
0
    def __init__(self, conf_dict):
        self.nnet_conf = ProjConfig()
        self.nnet_conf.initial(conf_dict)

        self.kaldi_io_nstream = None
        feat_trans = FeatureTransform()
        feat_trans.LoadTransform(conf_dict['feature_transfile'])
        # init train file
        self.kaldi_io_nstream_train = KaldiDataReadParallel()
        self.input_dim = self.kaldi_io_nstream_train.Initialize(
            conf_dict,
            scp_file=conf_dict['scp_file'],
            label=conf_dict['label'],
            feature_transform=feat_trans,
            criterion='ce')

        # init cv file
        self.kaldi_io_nstream_cv = KaldiDataReadParallel()
        self.kaldi_io_nstream_cv.Initialize(conf_dict,
                                            scp_file=conf_dict['cv_scp'],
                                            label=conf_dict['cv_label'],
                                            feature_transform=feat_trans,
                                            criterion='ce')

        self.num_batch_total = 0
        self.num_frames_total = 0

        logging.info(self.nnet_conf.__repr__())
        logging.info(self.kaldi_io_nstream_train.__repr__())
        logging.info(self.kaldi_io_nstream_cv.__repr__())

        self.print_trainable_variables = False
        if conf_dict.has_key('print_trainable_variables'):
            self.print_trainable_variables = conf_dict[
                'print_trainable_variables']
        self.tf_async_model_prefix = conf_dict['checkpoint_dir']
        self.num_threads = conf_dict['num_threads']
        self.queue_cache = conf_dict['queue_cache']
        self.input_queue = Queue.Queue(self.queue_cache)
        self.acc_label_error_rate = []
        for i in range(self.num_threads):
            self.acc_label_error_rate.append(1.1)
        if conf_dict.has_key('use_normal'):
            self.use_normal = conf_dict['use_normal']
        else:
            self.use_normal = False
        if conf_dict.has_key('use_sgd'):
            self.use_sgd = conf_dict['use_sgd']
        else:
            self.use_sgd = True

        if conf_dict.has_key('restore_training'):
            self.restore_training = conf_dict['restore_training']
        else:
            self.restore_training = False
Ejemplo n.º 2
0
                )


if __name__ == '__main__':
    path = '/search/speech/hubo/git/tf-code-acoustics/chain_source_7300/8-cegs-scp/'
    path = '/search/speech/hubo/git/tf-code-acoustics/chain_source_7300/'
    conf_dict = { 'batch_size' :64,
            'skip_offset': 0,
            'skip_frame':3,
            'shuffle': False,
            'queue_cache':2,
            'io_thread_num':5}
    feat_trans_file = '../conf/final.feature_transform'
    feat_trans = FeatureTransform()
    feat_trans.LoadTransform(feat_trans_file)
    io_read = KaldiDataReadParallel()
    #io_read.Initialize(conf_dict, scp_file=path+'cegs.all.scp_0',
    io_read.Initialize(conf_dict, scp_file=path+'cegs.1.scp',
            feature_transform = feat_trans, criterion = 'chain')

    io_read.Reset(shuffle = False)

    def Gen():
        while True:
            inputs = io_read.GetInput()
            if inputs[0] is not None:
                indexs, in_labels, weights, statesinfo, num_states = inputs[3]
                yield(inputs[0],inputs[1],inputs[2],indexs, in_labels, weights, statesinfo, num_states)
            else:
                print("-----end io----")
                break
Ejemplo n.º 3
0
if __name__ == '__main__':
    path = './'
    conf_dict = {
        'batch_size': 64,
        'skip_offset': 0,
        'shuffle': False,
        'queue_cache': 2,
        'io_thread_num': 2
    }
    feat_trans_file = '../../conf/final.feature_transform'
    feat_trans = FeatureTransform()
    feat_trans.LoadTransform(feat_trans_file)
    logging.basicConfig(filename='test.log')
    logging.getLogger().setLevel('INFO')
    io_read = KaldiDataReadParallel()
    io_read.Initialize(conf_dict,
                       scp_file=path + 'scp',
                       feature_transform=feat_trans,
                       criterion='chain')

    start = time.time()
    io_read.Reset(shuffle=False)
    batch_num = 0
    den_fst = '../../chain_source/den.fst'
    den_indexs, den_in_labels, den_weights, den_statesinfo, den_num_states, den_start_state, laststatesuperfinal = Fst2SparseMatrix(
        den_fst)
    leaky_hmm_coefficient = 0.1
    l2_regularize = 0.00005
    xent_regularize = 0.0
    delete_laststatesuperfinal = True
Ejemplo n.º 4
0
    def __init__(self, conf_dict):
        # configure paramtere
        self.conf_dict = conf_dict
        self.print_trainable_variables_cf = False
        self.use_normal_cf = False
        self.use_sgd_cf = True
        self.restore_training_cf = True
        self.checkpoint_dir_cf = None
        self.num_threads_cf = 1
        self.queue_cache_cf = 100
        self.task_index_cf = -1
        self.grad_clip_cf = 5.0
        self.feature_transfile_cf = None
        self.learning_rate_cf = 0.1
        self.learning_rate_decay_steps_cf = 100000
        self.learning_rate_decay_rate_cf = 0.96
        self.batch_size_cf = 16
        self.num_frames_batch_cf = 20

        self.steps_per_checkpoint_cf = 1000

        self.criterion_cf = 'ctc'
        # initial configuration parameter
        for attr in self.__dict__:
            if len(attr.split('_cf')) != 2:
                continue;
            key = attr.split('_cf')[0]
            if key in conf_dict.keys():
                if key in strset or type(conf_dict[key]) is not str:
                    self.__dict__[attr] = conf_dict[key]
                else:
                    print('***************',key)
                    self.__dict__[attr] = eval(conf_dict[key])

        if self.feature_transfile_cf == None:
            logging.info('No feature_transfile,it must have.')
            sys.exit(1)
        feat_trans = FeatureTransform()
        feat_trans.LoadTransform(self.feature_transfile_cf)

        # init train file
        self.kaldi_io_nstream_train = KaldiDataReadParallel()
        self.input_dim = self.kaldi_io_nstream_train.Initialize(conf_dict,
                scp_file = conf_dict['tr_scp'], label = conf_dict['tr_label'],
                feature_transform = feat_trans, criterion = self.criterion_cf)
        # init cv file
        self.kaldi_io_nstream_cv = KaldiDataReadParallel()
        #self.kaldi_io_nstream_cv.Initialize(conf_dict,
        #        scp_file = conf_dict['cv_scp'], label = conf_dict['cv_label'],
        #        feature_transform = feat_trans, criterion = 'ctc')

        self.num_batch_total = 0
        self.tot_lab_err_rate = 0.0
        self.tot_num_batch = 0.0

        logging.info(self.kaldi_io_nstream_train.__repr__())
        #logging.info(self.kaldi_io_nstream_cv.__repr__())
        
        # Initial input queue.
        self.input_queue = Queue.Queue(self.queue_cache_cf)

        self.acc_label_error_rate = []
        self.all_lab_err_rate = []
        self.num_save = 0
        for i in range(5):
            self.all_lab_err_rate.append(1.1)
        self.num_batch = []
        for i in range(self.num_threads_cf):
            self.acc_label_error_rate.append(1.0)
            self.num_batch.append(0)

        return
Ejemplo n.º 5
0
class TrainClass(object):
    '''
    '''
    def __init__(self, conf_dict):
        # configure paramtere
        self.conf_dict = conf_dict
        self.print_trainable_variables_cf = False
        self.use_normal_cf = False
        self.use_sgd_cf = True
        self.restore_training_cf = True
        self.checkpoint_dir_cf = None
        self.num_threads_cf = 1
        self.queue_cache_cf = 100
        self.task_index_cf = -1
        self.grad_clip_cf = 5.0
        self.feature_transfile_cf = None
        self.learning_rate_cf = 0.1
        self.learning_rate_decay_steps_cf = 100000
        self.learning_rate_decay_rate_cf = 0.96
        self.batch_size_cf = 16
        self.num_frames_batch_cf = 20

        self.steps_per_checkpoint_cf = 1000

        self.criterion_cf = 'ctc'
        # initial configuration parameter
        for attr in self.__dict__:
            if len(attr.split('_cf')) != 2:
                continue;
            key = attr.split('_cf')[0]
            if key in conf_dict.keys():
                if key in strset or type(conf_dict[key]) is not str:
                    self.__dict__[attr] = conf_dict[key]
                else:
                    print('***************',key)
                    self.__dict__[attr] = eval(conf_dict[key])

        if self.feature_transfile_cf == None:
            logging.info('No feature_transfile,it must have.')
            sys.exit(1)
        feat_trans = FeatureTransform()
        feat_trans.LoadTransform(self.feature_transfile_cf)

        # init train file
        self.kaldi_io_nstream_train = KaldiDataReadParallel()
        self.input_dim = self.kaldi_io_nstream_train.Initialize(conf_dict,
                scp_file = conf_dict['tr_scp'], label = conf_dict['tr_label'],
                feature_transform = feat_trans, criterion = self.criterion_cf)
        # init cv file
        self.kaldi_io_nstream_cv = KaldiDataReadParallel()
        #self.kaldi_io_nstream_cv.Initialize(conf_dict,
        #        scp_file = conf_dict['cv_scp'], label = conf_dict['cv_label'],
        #        feature_transform = feat_trans, criterion = 'ctc')

        self.num_batch_total = 0
        self.tot_lab_err_rate = 0.0
        self.tot_num_batch = 0.0

        logging.info(self.kaldi_io_nstream_train.__repr__())
        #logging.info(self.kaldi_io_nstream_cv.__repr__())
        
        # Initial input queue.
        self.input_queue = Queue.Queue(self.queue_cache_cf)

        self.acc_label_error_rate = []
        self.all_lab_err_rate = []
        self.num_save = 0
        for i in range(5):
            self.all_lab_err_rate.append(1.1)
        self.num_batch = []
        for i in range(self.num_threads_cf):
            self.acc_label_error_rate.append(1.0)
            self.num_batch.append(0)

        return
    
    # multi computers construct train graph
    def ConstructGraph(self, device, server):
        with tf.device(device):
            if 'cnn' in self.criterion_cf:
                self.X = tf.placeholder(tf.float32, [None, self.input_dim[0], self.input_dim[1], 1],
                        name='feature')
            else:
                self.X = tf.placeholder(tf.float32, [None, self.batch_size_cf, self.input_dim], 
                        name='feature')
            
            if 'ctc' in self.criterion_cf:
                self.Y = tf.sparse_placeholder(tf.int32, name="labels")
            elif 'whole' in self.criterion_cf:
                self.Y = tf.placeholder(tf.int32, [self.batch_size_cf, None], name="labels")
            elif 'ce' in self.criterion_cf:
                self.Y = tf.placeholder(tf.int32, [self.batch_size_cf, self.num_frames_batch_cf], name="labels")

            self.seq_len = tf.placeholder(tf.int32,[None], name = 'seq_len')
            
            #self.learning_rate_var_tf = tf.Variable(float(self.learning_rate_cf), 
            #        trainable=False, name='learning_rate')
            # init global_step and learning rate decay criterion
            global_step=tf.train.get_or_create_global_step()
            exponential_decay = True
            if exponential_decay == True:
                self.learning_rate_var_tf = tf.train.exponential_decay(
                        float(self.learning_rate_cf), global_step,
                        self.learning_rate_decay_steps_cf,
                        self.learning_rate_decay_rate_cf,
                        staircase=True, 
                        name = 'learning_rate_exponential_decay')
            elif piecewise_constant == True:
                boundaries = [100000, 110000]
                values = [1.0, 0.5, 0.1]
                self.learning_rate_var_tf = tf.train.piecewise_constant(
                        global_step, boundaries, values)
            elif inverse_time_decay == True:
                # decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
                # decay_rate = 0.5 , decay_step = 100000
                self.learning_rate_var_tf =  tf.train.inverse_time_decay(
                        float(self.learning_rate_cf), global_step,
                        self.learning_rate_decay_steps_cf,
                        self.learning_rate_decay_rate_cf,
                        staircase=True,
                        name = 'learning_rate_inverse_time_decay')


            if self.use_sgd_cf:
                optimizer = tf.train.GradientDescentOptimizer(self.learning_rate_var_tf)
            else:
                optimizer = tf.train.AdamOptimizer(learning_rate=
                        self.learning_rate_var_tf, beta1=0.9, beta2=0.999, epsilon=1e-08)
            nnet_model = LstmModel(self.conf_dict)

            mean_loss = None
            loss = None
            rnn_state_zero_op = None
            rnn_keep_state_op = None
            if 'ctc' in self.criterion_cf:
                ctc_mean_loss, ctc_loss , label_error_rate, _ = nnet_model.CtcLoss(self.X, self.Y, self.seq_len)
                mean_loss = ctc_mean_loss
                loss = ctc_loss
#            elif 'ce' in self.criterion_cf and 'cnn' in self.criterion_cf and 'whole' in self.criterion_cf:
#                ce_mean_loss, ce_loss , label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.CeCnnBlstmLoss(self.X, self.Y, self.seq_len)
#                mean_loss = ce_mean_loss
#                loss = ce_loss
            elif 'ce' in self.criterion_cf:
                ce_mean_loss, ce_loss , label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.CeLoss(self.X, self.Y, self.seq_len)
                mean_loss = ce_mean_loss
                loss = ce_loss

            if self.use_sgd_cf and self.use_normal_cf:
                tvars = tf.trainable_variables()
                grads, _ = tf.clip_by_global_norm(tf.gradients(
                    mean_loss, tvars), self.grad_clip_cf)
                train_op = optimizer.apply_gradients(
                        zip(grads, tvars),
                        global_step=tf.train.get_or_create_global_step())
            else:
                train_op = optimizer.minimize(mean_loss,
                        global_step=tf.train.get_or_create_global_step())

            # set run operation
            self.run_ops = {'train_op':train_op,
                    'mean_loss':mean_loss,
                    'loss':loss,
                    'label_error_rate':label_error_rate,
                    'rnn_keep_state_op':rnn_keep_state_op,
                    'rnn_state_zero_op':rnn_state_zero_op}

            # set initial parameter
            self.init_para = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())

            #tmp_variables = tf.trainable_variables()
            #self.saver = tf.train.Saver(tmp_variables, max_to_keep=100)

            self.total_variables = np.sum([np.prod(v.get_shape().as_list()) 
                for v in tf.trainable_variables()])
            logging.info('total parameters : %d' % self.total_variables)
            
            # set gpu option
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)

            # session config
            sess_config = tf.ConfigProto(intra_op_parallelism_threads=self.num_threads_cf,
                    inter_op_parallelism_threads=self.num_threads_cf,
                    allow_soft_placement=True,
                    log_device_placement=False,gpu_options=gpu_options)
            global_step = tf.train.get_or_create_global_step()

            # add saver hook
            self.saver = tf.train.Saver(max_to_keep=50, sharded=True, allow_empty=True)
            scaffold = tf.train.Scaffold(saver = self.saver)
            self.sess = tf.train.MonitoredTrainingSession(
                    master = server.target,
                    is_chief = (self.task_index_cf==0),
                    checkpoint_dir = self.checkpoint_dir_cf,
                    scaffold= scaffold,
                    hooks=None,
                    chief_only_hooks=None,
                    save_checkpoint_secs=None,
                    save_summaries_steps=self.steps_per_checkpoint_cf,
                    save_summaries_secs=None,
                    config=sess_config,
                    stop_grace_period_secs=120,
                    log_step_count_steps=100,
                    max_wait_secs=7200,
                    save_checkpoint_steps=self.steps_per_checkpoint_cf)
          #          summary_dir = self.checkpoint_dir_cf + "_summary_dir")
            '''
            sv = tf.train.Supervisor(is_chief=(self.task_index_cf==0),
                    global_step=global_step,
                    init_op = self.init_para,
                    logdir = self.checkpoint_dir_cf,
                    saver=self.saver,
                    save_model_secs=600,
                    checkpoint_basename='model.ckpt')

            self.sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
            '''
        return 

    def SaveTextModel(self):
        if self.print_trainable_variables_cf == True:
            ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir_cf)
            if ckpt and ckpt.model_checkpoint_path:
                print_trainable_variables(self.sess, ckpt.model_checkpoint_path+'.txt')

    def InputFeat(self, input_lock):
        while True:
            input_lock.acquire()
            '''
            if 'ctc' in self.criterion_cf or 'whole' in self.criterion_cf:
                if 'cnn' in self.criterion_cf:
                    feat,label,length = self.kaldi_io_nstream.CnnLoadNextNstreams()
                else:
                    feat,label,length = self.kaldi_io_nstream.WholeLoadNextNstreams()
                if length is None:
                    break
                print(np.shape(feat),np.shape(label), np.shape(length))
                if len(label) != self.batch_size_cf:
                    break
                if 'ctc' in self.criterion_cf:
                    sparse_label = sparse_tuple_from(label)
                    self.input_queue.put((feat,sparse_label,length))
                else:
                    self.input_queue.put((feat,label,length))

            elif 'ce' in self.criterion_cf:
                if 'cnn' in self.criterion_cf:
                    feat_array, label_array, length_array = self.kaldi_io_nstream.CnnSliceLoadNextNstreams()
                else:
                    feat_array, label_array, length_array = self.kaldi_io_nstream.SliceLoadNextNstreams()
                if length_array is None:
                    break
                print(np.shape(feat_array),np.shape(label_array), np.shape(length_array))
                if len(label_array[0]) != self.batch_size_cf:
                    break
                self.input_queue.put((feat_array, label_array, length_array))
            '''
            feat,label,length = self.kaldi_io_nstream.LoadBatch()
            if length is None:
                break
            print(np.shape(feat),np.shape(label), np.shape(length))
            sys.stdout.flush()
            if 'ctc' in self.criterion_cf:
                sparse_label = sparse_tuple_from(label)
                self.input_queue.put((feat,sparse_label,length))
            else:
                self.input_queue.put((feat,label,length))
            self.num_batch_total += 1
#            if self.num_batch_total % 3000 == 0:
#                self.SaveModel()
#                self.AdjustLearnRate()
            print('total_batch_num**********',self.num_batch_total,'***********')
            input_lock.release()
        self.input_queue.put((None, None, None))

    def ThreadInputFeatAndLab(self):
        input_thread = []
        input_lock = threading.Lock()
        for i in range(1):
            input_thread.append(threading.Thread(group=None, target=self.InputFeat,
                args=(input_lock,),name='read_thread'+str(i)))

        for thr in input_thread:
            logging.info('ThreadInputFeatAndLab start')
            thr.start()

        return input_thread

    def SaveModel(self):
        while True:
            time.sleep(1.0)
            if self.input_queue.empty():
                checkpoint_path = os.path.join(self.checkpoint_dir_cf, 
                        str(self.num_batch_total)+'_model'+'.ckpt')
                logging.info('save model: '+checkpoint_path+
                        ' --- learn_rate: ' +
                        str(self.sess.run(self.learning_rate_var_tf)))
                self.saver.save(self.sess, checkpoint_path)
                break

    # if current label error rate less then previous five
    def AdjustLearnRate(self):
        curr_lab_err_rate = self.GetAverageLabelErrorRate()
        logging.info("current label error rate : %f" % curr_lab_err_rate)
        all_lab_err_rate_len = len(self.all_lab_err_rate)
        for i in range(all_lab_err_rate_len):
            if curr_lab_err_rate < self.all_lab_err_rate[i]:
                break
            if i == len(self.all_lab_err_rate)-1:
                self.DecayLearningRate(0.8)
                logging.info('learn_rate decay to '+str(self.sess.run(self.learning_rate_var_tf)))
        self.all_lab_err_rate[self.num_save%all_lab_err_rate_len] = curr_lab_err_rate
        self.num_save += 1

    def DecayLearningRate(self, lr_decay_factor):
        learning_rate_decay_op = self.learning_rate_var_tf.assign(tf.multiply(self.learning_rate_var_tf, lr_decay_factor))
        self.sess.run(learning_rate_decay_op)
        logging.info('learn_rate decay to '+str(self.sess.run(self.learning_rate_var_tf)))
        logging.info('lr_decay_factor is '+str(lr_decay_factor))

    # get restore model number
    def GetNum(self,str):
        return int(str.split('/')[-1].split('_')[0])

    # train_loss is a open train or cv .
    def TrainLogic(self, device, shuffle = False, train_loss = True, skip_offset = 0):
        if train_loss == True:
            logging.info('TrainLogic train start.')
            logging.info('Start global step is %d---learn_rate is %f' % (self.sess.run(tf.train.get_or_create_global_step()), self.sess.run(self.learning_rate_var_tf)))
            self.kaldi_io_nstream = self.kaldi_io_nstream_train
            # set run operation
            if 'ctc' in self.criterion_cf or 'whole' in self.criterion_cf:
                run_op = {'train_op':self.run_ops['train_op'],
                        'label_error_rate': self.run_ops['label_error_rate'],
                        'mean_loss':self.run_ops['mean_loss'],
                        'loss':self.run_ops['loss']}
            elif 'ce' in self.criterion_cf:
                run_op = {'train_op':self.run_ops['train_op'],
                        'label_error_rate': self.run_ops['label_error_rate'],
                        'mean_loss':self.run_ops['mean_loss'],
                        'loss':self.run_ops['loss'],
                        'rnn_keep_state_op':self.run_ops['rnn_keep_state_op'],
                        'rnn_state_zero_op':self.run_ops['rnn_state_zero_op']}
            else:
                assert 'No train criterion.'
        else:
            logging.info('TrainLogic cv start.')
            self.kaldi_io_nstream = self.kaldi_io_nstream_cv
            run_op = {'label_error_rate':self.run_ops['label_error_rate'],
                    'mean_loss':self.run_ops['mean_loss']}
        # reset io and start input thread
        self.kaldi_io_nstream.Reset(shuffle = shuffle, skip_offset = skip_offset)
        threadinput = self.ThreadInputFeatAndLab()
        time.sleep(3)
        with tf.device(device):
            if 'ctc' in self.criterion_cf or 'whole' in self.criterion_cf:
                self.WholeTrainFunction(0, run_op, 'train_ctc_thread_hubo')
            elif 'ce' in self.criterion_cf:
                self.SliceTrainFunction(0, run_op, 'train_ce_thread_hubo')
        
        tmp_label_error_rate = self.GetAverageLabelErrorRate()
        logging.info("current averagelabel error rate : %f" % tmp_label_error_rate)
        logging.info('learn_rate is '+str(self.sess.run(self.learning_rate_var_tf)))
        
        if train_loss == True:
            self.AdjustLearnRate()
            logging.info('TrainLogic train end.')
        else:
            logging.info('TrainLogic cv end.')

        # End input thread
        for  i in range(len(threadinput)):
            threadinput[i].join()

        self.ResetAccuracy()
        return tmp_label_error_rate

    def WholeTrainFunction(self, gpu_id, run_op, thread_name):
        logging.info('******start WholeTrainFunction******')
        total_curr_error_rate = 0.0
        num_batch = 0
        self.acc_label_error_rate[gpu_id] = 0.0
        self.num_batch[gpu_id] = 0

        #print_trainable_variables(self.sess, 'save.model.txt')
        while True:
            time1=time.time()
            feat, label, length = self.GetFeatAndLabel()
            if feat is None:
                logging.info('train thread end : %s' % thread_name)
                break
            time2=time.time()

            feed_dict = {self.X : feat, self.Y : label, self.seq_len : length}
            time3 = time.time()
            calculate_return = self.sess.run(run_op, feed_dict = feed_dict)
            time4 = time.time()


            print("thread_name: ", thread_name,  num_batch, " time:",time2-time1,time3-time2,time4-time3,time4-time1)
            print('label_error_rate:',calculate_return['label_error_rate'])
            print('mean_loss:',calculate_return['mean_loss'])

            num_batch += 1
            total_curr_error_rate += calculate_return['label_error_rate']
            self.acc_label_error_rate[gpu_id] += calculate_return['label_error_rate']
            self.num_batch[gpu_id] += 1
            if self.num_batch[gpu_id] % int(self.steps_per_checkpoint_cf/50) == 0:
                logging.info("Batch: %d current averagelabel error rate : %f" % (int(self.steps_per_checkpoint_cf/50), total_curr_error_rate / int(self.steps_per_checkpoint_cf/50)))
                total_curr_error_rate = 0.0
                logging.info("Batch: %d current total averagelabel error rate : %f" % (self.num_batch[gpu_id], self.acc_label_error_rate[gpu_id] / self.num_batch[gpu_id]))
        logging.info('******end TrainFunction******')

    def SliceTrainFunction(self, gpu_id, run_op, thread_name):
        logging.info('******start TrainFunction******')
        total_acc_error_rate = 0.0
        num_batch = 0
        self.acc_label_error_rate[gpu_id] = 0.0
        self.num_batch[gpu_id] = 0

        while True:
            time1=time.time()
            feat, label, length = self.GetFeatAndLabel()
            if feat is None:
                logging.info('train thread end : %s' % thread_name)
                break
            time2=time.time()
            self.sess.run(run_op['rnn_state_zero_op'])
            for i in range(len(feat)):
                time3 = time.time()
                feed_dict = {self.X : feat[i], self.Y : label[i], self.seq_len : length[i]}
                time4 = time.time()
                run_need_op = {'train_op':run_op['train_op'],
                        'mean_loss':run_op['mean_loss'],
                        'loss':run_op['loss'],
                        'rnn_keep_state_op':run_op['rnn_keep_state_op'],
                        'label_error_rate':run_op['label_error_rate']}
                calculate_return = self.sess.run(run_need_op, feed_dict = feed_dict)
                time5 = time.time()
                print("thread_name: ", thread_name,  num_batch, " time:",time4-time3,time5-time4)
                print('label_error_rate:',calculate_return['label_error_rate'])
                print('mean_loss:',calculate_return['mean_loss'])

                print("thread_name: ", thread_name,  num_batch, " time:",time2-time1,time3-time2,time4-time3,time4-time1)
                num_batch += 1
                total_acc_error_rate += calculate_return['label_error_rate']
                self.acc_label_error_rate[gpu_id] += calculate_return['label_error_rate']
                self.num_batch[gpu_id] += 1
                if self.num_batch[gpu_id] % int(self.steps_per_checkpoint_cf/50) == 0:
                    logging.info("Batch: %d current averagelabel error rate : %f" % (self.num_batch[gpu_id], self.acc_label_error_rate[gpu_id] / self.num_batch[gpu_id]))
        logging.info('******end TrainFunction******')

    def GetFeatAndLabel(self):
        return self.input_queue.get()

    def GetAverageLabelErrorRate(self):
        tot_label_error_rate = 0.0
        tot_num_batch = 0
        for i in range(self.num_threads_cf):
            tot_label_error_rate += self.acc_label_error_rate[i]
            tot_num_batch += self.num_batch[i]
        if tot_num_batch == 0:
            average_label_error_rate = 1.0
        else:
            average_label_error_rate = tot_label_error_rate / tot_num_batch
        self.tot_lab_err_rate += tot_label_error_rate
        self.tot_num_batch += tot_num_batch
        #self.ResetAccuracy(tot_reset = False)
        return average_label_error_rate

    def GetTotLabErrRate(self):
        return self.tot_lab_err_rate/self.tot_num_batch

    def ResetAccuracy(self, tot_reset = True):
        for i in range(len(self.acc_label_error_rate)):
            self.acc_label_error_rate[i] = 0.0
            self.num_batch[i] = 0
        
        if tot_reset:
            self.tot_lab_err_rate = 0
            self.tot_num_batch = 0
            for i in range(5):
                self.all_lab_err_rate.append(1.1)
            self.num_save = 0
Ejemplo n.º 6
0
class TrainClass(object):
    '''
    '''
    def __init__(self, conf_dict):
        # configure paramtere
        self.conf_dict = conf_dict
        self.print_trainable_variables_cf = False
        self.use_normal_cf = False
        self.use_sgd_cf = True
        self.optimizer_cf = 'GD'
        self.use_sync_cf = False
        self.use_clip_cf = False
        self.restore_training_cf = True
        self.checkpoint_dir_cf = None
        self.num_threads_cf = 1
        self.queue_cache_cf = 100
        self.task_index_cf = -1
        self.grad_clip_cf = 5.0
        self.feature_transfile_cf = None
        self.learning_rate_cf = 0.1
        self.learning_rate_decay_steps_cf = 100000
        self.learning_rate_decay_rate_cf = 0.96
        self.l2_scale_cf = 0.00005
        self.batch_size_cf = 16
        self.num_frames_batch_cf = 20
        self.reset_global_step_cf = True

        self.steps_per_checkpoint_cf = 1000

        self.criterion_cf = 'ctc'
        self.silence_phones = []
        # initial configuration parameter
        for attr in self.__dict__:
            if len(attr.split('_cf')) != 2:
                continue
            key = attr.split('_cf')[0]
            if key in conf_dict.keys():
                if key in strset or type(conf_dict[key]) is not str:
                    self.__dict__[attr] = conf_dict[key]
                else:
                    print('***************', key)
                    self.__dict__[attr] = eval(conf_dict[key])

        if self.feature_transfile_cf == None:
            logging.info('No feature_transfile,it must have.')
            sys.exit(1)
        feat_trans = FeatureTransform()
        feat_trans.LoadTransform(self.feature_transfile_cf)

        # init train file
        self.kaldi_io_nstream_train = KaldiDataReadParallel()
        self.input_dim = self.kaldi_io_nstream_train.Initialize(
            self.conf_dict,
            scp_file=conf_dict['tr_scp'],
            label=conf_dict['tr_label'],
            lat_scp_file=conf_dict['lat_scp_file'],
            feature_transform=feat_trans,
            criterion=self.criterion_cf)
        # init cv file
        self.kaldi_io_nstream_cv = KaldiDataReadParallel()
        #self.kaldi_io_nstream_cv.Initialize(conf_dict,
        #        scp_file = conf_dict['cv_scp'], label = conf_dict['cv_label'],
        #        feature_transform = feat_trans, criterion = 'ctc')

        self.num_batch_total = 0
        self.tot_lab_err_rate = 0.0
        self.tot_num_batch = 0.0

        logging.info(self.kaldi_io_nstream_train.__repr__())
        #logging.info(self.kaldi_io_nstream_cv.__repr__())

        # Initial input queue.
        #self.input_queue = Queue.Queue(self.queue_cache_cf)

        self.acc_label_error_rate = []
        self.all_lab_err_rate = []
        self.num_save = 0
        for i in range(5):
            self.all_lab_err_rate.append(1.1)
        self.num_batch = []
        for i in range(self.num_threads_cf):
            self.acc_label_error_rate.append(1.0)
            self.num_batch.append(0)

        return

    # multi computers construct train graph
    def ConstructGraph(self, device, server):
        with tf.device(device):
            if 'cnn' in self.criterion_cf:
                self.X = tf.placeholder(
                    tf.float32,
                    [None, self.input_dim[0], self.input_dim[1], 1],
                    name='feature')
            else:
                self.X = tf.placeholder(
                    tf.float32, [None, self.batch_size_cf, self.input_dim],
                    name='feature')

            if 'ctc' in self.criterion_cf:
                self.Y = tf.sparse_placeholder(tf.int32, name="labels")
            elif 'whole' in self.criterion_cf:
                self.Y = tf.placeholder(tf.int32, [self.batch_size_cf, None],
                                        name="labels")
            elif 'ce' in self.criterion_cf:
                self.Y = tf.placeholder(
                    tf.int32, [self.batch_size_cf, self.num_frames_batch_cf],
                    name="labels")
            elif 'chain' in self.criterion_cf:
                self.Y = tf.placeholder(tf.float32, [self.batch_size_cf, None],
                                        name="labels")

            #
            if 'mmi' in self.criterion_cf or 'smbr' in self.criterion_cf or 'mpfe' in self.criterion_cf:
                self.indexs = tf.placeholder(tf.int32,
                                             [self.batch_size_cf, None, 2],
                                             name="indexs")
                self.pdf_values = tf.placeholder(tf.int32,
                                                 [self.batch_size_cf, None],
                                                 name="pdf_values")
                self.lm_ws = tf.placeholder(tf.float32,
                                            [self.batch_size_cf, None],
                                            name="lm_ws")
                self.am_ws = tf.placeholder(tf.float32,
                                            [self.batch_size_cf, None],
                                            name="am_ws")
                self.statesinfo = tf.placeholder(tf.int32,
                                                 [self.batch_size_cf, None, 2],
                                                 name="statesinfo")
                self.num_states = tf.placeholder(tf.int32,
                                                 [self.batch_size_cf],
                                                 name="num_states")
                self.lattice = [
                    self.indexs, self.pdf_values, self.lm_ws, self.am_ws,
                    self.statesinfo, self.num_states
                ]
            elif 'chain' in self.criterion_cf:
                self.indexs = tf.placeholder(tf.int32,
                                             [self.batch_size_cf, None, 2],
                                             name="indexs")
                self.in_labels = tf.placeholder(tf.int32,
                                                [self.batch_size_cf, None],
                                                name="in_labels")
                self.weights = tf.placeholder(tf.float32,
                                              [self.batch_size_cf, None],
                                              name="weights")
                self.statesinfo = tf.placeholder(tf.int32,
                                                 [self.batch_size_cf, None, 2],
                                                 name="statesinfo")
                self.num_states = tf.placeholder(tf.int32,
                                                 [self.batch_size_cf],
                                                 name="num_states")
                self.length = tf.placeholder(tf.int32, [None], name="length")
                self.fst = [
                    self.indexs, self.in_labels, self.weights, self.statesinfo,
                    self.num_states
                ]

            self.seq_len = tf.placeholder(tf.int32, [None], name='seq_len')

            #self.learning_rate_var_tf = tf.Variable(float(self.learning_rate_cf),
            #        trainable=False, name='learning_rate')
            # init global_step and learning rate decay criterion
            #self.global_step=tf.train.get_or_create_global_step()
            self.global_step = tf.Variable(0,
                                           trainable=False,
                                           name='global_step',
                                           dtype=tf.int64)
            exponential_decay = True
            piecewise_constant = False
            inverse_time_decay = False
            if exponential_decay == True:
                self.learning_rate_var_tf = tf.train.exponential_decay(
                    float(self.learning_rate_cf),
                    self.global_step,
                    self.learning_rate_decay_steps_cf,
                    self.learning_rate_decay_rate_cf,
                    staircase=True,
                    name='learning_rate_exponential_decay')
            elif piecewise_constant == True:
                boundaries = [100000, 110000]
                values = [1.0, 0.5, 0.1]
                self.learning_rate_var_tf = tf.train.piecewise_constant(
                    self.global_step, boundaries, values)
            elif inverse_time_decay == True:
                # decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
                # decay_rate = 0.5 , decay_step = 100000
                self.learning_rate_var_tf = tf.train.inverse_time_decay(
                    float(self.learning_rate_cf),
                    self.global_step,
                    self.learning_rate_decay_steps_cf,
                    self.learning_rate_decay_rate_cf,
                    staircase=True,
                    name='learning_rate_inverse_time_decay')
            else:
                #self.learning_rate_var_tf = float(self.learning_rate_cf)
                #self.learning_rate_var_tf = tf.Variable(float(self.learning_rate_cf), trainable=False)
                self.learning_rate_var_tf = tf.convert_to_tensor(
                    self.learning_rate_cf, name="learning_rate")
            #        trainable=False, name='learning_rate')

            if self.optimizer_cf == 'GD':
                optimizer = tf.train.GradientDescentOptimizer(
                    self.learning_rate_var_tf)
            elif self.optimizer_cf == 'Adam':
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=self.learning_rate_var_tf,
                    beta1=0.9,
                    beta2=0.999,
                    epsilon=1e-08,
                    use_locking=True)
            elif self.optimizer_cf == 'Adadelta':
                tf.train.AdadeltaOptimizer(
                    learning_rate=self.learning_rate_var_tf,
                    rho=0.95,
                    epsilon=1e-08,
                    use_locking=False,
                    name='Adadelta')
            elif self.optimizer_cf == 'AdagradDA':
                tf.train.AdagradDAOptimizer(
                    learning_rate=self.learning_rate_var_tf,
                    global_step=self.global_step,
                    initial_gradient_squared_accumulator_value=0.1,
                    l1_regularization_strength=0.0,
                    l2_regularization_strength=0.0,
                    use_locking=False,
                    name='AdagradDA')
            elif self.optimizer_cf == 'Adagrad':
                tf.train.AdagradOptimizer(
                    learning_rate=self.learning_rate_var_tf,
                    initial_accumulator_value=0.1,
                    use_locking=False,
                    name='Adagrad')
            else:
                logging.error("no this opyimizer.")
                sys.exit(1)

            # sync train
            if self.use_sync_cf:
                optimizer = tf.train.SyncReplicasOptimizer(
                    optimizer,
                    replicas_to_aggregate=32,
                    total_num_replicas=32,
                    use_locking=True)
                sync_replicas_hook = [
                    optimizer.make_session_run_hook(
                        is_chief=(self.task_index_cf == 0))
                ]
                logging.info("******use synchronization train******")
            else:
                sync_replicas_hook = None
                logging.info("******use asynchronization train******")
            nnet_model = LstmModel(self.conf_dict)

            mean_loss = None
            loss = None
            rnn_state_zero_op = None
            rnn_keep_state_op = None
            if 'ctc' in self.criterion_cf:
                ctc_mean_loss, ctc_loss, label_error_rate, _ = nnet_model.CtcLoss(
                    self.X, self.Y, self.seq_len)
                mean_loss = ctc_mean_loss
                loss = ctc_loss
#            elif 'ce' in self.criterion_cf and 'cnn' in self.criterion_cf and 'whole' in self.criterion_cf:
#                ce_mean_loss, ce_loss , label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.CeCnnBlstmLoss(self.X, self.Y, self.seq_len)
#                mean_loss = ce_mean_loss
#                loss = ce_loss
            elif 'ce' in self.criterion_cf:
                ce_mean_loss, ce_loss, label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.CeLoss(
                    self.X, self.Y, self.seq_len)
                mean_loss = ce_mean_loss
                loss = ce_loss
            elif 'mmi' in self.criterion_cf:
                mmi_mean_loss, mmi_loss, label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.MmiLoss(
                    self.X,
                    self.Y,
                    self.seq_len,
                    self.indexs,
                    self.pdf_values,
                    self.lm_ws,
                    self.am_ws,
                    self.statesinfo,
                    self.num_states,
                    old_acoustic_scale=0.0,
                    acoustic_scale=0.083,
                    time_major=True,
                    drop_frames=True)
                mean_loss = mmi_mean_loss
                loss = mmi_loss
            elif 'smbr' in self.criterion_cf or 'mpfe' in self.criterion_cf:
                if 'smbr' in self.criterion_cf:
                    criterion = 'smbr'
                else:
                    criterion = 'mpfe'
                pdf_to_phone = self.kaldi_io_nstream_train.pdf_to_phone
                log_priors = self.kaldi_io_nstream_train.pdf_prior

                mpe_mean_loss, mpe_loss, label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.MpeLoss(
                    self.X,
                    self.Y,
                    self.seq_len,
                    self.indexs,
                    self.pdf_values,
                    self.lm_ws,
                    self.am_ws,
                    self.statesinfo,
                    self.num_states,
                    log_priors=log_priors,
                    silence_phones=self.silence_phones,
                    pdf_to_phone=pdf_to_phone,
                    one_silence_class=True,
                    criterion=criterion,
                    old_acoustic_scale=0.0,
                    acoustic_scale=0.083,
                    time_major=True)
                mean_loss = mpe_mean_loss
                loss = mpe_loss
            elif 'chain' in self.criterion_cf:
                den_indexs, den_in_labels, den_weights, den_statesinfo, den_num_states, den_start_state, laststatesuperfinal = Fst2SparseMatrix(
                    self.conf_dict['den_fst'])
                label_dim = self.conf_dict['label_dim']
                delete_laststatesuperfinal = True
                l2_regularize = 0.00005
                leaky_hmm_coefficient = 0.1
                #xent_regularize = 0.025
                xent_regularize = 0.0
                #den_indexs = tf.convert_to_tensor(den_indexs, name='den_indexs')
                #den_in_labels = tf.convert_to_tensor(den_in_labels, name='den_in_labels')
                #den_weights = tf.convert_to_tensor(den_weights, name='den_weights')
                #den_statesinfo = tf.convert_to_tensor(den_statesinfo, name='den_statesinfo')
                #den_indexs = tf.make_tensor_proto(den_indexs)
                #den_in_labels = tf.make_tensor_proto(den_in_labels)
                #den_weights = tf.make_tensor_proto(den_weights)
                #den_statesinfo = tf.make_tensor_proto(den_statesinfo)
                den_indexs = np.reshape(den_indexs, [-1]).tolist()
                den_in_labels = np.reshape(den_in_labels, [-1]).tolist()
                den_weights = np.reshape(den_weights, [-1]).tolist()
                den_statesinfo = np.reshape(den_statesinfo, [-1]).tolist()
                if 'xent' in self.criterion_cf:
                    xent_regularize = 0.025
                    chain_mean_loss, chain_loss, label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.ChainXentLoss(
                        self.X, self.Y, self.indexs, self.in_labels,
                        self.weights, self.statesinfo, self.num_states,
                        self.length, label_dim, den_indexs, den_in_labels,
                        den_weights, den_statesinfo, den_num_states,
                        den_start_state, delete_laststatesuperfinal,
                        l2_regularize, leaky_hmm_coefficient, xent_regularize)
                else:
                    chain_mean_loss, chain_loss, label_error_rate, rnn_keep_state_op, rnn_state_zero_op = nnet_model.ChainLoss(
                        self.X, self.Y, self.indexs, self.in_labels,
                        self.weights, self.statesinfo, self.num_states,
                        self.length, label_dim, den_indexs, den_in_labels,
                        den_weights, den_statesinfo, den_num_states,
                        den_start_state, delete_laststatesuperfinal,
                        l2_regularize, leaky_hmm_coefficient, xent_regularize)
                mean_loss = chain_mean_loss
                loss = chain_loss
            else:
                logging.info("no criterion.")
                sys.exit(1)

            if self.use_sgd_cf:
                tvars = tf.trainable_variables()
                if self.use_normal_cf:
                    apply_l2_regu = tf.add_n([tf.nn.l2_loss(v) for v in tvars
                                              ]) * self.l2_scale_cf
                    #l2_regu = tf.contrib.layers.l2_regularizer(0.5)
                    #lstm_vars = [ var for var in tvars if 'lstm' in var.name ]
                    #apply_l2_regu = tf.contrib.layers.apply_regularization(l2_regu, lstm_vars)
                    #apply_l2_regu = tf.contrib.layers.apply_regularization(l2_regu, tvars)
                    mean_loss = mean_loss + apply_l2_regu

                #mean_loss = mean_loss + apply_l2_regu
                #grads_var = optimizer.compute_gradients(mean_loss+apply_l2_regu,
                #        var_list = tvars)
                #grads = [ g for g,_ in grads_var ]
                if self.use_clip_cf:
                    grads, gradient_norms = tf.clip_by_global_norm(
                        tf.gradients(mean_loss, tvars),
                        self.grad_clip_cf,
                        use_norm=None)
                    #grads, gradient_norms = tf.clip_by_global_norm(grads,
                    #        self.grad_clip_cf,
                    #        use_norm=None)

                    train_op = optimizer.apply_gradients(
                        zip(grads, tvars), global_step=self.global_step)
                else:
                    train_op = optimizer.minimize(mean_loss,
                                                  global_step=self.global_step)

            else:
                train_op = optimizer.minimize(mean_loss,
                                              global_step=self.global_step)

            # set run operation
            self.run_ops = {
                'train_op': train_op,
                'mean_loss': mean_loss,
                'loss': loss,
                'label_error_rate': label_error_rate,
                'rnn_keep_state_op': rnn_keep_state_op,
                'rnn_state_zero_op': rnn_state_zero_op
            }

            # set initial parameter
            self.init_para = tf.group(tf.global_variables_initializer(),
                                      tf.local_variables_initializer())

            #tmp_variables = tf.trainable_variables()
            #self.saver = tf.train.Saver(tmp_variables, max_to_keep=100)

            self.total_variables = np.sum([
                np.prod(v.get_shape().as_list())
                for v in tf.trainable_variables()
            ])
            logging.info('total parameters : %d' % self.total_variables)

            # set gpu option
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)

            # session config
            sess_config = tf.ConfigProto(
                intra_op_parallelism_threads=self.num_threads_cf,
                inter_op_parallelism_threads=self.num_threads_cf,
                allow_soft_placement=True,
                log_device_placement=False,
                gpu_options=gpu_options)
            sess_config = tf.ConfigProto(intra_op_parallelism_threads=2,
                                         inter_op_parallelism_threads=5,
                                         allow_soft_placement=True,
                                         log_device_placement=False,
                                         gpu_options=gpu_options)
            if self.reset_global_step_cf and self.task_index_cf == 0:
                non_train_variables = [self.global_step] + tf.local_variables()
            else:
                non_train_variables = tf.local_variables()

            ready_for_local_init_op = tf.variables_initializer(
                non_train_variables)
            local_init_op = tf.report_uninitialized_variables(
                var_list=non_train_variables)

            # add saver hook
            self.saver = tf.train.Saver(max_to_keep=50,
                                        sharded=False,
                                        allow_empty=False)
            scaffold = tf.train.Scaffold(
                init_op=None,
                init_feed_dict=None,
                init_fn=None,
                ready_op=None,
                ready_for_local_init_op=ready_for_local_init_op,
                local_init_op=local_init_op,
                summary_op=None,
                saver=self.saver,
                copy_from_scaffold=None)

            self.sess = tf.train.MonitoredTrainingSession(
                master=server.target,
                is_chief=(self.task_index_cf == 0),
                checkpoint_dir=self.checkpoint_dir_cf,
                scaffold=scaffold,
                hooks=sync_replicas_hook,
                chief_only_hooks=None,
                save_checkpoint_secs=None,
                save_summaries_steps=self.steps_per_checkpoint_cf,
                save_summaries_secs=None,
                config=sess_config,
                stop_grace_period_secs=120,
                log_step_count_steps=100,
                max_wait_secs=600,
                save_checkpoint_steps=self.steps_per_checkpoint_cf)
            #          summary_dir = self.checkpoint_dir_cf + "_summary_dir")
            #self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
            '''
            sv = tf.train.Supervisor(is_chief=(self.task_index_cf==0),
                    global_step=global_step,
                    init_op = self.init_para,
                    logdir = self.checkpoint_dir_cf,
                    saver=self.saver,
                    save_model_secs=600,
                    checkpoint_basename='model.ckpt')

            self.sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
            '''
        return

    def SaveTextModel(self):
        if self.print_trainable_variables_cf == True:
            ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir_cf)
            if ckpt and ckpt.model_checkpoint_path:
                print_trainable_variables(self.sess,
                                          ckpt.model_checkpoint_path + '.txt')


#    def InputFeat(self, input_lock):
#        while True:
#            input_lock.acquire()
#            '''
#            if 'ctc' in self.criterion_cf or 'whole' in self.criterion_cf:
#                if 'cnn' in self.criterion_cf:
#                    feat,label,length = self.kaldi_io_nstream.CnnLoadNextNstreams()
#                else:
#                    feat,label,length = self.kaldi_io_nstream.WholeLoadNextNstreams()
#                if length is None:
#                    break
#                print(np.shape(feat),np.shape(label), np.shape(length))
#                if len(label) != self.batch_size_cf:
#                    break
#                if 'ctc' in self.criterion_cf:
#                    sparse_label = sparse_tuple_from(label)
#                    self.input_queue.put((feat,sparse_label,length))
#                else:
#                    self.input_queue.put((feat,label,length))
#
#            elif 'ce' in self.criterion_cf:
#                if 'cnn' in self.criterion_cf:
#                    feat_array, label_array, length_array = self.kaldi_io_nstream.CnnSliceLoadNextNstreams()
#                else:
#                    feat_array, label_array, length_array = self.kaldi_io_nstream.SliceLoadNextNstreams()
#                if length_array is None:
#                    break
#                print(np.shape(feat_array),np.shape(label_array), np.shape(length_array))
#                if len(label_array[0]) != self.batch_size_cf:
#                    break
#                self.input_queue.put((feat_array, label_array, length_array))
#            '''
#            feat,label,length = self.kaldi_io_nstream.LoadBatch()
#            if length is None:
#                break
#            print(np.shape(feat),np.shape(label), np.shape(length))
#            sys.stdout.flush()
#            if 'ctc' in self.criterion_cf:
#                sparse_label = sparse_tuple_from(label)
#                self.input_queue.put((feat,sparse_label,length))
#            else:
#                self.input_queue.put((feat,label,length))
#            self.num_batch_total += 1
##            if self.num_batch_total % 3000 == 0:
##                self.SaveModel()
##                self.AdjustLearnRate()
#            print('total_batch_num**********',self.num_batch_total,'***********')
#            input_lock.release()
#        self.input_queue.put((None, None, None))
#
#    def ThreadInputFeatAndLab(self):
#        input_thread = []
#        input_lock = threading.Lock()
#        for i in range(1):
#            input_thread.append(threading.Thread(group=None, target=self.InputFeat,
#                args=(input_lock,),name='read_thread'+str(i)))
#
#        for thr in input_thread:
#            logging.info('ThreadInputFeatAndLab start')
#            thr.start()
#
#        return input_thread

#    def SaveModel(self):
#        while True:
#            time.sleep(1.0)
#            if self.input_queue.empty():
#                checkpoint_path = os.path.join(self.checkpoint_dir_cf,
#                        str(self.num_batch_total)+'_model'+'.ckpt')
#                logging.info('save model: '+checkpoint_path+
#                        ' --- learn_rate: ' +
#                        str(self.sess.run(self.learning_rate_var_tf)))
#                self.saver.save(self.sess, checkpoint_path)
#                break

# if current label error rate less then previous five

    def AdjustLearnRate(self):
        curr_lab_err_rate = self.GetAverageLabelErrorRate()
        logging.info("current label error rate : %f" % curr_lab_err_rate)
        all_lab_err_rate_len = len(self.all_lab_err_rate)
        for i in range(all_lab_err_rate_len):
            if curr_lab_err_rate < self.all_lab_err_rate[i]:
                break
            if i == len(self.all_lab_err_rate) - 1:
                self.DecayLearningRate(0.8)
                logging.info('learn_rate decay to ' +
                             str(self.sess.run(self.learning_rate_var_tf)))
        self.all_lab_err_rate[self.num_save %
                              all_lab_err_rate_len] = curr_lab_err_rate
        self.num_save += 1

    def DecayLearningRate(self, lr_decay_factor):
        learning_rate_decay_op = self.learning_rate_var_tf.assign(
            tf.multiply(self.learning_rate_var_tf, lr_decay_factor))
        self.sess.run(learning_rate_decay_op)
        logging.info('learn_rate decay to ' +
                     str(self.sess.run(self.learning_rate_var_tf)))
        logging.info('lr_decay_factor is ' + str(lr_decay_factor))

    # get restore model number
    def GetNum(self, str):
        return int(str.split('/')[-1].split('_')[0])

    # train_loss is a open train or cv .
    def TrainLogic(self,
                   device,
                   shuffle=False,
                   train_loss=True,
                   skip_offset=0):
        if train_loss == True:
            logging.info('TrainLogic train start.')
            logging.info('Start global step is %d---learn_rate is %f' %
                         (self.sess.run(tf.train.get_or_create_global_step()),
                          self.sess.run(self.learning_rate_var_tf)))
            self.kaldi_io_nstream = self.kaldi_io_nstream_train
            # set run operation
            if 'ctc' in self.criterion_cf or 'whole' in self.criterion_cf:
                run_op = {
                    'train_op': self.run_ops['train_op'],
                    'label_error_rate': self.run_ops['label_error_rate'],
                    'mean_loss': self.run_ops['mean_loss'],
                    'loss': self.run_ops['loss']
                }
            elif 'ce' in self.criterion_cf:
                run_op = {
                    'train_op': self.run_ops['train_op'],
                    'label_error_rate': self.run_ops['label_error_rate'],
                    'mean_loss': self.run_ops['mean_loss'],
                    'loss': self.run_ops['loss'],
                    'rnn_keep_state_op': self.run_ops['rnn_keep_state_op'],
                    'rnn_state_zero_op': self.run_ops['rnn_state_zero_op']
                }
            elif 'chain' in self.criterion_cf:
                run_op = {
                    'train_op': self.run_ops['train_op'],
                    'mean_loss': self.run_ops['mean_loss'],
                    'loss': self.run_ops['loss']
                }
            else:
                assert 'No train criterion.'
        else:
            logging.info('TrainLogic cv start.')
            self.kaldi_io_nstream = self.kaldi_io_nstream_cv
            run_op = {
                'label_error_rate': self.run_ops['label_error_rate'],
                'mean_loss': self.run_ops['mean_loss']
            }
        # reset io and start input thread
        self.kaldi_io_nstream.Reset(shuffle=shuffle, skip_offset=skip_offset)
        #threadinput = self.ThreadInputFeatAndLab()
        time.sleep(3)
        with tf.device(device):
            if 'ctc' in self.criterion_cf or 'whole' in self.criterion_cf or 'chain' in self.criterion_cf:
                self.WholeTrainFunction(0, run_op, 'train_ctc_thread_hubo')
            elif 'ce' in self.criterion_cf:
                self.SliceTrainFunction(0, run_op, 'train_ce_thread_hubo')
            else:
                logging.info('no criterion in train')
                sys.exit(1)

        tmp_label_error_rate = self.GetAverageLabelErrorRate()
        logging.info("current averagelabel error rate : %f" %
                     tmp_label_error_rate)
        logging.info('learn_rate is ' +
                     str(self.sess.run(self.learning_rate_var_tf)))

        if train_loss == True:
            self.AdjustLearnRate()
            logging.info('TrainLogic train end.')
        else:
            logging.info('TrainLogic cv end.')

        # End input thread
        self.kaldi_io_nstream.JoinInput()
        #for  i in range(len(threadinput)):
        #    threadinput[i].join()

        self.ResetAccuracy()
        return tmp_label_error_rate

    def WholeTrainFunction(self, gpu_id, run_op, thread_name):
        logging.info('******start WholeTrainFunction******')
        total_curr_error_rate = 0.0
        num_batch = 0
        self.acc_label_error_rate[gpu_id] = 0.0
        self.num_batch[gpu_id] = 0
        total_curr_mean_loss = 0.0
        total_mean_loss = 0.0

        #print_trainable_variables(self.sess, 'save.model.txt')
        while True:
            time1 = time.time()
            feat, label, length, lat_list = self.GetFeatAndLabel()
            if feat is None:
                logging.info('train thread end : %s' % thread_name)
                break
            time2 = time.time()

            if 'mmi' in self.criterion_cf or 'smbr' in self.criterion_cf or 'mpfe' in self.criterion_cf:
                feed_dict = {
                    self.X: feat,
                    self.Y: label,
                    self.seq_len: length,
                    self.indexs: lat_list[0],
                    self.pdf_values: lat_list[1],
                    self.lm_ws: lat_list[2],
                    self.am_ws: lat_list[3],
                    self.statesinfo: lat_list[4],
                    self.num_states: lat_list[5]
                }
            elif 'chain' in self.criterion_cf:
                # length is valid_length is int
                # self.Y is deriv_weights
                feed_dict = {
                    self.X: feat,
                    self.Y: label,
                    self.length: [length],
                    self.indexs: lat_list[0],
                    self.in_labels: lat_list[1],
                    self.weights: lat_list[2],
                    self.statesinfo: lat_list[3],
                    self.num_states: lat_list[4]
                }
            else:
                feed_dict = {self.X: feat, self.Y: label, self.seq_len: length}

            time3 = time.time()
            calculate_return = self.sess.run(run_op, feed_dict=feed_dict)
            time4 = time.time()

            print("thread_name: ", thread_name, num_batch, " time:",
                  time2 - time1, time3 - time2, time4 - time3, time4 - time1)
            if 'label_error_rate' in calculate_return.keys():
                print('label_error_rate:',
                      calculate_return['label_error_rate'])
                total_curr_error_rate += calculate_return['label_error_rate']
                self.acc_label_error_rate[gpu_id] += calculate_return[
                    'label_error_rate']
            else:
                total_curr_error_rate += 0.0
                self.acc_label_error_rate[gpu_id] += 0.0
            print('mean_loss:', calculate_return['mean_loss'])
            print('loss:', calculate_return['loss'])

            if type(calculate_return['mean_loss']) is list:
                total_curr_mean_loss += calculate_return['mean_loss'][0]
                total_mean_loss += calculate_return['mean_loss'][0]
            else:
                total_curr_mean_loss += calculate_return['mean_loss']
                total_mean_loss += calculate_return['mean_loss']

            num_batch += 1

            self.num_batch[gpu_id] += 1
            if self.num_batch[gpu_id] % int(
                    self.steps_per_checkpoint_cf / 50) == 0:
                logging.info(
                    "Batch: %d current averagelabel error rate : %s, mean loss : %s"
                    % (int(self.steps_per_checkpoint_cf / 50),
                       str(total_curr_error_rate /
                           int(self.steps_per_checkpoint_cf / 50)),
                       str(total_curr_mean_loss /
                           int(self.steps_per_checkpoint_cf / 50))))
                total_curr_error_rate = 0.0
                total_curr_mean_loss = 0.0
                logging.info(
                    "Batch: %d current total averagelabel error rate : %s,  mean loss : %s"
                    % (self.num_batch[gpu_id],
                       str(self.acc_label_error_rate[gpu_id] /
                           self.num_batch[gpu_id]),
                       str(total_mean_loss / self.num_batch[gpu_id])))
        logging.info('******end TrainFunction******')

    def SliceTrainFunction(self, gpu_id, run_op, thread_name):
        logging.info('******start SliceTrainFunction******')
        total_curr_error_rate = 0.0
        num_batch = 0
        self.acc_label_error_rate[gpu_id] = 0.0
        self.num_batch[gpu_id] = 0
        total_curr_mean_loss = 0.0
        total_mean_loss = 0.0

        num_sentence = 0
        while True:
            time1 = time.time()
            feat, label, length, lat_list = self.GetFeatAndLabel()
            if feat is None:
                logging.info('train thread end : %s' % thread_name)
                break
            time2 = time.time()
            self.sess.run(run_op['rnn_state_zero_op'])
            for i in range(len(feat)):
                print('************input info**********:',
                      np.shape(feat[i]),
                      np.shape(label[i]),
                      length[i],
                      flush=True)
                time3 = time.time()
                feed_dict = {
                    self.X: feat[i],
                    self.Y: label[i],
                    self.seq_len: length[i]
                }
                time4 = time.time()
                run_need_op = {
                    'train_op': run_op['train_op'],
                    'mean_loss': run_op['mean_loss'],
                    'loss': run_op['loss'],
                    'rnn_keep_state_op': run_op['rnn_keep_state_op'],
                    'label_error_rate': run_op['label_error_rate']
                }
                calculate_return = self.sess.run(run_need_op,
                                                 feed_dict=feed_dict)
                time5 = time.time()
                print("thread_name: ", thread_name, num_batch, " time:",
                      time4 - time3, time5 - time4)
                print('label_error_rate:',
                      calculate_return['label_error_rate'])
                print('mean_loss:', calculate_return['mean_loss'])

                total_curr_mean_loss += calculate_return['mean_loss']
                total_mean_loss += calculate_return['mean_loss']

                num_batch += 1
                total_curr_error_rate += calculate_return['label_error_rate']
                self.acc_label_error_rate[gpu_id] += calculate_return[
                    'label_error_rate']
                self.num_batch[gpu_id] += 1
                if self.num_batch[gpu_id] % int(
                        self.steps_per_checkpoint_cf / 50) == 0:
                    logging.info(
                        "Batch: %d current averagelabel error rate : %f, mean loss : %f"
                        % (int(self.steps_per_checkpoint_cf / 50),
                           total_curr_error_rate /
                           int(self.steps_per_checkpoint_cf / 50),
                           total_curr_mean_loss /
                           int(self.steps_per_checkpoint_cf / 50)))
                    total_curr_error_rate = 0.0
                    total_curr_mean_loss = 0.0
                    logging.info(
                        "Batch: %d current total averagelabel error rate : %f, mean loss : %f"
                        % (self.num_batch[gpu_id],
                           self.acc_label_error_rate[gpu_id] /
                           self.num_batch[gpu_id],
                           total_mean_loss / self.num_batch[gpu_id]))
            # print batch sentence time info
            print("******thread_name: ", thread_name, num_sentence, "io time:",
                  time2 - time1, "calculation time:", time5 - time1)
            num_sentence += 1

        logging.info('******end SliceTrainFunction******')

    def GetFeatAndLabel(self):
        return self.kaldi_io_nstream.GetInput()

    def GetAverageLabelErrorRate(self):
        tot_label_error_rate = 0.0
        tot_num_batch = 0
        for i in range(self.num_threads_cf):
            tot_label_error_rate += self.acc_label_error_rate[i]
            tot_num_batch += self.num_batch[i]
        if tot_num_batch == 0:
            average_label_error_rate = 1.0
        else:
            average_label_error_rate = tot_label_error_rate / tot_num_batch
        self.tot_lab_err_rate += tot_label_error_rate
        self.tot_num_batch += tot_num_batch
        #self.ResetAccuracy(tot_reset = False)
        return average_label_error_rate

    def GetTotLabErrRate(self):
        return self.tot_lab_err_rate / self.tot_num_batch

    def ResetAccuracy(self, tot_reset=True):
        for i in range(len(self.acc_label_error_rate)):
            self.acc_label_error_rate[i] = 0.0
            self.num_batch[i] = 0

        if tot_reset:
            self.tot_lab_err_rate = 0
            self.tot_num_batch = 0
            for i in range(5):
                self.all_lab_err_rate.append(1.1)
            self.num_save = 0
Ejemplo n.º 7
0
class TrainClass(object):
    '''
    '''
    def __init__(self, conf_dict):
        # configure paramtere
        self.conf_dict = conf_dict
        self.print_trainable_variables_cf = False
        self.use_normal_cf = False
        self.use_sgd_cf = True
        self.restore_training_cf = True
        self.checkpoint_dir_cf = None
        self.num_threads_cf = 1
        self.queue_cache_cf = 100
        self.task_index_cf = -1
        self.grad_clip_cf = 5.0
        self.feature_transfile_cf = None
        self.learning_rate_cf = 0.001
        self.batch_size_cf = 10

        # initial configuration parameter
        for attr in self.__dict__:
            if len(attr.split('_cf')) != 2:
                continue;
            key = attr.split('_cf')[0]
            if key in conf_dict.keys():
                self.__dict__[attr] = conf_dict[key]

        if self.feature_transfile_cf == None:
            logging.info('No feature_transfile,it must have.')
            sys.exit(1)
        feat_trans = FeatureTransform()
        feat_trans.LoadTransform(self.feature_transfile_cf)

        # init train file
        self.kaldi_io_nstream_train = KaldiDataReadParallel()
        self.input_dim = self.kaldi_io_nstream_train.Initialize(conf_dict,
                scp_file = conf_dict['tr_scp'], label = conf_dict['tr_label'],
                feature_transform = feat_trans, criterion = 'ctc')
        # init cv file
        self.kaldi_io_nstream_cv = KaldiDataReadParallel()
        self.kaldi_io_nstream_cv.Initialize(conf_dict,
                scp_file = conf_dict['cv_scp'], label = conf_dict['cv_label'],
                feature_transform = feat_trans, criterion = 'ctc')

        self.num_batch_total = 0
        self.tot_lab_err_rate = 0.0
        self.tot_num_batch = 0.0

        logging.info(self.kaldi_io_nstream_train.__repr__())
        logging.info(self.kaldi_io_nstream_cv.__repr__())
        
        # Initial input queue.
        self.input_queue = Queue.Queue(self.queue_cache_cf)

        self.acc_label_error_rate = []
        self.all_lab_err_rate = []
        self.num_save = 0
        for i in range(5):
            self.all_lab_err_rate.append(1.1)
        self.num_batch = []
        for i in range(self.num_threads_cf):
            self.acc_label_error_rate.append(1.0)
            self.num_batch.append(0)



        return
    
    # multi computers construct train graph
    def ConstructGraph(self, device, server):
        with tf.device(device):
            self.X = tf.placeholder(tf.float32, [None, None, self.input_dim], 
                    name='feature')
            self.Y = tf.sparse_placeholder(tf.int32, name="labels")
            self.seq_len = tf.placeholder(tf.int32,[None], name = 'seq_len')
            self.learning_rate_var_tf = tf.Variable(float(self.learning_rate_cf), 
                    trainable=False, name='learning_rate')
            if self.use_sgd_cf:
                optimizer = tf.train.GradientDescentOptimizer(self.learning_rate_var_tf)
            else:
                optimizer = tf.train.AdamOptimizer(learning_rate=
                        self.learning_rate_var_tf, beta1=0.9, beta2=0.999, epsilon=1e-08)
            nnet_model = LstmModel(self.conf_dict)

            ctc_mean_loss, ctc_loss , label_error_rate, decoded = nnet_model.CtcLoss(self.X, self.Y, self.seq_len)

            if self.use_sgd_cf and self.use_normal_cf:
                tvars = tf.trainable_variables()
                grads, _ = tf.clip_by_global_norm(tf.gradients(
                    ctc_mean_loss, tvars), self.grad_clip_cf)
                train_op = optimizer.apply_gradients(
                        zip(grads, tvars),
                        global_step=tf.contrib.framework.get_or_create_global_step())
            else:
                train_op = optimizer.minimize(ctc_mean_loss)
            # set run operation
            self.run_ops = {'train_op':train_op,
                    'ctc_mean_loss':ctc_mean_loss,
                    'ctc_loss':ctc_loss,
                    'label_error_rate':label_error_rate}

            # set initial parameter
            self.init_para = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())

            tmp_variables=tf.trainable_variables()
            self.saver = tf.train.Saver(tmp_variables, max_to_keep=100)

            self.total_variables = np.sum([np.prod(v.get_shape().as_list()) 
                for v in tf.trainable_variables()])
            logging.info('total parameters : %d' % self.total_variables)

            # set gpu option
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)

            # session config
            sess_config = tf.ConfigProto(intra_op_parallelism_threads=self.num_threads_cf,
                    inter_op_parallelism_threads=self.num_threads_cf,
                    allow_soft_placement=True,
                    log_device_placement=False,gpu_options=gpu_options)
            global_step = tf.contrib.framework.get_or_create_global_step()

            sv = tf.train.Supervisor(is_chief=(self.task_index_cf==0),
                    global_step=global_step,
                    init_op = self.init_para,
                    logdir = self.checkpoint_dir_cf,
                    saver=self.saver,
                    save_model_secs=3600,
                    checkpoint_basename='model.ckpt')

            self.sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)

        return 

    def SaveTextModel(self):
        if self.print_trainable_variables_cf == True:
            ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir_cf)
            if ckpt and ckpt.model_checkpoint_path:
                print_trainable_variables(self.sess, ckpt.model_checkpoint_path+'.txt')

    def InputFeat(self, input_lock):
        while True:
            input_lock.acquire()
            feat,label,length = self.kaldi_io_nstream.LoadNextNstreams()
            if length is None:
                break
            if len(label) != self.batch_size_cf:
                break
            sparse_label = sparse_tuple_from(label)
            self.input_queue.put((feat,sparse_label,length))
            self.num_batch_total += 1
            if self.num_batch_total % 3000 == 0:
                self.SaveModel()
                self.AdjustLearnRate()
            print('total_batch_num**********',self.num_batch_total,'***********')
            input_lock.release()
        self.input_queue.put((None, None, None))

    def ThreadInputFeatAndLab(self):
        input_thread = []
        input_lock = threading.Lock()
        for i in range(1):
            input_thread.append(threading.Thread(group=None, target=self.InputFeat,
                args=(input_lock,),name='read_thread'+str(i)))

        for thr in input_thread:
            logging.info('ThreadInputFeatAndLab start')
            thr.start()

        return input_thread

    def SaveModel(self):
        while True:
            time.sleep(1.0)
            if self.input_queue.empty():
                checkpoint_path = os.path.join(self.checkpoint_dir, str(self.num_batch_total)+'_model'+'.ckpt')
                logging.info('save model: '+checkpoint_path+
                        ' --- learn_rate: ' +
                        str(self.sess.run(self.learning_rate_var_tf)))
                self.saver.save(self.sess, checkpoint_path)
                break

    # if current label error rate less then previous five
    def AdjustLearnRate(self):
        curr_lab_err_rate = self.GetAvergaeLabelErrorRate()
        logging.info("current label error rate : %f\n" % curr_lab_err_rate)
        all_lab_err_rate_len = len(self.all_lab_err_rate)
        for i in range(all_lab_err_rate_len):
            if curr_lab_err_rate < self.all_lab_err_rate[i]:
                break
            if i == len(self.all_lab_err_rate)-1:
                self.decay_learning_rate(0.8)
        self.all_lab_err_rate[self.num_save%all_lab_err_rate_len] = curr_lab_err_rate
        self.num_save += 1

    # train_loss is a open train or cv .
    def TrainLogic(self, device, shuffle = False, train_loss = True, skip_offset = 0):
        if train_loss == True:
            self.kaldi_io_nstream = self.kaldi_io_nstream_train
            run_op = self.run_ops
        else:
            self.kaldi_io_nstream = self.kaldi_io_nstream_cv
            run_op = {'label_error_rate':run_op['label_error_rate'],
                    'ctc_mean_loss':run_op['ctc_mean_loss']}
        
        threadinput = self.ThreadInputFeatAndLab()
        time.sleep(3)
        logging.info('train start.')
        with tf.device(device):
            self.TrainFunction(0, run_op, 'train_thread_hubo')

        logging.info('train end.')
        threadinput[0].join()

        tmp_label_error_rate = self.GetAvergaeLabelErrorRate()
        self.kaldi_io_nstream.Reset(shuffle = shuffle, skip_offset = skip_offset)
        self.ResetAccuracy()
        return tmp_label_error_rate

    def TrainFunction(self, gpu_id, run_op, thread_name):
        logging.info('******start TrainFunction******')
        total_acc_error_rate = 0.0
        num_batch = 0
        self.acc_label_error_rate[gpu_id] = 0.0
        self.num_batch[gpu_id] = 0

        while True:
            time1=time.time()
            feat, sparse_label, length = self.GetFeatAndLabel()
            if feat is None:
                logging.info('train thread end : %s' % thread_name)
                break
            time2=time.time()

            feed_dict = {self.X : feat, self.Y : sparse_label, self.seq_len : length}
            time3 = time.time()
            calculate_return = self.sess.run(run_op, feed_dict = feed_dict)
            time4 = time.time()

            print("thread_name: ", thread_name,  num_batch," time:",time2-time1,time3-time2,time4-time3,time4-time1)
            print('label_error_rate:',calculate_return['label_error_rate'])
            print('ctc_mean_loss:',calculate_return['ctc_mean_loss'])
            #print('ctc_loss:',calculate_return['ctc_loss'])

            num_batch += 1
            total_acc_error_rate += calculate_return['label_error_rate']
            self.acc_label_error_rate[gpu_id] += calculate_return['label_error_rate']
            self.num_batch[gpu_id] += 1
        logging.info('******end TrainFunction******')
    
    def GetFeatAndLabel(self):
        return self.input_queue.get()

    def GetAvergaeLabelErrorRate(self):
        tot_label_error_rate = 0.0
        tot_num_batch = 0
        for i in range(self.num_threads_cf):
            tot_label_error_rate += self.acc_label_error_rate[i]
            tot_num_batch += self.num_batch[i]
        if tot_num_batch == 0:
            average_label_error_rate = 1.0
        else:
            average_label_error_rate = tot_label_error_rate / tot_num_batch
        self.tot_lab_err_rate += tot_label_error_rate
        self.tot_num_batch += tot_num_batch
        self.ResetAccuracy(tot_reset = False)
        return average_label_error_rate

    def GetTotLabErrRate(self):
        return self.tot_lab_err_rate/self.tot_num_batch

    def ResetAccuracy(self, tot_reset = True):
        for i in range(len(self.acc_label_error_rate)):
            self.acc_label_error_rate[i] = 0.0
            self.num_batch[i] = 0
        
        if tot_reset:
            self.tot_lab_err_rate = 0
            self.tot_num_batch = 0
            for i in range(5):
                self.all_lab_err_rate.append(1.1)
            self.num_save = 0
Ejemplo n.º 8
0
class train_class(object):
    def __init__(self, conf_dict):
        self.nnet_conf = ProjConfig()
        self.nnet_conf.initial(conf_dict)

        self.kaldi_io_nstream = None
        feat_trans = FeatureTransform()
        feat_trans.LoadTransform(conf_dict['feature_transfile'])
        # init train file
        self.kaldi_io_nstream_train = KaldiDataReadParallel()
        self.input_dim = self.kaldi_io_nstream_train.Initialize(
            conf_dict,
            scp_file=conf_dict['scp_file'],
            label=conf_dict['label'],
            feature_transform=feat_trans,
            criterion='ce')

        # init cv file
        self.kaldi_io_nstream_cv = KaldiDataReadParallel()
        self.kaldi_io_nstream_cv.Initialize(conf_dict,
                                            scp_file=conf_dict['cv_scp'],
                                            label=conf_dict['cv_label'],
                                            feature_transform=feat_trans,
                                            criterion='ce')

        self.num_batch_total = 0
        self.num_frames_total = 0

        logging.info(self.nnet_conf.__repr__())
        logging.info(self.kaldi_io_nstream_train.__repr__())
        logging.info(self.kaldi_io_nstream_cv.__repr__())

        self.print_trainable_variables = False
        if conf_dict.has_key('print_trainable_variables'):
            self.print_trainable_variables = conf_dict[
                'print_trainable_variables']
        self.tf_async_model_prefix = conf_dict['checkpoint_dir']
        self.num_threads = conf_dict['num_threads']
        self.queue_cache = conf_dict['queue_cache']
        self.input_queue = Queue.Queue(self.queue_cache)
        self.acc_label_error_rate = []
        for i in range(self.num_threads):
            self.acc_label_error_rate.append(1.1)
        if conf_dict.has_key('use_normal'):
            self.use_normal = conf_dict['use_normal']
        else:
            self.use_normal = False
        if conf_dict.has_key('use_sgd'):
            self.use_sgd = conf_dict['use_sgd']
        else:
            self.use_sgd = True

        if conf_dict.has_key('restore_training'):
            self.restore_training = conf_dict['restore_training']
        else:
            self.restore_training = False

    def get_num(self, str):
        return int(str.split('/')[-1].split('_')[0])
        #model_48434.ckpt.final
    def construct_graph(self):
        with tf.Graph().as_default():
            self.run_ops = []
            #self.X = tf.placeholder(tf.float32, [None, None, self.input_dim], name='feature')
            print(self.nnet_conf.num_frames_batch, self.nnet_conf.batch_size,
                  self.input_dim)
            self.X = tf.placeholder(tf.float32, [
                self.nnet_conf.num_frames_batch, self.nnet_conf.batch_size,
                self.input_dim
            ],
                                    name='feature')
            #self.Y = tf.sparse_placeholder(tf.int32, name="labels")
            self.Y = tf.placeholder(
                tf.int32,
                [self.nnet_conf.batch_size, self.nnet_conf.num_frames_batch],
                name="labels")
            self.seq_len = tf.placeholder(tf.int32, [None], name='seq_len')

            self.learning_rate_var = tf.Variable(float(
                self.nnet_conf.learning_rate),
                                                 trainable=False,
                                                 name='learning_rate')
            if self.use_sgd:
                optimizer = tf.train.GradientDescentOptimizer(
                    self.learning_rate_var)
            else:
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=self.learning_rate_var,
                    beta1=0.9,
                    beta2=0.999,
                    epsilon=1e-08)

            for i in range(self.num_threads):
                with tf.device("/gpu:%d" % i):
                    initializer = tf.random_uniform_initializer(
                        -self.nnet_conf.init_scale, self.nnet_conf.init_scale)
                    model = LSTM_Model(self.nnet_conf)
                    mean_loss, ce_loss, rnn_keep_state_op, rnn_state_zero_op, label_error_rate, softval = model.ce_train(
                        self.X, self.Y, self.seq_len)
                    if self.use_sgd and self.use_normal:
                        tvars = tf.trainable_variables()
                        grads, _ = tf.clip_by_global_norm(
                            tf.gradients(mean_loss, tvars),
                            self.nnet_conf.grad_clip)
                        train_op = optimizer.apply_gradients(
                            zip(grads, tvars),
                            global_step=tf.contrib.framework.
                            get_or_create_global_step())
                    else:
                        train_op = optimizer.minimize(mean_loss)

                    run_op = {
                        'train_op': train_op,
                        'mean_loss': mean_loss,
                        'ce_loss': ce_loss,
                        'rnn_keep_state_op': rnn_keep_state_op,
                        'rnn_state_zero_op': rnn_state_zero_op,
                        'label_error_rate': label_error_rate,
                        'softval': softval
                    }
                    self.run_ops.append(run_op)
                    tf.get_variable_scope().reuse_variables()

            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)
            self.sess = tf.Session(config=tf.ConfigProto(
                intra_op_parallelism_threads=self.num_threads,
                allow_soft_placement=True,
                log_device_placement=False,
                gpu_options=gpu_options))
            init = tf.group(tf.global_variables_initializer(),
                            tf.local_variables_initializer())
            tmp_variables = tf.trainable_variables()
            self.saver = tf.train.Saver(tmp_variables, max_to_keep=100)
            #self.saver = tf.train.Saver(max_to_keep=100)
            if self.restore_training:
                self.sess.run(init)
                ckpt = tf.train.get_checkpoint_state(
                    self.tf_async_model_prefix)
                if ckpt and ckpt.model_checkpoint_path:
                    logging.info("restore training")
                    self.saver.restore(self.sess, ckpt.model_checkpoint_path)
                    self.num_batch_total = self.get_num(
                        ckpt.model_checkpoint_path)
                    if self.print_trainable_variables == True:
                        print_trainable_variables(
                            self.sess, ckpt.model_checkpoint_path + '.txt')
                        sys.exit(0)

                    logging.info('model:' + ckpt.model_checkpoint_path)
                    logging.info('restore learn_rate:' +
                                 str(self.sess.run(self.learning_rate_var)))
                    #print('*******************',self.num_batch_total)
                    #time.sleep(3)
                    #model_48434.ckpt.final
                    #print("ckpt.model_checkpoint_path",ckpt.model_checkpoint_path)
                    #print("self.tf_async_model_prefix",self.tf_async_model_prefix)
                    #self.saver.restore(self.sess, self.tf_async_model_prefix)

                else:
                    logging.info('No checkpoint file found')
                    self.sess.run(init)
                    logging.info('init learn_rate:' +
                                 str(self.sess.run(self.learning_rate_var)))
            else:
                self.sess.run(init)

            self.total_variables = np.sum([
                np.prod(v.get_shape().as_list())
                for v in tf.trainable_variables()
            ])
            logging.info('total parameters : %d' % self.total_variables)

    def train_function(self, gpu_id, run_op, thread_name):
        total_acc_error_rate = 0.0
        num_batch = 0
        num_bptt = 0
        while True:
            time1 = time.time()
            feat, label, length = self.get_feat_and_label()
            if feat is None:
                logging.info('train thread ok: %s' % thread_name)
                break
            time2 = time.time()
            print('******time:', time2 - time1, thread_name)
            self.sess.run(run_op['rnn_state_zero_op'])
            for i in range(len(feat)):
                feed_dict = {
                    self.X: feat[i],
                    self.Y: label[i],
                    self.seq_len: length[i]
                }
                run_need_op = {
                    'train_op': run_op['train_op'],
                    'mean_loss': run_op['mean_loss'],
                    'ce_loss': run_op['ce_loss'],
                    'rnn_keep_state_op': run_op['rnn_keep_state_op'],
                    'label_error_rate': run_op['label_error_rate']
                }
                time3 = time.time()
                calculate_return = self.sess.run(run_need_op,
                                                 feed_dict=feed_dict)
                print('mean_loss:', calculate_return['mean_loss'])
                #print('ce_loss:',calculate_return['ce_loss'])
                #self.sess.run(run_op['rnn_keep_state_op'])
                time4 = time.time()
                print(num_batch, " time:", time4 - time3)
                print('label_error_rate:',
                      calculate_return['label_error_rate'])
                total_acc_error_rate += calculate_return['label_error_rate']
                num_bptt += 1
            time5 = time.time()

            print(num_batch, " time:", time2 - time1, time5 - time2)

            num_batch += 1
            #total_acc_error_rate += calculate_return['label_error_rate']
            self.acc_label_error_rate[gpu_id] = total_acc_error_rate / num_bptt
        self.acc_label_error_rate[gpu_id] = total_acc_error_rate / num_bptt

    def cv_function(self, gpu_id, run_op, thread_name):
        total_acc_error_rate = 0.0
        num_batch = 0
        num_bptt = 0
        while True:
            feat, label, length = self.get_feat_and_label()
            if feat is None:
                logging.info('cv ok : %s\n' % thread_name)
                break
            self.sess.run(run_op['rnn_state_zero_op'])
            for i in range(len(feat)):
                feed_dict = {
                    self.X: feat[i],
                    self.Y: label[i],
                    self.seq_len: length[i]
                }
                run_need_op = {
                    'mean_loss': run_op['mean_loss'],
                    'ce_loss': run_op['ce_loss'],
                    'rnn_keep_state_op': run_op['rnn_keep_state_op'],
                    'label_error_rate': run_op['label_error_rate']
                }
                #'softval':run_op['softval']}
                calculate_return = self.sess.run(run_need_op,
                                                 feed_dict=feed_dict)
                total_acc_error_rate += calculate_return['label_error_rate']
                #                print('feat:',feat[i])
                print('label_error_rate:',
                      calculate_return['label_error_rate'])
                print('mean_loss:', calculate_return['mean_loss'])
                print('ce_loss', calculate_return['ce_loss'])
                #                print(i,'softval', calculate_return['softval'])
                #                print('rnn_keep_state_op', calculate_return['rnn_keep_state_op'])
                num_bptt += 1
            num_batch += 1
            self.acc_label_error_rate[gpu_id] = total_acc_error_rate / num_bptt
        self.acc_label_error_rate[gpu_id] = total_acc_error_rate / num_bptt

    def get_feat_and_label(self):
        return self.input_queue.get()

    def input_feat_and_label(self):
        feat, label, length = self.kaldi_io_nstream.LoadNextNstreams()
        if length is None:
            return False
        if len(label) != self.nnet_conf.batch_size:
            return False
        sparse_label = sparse_tuple_from(label)
        self.input_queue.put((feat, sparse_label, length))
        self.num_batch_total += 1
        for i in length:
            self.num_frames_total += i
        print('total_batch_num**********', self.num_batch_total, '***********')
        return True

    def input_ce_feat_and_label(self):
        feat_array, label_array, length_array = self.kaldi_io_nstream.SliceLoadNextNstreams(
        )
        if length_array is None:
            return False
        if len(label_array[0]) != self.nnet_conf.batch_size:
            return False
        #process feature
        #sparse_label_array = []
        #for lab in label:
        #    sparse_label_array.append(sparse_tuple_from(lab))
        self.input_queue.put((feat_array, label_array, length_array))
        self.num_batch_total += 1
        for batch_len in length_array:
            for i in batch_len:
                self.num_frames_total += i
        print('total_batch_num**********', self.num_batch_total, '***********')
        return True

    def cv_logic(self):
        self.kaldi_io_nstream = self.kaldi_io_nstream_cv
        train_thread = []
        #first start cv thread
        for i in range(self.num_threads):
            train_thread.append(
                threading.Thread(group=None,
                                 target=self.cv_function,
                                 args=(i, self.run_ops[i],
                                       'thread_hubo_' + str(i)),
                                 name='thread_hubo_' + str(i)))

        for thr in train_thread:
            thr.start()

        logging.info('start cv thread.')
        while True:
            # input data
            if self.input_ce_feat_and_label():
                continue
            break
        logging.info('cv read feat ok')

        for thr in train_thread:
            self.input_queue.put((None, None, None))

        while True:
            if self.input_queue.empty():
                logging.info('cv is ok')
                break

        for thr in train_thread:
            thr.join()
            logging.info('join cv thread %s' % thr.name)

        tmp_label_error_rate = self.get_avergae_label_error_rate()
        self.kaldi_io_nstream.Reset()
        self.reset_acc()
        return tmp_label_error_rate

    def train_logic(self):
        self.kaldi_io_nstream = self.kaldi_io_nstream_train
        train_thread = []
        #first start train thread
        for i in range(self.num_threads):
            #self.acc_label_error_rate.append(1.0)
            train_thread.append(
                threading.Thread(group=None,
                                 target=self.train_function,
                                 args=(i, self.run_ops[i],
                                       'thread_hubo_' + str(i)),
                                 name='thread_hubo_' + str(i)))

        for thr in train_thread:
            thr.start()

        logging.info('start train thread ok.\n')

        all_lab_err_rate = []
        for i in range(5):
            all_lab_err_rate.append(1.1)

        while True:
            # save model
            if self.num_batch_total % 1000 == 0:
                while True:
                    #print('wait save mode')
                    time.sleep(0.5)
                    if self.input_queue.empty():
                        checkpoint_path = os.path.join(
                            self.tf_async_model_prefix,
                            str(self.num_batch_total) + '_model' + '.ckpt')
                        logging.info(
                            'save model: ' + checkpoint_path +
                            ' --- learn_rate: ' +
                            str(self.sess.run(self.learning_rate_var)))
                        self.saver.save(self.sess, checkpoint_path)

                        if self.num_batch_total == 0:
                            break

                        curr_lab_err_rate = self.get_avergae_label_error_rate()
                        all_lab_err_rate.sort()
                        for i in range(len(all_lab_err_rate)):
                            if curr_lab_err_rate < all_lab_err_rate[i]:
                                all_lab_err_rate[len(all_lab_err_rate) -
                                                 1] = curr_lab_err_rate
                                break
                            if i == len(all_lab_err_rate) - 1:
                                train_logic.decay_learning_rate(0.5)
                                all_lab_err_rate[len(all_lab_err_rate) -
                                                 1] = curr_lab_err_rate
                        break
            # input data
            if self.input_ce_feat_and_label():
                continue
            break
        time.sleep(1)
        logging.info('read feat ok')
        '''
            end train
        '''
        for thr in train_thread:
            self.input_queue.put((None, None, None))

        while True:
            if self.input_queue.empty():
                logging.info('train is end')
                checkpoint_path = os.path.join(
                    self.tf_async_model_prefix,
                    str(self.num_batch_total) + '_model' + '.ckpt')
                self.saver.save(self.sess, checkpoint_path + '.final')
                break
        '''
            train is end
        '''
        for thr in train_thread:
            thr.join()
            logging.info('join thread %s' % thr.name)

        tmp_label_error_rate = self.get_avergae_label_error_rate()
        self.kaldi_io_nstream.Reset()
        self.reset_acc()
        return tmp_label_error_rate

    def decay_learning_rate(self, lr_decay_factor):
        learning_rate_decay_op = self.learning_rate_var.assign(
            tf.multiply(self.learning_rate_var, lr_decay_factor))
        self.sess.run(learning_rate_decay_op)
        logging.info('learn_rate decay to ' +
                     str(self.sess.run(self.learning_rate_var)))
        logging.info('lr_decay_factor is ' + str(lr_decay_factor))


#        return learning_rate_decay_op

    def get_avergae_label_error_rate(self):
        average_label_error_rate = 0.0
        for i in range(self.num_threads):
            average_label_error_rate += self.acc_label_error_rate[i]
        average_label_error_rate /= self.num_threads
        logging.info("average label error rate : %f" %
                     average_label_error_rate)
        return average_label_error_rate

    def reset_acc(self):
        for i in range(len(self.acc_label_error_rate)):
            self.acc_label_error_rate[i] = 1.1
Ejemplo n.º 9
0
    path = '/search/speech/hubo/git/tf-code-acoustics/chain_source_7300/'
    #path = './'
    conf_dict = {
        'batch_size': 64,
        'skip_offset': 0,
        'skip_frame': 3,
        'shuffle': False,
        'queue_cache': 10,
        'io_thread_num': 5
    }
    feat_trans_file = '../conf/final.feature_transform'
    feat_trans = FeatureTransform()
    feat_trans.LoadTransform(feat_trans_file)
    logging.basicConfig(filename='test.log')
    logging.getLogger().setLevel('INFO')
    io_read = KaldiDataReadParallel()
    io_read.Initialize(
        conf_dict,
        scp_file=path + 'cegs.1.scp',
        #io_read.Initialize(conf_dict, scp_file=path+'scp',
        #io_read.Initialize(conf_dict, scp_file=path+'cegs.all.scp_0',
        feature_transform=feat_trans,
        criterion='chain')

    batch_info = 2000
    start = time.time()
    io_read.Reset(shuffle=False)
    batch_num = 0
    model = CommonModel(nnet_conf)

    # Instantiate an optimizer.
    def __init__(self, conf_dict):
        self.print_trainable_variables = False
        self.use_normal = False
        self.use_sgd = True
        self.restore_training = True

        self.checkpoint_dir = None
        self.num_threads = 1
        self.queue_cache = 100

        self.feature_transfile = None
        # initial configuration parameter
        for key in self.__dict__:
            if key in conf_dict.keys():
                self.__dict__[key] = conf_dict[key]

        # initial nnet configuration parameter
        self.nnet_conf = ProjConfig()
        self.nnet_conf.initial(conf_dict)

        self.kaldi_io_nstream = None

        if self.feature_transfile == None:
            logging.info('No feature_transfile,it must have.')
            sys.exit(1)
        feat_trans = FeatureTransform()
        feat_trans.LoadTransform(self.feature_transfile)
        # init train file
        self.kaldi_io_nstream_train = KaldiDataReadParallel()
        self.input_dim = self.kaldi_io_nstream_train.Initialize(
            conf_dict,
            scp_file=conf_dict['scp_file'],
            label=conf_dict['label'],
            feature_transform=feat_trans,
            criterion='ctc')

        # init cv file
        self.kaldi_io_nstream_cv = KaldiDataReadParallel()
        self.kaldi_io_nstream_cv.Initialize(conf_dict,
                                            scp_file=conf_dict['cv_scp'],
                                            label=conf_dict['cv_label'],
                                            feature_transform=feat_trans,
                                            criterion='ctc')

        self.num_batch_total = 0
        self.tot_lab_err_rate = 0.0
        self.tot_num_batch = 0

        logging.info(self.nnet_conf.__repr__())
        logging.info(self.kaldi_io_nstream_train.__repr__())
        logging.info(self.kaldi_io_nstream_cv.__repr__())

        self.input_queue = Queue.Queue(self.queue_cache)

        self.acc_label_error_rate = []  # record every thread label error rate
        self.all_lab_err_rate = []  # for adjust learn rate
        self.num_save = 0
        for i in range(5):
            self.all_lab_err_rate.append(1.1)
        self.num_batch = []
        for i in range(self.num_threads):
            self.acc_label_error_rate.append(1.0)
            self.num_batch.append(0)
class train_class(object):
    def __init__(self, conf_dict):
        self.print_trainable_variables = False
        self.use_normal = False
        self.use_sgd = True
        self.restore_training = True

        self.checkpoint_dir = None
        self.num_threads = 1
        self.queue_cache = 100

        self.feature_transfile = None
        # initial configuration parameter
        for key in self.__dict__:
            if key in conf_dict.keys():
                self.__dict__[key] = conf_dict[key]

        # initial nnet configuration parameter
        self.nnet_conf = ProjConfig()
        self.nnet_conf.initial(conf_dict)

        self.kaldi_io_nstream = None

        if self.feature_transfile == None:
            logging.info('No feature_transfile,it must have.')
            sys.exit(1)
        feat_trans = FeatureTransform()
        feat_trans.LoadTransform(self.feature_transfile)
        # init train file
        self.kaldi_io_nstream_train = KaldiDataReadParallel()
        self.input_dim = self.kaldi_io_nstream_train.Initialize(
            conf_dict,
            scp_file=conf_dict['scp_file'],
            label=conf_dict['label'],
            feature_transform=feat_trans,
            criterion='ctc')

        # init cv file
        self.kaldi_io_nstream_cv = KaldiDataReadParallel()
        self.kaldi_io_nstream_cv.Initialize(conf_dict,
                                            scp_file=conf_dict['cv_scp'],
                                            label=conf_dict['cv_label'],
                                            feature_transform=feat_trans,
                                            criterion='ctc')

        self.num_batch_total = 0
        self.tot_lab_err_rate = 0.0
        self.tot_num_batch = 0

        logging.info(self.nnet_conf.__repr__())
        logging.info(self.kaldi_io_nstream_train.__repr__())
        logging.info(self.kaldi_io_nstream_cv.__repr__())

        self.input_queue = Queue.Queue(self.queue_cache)

        self.acc_label_error_rate = []  # record every thread label error rate
        self.all_lab_err_rate = []  # for adjust learn rate
        self.num_save = 0
        for i in range(5):
            self.all_lab_err_rate.append(1.1)
        self.num_batch = []
        for i in range(self.num_threads):
            self.acc_label_error_rate.append(1.0)
            self.num_batch.append(0)

    # get restore model number
    def get_num(self, str):
        return int(str.split('/')[-1].split('_')[0])

    # construct train graph
    def construct_graph(self):
        with tf.Graph().as_default():
            self.run_ops = []
            self.X = tf.placeholder(tf.float32, [None, None, self.input_dim],
                                    name='feature')
            self.Y = tf.sparse_placeholder(tf.int32, name="labels")
            self.seq_len = tf.placeholder(tf.int32, [None], name='seq_len')

            self.learning_rate_var = tf.Variable(float(
                self.nnet_conf.learning_rate),
                                                 trainable=False,
                                                 name='learning_rate')
            if self.use_sgd:
                optimizer = tf.train.GradientDescentOptimizer(
                    self.learning_rate_var)
            else:
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=self.learning_rate_var,
                    beta1=0.9,
                    beta2=0.999,
                    epsilon=1e-08)

            for i in range(self.num_threads):
                with tf.device("/gpu:%d" % i):
                    initializer = tf.random_uniform_initializer(
                        -self.nnet_conf.init_scale, self.nnet_conf.init_scale)
                    model = LSTM_Model(self.nnet_conf)
                    mean_loss, ctc_loss, label_error_rate, decoded, softval = model.loss(
                        self.X, self.Y, self.seq_len)
                    if self.use_sgd and self.use_normal:
                        tvars = tf.trainable_variables()
                        grads, _ = tf.clip_by_global_norm(
                            tf.gradients(mean_loss, tvars),
                            self.nnet_conf.grad_clip)
                        train_op = optimizer.apply_gradients(
                            zip(grads, tvars),
                            global_step=tf.contrib.framework.
                            get_or_create_global_step())
                    else:
                        train_op = optimizer.minimize(mean_loss)

                    run_op = {
                        'train_op': train_op,
                        'mean_loss': mean_loss,
                        'ctc_loss': ctc_loss,
                        'label_error_rate': label_error_rate
                    }
                    #        'decoded':decoded,
                    #        'softval':softval}
                    self.run_ops.append(run_op)
                    tf.get_variable_scope().reuse_variables()

            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95)
            self.sess = tf.Session(config=tf.ConfigProto(
                intra_op_parallelism_threads=self.num_threads,
                allow_soft_placement=True,
                log_device_placement=False,
                gpu_options=gpu_options))
            init = tf.group(tf.global_variables_initializer(),
                            tf.local_variables_initializer())

            tmp_variables = tf.trainable_variables()
            self.saver = tf.train.Saver(tmp_variables, max_to_keep=100)

            #self.saver = tf.train.Saver(max_to_keep=100, sharded = True)
            if self.restore_training:
                self.sess.run(init)
                ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    logging.info("restore training")
                    self.saver.restore(self.sess, ckpt.model_checkpoint_path)
                    self.num_batch_total = self.get_num(
                        ckpt.model_checkpoint_path)
                    if self.print_trainable_variables == True:
                        print_trainable_variables(
                            self.sess, ckpt.model_checkpoint_path + '.txt')
                        sys.exit(0)
                    logging.info('model:' + ckpt.model_checkpoint_path)
                    logging.info('restore learn_rate:' +
                                 str(self.sess.run(self.learning_rate_var)))
                else:
                    logging.info('No checkpoint file found')
                    self.sess.run(init)
                    logging.info('init learn_rate:' +
                                 str(self.sess.run(self.learning_rate_var)))
            else:
                self.sess.run(init)

            self.total_variables = np.sum([
                np.prod(v.get_shape().as_list())
                for v in tf.trainable_variables()
            ])
            logging.info('total parameters : %d' % self.total_variables)

    def SaveTextModel(self):
        if self.print_trainable_variables == True:
            ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                print_trainable_variables(self.sess,
                                          ckpt.model_checkpoint_path + '.txt')

    def train_function(self, gpu_id, run_op, thread_name):
        total_acc_error_rate = 0.0
        num_batch = 0
        self.acc_label_error_rate[gpu_id] = 0.0
        self.num_batch[gpu_id] = 0
        while True:
            time1 = time.time()
            feat, sparse_label, length = self.get_feat_and_label()
            if feat is None:
                logging.info('train thread end : %s' % thread_name)
                break
            time2 = time.time()
            print('******time:', time2 - time1, thread_name)

            feed_dict = {
                self.X: feat,
                self.Y: sparse_label,
                self.seq_len: length
            }

            time3 = time.time()
            calculate_return = self.sess.run(run_op, feed_dict=feed_dict)
            time4 = time.time()

            print("thread_name: ", thread_name, num_batch, " time:",
                  time2 - time1, time3 - time2, time4 - time3, time4 - time1)
            print('label_error_rate:', calculate_return['label_error_rate'])
            print('mean_loss:', calculate_return['mean_loss'])
            print('ctc_loss:', calculate_return['ctc_loss'])

            num_batch += 1
            total_acc_error_rate += calculate_return['label_error_rate']
            self.acc_label_error_rate[gpu_id] += calculate_return[
                'label_error_rate']
            self.num_batch[gpu_id] += 1

    def cv_function(self, gpu_id, run_op, thread_name):
        total_acc_error_rate = 0.0
        num_batch = 0
        self.acc_label_error_rate[gpu_id] = 0.0
        self.num_batch[gpu_id] = 0
        while True:
            feat, sparse_label, length = self.get_feat_and_label()
            if feat is None:
                logging.info('cv thread end : %s' % thread_name)
                break
            feed_dict = {
                self.X: feat,
                self.Y: sparse_label,
                self.seq_len: length
            }
            run_need_op = {'label_error_rate': run_op['label_error_rate']}
            #'softval':run_op['softval']}
            #'mean_loss':run_op['mean_loss'],
            #'ctc_loss':run_op['ctc_loss'],
            #'label_error_rate':run_op['label_error_rate']}
            calculate_return = self.sess.run(run_need_op, feed_dict=feed_dict)
            print('label_error_rate:', calculate_return['label_error_rate'])
            #print('softval:',calculate_return['softval'])
            num_batch += 1
            total_acc_error_rate += calculate_return['label_error_rate']
            self.acc_label_error_rate[gpu_id] += calculate_return[
                'label_error_rate']
            self.num_batch[gpu_id] += 1

    def get_feat_and_label(self):
        return self.input_queue.get()

    def input_feat_and_label(self):
        strat_io_time = time.time()
        feat, label, length = self.kaldi_io_nstream.LoadNextNstreams()
        end_io_time = time.time()
        #        print('*************io time**********************',end_io_time-strat_io_time)
        if length is None:
            return False
        if len(label) != self.nnet_conf.batch_size:
            return False
        sparse_label = sparse_tuple_from(label)
        self.input_queue.put((feat, sparse_label, length))
        self.num_batch_total += 1
        print('total_batch_num**********', self.num_batch_total, '***********')
        return True

    def InputFeat(self, input_lock):
        while True:
            feat, label, length = self.kaldi_io_nstream.LoadNextNstreams()
            if length is None:
                break
            if len(label) != self.nnet_conf.batch_size:
                break
            sparse_label = sparse_tuple_from(label)
            input_lock.acquire()
            self.input_queue.put((feat, sparse_label, length))
            self.num_batch_total += 1
            if self.num_batch_total % 3000 == 0:
                self.SaveModel()
                self.AdjustLearnRate()
            print('total_batch_num**********', self.num_batch_total,
                  '***********')
            input_lock.release()

    def ThreadInputFeatAndLab(self):
        input_thread = []
        input_lock = threading.Lock()
        for i in range(2):
            input_thread.append(
                threading.Thread(group=None,
                                 target=self.InputFeat,
                                 args=(input_lock),
                                 name='read_thread' + str(i)))
        for thr in input_thread:
            thr.start()

        for thr in input_thread:
            thr.join()

    def SaveModel(self):
        while True:
            time.sleep(1.0)
            if self.input_queue.empty():
                checkpoint_path = os.path.join(
                    self.checkpoint_dir,
                    str(self.num_batch_total) + '_model' + '.ckpt')
                logging.info('save model: ' + checkpoint_path +
                             ' --- learn_rate: ' +
                             str(self.sess.run(self.learning_rate_var)))
                self.saver.save(self.sess, checkpoint_path)
                break

    # if current label error rate less then previous five
    def AdjustLearnRate(self):
        curr_lab_err_rate = self.get_avergae_label_error_rate()
        logging.info("current label error rate : %f\n" % curr_lab_err_rate)
        all_lab_err_rate_len = len(self.all_lab_err_rate)
        #self.all_lab_err_rate.sort()
        for i in range(all_lab_err_rate_len):
            if curr_lab_err_rate < self.all_lab_err_rate[i]:
                break
            if i == len(self.all_lab_err_rate) - 1:
                self.decay_learning_rate(0.8)
        self.all_lab_err_rate[self.num_save %
                              all_lab_err_rate_len] = curr_lab_err_rate
        self.num_save += 1

    def cv_logic(self):
        self.kaldi_io_nstream = self.kaldi_io_nstream_cv
        train_thread = []
        for i in range(self.num_threads):
            #            self.acc_label_error_rate.append(1.0)
            train_thread.append(
                threading.Thread(group=None,
                                 target=self.cv_function,
                                 args=(i, self.run_ops[i],
                                       'thread_hubo_' + str(i)),
                                 name='thread_hubo_' + str(i)))

        for thr in train_thread:
            thr.start()
        logging.info('cv thread start.')

        while True:
            if self.input_feat_and_label():
                continue
            break

        logging.info('cv read feat ok')
        for thr in train_thread:
            self.input_queue.put((None, None, None))

        while True:
            if self.input_queue.empty():
                break

        logging.info('cv is end.')
        for thr in train_thread:
            thr.join()
            logging.info('join cv thread %s' % thr.name)

        tmp_label_error_rate = self.get_avergae_label_error_rate()
        self.kaldi_io_nstream.Reset()
        self.reset_acc()
        return tmp_label_error_rate

    def train_logic(self, shuffle=False, skip_offset=0):
        self.kaldi_io_nstream = self.kaldi_io_nstream_train
        train_thread = []
        for i in range(self.num_threads):
            #            self.acc_label_error_rate.append(1.0)
            train_thread.append(
                threading.Thread(group=None,
                                 target=self.train_function,
                                 args=(i, self.run_ops[i],
                                       'thread_hubo_' + str(i)),
                                 name='thread_hubo_' + str(i)))

        for thr in train_thread:
            thr.start()

        logging.info('train thread start.')

        all_lab_err_rate = []
        for i in range(5):
            all_lab_err_rate.append(1.1)

        tot_time = 0.0
        while True:
            if self.num_batch_total % 3000 == 0:
                while True:
                    #print('wait save mode')
                    time.sleep(0.5)
                    if self.input_queue.empty():
                        checkpoint_path = os.path.join(
                            self.checkpoint_dir,
                            str(self.num_batch_total) + '_model' + '.ckpt')
                        logging.info(
                            'save model: ' + checkpoint_path +
                            '--- learn_rate: ' +
                            str(self.sess.run(self.learning_rate_var)))
                        self.saver.save(self.sess, checkpoint_path)

                        if self.num_batch_total == 0:
                            break

                        self.AdjustLearnRate()  # adjust learn rate
                        break

            s_1 = time.time()
            if self.input_feat_and_label():
                e_1 = time.time()
                tot_time += e_1 - s_1
                print("***self.input_feat_and_label time*****", e_1 - s_1)
                continue
            break
        print("total input time:", tot_time)
        time.sleep(1)
        logging.info('read feat ok')
        '''
            end train
        '''
        for thr in train_thread:
            self.input_queue.put((None, None, None))

        while True:
            if self.input_queue.empty():
                #                logging.info('train is ok')
                checkpoint_path = os.path.join(
                    self.checkpoint_dir,
                    str(self.num_batch_total) + '_model' + '.ckpt')
                logging.info('save model: ' + checkpoint_path +
                             '.final --- learn_rate: ' +
                             str(self.sess.run(self.learning_rate_var)))
                self.saver.save(self.sess, checkpoint_path + '.final')
                break
        '''
            train is end
        '''
        logging.info('train is end.')
        for thr in train_thread:
            thr.join()
            logging.info('join thread %s' % thr.name)

        tmp_label_error_rate = self.get_avergae_label_error_rate()
        self.kaldi_io_nstream.Reset(shuffle=shuffle, skip_offset=skip_offset)
        self.reset_acc()
        return tmp_label_error_rate

    def decay_learning_rate(self, lr_decay_factor):
        learning_rate_decay_op = self.learning_rate_var.assign(
            tf.multiply(self.learning_rate_var, lr_decay_factor))
        self.sess.run(learning_rate_decay_op)
        logging.info('learn_rate decay to ' +
                     str(self.sess.run(self.learning_rate_var)))
        logging.info('lr_decay_factor is ' + str(lr_decay_factor))

    def get_avergae_label_error_rate(self):
        tot_label_error_rate = 0.0
        tot_num_batch = 0
        for i in range(self.num_threads):
            tot_label_error_rate += self.acc_label_error_rate[i]
            tot_num_batch += self.num_batch[i]
        if tot_num_batch == 0:
            average_label_error_rate = 1.0
        else:
            average_label_error_rate = tot_label_error_rate / tot_num_batch


#        logging.info("average label error rate : %f\n" % average_label_error_rate)
        self.tot_lab_err_rate += tot_label_error_rate
        self.tot_num_batch += tot_num_batch
        self.reset_acc(tot_reset=False)
        return average_label_error_rate

    def GetTotLabErrRate(self):
        return self.tot_lab_err_rate / self.tot_num_batch

    def reset_acc(self, tot_reset=True):
        for i in range(len(self.acc_label_error_rate)):
            self.acc_label_error_rate[i] = 0.0
            self.num_batch[i] = 0
        if tot_reset:
            self.tot_lab_err_rate = 0
            self.tot_num_batch = 0
            for i in range(5):
                self.all_lab_err_rate.append(1.1)
            self.num_save = 0