def main(): args = input_args() if args.mode == M_PREDICTION: if not check_input_args(args): return False audio = args.audio acoustic = Acoustic(CONFIG_PATH, operation=M_PREDICTION) if args.predict_mode == M_PREDICT_GREEDY: output = acoustic.run(audio=audio, greedy=True) print(Color.red("input: {}".format(audio))) print(Color.yellow("Greedy predict: {}".format(output))) elif args.predict_mode == M_PREDICT_BEAM: output = acoustic.run(audio=audio, beam_width=args.beam, top_paths=args.top, greedy=False) print(Color.yellow("input: {}".format(audio))) for i in range(args.top): print(Color.yellow("Beam predict: {}.{}".format(i, output[i]))) print() else: acoustic = Acoustic(CONFIG_PATH, operation=args.mode) print(acoustic.model.summary()) acoustic.run()
def writer_file(file, results, mode='a', encoding='utf-8'): """ 输入字符串、列表字符,写入文件 :param file: (str, mandatory) 文件名 或 路径 + 文件名 :param results: (str or list or tuple, mandatory) 需要写入文件的字符集或单个字符串。例如"hello word" 或['hello world'] :param mode: (str, optional, default='a') 模式。默认为 'a' ,追加模式 :param encoding: (str, optional, default='utf-8') 编码。默认为 UTF-8 编码 :return: """ if not isinstance(results, (str, list, tuple)): raise ParameterError( "result parameter must be is {}, but actually get {}".format((str, list, tuple), type(results))) if isinstance(results, (list, tuple)): for result in results: if not isinstance(result, str): raise ParameterError( "results parameter elements must be is str, but actually get {}, elements:{}".format( type(result), result)) Writer.check_path(file) with open(file, mode, encoding=encoding) as f_write: if isinstance(results, str): f_write.writelines(results + '\n') else: for result in results: f_write.writelines(result + '\n') print("\n" + "over!File: {}, encoding: {}".format(Color.red(file), Color.red(encoding)))
def check_input_args(args): """ 检查输入的参数 """ if args.audio is '': print(Color.red("Please input audio file !")) return False else: if not os.path.isfile(args.audio): print(FileNotFoundException(args.audio)) return False if args.top > args.beam: print(Color.red("requested top:{} than the beam:{}".format(args.top, args.beam))) return False return True
def init_transverse(conf, input_dim, output_vocab_size): """ 初始化 Transverse :param conf: (obj, mandatory) 配置对象 :param input_dim: (int, mandatory) 输入维度 :param output_vocab_size: (int, mandatory) 输出词汇量大小 :return: (AcousticTransverseNet) AcousticTransverseNet 对象 """ transverse = AcousticTransverseNet( input_dim=input_dim, output_vocab_size=output_vocab_size, dn_hidden_size=conf.MODEL.TRANSVERSE.dn_hidden_size, dilated_conv_depth=conf.MODEL.TRANSVERSE.dilated_conv_depth, width_conv_depth=conf.MODEL.TRANSVERSE.width_conv_depth, multi_dilated_rate=conf.MODEL.TRANSVERSE.multi_dilated_rate, dilated_conv_filters=conf.MODEL.TRANSVERSE.dilated_conv_filters, width_conv_filters=conf.MODEL.TRANSVERSE.width_conv_filters, dropout_rate=conf.MODEL.TRANSVERSE.dropout_rate, l1=conf.MODEL.TRANSVERSE.l1, l2=conf.MODEL.TRANSVERSE.l2, activation=conf.MODEL.TRANSVERSE.activation, learning_rate=conf.MODEL.TRANSVERSE.learning_rate, warmup_steps=conf.MODEL.TRANSVERSE.warmup_steps, optimizer_beta_1=conf.MODEL.TRANSVERSE.optimizer_beta_1, optimizer_beta_2=conf.MODEL.TRANSVERSE.optimizer_beta_2, optimizer_epsilon=conf.MODEL.TRANSVERSE.optimizer_epsilon, ckpt_dir=conf.MODEL.TRANSVERSE.ckpt_dir, ckpt_max_to_keep=conf.MODEL.TRANSVERSE.ckpt_max_to_keep) print(Color.green('Initialization transverse from scratch')) return transverse
def init_or_restore_audio_feature(dataset, file, f_type, frame_length, frame_shift, mfcc_dim): """ 初始化或还原音频特征器 :param dataset: (list, mandatory) 数据集 :param file: (str, mandatory) 保存或者还原audio_feature的文件 :param f_type: (str, mandatory) 处理音频特征的类型。"spectrogram" or "mfcc" :param frame_length: (int, mandatory) 帧长 :param frame_shift: (int, mandatory) 帧移 :param mfcc_dim: (int, mandatory) mfcc 特征维度 :return: (AudioFeature) 音频特征对象 """ if os.path.isfile(file): audio_feat = AudioFeatures.load(file) output = "Audio feature: restore from : {}".format(file) else: audio_feat = AudioFeatures(f_type=f_type, frame_length=frame_length, frame_shift=frame_shift, mfcc_dim=mfcc_dim) audio_feat.fit(train_data=dataset) output = "Audio feature: Not found file:{} Initializing from scratch".format( file) audio_feat.save(file) print(Color.green(output)) return audio_feat
def init_transformer(input_vocab_size, output_vocab_size, padding_index, conf): """ 初始化 transformer :param input_vocab_size: (int, mandatory) 输入词汇大小 :param output_vocab_size: (int, mandatory) 输出词汇大小 :param padding_index: (int, mandatory) 填充字符索引 :param conf: (object) 配置信息 :return: (LanguageTransformer) transformer 对象 """ transformer = LanguageTransformer( input_vocab_size=input_vocab_size, output_vocab_size=output_vocab_size, d_model=conf.MODEL.TRANSFORMER.d_model, heads_num=conf.MODEL.TRANSFORMER.heads_num, forward_hidden=conf.MODEL.TRANSFORMER.forward_hidden, num_layers=conf.MODEL.TRANSFORMER.num_layers, input_max_positional=conf.MODEL.TRANSFORMER.input_max_positional, target_max_positional=conf.MODEL.TRANSFORMER.target_max_positional, dropout_rate=conf.MODEL.TRANSFORMER.dropout_rate, ckpt_dir=conf.MODEL.TRANSFORMER.ckpt_dir, ckpt_max_to_keep=conf.MODEL.TRANSFORMER.ckpt_max_to_keep, lr_warmup_steps=conf.MODEL.TRANSFORMER.lr_warmup_steps, optimizer_beta_1=conf.MODEL.TRANSFORMER.optimizer_beta_1, optimizer_beta_2=conf.MODEL.TRANSFORMER.optimizer_beta_2, optimizer_epsilon=conf.MODEL.TRANSFORMER.optimizer_epsilon, padding_index=padding_index, pred_max_length=conf.MODEL.TRANSFORMER.pred_max_length) print(Color.green('Initialization transformer from scratch')) return transformer
def check_input_sentence(sent): """ 检查输入的pinyin句子是否满足要求 :param sent: (str, mandatory) 拼音句子。例如 'lv4 shi4 yang2' :return: (bool) True or False """ if sent is None: print(Color.red("Please input the prediction sentence !")) return False sent = regular(sent) if len(sent) == 0: print( Color.red( "Please input the prediction sentence ! sentence cannot be empty!" )) return False return True
def init_or_restore_chinese_dict(dataset, file): """ 初始化或还原中文字典 :param dataset: (list, mandatory) 数据集 :param file: (str, mandatory) 保存或者还原chinese的文件 :return: (WordDict) 字典对象 """ if os.path.isfile(file): chinese = WordDict.load(file) output = 'Restore chinese dict from file: {}'.format(file) else: chinese = WordDict() chinese.fit(dataset=dataset) chinese.save(file) output = 'Initialize chinese dict from dataset' print(Color.green(output)) return chinese
def init_or_restore_pinyin_dict(dataset, file): """ 初始化或还原拼音字典 :param dataset: (list, mandatory) 数据集 :param file: (str, mandatory) 保存或者还原pinyin字典的文件 :return: (WordDict) 字典对象 """ if os.path.isfile(file): pinyin = PinYin.load(file) output = 'Restore pinyin dict from file:{}'.format(file) else: pinyin = PinYin() pinyin.fit(dataset=dataset) pinyin.save(file) output = 'Initialize pinyin dict form dataset' print(Color.green(output)) return pinyin
def restore_transformer(file, ckpt_dir): """ 还原 transformer :param file: (str, mandatory) transformer 配置文件 :param ckpt_dir: (str, mandatory) 检查点保存目录 :return: (LanguageTransformer) transformer 对象 """ try: language_transformer = LanguageTransformer.restore(file, ckpt_dir) except Exception as e: print( Color.red( "restore transformer model failed!config file: {}, checkpoint: {}" .format(file, ckpt_dir))) print(e) return False return language_transformer
def load_dataset(source_data_path, data_path): """ 加载数据集 :param data_path: (str, mandatory) 训练数据所在的路径 :param source_data_path: (str, mandatory) 实际数据路径 :return: (dict of list) 列表字典。返回音频文件路径,标签文件路径,标签数据路径,拼音字符,汉字 """ pinyin_sent_data = list() chinese_sent_data = list() chinese_data = list() pinyin_data = list() data_file = os.listdir(data_path) for file in tqdm(data_file): # 检查文件名后缀是不是 wav 音频类型文件。 if file[-3:] == 'wav': # 标签文件。根据音频文件名得到标签文件名 label_file = file + '.trn' # 判断标签文件是否存在。不存在则报错 assert label_file in data_file, FileNotFoundException(label_file) # 音频文件路径和标签文件路径 label_file_path = os.path.join(data_path, label_file) # 读取标签文件内容。找到对应标签数据路径 label_data_file_path = DataUtils.read_text_data( label_file_path, show_progress_bar=False) assert len(label_data_file_path) == 1, \ 'Get redundant data path: {}, label_file_path:{}'.format(label_data_file_path, label_file_path) # 重新拼接路径 label_data_file_path = os.path.join( source_data_path, label_data_file_path[0].split('/')[-1]) assert os.path.isfile(label_data_file_path), FileNotFoundException( label_data_file_path) # 读取标签数据。包括拼音、文本 text_data = DataUtils.read_text_data(label_data_file_path, handle_func=handle, show_progress_bar=False) chinese = text_data[0] pinyin = text_data[1] # 检查是否存在重复的数据,如果存在则跳过 chinese_sent = ''.join(chinese) pinyin_sent = ' '.join(pinyin) if chinese_sent in chinese_sent_data and pinyin_sent in pinyin_sent_data: # print("repeat sent: {}, {}".format(chinese_sent, pinyin_sent)) continue else: chinese_sent_data.append(chinese_sent) pinyin_sent_data.append(pinyin_sent) # 检查pinyin字符是否正确。主要判断每个拼音的最后一个字符是不是数字。不是数字则报错 for py in pinyin: assert py[-1].isdigit(), "the last character:{} of Pinyin is not a number! " \ "pinyin_str:{}, " \ "label_data_file_path:{}, " \ "label_file_path:{}".format(py, pinyin, label_data_file_path, label_file_path) # 将由多个中文字符组成的词转换成单个字 new_chinese = list() for ch in chinese: new_chinese += list(ch) chinese = new_chinese # 检查是否是一个拼音对应中文字符, 如果不是则报错 assert len(chinese) == len(pinyin), "the number of pinyin:{} and chinese:{} is different, " \ "chinese:{}, pinyin:{}, file:{}".format(len(pinyin), len(chinese), chinese, pinyin, label_data_file_path) chinese_data.append(chinese) pinyin_data.append(pinyin) print(Color.red("load to {} of data!".format(len(chinese_data)))) return chinese_data, pinyin_data