def remove_space(l): labelUtil = LabelUtil() ret = [] for i in range(len(l)): if l[i] != labelUtil.get_space_index(): ret.append(l[i]) return ret
def update(self, labels, preds): check_label_shapes(labels, preds) if self.is_logging: log = LogUtil().getlogger() labelUtil = LabelUtil() self.batch_loss = 0. # log.info(self.audio_paths) host_name = socket.gethostname() for label, pred in zip(labels, preds): label = label.asnumpy() pred = pred.asnumpy() seq_length = len(pred) / int( int(self.batch_size) / int(self.num_gpu)) for i in range(int(int(self.batch_size) / int(self.num_gpu))): l = remove_blank(label[i]) p = [] probs = [] for k in range(int(seq_length)): p.append( np.argmax(pred[ k * int(int(self.batch_size) / int(self.num_gpu)) + i])) probs.append( pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]) p = pred_best(p) l_distance = levenshtein_distance(l, p) # l_distance = editdistance.eval(labelUtil.convert_num_to_word(l).split(" "), res) self.total_n_label += len(l) self.total_l_dist += l_distance this_cer = float(l_distance) / float(len(l)) if self.is_logging and this_cer > 0.4: log.info("%s label: %s " % (host_name, labelUtil.convert_num_to_word(l))) log.info( "%s pred : %s , cer: %f (distance: %d/ label length: %d)" % (host_name, labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l))) # log.info("ctc_loss: %.2f" % ctc_loss(l, pred, i, int(seq_length), int(self.batch_size), int(self.num_gpu))) self.num_inst += 1 self.sum_metric += this_cer # if self.is_epoch_end: # loss = ctc_loss(l, pred, i, int(seq_length), int(self.batch_size), int(self.num_gpu)) # self.batch_loss += loss # if self.is_logging: # log.info("loss: %f " % loss) self.total_ctc_loss += 0 # self.batch_loss
def update(self, labels, preds): check_label_shapes(labels, preds) if self.is_logging: log = LogUtil().getlogger() labelUtil = LabelUtil.getInstance() self.batch_loss = 0. for label, pred in zip(labels, preds): label = label.asnumpy() pred = pred.asnumpy() for i in range(int(int(self.batch_size) / int(self.num_gpu))): l = remove_blank(label[i]) p = [] for k in range(int(self.seq_length)): p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i])) p = pred_best(p) l_distance = levenshtein_distance(l, p) self.total_n_label += len(l) self.total_l_dist += l_distance this_cer = float(l_distance) / float(len(l)) if self.is_logging: log.info("label: %s " % (labelUtil.convert_num_to_word(l))) log.info("pred : %s , cer: %f (distance: %d/ label length: %d)" % ( labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l))) self.num_inst += 1 self.sum_metric += this_cer if self.is_epoch_end: loss = ctc_loss(l, pred, i, int(self.seq_length), int(self.batch_size), int(self.num_gpu)) self.batch_loss += loss if self.is_logging: log.info("loss: %f " % loss) self.total_ctc_loss += self.batch_loss
def remove_space(l): labelUtil = LabelUtil.getInstance() ret = [] for i in range(len(l)): if l[i] != labelUtil.get_space_index(): ret.append(l[i]) return ret
def update(self, labels, preds): check_label_shapes(labels, preds) log = LogUtil().getlogger() labelUtil = LabelUtil.getInstance() for label, pred in zip(labels, preds): label = label.asnumpy() pred = pred.asnumpy() for i in range(int(int(self.batch_size) / int(self.num_gpu))): l = remove_blank(label[i]) p = [] for k in range(int(self.seq_length)): p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i])) p = pred_best(p) l_distance = levenshtein_distance(l, p) self.total_n_label += len(l) self.total_l_dist += l_distance this_cer = float(l_distance) / float(len(l)) log.info("label: %s " % (labelUtil.convert_num_to_word(l))) log.info("pred : %s , cer: %f (distance: %d/ label length: %d)" % ( labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l))) self.num_inst += 1 self.sum_metric += this_cer if self.is_epoch_end: loss = ctc_loss(l, pred, i, int(self.seq_length), int(self.batch_size), int(self.num_gpu)) self.total_ctc_loss += loss log.info("loss: %f " % loss)
def prepare_minibatch(self, audio_paths, texts, overwrite=False, is_bi_graphemes=False, seq_length=-1, save_feature_as_csvfile=False): """ Featurize a minibatch of audio, zero pad them and return a dictionary Params: audio_paths (list(str)): List of paths to audio files texts (list(str)): List of texts corresponding to the audio files Returns: dict: See below for contents """ assert len(audio_paths) == len( texts ), "Inputs and outputs to the network must be of the same number" # Features is a list of (timesteps, feature_dim) arrays # Calculate the features for each audio clip, as the log of the # Fourier Transform of the audio features = [ self.featurize(a, overwrite=overwrite, save_feature_as_csvfile=save_feature_as_csvfile) for a in audio_paths ] input_lengths = [f.shape[0] for f in features] feature_dim = features[0].shape[1] mb_size = len(features) # Pad all the inputs so that they are all the same length if seq_length == -1: x = np.zeros((mb_size, self.max_seq_length, feature_dim)) else: x = np.zeros((mb_size, seq_length, feature_dim)) y = np.zeros((mb_size, self.max_label_length)) labelUtil = LabelUtil.getInstance() label_lengths = [] for i in range(mb_size): feat = features[i] feat = self.normalize(feat) # Center using means and std x[i, :feat.shape[0], :] = feat if is_bi_graphemes: label = generate_bi_graphemes_label(texts[i]) label = labelUtil.convert_bi_graphemes_to_num(label) y[i, :len(label)] = label else: label = labelUtil.convert_word_to_num(texts[i]) y[i, :len(texts[i])] = label label_lengths.append(len(label)) return { 'x': x, # (0-padded features of shape(mb_size,timesteps,feat_dim) 'y': y, # list(int) Flattened labels (integer sequences) 'texts': texts, # list(str) Original texts 'input_lengths': input_lengths, # list(int) Length of each input 'label_lengths': label_lengths, # list(int) Length of each label }
def prepare_minibatch(self, audio_paths, texts, overwrite=False, is_bi_graphemes=False, seq_length=-1, save_feature_as_csvfile=False): """ Featurize a minibatch of audio, zero pad them and return a dictionary Params: audio_paths (list(str)): List of paths to audio files texts (list(str)): List of texts corresponding to the audio files Returns: dict: See below for contents """ assert len(audio_paths) == len(texts),\ "Inputs and outputs to the network must be of the same number" # Features is a list of (timesteps, feature_dim) arrays # Calculate the features for each audio clip, as the log of the # Fourier Transform of the audio features = [self.featurize(a, overwrite=overwrite, save_feature_as_csvfile=save_feature_as_csvfile) for a in audio_paths] input_lengths = [f.shape[0] for f in features] feature_dim = features[0].shape[1] mb_size = len(features) # Pad all the inputs so that they are all the same length if seq_length == -1: x = np.zeros((mb_size, self.max_seq_length, feature_dim)) else: x = np.zeros((mb_size, seq_length, feature_dim)) y = np.zeros((mb_size, self.max_label_length)) labelUtil = LabelUtil.getInstance() label_lengths = [] for i in range(mb_size): feat = features[i] feat = self.normalize(feat) # Center using means and std x[i, :feat.shape[0], :] = feat if is_bi_graphemes: label = generate_bi_graphemes_label(texts[i]) label = labelUtil.convert_bi_graphemes_to_num(label) y[i, :len(label)] = label else: label = labelUtil.convert_word_to_num(texts[i]) y[i, :len(texts[i])] = label label_lengths.append(len(label)) return { 'x': x, # (0-padded features of shape(mb_size,timesteps,feat_dim) 'y': y, # list(int) Flattened labels (integer sequences) 'texts': texts, # list(str) Original texts 'input_lengths': input_lengths, # list(int) Length of each input 'label_lengths': label_lengths, # list(int) Length of each label }
def get_scorer(alpha=1., beta=1.): try: from swig_wrapper import Scorer labelUtil = LabelUtil() vocab_list = [chars.encode("utf-8") for chars in labelUtil.byList] log.info("vacab_list len is %d" % len(vocab_list)) _ext_scorer = Scorer(alpha, beta, args.config.get('common', 'kenlm'), vocab_list) lm_char_based = _ext_scorer.is_character_based() lm_max_order = _ext_scorer.get_max_order() lm_dict_size = _ext_scorer.get_dict_size() log.info("language model: " "is_character_based = %d," % lm_char_based + " max_order = %d," % lm_max_order + " dict_size = %d" % lm_dict_size) return _ext_scorer except ImportError: import kenlm km = kenlm.Model(args.config.get('common', 'kenlm')) return km.score
def load_data(args): mode = args.config.get('common', 'mode') if mode not in ['train', 'predict', 'load']: raise Exception( 'mode must be the one of the followings - train,predict,load') batch_size = args.config.getint('common', 'batch_size') whcs = WHCS() whcs.width = args.config.getint('data', 'width') whcs.height = args.config.getint('data', 'height') whcs.channel = args.config.getint('data', 'channel') whcs.stride = args.config.getint('data', 'stride') save_dir = 'checkpoints' model_name = args.config.get('common', 'prefix') is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files') overwrite_bi_graphemes_dictionary = args.config.getboolean( 'train', 'overwrite_bi_graphemes_dictionary') max_duration = args.config.getfloat('data', 'max_duration') language = args.config.get('data', 'language') log = LogUtil().getlogger() labelUtil = LabelUtil.getInstance() if mode == "train" or mode == "load": data_json = args.config.get('data', 'train_json') val_json = args.config.get('data', 'val_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(data_json, max_duration=max_duration) datagen.load_validation_data(val_json, max_duration=max_duration) if is_bi_graphemes: if not os.path.isfile( "resources/unicodemap_en_baidu_bi_graphemes.csv" ) or overwrite_bi_graphemes_dictionary: load_labelutil(labelUtil=labelUtil, is_bi_graphemes=False, language=language) generate_bi_graphemes_dictionary(datagen.train_texts + datagen.val_texts) load_labelutil(labelUtil=labelUtil, is_bi_graphemes=is_bi_graphemes, language=language) args.config.set('arch', 'n_classes', str(labelUtil.get_count())) if mode == "train": if overwrite_meta_files: log.info("Generate mean and std from samples") normalize_target_k = args.config.getint( 'train', 'normalize_target_k') datagen.sample_normalize(normalize_target_k, True) else: log.info("Read mean and std from meta files") datagen.get_meta_from_file( np.loadtxt( generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt( generate_file_path(save_dir, model_name, 'feats_std'))) elif mode == "load": # get feat_mean and feat_std to normalize dataset datagen.get_meta_from_file( np.loadtxt( generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt( generate_file_path(save_dir, model_name, 'feats_std'))) elif mode == 'predict': test_json = args.config.get('data', 'test_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(test_json, max_duration=max_duration) labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="en") args.config.set('arch', 'n_classes', str(labelUtil.get_count())) datagen.get_meta_from_file( np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') if batch_size == 1 and is_batchnorm and (mode == 'train' or mode == 'load'): raise Warning('batch size 1 is too small for is_batchnorm') # sort file paths by its duration in ascending order to implement sortaGrad if mode == "train" or mode == "load": max_t_count = datagen.get_max_seq_length(partition="train") max_label_length = \ datagen.get_max_label_length(partition="train", is_bi_graphemes=is_bi_graphemes) elif mode == "predict": max_t_count = datagen.get_max_seq_length(partition="test") max_label_length = \ datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes) args.config.set('arch', 'max_t_count', str(max_t_count)) args.config.set('arch', 'max_label_length', str(max_label_length)) from importlib import import_module prepare_data_template = import_module(args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(args) sort_by_duration = (mode == "train") is_bucketing = args.config.getboolean('arch', 'is_bucketing') save_feature_as_csvfile = args.config.getboolean( 'train', 'save_feature_as_csvfile') if is_bucketing: buckets = json.loads(args.config.get('arch', 'buckets')) data_loaded = BucketSTTIter( partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes, buckets=buckets, save_feature_as_csvfile=save_feature_as_csvfile) else: data_loaded = STTIter(partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes, save_feature_as_csvfile=save_feature_as_csvfile) if mode == 'train' or mode == 'load': if is_bucketing: validation_loaded = BucketSTTIter( partition="validation", count=datagen.val_count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=False, is_bi_graphemes=is_bi_graphemes, buckets=buckets, save_feature_as_csvfile=save_feature_as_csvfile) else: validation_loaded = STTIter( partition="validation", count=datagen.val_count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=False, is_bi_graphemes=is_bi_graphemes, save_feature_as_csvfile=save_feature_as_csvfile) return data_loaded, validation_loaded, args elif mode == 'predict': return data_loaded, args
def load_data(args): mode = args.config.get('common', 'mode') batch_size = args.config.getint('common', 'batch_size') whcs = WHCS() whcs.width = args.config.getint('data', 'width') whcs.height = args.config.getint('data', 'height') whcs.channel = args.config.getint('data', 'channel') whcs.stride = args.config.getint('data', 'stride') save_dir = 'checkpoints' model_name = args.config.get('common', 'prefix') is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files') if mode == 'predict': test_json = args.config.get('data', 'test_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(test_json) datagen.get_meta_from_file( np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) elif mode == "train" or mode == "load": data_json = args.config.get('data', 'train_json') val_json = args.config.get('data', 'val_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(data_json) #test bigramphems language = args.config.get('data', 'language') is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') if overwrite_meta_files and is_bi_graphemes: generate_bi_graphemes_dictionary(datagen.train_texts) labelUtil = LabelUtil.getInstance() if language == "en": if is_bi_graphemes: try: labelUtil.load_unicode_set( "resources/unicodemap_en_baidu_bi_graphemes.csv") except: raise Exception( "There is no resources/unicodemap_en_baidu_bi_graphemes.csv. Please set overwrite_meta_files at train section True" ) else: labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv") else: raise Exception("Error: Language Type: %s" % language) args.config.set('arch', 'n_classes', str(labelUtil.get_count())) if mode == "train": if overwrite_meta_files: normalize_target_k = args.config.getint( 'train', 'normalize_target_k') datagen.sample_normalize(normalize_target_k, True) else: datagen.get_meta_from_file( np.loadtxt( generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt( generate_file_path(save_dir, model_name, 'feats_std'))) datagen.load_validation_data(val_json) elif mode == "load": # get feat_mean and feat_std to normalize dataset datagen.get_meta_from_file( np.loadtxt( generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt( generate_file_path(save_dir, model_name, 'feats_std'))) datagen.load_validation_data(val_json) else: raise Exception( 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.' ) is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') if batch_size == 1 and is_batchnorm: raise Warning('batch size 1 is too small for is_batchnorm') # sort file paths by its duration in ascending order to implement sortaGrad if mode == "train" or mode == "load": max_t_count = datagen.get_max_seq_length(partition="train") max_label_length = datagen.get_max_label_length( partition="train", is_bi_graphemes=is_bi_graphemes) elif mode == "predict": max_t_count = datagen.get_max_seq_length(partition="test") max_label_length = datagen.get_max_label_length( partition="test", is_bi_graphemes=is_bi_graphemes) else: raise Exception( 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.' ) args.config.set('arch', 'max_t_count', str(max_t_count)) args.config.set('arch', 'max_label_length', str(max_label_length)) from importlib import import_module prepare_data_template = import_module(args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(args) if mode == "train": sort_by_duration = True else: sort_by_duration = False data_loaded = STTIter(partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes) if mode == 'predict': return data_loaded, args else: validation_loaded = STTIter(partition="validation", count=datagen.val_count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=False, is_bi_graphemes=is_bi_graphemes) return data_loaded, validation_loaded, args
def __init__(self, args): self.args = args # set parameters from data section(common) self.mode = self.args.config.get('common', 'mode') # get meta file where character to number conversions are defined self.contexts = parse_contexts(self.args) self.num_gpu = len(self.contexts) self.batch_size = self.args.config.getint('common', 'batch_size') # check the number of gpus is positive divisor of the batch size for data parallel self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm') self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing') # log current config self.config_logger = ConfigLogger(log) self.config_logger(args.config) save_dir = 'checkpoints' model_name = self.args.config.get('common', 'prefix') max_freq = self.args.config.getint('data', 'max_freq') self.datagen = DataGenerator(save_dir=save_dir, model_name=model_name, max_freq=max_freq) self.datagen.get_meta_from_file( np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) self.buckets = json.loads(self.args.config.get('arch', 'buckets')) default_bucket_key = self.buckets[-1] self.args.config.set('arch', 'max_t_count', str(default_bucket_key)) self.args.config.set('arch', 'max_label_length', str(100)) self.labelUtil = LabelUtil() is_bi_graphemes = self.args.config.getboolean('common', 'is_bi_graphemes') load_labelutil(self.labelUtil, is_bi_graphemes, language="zh") self.args.config.set('arch', 'n_classes', str(self.labelUtil.get_count())) self.max_t_count = self.args.config.getint('arch', 'max_t_count') # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states') # load model self.model_loaded, self.model_num_epoch, self.model_path = load_model( self.args) symbol, self.arg_params, self.aux_params = mx.model.load_checkpoint( self.model_path, self.model_num_epoch) # all_layers = symbol.get_internals() # s_sym = all_layers['concat36457_output'] # sm = mx.sym.SoftmaxOutput(data=s_sym, name='softmax') # self.model = STTBucketingModule( # sym_gen=self.model_loaded, # default_bucket_key=default_bucket_key, # context=self.contexts # ) s_mod = mx.mod.BucketingModule(sym_gen=self.model_loaded, context=self.contexts, default_bucket_key=default_bucket_key) from importlib import import_module prepare_data_template = import_module( self.args.config.get('arch', 'arch_file')) self.init_states = prepare_data_template.prepare_data(self.args) self.width = self.args.config.getint('data', 'width') self.height = self.args.config.getint('data', 'height') s_mod.bind(data_shapes=[ ('data', (self.batch_size, default_bucket_key, self.width * self.height)) ] + self.init_states, for_training=False) s_mod.set_params(self.arg_params, self.aux_params, allow_extra=True, allow_missing=True) for bucket in self.buckets: provide_data = [ ('data', (self.batch_size, bucket, self.width * self.height)) ] + self.init_states s_mod.switch_bucket(bucket_key=bucket, data_shapes=provide_data) self.model = s_mod try: from swig_wrapper import Scorer vocab_list = [ chars.encode("utf-8") for chars in self.labelUtil.byList ] log.info("vacab_list len is %d" % len(vocab_list)) _ext_scorer = Scorer(0.26, 0.1, self.args.config.get('common', 'kenlm'), vocab_list) lm_char_based = _ext_scorer.is_character_based() lm_max_order = _ext_scorer.get_max_order() lm_dict_size = _ext_scorer.get_dict_size() log.info("language model: " "is_character_based = %d," % lm_char_based + " max_order = %d," % lm_max_order + " dict_size = %d" % lm_dict_size) self.scorer = _ext_scorer # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True, # scorer=_ext_scorer) except ImportError: import kenlm km = kenlm.Model(self.args.config.get('common', 'kenlm')) # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True, # scorer=km.score) self.scorer = km.score
class Net(object): def __init__(self, args): self.args = args # set parameters from data section(common) self.mode = self.args.config.get('common', 'mode') # get meta file where character to number conversions are defined self.contexts = parse_contexts(self.args) self.num_gpu = len(self.contexts) self.batch_size = self.args.config.getint('common', 'batch_size') # check the number of gpus is positive divisor of the batch size for data parallel self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm') self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing') # log current config self.config_logger = ConfigLogger(log) self.config_logger(args.config) save_dir = 'checkpoints' model_name = self.args.config.get('common', 'prefix') max_freq = self.args.config.getint('data', 'max_freq') self.datagen = DataGenerator(save_dir=save_dir, model_name=model_name, max_freq=max_freq) self.datagen.get_meta_from_file( np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) self.buckets = json.loads(self.args.config.get('arch', 'buckets')) default_bucket_key = self.buckets[-1] self.args.config.set('arch', 'max_t_count', str(default_bucket_key)) self.args.config.set('arch', 'max_label_length', str(100)) self.labelUtil = LabelUtil() is_bi_graphemes = self.args.config.getboolean('common', 'is_bi_graphemes') load_labelutil(self.labelUtil, is_bi_graphemes, language="zh") self.args.config.set('arch', 'n_classes', str(self.labelUtil.get_count())) self.max_t_count = self.args.config.getint('arch', 'max_t_count') # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states') # load model self.model_loaded, self.model_num_epoch, self.model_path = load_model( self.args) symbol, self.arg_params, self.aux_params = mx.model.load_checkpoint( self.model_path, self.model_num_epoch) # all_layers = symbol.get_internals() # s_sym = all_layers['concat36457_output'] # sm = mx.sym.SoftmaxOutput(data=s_sym, name='softmax') # self.model = STTBucketingModule( # sym_gen=self.model_loaded, # default_bucket_key=default_bucket_key, # context=self.contexts # ) s_mod = mx.mod.BucketingModule(sym_gen=self.model_loaded, context=self.contexts, default_bucket_key=default_bucket_key) from importlib import import_module prepare_data_template = import_module( self.args.config.get('arch', 'arch_file')) self.init_states = prepare_data_template.prepare_data(self.args) self.width = self.args.config.getint('data', 'width') self.height = self.args.config.getint('data', 'height') s_mod.bind(data_shapes=[ ('data', (self.batch_size, default_bucket_key, self.width * self.height)) ] + self.init_states, for_training=False) s_mod.set_params(self.arg_params, self.aux_params, allow_extra=True, allow_missing=True) for bucket in self.buckets: provide_data = [ ('data', (self.batch_size, bucket, self.width * self.height)) ] + self.init_states s_mod.switch_bucket(bucket_key=bucket, data_shapes=provide_data) self.model = s_mod try: from swig_wrapper import Scorer vocab_list = [ chars.encode("utf-8") for chars in self.labelUtil.byList ] log.info("vacab_list len is %d" % len(vocab_list)) _ext_scorer = Scorer(0.26, 0.1, self.args.config.get('common', 'kenlm'), vocab_list) lm_char_based = _ext_scorer.is_character_based() lm_max_order = _ext_scorer.get_max_order() lm_dict_size = _ext_scorer.get_dict_size() log.info("language model: " "is_character_based = %d," % lm_char_based + " max_order = %d," % lm_max_order + " dict_size = %d" % lm_dict_size) self.scorer = _ext_scorer # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True, # scorer=_ext_scorer) except ImportError: import kenlm km = kenlm.Model(self.args.config.get('common', 'kenlm')) # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True, # scorer=km.score) self.scorer = km.score def getTrans(self, wav_file): res = spectrogram_from_file(wav_file, noise_percent=0) buck = bisect.bisect_left(self.buckets, len(res)) bucket_key = self.buckets[buck] res = self.datagen.normalize(res) d = np.zeros((self.batch_size, bucket_key, res.shape[1])) d[0, :res.shape[0], :] = res init_state_arrays = [mx.nd.zeros(x[1]) for x in self.init_states] model_loaded = self.model provide_data = [ ('data', (self.batch_size, bucket_key, self.width * self.height)) ] + self.init_states data_batch = mx.io.DataBatch([mx.nd.array(d)] + init_state_arrays, label=None, bucket_key=bucket_key, provide_data=provide_data, provide_label=None) st = time.time() model_loaded.forward(data_batch, is_train=False) probs = model_loaded.get_outputs()[0].asnumpy() log.info("forward cost %.3f" % (time.time() - st)) st = time.time() res = ctc_greedy_decode(probs, self.labelUtil.byList) log.info("greedy decode cost %.3f, result is:\n%s" % (time.time() - st, res)) beam_size = 5 from stt_metric import ctc_beam_decode st = time.time() results = ctc_beam_decode(scorer=self.scorer, beam_size=beam_size, vocab=self.labelUtil.byList, probs=probs) log.info("beam decode cost %.3f, result is:\n%s" % (time.time() - st, "\n".join(results))) return "greedy:\n" + res + "\nbeam:\n" + "\n".join(results)
def load_data(args): mode = args.config.get('common', 'mode') batch_size = args.config.getint('common', 'batch_size') whcs = WHCS() whcs.width = args.config.getint('data', 'width') whcs.height = args.config.getint('data', 'height') whcs.channel = args.config.getint('data', 'channel') whcs.stride = args.config.getint('data', 'stride') save_dir = 'checkpoints' model_name = args.config.get('common', 'prefix') is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files') language = args.config.get('data', 'language') is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') labelUtil = LabelUtil.getInstance() if language == "en": if is_bi_graphemes: try: labelUtil.load_unicode_set("resources/unicodemap_en_baidu_bi_graphemes.csv") except: raise Exception("There is no resources/unicodemap_en_baidu_bi_graphemes.csv. Please set overwrite_meta_files at train section True") else: labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv") else: raise Exception("Error: Language Type: %s" % language) args.config.set('arch', 'n_classes', str(labelUtil.get_count())) if mode == 'predict': test_json = args.config.get('data', 'test_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(test_json) datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) elif mode =="train" or mode == "load": data_json = args.config.get('data', 'train_json') val_json = args.config.get('data', 'val_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(data_json) #test bigramphems if overwrite_meta_files and is_bi_graphemes: generate_bi_graphemes_dictionary(datagen.train_texts) args.config.set('arch', 'n_classes', str(labelUtil.get_count())) if mode == "train": if overwrite_meta_files: normalize_target_k = args.config.getint('train', 'normalize_target_k') datagen.sample_normalize(normalize_target_k, True) else: datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) datagen.load_validation_data(val_json) elif mode == "load": # get feat_mean and feat_std to normalize dataset datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) datagen.load_validation_data(val_json) else: raise Exception( 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.') is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') if batch_size == 1 and is_batchnorm: raise Warning('batch size 1 is too small for is_batchnorm') # sort file paths by its duration in ascending order to implement sortaGrad if mode == "train" or mode == "load": max_t_count = datagen.get_max_seq_length(partition="train") max_label_length = datagen.get_max_label_length(partition="train",is_bi_graphemes=is_bi_graphemes) elif mode == "predict": max_t_count = datagen.get_max_seq_length(partition="test") max_label_length = datagen.get_max_label_length(partition="test",is_bi_graphemes=is_bi_graphemes) else: raise Exception( 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.') args.config.set('arch', 'max_t_count', str(max_t_count)) args.config.set('arch', 'max_label_length', str(max_label_length)) from importlib import import_module prepare_data_template = import_module(args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(args) if mode == "train": sort_by_duration = True else: sort_by_duration = False data_loaded = STTIter(partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes) if mode == 'predict': return data_loaded, args else: validation_loaded = STTIter(partition="validation", count=datagen.val_count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=False, is_bi_graphemes=is_bi_graphemes) return data_loaded, validation_loaded, args
def load_data(args): mode = args.config.get('common', 'mode') if mode not in ['train', 'predict', 'load']: raise Exception('mode must be the one of the followings - train,predict,load') batch_size = args.config.getint('common', 'batch_size') whcs = WHCS() whcs.width = args.config.getint('data', 'width') whcs.height = args.config.getint('data', 'height') whcs.channel = args.config.getint('data', 'channel') whcs.stride = args.config.getint('data', 'stride') save_dir = 'checkpoints' model_name = args.config.get('common', 'prefix') is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files') overwrite_bi_graphemes_dictionary = args.config.getboolean('train', 'overwrite_bi_graphemes_dictionary') max_duration = args.config.getfloat('data', 'max_duration') language = args.config.get('data', 'language') log = LogUtil().getlogger() labelUtil = LabelUtil.getInstance() if mode == "train" or mode == "load": data_json = args.config.get('data', 'train_json') val_json = args.config.get('data', 'val_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(data_json, max_duration=max_duration) datagen.load_validation_data(val_json, max_duration=max_duration) if is_bi_graphemes: if not os.path.isfile("resources/unicodemap_en_baidu_bi_graphemes.csv") or overwrite_bi_graphemes_dictionary: load_labelutil(labelUtil=labelUtil, is_bi_graphemes=False, language=language) generate_bi_graphemes_dictionary(datagen.train_texts+datagen.val_texts) load_labelutil(labelUtil=labelUtil, is_bi_graphemes=is_bi_graphemes, language=language) args.config.set('arch', 'n_classes', str(labelUtil.get_count())) if mode == "train": if overwrite_meta_files: log.info("Generate mean and std from samples") normalize_target_k = args.config.getint('train', 'normalize_target_k') datagen.sample_normalize(normalize_target_k, True) else: log.info("Read mean and std from meta files") datagen.get_meta_from_file( np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) elif mode == "load": # get feat_mean and feat_std to normalize dataset datagen.get_meta_from_file( np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) elif mode == 'predict': test_json = args.config.get('data', 'test_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(test_json, max_duration=max_duration) labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="en") args.config.set('arch', 'n_classes', str(labelUtil.get_count())) datagen.get_meta_from_file( np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') if batch_size == 1 and is_batchnorm and (mode == 'train' or mode == 'load'): raise Warning('batch size 1 is too small for is_batchnorm') # sort file paths by its duration in ascending order to implement sortaGrad if mode == "train" or mode == "load": max_t_count = datagen.get_max_seq_length(partition="train") max_label_length = \ datagen.get_max_label_length(partition="train", is_bi_graphemes=is_bi_graphemes) elif mode == "predict": max_t_count = datagen.get_max_seq_length(partition="test") max_label_length = \ datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes) args.config.set('arch', 'max_t_count', str(max_t_count)) args.config.set('arch', 'max_label_length', str(max_label_length)) from importlib import import_module prepare_data_template = import_module(args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(args) sort_by_duration = (mode == "train") is_bucketing = args.config.getboolean('arch', 'is_bucketing') save_feature_as_csvfile = args.config.getboolean('train', 'save_feature_as_csvfile') if is_bucketing: buckets = json.loads(args.config.get('arch', 'buckets')) data_loaded = BucketSTTIter(partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes, buckets=buckets, save_feature_as_csvfile=save_feature_as_csvfile) else: data_loaded = STTIter(partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes, save_feature_as_csvfile=save_feature_as_csvfile) if mode == 'train' or mode == 'load': if is_bucketing: validation_loaded = BucketSTTIter(partition="validation", count=datagen.val_count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=False, is_bi_graphemes=is_bi_graphemes, buckets=buckets, save_feature_as_csvfile=save_feature_as_csvfile) else: validation_loaded = STTIter(partition="validation", count=datagen.val_count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=False, is_bi_graphemes=is_bi_graphemes, save_feature_as_csvfile=save_feature_as_csvfile) return data_loaded, validation_loaded, args elif mode == 'predict': return data_loaded, args
def prepare_minibatch_fbank(self, audio_paths, texts, overwrite=False, is_bi_graphemes=False, seq_length=-1, save_feature_as_csvfile=False, language="en", zh_type="zi", noise_percent=0.4): """ Featurize a minibatch of audio, zero pad them and return a dictionary Params: audio_paths (list(str)): List of paths to audio files texts (list(str)): List of texts corresponding to the audio files Returns: dict: See below for contents """ assert len(audio_paths) == len(texts), \ "Inputs and outputs to the network must be of the same number" # Features is a list of (timesteps, feature_dim(161)) arrays (channel(3), feature_dim(41), timesteps) # Calculate the features for each audio clip, as the log of the # Fourier Transform of the audio features = [ self.featurize_fbank( a, overwrite=overwrite, save_feature_as_csvfile=save_feature_as_csvfile, noise_percent=noise_percent, seq_length=seq_length) for a in audio_paths ] input_lengths = [f.shape[1] for f in features] channel, timesteps, feature_dim = features[0].shape mb_size = len(features) # Pad all the inputs so that they are all the same length if seq_length == -1: x = np.zeros((mb_size, channel, self.max_seq_length, feature_dim)) else: x = np.zeros((mb_size, channel, seq_length, feature_dim)) y = np.zeros((mb_size, self.max_label_length)) labelUtil = LabelUtil() label_lengths = [] for i in range(mb_size): feat = features[i] feat = self.normalize_fbank(feat) # Center using means and std x[i, :, :feat. shape[1], :] = feat # padding with 0 padding with noise? if language == "en" and is_bi_graphemes: label = generate_bi_graphemes_label(texts[i]) label = labelUtil.convert_bi_graphemes_to_num(label) y[i, :len(label)] = label elif language == "en" and not is_bi_graphemes: label = labelUtil.convert_word_to_num(texts[i]) y[i, :len(texts[i])] = label elif language == "zh" and zh_type == "phone": label = generate_phone_label(texts[i]) label = labelUtil.convert_bi_graphemes_to_num(label) y[i, :len(label)] = label elif language == "zh" and zh_type == "py": label = generate_py_label(texts[i]) label = labelUtil.convert_bi_graphemes_to_num(label) y[i, :len(label)] = label elif language == "zh" and zh_type == "zi": label = generate_zi_label(texts[i]) label = labelUtil.convert_bi_graphemes_to_num(label) y[i, :len(label)] = label label_lengths.append(len(label)) return { 'x': x, # (0-padded features of shape(mb_size,timesteps,feat_dim) 'y': y, # list(int) Flattened labels (integer sequences) 'texts': texts, # list(str) Original texts 'input_lengths': input_lengths, # list(int) Length of each input 'label_lengths': label_lengths, # list(int) Length of each label }
def __init__(self): if len(sys.argv) <= 1: raise Exception('cfg file path must be provided. ' + 'ex)python main.py --configfile examplecfg.cfg') self.args = parse_args(sys.argv[1]) # set parameters from cfg file # give random seed self.random_seed = self.args.config.getint('common', 'random_seed') self.mx_random_seed = self.args.config.getint('common', 'mx_random_seed') # random seed for shuffling data list if self.random_seed != -1: np.random.seed(self.random_seed) # set mx.random.seed to give seed for parameter initialization if self.mx_random_seed != -1: mx.random.seed(self.mx_random_seed) else: mx.random.seed(hash(datetime.now())) # set log file name self.log_filename = self.args.config.get('common', 'log_filename') self.log = LogUtil(filename=self.log_filename).getlogger() # set parameters from data section(common) self.mode = self.args.config.get('common', 'mode') # get meta file where character to number conversions are defined self.contexts = parse_contexts(self.args) self.num_gpu = len(self.contexts) self.batch_size = self.args.config.getint('common', 'batch_size') # check the number of gpus is positive divisor of the batch size for data parallel self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm') self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing') # log current config self.config_logger = ConfigLogger(self.log) self.config_logger(self.args.config) default_bucket_key = 1600 self.args.config.set('arch', 'max_t_count', str(default_bucket_key)) self.args.config.set('arch', 'max_label_length', str(100)) self.labelUtil = LabelUtil() is_bi_graphemes = self.args.config.getboolean('common', 'is_bi_graphemes') load_labelutil(self.labelUtil, is_bi_graphemes, language="zh") self.args.config.set('arch', 'n_classes', str(self.labelUtil.get_count())) self.max_t_count = self.args.config.getint('arch', 'max_t_count') # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states') # load model self.model_loaded, self.model_num_epoch, self.model_path = load_model( self.args) self.model = STTBucketingModule(sym_gen=self.model_loaded, default_bucket_key=default_bucket_key, context=self.contexts) from importlib import import_module prepare_data_template = import_module( self.args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(self.args) width = self.args.config.getint('data', 'width') height = self.args.config.getint('data', 'height') self.model.bind(data_shapes=[ ('data', (self.batch_size, default_bucket_key, width * height)) ] + init_states, label_shapes=[ ('label', (self.batch_size, self.args.config.getint('arch', 'max_label_length'))) ], for_training=True) _, self.arg_params, self.aux_params = mx.model.load_checkpoint( self.model_path, self.model_num_epoch) self.model.set_params(self.arg_params, self.aux_params, allow_extra=True, allow_missing=True) try: from swig_wrapper import Scorer vocab_list = [ chars.encode("utf-8") for chars in self.labelUtil.byList ] self.log.info("vacab_list len is %d" % len(vocab_list)) _ext_scorer = Scorer(0.26, 0.1, self.args.config.get('common', 'kenlm'), vocab_list) lm_char_based = _ext_scorer.is_character_based() lm_max_order = _ext_scorer.get_max_order() lm_dict_size = _ext_scorer.get_dict_size() self.log.info("language model: " "is_character_based = %d," % lm_char_based + " max_order = %d," % lm_max_order + " dict_size = %d" % lm_dict_size) self.scorer = _ext_scorer # self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True, # scorer=_ext_scorer) except ImportError: import kenlm km = kenlm.Model(self.args.config.get('common', 'kenlm')) self.scorer = km.score
def load_data(args, wav_file): mode = args.config.get('common', 'mode') if mode not in ['train', 'predict', 'load']: raise Exception( 'mode must be the one of the followings - train,predict,load') batch_size = args.config.getint('common', 'batch_size') whcs = WHCS() whcs.width = args.config.getint('data', 'width') whcs.height = args.config.getint('data', 'height') whcs.channel = args.config.getint('data', 'channel') whcs.stride = args.config.getint('data', 'stride') save_dir = 'checkpoints' model_name = args.config.get('common', 'prefix') is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files') overwrite_bi_graphemes_dictionary = args.config.getboolean( 'train', 'overwrite_bi_graphemes_dictionary') max_duration = args.config.getfloat('data', 'max_duration') max_freq = args.config.getint('data', 'max_freq') language = args.config.get('data', 'language') log = LogUtil().getlogger() labelUtil = LabelUtil() # test_json = "resources/d.json" datagen = DataGenerator(save_dir=save_dir, model_name=model_name, max_freq=max_freq) datagen.train_audio_paths = [wav_file] datagen.train_durations = [get_duration_wave(wav_file)] datagen.train_texts = ["1 1"] datagen.count = 1 # datagen.load_train_data(test_json, max_duration=max_duration) labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="zh") args.config.set('arch', 'n_classes', str(labelUtil.get_count())) datagen.get_meta_from_file( np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') if batch_size == 1 and is_batchnorm and (mode == 'train' or mode == 'load'): raise Warning('batch size 1 is too small for is_batchnorm') max_t_count = datagen.get_max_seq_length(partition="test") max_label_length = \ datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes) args.config.set('arch', 'max_t_count', str(max_t_count)) args.config.set('arch', 'max_label_length', str(max_label_length)) from importlib import import_module prepare_data_template = import_module(args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(args) sort_by_duration = (mode == "train") is_bucketing = args.config.getboolean('arch', 'is_bucketing') save_feature_as_csvfile = args.config.getboolean( 'train', 'save_feature_as_csvfile') if is_bucketing: buckets = json.loads(args.config.get('arch', 'buckets')) data_loaded = BucketSTTIter( partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes, buckets=buckets, save_feature_as_csvfile=save_feature_as_csvfile) else: data_loaded = STTIter(partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes, save_feature_as_csvfile=save_feature_as_csvfile) return data_loaded, args
def update(self, labels, preds): check_label_shapes(labels, preds) if self.is_logging: log = LogUtil().getlogger() labelUtil = LabelUtil() self.batch_loss = 0. shouldPrint = True host_name = socket.gethostname() for label, pred in zip(labels, preds): label = label.asnumpy() pred = pred.asnumpy() seq_length = len(pred) / int( int(self.batch_size) / int(self.num_gpu)) # sess = tf.Session() for i in range(int(int(self.batch_size) / int(self.num_gpu))): l = remove_blank(label[i]) # p = [] probs = [] for k in range(int(seq_length)): # p.append(np.argmax(pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i])) probs.append( pred[k * int(int(self.batch_size) / int(self.num_gpu)) + i]) # p = pred_best(p) probs = np.array(probs) st = time.time() beam_size = 20 results = ctc_beam_decode(self.scorer, beam_size, labelUtil.byList, probs) log.info("decode by ctc_beam cost %.2f result: %s" % (time.time() - st, "\n".join(results))) res_str1 = ctc_greedy_decode(probs, labelUtil.byList) log.info("decode by pred_best: %s" % res_str1) # max_time_steps = int(seq_length) # input_log_prob_matrix_0 = np.log(probs) # + 2.0 # # # len max_time_steps array of batch_size x depth matrices # inputs = ([ # input_log_prob_matrix_0[t, :][np.newaxis, :] for t in range(max_time_steps)] # ) # # inputs_t = [ops.convert_to_tensor(x) for x in inputs] # inputs_t = array_ops.stack(inputs_t) # # st = time.time() # # run CTC beam search decoder in tensorflow # decoded, log_probabilities = tf.nn.ctc_beam_search_decoder(inputs_t, # [max_time_steps], # beam_width=10, # top_paths=3, # merge_repeated=False) # tf_decoded, tf_log_probs = sess.run([decoded, log_probabilities]) # st1 = time.time() - st # for index in range(3): # tf_result = ''.join([labelUtil.byIndex.get(i + 1, ' ') for i in tf_decoded[index].values]) # print("%.2f elpse %.2f, %s" % (tf_log_probs[0][index], st1, tf_result)) l_distance = editdistance.eval( labelUtil.convert_num_to_word(l).split(" "), res_str1) # l_distance_beam = editdistance.eval(labelUtil.convert_num_to_word(l).split(" "), beam_result[0][1]) l_distance_beam_cpp = editdistance.eval( labelUtil.convert_num_to_word(l).split(" "), results[0]) self.total_n_label += len(l) # self.total_l_dist_beam += l_distance_beam self.total_l_dist_beam_cpp += l_distance_beam_cpp self.total_l_dist += l_distance this_cer = float(l_distance) / float(len(l)) if self.is_logging: # log.info("%s label: %s " % (host_name, labelUtil.convert_num_to_word(l))) # log.info("%s pred : %s , cer: %f (distance: %d/ label length: %d)" % ( # host_name, labelUtil.convert_num_to_word(p), this_cer, l_distance, len(l))) log.info("%s label: %s " % (host_name, labelUtil.convert_num_to_word(l))) log.info( "%s pred : %s , cer: %f (distance: %d/ label length: %d)" % (host_name, res_str1, this_cer, l_distance, len(l))) # log.info("%s predb: %s , cer: %f (distance: %d/ label length: %d)" % ( # host_name, " ".join(beam_result[0][1]), float(l_distance_beam) / len(l), l_distance_beam, # len(l))) log.info( "%s predc: %s , cer: %f (distance: %d/ label length: %d)" % (host_name, " ".join( results[0]), float(l_distance_beam_cpp) / len(l), l_distance_beam_cpp, len(l))) self.total_ctc_loss += self.batch_loss self.placeholder = res_str1 + "\n" + "\n".join(results)
mx.random.seed(hash(datetime.now())) # set parameters from cfg file args = parse_args(sys.argv[1]) log_filename = args.config.get('common', 'log_filename') log = LogUtil(filename=log_filename).getlogger() # set parameters from data section(common) mode = args.config.get('common', 'mode') if mode not in ['train', 'predict', 'load']: raise Exception( 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.') # get meta file where character to number conversions are defined language = args.config.get('data', 'language') labelUtil = LabelUtil.getInstance() if language == "en": labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv") else: raise Exception("Error: Language Type: %s" % language) args.config.set('arch', 'n_classes', str(labelUtil.get_count())) contexts = parse_contexts(args) num_gpu = len(contexts) batch_size = args.config.getint('common', 'batch_size') # check the number of gpus is positive divisor of the batch size if batch_size % num_gpu != 0: raise Exception('num_gpu should be positive divisor of batch_size') if mode == "predict":
def __init__(self): if len(sys.argv) <= 1: raise Exception('cfg file path must be provided. ' + 'ex)python main.py --configfile examplecfg.cfg') self.args = parse_args(sys.argv[1]) # set parameters from cfg file # give random seed self.random_seed = self.args.config.getint('common', 'random_seed') self.mx_random_seed = self.args.config.getint('common', 'mx_random_seed') # random seed for shuffling data list if self.random_seed != -1: np.random.seed(self.random_seed) # set mx.random.seed to give seed for parameter initialization if self.mx_random_seed != -1: mx.random.seed(self.mx_random_seed) else: mx.random.seed(hash(datetime.now())) # set log file name self.log_filename = self.args.config.get('common', 'log_filename') self.log = LogUtil(filename=self.log_filename).getlogger() # set parameters from data section(common) self.mode = self.args.config.get('common', 'mode') save_dir = 'checkpoints' model_name = self.args.config.get('common', 'prefix') max_freq = self.args.config.getint('data', 'max_freq') self.datagen = DataGenerator(save_dir=save_dir, model_name=model_name, max_freq=max_freq) self.datagen.get_meta_from_file( np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) self.buckets = json.loads(self.args.config.get('arch', 'buckets')) # get meta file where character to number conversions are defined self.contexts = parse_contexts(self.args) self.num_gpu = len(self.contexts) self.batch_size = self.args.config.getint('common', 'batch_size') # check the number of gpus is positive divisor of the batch size for data parallel self.is_batchnorm = self.args.config.getboolean('arch', 'is_batchnorm') self.is_bucketing = self.args.config.getboolean('arch', 'is_bucketing') # log current config self.config_logger = ConfigLogger(self.log) self.config_logger(self.args.config) default_bucket_key = 1600 self.args.config.set('arch', 'max_t_count', str(default_bucket_key)) self.args.config.set('arch', 'max_label_length', str(95)) self.labelUtil = LabelUtil() is_bi_graphemes = self.args.config.getboolean('common', 'is_bi_graphemes') load_labelutil(self.labelUtil, is_bi_graphemes, language="zh") self.args.config.set('arch', 'n_classes', str(self.labelUtil.get_count())) self.max_t_count = self.args.config.getint('arch', 'max_t_count') # self.load_optimizer_states = self.args.config.getboolean('load', 'load_optimizer_states') # load model self.model_loaded, self.model_num_epoch, self.model_path = load_model( self.args) # self.model = STTBucketingModule( # sym_gen=self.model_loaded, # default_bucket_key=default_bucket_key, # context=self.contexts # ) from importlib import import_module prepare_data_template = import_module( self.args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(self.args) width = self.args.config.getint('data', 'width') height = self.args.config.getint('data', 'height') for bucket in self.buckets: net, init_state_names, ll = self.model_loaded(bucket) net.save('checkpoints/%s-symbol.json' % bucket) input_shapes = dict([('data', (self.batch_size, default_bucket_key, width * height))] + init_states + [('label', (1, 18))]) # self.executor = net.simple_bind(ctx=mx.cpu(), **input_shapes) # self.model.bind(data_shapes=[('data', (self.batch_size, default_bucket_key, width * height))] + init_states, # label_shapes=[ # ('label', (self.batch_size, self.args.config.getint('arch', 'max_label_length')))], # for_training=True) symbol, self.arg_params, self.aux_params = mx.model.load_checkpoint( self.model_path, self.model_num_epoch) all_layers = symbol.get_internals() concat = all_layers['concat36457_output'] sm = mx.sym.SoftmaxOutput(data=concat, name='softmax') self.executor = sm.simple_bind(ctx=mx.cpu(), **input_shapes) # self.model.set_params(self.arg_params, self.aux_params, allow_extra=True, allow_missing=True) for key in self.executor.arg_dict.keys(): if key in self.arg_params: self.arg_params[key].copyto(self.executor.arg_dict[key]) init_state_names.remove('data') init_state_names.sort() self.states_dict = dict( zip(init_state_names, self.executor.outputs[1:])) self.input_arr = mx.nd.zeros( (self.batch_size, default_bucket_key, width * height)) try: from swig_wrapper import Scorer vocab_list = [ chars.encode("utf-8") for chars in self.labelUtil.byList ] self.log.info("vacab_list len is %d" % len(vocab_list)) _ext_scorer = Scorer(0.26, 0.1, self.args.config.get('common', 'kenlm'), vocab_list) lm_char_based = _ext_scorer.is_character_based() lm_max_order = _ext_scorer.get_max_order() lm_dict_size = _ext_scorer.get_dict_size() self.log.info("language model: " "is_character_based = %d," % lm_char_based + " max_order = %d," % lm_max_order + " dict_size = %d" % lm_dict_size) self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True, scorer=_ext_scorer) except ImportError: import kenlm km = kenlm.Model(self.args.config.get('common', 'kenlm')) self.eval_metric = EvalSTTMetric(batch_size=self.batch_size, num_gpu=self.num_gpu, is_logging=True, scorer=km.score)