コード例 #1
0
def main():

    # Setup a model
    model_path = None
    enhance_model_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/enhance_fbank_train_table_2/model.loss.best'
    #enhance_model_path = '/usr/home/wudamu/Desktop/other_data/model.loss.best.base'

    #asr_mode_path = '/usr/home/wudamu/Desktop/other_data/model.acc.best'
    asr_mode_path = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/asr_clean_train_table3/model.acc.best'
    feat_model_path = asr_mode_path
    #if opt.resume:
    #model_path = os.path.join(opt.works_dir, opt.resume)
    #if os.path.isfile(model_path):
    #package = torch.load(model_path, map_location=lambda storage, loc: storage)
    #enhance_model = EnhanceModel.load_model(model_path, 'enhance_state_dict', opt)
    #feat_model = FbankModel.load_model(model_path, 'fbank_state_dict', opt)
    #asr_model = E2E.load_model(model_path, 'asr_state_dict', opt)
    #else:
    #raise Exception("no checkpoint found at {}".format(opt.resume))
    #else:
    #raise Exception("no checkpoint found at {}".format(opt.resume))
    if opt.resume:
        #model_path = os.path.join(opt.works_dir, opt.resume)

        #package = torch.load(model_path, map_location=lambda storage, loc: storage)
        enhance_model = EnhanceModel.load_model(enhance_model_path,
                                                'enhance_state_dict', opt)
        feat_model = FbankModel.load_model(feat_model_path, 'fbank_state_dict',
                                           opt)
        asr_model = E2E.load_model(asr_mode_path, 'asr_state_dict', opt)

    else:
        raise Exception("no checkpoint found at {}".format(opt.resume))

    def cpu_loader(storage, location):
        return storage

    if opt.lmtype == 'rnnlm':
        # read rnnlm
        if opt.rnnlm:
            rnnlm = lm.ClassifierWithState(
                #lm.RNNLM(len(opt.char_list), 650, 650))
                lm.RNNLM(len(opt.char_list), 300, 650))
            opt.rnnlm = "/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/rnnlm_train_shi/rnnlm.model.best"
            rnnlm.load_state_dict(
                torch.load(opt.rnnlm, map_location=cpu_loader))
            if len(opt.gpu_ids) > 0:
                rnnlm = rnnlm.cuda()
            print('load RNNLM from {}'.format(opt.rnnlm))
            rnnlm.eval()
        else:
            rnnlm = None

        if opt.word_rnnlm:
            if not opt.word_dict:
                logging.error(
                    'word dictionary file is not specified for the word RNNLM.'
                )
                sys.exit(1)

            word_dict = load_labeldict(opt.word_dict)
            char_dict = {x: i for i, x in enumerate(opt.char_list)}
            word_rnnlm = lm.ClassifierWithState(lm.RNNLM(len(word_dict), 650))
            word_rnnlm.load_state_dict(
                torch.load(opt.word_rnnlm, map_location=cpu_loader))
            word_rnnlm.eval()

            if rnnlm is not None:
                rnnlm = lm.ClassifierWithState(
                    extlm.MultiLevelLM(word_rnnlm.predictor, rnnlm.predictor,
                                       word_dict, char_dict))
            else:
                rnnlm = lm.ClassifierWithState(
                    extlm.LookAheadWordLM(word_rnnlm.predictor, word_dict,
                                          char_dict))
        fstlm = None

    elif opt.lmtype == 'fsrnnlm':
        if opt.rnnlm:
            rnnlm = lm.ClassifierWithState(
                fsrnn.FSRNNLM(len(opt.char_list), 300, opt.fast_layers,
                              opt.fast_cell_size, opt.slow_cell_size,
                              opt.zoneout_keep_h, opt.zoneout_keep_c))
            rnnlm.load_state_dict(
                torch.load(opt.rnnlm, map_location=cpu_loader))
            if len(opt.gpu_ids) > 0:
                rnnlm = rnnlm.cuda()
            print('load fsrnn from {}'.format(opt.rnnlm))
            rnnlm.eval()
        else:
            rnnlm = None
            print('not load fsrnn from {}'.format(opt.rnnlm))
        fstlm = None

    elif opt.lmtype == 'fstlm':
        if opt.fstlm_path:
            fstlm = NgramFstLM(opt.fstlm_path, opt.nn_char_map_file, 20)
        else:
            fstlm = None
        rnnlm = None
    else:
        rnnlm = None
        fstlm = None

    #fbank_cmvn_file = os.path.join(opt.exp_path, 'fbank_cmvn.npy')
    fbank_cmvn_file = '/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/decode_asr_train_table3_1/decode_clean/fbank_cmvn.npy'
    if os.path.exists(fbank_cmvn_file):
        fbank_cmvn = np.load(fbank_cmvn_file)
        fbank_cmvn = torch.FloatTensor(fbank_cmvn)
    else:
        raise Exception("no found at {}".format(fbank_cmvn_file))
    #enhance_cmvn_file ='/usr/home/wudamu/Documents/Robust_e2e_gan-master/checkpoints/result_enhance_fbank/enhance_cmvn.npy'
    #enhance_cmvn = np.load(enhance_cmvn_file)
    #enhance_cmvn = torch.FloatTensor(enhance_cmvn)
    torch.set_grad_enabled(False)
    new_json = {}
    for i, (data) in enumerate(recog_loader, start=0):
        utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data
        #utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
        name = utt_ids[0]
        #ss = torch.max(inputs)
        print(name)

        enhance_outputs = enhance_model(inputs, log_inputs, input_sizes)
        #print(enhance_outputs)
        #aa = torch.max(enhance_outputs)

        #feats = feat_model(enhance_outputs, fbank_cmvn)
        enhance_feat = feat_model(enhance_outputs, fbank_cmvn)

        nbest_hyps = asr_model.recognize(enhance_feat,
                                         opt,
                                         opt.char_list,
                                         rnnlm=rnnlm,
                                         fstlm=fstlm)
        #nbest_hyps = asr_model.recognize(enhance_outputs, opt, opt.char_list, rnnlm=rnnlm, fstlm=fstlm)
        # get 1best and remove sos
        y_hat = nbest_hyps[0]['yseq'][1:]
        print(y_hat)
        ##y_true = map(int, targets[0].split())
        y_true = targets

        # print out decoding result
        seq_hat = [opt.char_list[int(idx)] for idx in y_hat]
        seq_true = [opt.char_list[int(idx)] for idx in y_true]
        seq_hat_text = "".join(seq_hat).replace('<space>', ' ')
        seq_true_text = "".join(seq_true).replace('<space>', ' ')
        logging.info("groundtruth[%s]: " + seq_true_text, name)
        logging.info("prediction [%s]: " + seq_hat_text, name)
        # copy old json info
        new_json[name] = dict()
        new_json[name]['utt2spk'] = spk_ids[0]

        # added recognition results to json
        logging.debug("dump token id")
        out_dic = dict()
        out_dic['name'] = 'target1'
        out_dic['text'] = seq_true_text
        out_dic['token'] = " ".join(seq_true)
        out_dic['tokenid'] = " ".join([str(int(idx)) for idx in y_true])

        # TODO(karita) make consistent to chainer as idx[0] not idx
        out_dic['rec_tokenid'] = " ".join([str(int(idx)) for idx in y_hat])
        #logger.debug("dump token")
        out_dic['rec_token'] = " ".join(seq_hat)
        #logger.debug("dump text")
        out_dic['rec_text'] = seq_hat_text

        new_json[name]['output'] = [out_dic]
        # TODO(nelson): Modify this part when saving more than 1 hyp is enabled
        # add n-best recognition results with scores
        if opt.beam_size > 1 and len(nbest_hyps) > 1:
            for i, hyp in enumerate(nbest_hyps):
                y_hat = hyp['yseq'][1:]
                seq_hat = [opt.char_list[int(idx)] for idx in y_hat]
                seq_hat_text = "".join(seq_hat).replace('<space>', ' ')
                new_json[name]['rec_tokenid' + '[' + '{:05d}'.format(i) +
                               ']'] = " ".join([str(idx) for idx in y_hat])
                new_json[name]['rec_token' + '[' + '{:05d}'.format(i) +
                               ']'] = " ".join(seq_hat)
                new_json[name]['rec_text' + '[' + '{:05d}'.format(i) +
                               ']'] = seq_hat_text
                new_json[name]['score' + '[' + '{:05d}'.format(i) +
                               ']'] = float(hyp['score'])
    # TODO(watanabe) fix character coding problems when saving it
    with open(opt.result_label, 'wb') as f:
        f.write(
            json.dumps({
                'utts': new_json
            }, indent=4, sort_keys=True).encode('utf_8'))
