Ejemplo n.º 1
0
    def tst_input_fn():

        return dataset(rgb_imgs,
                       None,
                       None,
                       params['batch_size'],
                       repeat=True,
                       shuffle=False,
                       dim_imgs=params['rgb_size'],
                       dim_lbls=params['gi_size'])
Ejemplo n.º 2
0
  def trn_input_fn():
    trn_set = load_pairs(params['data_dir'], params, 'trn')
    rgb_imgs, geo_imgs, msks = trn_set['rgb_list'], trn_set['gi_list'], trn_set['mask_list']
    msks = msks if params['use_mask'] else None

    # Converting list of tuple to list of string
    geo_imgs = [x[0] for x in geo_imgs]

    return dataset(
      rgb_imgs, geo_imgs, msks,
      params['batch_size'],
      repeat=params['epochs_between_evals'],
      shuffle=True,
      dim_imgs=params['rgb_size'],
      dim_lbls=params['gi_size']
    )
Ejemplo n.º 3
0
  def val_input_fn():
    val_set = load_pairs(params['data_dir'], params, 'val')
    rgb_imgs, geo_imgs, msks = val_set['rgb_list'], val_set['gi_list'], val_set['mask_list']
    msks = msks if params['use_mask'] else None

    # Converting list of tuple to list of string
    geo_imgs = [x[0] for x in geo_imgs]

    return dataset(
      rgb_imgs, geo_imgs, msks,
      params['batch_size'],
      repeat=True,
      shuffle=True,
      dim_imgs=params['rgb_size'],
      dim_lbls=params['gi_size']
    )
Ejemplo n.º 4
0
def prepare_data():
    test_text = dataset()

    results = senta.sentiment_classify(data={"text": test_text})

    results = [
        {'text': item['text'], 'sentiment': item['sentiment_key']} for item in results
    ]

    print("data length is {}".format(len(results)))

    with open('data.json', 'w+', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False)

    with open('data.csv', 'w+', encoding='utf-8') as f:
        writer = csv.DictWriter(f, results[0].keys())
        writer.writeheader()
        writer.writerows(results)

    with open('data.txt', 'w+', encoding='utf-8') as f:
        f.writelines(map(lambda item: '{}\t{}\n'.format(*item.values()), results))
