Exemple #1
0
def main():

    opt = TrainOptions().parse()
    device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids)
                          > 0 and torch.cuda.is_available() else "cpu")

    visualizer = Visualizer(opt)
    logging = visualizer.get_logger()
    loss_report = visualizer.add_plot_report(['train/loss', 'val/loss'],
                                             'loss.png')

    # data
    logging.info("Building dataset.")
    train_dataset = MixSequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'train'),
        os.path.join(opt.dict_dir, 'train_units.txt'),
    )
    val_dataset = MixSequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'dev'),
        os.path.join(opt.dict_dir, 'train_units.txt'),
    )
    train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size)
    train_loader = MixSequentialDataLoader(train_dataset,
                                           num_workers=opt.num_workers,
                                           batch_sampler=train_sampler)
    val_loader = MixSequentialDataLoader(val_dataset,
                                         batch_size=int(opt.batch_size / 2),
                                         num_workers=opt.num_workers,
                                         shuffle=False)
    opt.idim = train_dataset.get_feat_size()
    opt.odim = train_dataset.get_num_classes()
    opt.char_list = train_dataset.get_char_list()
    opt.train_dataset_len = len(train_dataset)
    logging.info('#input dims : ' + str(opt.idim))
    logging.info('#output dims: ' + str(opt.odim))
    logging.info("Dataset ready!")

    lr = opt.lr
    eps = opt.eps
    iters = opt.iters
    best_loss = opt.best_loss
    start_epoch = opt.start_epoch
    model_path = None
    if opt.enhace_resume:
        model_path = os.path.join(opt.works_dir, opt.enhace_resume)
        if os.path.isfile(model_path):
            package = torch.load(model_path,
                                 map_location=lambda storage, loc: storage)
            lr = package.get('lr', opt.lr)
            eps = package.get('eps', opt.eps)
            best_loss = package.get('best_loss', float('inf'))
            start_epoch = int(package.get('epoch', 0))
            iters = int(package.get('iters', 0))
            loss_report = package.get('loss_report', loss_report)
            visualizer.set_plot_report(loss_report, 'loss.png')
            print('package found at {} and start_epoch {} iters {}',
                  format(model_path, start_epoch, iters))
        else:
            print("no checkpoint found at {}".format(model_path))
    enhance_model = EnhanceModel.load_model(model_path, 'enhance_state_dict',
                                            opt)

    # Setup an optimizer
    enhance_parameters = filter(lambda p: p.requires_grad,
                                enhance_model.parameters())
    if opt.opt_type == 'adadelta':
        enhance_optimizer = torch.optim.Adadelta(enhance_parameters,
                                                 rho=0.95,
                                                 eps=opt.eps)
    elif opt.opt_type == 'adam':
        enhance_optimizer = torch.optim.Adam(enhance_parameters,
                                             lr=opt.lr,
                                             betas=(opt.beta1, 0.999))

    # Training
    for epoch in range(start_epoch, opt.epochs):
        enhance_model.train()
        if epoch > opt.shuffle_epoch:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle(epoch)
        for i, (data) in enumerate(train_loader,
                                   start=(iters % len(train_dataset))):
            utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
            loss, enhance_out = enhance_model(clean_inputs, mix_inputs,
                                              mix_log_inputs, cos_angles,
                                              input_sizes)
            enhance_optimizer.zero_grad()  # Clear the parameter gradients
            loss.backward()
            # compute the gradient norm to check if it is normal or not
            grad_norm = torch.nn.utils.clip_grad_norm_(
                enhance_model.parameters(), opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                enhance_optimizer.step()

            iters += 1
            errors = {'train/loss': loss.item()}
            visualizer.set_current_errors(errors)
            if iters % opt.print_freq == 0:
                visualizer.print_current_errors(epoch, iters)
                state = {
                    'enhance_state_dict': enhance_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': eps,
                    'lr': lr,
                    'best_loss': best_loss,
                    'loss_report': loss_report
                }
                filename = 'latest'
                utils.save_checkpoint(state, opt.exp_path, filename=filename)

            if iters % opt.validate_freq == 0:
                enhance_model.eval()
                torch.set_grad_enabled(False)
                num_saved_specgram = 0
                for i, (data) in tqdm(enumerate(val_loader, start=0)):
                    utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
                    loss, enhance_out = enhance_model(clean_inputs, mix_inputs,
                                                      mix_log_inputs,
                                                      cos_angles, input_sizes)
                    errors = {'val/loss': loss.item()}
                    visualizer.set_current_errors(errors)

                    if opt.num_saved_specgram > 0:
                        if num_saved_specgram < opt.num_saved_specgram:
                            enhanced_outs = enhance_model.calculate_all_specgram(
                                mix_inputs, mix_log_inputs, input_sizes)
                            for x in range(len(utt_ids)):
                                enhanced_out = enhanced_outs[x].data.cpu(
                                ).numpy()
                                enhanced_out[enhanced_out <= 1e-7] = 1e-7
                                enhanced_out = np.log10(enhanced_out)
                                clean_input = clean_inputs[x].data.cpu().numpy(
                                )
                                clean_input[clean_input <= 1e-7] = 1e-7
                                clean_input = np.log10(clean_input)
                                mix_input = mix_inputs[x].data.cpu().numpy()
                                mix_input[mix_input <= 1e-7] = 1e-7
                                mix_input = np.log10(mix_input)
                                utt_id = utt_ids[x]
                                file_name = "{}_ep{}_it{}.png".format(
                                    utt_id, epoch, iters)
                                input_size = int(input_sizes[x])
                                visualizer.plot_specgram(
                                    clean_input, mix_input, enhanced_out,
                                    input_size, file_name)
                                num_saved_specgram += 1
                                if num_saved_specgram >= opt.num_saved_specgram:
                                    break
                enhance_model.train()
                torch.set_grad_enabled(True)

                visualizer.print_epoch_errors(epoch, iters)
                loss_report = visualizer.plot_epoch_errors(
                    epoch, iters, 'loss.png')
                train_loss = visualizer.get_current_errors('train/loss')
                val_loss = visualizer.get_current_errors('val/loss')
                filename = None
                if val_loss > best_loss:
                    print('val_loss {} > best_loss {}'.format(
                        val_loss, best_loss))
                    eps = utils.adadelta_eps_decay(optimizer, opt.eps_decay)
                else:
                    filename = 'model.loss.best'
                best_loss = min(val_loss, best_loss)
                print('best_loss {}'.format(best_loss))

                state = {
                    'enhance_state_dict': enhance_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': eps,
                    'lr': lr,
                    'best_loss': best_loss,
                    'loss_report': loss_report
                }
                ##filename='epoch-{}_iters-{}_loss-{:.6f}-{:.6f}.pth'.format(epoch, iters, train_loss, val_loss)
                utils.save_checkpoint(state, opt.exp_path, filename=filename)
                visualizer.reset()
Exemple #2
0
def main():    
    opt = TrainOptions().parse()    
    device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids) > 0 and torch.cuda.is_available() else "cpu")
     
    visualizer = Visualizer(opt)  
    logging = visualizer.get_logger()
    acc_report = visualizer.add_plot_report(['train/acc', 'val/acc'], 'acc.png')
    loss_report = visualizer.add_plot_report(['train/loss', 'val/loss', 'train/enhance_loss', 'val/enhance_loss'], 'loss.png')
     
    # data
    logging.info("Building dataset.")
    train_dataset = MixSequentialDataset(opt, os.path.join(opt.dataroot, 'train'), os.path.join(opt.dict_dir, 'train_units.txt'),) 
    val_dataset   = MixSequentialDataset(opt, os.path.join(opt.dataroot, 'dev'), os.path.join(opt.dict_dir, 'train_units.txt'),)
    train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size) 
    train_loader  = MixSequentialDataLoader(train_dataset, num_workers=opt.num_workers, batch_sampler=train_sampler)
    val_loader    = MixSequentialDataLoader(val_dataset, batch_size=int(opt.batch_size/2), num_workers=opt.num_workers, shuffle=False)
    opt.idim = train_dataset.get_feat_size()
    opt.odim = train_dataset.get_num_classes()
    opt.char_list = train_dataset.get_char_list()
    opt.train_dataset_len = len(train_dataset)
    logging.info('#input dims : ' + str(opt.idim))
    logging.info('#output dims: ' + str(opt.odim))
    logging.info("Dataset ready!")
    
    # Setup an model
    lr = opt.lr
    eps = opt.eps
    iters = opt.iters   
    best_acc = opt.best_acc 
    best_loss = opt.best_loss  
    start_epoch = opt.start_epoch
    
    enhance_model_path = None
    if opt.enhance_resume:
        enhance_model_path = os.path.join(opt.works_dir, opt.enhance_resume)
        if os.path.isfile(enhance_model_path):
            enhance_model = EnhanceModel.load_model(enhance_model_path, 'enhance_state_dict', opt)
        else:
            print("no checkpoint found at {}".format(enhance_model_path))     
    
    asr_model_path = None
    if opt.asr_resume:
        asr_model_path = os.path.join(opt.works_dir, opt.asr_resume)
        if os.path.isfile(asr_model_path):
            asr_model = ShareE2E.load_model(asr_model_path, 'asr_state_dict', opt)
        else:
            print("no checkpoint found at {}".format(asr_model_path))  
                                        
    joint_model_path = None
    if opt.joint_resume:
        joint_model_path = os.path.join(opt.works_dir, opt.joint_resume)
        if os.path.isfile(joint_model_path):
            package = torch.load(joint_model_path, map_location=lambda storage, loc: storage)
            lr = package.get('lr', opt.lr)
            eps = package.get('eps', opt.eps)  
            best_acc = package.get('best_acc', 0)      
            best_loss = package.get('best_loss', float('inf'))
            start_epoch = int(package.get('epoch', 0))   
            iters = int(package.get('iters', 0)) - 1   
            print('joint_model_path {} and iters {}'.format(joint_model_path, iters))        
            ##loss_report = package.get('loss_report', loss_report)
            ##visualizer.set_plot_report(loss_report, 'loss.png')
        else:
            print("no checkpoint found at {}".format(joint_model_path))
    if joint_model_path is not None or enhance_model_path is None:     
        enhance_model = EnhanceModel.load_model(joint_model_path, 'enhance_state_dict', opt)    
    if joint_model_path is not None or asr_model_path is None:  
        asr_model = ShareE2E.load_model(joint_model_path, 'asr_state_dict', opt)     
    feat_model = FbankModel.load_model(joint_model_path, 'fbank_state_dict', opt) 
    if opt.isGAN:
        gan_model = GANModel.load_model(joint_model_path, 'gan_state_dict', opt) 
    ##set_requires_grad([enhance_model], False)    
    
    # Setup an optimizer
    enhance_parameters = filter(lambda p: p.requires_grad, enhance_model.parameters())
    asr_parameters = filter(lambda p: p.requires_grad, asr_model.parameters())
    if opt.isGAN:
        gan_parameters = filter(lambda p: p.requires_grad, gan_model.parameters())   
    if opt.opt_type == 'adadelta':
        enhance_optimizer = torch.optim.Adadelta(enhance_parameters, rho=0.95, eps=eps)
        asr_optimizer = torch.optim.Adadelta(asr_parameters, rho=0.95, eps=eps)
        if opt.isGAN:
            gan_optimizer = torch.optim.Adadelta(gan_parameters, rho=0.95, eps=eps)
    elif opt.opt_type == 'adam':
        enhance_optimizer = torch.optim.Adam(enhance_parameters, lr=lr, betas=(opt.beta1, 0.999))   
        asr_optimizer = torch.optim.Adam(asr_parameters, lr=lr, betas=(opt.beta1, 0.999)) 
        if opt.isGAN:                      
            gan_optimizer = torch.optim.Adam(gan_parameters, lr=lr, betas=(opt.beta1, 0.999))
    if opt.isGAN:
        criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan).to(device)
       
    # Training	
    enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model, feat_model) 
    sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_iter, opt.sche_samp_final_iter, opt.sche_samp_final_rate)  
    sche_samp_rate = sample_rampup.update(iters)
    
    enhance_model.train()
    feat_model.train()
    asr_model.train()               	                    
    for epoch in range(start_epoch, opt.epochs):               
        if epoch > opt.shuffle_epoch:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle(epoch)  
        for i, (data) in enumerate(train_loader, start=0):
            utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
            enhance_out = enhance_model(mix_inputs, mix_log_inputs, input_sizes) 
            enhance_feat = feat_model(enhance_out)
            clean_feat = feat_model(clean_inputs)
            mix_feat = feat_model(mix_inputs)
            if opt.enhance_loss_type == 'L2':
                enhance_loss = F.mse_loss(enhance_feat, clean_feat.detach())
            elif opt.enhance_loss_type == 'L1':
                enhance_loss = F.l1_loss(enhance_feat, clean_feat.detach())
            elif opt.enhance_loss_type == 'smooth_L1':
                enhance_loss = F.smooth_l1_loss(enhance_feat, clean_feat.detach())
            enhance_loss = opt.enhance_loss_lambda * enhance_loss
                
            loss_ctc, loss_att, acc, clean_context, mix_context = asr_model(clean_feat, enhance_feat, targets, input_sizes, target_sizes, sche_samp_rate, enhance_cmvn) 
            coral_loss = opt.coral_loss_lambda * CORAL(clean_context, mix_context)              
            asr_loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att
            loss = asr_loss + enhance_loss + coral_loss
                    
            if opt.isGAN:
                set_requires_grad([gan_model], False)
                if opt.netD_type == 'pixel':
                    fake_AB = torch.cat((mix_feat, enhance_feat), 2)
                else:
                    fake_AB = enhance_feat
                gan_loss = opt.gan_loss_lambda * criterionGAN(gan_model(fake_AB, enhance_cmvn), True)
                loss += gan_loss
                                              
            enhance_optimizer.zero_grad()
            asr_optimizer.zero_grad()  # Clear the parameter gradients
            loss.backward()          
            # compute the gradient norm to check if it is normal or not
            grad_norm = torch.nn.utils.clip_grad_norm_(asr_model.parameters(), opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                enhance_optimizer.step()
                asr_optimizer.step()                
            
            if opt.isGAN:
                set_requires_grad([gan_model], True)   
                gan_optimizer.zero_grad()
                if opt.netD_type == 'pixel':
                    fake_AB = torch.cat((mix_feat, enhance_feat), 2)
                    real_AB = torch.cat((mix_feat, clean_feat), 2)
                else:
                    fake_AB = enhance_feat
                    real_AB = clean_feat
                loss_D_real = criterionGAN(gan_model(real_AB.detach(), enhance_cmvn), True)
                loss_D_fake = criterionGAN(gan_model(fake_AB.detach(), enhance_cmvn), False)
                loss_D = (loss_D_real + loss_D_fake) * 0.5
                loss_D.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(gan_model.parameters(), opt.grad_clip)
                if math.isnan(grad_norm):
                    logging.warning('grad norm is nan. Do not update model.')
                else:
                    gan_optimizer.step()
                                               
            iters += 1
            errors = {'train/loss': loss.item(), 'train/loss_ctc': loss_ctc.item(), 
                      'train/acc': acc, 'train/loss_att': loss_att.item(), 
                      'train/enhance_loss': enhance_loss.item(), 'train/coral_loss': coral_loss.item()}
            if opt.isGAN:
                errors['train/loss_D'] = loss_D.item()
                errors['train/gan_loss'] = opt.gan_loss_lambda * gan_loss.item()  
              
            visualizer.set_current_errors(errors)
            if iters % opt.print_freq == 0:
                visualizer.print_current_errors(epoch, iters)
                state = {'asr_state_dict': asr_model.state_dict(), 
                         'fbank_state_dict': feat_model.state_dict(), 
                         'enhance_state_dict': enhance_model.state_dict(), 
                         'opt': opt, 'epoch': epoch, 'iters': iters, 
                         'eps': opt.eps, 'lr': opt.lr,                                    
                         'best_loss': best_loss, 'best_acc': best_acc, 
                         'acc_report': acc_report, 'loss_report': loss_report}
                if opt.isGAN:
                    state['gan_state_dict'] = gan_model.state_dict()
                filename='latest'
                utils.save_checkpoint(state, opt.exp_path, filename=filename)
                    
            if iters % opt.validate_freq == 0:
                sche_samp_rate = sample_rampup.update(iters)
                print("iters {} sche_samp_rate {}".format(iters, sche_samp_rate))    
                enhance_model.eval() 
                feat_model.eval() 
                asr_model.eval()
                torch.set_grad_enabled(False)                
                num_saved_attention = 0 
                for i, (data) in tqdm(enumerate(val_loader, start=0)):
                    utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
                    enhance_out = enhance_model(mix_inputs, mix_log_inputs, input_sizes)                         
                    enhance_feat = feat_model(enhance_out)
                    clean_feat = feat_model(clean_inputs)
                    mix_feat = feat_model(mix_inputs)
                    if opt.enhance_loss_type == 'L2':
                        enhance_loss = F.mse_loss(enhance_feat, clean_feat.detach())
                    elif opt.enhance_loss_type == 'L1':
                        enhance_loss = F.l1_loss(enhance_feat, clean_feat.detach())
                    elif opt.enhance_loss_type == 'smooth_L1':
                        enhance_loss = F.smooth_l1_loss(enhance_feat, clean_feat.detach())
                    if opt.isGAN:
                        set_requires_grad([gan_model], False)
                        if opt.netD_type == 'pixel':
                            fake_AB = torch.cat((mix_feat, enhance_feat), 2)
                        else:
                            fake_AB = enhance_feat
                        gan_loss = criterionGAN(gan_model(fake_AB, enhance_cmvn), True)
                        enhance_loss += opt.gan_loss_lambda * gan_loss
                        
                    loss_ctc, loss_att, acc, clean_context, mix_context = asr_model(clean_feat, enhance_feat, targets, input_sizes, target_sizes, 0.0, enhance_cmvn)
                                                  
                    asr_loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att
                    enhance_loss = opt.enhance_loss_lambda * enhance_loss
                    loss = asr_loss + enhance_loss                          
                    errors = {'val/loss': loss.item(), 'val/loss_ctc': loss_ctc.item(), 
                              'val/acc': acc, 'val/loss_att': loss_att.item(),
                              'val/enhance_loss': enhance_loss.item()}
                    if opt.isGAN:        
                        errors['val/gan_loss'] = opt.gan_loss_lambda * gan_loss.item()  
                    visualizer.set_current_errors(errors)
                
                    if opt.num_save_attention > 0 and opt.mtlalpha != 1.0:
                        if num_saved_attention < opt.num_save_attention:
                            att_ws = asr_model.calculate_all_attentions(enhance_feat, targets, input_sizes, target_sizes, enhance_cmvn)                            
                            for x in range(len(utt_ids)):
                                att_w = att_ws[x]
                                utt_id = utt_ids[x]
                                file_name = "{}_ep{}_it{}.png".format(utt_id, epoch, iters)
                                dec_len = int(target_sizes[x])
                                enc_len = int(input_sizes[x]) 
                                visualizer.plot_attention(att_w, dec_len, enc_len, file_name) 
                                num_saved_attention += 1
                                if num_saved_attention >= opt.num_save_attention:   
                                    break 
                enhance_model.train()
                feat_model.train()
                asr_model.train() 
                torch.set_grad_enabled(True)  
				
                visualizer.print_epoch_errors(epoch, iters)  
                acc_report = visualizer.plot_epoch_errors(epoch, iters, 'acc.png') 
                loss_report = visualizer.plot_epoch_errors(epoch, iters, 'loss.png') 
                val_loss = visualizer.get_current_errors('val/loss')
                val_acc = visualizer.get_current_errors('val/acc') 
                filename = None                
                if opt.criterion == 'acc' and opt.mtl_mode is not 'ctc':
                    if val_acc < best_acc:
                        logging.info('val_acc {} > best_acc {}'.format(val_acc, best_acc))
                        opt.eps = utils.adadelta_eps_decay(asr_optimizer, opt.eps_decay)
                    else:
                        filename='model.acc.best'                    
                    best_acc = max(best_acc, val_acc)
                    logging.info('best_acc {}'.format(best_acc))  
                elif args.criterion == 'loss':
                    if val_loss > best_loss:
                        logging.info('val_loss {} > best_loss {}'.format(val_loss, best_loss))
                        opt.eps = utils.adadelta_eps_decay(asr_optimizer, opt.eps_decay)
                    else:
                        filename='model.loss.best'    
                    best_loss = min(val_loss, best_loss)
                    logging.info('best_loss {}'.format(best_loss))                  
                state = {'asr_state_dict': asr_model.state_dict(), 
                         'fbank_state_dict': feat_model.state_dict(), 
                         'enhance_state_dict': enhance_model.state_dict(), 
                         'opt': opt, 'epoch': epoch, 'iters': iters, 
                         'eps': opt.eps, 'lr': opt.lr,                                    
                         'best_loss': best_loss, 'best_acc': best_acc, 
                         'acc_report': acc_report, 'loss_report': loss_report}
                if opt.isGAN:
                    state['gan_state_dict'] = gan_model.state_dict()
                utils.save_checkpoint(state, opt.exp_path, filename=filename)                  
                visualizer.reset()  
                enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model, feat_model)  