コード例 #2
0
def main():

    #opt = TrainOptions().parse()
    opt = fake_opt.Enhance_fbank_train()
    device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids)
                          > 0 and torch.cuda.is_available() else "cpu")
    print(device)
    visualizer = Visualizer(opt)
    logging = visualizer.get_logger()
    loss_report = visualizer.add_plot_report(['train/loss', 'val/loss'],
                                             'loss.png')

    # data
    logging.info("Building dataset.")
    #train_dataset = MixSequentialDataset(opt, os.path.join(opt.dataroot, 'train'), os.path.join(opt.dict_dir, 'train_units.txt'),)
    #val_dataset   = MixSequentialDataset(opt, os.path.join(opt.dataroot, 'dev'), os.path.join(opt.dict_dir, 'train_units.txt'),)
    train_dataset = MixSequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'train_new'),
        os.path.join(opt.dict_dir, 'train/vocab'),
    )
    val_dataset = MixSequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'dev_new'),
        os.path.join(opt.dict_dir, 'train/vocab'),
    )
    train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size)
    train_loader = MixSequentialDataLoader(train_dataset,
                                           num_workers=opt.num_workers,
                                           batch_sampler=train_sampler)
    val_loader = MixSequentialDataLoader(val_dataset,
                                         batch_size=int(opt.batch_size / 2),
                                         num_workers=opt.num_workers,
                                         shuffle=False)
    opt.idim = train_dataset.get_feat_size()
    opt.odim = train_dataset.get_num_classes()
    opt.char_list = train_dataset.get_char_list()
    opt.train_dataset_len = len(train_dataset)
    logging.info('#input dims : ' + str(opt.idim))
    logging.info('#output dims: ' + str(opt.odim))
    logging.info("Dataset ready!")

    lr = opt.lr
    eps = opt.eps
    iters = opt.iters
    best_loss = opt.best_loss
    start_epoch = opt.start_epoch
    model_path = None
    if opt.enhance_resume:
        #model_path = os.path.join(opt.works_dir, opt.enhance_resume)
        model_path = opt.enhance_resume
        if os.path.isfile(model_path):
            package = torch.load(model_path,
                                 map_location=lambda storage, loc: storage)
            lr = package.get('lr', opt.lr)
            eps = package.get('eps', opt.eps)
            best_loss = package.get('best_loss', float('inf'))
            start_epoch = int(package.get('epoch', 0))
            iters = int(package.get('iters', 0))
            loss_report = package.get('loss_report', loss_report)
            visualizer.set_plot_report(loss_report, 'loss.png')
            print('package found at {} and start_epoch {} iters {}'.format(
                model_path, start_epoch, iters))
        else:
            print("no checkpoint found at {}".format(model_path))
    enhance_model = EnhanceModel.load_model(model_path, 'enhance_state_dict',
                                            opt)
    feat_model = FbankModel.load_model(model_path, 'fbank_state_dict', opt)

    # Setup an optimizer
    enhance_parameters = filter(lambda p: p.requires_grad,
                                enhance_model.parameters())
    if opt.opt_type == 'adadelta':
        enhance_optimizer = torch.optim.Adadelta(enhance_parameters,
                                                 rho=0.95,
                                                 eps=opt.eps)
    elif opt.opt_type == 'adam':
        enhance_optimizer = torch.optim.Adam(enhance_parameters,
                                             lr=opt.lr,
                                             betas=(opt.beta1, 0.999))

    # Training

    #enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model, feat_model)
    enhance_model.train()
    feat_model.train()
    #fbank_cmvn_file = os.path.join(opt.exp_path, 'fbank_cmvn.npy')
    #fbank_cmvn = np.load(fbank_cmvn_file)
    #fbank_cmvn = torch.FloatTensor(fbank_cmvn)
    for epoch in range(start_epoch, opt.epochs):
        if epoch > opt.shuffle_epoch:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle(epoch)
        for i, (data) in enumerate(train_loader,
                                   start=(iters % len(train_dataset))):
            utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
            if opt.enhance_type == 'unet_128' or opt.enhance_type == 'unet_256':
                t_step = int(clean_inputs.size(1))
                if t_step % 4 != 0:
                    t_step = t_step - t_step % 4
                    clean_inputs = clean_inputs[:, :t_step, :]
                    mix_inputs = mix_inputs[:, :t_step, :]
                    mix_log_inputs = mix_log_inputs[:, :t_step, :]
                mix_log_inputs = mix_log_inputs.unsqueeze(1)
            enhance_out = enhance_model(mix_inputs, mix_log_inputs,
                                        input_sizes)
            enhance_feat = feat_model(enhance_out)
            clean_feat = feat_model(clean_inputs)
            #enhance_feat = feat_model(enhance_out, enhance_cmvn)
            #clean_feat = feat_model(clean_inputs, fbank_cmvn)
            if opt.enhance_loss_type == 'L2':
                loss = F.mse_loss(enhance_feat, clean_feat.detach())
            elif opt.enhance_loss_type == 'L1':
                loss = F.l1_loss(enhance_feat, clean_feat.detach())
            elif opt.enhance_loss_type == 'smooth_L1':
                loss = F.smooth_l1_loss(enhance_feat, clean_feat.detach())
            enhance_optimizer.zero_grad()  # Clear the parameter gradients
            loss.backward()
            # compute the gradient norm to check if it is normal or not
            grad_norm = torch.nn.utils.clip_grad_norm_(
                enhance_model.parameters(), opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                enhance_optimizer.step()

            iters += 1
            errors = {'train/loss': loss.item()}
            visualizer.set_current_errors(errors)
            if iters % opt.print_freq == 0:
                visualizer.print_current_errors(epoch, iters)
                state = {
                    'enhance_state_dict': enhance_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': eps,
                    'lr': lr,
                    'best_loss': best_loss,
                    'loss_report': loss_report
                }
                filename = 'latest'
                utils.save_checkpoint(state, opt.exp_path, filename=filename)

            if iters % opt.validate_freq == 0:
                enhance_model.eval()
                feat_model.eval()
                torch.set_grad_enabled(False)
                num_saved_specgram = 0
                for i, (data) in tqdm(enumerate(val_loader, start=0)):
                    utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
                    if opt.enhance_type == 'unet_128' or opt.enhance_type == 'unet_256':
                        t_step = int(clean_inputs.size(1))
                        if t_step % 4 != 0:
                            t_step = t_step - t_step % 4
                            clean_inputs = clean_inputs[:, :t_step, :]
                            mix_inputs = mix_inputs[:, :t_step, :]
                            mix_log_inputs = mix_log_inputs[:, :t_step, :]
                        mix_log_inputs = mix_log_inputs.unsqueeze(1)
                    enhance_out = enhance_model(mix_inputs, mix_log_inputs,
                                                input_sizes)
                    enhance_feat = feat_model(enhance_out)
                    clean_feat = feat_model(clean_inputs)
                    #enhance_feat = feat_model(enhance_out, enhance_cmvn)
                    #clean_feat = feat_model(clean_inputs, fbank_cmvn)
                    if opt.enhance_loss_type == 'L2':
                        loss = F.mse_loss(enhance_feat, clean_feat.detach())
                    elif opt.enhance_loss_type == 'L1':
                        loss = F.l1_loss(enhance_feat, clean_feat.detach())
                    elif opt.enhance_loss_type == 'smooth_L1':
                        loss = F.smooth_l1_loss(enhance_feat,
                                                clean_feat.detach())
                    errors = {'val/loss': loss.item()}
                    visualizer.set_current_errors(errors)

                    if opt.num_saved_specgram > 0:
                        if num_saved_specgram < opt.num_saved_specgram:
                            enhanced_outs = enhance_model.calculate_all_specgram(
                                mix_inputs, mix_log_inputs, input_sizes)
                            for x in range(len(utt_ids)):
                                enhanced_out = enhanced_outs[x].data.cpu(
                                ).numpy()
                                enhanced_out[enhanced_out <= 1e-7] = 1e-7
                                enhanced_out = np.log10(enhanced_out)
                                clean_input = clean_inputs[x].data.cpu().numpy(
                                )
                                clean_input[clean_input <= 1e-7] = 1e-7
                                clean_input = np.log10(clean_input)
                                mix_input = mix_inputs[x].data.cpu().numpy()
                                mix_input[mix_input <= 1e-7] = 1e-7
                                mix_input = np.log10(mix_input)
                                utt_id = utt_ids[x]
                                file_name = "{}_ep{}_it{}.png".format(
                                    utt_id, epoch, iters)
                                input_size = int(input_sizes[x])
                                visualizer.plot_specgram(
                                    clean_input, mix_input, enhanced_out,
                                    input_size, file_name)
                                num_saved_specgram += 1
                                if num_saved_specgram >= opt.num_saved_specgram:
                                    break
                enhance_model.train()
                feat_model.train()
                torch.set_grad_enabled(True)

                visualizer.print_epoch_errors(epoch, iters)
                loss_report = visualizer.plot_epoch_errors(
                    epoch, iters, 'loss.png')
                train_loss = visualizer.get_current_errors('train/loss')
                val_loss = visualizer.get_current_errors('val/loss')
                filename = None
                if val_loss > best_loss:
                    print('val_loss {} > best_loss {}'.format(
                        val_loss, best_loss))
                    eps = utils.adadelta_eps_decay(enhance_optimizer,
                                                   opt.eps_decay)
                else:
                    filename = 'model.loss.best'
                best_loss = min(val_loss, best_loss)
                print('best_loss {}'.format(best_loss))

                state = {
                    'enhance_state_dict': enhance_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': eps,
                    'lr': lr,
                    'best_loss': best_loss,
                    'loss_report': loss_report
                }
                ##filename='epoch-{}_iters-{}_loss-{:.6f}-{:.6f}.pth'.format(epoch, iters, train_loss, val_loss)
                utils.save_checkpoint(state, opt.exp_path, filename=filename)
                visualizer.reset()
コード例 #3
0
def main():

    opt = TrainOptions().parse()
    device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids)
                          > 0 and torch.cuda.is_available() else "cpu")

    #import fake_opt
    #opt = fake_opt.Asr_train()

    visualizer = Visualizer(opt)
    logging = visualizer.get_logger()
    acc_report = visualizer.add_plot_report(['train/acc', 'val/acc'],
                                            'acc.png')
    loss_report = visualizer.add_plot_report(['train/loss', 'val/loss'],
                                             'loss.png')

    # data
    logging.info("Building dataset.")
    # train目录 和 dict目录,作为输入
    train_dataset = SequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'train'),
        os.path.join(opt.dict_dir, 'train_units.txt'),
    )
    val_dataset = SequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'dev'),
        os.path.join(opt.dict_dir, 'train_units.txt'),
    )

    train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size)

    train_loader = SequentialDataLoader(train_dataset,
                                        num_workers=opt.num_workers,
                                        batch_sampler=train_sampler)
    val_loader = SequentialDataLoader(val_dataset,
                                      batch_size=int(opt.batch_size / 2),
                                      num_workers=opt.num_workers,
                                      shuffle=False)

    opt.idim = train_dataset.get_feat_size()
    opt.odim = train_dataset.get_num_classes()
    opt.char_list = train_dataset.get_char_list()
    opt.train_dataset_len = len(train_dataset)

    logging.info('#input dims : ' + str(opt.idim))
    logging.info('#output dims: ' + str(opt.odim))
    logging.info("Dataset ready!")

    # Setup a model
    asr_model = E2E(opt)
    fbank_model = FbankModel(opt)
    lr = opt.lr  # default=0.005
    eps = opt.eps  # default=1e-8
    iters = opt.iters  # default=0
    start_epoch = opt.start_epoch  # default=0
    best_loss = opt.best_loss  # default=float('inf')
    best_acc = opt.best_acc  # default=0

    # 如果有中继点
    if opt.resume:
        model_path = os.path.join(opt.works_dir, opt.resume)
        if os.path.isfile(model_path):
            package = torch.load(model_path,
                                 map_location=lambda storage, loc: storage)
            lr = package.get('lr', opt.lr)
            eps = package.get('eps', opt.eps)
            best_loss = package.get('best_loss', float('inf'))
            best_acc = package.get('best_acc', 0)
            start_epoch = int(package.get('epoch', 0))
            iters = int(package.get('iters', 0))

            acc_report = package.get('acc_report', acc_report)
            loss_report = package.get('loss_report', loss_report)
            visualizer.set_plot_report(acc_report, 'acc.png')
            visualizer.set_plot_report(loss_report, 'loss.png')

            asr_model = E2E.load_model(model_path, 'asr_state_dict')
            fbank_model = FbankModel.load_model(model_path, 'fbank_state_dict')
            logging.info('Loading model {} and iters {}'.format(
                model_path, iters))
        else:
            print("no checkpoint found at {}".format(model_path))
    # convert to cuda
    asr_model.cuda()
    fbank_model.cuda()
    print(asr_model)
    print(fbank_model)

    # Setup an optimizer
    parameters = filter(
        lambda p: p.requires_grad,
        itertools.chain(asr_model.parameters(), fbank_model.parameters()))
    #parameters = filter(lambda p: p.requires_grad, itertools.chain(asr_model.parameters()))
    if opt.opt_type == 'adadelta':
        optimizer = torch.optim.Adadelta(parameters, rho=0.95, eps=eps)
    elif opt.opt_type == 'adam':
        optimizer = torch.optim.Adam(parameters,
                                     lr=lr,
                                     betas=(opt.beta1, 0.999))

    asr_model.train()
    fbank_model.train()
    #NOTE sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_iter, opt.sche_samp_final_iter, opt.sche_samp_final_rate)
    sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_epoch,
                                           opt.sche_samp_final_epoch,
                                           opt.sche_samp_final_rate)
    sche_samp_rate = sample_rampup.update(iters)

    # 计算fbank的cmvn输入
    fbank_cmvn_file = os.path.join(opt.exp_path, 'fbank_cmvn.npy')
    if os.path.exists(fbank_cmvn_file):
        # 如果有fbank_cmvn
        fbank_cmvn = np.load(fbank_cmvn_file)
    else:
        # 否则自己生成
        for i, (data) in enumerate(train_loader, start=0):
            utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data
            fbank_cmvn = fbank_model.compute_cmvn(inputs, input_sizes)
            # 下面这个if 是原code,通过fbank-cmvn是否为none判断是否break是十分愚蠢的
            if fbank_cmvn is not None:
                np.save(fbank_cmvn_file, fbank_cmvn)
                print('save fbank_cmvn to {}'.format(fbank_cmvn_file))
                break
            # 因此需要通过 cmvn_processed_num 和 cmvn_num 来判断
            if fbank_model.cmvn_processed_num >= fbank_model.cmvn_num:
                # 运行最后一次compute_cmvn
                fbank_cmvn = fbank_model.compute_cmvn(inputs, input_sizes)
                np.save(fbank_cmvn_file, fbank_cmvn)
                print('save fbank_cmvn to {}'.format(fbank_cmvn_file))
                break
    print(fbank_model.cmvn_processed_num)  # 3944
    fbank_cmvn = torch.FloatTensor(fbank_cmvn)

    # 开始训练
    for epoch in range(start_epoch, opt.epochs):
        if epoch > opt.shuffle_epoch:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle(epoch)
        for i, (data) in enumerate(train_loader,
                                   start=(iters * opt.batch_size) %
                                   len(train_dataset)):
            utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data
            fbank_features = fbank_model(inputs, fbank_cmvn)
            # NOTE 下面这个原来的语句,是和 变量data 不匹配的
            # utt_ids, spk_ids, fbank_features, targets, input_sizes, target_sizes = data

            # asr_model 输出的数量是3,而这里却有4个变量
            # 去查以下e2e_model
            # 实际在forward中的输出根本没有context
            # 另外,下面另外一个asr_model 同理
            loss_ctc, loss_att, acc = asr_model(fbank_features, targets,
                                                input_sizes, target_sizes,
                                                sche_samp_rate)
            loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att

            optimizer.zero_grad()  # Clear the parameter gradients
            loss.backward()  # compute backwards

            # compute the gradient norm to check if it is normal or not 'fbank_state_dict': fbank_model.state_dict(),
            grad_norm = torch.nn.utils.clip_grad_norm_(asr_model.parameters(),
                                                       opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                optimizer.step()

            iters += 1
            errors = {
                'train/loss': loss.item(),
                'train/loss_ctc': loss_ctc.item(),
                'train/acc': acc,
                'train/loss_att': loss_att.item()
            }
            visualizer.set_current_errors(errors)
            if iters % opt.print_freq == 0:
                visualizer.print_current_errors(epoch, iters)
                state = {
                    'asr_state_dict': asr_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': opt.eps,
                    'lr': opt.lr,
                    'best_loss': best_loss,
                    'best_acc': best_acc,
                    'acc_report': acc_report,
                    'loss_report': loss_report
                }
                filename = 'latest'
                utils.save_checkpoint(state, opt.exp_path, filename=filename)

            if iters % opt.validate_freq == 0:
                sche_samp_rate = sample_rampup.update(iters)
                print("iters {} sche_samp_rate {}".format(
                    iters, sche_samp_rate))
                asr_model.eval()
                fbank_model.eval()
                torch.set_grad_enabled(False)
                num_saved_attention = 0
                for i, (data) in tqdm(enumerate(val_loader, start=0)):
                    utt_ids, spk_ids, inputs, log_inputs, targets, input_sizes, target_sizes = data
                    fbank_features = fbank_model(inputs, fbank_cmvn)
                    # utt_ids, spk_ids, fbank_features, targets, input_sizes, target_sizes = data
                    loss_ctc, loss_att, acc = asr_model(
                        fbank_features, targets, input_sizes, target_sizes,
                        0.0)
                    loss = opt.mtlalpha * loss_ctc + (1 -
                                                      opt.mtlalpha) * loss_att
                    errors = {
                        'val/loss': loss.item(),
                        'val/loss_ctc': loss_ctc.item(),
                        'val/acc': acc,
                        'val/loss_att': loss_att.item()
                    }
                    visualizer.set_current_errors(errors)

                    if opt.num_save_attention > 0 and opt.mtlalpha != 1.0:
                        if num_saved_attention < opt.num_save_attention:
                            att_ws = asr_model.calculate_all_attentions(
                                fbank_features, targets, input_sizes,
                                target_sizes)
                            for x in range(len(utt_ids)):
                                att_w = att_ws[x]
                                utt_id = utt_ids[x]
                                file_name = "{}_ep{}_it{}.png".format(
                                    utt_id, epoch, iters)
                                dec_len = int(target_sizes[x])
                                enc_len = int(input_sizes[x])
                                visualizer.plot_attention(
                                    att_w, dec_len, enc_len, file_name)
                                num_saved_attention += 1
                                if num_saved_attention >= opt.num_save_attention:
                                    break
                asr_model.train()
                fbank_model.train()
                torch.set_grad_enabled(True)

                visualizer.print_epoch_errors(epoch, iters)
                acc_report = visualizer.plot_epoch_errors(
                    epoch, iters, 'acc.png')
                loss_report = visualizer.plot_epoch_errors(
                    epoch, iters, 'loss.png')
                val_loss = visualizer.get_current_errors('val/loss')
                val_acc = visualizer.get_current_errors('val/acc')
                filename = None
                if opt.criterion == 'acc' and opt.mtl_mode is not 'ctc':
                    if val_acc < best_acc:
                        logging.info('val_acc {} > best_acc {}'.format(
                            val_acc, best_acc))
                        opt.eps = utils.adadelta_eps_decay(
                            optimizer, opt.eps_decay)
                    else:
                        filename = 'model.acc.best'
                    best_acc = max(best_acc, val_acc)
                    logging.info('best_acc {}'.format(best_acc))
                elif args.criterion == 'loss':
                    if val_loss > best_loss:
                        logging.info('val_loss {} > best_loss {}'.format(
                            val_loss, best_loss))
                        opt.eps = utils.adadelta_eps_decay(
                            optimizer, opt.eps_decay)
                    else:
                        filename = 'model.loss.best'
                    best_loss = min(val_loss, best_loss)
                    logging.info('best_loss {}'.format(best_loss))
                state = {
                    'asr_state_dict': asr_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': opt.eps,
                    'lr': opt.lr,
                    'best_loss': best_loss,
                    'best_acc': best_acc,
                    'acc_report': acc_report,
                    'loss_report': loss_report
                }
                utils.save_checkpoint(state, opt.exp_path, filename=filename)
                ##filename='epoch-{}_iters-{}_loss-{:.4f}_acc-{:.4f}.pth'.format(epoch, iters, val_loss, val_acc)
                ##utils.save_checkpoint(state, opt.exp_path, filename=filename)
                visualizer.reset()
コード例 #4
0
def main():    
    opt = TrainOptions().parse()    
    device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids) > 0 and torch.cuda.is_available() else "cpu")
     
    visualizer = Visualizer(opt)  
    logging = visualizer.get_logger()
    acc_report = visualizer.add_plot_report(['train/acc', 'val/acc'], 'acc.png')
    loss_report = visualizer.add_plot_report(['train/loss', 'val/loss', 'train/enhance_loss', 'val/enhance_loss'], 'loss.png')
     
    # data
    logging.info("Building dataset.")
    train_dataset = MixSequentialDataset(opt, os.path.join(opt.dataroot, 'train'), os.path.join(opt.dict_dir, 'train_units.txt'),) 
    val_dataset   = MixSequentialDataset(opt, os.path.join(opt.dataroot, 'dev'), os.path.join(opt.dict_dir, 'train_units.txt'),)
    train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size) 
    train_loader  = MixSequentialDataLoader(train_dataset, num_workers=opt.num_workers, batch_sampler=train_sampler)
    val_loader    = MixSequentialDataLoader(val_dataset, batch_size=int(opt.batch_size/2), num_workers=opt.num_workers, shuffle=False)
    opt.idim = train_dataset.get_feat_size()
    opt.odim = train_dataset.get_num_classes()
    opt.char_list = train_dataset.get_char_list()
    opt.train_dataset_len = len(train_dataset)
    logging.info('#input dims : ' + str(opt.idim))
    logging.info('#output dims: ' + str(opt.odim))
    logging.info("Dataset ready!")
    
    # Setup an model
    lr = opt.lr
    eps = opt.eps
    iters = opt.iters   
    best_acc = opt.best_acc 
    best_loss = opt.best_loss  
    start_epoch = opt.start_epoch
    
    enhance_model_path = None
    if opt.enhance_resume:
        enhance_model_path = os.path.join(opt.works_dir, opt.enhance_resume)
        if os.path.isfile(enhance_model_path):
            enhance_model = EnhanceModel.load_model(enhance_model_path, 'enhance_state_dict', opt)
        else:
            print("no checkpoint found at {}".format(enhance_model_path))     
    
    asr_model_path = None
    if opt.asr_resume:
        asr_model_path = os.path.join(opt.works_dir, opt.asr_resume)
        if os.path.isfile(asr_model_path):
            asr_model = ShareE2E.load_model(asr_model_path, 'asr_state_dict', opt)
        else:
            print("no checkpoint found at {}".format(asr_model_path))  
                                        
    joint_model_path = None
    if opt.joint_resume:
        joint_model_path = os.path.join(opt.works_dir, opt.joint_resume)
        if os.path.isfile(joint_model_path):
            package = torch.load(joint_model_path, map_location=lambda storage, loc: storage)
            lr = package.get('lr', opt.lr)
            eps = package.get('eps', opt.eps)  
            best_acc = package.get('best_acc', 0)      
            best_loss = package.get('best_loss', float('inf'))
            start_epoch = int(package.get('epoch', 0))   
            iters = int(package.get('iters', 0)) - 1   
            print('joint_model_path {} and iters {}'.format(joint_model_path, iters))        
            ##loss_report = package.get('loss_report', loss_report)
            ##visualizer.set_plot_report(loss_report, 'loss.png')
        else:
            print("no checkpoint found at {}".format(joint_model_path))
    if joint_model_path is not None or enhance_model_path is None:     
        enhance_model = EnhanceModel.load_model(joint_model_path, 'enhance_state_dict', opt)    
    if joint_model_path is not None or asr_model_path is None:  
        asr_model = ShareE2E.load_model(joint_model_path, 'asr_state_dict', opt)     
    feat_model = FbankModel.load_model(joint_model_path, 'fbank_state_dict', opt) 
    if opt.isGAN:
        gan_model = GANModel.load_model(joint_model_path, 'gan_state_dict', opt) 
    ##set_requires_grad([enhance_model], False)    
    
    # Setup an optimizer
    enhance_parameters = filter(lambda p: p.requires_grad, enhance_model.parameters())
    asr_parameters = filter(lambda p: p.requires_grad, asr_model.parameters())
    if opt.isGAN:
        gan_parameters = filter(lambda p: p.requires_grad, gan_model.parameters())   
    if opt.opt_type == 'adadelta':
        enhance_optimizer = torch.optim.Adadelta(enhance_parameters, rho=0.95, eps=eps)
        asr_optimizer = torch.optim.Adadelta(asr_parameters, rho=0.95, eps=eps)
        if opt.isGAN:
            gan_optimizer = torch.optim.Adadelta(gan_parameters, rho=0.95, eps=eps)
    elif opt.opt_type == 'adam':
        enhance_optimizer = torch.optim.Adam(enhance_parameters, lr=lr, betas=(opt.beta1, 0.999))   
        asr_optimizer = torch.optim.Adam(asr_parameters, lr=lr, betas=(opt.beta1, 0.999)) 
        if opt.isGAN:                      
            gan_optimizer = torch.optim.Adam(gan_parameters, lr=lr, betas=(opt.beta1, 0.999))
    if opt.isGAN:
        criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan).to(device)
       
    # Training	
    enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model, feat_model) 
    sample_rampup = utils.ScheSampleRampup(opt.sche_samp_start_iter, opt.sche_samp_final_iter, opt.sche_samp_final_rate)  
    sche_samp_rate = sample_rampup.update(iters)
    
    enhance_model.train()
    feat_model.train()
    asr_model.train()               	                    
    for epoch in range(start_epoch, opt.epochs):               
        if epoch > opt.shuffle_epoch:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle(epoch)  
        for i, (data) in enumerate(train_loader, start=0):
            utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
            enhance_out = enhance_model(mix_inputs, mix_log_inputs, input_sizes) 
            enhance_feat = feat_model(enhance_out)
            clean_feat = feat_model(clean_inputs)
            mix_feat = feat_model(mix_inputs)
            if opt.enhance_loss_type == 'L2':
                enhance_loss = F.mse_loss(enhance_feat, clean_feat.detach())
            elif opt.enhance_loss_type == 'L1':
                enhance_loss = F.l1_loss(enhance_feat, clean_feat.detach())
            elif opt.enhance_loss_type == 'smooth_L1':
                enhance_loss = F.smooth_l1_loss(enhance_feat, clean_feat.detach())
            enhance_loss = opt.enhance_loss_lambda * enhance_loss
                
            loss_ctc, loss_att, acc, clean_context, mix_context = asr_model(clean_feat, enhance_feat, targets, input_sizes, target_sizes, sche_samp_rate, enhance_cmvn) 
            coral_loss = opt.coral_loss_lambda * CORAL(clean_context, mix_context)              
            asr_loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att
            loss = asr_loss + enhance_loss + coral_loss
                    
            if opt.isGAN:
                set_requires_grad([gan_model], False)
                if opt.netD_type == 'pixel':
                    fake_AB = torch.cat((mix_feat, enhance_feat), 2)
                else:
                    fake_AB = enhance_feat
                gan_loss = opt.gan_loss_lambda * criterionGAN(gan_model(fake_AB, enhance_cmvn), True)
                loss += gan_loss
                                              
            enhance_optimizer.zero_grad()
            asr_optimizer.zero_grad()  # Clear the parameter gradients
            loss.backward()          
            # compute the gradient norm to check if it is normal or not
            grad_norm = torch.nn.utils.clip_grad_norm_(asr_model.parameters(), opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                enhance_optimizer.step()
                asr_optimizer.step()                
            
            if opt.isGAN:
                set_requires_grad([gan_model], True)   
                gan_optimizer.zero_grad()
                if opt.netD_type == 'pixel':
                    fake_AB = torch.cat((mix_feat, enhance_feat), 2)
                    real_AB = torch.cat((mix_feat, clean_feat), 2)
                else:
                    fake_AB = enhance_feat
                    real_AB = clean_feat
                loss_D_real = criterionGAN(gan_model(real_AB.detach(), enhance_cmvn), True)
                loss_D_fake = criterionGAN(gan_model(fake_AB.detach(), enhance_cmvn), False)
                loss_D = (loss_D_real + loss_D_fake) * 0.5
                loss_D.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(gan_model.parameters(), opt.grad_clip)
                if math.isnan(grad_norm):
                    logging.warning('grad norm is nan. Do not update model.')
                else:
                    gan_optimizer.step()
                                               
            iters += 1
            errors = {'train/loss': loss.item(), 'train/loss_ctc': loss_ctc.item(), 
                      'train/acc': acc, 'train/loss_att': loss_att.item(), 
                      'train/enhance_loss': enhance_loss.item(), 'train/coral_loss': coral_loss.item()}
            if opt.isGAN:
                errors['train/loss_D'] = loss_D.item()
                errors['train/gan_loss'] = opt.gan_loss_lambda * gan_loss.item()  
              
            visualizer.set_current_errors(errors)
            if iters % opt.print_freq == 0:
                visualizer.print_current_errors(epoch, iters)
                state = {'asr_state_dict': asr_model.state_dict(), 
                         'fbank_state_dict': feat_model.state_dict(), 
                         'enhance_state_dict': enhance_model.state_dict(), 
                         'opt': opt, 'epoch': epoch, 'iters': iters, 
                         'eps': opt.eps, 'lr': opt.lr,                                    
                         'best_loss': best_loss, 'best_acc': best_acc, 
                         'acc_report': acc_report, 'loss_report': loss_report}
                if opt.isGAN:
                    state['gan_state_dict'] = gan_model.state_dict()
                filename='latest'
                utils.save_checkpoint(state, opt.exp_path, filename=filename)
                    
            if iters % opt.validate_freq == 0:
                sche_samp_rate = sample_rampup.update(iters)
                print("iters {} sche_samp_rate {}".format(iters, sche_samp_rate))    
                enhance_model.eval() 
                feat_model.eval() 
                asr_model.eval()
                torch.set_grad_enabled(False)                
                num_saved_attention = 0 
                for i, (data) in tqdm(enumerate(val_loader, start=0)):
                    utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
                    enhance_out = enhance_model(mix_inputs, mix_log_inputs, input_sizes)                         
                    enhance_feat = feat_model(enhance_out)
                    clean_feat = feat_model(clean_inputs)
                    mix_feat = feat_model(mix_inputs)
                    if opt.enhance_loss_type == 'L2':
                        enhance_loss = F.mse_loss(enhance_feat, clean_feat.detach())
                    elif opt.enhance_loss_type == 'L1':
                        enhance_loss = F.l1_loss(enhance_feat, clean_feat.detach())
                    elif opt.enhance_loss_type == 'smooth_L1':
                        enhance_loss = F.smooth_l1_loss(enhance_feat, clean_feat.detach())
                    if opt.isGAN:
                        set_requires_grad([gan_model], False)
                        if opt.netD_type == 'pixel':
                            fake_AB = torch.cat((mix_feat, enhance_feat), 2)
                        else:
                            fake_AB = enhance_feat
                        gan_loss = criterionGAN(gan_model(fake_AB, enhance_cmvn), True)
                        enhance_loss += opt.gan_loss_lambda * gan_loss
                        
                    loss_ctc, loss_att, acc, clean_context, mix_context = asr_model(clean_feat, enhance_feat, targets, input_sizes, target_sizes, 0.0, enhance_cmvn)
                                                  
                    asr_loss = opt.mtlalpha * loss_ctc + (1 - opt.mtlalpha) * loss_att
                    enhance_loss = opt.enhance_loss_lambda * enhance_loss
                    loss = asr_loss + enhance_loss                          
                    errors = {'val/loss': loss.item(), 'val/loss_ctc': loss_ctc.item(), 
                              'val/acc': acc, 'val/loss_att': loss_att.item(),
                              'val/enhance_loss': enhance_loss.item()}
                    if opt.isGAN:        
                        errors['val/gan_loss'] = opt.gan_loss_lambda * gan_loss.item()  
                    visualizer.set_current_errors(errors)
                
                    if opt.num_save_attention > 0 and opt.mtlalpha != 1.0:
                        if num_saved_attention < opt.num_save_attention:
                            att_ws = asr_model.calculate_all_attentions(enhance_feat, targets, input_sizes, target_sizes, enhance_cmvn)                            
                            for x in range(len(utt_ids)):
                                att_w = att_ws[x]
                                utt_id = utt_ids[x]
                                file_name = "{}_ep{}_it{}.png".format(utt_id, epoch, iters)
                                dec_len = int(target_sizes[x])
                                enc_len = int(input_sizes[x]) 
                                visualizer.plot_attention(att_w, dec_len, enc_len, file_name) 
                                num_saved_attention += 1
                                if num_saved_attention >= opt.num_save_attention:   
                                    break 
                enhance_model.train()
                feat_model.train()
                asr_model.train() 
                torch.set_grad_enabled(True)  
				
                visualizer.print_epoch_errors(epoch, iters)  
                acc_report = visualizer.plot_epoch_errors(epoch, iters, 'acc.png') 
                loss_report = visualizer.plot_epoch_errors(epoch, iters, 'loss.png') 
                val_loss = visualizer.get_current_errors('val/loss')
                val_acc = visualizer.get_current_errors('val/acc') 
                filename = None                
                if opt.criterion == 'acc' and opt.mtl_mode is not 'ctc':
                    if val_acc < best_acc:
                        logging.info('val_acc {} > best_acc {}'.format(val_acc, best_acc))
                        opt.eps = utils.adadelta_eps_decay(asr_optimizer, opt.eps_decay)
                    else:
                        filename='model.acc.best'                    
                    best_acc = max(best_acc, val_acc)
                    logging.info('best_acc {}'.format(best_acc))  
                elif args.criterion == 'loss':
                    if val_loss > best_loss:
                        logging.info('val_loss {} > best_loss {}'.format(val_loss, best_loss))
                        opt.eps = utils.adadelta_eps_decay(asr_optimizer, opt.eps_decay)
                    else:
                        filename='model.loss.best'    
                    best_loss = min(val_loss, best_loss)
                    logging.info('best_loss {}'.format(best_loss))                  
                state = {'asr_state_dict': asr_model.state_dict(), 
                         'fbank_state_dict': feat_model.state_dict(), 
                         'enhance_state_dict': enhance_model.state_dict(), 
                         'opt': opt, 'epoch': epoch, 'iters': iters, 
                         'eps': opt.eps, 'lr': opt.lr,                                    
                         'best_loss': best_loss, 'best_acc': best_acc, 
                         'acc_report': acc_report, 'loss_report': loss_report}
                if opt.isGAN:
                    state['gan_state_dict'] = gan_model.state_dict()
                utils.save_checkpoint(state, opt.exp_path, filename=filename)                  
                visualizer.reset()  
                enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model, feat_model)  
