Ejemplo n.º 1
0
    def __init__(self, conf_dict):
        self.nnet_conf = ProjConfig()
        self.nnet_conf.initial(conf_dict)

        self.kaldi_io_nstream = None
        feat_trans = FeatureTransform()
        feat_trans.LoadTransform(conf_dict['feature_transfile'])
        # init train file
        self.kaldi_io_nstream_train = KaldiDataReadParallel()
        self.input_dim = self.kaldi_io_nstream_train.Initialize(
            conf_dict,
            scp_file=conf_dict['scp_file'],
            label=conf_dict['label'],
            feature_transform=feat_trans,
            criterion='ce')

        # init cv file
        self.kaldi_io_nstream_cv = KaldiDataReadParallel()
        self.kaldi_io_nstream_cv.Initialize(conf_dict,
                                            scp_file=conf_dict['cv_scp'],
                                            label=conf_dict['cv_label'],
                                            feature_transform=feat_trans,
                                            criterion='ce')

        self.num_batch_total = 0
        self.num_frames_total = 0

        logging.info(self.nnet_conf.__repr__())
        logging.info(self.kaldi_io_nstream_train.__repr__())
        logging.info(self.kaldi_io_nstream_cv.__repr__())

        self.print_trainable_variables = False
        if conf_dict.has_key('print_trainable_variables'):
            self.print_trainable_variables = conf_dict[
                'print_trainable_variables']
        self.tf_async_model_prefix = conf_dict['checkpoint_dir']
        self.num_threads = conf_dict['num_threads']
        self.queue_cache = conf_dict['queue_cache']
        self.input_queue = Queue.Queue(self.queue_cache)
        self.acc_label_error_rate = []
        for i in range(self.num_threads):
            self.acc_label_error_rate.append(1.1)
        if conf_dict.has_key('use_normal'):
            self.use_normal = conf_dict['use_normal']
        else:
            self.use_normal = False
        if conf_dict.has_key('use_sgd'):
            self.use_sgd = conf_dict['use_sgd']
        else:
            self.use_sgd = True

        if conf_dict.has_key('restore_training'):
            self.restore_training = conf_dict['restore_training']
        else:
            self.restore_training = False
Ejemplo n.º 2
0
                pri += key + ':\t' + 'too big not print' + '\n'
        pri += '}'
        return pri


if __name__ == '__main__':
    conf_dict = {
        'batch_size': 100,
        'skip_frame': 3,
        'skip_offset': 0,
        'do_skip_lab': True,
        'shuffle': False
    }
    path = '/search/speech/hubo/git/tf-code-acoustics/train-data'
    feat_trans_file = '../conf/final.feature_transform'
    feat_trans = FeatureTransform()
    feat_trans.LoadTransform(feat_trans_file)

    logging.basicConfig(filename='test.log')
    logging.getLogger().setLevel('INFO')
    io_read = KaldiDataReadParallel()
    io_read.Initialize(conf_dict,
                       scp_file=path + '/abc.scp',
                       label=path + '/merge_sort_cv.labels',
                       feature_transform=feat_trans,
                       criterion='whole')

    #label = path+'/sort_tr.labels.4026.ce',
    start = time.time()
    while True:
        #feat_mat, label, length = io_read.LoadNextNstreams()
