def __init_model(self): model = Transducer(self.config.model) checkpoint = torch.load(self.config.training.load_model) model.encoder.load_state_dict(checkpoint['encoder']) model.decoder.load_state_dict(checkpoint['decoder']) model.joint.load_state_dict(checkpoint['joint']) model = model.cuda() model.eval() print("已加载模型") self.model = model
def init_model(): config_file = open("config/joint_streaming.yaml", encoding='utf-8') config = AttrDict(yaml.load(config_file, Loader=yaml.FullLoader)) model = Transducer(config.model) checkpoint = torch.load(config.training.load_model) model.encoder.load_state_dict(checkpoint['encoder']) model.decoder.load_state_dict(checkpoint['decoder']) model.joint.load_state_dict(checkpoint['joint']) model = model.cuda() model.eval() print("已加载模型") vocab = {} with open(config.data.vocab, "r") as f: for line in f: parts = line.strip().split() word = parts[0] index = int(parts[1]) vocab[index] = word print("已加载词典") return model, vocab
def main(): parser = argparse.ArgumentParser() parser.add_argument('-config', type=str, default='config/thchs30.yaml') parser.add_argument('-log', type=str, default='train.log') parser.add_argument('-mode', type=str, default='retrain') opt = parser.parse_args() configfile = open(opt.config) config = AttrDict(yaml.load(configfile, Loader=yaml.FullLoader)) exp_name = os.path.join('egs', config.data.name, 'exp', config.training.save_model) if not os.path.isdir(exp_name): os.makedirs(exp_name) logger = init_logger(os.path.join(exp_name, opt.log)) shutil.copyfile(opt.config, os.path.join(exp_name, 'config.yaml')) logger.info('Save config info.') num_workers = config.training.num_gpu * 2 train_dataset = AudioDataset(config.data, 'train') training_data = torch.utils.data.DataLoader( train_dataset, batch_size=config.data.batch_size * config.training.num_gpu, shuffle=config.data.shuffle, num_workers=num_workers) logger.info('Load Train Set!') dev_dataset = AudioDataset(config.data, 'dev') validate_data = torch.utils.data.DataLoader( dev_dataset, batch_size=config.data.batch_size * config.training.num_gpu, shuffle=False, num_workers=num_workers) logger.info('Load Dev Set!') if config.training.num_gpu > 0: torch.cuda.manual_seed(config.training.seed) torch.backends.cudnn.deterministic = True else: torch.manual_seed(config.training.seed) logger.info('Set random seed: %d' % config.training.seed) model = Transducer(config.model) if config.training.load_model: checkpoint = torch.load(config.training.load_model) model.encoder.load_state_dict(checkpoint['encoder']) model.decoder.load_state_dict(checkpoint['decoder']) model.joint.load_state_dict(checkpoint['joint']) logger.info('Loaded model from %s' % config.training.load_model) elif config.training.load_encoder or config.training.load_decoder: if config.training.load_encoder: checkpoint = torch.load(config.training.load_encoder) model.encoder.load_state_dict(checkpoint['encoder']) logger.info('Loaded encoder from %s' % config.training.load_encoder) if config.training.load_decoder: checkpoint = torch.load(config.training.load_decoder) model.decoder.load_state_dict(checkpoint['decoder']) logger.info('Loaded decoder from %s' % config.training.load_decoder) if config.training.num_gpu > 0: model = model.cuda() if config.training.num_gpu > 1: device_ids = list(range(config.training.num_gpu)) model = torch.nn.DataParallel(model, device_ids=device_ids) logger.info('Loaded the model to %d GPUs' % config.training.num_gpu) n_params, enc, dec = count_parameters(model) logger.info('# the number of parameters in the whole model: %d' % n_params) logger.info('# the number of parameters in the Encoder: %d' % enc) logger.info('# the number of parameters in the Decoder: %d' % dec) logger.info('# the number of parameters in the JointNet: %d' % (n_params - dec - enc)) optimizer = Optimizer(model.parameters(), config.optim) logger.info('Created a %s optimizer.' % config.optim.type) if opt.mode == 'continue': optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] logger.info('Load Optimizer State!') else: start_epoch = 0 # create a visualizer if config.training.visualization: visualizer = SummaryWriter(os.path.join(exp_name, 'log')) logger.info('Created a visualizer.') else: visualizer = None for epoch in range(start_epoch, config.training.epochs): train(epoch, config, model, training_data, optimizer, logger, visualizer) if config.training.eval_or_not: _ = eval(epoch, config, model, validate_data, logger, visualizer) save_name = os.path.join(exp_name, '%s.epoch%d.chkpt' % (config.training.save_model, epoch)) save_model(model, optimizer, config, save_name) logger.info('Epoch %d model has been saved.' % epoch) if epoch >= config.optim.begin_to_adjust_lr: optimizer.decay_lr() # early stop if optimizer.lr < 1e-6: logger.info('The learning rate is too low to train.') break logger.info('Epoch %d update learning rate: %.6f' % (epoch, optimizer.lr)) logger.info('The training process is OVER!')
# @Contact : [email protected] # @Time : 2020/8/5 下午3:11 import os import yaml import torch from tt.model import Transducer from tt.utils import AttrDict, read_wave_from_file, get_feature, concat_frame, subsampling, context_mask os.chdir('../') WAVE_OUTPUT_FILENAME = 'audio/5_1812_20170628135834.wav' # 加载模型 config_file = open("config/joint_streaming.yaml") config = AttrDict(yaml.load(config_file, Loader=yaml.FullLoader)) model = Transducer(config.model) checkpoint = torch.load(config.training.load_model) model.encoder.load_state_dict(checkpoint['encoder']) model.decoder.load_state_dict(checkpoint['decoder']) model.joint.load_state_dict(checkpoint['joint']) print('加载模型') model.eval() # 获取音频特征 audio, fr = read_wave_from_file(WAVE_OUTPUT_FILENAME) feature = get_feature(audio, fr) feature = concat_frame(feature, 3, 0) feature = subsampling(feature, 3) feature = torch.from_numpy(feature) feature = torch.unsqueeze(feature, 0)
def eval(): """Decode with the given args. Args: args (namespace): The program arguments. """ parser = get_parser() args = parser.parse_args() # load yaml config file configfile = open(args.config) config = AttrDict(yaml.load(configfile, Loader=yaml.FullLoader)) # read json data with open(config.data.valid_json, "rb") as f: valid_json = json.load(f)["utts"] # read idim and odim from json data utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]["input"][0]["shape"][-1]) odim = int(valid_json[utts[0]]["output"][0]["shape"][-1]) # model, train_args = load_trained_model(args.model) from tt.model import Transducer model = Transducer(idim, odim, args) checkpoint = torch.load(args.model) model.encoder.load_state_dict(checkpoint['encoder']) model.decoder.load_state_dict(checkpoint['decoder']) model.recog_args = args logging.info( " Total parameter of the model = " + str(sum(p.numel() for p in model.parameters())) ) rnnlm = None if config.data.vocab is not None: with open(config.data.vocab, "rb") as f: dictionary = f.readlines() char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary] char_list.insert(0, "<blank>") char_list.append("<eos>") args.char_list = char_list # gpu if config.ngpu == 1: gpu_id = list(range(args.ngpu)) logging.info("gpu id: " + str(gpu_id)) model.cuda() if rnnlm: rnnlm.cuda() new_js = {} load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, preprocess_args={"train": False}, ) # model.eval() ## training: false avg_cer = [] with torch.no_grad(): ## training: false for idx, name in enumerate(valid_json.keys(), 1): logging.info("(%d/%d) decoding " + name, idx, len(valid_json.keys())) batch = [(name, valid_json[name])] feat = load_inputs_and_targets(batch) feat = ( feat[0][0] ) nbest_hyps = model.recognize( feat, args, args.char_list, rnnlm ) new_js[name] = add_results_to_json( valid_json[name], nbest_hyps, args.char_list ) hyp_chars = new_js[name]['output'][0]['rec_text'].replace(" ", "") ref_chars = valid_json[name]['output'][0]['text'].replace(" ", "") char_eds = editdistance.eval(hyp_chars, ref_chars) cer = float(char_eds / len(ref_chars)) * 100 avg_cer.append(cer) logging.info("{} cer: {}".format(name, cer)) logging.info('avg_cer:{}'.format(mean(np.array(avg_cer)))) with open('result.txt', "wb") as f: f.write( json.dumps( {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") )