コード例 #5
0
def main():
    opt = TrainOptions().parse()
    device = torch.device("cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids)
                          > 0 and torch.cuda.is_available() else "cpu")

    visualizer = Visualizer(opt)
    logging = visualizer.get_logger()
    loss_report = visualizer.add_plot_report([
        'train/loss', 'val/loss', 'train/gan_loss', 'train/enhance_loss',
        'train/loss_D'
    ], 'loss.png')

    # data
    logging.info("Building dataset.")
    train_dataset = MixSequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'train'),
        os.path.join(opt.dict_dir, 'train_units.txt'),
    )
    val_dataset = MixSequentialDataset(
        opt,
        os.path.join(opt.dataroot, 'dev'),
        os.path.join(opt.dict_dir, 'train_units.txt'),
    )
    train_sampler = BucketingSampler(train_dataset, batch_size=opt.batch_size)
    train_loader = MixSequentialDataLoader(train_dataset,
                                           num_workers=opt.num_workers,
                                           batch_sampler=train_sampler)
    val_loader = MixSequentialDataLoader(val_dataset,
                                         batch_size=int(opt.batch_size / 2),
                                         num_workers=opt.num_workers,
                                         shuffle=False)
    opt.idim = train_dataset.get_feat_size()
    opt.odim = train_dataset.get_num_classes()
    opt.char_list = train_dataset.get_char_list()
    opt.train_dataset_len = len(train_dataset)
    logging.info('#input dims : ' + str(opt.idim))
    logging.info('#output dims: ' + str(opt.odim))
    logging.info("Dataset ready!")

    # Setup an model
    lr = opt.lr
    eps = opt.eps
    iters = opt.iters
    best_loss = opt.best_loss
    start_epoch = opt.start_epoch
    model_path = None
    if opt.enhace_resume:
        model_path = os.path.join(opt.works_dir, opt.enhace_resume)
        if os.path.isfile(model_path):
            package = torch.load(model_path,
                                 map_location=lambda storage, loc: storage)
            lr = package.get('lr', opt.lr)
            eps = package.get('eps', opt.eps)
            best_loss = package.get('best_loss', float('inf'))
            start_epoch = int(package.get('epoch', 0))
            iters = int(package.get('iters', 0))
            loss_report = package.get('loss_report', loss_report)
            visualizer.set_plot_report(loss_report, 'loss.png')
        else:
            print("no checkpoint found at {}".format(model_path))
    enhance_model = EnhanceModel.load_model(model_path, 'enhance_state_dict',
                                            opt)
    gan_model = GANModel.load_model(model_path, 'gan_state_dict', opt)
    if opt.enhance_opt_type == 'gan_fft':
        feat_model = FFTModel.load_model(model_path, 'fft_state_dict', opt)
    elif opt.enhance_opt_type == 'gan_fbank':
        feat_model = FbankModel.load_model(model_path, 'fbank_state_dict', opt)
    else:
        raise NotImplementedError('enhance_opt_type [%s] is not recognized' %
                                  enhance_opt_type)

    # Setup an optimizer
    enhance_parameters = filter(
        lambda p: p.requires_grad,
        itertools.chain(enhance_model.parameters(), feat_model.parameters()))
    gan_parameters = filter(lambda p: p.requires_grad, gan_model.parameters())
    if opt.opt_type == 'adadelta':
        enhance_optimizer = torch.optim.Adadelta(enhance_parameters,
                                                 rho=0.95,
                                                 eps=eps)
        gan_optimizer = torch.optim.Adadelta(gan_parameters, rho=0.95, eps=eps)
    elif opt.opt_type == 'adam':
        enhance_optimizer = torch.optim.Adam(enhance_parameters,
                                             lr=lr,
                                             betas=(opt.beta1, 0.999))
        gan_optimizer = torch.optim.Adam(gan_parameters,
                                         lr=lr,
                                         betas=(opt.beta1, 0.999))
    criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan).to(device)

    # Training
    enhance_cmvn = compute_cmvn_epoch(opt, train_loader, enhance_model,
                                      feat_model)
    enhance_model.train()
    feat_model.train()
    gan_model.train()
    for epoch in range(start_epoch, opt.epochs):
        if epoch > opt.shuffle_epoch:
            print("Shuffling batches for the following epochs")
            train_sampler.shuffle(epoch)
        for i, (data) in enumerate(train_loader,
                                   start=(iters % len(train_dataset))):
            utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
            enhance_loss, enhance_out = enhance_model(clean_inputs, mix_inputs,
                                                      mix_log_inputs,
                                                      cos_angles, input_sizes)
            enhance_feat = feat_model(enhance_out, enhance_cmvn)
            clean_feat = feat_model(clean_inputs, enhance_cmvn)
            set_requires_grad([gan_model], False)
            gan_loss = criterionGAN(gan_model(enhance_feat), True)
            enhance_optimizer.zero_grad()
            loss = enhance_loss + opt.gan_loss_lambda * gan_loss
            loss.backward()
            # compute the gradient norm to check if it is normal or not
            grad_norm = torch.nn.utils.clip_grad_norm_(
                enhance_model.parameters(), opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                enhance_optimizer.step()

            set_requires_grad([gan_model], True)
            gan_optimizer.zero_grad()
            loss_D_real = criterionGAN(gan_model(clean_feat.detach()), True)
            loss_D_fake = criterionGAN(gan_model(enhance_feat.detach()), False)
            loss_D = (loss_D_real + loss_D_fake) * 0.5
            loss_D.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(gan_model.parameters(),
                                                       opt.grad_clip)
            if math.isnan(grad_norm):
                logging.warning('grad norm is nan. Do not update model.')
            else:
                gan_optimizer.step()

            iters += 1
            errors = {
                'train/loss': loss.item(),
                'train/gan_loss': gan_loss.item(),
                'train/enhance_loss': enhance_loss.item(),
                'train/loss_D': loss_D.item()
            }
            visualizer.set_current_errors(errors)
            if iters % opt.print_freq == 0:
                visualizer.print_current_errors(epoch, iters)
                state = {
                    'enhance_state_dict': enhance_model.state_dict(),
                    'gan_state_dict': gan_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': opt.eps,
                    'lr': opt.lr,
                    'best_loss': best_loss,
                    'loss_report': loss_report
                }
                if opt.enhance_opt_type == 'gan_fft':
                    state['fft_state_dict'] = feat_model.state_dict()
                elif opt.enhance_opt_type == 'gan_fbank':
                    state['fbank_state_dict'] = feat_model.state_dict()
                filename = 'latest'
                utils.save_checkpoint(state, opt.exp_path, filename=filename)

            if iters % opt.validate_freq == 0:
                enhance_model.eval()
                feat_model.eval()
                gan_model.eval()
                torch.set_grad_enabled(False)
                num_saved_specgram = 0
                for i, (data) in tqdm(enumerate(val_loader, start=0)):
                    utt_ids, spk_ids, clean_inputs, clean_log_inputs, mix_inputs, mix_log_inputs, cos_angles, targets, input_sizes, target_sizes = data
                    loss, enhance_out = enhance_model(clean_inputs, mix_inputs,
                                                      mix_log_inputs,
                                                      cos_angles, input_sizes)
                    errors = {'val/loss': loss.item()}
                    visualizer.set_current_errors(errors)

                    if opt.num_saved_specgram > 0:
                        if num_saved_specgram < opt.num_saved_specgram:
                            enhanced_outs = enhance_model.calculate_all_specgram(
                                mix_inputs, mix_log_inputs, input_sizes)
                            for x in range(len(utt_ids)):
                                enhanced_out = enhanced_outs[x].data.cpu(
                                ).numpy()
                                enhanced_out[enhanced_out <= 1e-7] = 1e-7
                                enhanced_out = np.log10(enhanced_out)
                                clean_input = clean_inputs[x].data.cpu().numpy(
                                )
                                clean_input[clean_input <= 1e-7] = 1e-7
                                clean_input = np.log10(clean_input)
                                mix_input = mix_inputs[x].data.cpu().numpy()
                                mix_input[mix_input <= 1e-7] = 1e-7
                                mix_input = np.log10(mix_input)
                                utt_id = utt_ids[x]
                                file_name = "{}_ep{}_it{}.png".format(
                                    utt_id, epoch, iters)
                                input_size = int(input_sizes[x])
                                visualizer.plot_specgram(
                                    clean_input, mix_input, enhanced_out,
                                    input_size, file_name)
                                num_saved_specgram += 1
                                if num_saved_specgram >= opt.num_saved_specgram:
                                    break
                enhance_model.train()
                feat_model.train()
                gan_model.train()
                torch.set_grad_enabled(True)

                visualizer.print_epoch_errors(epoch, iters)
                loss_report = visualizer.plot_epoch_errors(
                    epoch, iters, 'loss.png')
                train_loss = visualizer.get_current_errors('train/loss')
                val_loss = visualizer.get_current_errors('val/loss')
                filename = None
                if val_loss > best_loss:
                    print('val_loss {} > best_loss {}'.format(
                        val_loss, best_loss))
                    opt.eps = utils.adadelta_eps_decay(optimizer,
                                                       opt.eps_decay)
                else:
                    filename = 'model.loss.best'
                best_loss = min(val_loss, best_loss)
                print('best_loss {}'.format(best_loss))

                state = {
                    'enhance_state_dict': enhance_model.state_dict(),
                    'gan_state_dict': gan_model.state_dict(),
                    'opt': opt,
                    'epoch': epoch,
                    'iters': iters,
                    'eps': opt.eps,
                    'lr': opt.lr,
                    'best_loss': best_loss,
                    'loss_report': loss_report
                }
                if opt.enhance_opt_type == 'gan_fft':
                    state['fft_state_dict'] = feat_model.state_dict()
                elif opt.enhance_opt_type == 'gan_fbank':
                    state['fbank_state_dict'] = feat_model.state_dict()
                ##filename='epoch-{}_iters-{}_loss-{:.6f}-{:.6f}.pth'.format(epoch, iters, train_loss, val_loss)
                utils.save_checkpoint(state, opt.exp_path, filename=filename)
                visualizer.reset()
                enhance_cmvn = compute_cmvn_epoch(opt, train_loader,
                                                  enhance_model, feat_model)