Ejemplo n.º 3
0
    def __init__(self, conf_dict):
        # configure paramtere
        self.conf_dict = conf_dict
        self.print_trainable_variables_cf = False
        self.use_normal_cf = False
        self.use_sgd_cf = True
        self.restore_training_cf = True
        self.checkpoint_dir_cf = None
        self.num_threads_cf = 1
        self.queue_cache_cf = 100
        self.task_index_cf = -1
        self.grad_clip_cf = 5.0
        self.feature_transfile_cf = None
        self.learning_rate_cf = 0.1
        self.learning_rate_decay_steps_cf = 100000
        self.learning_rate_decay_rate_cf = 0.96
        self.batch_size_cf = 16
        self.num_frames_batch_cf = 20

        self.steps_per_checkpoint_cf = 1000

        self.criterion_cf = 'ctc'
        # initial configuration parameter
        for attr in self.__dict__:
            if len(attr.split('_cf')) != 2:
                continue;
            key = attr.split('_cf')[0]
            if key in conf_dict.keys():
                if key in strset or type(conf_dict[key]) is not str:
                    self.__dict__[attr] = conf_dict[key]
                else:
                    print('***************',key)
                    self.__dict__[attr] = eval(conf_dict[key])

        if self.feature_transfile_cf == None:
            logging.info('No feature_transfile,it must have.')
            sys.exit(1)
        feat_trans = FeatureTransform()
        feat_trans.LoadTransform(self.feature_transfile_cf)

        # init train file
        self.kaldi_io_nstream_train = KaldiDataReadParallel()
        self.input_dim = self.kaldi_io_nstream_train.Initialize(conf_dict,
                scp_file = conf_dict['tr_scp'], label = conf_dict['tr_label'],
                feature_transform = feat_trans, criterion = self.criterion_cf)
        # init cv file
        self.kaldi_io_nstream_cv = KaldiDataReadParallel()
        #self.kaldi_io_nstream_cv.Initialize(conf_dict,
        #        scp_file = conf_dict['cv_scp'], label = conf_dict['cv_label'],
        #        feature_transform = feat_trans, criterion = 'ctc')

        self.num_batch_total = 0
        self.tot_lab_err_rate = 0.0
        self.tot_num_batch = 0.0

        logging.info(self.kaldi_io_nstream_train.__repr__())
        #logging.info(self.kaldi_io_nstream_cv.__repr__())
        
        # Initial input queue.
        self.input_queue = Queue.Queue(self.queue_cache_cf)

        self.acc_label_error_rate = []
        self.all_lab_err_rate = []
        self.num_save = 0
        for i in range(5):
            self.all_lab_err_rate.append(1.1)
        self.num_batch = []
        for i in range(self.num_threads_cf):
            self.acc_label_error_rate.append(1.0)
            self.num_batch.append(0)

        return
    def __init__(self, conf_dict):
        self.print_trainable_variables = False
        self.use_normal = False
        self.use_sgd = True
        self.restore_training = True

        self.checkpoint_dir = None
        self.num_threads = 1
        self.queue_cache = 100

        self.feature_transfile = None
        # initial configuration parameter
        for key in self.__dict__:
            if key in conf_dict.keys():
                self.__dict__[key] = conf_dict[key]

        # initial nnet configuration parameter
        self.nnet_conf = ProjConfig()
        self.nnet_conf.initial(conf_dict)

        self.kaldi_io_nstream = None

        if self.feature_transfile == None:
            logging.info('No feature_transfile,it must have.')
            sys.exit(1)
        feat_trans = FeatureTransform()
        feat_trans.LoadTransform(self.feature_transfile)
        # init train file
        self.kaldi_io_nstream_train = KaldiDataReadParallel()
        self.input_dim = self.kaldi_io_nstream_train.Initialize(
            conf_dict,
            scp_file=conf_dict['scp_file'],
            label=conf_dict['label'],
            feature_transform=feat_trans,
            criterion='ctc')

        # init cv file
        self.kaldi_io_nstream_cv = KaldiDataReadParallel()
        self.kaldi_io_nstream_cv.Initialize(conf_dict,
                                            scp_file=conf_dict['cv_scp'],
                                            label=conf_dict['cv_label'],
                                            feature_transform=feat_trans,
                                            criterion='ctc')

        self.num_batch_total = 0
        self.tot_lab_err_rate = 0.0
        self.tot_num_batch = 0

        logging.info(self.nnet_conf.__repr__())
        logging.info(self.kaldi_io_nstream_train.__repr__())
        logging.info(self.kaldi_io_nstream_cv.__repr__())

        self.input_queue = Queue.Queue(self.queue_cache)

        self.acc_label_error_rate = []  # record every thread label error rate
        self.all_lab_err_rate = []  # for adjust learn rate
        self.num_save = 0
        for i in range(5):
            self.all_lab_err_rate.append(1.1)
        self.num_batch = []
        for i in range(self.num_threads):
            self.acc_label_error_rate.append(1.0)
            self.num_batch.append(0)