def __init__(self, args): super(E2EModel, self).__init__() self.opt = args self.fbank_model = FbankModel(args) args.idim = args.fbank_dim self.e2e = E2E(args) self.sum = np.zeros(shape=[1, args.fbank_dim], dtype=np.float32) self.sum_sq = np.zeros(shape=[1, args.fbank_dim], dtype=np.float32) self.fbank_cmvn = np.zeros(shape=[2, args.fbank_dim], dtype=np.float32) self.cmvn_num = min(args.train_dataset_len, args.num_utt_cmvn) self.cmvn_processed_num = 0 self.frame_count = 0 self.pbar = ProgressBar().start() print(">> compute fbank_cmvn using {} utterance ".format(self.cmvn_num))
def main(): # Setup a model model_path = None enhance_model_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/enhance_fbank_train_table_2/model.loss.best' #enhance_model_path = '/usr/home/wudamu/Desktop/other_data/model.loss.best.base' #asr_mode_path = '/usr/home/wudamu/Desktop/other_data/model.acc.best' asr_mode_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/asr_clean_train_table3/model.acc.best' feat_model_path = asr_mode_path #if opt.resume: #model_path = os.path.join(opt.works_dir, opt.resume) #if os.path.isfile(model_path): #package = torch.load(model_path, map_location=lambda storage, loc: storage) #enhance_model = EnhanceModel.load_model(model_path, 'enhance_state_dict', opt) #feat_model = FbankModel.load_model(model_path, 'fbank_state_dict', opt) #asr_model = E2E.load_model(model_path, 'asr_state_dict', opt) #else: #raise Exception("no checkpoint found at {}".format(opt.resume)) #else: #raise Exception("no checkpoint found at {}".format(opt.resume)) if opt.resume: #model_path = os.path.join(opt.works_dir, opt.resume) #package = torch.load(model_path, map_location=lambda storage, loc: storage) enhance_model = EnhanceModel.load_model(enhance_model_path, 'enhance_state_dict', opt) feat_model = FbankModel.load_model(feat_model_path, 'fbank_state_dict', opt) asr_model = E2E.load_model(asr_mode_path, 'asr_state_dict', opt) else: raise Exception("no checkpoint found at {}".format(opt.resume)) def cpu_loader(storage, location): return storage if opt.lmtype == 'rnnlm': # read rnnlm if opt.rnnlm: rnnlm = lm.ClassifierWithState( #lm.RNNLM(len(opt.char_list), 650, 650)) lm.RNNLM(len(opt.char_list), 300, 650)) opt.rnnlm = "/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/rnnlm_train_shi/rnnlm.model.best" rnnlm.load_state_dict( torch.load(opt.rnnlm, map_location=cpu_loader)) if len(opt.gpu_ids) > 0: rnnlm = rnnlm.cuda() print('load RNNLM from {}'.format(opt.rnnlm)) rnnlm.eval() else: rnnlm = None if opt.word_rnnlm: if not opt.word_dict: logging.error( 'word dictionary file is not specified for the word RNNLM.' ) sys.exit(1) word_dict = load_labeldict(opt.word_dict) char_dict = {x: i for i, x in enumerate(opt.char_list)} word_rnnlm = lm.ClassifierWithState(lm.RNNLM(len(word_dict), 650)) word_rnnlm.load_state_dict( torch.load(opt.word_rnnlm, map_location=cpu_loader)) word_rnnlm.eval() if rnnlm is not None: rnnlm = lm.ClassifierWithState( extlm.MultiLevelLM(word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict)) else: rnnlm = lm.ClassifierWithState( extlm.LookAheadWordLM(word_rnnlm.predictor, word_dict, char_dict)) fstlm = None elif opt.lmtype == 'fsrnnlm': if opt.rnnlm: rnnlm = lm.ClassifierWithState( fsrnn.FSRNNLM(len(opt.char_list), 300, opt.fast_layers, opt.fast_cell_size, opt.slow_cell_size, opt.zoneout_keep_h, opt.zoneout_keep_c)) rnnlm.load_state_dict( torch.load(opt.rnnlm, map_location=cpu_loader)) if len(opt.gpu_ids) > 0: rnnlm = rnnlm.cuda() print('load fsrnn from {}'.format(opt.rnnlm)) rnnlm.eval() else: rnnlm = None print('not load fsrnn from {}'.format(opt.rnnlm)) fstlm = None elif opt.lmtype == 'fstlm': if opt.fstlm_path: fstlm = NgramFstLM(opt.fstlm_path, opt.nn_char_map_file, 20) else: fstlm = None rnnlm = None else: rnnlm = None fstlm = None #fbank_cmvn_file = os.path.join(opt.exp_path, 'fbank_cmvn.npy') fbank_cmvn_file = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/decode_asr_train_table3_1/decode_clean/fbank_cmvn.npy' if os.path.exists(fbank_cmvn_file): fbank_cmvn = np.load(fbank_cmvn_file) fbank_cmvn = torch.FloatTensor(fbank_cmvn) else: raise Exception("no found at {}".format(fbank_cmvn_file)) #enhance_cmvn_file ='/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/result_enhance_fbank/enhance_cmvn.npy' #enhance_cmvn = np.load(enhance_cmvn_file) #enhance_cmvn = torch.FloatTensor(enhance_cmvn) torch.set_grad_enabled(False) new_json = {} for i, (data) in enumerate(recog_loader, start=0): utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data #utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data name = utt_ids[0] #ss = torch.max(inputs) print(name) enhance_outputs = enhance_model(inputs, log_inputs, input_sizes) #print(enhance_outputs) #aa = torch.max(enhance_outputs) #feats = feat_model(enhance_outputs, fbank_cmvn) enhance_feat = feat_model(enhance_outputs, fbank_cmvn) nbest_hyps = asr_model.recognize(enhance_feat, opt, opt.char_list, rnnlm=rnnlm, fstlm=fstlm) #nbest_hyps = asr_model.recognize(enhance_outputs, opt, opt.char_list, rnnlm=rnnlm, fstlm=fstlm) # get 1best and remove sos y_hat = nbest_hyps[0]['yseq'][1:] print(y_hat) ##y_true = map(int, targets[0].split()) y_true = targets # print out decoding result seq_hat = [opt.char_list[int(idx)] for idx in y_hat] seq_true = [opt.char_list[int(idx)] for idx in y_true] seq_hat_text = "".join(seq_hat).replace('<space>', ' ') seq_true_text = "".join(seq_true).replace('<space>', ' ') logging.info("groundtruth[%s]: " + seq_true_text, name) logging.info("prediction [%s]: " + seq_hat_text, name) # copy old json info new_json[name] = dict() new_json[name]['utt2spk'] = spk_ids[0] # added recognition results to json logging.debug("dump token id") out_dic = dict() out_dic['name'] = 'target1' out_dic['text'] = seq_true_text out_dic['token'] = " ".join(seq_true) out_dic['tokenid'] = " ".join([str(int(idx)) for idx in y_true]) # TODO(karita) make consistent to chainer as idx[0] not idx out_dic['rec_tokenid'] = " ".join([str(int(idx)) for idx in y_hat]) #logger.debug("dump token") out_dic['rec_token'] = " ".join(seq_hat) #logger.debug("dump text") out_dic['rec_text'] = seq_hat_text new_json[name]['output'] = [out_dic] # TODO(nelson): Modify this part when saving more than 1 hyp is enabled # add n-best recognition results with scores if opt.beam_size > 1 and len(nbest_hyps) > 1: for i, hyp in enumerate(nbest_hyps): y_hat = hyp['yseq'][1:] seq_hat = [opt.char_list[int(idx)] for idx in y_hat] seq_hat_text = "".join(seq_hat).replace('<space>', ' ') new_json[name]['rec_tokenid' + '[' + '{:05d}'.format(i) + ']'] = " ".join([str(idx) for idx in y_hat]) new_json[name]['rec_token' + '[' + '{:05d}'.format(i) + ']'] = " ".join(seq_hat) new_json[name]['rec_text' + '[' + '{:05d}'.format(i) + ']'] = seq_hat_text new_json[name]['score' + '[' + '{:05d}'.format(i) + ']'] = float(hyp['score']) # TODO(watanabe) fix character coding problems when saving it with open(opt.result_label, 'wb') as f: f.write( json.dumps({ 'utts': new_json }, indent=4, sort_keys=True).encode('utf_8'))
def main(): opt = TrainOptions().parse() device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids) > 0 and torch.cuda.is_available() else "cpu") visualizer = Visualizer(opt) logging = visualizer.get_logger() acc_report = visualizer.add_plot_report(['train/acc', 'val/acc'], 'acc.png') loss_report = visualizer.add_plot_report(['train/loss', 'val/loss'], 'loss.png') # data logging.info("Building dataset.") # train目录 和 dict目录,作为输入 train_dataset = SequentialDataset(opt, os.path.join(opt.dataroot, 'train'), os.path.join(opt.dict_dir, 'train_units.txt'),) val_dataset = SequentialDataset(opt, os.path.join(opt.dataroot, 'dev'), os.path.join(opt.dict_dir, 'train_units.txt'),) train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size) train_loader = SequentialDataLoader(train_dataset, num_workers=opt.num_workers, batch_sampler=train_sampler) val_loader = SequentialDataLoader(val_dataset, batch_size=int(opt.batch_size/2), num_workers=opt.num_workers, shuffle=False) opt.idim = train_dataset.get_feat_size() opt.odim = train_dataset.get_num_classes() opt.char_list = train_dataset.get_char_list() opt.train_dataset_len = len(train_dataset) logging.info('#input dims : ' + str(opt.idim)) logging.info('#output dims: ' + str(opt.odim)) logging.info("Dataset ready!") # Setup a model asr_model = E2E(opt) ##fbank_model = FbankModel(opt) lr = opt.lr # default=0.005 eps = opt.eps # default=1e-8 iters = opt.iters # default=0 start_epoch = opt.start_epoch # default=0 best_loss = opt.best_loss # default=float('inf') best_acc = opt.best_acc # default=0 if opt.resume: # 如果有中继点 model_path = os.path.join(opt.works_dir, opt.resume) if os.path.isfile(model_path): package = torch.load(model_path, map_location=lambda storage, loc: storage) lr = package.get('lr', opt.lr) eps = package.get('eps', opt.eps) best_loss = package.get('best_loss', float('inf')) best_acc = package.get('best_acc', 0) start_epoch = int(package.get('epoch', 0)) iters = int(package.get('iters', 0)) acc_report = package.get('acc_report', acc_report) loss_report = package.get('loss_report', loss_report) visualizer.set_plot_report(acc_report, 'acc.png') visualizer.set_plot_report(loss_report, 'loss.png') asr_model = E2E.load_model(model_path, 'asr_state_dict') ##fbank_model = FbankModel.load_model(model_path, 'fbank_state_dict') logging.info('Loading model {} and iters {}'.format(model_path, iters)) else: print("no checkpoint found at {}".format(model_path)) # convert to cuda asr_model.cuda() ##fbank_model.cuda() print(asr_model) # Setup an optimizer #parameters = filter(lambda p: p.requires_grad, itertools.chain(asr_model.parameters(), fbank_model.parameters())) parameters = filter(lambda p: p.requires_grad, itertools.chain(asr_model.parameters())) # ? if opt.opt_type == 'adadelta': optimizer = torch.optim.Adadelta(parameters, rho=0.95, eps=eps) elif opt.opt_type == 'adam': optimizer = torch.optim.Adam(parameters, lr=lr, betas=(opt.beta1, 0.999)) asr_model.train() #fbank_model.train() sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_iter, opt.sche_samp_final_iter, opt.sche_samp_final_rate) sche_samp_rate = sample_rampup.update(iters) '''fbank_cmvn_file = os.path.join(opt.exp_path, 'fbank_cmvn.npy') if os.path.exists(fbank_cmvn_file): fbank_cmvn = np.load(fbank_cmvn_file) else: for i, (data) in enumerate(train_loader, start=0): utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data fbank_cmvn = fbank_model.compute_cmvn(inputs, input_sizes) if fbank_cmvn is not None: np.save(fbank_cmvn_file, fbank_cmvn) print('save fbank_cmvn to {}'.format(fbank_cmvn_file)) break fbank_cmvn = torch.FloatTensor(fbank_cmvn)''' for epoch in range(start_epoch, opt.epochs): if epoch > opt.shuffle_epoch: print("Shuffling batches for the following epochs") train_sampler.shuffle(epoch) for i, (data) in enumerate(train_loader, start=(iters*opt.batch_size)%len(train_dataset)): #utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data #fbank_features = fbank_model(inputs, fbank_cmvn) utt_ids, spk_ids, fbank_features, targets, input_sizes, target_sizes = data loss_ctc, loss_att, acc, context = asr_model(fbank_features, targets, input_sizes, target_sizes, sche_samp_rate) loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att optimizer.zero_grad() # Clear the parameter gradients loss.backward() # compute backwards # compute the gradient norm to check if it is normal or not 'fbank_state_dict': fbank_model.state_dict(), grad_norm = torch.nn.utils.clip_grad_norm_(asr_model.parameters(), opt.grad_clip) if math.isnan(grad_norm): logging.warning('grad norm is nan. Do not update model.') else: optimizer.step() iters += 1 errors = {'train/loss': loss.item(), 'train/loss_ctc': loss_ctc.item(), 'train/acc': acc, 'train/loss_att': loss_att.item()} visualizer.set_current_errors(errors) if iters % opt.print_freq == 0: visualizer.print_current_errors(epoch, iters) state = {'asr_state_dict': asr_model.state_dict(), 'opt': opt, 'epoch': epoch, 'iters': iters, 'eps': opt.eps, 'lr': opt.lr, 'best_loss': best_loss, 'best_acc': best_acc, 'acc_report': acc_report, 'loss_report': loss_report} filename='latest' utils.save_checkpoint(state, opt.exp_path, filename=filename) if iters % opt.validate_freq == 0: sche_samp_rate = sample_rampup.update(iters) print("iters {} sche_samp_rate {}".format(iters, sche_samp_rate)) asr_model.eval() #fbank_model.eval() torch.set_grad_enabled(False) num_saved_attention = 0 for i, (data) in tqdm(enumerate(val_loader, start=0)): #utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data #fbank_features = fbank_model(inputs, fbank_cmvn) utt_ids, spk_ids, fbank_features, targets, input_sizes, target_sizes = data loss_ctc, loss_att, acc, context = asr_model(fbank_features, targets, input_sizes, target_sizes, 0.0) loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att errors = {'val/loss': loss.item(), 'val/loss_ctc': loss_ctc.item(), 'val/acc': acc, 'val/loss_att': loss_att.item()} visualizer.set_current_errors(errors) if opt.num_save_attention > 0 and opt.mtlalpha != 1.0: if num_saved_attention < opt.num_save_attention: att_ws = asr_model.calculate_all_attentions(fbank_features, targets, input_sizes, target_sizes) for x in range(len(utt_ids)): att_w = att_ws[x] utt_id = utt_ids[x] file_name = "{}_ep{}_it{}.png".format(utt_id, epoch, iters) dec_len = int(target_sizes[x]) enc_len = int(input_sizes[x]) visualizer.plot_attention(att_w, dec_len, enc_len, file_name) num_saved_attention += 1 if num_saved_attention >= opt.num_save_attention: break asr_model.train() #fbank_model.train() torch.set_grad_enabled(True) visualizer.print_epoch_errors(epoch, iters) acc_report = visualizer.plot_epoch_errors(epoch, iters, 'acc.png') loss_report = visualizer.plot_epoch_errors(epoch, iters, 'loss.png') val_loss = visualizer.get_current_errors('val/loss') val_acc = visualizer.get_current_errors('val/acc') filename = None if opt.criterion == 'acc' and opt.mtl_mode is not 'ctc': if val_acc < best_acc: logging.info('val_acc {} > best_acc {}'.format(val_acc, best_acc)) opt.eps = utils.adadelta_eps_decay(optimizer, opt.eps_decay) else: filename='model.acc.best' best_acc = max(best_acc, val_acc) logging.info('best_acc {}'.format(best_acc)) elif args.criterion == 'loss': if val_loss > best_loss: logging.info('val_loss {} > best_loss {}'.format(val_loss, best_loss)) opt.eps = utils.adadelta_eps_decay(optimizer, opt.eps_decay) else: filename='model.loss.best' best_loss = min(val_loss, best_loss) logging.info('best_loss {}'.format(best_loss)) state = {'asr_state_dict': asr_model.state_dict(), 'opt': opt, 'epoch': epoch, 'iters': iters, 'eps': opt.eps, 'lr': opt.lr, 'best_loss': best_loss, 'best_acc': best_acc, 'acc_report': acc_report, 'loss_report': loss_report} utils.save_checkpoint(state, opt.exp_path, filename=filename) ##filename='epoch-{}_iters-{}_loss-{:.4f}_acc-{:.4f}.pth'.format(epoch, iters, val_loss, val_acc) ##utils.save_checkpoint(state, opt.exp_path, filename=filename) visualizer.reset()
def main(): opt = fake_opt.JointTrain() device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids) > 0 and torch.cuda.is_available() else "cpu") visualizer = Visualizer(opt) logging = visualizer.get_logger() acc_report = visualizer.add_plot_report(['train/acc', 'val/acc'], 'acc.png') loss_report = visualizer.add_plot_report( ['train/loss', 'val/loss', 'train/enhance_loss', 'val/enhance_loss'], 'loss.png') # data logging.info("Building dataset.") train_dataset = MixSequentialDataset( opt, os.path.join(opt.dataroot, 'train_new'), os.path.join(opt.dict_dir, 'train/vocab'), ) val_dataset = MixSequentialDataset( opt, os.path.join(opt.dataroot, 'dev_new'), os.path.join(opt.dict_dir, 'train/vocab'), ) train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size) train_loader = MixSequentialDataLoader(train_dataset, num_workers=opt.num_workers, batch_sampler=train_sampler) val_loader = MixSequentialDataLoader(val_dataset, batch_size=int(opt.batch_size / 2), num_workers=opt.num_workers, shuffle=False) opt.idim = train_dataset.get_feat_size() opt.odim = train_dataset.get_num_classes() opt.char_list = train_dataset.get_char_list() opt.train_dataset_len = len(train_dataset) logging.info('#input dims : ' + str(opt.idim)) logging.info('#output dims: ' + str(opt.odim)) logging.info("Dataset ready!") # Setup an model lr = opt.lr eps = opt.eps iters = opt.iters best_acc = opt.best_acc best_loss = opt.best_loss start_epoch = opt.start_epoch enhance_model_path = None if opt.enhance_resume: #enhance_model_path = os.path.join(opt.works_dir, opt.enhance_resume) enhance_model_path = "/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/enhance_fbank_train_table_2/model.loss.best" if os.path.isfile(enhance_model_path): enhance_model = EnhanceModel.load_model(enhance_model_path, 'enhance_state_dict', opt) else: print("no checkpoint found at {}".format(enhance_model_path)) asr_model_path = None if opt.asr_resume: #asr_model_path = os.path.join(opt.works_dir, opt.asr_resume) asr_model_path = "/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/asr_mix_train_table3_1/model.acc.best" if os.path.isfile(asr_model_path): #asr_model = ShareE2E.load_model(asr_model_path, 'asr_state_dict', opt) asr_model = E2E.load_model(asr_model_path, 'asr_state_dict', opt) else: print("no checkpoint found at {}".format(asr_model_path)) joint_model_path = None if opt.joint_resume: joint_model_path = os.path.join(opt.works_dir, opt.joint_resume) if os.path.isfile(joint_model_path): package = torch.load(joint_model_path, map_location=lambda storage, loc: storage) lr = package.get('lr', opt.lr) eps = package.get('eps', opt.eps) best_acc = package.get('best_acc', 0) best_loss = package.get('best_loss', float('inf')) start_epoch = int(package.get('epoch', 0)) iters = int(package.get('iters', 0)) - 1 print('joint_model_path {} and iters {}'.format( joint_model_path, iters)) ##loss_report = package.get('loss_report', loss_report) ##visualizer.set_plot_report(loss_report, 'loss.png') else: print("no checkpoint found at {}".format(joint_model_path)) if joint_model_path is not None or enhance_model_path is None: enhance_model_path_with_gan = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/enhance_gan_train_both_enhance_cmvn/model.loss.best' enhance_model = EnhanceModel.load_model(enhance_model_path_with_gan, 'enhance_state_dict', opt) if joint_model_path is not None or asr_model_path is None: #asr_model = ShareE2E.load_model(joint_model_path, 'asr_state_dict', opt) asr_model_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/asr_train/model.acc.best' asr_model = E2E.load_model(asr_model_path, 'asr_state_dict', opt) feat_model = FbankModel.load_model(joint_model_path, 'fbank_state_dict', opt) if opt.isGAN: gan_model = GANModel.load_model(enhance_model_path_with_gan, 'gan_state_dict', opt) ##set_requires_grad([enhance_model], False) # Setup an optimizer enhance_parameters = filter(lambda p: p.requires_grad, enhance_model.parameters()) asr_parameters = filter(lambda p: p.requires_grad, asr_model.parameters()) if opt.isGAN: gan_parameters = filter(lambda p: p.requires_grad, gan_model.parameters()) if opt.opt_type == 'adadelta': enhance_optimizer = torch.optim.Adadelta(enhance_parameters, rho=0.95, eps=eps) asr_optimizer = torch.optim.Adadelta(asr_parameters, rho=0.95, eps=eps) if opt.isGAN: gan_optimizer = torch.optim.Adadelta(gan_parameters, rho=0.95, eps=eps) elif opt.opt_type == 'adam': enhance_optimizer = torch.optim.Adam(enhance_parameters, lr=lr, betas=(opt.beta1, 0.999)) asr_optimizer = torch.optim.Adam(asr_parameters, lr=lr, betas=(opt.beta1, 0.999)) if opt.isGAN: gan_optimizer = torch.optim.Adam(gan_parameters, lr=lr, betas=(opt.beta1, 0.999)) if opt.isGAN: criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan).to(device) # Training #enhance_cmvn_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/joint_train/enhance_cmvn.npy' enhance_cmvn_path = None if enhance_cmvn_path: enhance_cmvn = np.load(enhance_cmvn_path) enhance_cmvn = torch.FloatTensor(enhance_cmvn) else: enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model, feat_model) sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_iter, opt.sche_samp_final_iter, opt.sche_samp_final_rate) sche_samp_rate = sample_rampup.update(iters) fbank_cmvn_file = os.path.join(opt.exp_path, 'fbank_cmvn.npy') fbank_cmvn = np.load(fbank_cmvn_file) fbank_cmvn = torch.FloatTensor(fbank_cmvn) enhance_model.train() feat_model.train() asr_model.train() for epoch in range(start_epoch, opt.epochs): if epoch > opt.shuffle_epoch: print("Shuffling batches for the following epochs") train_sampler.shuffle(epoch) for i, (data) in enumerate(train_loader, start=0): utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data enhance_out = enhance_model(mix_inputs, mix_log_inputs, input_sizes) enhance_feat = feat_model(enhance_out) clean_feat = feat_model(clean_inputs) mix_feat = feat_model(mix_inputs) if opt.enhance_loss_type == 'L2': enhance_loss = F.mse_loss(enhance_feat, clean_feat.detach()) elif opt.enhance_loss_type == 'L1': enhance_loss = F.l1_loss(enhance_feat, clean_feat.detach()) elif opt.enhance_loss_type == 'smooth_L1': enhance_loss = F.smooth_l1_loss(enhance_feat, clean_feat.detach()) enhance_loss = opt.enhance_loss_lambda * enhance_loss enhance_feature = feat_model(enhance_out, enhance_cmvn) clean_feature = feat_model(clean_inputs, fbank_cmvn) loss_ctc, loss_att, acc = asr_model(enhance_feature, targets, input_sizes, target_sizes, sche_samp_rate) #loss_ctc, loss_att, acc, clean_context, mix_context = asr_model(clean_feat, enhance_feat, targets, input_sizes, target_sizes, sche_samp_rate, enhance_cmvn) #coral_loss = opt.coral_loss_lambda * CORAL(clean_context, mix_context) coral_loss = 0 asr_loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att loss = asr_loss + enhance_loss + coral_loss #loss = asr_loss if opt.isGAN: set_requires_grad([gan_model], False) if opt.netD_type == 'pixel': fake_AB = torch.cat((mix_feat, enhance_feat), 2) else: fake_AB = enhance_feature gan_loss = opt.gan_loss_lambda * criterionGAN( gan_model(fake_AB), True) loss += gan_loss set_requires_grad([enhance_model], False) enhance_optimizer.zero_grad() asr_optimizer.zero_grad() # Clear the parameter gradients loss.backward() # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_(asr_model.parameters(), opt.grad_clip) if math.isnan(grad_norm): logging.warning('grad norm is nan. Do not update model.') else: enhance_optimizer.step() asr_optimizer.step() if opt.isGAN: set_requires_grad([gan_model], True) gan_optimizer.zero_grad() if opt.netD_type == 'pixel': fake_AB = torch.cat((mix_feat, enhance_feat), 2) real_AB = torch.cat((mix_feat, clean_feat), 2) else: fake_AB = enhance_feature real_AB = clean_feature loss_D_real = criterionGAN(gan_model(real_AB.detach()), True) loss_D_fake = criterionGAN(gan_model(fake_AB.detach()), False) loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() grad_norm = torch.nn.utils.clip_grad_norm_( gan_model.parameters(), opt.grad_clip) if math.isnan(grad_norm): logging.warning('grad norm is nan. Do not update model.') else: gan_optimizer.step() iters += 1 errors = { 'train/loss': loss.item(), 'train/loss_ctc': loss_ctc.item(), 'train/acc': acc, 'train/loss_att': loss_att.item(), 'train/enhance_loss': enhance_loss.item() } if opt.isGAN: errors['train/loss_D'] = loss_D.item() errors['train/gan_loss'] = opt.gan_loss_lambda * gan_loss.item( ) visualizer.set_current_errors(errors) if iters % opt.print_freq == 0: visualizer.print_current_errors(epoch, iters) state = { 'asr_state_dict': asr_model.state_dict(), 'fbank_state_dict': feat_model.state_dict(), 'enhance_state_dict': enhance_model.state_dict(), 'opt': opt, 'epoch': epoch, 'iters': iters, 'eps': opt.eps, 'lr': opt.lr, 'best_loss': best_loss, 'best_acc': best_acc, 'acc_report': acc_report, 'loss_report': loss_report } if opt.isGAN: state['gan_state_dict'] = gan_model.state_dict() filename = 'latest' utils.save_checkpoint(state, opt.exp_path, filename=filename) if iters % opt.validate_freq == 0: sche_samp_rate = sample_rampup.update(iters) print("iters {} sche_samp_rate {}".format( iters, sche_samp_rate)) enhance_model.eval() feat_model.eval() asr_model.eval() torch.set_grad_enabled(False) num_saved_attention = 0 for i, (data) in tqdm(enumerate(val_loader, start=0)): utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data enhance_out = enhance_model(mix_inputs, mix_log_inputs, input_sizes) enhance_feat = feat_model(enhance_out) clean_feat = feat_model(clean_inputs) mix_feat = feat_model(mix_inputs) clean_feat_val = feat_model(clean_inputs, fbank_cmvn) enhance_feat_val = feat_model(enhance_out, enhance_cmvn) if opt.enhance_loss_type == 'L2': enhance_loss = F.mse_loss(enhance_feat, clean_feat.detach()) elif opt.enhance_loss_type == 'L1': enhance_loss = F.l1_loss(enhance_feat, clean_feat.detach()) elif opt.enhance_loss_type == 'smooth_L1': enhance_loss = F.smooth_l1_loss( enhance_feat, clean_feat.detach()) if opt.isGAN: set_requires_grad([gan_model], False) if opt.netD_type == 'pixel': fake_AB = torch.cat((mix_feat, enhance_feat), 2) else: fake_AB = enhance_feat_val gan_loss = criterionGAN(gan_model(fake_AB), True) enhance_loss += opt.gan_loss_lambda * gan_loss #loss_ctc, loss_att, acc, clean_context, mix_context = asr_model(clean_feat, enhance_feat, targets, input_sizes, target_sizes, 0.0, enhance_cmvn) loss_ctc, loss_att, acc = asr_model( enhance_feat_val, targets, input_sizes, target_sizes, sche_samp_rate) asr_loss = opt.mtlalpha * loss_ctc + ( 1 - opt.mtlalpha) * loss_att enhance_loss = opt.enhance_loss_lambda * enhance_loss loss = asr_loss + enhance_loss errors = { 'val/loss': loss.item(), 'val/loss_ctc': loss_ctc.item(), 'val/acc': acc, 'val/loss_att': loss_att.item(), 'val/enhance_loss': enhance_loss.item() } if opt.isGAN: errors[ 'val/gan_loss'] = opt.gan_loss_lambda * gan_loss.item( ) visualizer.set_current_errors(errors) if opt.num_save_attention > 0 and opt.mtlalpha != 1.0: if num_saved_attention < opt.num_save_attention: att_ws = asr_model.calculate_all_attentions( enhance_feat_val, targets, input_sizes, target_sizes) for x in range(len(utt_ids)): att_w = att_ws[x] utt_id = utt_ids[x] file_name = "{}_ep{}_it{}.png".format( utt_id, epoch, iters) dec_len = int(target_sizes[x]) enc_len = int(input_sizes[x]) visualizer.plot_attention( att_w, dec_len, enc_len, file_name) num_saved_attention += 1 if num_saved_attention >= opt.num_save_attention: break enhance_model.train() feat_model.train() asr_model.train() torch.set_grad_enabled(True) visualizer.print_epoch_errors(epoch, iters) acc_report = visualizer.plot_epoch_errors( epoch, iters, 'acc.png') loss_report = visualizer.plot_epoch_errors( epoch, iters, 'loss.png') val_loss = visualizer.get_current_errors('val/loss') val_acc = visualizer.get_current_errors('val/acc') filename = None if opt.criterion == 'acc' and opt.mtl_mode is not 'ctc': if val_acc < best_acc: logging.info('val_acc {} > best_acc {}'.format( val_acc, best_acc)) opt.eps = utils.adadelta_eps_decay( asr_optimizer, opt.eps_decay) else: filename = 'model.acc.best' best_acc = max(best_acc, val_acc) logging.info('best_acc {}'.format(best_acc)) elif opt.criterion == 'loss': if val_loss > best_loss: logging.info('val_loss {} > best_loss {}'.format( val_loss, best_loss)) opt.eps = utils.adadelta_eps_decay( asr_optimizer, opt.eps_decay) else: filename = 'model.loss.best' best_loss = min(val_loss, best_loss) logging.info('best_loss {}'.format(best_loss)) state = { 'asr_state_dict': asr_model.state_dict(), 'fbank_state_dict': feat_model.state_dict(), 'enhance_state_dict': enhance_model.state_dict(), 'opt': opt, 'epoch': epoch, 'iters': iters, 'eps': opt.eps, 'lr': opt.lr, 'best_loss': best_loss, 'best_acc': best_acc, 'acc_report': acc_report, 'loss_report': loss_report } if opt.isGAN: state['gan_state_dict'] = gan_model.state_dict() utils.save_checkpoint(state, opt.exp_path, filename=filename) visualizer.reset() enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model, feat_model)