def main():
    
    opt = TrainOptions().parse()    
    device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids) > 0 and torch.cuda.is_available() else "cpu")
     
    visualizer = Visualizer(opt)  
    logging = visualizer.get_logger()
    acc_report = visualizer.add_plot_report(['train/acc', 'val/acc'], 'acc.png')
    loss_report = visualizer.add_plot_report(['train/loss', 'val/loss'], 'loss.png')
     
    # data
    logging.info("Building dataset.")
    # train目录 和 dict目录,作为输入
    train_dataset = SequentialDataset(opt, os.path.join(opt.dataroot, 'train'), os.path.join(opt.dict_dir, 'train_units.txt'),) 
    val_dataset = SequentialDataset(opt, os.path.join(opt.dataroot, 'dev'), os.path.join(opt.dict_dir, 'train_units.txt'),)    
    
    train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size) 
    
    train_loader = SequentialDataLoader(train_dataset, num_workers=opt.num_workers, batch_sampler=train_sampler)
    val_loader = SequentialDataLoader(val_dataset, batch_size=int(opt.batch_size/2), num_workers=opt.num_workers, shuffle=False)
    
    opt.idim = train_dataset.get_feat_size()
    opt.odim = train_dataset.get_num_classes()
    opt.char_list = train_dataset.get_char_list()
    opt.train_dataset_len = len(train_dataset)
    
    logging.info('#input dims : ' + str(opt.idim))
    logging.info('#output dims: ' + str(opt.odim))
    logging.info("Dataset ready!")
    
    # Setup a model
    asr_model = E2E(opt)
    ##fbank_model = FbankModel(opt)
    lr = opt.lr       # default=0.005
    eps = opt.eps     # default=1e-8
    iters = opt.iters # default=0
    start_epoch = opt.start_epoch     # default=0
    best_loss = opt.best_loss # default=float('inf')
    best_acc = opt.best_acc # default=0
    
    if opt.resume:
        # 如果有中继点
        model_path = os.path.join(opt.works_dir, opt.resume)
        if os.path.isfile(model_path):
            package = torch.load(model_path, map_location=lambda storage, loc: storage)
            lr = package.get('lr', opt.lr)
            eps = package.get('eps', opt.eps)        
            best_loss = package.get('best_loss', float('inf'))
            best_acc = package.get('best_acc', 0)
            start_epoch = int(package.get('epoch', 0))   
            iters = int(package.get('iters', 0))
            
            acc_report = package.get('acc_report', acc_report)
            loss_report = package.get('loss_report', loss_report)
            visualizer.set_plot_report(acc_report, 'acc.png')
            visualizer.set_plot_report(loss_report, 'loss.png')
            
            asr_model = E2E.load_model(model_path, 'asr_state_dict') 
            ##fbank_model = FbankModel.load_model(model_path, 'fbank_state_dict') 
            logging.info('Loading model {} and iters {}'.format(model_path, iters))
        else:
            print("no checkpoint found at {}".format(model_path)) 
    # convert to cuda               
    asr_model.cuda()
    ##fbank_model.cuda()
    print(asr_model)
  
    # Setup an optimizer
    #parameters = filter(lambda p: p.requires_grad, itertools.chain(asr_model.parameters(), fbank_model.parameters()))
    parameters = filter(lambda p: p.requires_grad, itertools.chain(asr_model.parameters())) # ?
    if opt.opt_type == 'adadelta':
        optimizer = torch.optim.Adadelta(parameters, rho=0.95, eps=eps)
    elif opt.opt_type == 'adam':
        optimizer = torch.optim.Adam(parameters, lr=lr, betas=(opt.beta1, 0.999))                       
           
    asr_model.train()
    #fbank_model.train()    
    sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_iter, opt.sche_samp_final_iter, opt.sche_samp_final_rate)  
    sche_samp_rate = sample_rampup.update(iters)
    
    '''fbank_cmvn_file = os.path.join(opt.exp_path, 'fbank_cmvn.npy')
    if os.path.exists(fbank_cmvn_file):
        fbank_cmvn = np.load(fbank_cmvn_file)
    else:
        for i, (data) in enumerate(train_loader, start=0):
            utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data
            fbank_cmvn = fbank_model.compute_cmvn(inputs, input_sizes)
            if fbank_cmvn is not None:
                np.save(fbank_cmvn_file, fbank_cmvn)
                print('save fbank_cmvn to {}'.format(fbank_cmvn_file))
                break
    fbank_cmvn = torch.FloatTensor(fbank_cmvn)'''
                     
    for epoch in range(start_epoch, opt.epochs):
        if epoch > opt.shuffle_epoch:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle(epoch)                    
        for i, (data) in enumerate(train_loader, start=(iters*opt.batch_size)%len(train_dataset)):
            #utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data
            #fbank_features = fbank_model(inputs, fbank_cmvn)
            utt_ids, spk_ids, fbank_features, targets, input_sizes, target_sizes = data
            
            loss_ctc, loss_att, acc, context = asr_model(fbank_features, targets, input_sizes, target_sizes, sche_samp_rate) 
            loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att
            
            optimizer.zero_grad()  # Clear the parameter gradients
            loss.backward()        # compute backwards
                      
            # compute the gradient norm to check if it is normal or not 'fbank_state_dict': fbank_model.state_dict(), 
            grad_norm = torch.nn.utils.clip_grad_norm_(asr_model.parameters(), opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                optimizer.step()
                
            iters += 1
            errors = {'train/loss': loss.item(), 'train/loss_ctc': loss_ctc.item(), 
                      'train/acc': acc, 'train/loss_att': loss_att.item()}
            visualizer.set_current_errors(errors)
            if iters % opt.print_freq == 0:
                visualizer.print_current_errors(epoch, iters)
                state = {'asr_state_dict': asr_model.state_dict(), 
                         'opt': opt, 'epoch': epoch, 'iters': iters, 
                         'eps': opt.eps, 'lr': opt.lr,                                    
                         'best_loss': best_loss, 'best_acc': best_acc, 
                         'acc_report': acc_report, 'loss_report': loss_report}
                filename='latest'
                utils.save_checkpoint(state, opt.exp_path, filename=filename)
                    
            if iters % opt.validate_freq == 0:
                sche_samp_rate = sample_rampup.update(iters)
                print("iters {} sche_samp_rate {}".format(iters, sche_samp_rate))  
                asr_model.eval()
                #fbank_model.eval() 
                torch.set_grad_enabled(False)
                num_saved_attention = 0
                for i, (data) in tqdm(enumerate(val_loader, start=0)):
                    #utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data
                    #fbank_features = fbank_model(inputs, fbank_cmvn)
                    utt_ids, spk_ids, fbank_features, targets, input_sizes, target_sizes = data
                    loss_ctc, loss_att, acc, context = asr_model(fbank_features, targets, input_sizes, target_sizes, 0.0) 
                    loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att                            
                    errors = {'val/loss': loss.item(), 'val/loss_ctc': loss_ctc.item(), 
                              'val/acc': acc, 'val/loss_att': loss_att.item()}
                    visualizer.set_current_errors(errors)
                    
                    if opt.num_save_attention > 0 and opt.mtlalpha != 1.0:
                        if num_saved_attention < opt.num_save_attention:
                            att_ws = asr_model.calculate_all_attentions(fbank_features, targets, input_sizes, target_sizes)                            
                            for x in range(len(utt_ids)):
                                att_w = att_ws[x]
                                utt_id = utt_ids[x]
                                file_name = "{}_ep{}_it{}.png".format(utt_id, epoch, iters)
                                dec_len = int(target_sizes[x])
                                enc_len = int(input_sizes[x]) 
                                visualizer.plot_attention(att_w, dec_len, enc_len, file_name) 
                                num_saved_attention += 1
                                if num_saved_attention >= opt.num_save_attention:   
                                    break                                                       
                asr_model.train()
                #fbank_model.train() 
                torch.set_grad_enabled(True)
                
                visualizer.print_epoch_errors(epoch, iters)  
                acc_report = visualizer.plot_epoch_errors(epoch, iters, 'acc.png') 
                loss_report = visualizer.plot_epoch_errors(epoch, iters, 'loss.png') 
                val_loss = visualizer.get_current_errors('val/loss')
                val_acc = visualizer.get_current_errors('val/acc')  
                filename = None              
                if opt.criterion == 'acc' and opt.mtl_mode is not 'ctc':
                    if val_acc < best_acc:
                        logging.info('val_acc {} > best_acc {}'.format(val_acc, best_acc))
                        opt.eps = utils.adadelta_eps_decay(optimizer, opt.eps_decay)
                    else:
                        filename='model.acc.best'                    
                    best_acc = max(best_acc, val_acc)
                    logging.info('best_acc {}'.format(best_acc))  
                elif args.criterion == 'loss':
                    if val_loss > best_loss:
                        logging.info('val_loss {} > best_loss {}'.format(val_loss, best_loss))
                        opt.eps = utils.adadelta_eps_decay(optimizer, opt.eps_decay)
                    else:
                        filename='model.loss.best'    
                    best_loss = min(val_loss, best_loss)
                    logging.info('best_loss {}'.format(best_loss))                  
                state = {'asr_state_dict': asr_model.state_dict(), 
                         'opt': opt, 'epoch': epoch, 'iters': iters, 
                         'eps': opt.eps, 'lr': opt.lr,                                    
                         'best_loss': best_loss, 'best_acc': best_acc, 
                         'acc_report': acc_report, 'loss_report': loss_report}
                utils.save_checkpoint(state, opt.exp_path, filename=filename)    
                ##filename='epoch-{}_iters-{}_loss-{:.4f}_acc-{:.4f}.pth'.format(epoch, iters, val_loss, val_acc)
                ##utils.save_checkpoint(state, opt.exp_path, filename=filename)                  
                visualizer.reset() 
def main():
    opt = fake_opt.Enhance_gan_train()

    device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids)
                          > 0 and torch.cuda.is_available() else "cpu")

    visualizer = Visualizer(opt)
    logging = visualizer.get_logger()
    loss_report = visualizer.add_plot_report([
        'train/loss', 'val/loss', 'train/gan_loss', 'train/enhance_loss',
        'train/loss_D'
    ], 'loss.png')

    # data
    logging.info("Building dataset.")
    '''
    train_dataset = MixSequentialDataset(opt, os.path.join(opt.dataroot, 'train'), os.path.join(opt.dict_dir, 'train_units.txt'),) 
    val_dataset   = MixSequentialDataset(opt, os.path.join(opt.dataroot, 'dev'), os.path.join(opt.dict_dir, 'train_units.txt'),)
    train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size) 
    train_loader = MixSequentialDataLoader(train_dataset, num_workers=opt.num_workers, batch_sampler=train_sampler)
    val_loader = MixSequentialDataLoader(val_dataset, batch_size=int(opt.batch_size/2), num_workers=opt.num_workers, shuffle=False)
    opt.idim = train_dataset.get_feat_size()
    opt.odim = train_dataset.get_num_classes()
    opt.char_list = train_dataset.get_char_list()
    opt.train_dataset_len = len(train_dataset)
    logging.info('#input dims : ' + str(opt.idim))
    logging.info('#output dims: ' + str(opt.odim))
    logging.info("Dataset ready!")'''

    train_dataset = MixSequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'train'),
        os.path.join(opt.dict_dir, 'train/vocab'),
    )
    val_dataset = MixSequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'dev'),
        os.path.join(opt.dict_dir, 'train/vocab'),
    )
    train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size)
    train_loader = MixSequentialDataLoader(train_dataset,
                                           num_workers=opt.num_workers,
                                           batch_sampler=train_sampler)
    val_loader = MixSequentialDataLoader(val_dataset,
                                         batch_size=int(opt.batch_size / 2),
                                         num_workers=opt.num_workers,
                                         shuffle=False)
    opt.idim = train_dataset.get_feat_size()
    opt.odim = train_dataset.get_num_classes()
    opt.char_list = train_dataset.get_char_list()
    opt.train_dataset_len = len(train_dataset)
    logging.info('#input dims : ' + str(opt.idim))
    logging.info('#output dims: ' + str(opt.odim))
    logging.info("Dataset ready!")

    # Setup an model
    lr = opt.lr
    eps = opt.eps
    iters = opt.iters
    best_loss = opt.best_loss
    start_epoch = opt.start_epoch
    model_path = None
    if opt.enhance_resume:
        model_path = os.path.join(opt.works_dir, opt.enhance_resume)
        if os.path.isfile(model_path):
            package = torch.load(model_path,
                                 map_location=lambda storage, loc: storage)
            lr = package.get('lr', opt.lr)
            eps = package.get('eps', opt.eps)
            best_loss = package.get('best_loss', float('inf'))
            start_epoch = int(package.get('epoch', 0))
            iters = int(package.get('iters', 0))
            loss_report = package.get('loss_report', loss_report)
            visualizer.set_plot_report(loss_report, 'loss.png')
        else:
            print("no checkpoint found at {}".format(model_path))
    enhance_model = EnhanceModel.load_model(model_path, 'enhance_state_dict',
                                            opt)
    gan_model = GANModel.load_model(model_path, 'gan_state_dict', opt)
    print(gan_model)
    if opt.enhance_opt_type == 'gan_fft':
        feat_model = FFTModel.load_model(model_path, 'fft_state_dict', opt)
    elif opt.enhance_opt_type == 'gan_fbank':
        feat_model = FbankModel.load_model(model_path, 'fbank_state_dict', opt)
    else:
        raise NotImplementedError('enhance_opt_type [%s] is not recognized' %
                                  enhance_opt_type)

    # Setup an optimizer
    enhance_parameters = filter(
        lambda p: p.requires_grad,
        itertools.chain(enhance_model.parameters(), feat_model.parameters()))
    gan_parameters = filter(lambda p: p.requires_grad, gan_model.parameters())
    if opt.opt_type == 'adadelta':
        enhance_optimizer = torch.optim.Adadelta(enhance_parameters,
                                                 rho=0.95,
                                                 eps=eps)
        gan_optimizer = torch.optim.Adadelta(gan_parameters, rho=0.95, eps=eps)
    elif opt.opt_type == 'adam':
        enhance_optimizer = torch.optim.Adam(enhance_parameters,
                                             lr=lr,
                                             betas=(opt.beta1, 0.999))
        gan_optimizer = torch.optim.Adam(gan_parameters,
                                         lr=lr,
                                         betas=(opt.beta1, 0.999))
    criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan).to(device)

    # Training
    enhance_cmvn_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/enhance_gan_train_change_param/enhance_cmvn.npy'
    if enhance_cmvn_path:
        enhance_cmvn = np.load(enhance_cmvn_path)
        enhance_cmvn = torch.FloatTensor(enhance_cmvn)
    else:
        enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model,
                                          feat_model)
    #enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model, feat_model)
    enhance_model.train()
    feat_model.train()
    gan_model.train()
    for epoch in range(start_epoch, opt.epochs):
        if epoch > opt.shuffle_epoch:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle(epoch)
        for i, (data) in enumerate(train_loader,
                                   start=(iters % len(train_dataset))):
            utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
            enhance_loss, enhance_out = enhance_model(
                clean_inputs=clean_inputs,
                mix_inputs=mix_inputs,
                mix_log_inputs=mix_log_inputs,
                cos_angles=cos_angles,
                input_sizes=input_sizes)
            enhance_feat = feat_model(enhance_out, enhance_cmvn)
            clean_feat = feat_model(clean_inputs, enhance_cmvn)
            set_requires_grad([gan_model], False)
            gan_loss = criterionGAN(gan_model(enhance_feat), True)
            enhance_optimizer.zero_grad()
            loss = enhance_loss + opt.gan_loss_lambda * gan_loss
            loss.backward()
            # compute the gradient norm to check if it is normal or not
            grad_norm = torch.nn.utils.clip_grad_norm_(
                enhance_model.parameters(), opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                enhance_optimizer.step()

            set_requires_grad([gan_model], True)
            gan_optimizer.zero_grad()
            loss_D_real = criterionGAN(gan_model(clean_feat.detach()), True)
            loss_D_fake = criterionGAN(gan_model(enhance_feat.detach()), False)
            loss_D = (loss_D_real + loss_D_fake) * 0.5
            loss_D.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(gan_model.parameters(),
                                                       opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                gan_optimizer.step()

            iters += 1
            errors = {
                'train/loss': loss.item(),
                'train/gan_loss': gan_loss.item(),
                'train/enhance_loss': enhance_loss.item(),
                'train/loss_D': loss_D.item()
            }
            visualizer.set_current_errors(errors)
            if iters % opt.print_freq == 0:
                visualizer.print_current_errors(epoch, iters)
                state = {
                    'enhance_state_dict': enhance_model.state_dict(),
                    'gan_state_dict': gan_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': opt.eps,
                    'lr': opt.lr,
                    'best_loss': best_loss,
                    'loss_report': loss_report
                }
                if opt.enhance_opt_type == 'gan_fft':
                    state['fft_state_dict'] = feat_model.state_dict()
                elif opt.enhance_opt_type == 'gan_fbank':
                    state['fbank_state_dict'] = feat_model.state_dict()
                filename = 'latest'
                utils.save_checkpoint(state, opt.exp_path, filename=filename)

            if iters % opt.validate_freq == 0:
                enhance_model.eval()
                feat_model.eval()
                gan_model.eval()
                torch.set_grad_enabled(False)
                num_saved_specgram = 0
                for i, (data) in tqdm(enumerate(val_loader, start=0)):
                    utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
                    loss, enhance_out = enhance_model(
                        clean_inputs=clean_inputs,
                        mix_inputs=mix_inputs,
                        mix_log_inputs=mix_log_inputs,
                        cos_angles=cos_angles,
                        input_sizes=input_sizes)
                    errors = {'val/loss': loss.item()}
                    visualizer.set_current_errors(errors)

                    if opt.num_saved_specgram > 0:
                        if num_saved_specgram < opt.num_saved_specgram:
                            enhanced_outs = enhance_model.calculate_all_specgram(
                                mix_inputs, mix_log_inputs, input_sizes)
                            for x in range(len(utt_ids)):
                                enhanced_out = enhanced_outs[x].data.cpu(
                                ).numpy()
                                enhanced_out[enhanced_out <= 1e-7] = 1e-7
                                enhanced_out = np.log10(enhanced_out)
                                clean_input = clean_inputs[x].data.cpu().numpy(
                                )
                                clean_input[clean_input <= 1e-7] = 1e-7
                                clean_input = np.log10(clean_input)
                                mix_input = mix_inputs[x].data.cpu().numpy()
                                mix_input[mix_input <= 1e-7] = 1e-7
                                mix_input = np.log10(mix_input)
                                utt_id = utt_ids[x]
                                file_name = "{}_ep{}_it{}.png".format(
                                    utt_id, epoch, iters)
                                input_size = int(input_sizes[x])
                                visualizer.plot_specgram(
                                    clean_input, mix_input, enhanced_out,
                                    input_size, file_name)
                                num_saved_specgram += 1
                                if num_saved_specgram >= opt.num_saved_specgram:
                                    break
                enhance_model.train()
                feat_model.train()
                gan_model.train()
                torch.set_grad_enabled(True)

                visualizer.print_epoch_errors(epoch, iters)
                loss_report = visualizer.plot_epoch_errors(
                    epoch, iters, 'loss.png')
                train_loss = visualizer.get_current_errors('train/loss')
                val_loss = visualizer.get_current_errors('val/loss')
                filename = None
                if val_loss > best_loss:
                    print('val_loss {} > best_loss {}'.format(
                        val_loss, best_loss))
                    opt.eps = utils.adadelta_eps_decay(enhance_optimizer,
                                                       opt.eps_decay)
                else:
                    filename = 'model.loss.best'
                best_loss = min(val_loss, best_loss)
                print('best_loss {}'.format(best_loss))

                state = {
                    'enhance_state_dict': enhance_model.state_dict(),
                    'gan_state_dict': gan_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': opt.eps,
                    'lr': opt.lr,
                    'best_loss': best_loss,
                    'loss_report': loss_report
                }
                if opt.enhance_opt_type == 'gan_fft':
                    state['fft_state_dict'] = feat_model.state_dict()
                elif opt.enhance_opt_type == 'gan_fbank':
                    state['fbank_state_dict'] = feat_model.state_dict()
                ##filename='epoch-{}_iters-{}_loss-{:.6f}-{:.6f}.pth'.format(epoch, iters, train_loss, val_loss)
                utils.save_checkpoint(state, opt.exp_path, filename=filename)
                visualizer.reset()
                enhance_cmvn = compute_cmvn_epoch(opt, train_loader,
                                                  enhance_model, feat_model)