def load_data(args): mode = args.config.get('common', 'mode') batch_size = args.config.getint('common', 'batch_size') whcs = WHCS() whcs.width = args.config.getint('data', 'width') whcs.height = args.config.getint('data', 'height') whcs.channel = args.config.getint('data', 'channel') whcs.stride = args.config.getint('data', 'stride') save_dir = 'checkpoints' model_name = args.config.get('common', 'prefix') if mode == 'predict': test_json = args.config.get('data', 'test_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(test_json) datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) else: data_json = args.config.get('data', 'train_json') val_json = args.config.get('data', 'val_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(data_json) datagen.load_validation_data(val_json) if mode == "train": normalize_target_k = args.config.getint('train', 'normalize_target_k') datagen.sample_normalize(normalize_target_k, True) elif mode == "load": # get feat_mean and feat_std to normalize dataset datagen.get_meta_from_file(np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') if batch_size == 1 and is_batchnorm: raise Warning('batch size 1 is too small for is_batchnorm') # sort file paths by its duration in ascending order to implement sortaGrad if mode == "train" or mode == "load": max_t_count = datagen.get_max_seq_length(partition="train") max_label_length = datagen.get_max_label_length(partition="train") elif mode == "predict": max_t_count = datagen.get_max_seq_length(partition="test") max_label_length = datagen.get_max_label_length(partition="test") else: raise Exception( 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.') args.config.set('arch', 'max_t_count', str(max_t_count)) args.config.set('arch', 'max_label_length', str(max_label_length)) from importlib import import_module prepare_data_template = import_module(args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(args) if mode == "train": sort_by_duration=True shuffle=False else: sort_by_duration=False shuffle=True data_loaded = STTIter(partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, shuffle=shuffle) if mode == 'predict': return data_loaded, args else: validation_loaded = STTIter(partition="validation", count=datagen.val_count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=True, shuffle=False) return data_loaded, validation_loaded, args
def load_data(args): mode = args.config.get('common', 'mode') batch_size = args.config.getint('common', 'batch_size') whcs = WHCS() whcs.width = args.config.getint('data', 'width') whcs.height = args.config.getint('data', 'height') whcs.channel = args.config.getint('data', 'channel') whcs.stride = args.config.getint('data', 'stride') save_dir = 'checkpoints' model_name = args.config.get('common', 'prefix') is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files') if mode == 'predict': test_json = args.config.get('data', 'test_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(test_json) datagen.get_meta_from_file( np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) elif mode == "train" or mode == "load": data_json = args.config.get('data', 'train_json') val_json = args.config.get('data', 'val_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(data_json) #test bigramphems language = args.config.get('data', 'language') is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') if overwrite_meta_files and is_bi_graphemes: generate_bi_graphemes_dictionary(datagen.train_texts) labelUtil = LabelUtil.getInstance() if language == "en": if is_bi_graphemes: try: labelUtil.load_unicode_set( "resources/unicodemap_en_baidu_bi_graphemes.csv") except: raise Exception( "There is no resources/unicodemap_en_baidu_bi_graphemes.csv. Please set overwrite_meta_files at train section True" ) else: labelUtil.load_unicode_set("resources/unicodemap_en_baidu.csv") else: raise Exception("Error: Language Type: %s" % language) args.config.set('arch', 'n_classes', str(labelUtil.get_count())) if mode == "train": if overwrite_meta_files: normalize_target_k = args.config.getint( 'train', 'normalize_target_k') datagen.sample_normalize(normalize_target_k, True) else: datagen.get_meta_from_file( np.loadtxt( generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt( generate_file_path(save_dir, model_name, 'feats_std'))) datagen.load_validation_data(val_json) elif mode == "load": # get feat_mean and feat_std to normalize dataset datagen.get_meta_from_file( np.loadtxt( generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt( generate_file_path(save_dir, model_name, 'feats_std'))) datagen.load_validation_data(val_json) else: raise Exception( 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.' ) is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') if batch_size == 1 and is_batchnorm: raise Warning('batch size 1 is too small for is_batchnorm') # sort file paths by its duration in ascending order to implement sortaGrad if mode == "train" or mode == "load": max_t_count = datagen.get_max_seq_length(partition="train") max_label_length = datagen.get_max_label_length( partition="train", is_bi_graphemes=is_bi_graphemes) elif mode == "predict": max_t_count = datagen.get_max_seq_length(partition="test") max_label_length = datagen.get_max_label_length( partition="test", is_bi_graphemes=is_bi_graphemes) else: raise Exception( 'Define mode in the cfg file first. train or predict or load can be the candidate for the mode.' ) args.config.set('arch', 'max_t_count', str(max_t_count)) args.config.set('arch', 'max_label_length', str(max_label_length)) from importlib import import_module prepare_data_template = import_module(args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(args) if mode == "train": sort_by_duration = True else: sort_by_duration = False data_loaded = STTIter(partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes) if mode == 'predict': return data_loaded, args else: validation_loaded = STTIter(partition="validation", count=datagen.val_count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=False, is_bi_graphemes=is_bi_graphemes) return data_loaded, validation_loaded, args
def load_data(args): mode = args.config.get('common', 'mode') if mode not in ['train', 'predict', 'load']: raise Exception( 'mode must be the one of the followings - train,predict,load') batch_size = args.config.getint('common', 'batch_size') whcs = WHCS() whcs.width = args.config.getint('data', 'width') whcs.height = args.config.getint('data', 'height') whcs.channel = args.config.getint('data', 'channel') whcs.stride = args.config.getint('data', 'stride') save_dir = 'checkpoints' model_name = args.config.get('common', 'prefix') is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files') overwrite_bi_graphemes_dictionary = args.config.getboolean( 'train', 'overwrite_bi_graphemes_dictionary') max_duration = args.config.getfloat('data', 'max_duration') language = args.config.get('data', 'language') log = LogUtil().getlogger() labelUtil = LabelUtil.getInstance() if mode == "train" or mode == "load": data_json = args.config.get('data', 'train_json') val_json = args.config.get('data', 'val_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(data_json, max_duration=max_duration) datagen.load_validation_data(val_json, max_duration=max_duration) if is_bi_graphemes: if not os.path.isfile( "resources/unicodemap_en_baidu_bi_graphemes.csv" ) or overwrite_bi_graphemes_dictionary: load_labelutil(labelUtil=labelUtil, is_bi_graphemes=False, language=language) generate_bi_graphemes_dictionary(datagen.train_texts + datagen.val_texts) load_labelutil(labelUtil=labelUtil, is_bi_graphemes=is_bi_graphemes, language=language) args.config.set('arch', 'n_classes', str(labelUtil.get_count())) if mode == "train": if overwrite_meta_files: log.info("Generate mean and std from samples") normalize_target_k = args.config.getint( 'train', 'normalize_target_k') datagen.sample_normalize(normalize_target_k, True) else: log.info("Read mean and std from meta files") datagen.get_meta_from_file( np.loadtxt( generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt( generate_file_path(save_dir, model_name, 'feats_std'))) elif mode == "load": # get feat_mean and feat_std to normalize dataset datagen.get_meta_from_file( np.loadtxt( generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt( generate_file_path(save_dir, model_name, 'feats_std'))) elif mode == 'predict': test_json = args.config.get('data', 'test_json') datagen = DataGenerator(save_dir=save_dir, model_name=model_name) datagen.load_train_data(test_json, max_duration=max_duration) labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="en") args.config.set('arch', 'n_classes', str(labelUtil.get_count())) datagen.get_meta_from_file( np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') if batch_size == 1 and is_batchnorm and (mode == 'train' or mode == 'load'): raise Warning('batch size 1 is too small for is_batchnorm') # sort file paths by its duration in ascending order to implement sortaGrad if mode == "train" or mode == "load": max_t_count = datagen.get_max_seq_length(partition="train") max_label_length = \ datagen.get_max_label_length(partition="train", is_bi_graphemes=is_bi_graphemes) elif mode == "predict": max_t_count = datagen.get_max_seq_length(partition="test") max_label_length = \ datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes) args.config.set('arch', 'max_t_count', str(max_t_count)) args.config.set('arch', 'max_label_length', str(max_label_length)) from importlib import import_module prepare_data_template = import_module(args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(args) sort_by_duration = (mode == "train") is_bucketing = args.config.getboolean('arch', 'is_bucketing') save_feature_as_csvfile = args.config.getboolean( 'train', 'save_feature_as_csvfile') if is_bucketing: buckets = json.loads(args.config.get('arch', 'buckets')) data_loaded = BucketSTTIter( partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes, buckets=buckets, save_feature_as_csvfile=save_feature_as_csvfile) else: data_loaded = STTIter(partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes, save_feature_as_csvfile=save_feature_as_csvfile) if mode == 'train' or mode == 'load': if is_bucketing: validation_loaded = BucketSTTIter( partition="validation", count=datagen.val_count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=False, is_bi_graphemes=is_bi_graphemes, buckets=buckets, save_feature_as_csvfile=save_feature_as_csvfile) else: validation_loaded = STTIter( partition="validation", count=datagen.val_count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=False, is_bi_graphemes=is_bi_graphemes, save_feature_as_csvfile=save_feature_as_csvfile) return data_loaded, validation_loaded, args elif mode == 'predict': return data_loaded, args
def load_data(args, wav_file): mode = args.config.get('common', 'mode') if mode not in ['train', 'predict', 'load']: raise Exception( 'mode must be the one of the followings - train,predict,load') batch_size = args.config.getint('common', 'batch_size') whcs = WHCS() whcs.width = args.config.getint('data', 'width') whcs.height = args.config.getint('data', 'height') whcs.channel = args.config.getint('data', 'channel') whcs.stride = args.config.getint('data', 'stride') save_dir = 'checkpoints' model_name = args.config.get('common', 'prefix') is_bi_graphemes = args.config.getboolean('common', 'is_bi_graphemes') overwrite_meta_files = args.config.getboolean('train', 'overwrite_meta_files') overwrite_bi_graphemes_dictionary = args.config.getboolean( 'train', 'overwrite_bi_graphemes_dictionary') max_duration = args.config.getfloat('data', 'max_duration') max_freq = args.config.getint('data', 'max_freq') language = args.config.get('data', 'language') log = LogUtil().getlogger() labelUtil = LabelUtil() # test_json = "resources/d.json" datagen = DataGenerator(save_dir=save_dir, model_name=model_name, max_freq=max_freq) datagen.train_audio_paths = [wav_file] datagen.train_durations = [get_duration_wave(wav_file)] datagen.train_texts = ["1 1"] datagen.count = 1 # datagen.load_train_data(test_json, max_duration=max_duration) labelutil = load_labelutil(labelUtil, is_bi_graphemes, language="zh") args.config.set('arch', 'n_classes', str(labelUtil.get_count())) datagen.get_meta_from_file( np.loadtxt(generate_file_path(save_dir, model_name, 'feats_mean')), np.loadtxt(generate_file_path(save_dir, model_name, 'feats_std'))) is_batchnorm = args.config.getboolean('arch', 'is_batchnorm') if batch_size == 1 and is_batchnorm and (mode == 'train' or mode == 'load'): raise Warning('batch size 1 is too small for is_batchnorm') max_t_count = datagen.get_max_seq_length(partition="test") max_label_length = \ datagen.get_max_label_length(partition="test", is_bi_graphemes=is_bi_graphemes) args.config.set('arch', 'max_t_count', str(max_t_count)) args.config.set('arch', 'max_label_length', str(max_label_length)) from importlib import import_module prepare_data_template = import_module(args.config.get('arch', 'arch_file')) init_states = prepare_data_template.prepare_data(args) sort_by_duration = (mode == "train") is_bucketing = args.config.getboolean('arch', 'is_bucketing') save_feature_as_csvfile = args.config.getboolean( 'train', 'save_feature_as_csvfile') if is_bucketing: buckets = json.loads(args.config.get('arch', 'buckets')) data_loaded = BucketSTTIter( partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes, buckets=buckets, save_feature_as_csvfile=save_feature_as_csvfile) else: data_loaded = STTIter(partition="train", count=datagen.count, datagen=datagen, batch_size=batch_size, num_label=max_label_length, init_states=init_states, seq_length=max_t_count, width=whcs.width, height=whcs.height, sort_by_duration=sort_by_duration, is_bi_graphemes=is_bi_graphemes, save_feature_as_csvfile=save_feature_as_csvfile) return data_loaded, args