示例#1
0
def main():
    parser = argparse.ArgumentParser(description='training script')
    # data load
    parser.add_argument('--data', type=str, default='blizzard', help='blizzard / nancy')
    parser.add_argument('--batch_size', type=int, default=32, help='batch size')
    parser.add_argument('--text_limit', type=int, default=1000, help='maximum length of text to include in training set')
    parser.add_argument('--wave_limit', type=int, default=1400, help='maximum length of spectrogram to include in training set')
    parser.add_argument('--trunc_size', type=int, default=700, help='used for truncated-BPTT when memory is not enough.')
    parser.add_argument('--shuffle_data', type=int, default=1, help='whether to shuffle data loader')
    parser.add_argument('--load_queue_size', type=int, default=8, help='maximum number of batches to load on the memory')
    parser.add_argument('--n_workers', type=int, default=2, help='number of workers used in data loader')
    # model
    parser.add_argument('--charvec_dim', type=int, default=256, help='')
    parser.add_argument('--hidden_size', type=int, default=128, help='')
    parser.add_argument('--dec_out_size', type=int, default=80, help='decoder output size')
    parser.add_argument('--post_out_size', type=int, default=1025, help='should be n_fft / 2 + 1(check n_fft from "input_specL" ')
    parser.add_argument('--num_filters', type=int, default=16, help='number of filters in filter bank of CBHG')
    parser.add_argument('--r_factor', type=int, default=5, help='reduction factor(# of multiple output)')
    parser.add_argument('--dropout', type=float, default=0.5, help='')
    # optimization
    parser.add_argument('--max_epochs', type=int, default=100000, help='maximum epoch to train')
    parser.add_argument('--grad_clip', type=float, default=1, help='gradient clipping')
    parser.add_argument('--learning_rate', type=float, default=1e-3, help='2e-3 from Ito, I used to use 5e-4')
    parser.add_argument('--lr_decay_every', type=int, default=25000, help='decay learning rate every...')
    parser.add_argument('--lr_decay_factor', type=float, default=0.5, help='decay learning rate by this factor')
    parser.add_argument('--teacher_forcing_ratio', type=float, default=1, help='value between 0~1, use this for scheduled sampling')
    # loading
    parser.add_argument('--init_from', type=str, default='', help='load parameters from...')
    parser.add_argument('--resume', type=int, default=0, help='1 for resume from saved epoch')
    # misc
    parser.add_argument('--exp_no', type=int, default=0, help='')
    parser.add_argument('--print_every', type=int, default=-1, help='')
    parser.add_argument('--plot_every', type=int, default=-1, help='')
    parser.add_argument('--save_every', type=int, default=-1, help='')
    parser.add_argument('--save_dir', type=str, default='checkpoint', help='')
    parser.add_argument('--pinned_memory', type=int, default=1, help='1 to use pinned memory')
    parser.add_argument('--gpu', type=int, nargs='+', help='index of gpu machines to run')
    # debug
    parser.add_argument('--debug', type=int, default=0, help='1 for debug mode')
    args = parser.parse_args()

    torch.manual_seed(0)

    # set dataset option
    if args.data == 'blizzard':
        args.dir_bin = '/home/lyg0722/TTS_corpus/blizzard/segmented/bin/'
    elif args.data == 'etri':
        args.dir_bin = '/data2/lyg0722/TTS_corpus/etri/bin/'
    else:
        print('no dataset')
        return

    if args.gpu is None:
        args.use_gpu = False
        args.gpu = []
    else:
        args.use_gpu = True
        torch.cuda.manual_seed(0)
        torch.cuda.set_device(args.gpu[0])

    loader = DataLoader(args)

    # set misc options
    args.vocab_size = loader.get_num_vocab()
    if args.print_every == -1:
        args.print_every = loader.iter_per_epoch
    if args.plot_every == -1:
        args.plot_every = args.print_every
    if args.save_every == -1:
        args.save_every = loader.iter_per_epoch * 10    # save every 10 epoch by default

    model = Tacotron(args)
    model_optim = optim.Adam(model.parameters(), args.learning_rate)
    criterion_mel = nn.L1Loss(size_average=False)
    criterion_lin = nn.L1Loss(size_average=False)

    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every
    start_epoch = 0
    iter = 1

    if args.init_from:
        checkpoint = torch.load(args.init_from, map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        if args.resume != 0:
            start_epoch = checkpoint['epoch']
            plot_losses = checkpoint['plot_losses']
        print('loaded checkpoint %s (epoch %d)' % (args.init_from, start_epoch))

    model = model.train()
    if args.use_gpu:
        model = model.cuda()
        criterion_mel = criterion_mel.cuda()
        criterion_lin = criterion_lin.cuda()

    print('Start training... (1 epoch = %s iters)' % (loader.iter_per_epoch))
    while iter < args.max_epochs * loader.iter_per_epoch + 1:
        if loader.is_subbatch_end:
            prev_h = (None, None, None)             # set prev_h = h_0 when new sentences are loaded
        enc_input, target_mel, target_lin, wave_lengths, text_lengths = loader.next_batch('train')

        max_wave_len = max(wave_lengths)

        enc_input = Variable(enc_input, requires_grad=False)
        target_mel = Variable(target_mel, requires_grad=False)
        target_lin = Variable(target_lin, requires_grad=False)

        prev_h = loader.mask_prev_h(prev_h)

        model_optim.zero_grad()
        pred_mel, pred_lin, prev_h = model(enc_input, target_mel[:, :-1], wave_lengths, text_lengths, prev_h)

        loss_mel = criterion_mel(pred_mel, target_mel[:, 1:])\
                        .div(max_wave_len * args.batch_size * args.dec_out_size)
        loss_linear = criterion_lin(pred_lin, target_lin[:, 1:])\
                        .div(max_wave_len * args.batch_size * args.post_out_size)
        loss = torch.sum(loss_mel + loss_linear)

        loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)         # gradient clipping
        model_optim.step()

        print_loss_total += loss.data[0]
        plot_loss_total += loss.data[0]

        if iter % args.print_every == 0:
            print_loss_avg = print_loss_total / args.print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, iter / args.max_epochs),
                                         iter, iter / args.max_epochs * 100, print_loss_avg))
        if iter % args.plot_every == 0:
            plot_loss_avg = plot_loss_total / args.plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

            save_name = '%s/%dth_exp_loss.png' % (args.save_dir, args.exp_no)
            savePlot(plot_losses, save_name)


        if iter % args.save_every == 0:
            epoch = start_epoch + iter // loader.iter_per_epoch
            save_name = '%s/%d_%dth.t7' % (args.save_dir, args.exp_no, epoch)
            state = {
                'epoch': epoch,
                'args': args,
                'state_dict': model.state_dict(),
                'optimizer': model_optim.state_dict(),
                'plot_losses': plot_losses
            }
            torch.save(state, save_name)
            print('model saved to', save_name)
            # if is_best:               # TODO: implement saving best model.
            #     shutil.copyfile(save_name, '%s/%d_best.t7' % (args.save_dir, args.exp_no))

        iter += 1