Ejemplo n.º 5
0
def adversarial_train(train_step,
                      train_pattern,
                      cfg,
                      spec_dir=None,
                      resume_checkpoints=None,
                      current_time=None):

    checkpoints_dir = cfg['SRC_ROOT_DIR'] + 'checkpoints/'
    current_save_dir = checkpoints_dir + train_pattern + '/adversarial/' + current_time
    fig_dir = current_save_dir + '/fig/'
    if not os.path.exists(current_save_dir):
        os.system('mkdir -p ' + current_save_dir)

    if cfg['APPLY_DROPOUT']:
        from models.TTSModel_dropout import melSyn, SSRN
    else:
        from models.TTSModel import melSyn, SSRN

    from models.discriminator import melDisc, linDisc

    train_dataset = dataset(cfg=cfg,
                            mode='train',
                            pattern=train_pattern,
                            step=train_step,
                            spec_dir=spec_dir)

    validate_dataset = dataset(cfg=cfg,
                               mode='validate',
                               pattern=train_pattern,
                               step=train_step,
                               spec_dir=spec_dir)

    if train_step == 'train_text2mel':
        # subtract 1 because we merge "'" and '"'.
        model = melSyn(vocab_len=len(cfg['VOCABULARY']) - 1,
                       condition=(train_pattern == 'conditional'),
                       spkemb_dim=cfg['SPK_EMB_DIM'],
                       textemb_dim=cfg['TEXT_EMB_DIM'],
                       freq_bins=cfg['COARSE_MELSPEC']['FREQ_BINS'],
                       hidden_dim=cfg['HIDDEN_DIM'])

        disc = melDisc(freq_bins=cfg['COARSE_MELSPEC']['FREQ_BINS'],
                       disc_dim=cfg['DISC_DIM'])

        if cfg['MULTI_GPU']:
            model = torch.nn.DataParallel(model)
            disc = torch.nn.DataParallel(disc)

    if train_step == 'train_ssrn':
        model = SSRN(freq_bins=cfg['COARSE_MELSPEC']['FREQ_BINS'],
                     output_bins=(1 + cfg['STFT']['FFT_LENGTH'] // 2),
                     ssrn_dim=cfg['SSRN_DIM'])

        disc = linDisc(freq_bins=(1 + cfg['STFT']['FFT_LENGTH'] // 2),
                       disc_dim=cfg['DISC_DIM'])

        if cfg['MULTI_GPU']:
            model = torch.nn.DataParallel(model)
            disc = torch.nn.DataParallel(disc)

    # If train from scratch, initialize recursively.
    if resume_checkpoints is None:
        model.apply(init_weights)
        disc.apply(init_weights)
        epoch = 0
        iteration = 0
        print('CUDA available: ', torch.cuda.is_available())
        model.to(device)
        disc.to(device)
        opt_syn = optim.Adam(model.parameters(), cfg['ADAM']['ALPHA'],
                             (cfg['ADAM']['BETA_1'], cfg['ADAM']['BETA_2']),
                             cfg['ADAM']['EPSILON'])
        opt_disc = optim.Adam(disc.parameters(), cfg['ADAM']['ALPHA'],
                              (cfg['ADAM']['BETA_1'], cfg['ADAM']['BETA_2']),
                              cfg['ADAM']['EPSILON'])
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.3, patience=1, verbose=True, min_lr=1e-8)
        loss_val_log_syn = []
        loss_val_log_syn_onlyfromD = []
        loss_val_log_disc = []
        loss_train_log_syn = []
        loss_train_log_syn_onlyfromD = []
        loss_train_log_disc = []
        loss_train_smooth_log_syn = []
        loss_train_smooth_log_syn_onlyfromD = []
        loss_train_smooth_log_disc = []
    else:
        print('CUDA available: ', torch.cuda.is_available())
        checkpoint = torch.load(resume_checkpoints)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        disc.apply(init_weights)
        disc.to(device)
        opt_syn = optim.Adam(model.parameters(), cfg['ADAM']['ALPHA'],
                             (cfg['ADAM']['BETA_1'], cfg['ADAM']['BETA_2']),
                             cfg['ADAM']['EPSILON'])
        opt_disc = optim.Adam(disc.parameters(), cfg['ADAM']['ALPHA'],
                              (cfg['ADAM']['BETA_1'], cfg['ADAM']['BETA_2']),
                              cfg['ADAM']['EPSILON'])
        opt_syn.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        iteration = checkpoint['iteration']
        loss_val_log_syn = checkpoint['loss_val_log']
        loss_val_log_syn_onlyfromD = []
        loss_val_log_disc = []
        loss_train_log_syn = []
        loss_train_log_syn_onlyfromD = []
        loss_train_log_disc = []
        loss_train_smooth_log_syn = []
        loss_train_smooth_log_syn_onlyfromD = []
        loss_train_smooth_log_disc = []

    train_loader = DataLoader(train_dataset,
                              cfg['BATCH_SIZE'],
                              shuffle=True,
                              num_workers=4,
                              collate_fn=collate_pad_3 if train_step
                              == 'train_text2mel' else collate_pad_2)
    validate_loader = DataLoader(validate_dataset,
                                 batch_size=8,
                                 shuffle=True,
                                 num_workers=2,
                                 collate_fn=collate_pad_3 if train_step
                                 == 'train_text2mel' else collate_pad_2)

    # guided attention weights.
    gaw = guided_attention_mat(cfg['MAX_TEXT_LEN'], cfg['MAX_FRAME_NUM'])

    # check model's device
    print('Model G Device:', next(model.parameters()).device)
    print('Model D Device:', next(disc.parameters()).device)

    loss_iter_G = 0
    loss_iter_G_onlyfromD = 0
    loss_iter_D = 0

    while epoch < cfg['MAX_EPOCHS']:
        print('Epoch ', epoch + 1)
        print('*******************')
        loader_len = len(train_loader)

        for i, sp in enumerate(train_loader):
            # Training
            start_iter = time.time()
            opt_syn.zero_grad()
            opt_disc.zero_grad()

            train_target = 'D' if iteration % (cfg['RATIO'] + 1) else 'G'
            print('Iteration {}/{} for epoch {}, training {}'.format(
                str(i + 1), str(loader_len), str(epoch + 1), train_target))
            print('Global iteration ', iteration + 1)

            if train_step == 'train_text2mel':
                mel_gt = sp['data_0'].to(device)
                text_id = sp['data_1'].to(device)
                spk_emb = sp['data_2'].to(device)

                spec_inputs = torch.cat(
                    (torch.zeros_like(mel_gt[:, :, :1]), mel_gt[:, :, :-1]),
                    dim=-1)
                pred_mel_prob, att_mat = model(spec_inputs, text_id, spk_emb)

                disc_gt = disc(mel_gt[:, :, 1:9])
                disc_syn = disc(pred_mel_prob[:, :, 1:9])

                if train_target == 'G':
                    loss_l1 = torch.mean(torch.abs(mel_gt - pred_mel_prob))
                    loss_bin_div = torch.mean(
                        -mel_gt * torch.log(pred_mel_prob + 1e-8) -
                        (1 - mel_gt) * torch.log(1 - pred_mel_prob + 1e-8))
                    att_aug = F.pad(
                        att_mat, (0, cfg['MAX_FRAME_NUM'] - att_mat.size()[-1],
                                  0, cfg['MAX_TEXT_LEN'] - att_mat.size()[-2]),
                        value=-1)
                    # Here gaw will broadcast along axis 0.
                    loss_att = torch.sum(
                        torch.ne(att_aug, -1).float() * att_aug *
                        gaw) / torch.sum(torch.ne(att_aug, -1).float())
                    loss_disc = torch.mean(-torch.log(disc_syn + 1e-8))

                    loss = loss_l1 + loss_bin_div + loss_att + (
                        loss_l1.item() + loss_bin_div.item() +
                        loss_att.item()) / (loss_disc.item()) * loss_disc
                    loss_iter_G += loss.item()
                    loss_iter_G_onlyfromD += loss_disc.item()
                    loss_train_log_syn.append(loss.item())
                    loss_train_log_syn_onlyfromD.append(loss_disc.item())
                    print('L1:{}, BD:{}, ATT:{}, DISC:{}, ALL{}'.format(
                        str(loss_l1.item()), str(loss_bin_div.item()),
                        str(loss_att.item()), str(loss_disc.item()),
                        str(loss.item())))

                if train_target == 'D':
                    loss = torch.mean(-torch.log(disc_gt + 1e-8) -
                                      torch.log(1 - disc_syn + 1e-8))
                    loss_iter_D += loss.item()
                    loss_train_log_disc.append(loss.item())
                    print('DISC: ', loss.item())

            if train_step == 'train_ssrn':
                mel_gt = sp['data_0'].to(device)
                lin_gt = sp['data_1'].to(device)

                pred_lin_prob = model(mel_gt)

                disc_gt = disc(lin_gt[:, :, 1:33])
                disc_syn = disc(pred_lin_prob[:, :, 1:33])

                if train_target == 'G':
                    loss_l1 = torch.mean(torch.abs(lin_gt - pred_lin_prob))
                    loss_bin_div = torch.mean(
                        -lin_gt * torch.log(pred_lin_prob + 1e-8) -
                        (1 - lin_gt) * torch.log(1 - pred_lin_prob + 1e-8))
                    loss_disc = torch.mean(-torch.log(disc_syn + 1e-8))

                    loss = loss_l1 + loss_bin_div + (loss_l1.item(
                    ) + loss_bin_div.item()) / (loss_disc.item()) * loss_disc
                    loss_iter_G += loss.item()
                    loss_iter_G_onlyfromD += loss_disc.item()
                    loss_train_log_syn.append(loss.item())
                    loss_train_log_syn_onlyfromD.append(loss_disc.item())
                    print('L1:{}, BD:{}, DISC:{}, ALL:{}'.format(
                        str(loss_l1.item()), str(loss_bin_div.item()),
                        str(loss_disc.item()), str(loss.item())))

                if train_target == 'D':
                    loss = torch.mean(-torch.log(disc_gt + 1e-8) -
                                      torch.log(1 - disc_syn + 1e-8))
                    loss_iter_D += loss.item()
                    loss_train_log_disc.append(loss.item())
                    print('DISC: ', loss.item())

            loss.backward()
            if train_target == 'G':
                opt_syn.step()
            if train_target == 'D':
                opt_disc.step()

            print('\n')
            if (iteration % cfg['VAL_EVERY_ITER'] == 0) and iteration > 0:
                print('No.{} VALIDATION'.format(
                    str(iteration // cfg['VAL_EVERY_ITER'])))
                print(
                    'Generator average training loss: ', loss_iter_G /
                    (cfg['VAL_EVERY_ITER'] // (cfg['RATIO'] + 1)))
                print(
                    'Discriminator average training loss: ',
                    loss_iter_D / (cfg['VAL_EVERY_ITER'] //
                                   (cfg['RATIO'] + 1) * cfg['RATIO']))
                loss_train_smooth_log_syn.append(loss_iter_G /
                                                 (cfg['VAL_EVERY_ITER'] //
                                                  (cfg['RATIO'] + 1)))
                loss_train_smooth_log_syn_onlyfromD.append(
                    loss_iter_G_onlyfromD / (cfg['VAL_EVERY_ITER'] //
                                             (cfg['RATIO'] + 1)))
                loss_train_smooth_log_disc.append(
                    loss_iter_D / (cfg['VAL_EVERY_ITER'] //
                                   (cfg['RATIO'] + 1) * cfg['RATIO']))
                loss_iter_G = 0
                loss_iter_G_onlyfromD = 0
                loss_iter_D = 0

                # model.eval()
                disc.eval()
                loss_val_syn, loss_val_syn_onlyfromD, loss_val_disc, loss_train_syn, loss_train_disc = validate(
                    validate_loader, train_loader, gaw, cfg, model, disc,
                    train_step)
                loss_val_log_syn.append(loss_val_syn)
                loss_val_log_syn_onlyfromD.append(loss_val_syn_onlyfromD)
                loss_val_log_disc.append(loss_val_disc)
                # model.train()
                disc.train()

                ## How to decide current best model?
                if loss_val_log_syn.index(
                        min(loss_val_log_syn)) == len(loss_val_log_syn) - 1:
                    print('Current Best Model!')
                    torch.save(
                        {
                            'epoch':
                            epoch + 1,
                            'iteration':
                            iteration + 1,
                            'model_state_dict':
                            model.module.state_dict()
                            if cfg['MULTI_GPU'] else model.state_dict(),
                            'disc_state_dict':
                            disc.module.state_dict()
                            if cfg['MULTI_GPU'] else disc.state_dict(),
                            'opt_state_dict_syn:':
                            opt_syn.state_dict(),
                            'opt_state_dict_disc:':
                            opt_disc.state_dict(),
                            'loss_val_log_syn':
                            loss_val_log_syn,
                            'loss_val_log_syn_onlyfromD':
                            loss_val_log_syn_onlyfromD,
                            'loss_val_log_disc':
                            loss_val_log_disc,
                            'loss_train_log_syn':
                            loss_train_log_syn,
                            'loss_train_log_syn_onlyfromD':
                            loss_train_log_syn_onlyfromD,
                            'loss_train_log_disc':
                            loss_train_log_disc,
                            'loss_train_smooth_log_syn':
                            loss_train_smooth_log_syn,
                            'loss_train_smooth_log_syn_onlyfromD':
                            loss_train_smooth_log_syn_onlyfromD,
                            'loss_train_smooth_log_disc':
                            loss_train_smooth_log_disc
                        }, current_save_dir +
                        '/{}_best_model.tar.pth'.format(train_step[6:]))

                print(
                    'Generator validation loss of No.{} VALIDATION: {} on val set. {} on train set.'
                    .format(str(iteration // cfg['VAL_EVERY_ITER']),
                            str(loss_val_syn), str(loss_train_syn)))
                print(
                    'Discriminator validation loss of No.{} VALIDATION: {} on val set. {} on train set.'
                    .format(str(iteration // cfg['VAL_EVERY_ITER']),
                            str(loss_val_disc), str(loss_train_disc)))

                torch.save(
                    {
                        'epoch':
                        epoch + 1,
                        'iteration':
                        iteration + 1,
                        'model_state_dict':
                        model.module.state_dict()
                        if cfg['MULTI_GPU'] else model.state_dict(),
                        'disc_state_dict':
                        disc.module.state_dict()
                        if cfg['MULTI_GPU'] else disc.state_dict(),
                        'opt_state_dict_syn:':
                        opt_syn.state_dict(),
                        'opt_state_dict_disc:':
                        opt_disc.state_dict(),
                        'loss_val_log_syn':
                        loss_val_log_syn,
                        'loss_val_log_syn_onlyfromD':
                        loss_val_log_syn_onlyfromD,
                        'loss_val_log_disc':
                        loss_val_log_disc,
                        'loss_train_log_syn':
                        loss_train_log_syn,
                        'loss_train_log_syn_onlyfromD':
                        loss_train_log_syn_onlyfromD,
                        'loss_train_log_disc':
                        loss_train_log_disc,
                        'loss_train_smooth_log_syn':
                        loss_train_smooth_log_syn,
                        'loss_train_smooth_log_syn_onlyfromD':
                        loss_train_smooth_log_syn_onlyfromD,
                        'loss_train_smooth_log_disc':
                        loss_train_smooth_log_disc
                    }, current_save_dir + '/{}_iteration_{}.tar.pth'.format(
                        train_step[6:], str(iteration + 1)))
                print('At iteration {} {} modelsaved at {}.'.format(
                    str(iteration + 1), train_step[6:], current_save_dir))

                if train_step == 'train_text2mel':
                    plot_attention(att=att_mat[0, :, :],
                                   iters=iteration + 1,
                                   fig_dir=fig_dir)

                if cfg['PLOT_CURVE']:
                    losses = {
                        'v_s': loss_val_log_syn,
                        'v_s_o': loss_val_log_syn_onlyfromD,
                        'v_d': loss_val_log_disc,
                        't_s': loss_train_log_syn,
                        't_s_o': loss_train_log_syn_onlyfromD,
                        't_d': loss_train_log_disc,
                        't_s_s': loss_train_smooth_log_syn,
                        't_s_s_o': loss_train_smooth_log_syn_onlyfromD,
                        't_s_d': loss_train_smooth_log_disc
                    }
                    plot_loss(losses, iteration + 1, fig_dir)

            end_iter = time.time()
            iteration += 1
            print('Time elapsed {}s.'.format(str(end_iter - start_iter)))

        epoch += 1
Ejemplo n.º 6
0
def ordinary_train(train_step,
                   train_pattern,
                   cfg,
                   spec_dir=None,
                   resume_checkpoints=None,
                   current_time=None):
    """
	Args:
	train_step: 'train_text2mel' or 'train_ssrn'.
	train_pattern: 'universal', 'conditional'. 'ubm-finetune' should be in another function.
	resume_checkpoints: directory of checkpoints to resume.
	"""
    checkpoints_dir = cfg['SRC_ROOT_DIR'] + 'checkpoints/'
    current_save_dir = checkpoints_dir + train_pattern + '/not_adversarial/' + current_time
    fig_dir = current_save_dir + '/fig/'
    if not os.path.exists(current_save_dir):
        os.system('mkdir -p ' + current_save_dir)

    if cfg['APPLY_DROPOUT']:
        from models.TTSModel_dropout import melSyn, SSRN
    else:
        from models.TTSModel import melSyn, SSRN

    train_dataset = dataset(cfg=cfg,
                            mode='train',
                            pattern=train_pattern,
                            step=train_step,
                            spec_dir=spec_dir)

    validate_dataset = dataset(cfg=cfg,
                               mode='validate',
                               pattern=train_pattern,
                               step=train_step,
                               spec_dir=spec_dir)

    if train_step == 'train_text2mel':
        # subtract 1 because we merge "'" and '"'.
        model = melSyn(vocab_len=len(cfg['VOCABULARY']) - 1,
                       condition=(train_pattern == 'conditional'),
                       spkemb_dim=cfg['SPK_EMB_DIM'],
                       textemb_dim=cfg['TEXT_EMB_DIM'],
                       freq_bins=cfg['COARSE_MELSPEC']['FREQ_BINS'],
                       hidden_dim=cfg['HIDDEN_DIM'])
        if cfg['MULTI_GPU']:
            model = torch.nn.DataParallel(model)

    if train_step == 'train_ssrn':
        model = SSRN(freq_bins=cfg['COARSE_MELSPEC']['FREQ_BINS'],
                     output_bins=(1 + cfg['STFT']['FFT_LENGTH'] // 2),
                     ssrn_dim=cfg['SSRN_DIM'])
        if cfg['MULTI_GPU']:
            model = torch.nn.DataParallel(model)

    # If train from scratch, initialize recursively.
    if resume_checkpoints is None:
        model.apply(init_weights)
        epoch = 0
        iteration = 0
        print('CUDA available: ', torch.cuda.is_available())
        model.to(device)
        optimizer = optim.Adam(model.parameters(), cfg['ADAM']['ALPHA'],
                               (cfg['ADAM']['BETA_1'], cfg['ADAM']['BETA_2']),
                               cfg['ADAM']['EPSILON'])
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.3, patience=1, verbose=True, min_lr=1e-8)
        loss_val_log = []
    else:
        # load checkpoint.
        print('CUDA available: ', torch.cuda.is_available())
        checkpoint = torch.load(resume_checkpoints)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        optimizer = optim.Adam(model.parameters(), cfg['ADAM']['ALPHA'],
                               (cfg['ADAM']['BETA_1'], cfg['ADAM']['BETA_2']),
                               cfg['ADAM']['EPSILON'])

        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        epoch = checkpoint['epoch']
        iteration = checkpoint['iteration']
        loss_val_log = checkpoint['loss_val_log']

    train_loader = DataLoader(train_dataset,
                              cfg['BATCH_SIZE'],
                              shuffle=True,
                              num_workers=4,
                              collate_fn=collate_pad_3 if train_step
                              == 'train_text2mel' else collate_pad_2)
    validate_loader = DataLoader(validate_dataset,
                                 batch_size=8,
                                 shuffle=True,
                                 num_workers=2,
                                 collate_fn=collate_pad_3 if train_step
                                 == 'train_text2mel' else collate_pad_2)

    # guided attention weights.
    gaw = guided_attention_mat(cfg['MAX_TEXT_LEN'], cfg['MAX_FRAME_NUM'])

    # check model's device
    print('Model device:', next(model.parameters()).device)
    loss_iter = 0

    while epoch < cfg['MAX_EPOCHS']:
        #start_epoch = time.time()
        print('Epoch ', epoch + 1)
        print('*******************')
        # loss_epoch = 0
        loader_len = len(train_loader)

        for i, sp in enumerate(train_loader):
            # Training.
            start_iter = time.time()
            optimizer.zero_grad()

            if train_step == 'train_text2mel':
                mel_gt = sp['data_0'].to(device)
                text_id = sp['data_1'].to(device)
                spk_emb = sp['data_2'].to(device)

                spec_inputs = torch.cat(
                    (torch.zeros_like(mel_gt[:, :, :1]), mel_gt[:, :, :-1]),
                    dim=-1)
                pred_mel_prob, att_mat = model(spec_inputs, text_id, spk_emb)

                # Loss.
                loss_l1 = torch.mean(torch.abs(mel_gt - pred_mel_prob))
                loss_bin_div = torch.mean(
                    -mel_gt * torch.log(pred_mel_prob + 1e-8) -
                    (1 - mel_gt) * torch.log(1 - pred_mel_prob + 1e-8))
                att_aug = F.pad(att_mat,
                                (0, cfg['MAX_FRAME_NUM'] - att_mat.size()[-1],
                                 0, cfg['MAX_TEXT_LEN'] - att_mat.size()[-2]),
                                value=-1)
                # Here gaw will broadcast along axis 0.
                loss_att = torch.sum(
                    torch.ne(att_aug, -1).float() * att_aug * gaw) / torch.sum(
                        torch.ne(att_aug, -1).float())

                loss = loss_l1 + loss_bin_div + loss_att
                loss.backward()
                optimizer.step()
                loss_iter += loss.item()
                print(
                    'Iteration {}/{}'.format(str(i + 1), str(loader_len)),
                    ' for epoch {}, loss: {} {} {} {}'.format(
                        str(epoch + 1), str(loss_l1.item()),
                        str(loss_bin_div.item()), str(loss_att.item()),
                        str(loss.item())),
                    'global iteration {}'.format(str(iteration + 1)))

            if train_step == 'train_ssrn':
                mel_gt = sp['data_0'].to(device)
                lin_gt = sp['data_1'].to(device)

                pred_lin_prob = model(mel_gt)

                # Loss.
                loss_l1 = torch.mean(torch.abs(lin_gt - pred_lin_prob))
                loss_bin_div = torch.mean(
                    -lin_gt * torch.log(pred_lin_prob + 1e-8) -
                    (1 - lin_gt) * torch.log(1 - pred_lin_prob + 1e-8))

                loss = loss_l1 + loss_bin_div
                loss.backward()
                optimizer.step()
                loss_iter += loss.item()
                print(
                    'Iteration {}/{}'.format(str(i + 1), str(loader_len)),
                    ' for epoch {}, loss: {} {} {}'.format(
                        str(epoch + 1), str(loss_l1.item()),
                        str(loss_bin_div.item()), str(loss.item())),
                    'global iteration {}'.format(str(iteration + 1)))

            if (iteration % cfg['VAL_EVERY_ITER'] == 0) and iteration > 0:
                print('\n')
                print('No.{} VALIDATION'.format(
                    str(iteration // cfg['VAL_EVERY_ITER'])))
                print('Average training loss: ',
                      loss_iter / cfg['VAL_EVERY_ITER'])
                loss_iter = 0

                model.eval()
                loss_val, loss_val_train = validate(loader=validate_loader,
                                                    trainloader=train_loader,
                                                    gaw=gaw,
                                                    cfg=cfg,
                                                    model=model,
                                                    train_step=train_step)
                loss_val_log.append(loss_val)
                model.train()

                if loss_val_log.index(
                        min(loss_val_log)) == len(loss_val_log) - 1:
                    print('Current Best Model!')
                    torch.save(
                        {
                            'epoch':
                            epoch + 1,
                            'iteration':
                            iteration + 1,
                            'model_state_dict':
                            model.module.state_dict()
                            if cfg['MULTI_GPU'] else model.state_dict(),
                            'optimizer_state_dict':
                            optimizer.state_dict(),
                            'loss_val_log':
                            loss_val_log
                        }, current_save_dir +
                        '/{}.tar.pth'.format(train_step[6:] + '_best_model'))

                print(
                    'Validation loss of No.{} validation: {} on validation set. {} on train set.'
                    .format(str(iteration // cfg['VAL_EVERY_ITER']),
                            str(loss_val), str(loss_val_train)))

                torch.save(
                    {
                        'epoch':
                        epoch + 1,
                        'iteration':
                        iteration + 1,
                        'model_state_dict':
                        model.module.state_dict()
                        if cfg['MULTI_GPU'] else model.state_dict(),
                        'optimizer_state_dict':
                        optimizer.state_dict(),
                        'loss_val_log':
                        loss_val_log
                    }, current_save_dir + '/{}_iteration_{}.tar.pth'.format(
                        train_step[6:], str(iteration + 1)))
                print(
                    'At iteration ', iteration + 1,
                    '{} model saved at {}'.format(train_step[6:],
                                                  current_save_dir))

                if train_step == 'train_text2mel':
                    plot_attention(att=att_mat[0, :, :],
                                   iters=iteration + 1,
                                   fig_dir=fig_dir)

            end_iter = time.time()
            iteration += 1
            print('Time elapsed {}s'.format(str(end_iter - start_iter)))

        epoch += 1
Ejemplo n.º 7
0
kwargs = {
    'batch_size': args.batch_size,
    'network': cnn.feat_2D,
    'observation_dim': obs_dim,
    'optimizer': tf.train.AdamOptimizer,
    "num_labels": 2,
    "image_ch_dim": img_ch
}

ds_args = dict()
ds_args['design'] = args.design
ds_args['performance'] = design_performance[args.design]
ds_args['balance'] = args.balance
ds_args['alpha'] = args.alpha
ds_args['nofeat'] = args.nofeat
ds = dataset.dataset(ds_args)

classifier = Classifier(**kwargs)

training_batch = ds.nextBatch(args.batch_size)

if args.load_weights:
    classifier.load_weights(args.load_weights)

bestAcc = 0

for epoch in range(args.epochs):
    trainingLoss = 0
    tp, tn, fp, fn = 0, 0, 0, 0
    for _ in range(args.updates_per_epoch):
        x, label = next(training_batch)
Ejemplo n.º 8
0
def adversarial_train(train_step,
                      train_pattern,
                      cfg,
                      spec_dir=None,
                      checkpoints_dir=None,
                      resume_checkpoints=None,
                      is_parallel=False,
                      dropout=False,
                      current_time=None):

    current_save_dir = checkpoints_dir + train_pattern + '/adversarial/' + current_time
    fig_dir = current_save_dir + '/fig/'
    if not os.path.exists(current_save_dir):
        os.system('mkdir -p ' + current_save_dir)

    if dropout:
        from models.TTSModel_dropout import melSyn, SSRN
    else:
        from models.TTSModel import melSyn, SSRN

    from models.discriminator import melDisc, linDisc

    train_dataset = dataset(
        root_dir=cfg['DATA_ROOT_DIR'],
        spkemb_dir=cfg['SPK_EMB_DIR'],
        vocabulary=cfg['VOCABULARY'],
        n_mels=cfg['COARSE_MELSPEC']['FREQ_BINS'],
        time_frame_reduction=cfg['COARSE_MELSPEC']['REDUCTION'],
        preemphasis=cfg['PREEMPH'],
        n_fft=cfg['STFT']['FFT_LENGTH'],
        hop=cfg['STFT']['HOP_LENGTH'],
        norm_power=cfg['NORM_POWER']['ANALYSIS'],
        mode='train',
        pattern=train_pattern,
        step=train_step,
        spec_dir=spec_dir)

    validate_dataset = dataset(
        root_dir=cfg['DATA_ROOT_DIR'],
        spkemb_dir=cfg['SPK_EMB_DIR'],
        vocabulary=cfg['VOCABULARY'],
        n_mels=cfg['COARSE_MELSPEC']['FREQ_BINS'],
        time_frame_reduction=cfg['COARSE_MELSPEC']['REDUCTION'],
        preemphasis=cfg['PREEMPH'],
        n_fft=cfg['STFT']['FFT_LENGTH'],
        hop=cfg['STFT']['HOP_LENGTH'],
        norm_power=cfg['NORM_POWER']['ANALYSIS'],
        mode='validate',
        pattern=train_pattern,
        step=train_step,
        spec_dir=spec_dir)

    if train_step == 'train_text2mel':
        # subtract 1 because we merge "'" and '"'.
        model = melSyn(vocab_len=len(cfg['VOCABULARY']) - 1,
                       condition=(train_pattern == 'conditional'),
                       spkemb_dim=cfg['SPK_EMB_DIM'],
                       textemb_dim=cfg['TEXT_EMB_DIM'],
                       freq_bins=cfg['COARSE_MELSPEC']['FREQ_BINS'],
                       hidden_dim=cfg['HIDDEN_DIM'])

        disc = melDisc(freq_bins=cfg['COARSE_MELSPEC']['FREQ_BINS'],
                       disc_dim=cfg['DISC_DIM'])

        if is_parallel:
            model = torch.nn.DataParallel(model)
            disc = torch.nn.DataParallel(disc)

    if train_step == 'train_ssrn':
        model = SSRN(freq_bins=cfg['COARSE_MELSPEC']['FREQ_BINS'],
                     output_bins=(1 + cfg['STFT']['FFT_LENGTH'] // 2),
                     ssrn_dim=cfg['SSRN_DIM'])

        disc = linDisc(freq_bins=(1 + cfg['STFT']['FFT_LENGTH'] // 2),
                       disc_dim=cfg['DISC_DIM'])

        if is_parallel:
            model = torch.nn.DataParallel(model)
            disc = torch.nn.DataParallel(disc)

    # If train from scratch, initialize recursively.
    if resume_checkpoints is None:
        model.apply(init_weights)
        disc.apply(init_weights)
        epoch = 0
        iteration = 0
        print('CUDA available: ', torch.cuda.is_available())
        model.to(device)
        disc.to(device)
        opt_syn = optim.RMSprop(model.parameters(),
                                lr=cfg['ADAM']['ALPHA'],
                                eps=1e-6)
        opt_disc = optim.RMSprop(disc.parameters(),
                                 lr=cfg['ADAM']['ALPHA'],
                                 eps=1e-6)
        # opt_syn = optim.Adam(model.parameters(), cfg['ADAM']['ALPHA'], (cfg['ADAM']['BETA_1'], cfg['ADAM']['BETA_2']), cfg['ADAM']['EPSILON'])
        # opt_disc = optim.Adam(disc.parameters(), cfg['ADAM']['ALPHA'], (cfg['ADAM']['BETA_1'], cfg['ADAM']['BETA_2']), cfg['ADAM']['EPSILON'])
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.3, patience=1, verbose=True, min_lr=1e-8)
        # loss_val_log_syn = []
        # loss_val_log_syn_onlyfromD = []
        # loss_val_log_disc = []
        wd_log = []
        loss_train_log_syn = []
        loss_train_log_syn_onlyfromD = []
        loss_train_log_disc = []
        loss_train_smooth_log_syn = []
        loss_train_smooth_log_syn_onlyfromD = []
        loss_train_smooth_log_disc = []
    else:
        pass

    train_loader = DataLoader(train_dataset,
                              cfg['BATCH_SIZE'],
                              shuffle=True,
                              num_workers=4,
                              collate_fn=collate_pad_3 if train_step
                              == 'train_text2mel' else collate_pad_2)
    validate_loader = DataLoader(validate_dataset,
                                 batch_size=8,
                                 shuffle=True,
                                 num_workers=2,
                                 collate_fn=collate_pad_3 if train_step
                                 == 'train_text2mel' else collate_pad_2)

    # guided attention weights.
    gaw = guided_attention_mat(cfg['MAX_TEXT_LEN'], cfg['MAX_FRAME_NUM'])

    # check model's device
    print('Model G Device:', next(model.parameters()).device)
    print('Model D Device:', next(disc.parameters()).device)

    loss_iter_G = 0
    loss_iter_G_onlyfromD = 0
    loss_iter_D = 0

    while epoch < cfg['MAX_EPOCHS']:
        print('Epoch ', epoch + 1)
        print('*******************')
        loader_len = len(train_loader)

        for i, sp in enumerate(train_loader):
            # Training
            start_iter = time.time()
            opt_syn.zero_grad()
            opt_disc.zero_grad()

            train_target = 'D' if iteration % (cfg['RATIO'] + 1) else 'G'
            print('Iteration {}/{} for epoch {}, training {}'.format(
                str(i + 1), str(loader_len), str(epoch + 1), train_target))
            print('Global iteration ', iteration + 1)

            if train_step == 'train_text2mel':
                mel_gt = sp['data_0'].to(device)
                text_id = sp['data_1'].to(device)
                spk_emb = sp['data_2'].to(device)
                B, C, T = mel_gt.shape

                spec_inputs = torch.cat(
                    (torch.zeros_like(mel_gt[:, :, :1]), mel_gt[:, :, :-1]),
                    dim=-1)
                pred_mel_prob, att_mat = model(spec_inputs, text_id, spk_emb)

                if train_target == 'G':
                    disc_syn = disc(pred_mel_prob[:, :, 1:9])

                    loss_l1 = torch.mean(torch.abs(mel_gt - pred_mel_prob))
                    loss_bin_div = torch.mean(
                        -mel_gt * torch.log(pred_mel_prob + 1e-8) -
                        (1 - mel_gt) * torch.log(1 - pred_mel_prob + 1e-8))
                    att_aug = F.pad(
                        att_mat, (0, cfg['MAX_FRAME_NUM'] - att_mat.size()[-1],
                                  0, cfg['MAX_TEXT_LEN'] - att_mat.size()[-2]),
                        value=-1)
                    # Here gaw will broadcast along axis 0.
                    loss_att = torch.sum(
                        torch.ne(att_aug, -1).float() * att_aug *
                        gaw) / torch.sum(torch.ne(att_aug, -1).float())
                    loss_disc = torch.mean(-disc_syn)

                    loss = loss_l1 + loss_bin_div + loss_att + loss_disc
                    loss_iter_G += loss.item()
                    loss_iter_G_onlyfromD += loss_disc.item()
                    loss_train_log_syn.append(loss.item())
                    loss_train_log_syn_onlyfromD.append(loss_disc.item())
                    loss.backward()
                    opt_syn.step()
                    print('L1:{}, BD{}, ATT:{}, DISC:{}, ALL{}'.format(
                        str(loss_l1.item()), str(loss_bin_div.item()),
                        str(loss_att.item()), str(loss_disc.item()),
                        str(loss.item())))

                if train_target == 'D':
                    # coeff = torch.stack(8*[torch.stack(C*[torch.rand(B)],dim=1)],dim=2).to(device)
                    # input_mid = coeff*mel_gt[:, :, 1:9].detach() + (1-coeff)*pred_mel_prob[:, :, 1:9].detach()
                    # input_mid.requires_grad = True
                    # output_mid = disc(input_mid)
                    # # output_mid.backward(retain_graph=True, create_graph=True)
                    # gradients = torch.autograd.grad(outputs=output_mid, inputs=input_mid,  grad_outputs=torch.ones(output_mid.size()).to(device), retain_graph=True, create_graph=True)[0]
                    # loss_gp = torch.mean(cfg['LAMBDA']*(torch.norm(gradients,p=2,dim=(1,2))-1)**2)
                    # # opt_disc.zero_grad()
                    # loss_gp.backward()

                    input_gt = mel_gt[:, :, 1:9].detach()
                    input_syn = pred_mel_prob[:, :, 1:9].detach()
                    disc_gt = disc(input_gt)
                    disc_syn = disc(input_syn)
                    loss_D = torch.mean(disc_syn - disc_gt)
                    loss_D.backward()
                    opt_disc.step()
                    disc.apply(clip_weights)
                    # for p in disc.parameters():
                    # 	p.data.clamp_(-0.01, 0.01)

                    loss = loss_D.item()
                    loss_iter_D += loss
                    loss_train_log_disc.append(loss)
                    wd_log.append(-loss)
                    print('DISC:{}, WD:{}'.format(str(loss), str(-loss)))

            if train_step == 'train_ssrn':
                mel_gt = sp['data_0'].to(device)
                lin_gt = sp['data_1'].to(device)
                B, C, T = lin_gt.shape

                pred_lin_prob = model(mel_gt)

                if train_target == 'G':
                    disc_syn = disc(pred_lin_prob[:, :, 1:33])

                    loss_l1 = torch.mean(torch.abs(lin_gt - pred_lin_prob))
                    loss_bin_div = torch.mean(
                        -lin_gt * torch.log(pred_lin_prob + 1e-8) -
                        (1 - lin_gt) * torch.log(1 - pred_lin_prob + 1e-8))
                    loss_disc = torch.mean(-disc_syn)

                    loss = loss_l1 + loss_bin_div + loss_disc
                    loss_iter_G += loss.item()
                    loss_iter_G_onlyfromD += loss_disc.item()
                    loss_train_log_syn.append(loss.item())
                    loss_train_log_syn_onlyfromD.append(loss_disc.item())
                    loss.backward()
                    opt_syn.step()
                    print('L1:{}, BD:{}, DISC:{}, ALL:{}'.format(
                        str(loss_l1.item()), str(loss_bin_div.item()),
                        str(loss_disc.item()), str(loss.item())))

                if train_target == 'D':
                    # coeff = torch.stack(32*[torch.stack(C*[torch.rand(B)],dim=1)],dim=2).to(device)
                    # input_mid = coeff*lin_gt[:, :, 1:33].detach() + (1-coeff)*pred_lin_prob[:, :, 1:33].detach()
                    # input_mid.requires_grad = True
                    # output_mid = disc(input_mid)
                    # # output_mid.backward(retain_graph=True, create_graph=True)
                    # gradients = torch.autograd.grad(outputs=output_mid, inputs=input_mid, grad_outputs=torch.ones(output_mid.size()).to(device), retain_graph=True, create_graph=True)[0]
                    # loss_gp = torch.mean(cfg['LAMBDA']*(torch.norm(gradients,p=2,dim=(1,2))-1)**2)
                    # # opt_disc.zero_grad()
                    # loss_gp.backward()

                    input_gt = lin_gt[:, :, 1:33].detach()
                    input_syn = pred_lin_prob[:, :, 1:33].detach()
                    disc_gt = disc(input_gt)
                    disc_syn = disc(input_syn)
                    loss_D = torch.mean(disc_syn - disc_gt)
                    loss_D.backward()
                    opt_disc.step()
                    disc.apply(clip_weights)
                    # for p in disc.parameters():
                    # 	p.data.clamp_(-0.01, 0.01)

                    loss = loss_D.item()
                    loss_iter_D += loss
                    loss_train_log_disc.append(loss)
                    wd_log.append(-loss)
                    print('DISC:{}, WD:{}'.format(str(loss), str(-loss)))

            print('\n')
            if (iteration % cfg['VAL_EVERY_ITER'] == 0) and iteration > 0:
                # print('No.{} VALIDATION'.format(str(iteration//cfg['VAL_EVERY_ITER'])))
                print(
                    'Generator average training loss: ', loss_iter_G /
                    (cfg['VAL_EVERY_ITER'] // (cfg['RATIO'] + 1)))
                print(
                    'Discriminator average training loss: ',
                    loss_iter_D / (cfg['VAL_EVERY_ITER'] //
                                   (cfg['RATIO'] + 1) * cfg['RATIO']))
                loss_train_smooth_log_syn.append(loss_iter_G /
                                                 (cfg['VAL_EVERY_ITER'] //
                                                  (cfg['RATIO'] + 1)))
                loss_train_smooth_log_syn_onlyfromD.append(
                    loss_iter_G_onlyfromD / (cfg['VAL_EVERY_ITER'] //
                                             (cfg['RATIO'] + 1)))
                loss_train_smooth_log_disc.append(
                    loss_iter_D / (cfg['VAL_EVERY_ITER'] //
                                   (cfg['RATIO'] + 1) * cfg['RATIO']))
                loss_iter_G = 0
                loss_iter_G_onlyfromD = 0
                loss_iter_D = 0

                # model.eval()
                # disc.eval()
                # loss_val_syn, loss_val_syn_onlyfromD, loss_val_disc, loss_train_syn, loss_train_disc = validate(validate_loader, train_loader, gaw, cfg, model, disc, train_step)
                # loss_val_log_syn.append(loss_val_syn)
                # loss_val_log_syn_onlyfromD.append(loss_val_syn_onlyfromD)
                # loss_val_log_disc.append(loss_val_disc)
                # model.train()
                # disc.train()

                # ## How to decide current best model?
                # if loss_val_log_syn.index(min(loss_val_log_syn)) == len(loss_val_log_syn)-1:
                # 	print('Current Best Model!')
                # 	torch.save({'epoch': epoch+1,
                # 			    'iteration': iteration+1,
                # 			    'model_state_dict': model.module.state_dict() if is_parallel else model.state_dict(),
                # 			    'disc_state_dict': disc.module.state_dict() if is_parallel else disc.state_dict(),
                # 			    'opt_state_dict_syn:': opt_syn.state_dict(),
                # 			    'opt_state_dict_disc:': opt_disc.state_dict(),
                # 			    'loss_val_log_syn': loss_val_log_syn,
                # 			    'loss_val_log_syn_onlyfromD': loss_val_log_syn_onlyfromD,
                # 			    'loss_val_log_disc': loss_val_log_disc,
                # 			    'loss_train_log_syn': loss_train_log_syn,
                # 			    'loss_train_log_syn_onlyfromD': loss_train_log_syn_onlyfromD,
                # 			    'loss_train_log_disc': loss_train_log_disc,
                # 			    'loss_train_smooth_log_syn': loss_train_smooth_log_syn,
                # 			    'loss_train_smooth_log_syn_onlyfromD': loss_train_smooth_log_syn_onlyfromD,
                # 			    'loss_train_smooth_log_disc': loss_train_smooth_log_disc}, current_save_dir+'/{}_best_model.tar.pth'.format(train_step[6:]))

                # print('Generator validation loss of No.{} VALIDATION: {} on val set. {} on train set.'.format(str(iteration//cfg['VAL_EVERY_ITER']), str(loss_val_syn), str(loss_train_syn)))
                # print('Discriminator validation loss of No.{} VALIDATION: {} on val set. {} on train set.'.format(str(iteration//cfg['VAL_EVERY_ITER']), str(loss_val_disc), str(loss_train_disc)))

                torch.save(
                    {
                        'epoch':
                        epoch + 1,
                        'iteration':
                        iteration + 1,
                        'model_state_dict':
                        model.module.state_dict()
                        if is_parallel else model.state_dict(),
                        'disc_state_dict':
                        disc.module.state_dict()
                        if is_parallel else disc.state_dict(),
                        'opt_state_dict_syn:':
                        opt_syn.state_dict(),
                        'opt_state_dict_disc:':
                        opt_disc.state_dict(),
                        # 'loss_val_log_syn': loss_val_log_syn,
                        # 'loss_val_log_syn_onlyfromD': loss_val_log_syn_onlyfromD,
                        # 'loss_val_log_disc': loss_val_log_disc,
                        'wd_log':
                        wd_log,
                        'loss_train_log_syn':
                        loss_train_log_syn,
                        'loss_train_log_syn_onlyfromD':
                        loss_train_log_syn_onlyfromD,
                        'loss_train_log_disc':
                        loss_train_log_disc,
                        'loss_train_smooth_log_syn':
                        loss_train_smooth_log_syn,
                        'loss_train_smooth_log_syn_onlyfromD':
                        loss_train_smooth_log_syn_onlyfromD,
                        'loss_train_smooth_log_disc':
                        loss_train_smooth_log_disc
                    },
                    current_save_dir + '/{}_iteration_{}.tar.pth'.format(
                        train_step[6:], str(iteration + 1)))
                print('At iteration {} {} modelsaved at {}.'.format(
                    str(iteration + 1), train_step[6:], current_save_dir))

                if train_step == 'train_text2mel':
                    plot_attention(att=att_mat[0, :, :],
                                   iters=iteration + 1,
                                   fig_dir=fig_dir)

                losses = {
                    'wd': wd_log,
                    't_s': loss_train_log_syn,
                    't_s_o': loss_train_log_syn_onlyfromD,
                    't_d': loss_train_log_disc,
                    't_s_s': loss_train_smooth_log_syn,
                    't_s_s_o': loss_train_smooth_log_syn_onlyfromD,
                    't_s_d': loss_train_smooth_log_disc
                }
                plot_loss(losses, iteration + 1, fig_dir)

            end_iter = time.time()
            iteration += 1
            print('Time elapsed {}s.'.format(str(end_iter - start_iter)))

        epoch += 1
Ejemplo n.º 9
0
def synthesize(pattern, cfg, spec_dir, current_time=None):
    """
    Args:
    --pattern: 'universal' or 'conditional'.
    --cfg: configuration file.
    --spec_dir: None or Directory of saved spectrograms.
    --model_text2mel: Trained model of text2mel Network.
    --model_ssrn: Trained model of ssrn Network.
    """

    sample_dir = cfg['SRC_ROOT_DIR'] + 'samples/' + current_time + '/'
    fig_dir = sample_dir + 'fig/'

    if not os.path.exists(fig_dir):
        os.system('mkdir -p '+fig_dir)

    if cfg['APPLY_DROPOUT']:
        from models.TTSModel_dropout import melSyn, SSRN
    else:
        from models.TTSModel import melSyn, SSRN

    synthesize_dataset = dataset(cfg=cfg, mode='synthesize', pattern=pattern, step='synthesize', spec_dir=spec_dir)

    m1 = melSyn(vocab_len=len(cfg['VOCABULARY'])-1,
                condition = (pattern == 'conditional'),
                spkemb_dim=cfg['SPK_EMB_DIM'],
                textemb_dim=cfg['TEXT_EMB_DIM'],
                freq_bins=cfg['COARSE_MELSPEC']['FREQ_BINS'],
                hidden_dim=cfg['HIDDEN_DIM'])

    m2 = SSRN(freq_bins=cfg['COARSE_MELSPEC']['FREQ_BINS'],
              output_bins=(1+cfg['STFT']['FFT_LENGTH']//2),
              ssrn_dim=cfg['SSRN_DIM'])

    if cfg['MULTI_GPU']:
        m1 = torch.nn.DataParallel(m1)
        m2 = torch.nn.DataParallel(m2)

    print('CUDA available: ', torch.cuda.is_available())
    ckp1 = torch.load(cfg['INFERENCE_TEXT2MEL_MODEL'])
    ckp2 = torch.load(cfg['INFERENCE_SSRN_MODEL'])
    m1.load_state_dict(ckp1['model_state_dict'])
    m2.load_state_dict(ckp2['model_state_dict'])
    m1.to(device)
    m1.eval()    
    m2.to(device)
    m2.eval()

    synthesize_loader = DataLoader(synthesize_dataset, batch_size=8, shuffle=False, num_workers=2, collate_fn=collate_pad_4)

    gaw = guided_attention_mat(cfg['MAX_TEXT_LEN'], cfg['MAX_FRAME_NUM'])

    loss_avg_t2m = 0
    loss_avg_ssrn = 0

    with torch.no_grad():
        for i, sp in enumerate(synthesize_loader):
            mel_gt = sp['data_0'].to(device)
            text_id = sp['data_1'].to(device)
            spk_emb = sp['data_2'].to(device)
            lin_gt = sp['data_3'].to(device)

            d1, d2, d3 = mel_gt.shape
            init_frame = torch.zeros((d1, d2, 1)).to(device)
            Y, A, prev_maxatt, K, V = m1(melspec=init_frame, textid=text_id, spkemb=spk_emb, pma=torch.zeros((d1,)).long().to(device))
            inputs = torch.cat((init_frame, Y), dim=-1)
            for frame in range(d3-1):
                Y, A, prev_maxatt = m1(melspec=inputs, textid=None, spkemb=spk_emb, K=K, V=V, A_last=A, pma=prev_maxatt)
                inputs = torch.cat((inputs, Y[:, :, -1:]), dim=-1)

            plot_attention(att=A[0, :, :], idx=i+1, fig_dir=fig_dir)

            loss_l1_t2m = torch.mean(torch.abs(mel_gt-Y))
            loss_bin_div_t2m = torch.mean(-mel_gt*torch.log(Y+1e-8)-(1-mel_gt)*torch.log(1-Y+1e-8))
            A_aug = F.pad(A, (0, cfg['MAX_FRAME_NUM']-A.size()[-1], 0, cfg['MAX_TEXT_LEN']-A.size()[-2]), value=-1)
            loss_att = torch.sum(torch.ne(A_aug, -1).float()*A_aug*gaw) / torch.sum(torch.ne(A_aug, -1).float())

            loss_t2m = loss_l1_t2m + loss_bin_div_t2m + loss_att
            loss_avg_t2m += loss_t2m.item()
            print('syn set text2mel loss: {} {} {} {}'.format(str(loss_l1_t2m.item()), str(loss_bin_div_t2m.item()), str(loss_att.item()), str(loss_t2m.item())))

            pred_lin_prob = m2(Y)
            loss_l1_ssrn = torch.mean(torch.abs(lin_gt-pred_lin_prob))
            loss_bin_div_ssrn = torch.mean(-lin_gt*torch.log(pred_lin_prob+1e-8)-(1-lin_gt)*torch.log(1-pred_lin_prob+1e-8))

            loss_ssrn = loss_l1_ssrn + loss_bin_div_ssrn
            loss_avg_ssrn += loss_ssrn.item()
            print('syn set ssrn loss: {} {} {}'.format(str(loss_l1_ssrn.item()), str(loss_bin_div_ssrn.item()), str(loss_ssrn.item())))

            if torch.cuda.is_available():
                pred_lin_prob = pred_lin_prob.cpu()
            pred_lin = pred_lin_prob.numpy()

            if cfg['LOG_FEATURE']:
                pred_lin = pred_lin*cfg['MAX_DB'] - cfg['MAX_DB'] + cfg['REF_DB']
                pred_lin = np.power(10, 0.05*pred_lin)
    
            for k in range(d1):
                # print(pred_lin[k, :, :])
                # print(np.max(np.max(pred_lin[k, :, :])))
                if not cfg['LOG_FEATURE']:
                    pred_lin[k, :, :] = pred_lin[k, :, :]/np.max(pred_lin[k, :, :])
                spec = pred_lin[k, :, :]**(cfg['NORM_POWER']['RECONSTRUCTION']/cfg['NORM_POWER']['ANALYSIS'])
                time_signal = librosa.core.griffinlim(S=spec, n_iter=64, hop_length=cfg['STFT']['HOP_LENGTH'], win_length=cfg['STFT']['FFT_LENGTH'])
                time_signal = signal.lfilter([1], [1, -cfg['PREEMPH']], time_signal)
                #print(time_signal, np.max(time_signal))
                librosa.output.write_wav(sample_dir+'S{}_B{}.wav'.format(str(k+1), str(i+1)), time_signal if cfg['LOG_FEATURE'] else time_signal/np.max(time_signal)*0.75, cfg['SAMPLING_RATE'])