示例#2
0
def main():
    parser = argparse.ArgumentParser(description='training script')
    # data load
    parser.add_argument('--data', type=str, default='blizzard', help='blizzard / nancy')
    parser.add_argument('--batch_size', type=int, default=6, help='batch size')
    parser.add_argument('--text_limit', type=int, default=1500, help='maximum length of text to include in training set')
    parser.add_argument('--wave_limit', type=int, default=800, help='maximum length of spectrogram to include in training set')
    parser.add_argument('--shuffle_data', type=int, default=0, help='whether to shuffle data loader')
    parser.add_argument('--batch_idx', type=int, default=0, help='n-th batch of the dataset')
    parser.add_argument('--load_queue_size', type=int, default=1, help='maximum number of batches to load on the memory')
    parser.add_argument('--n_workers', type=int, default=1, help='number of workers used in data loader')
    # generation option
    parser.add_argument('--exp_no', type=int, default=0, help='')
    parser.add_argument('--out_dir', type=str, default='generated', help='')
    parser.add_argument('--init_from', type=str, default='', help='load parameters from...')
    parser.add_argument('--caption', type=str, default='', help='text to generate speech')
    parser.add_argument('--teacher_forcing_ratio', type=float, default=0, help='value between 0~1, use this for scheduled sampling')
    # audio related option
    parser.add_argument('--n_fft', type=int, default=2048, help='fft bin size')
    parser.add_argument('--sample_rate', type=int, default=16000, help='sampling rate')
    parser.add_argument('--frame_len_inMS', type=int, default=50, help='used to determine window size of fft')
    parser.add_argument('--frame_shift_inMS', type=int, default=12.5, help='used to determine stride in sfft')
    parser.add_argument('--num_recon_iters', type=int, default=50, help='# of iteration in griffin-lim recon')
    # misc
    parser.add_argument('--gpu', type=int, nargs='+', help='index of gpu machines to run')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    new_args = vars(parser.parse_args())

    # load and override some arguments
    checkpoint = torch.load(new_args['init_from'], map_location=lambda storage, loc: storage)
    args = checkpoint['args']
    for i in new_args:
        args.__dict__[i] = new_args[i]

    torch.manual_seed(args.seed)

    # set dataset option
    if args.data == 'blizzard':
        args.dir_bin = '/data2/lyg0722/TTS_corpus/blizzard/segmented/bin/'
    elif args.data == 'etri':
        args.dir_bin = '/data2/lyg0722/TTS_corpus/etri/bin/'
    else:
        print('no dataset')
        return

    if args.gpu is None:
        args.use_gpu = False
        args.gpu = []
    else:
        args.use_gpu = True
        torch.cuda.manual_seed(0)
        torch.cuda.set_device(args.gpu[0])

    model = Tacotron(args)
    criterion_mel = nn.L1Loss(size_average=False)
    criterion_lin = nn.L1Loss(size_average=False)

    window_len = int(np.ceil(args.frame_len_inMS * args.sample_rate / 1000))
    hop_length = int(np.ceil(args.frame_shift_inMS * args.sample_rate / 1000))

    if args.init_from:
        model.load_state_dict(checkpoint['state_dict'])
        print('loaded checkpoint %s' % (args.init_from))

    model = model.eval()

    if args.use_gpu:
        model = model.cuda()
        criterion_mel = criterion_mel.cuda()
        criterion_lin = criterion_lin.cuda()

    if args.caption:
        text_raw = args.caption

        if args.data == 'etri':
            text_raw = decompose_hangul(text_raw)       # For Korean dataset

        vocab_dict = torch.load(args.dir_bin + 'vocab.t7')

        enc_input = [vocab_dict[i] for i in text_raw]
        enc_input = enc_input + [0]                                   # null-padding at tail
        text_lengths = [len(enc_input)]
        enc_input = Variable(torch.LongTensor(enc_input).view(1,-1))

        dec_input = torch.Tensor(1, 1, args.dec_out_size).fill_(0)          # null-padding for start flag
        dec_input = Variable(dec_input)
        wave_lengths = [args.wave_limit]        # TODO: use <EOS> later...

        prev_h = (None, None, None)  # set prev_h = h_0 when new sentences are loaded

        if args.gpu:
            enc_input = enc_input.cuda()
            dec_input = dec_input.cuda()

        _, pred_lin, prev_h = model(enc_input, dec_input, wave_lengths, text_lengths, prev_h)

        # start generation
        wave = spectrogram2wav(
            pred_lin.data.view(-1, args.post_out_size).cpu().numpy(),
            n_fft=args.n_fft,
            win_length=window_len,
            hop_length=hop_length,
            num_iters=args.num_recon_iters
        )

        # write to file
        outpath1 = '%s/%s_%s.wav' % (args.out_dir, args.exp_no, args.caption)
        outpath2 = '%s/%s_%s.png' % (args.out_dir, args.exp_no, args.caption)
        librosa.output.write_wav(outpath1, wave, 16000)
        saveAttention(text_raw, torch.cat(model.attn_weights, dim=-1).squeeze(), outpath2)
    else:
        loader = DataLoader(args)
        args.vocab_size = loader.get_num_vocab()

        for iter in range(1, loader.iter_per_epoch + 1):
            if loader.is_subbatch_end:
                prev_h = (None, None, None)  # set prev_h = h_0 when new sentences are loaded

            for i in range(args.batch_idx):
                loader.next_batch('train')

            enc_input, target_mel, target_lin, wave_lengths, text_lengths = loader.next_batch('train')
            enc_input = Variable(enc_input, volatile=True)
            target_mel = Variable(target_mel, volatile=True)
            target_lin = Variable(target_lin, volatile=True)

            prev_h = loader.mask_prev_h(prev_h)

            if args.gpu:
                enc_input = enc_input.cuda()
                target_mel = target_mel.cuda()
                target_lin = target_lin.cuda()

            pred_mel, pred_lin, prev_h = model(enc_input, target_mel[:, :-1], wave_lengths, text_lengths, prev_h)

            loss_mel = criterion_mel(pred_mel, target_mel[:, 1:]) \
                .div(max(wave_lengths) * args.batch_size * args.dec_out_size)
            loss_linear = criterion_lin(pred_lin, target_lin[:, 1:]) \
                .div(max(wave_lengths) * args.batch_size * args.post_out_size)
            loss = torch.sum(loss_mel + loss_linear)

            print('loss:' , loss.data[0])

            attentions = torch.cat(model.attn_weights, dim=-1)

            # write to file
            for n in range(enc_input.size(0)):
                wave = spectrogram2wav(
                    pred_lin.data[n].view(-1, args.post_out_size).cpu().numpy(),
                    n_fft=args.n_fft,
                    win_length=window_len,
                    hop_length=hop_length,
                    num_iters=args.num_recon_iters
                )
                outpath1 = '%s/%s_%s_%s.wav' % (args.out_dir, args.exp_no, n, args.caption)
                librosa.output.write_wav(outpath1, wave, 16000)
                outpath2 = '%s/%s_%s_%s.png' % (args.out_dir, args.exp_no, n, args.caption)
                saveAttention(None, attentions[n], outpath2)


            # showPlot(plot_losses)
            break