def main(): parser = argparse.ArgumentParser(description='training script') # data load parser.add_argument('--data', type=str, default='blizzard', help='blizzard / nancy') parser.add_argument('--batch_size', type=int, default=32, help='batch size') parser.add_argument('--text_limit', type=int, default=1000, help='maximum length of text to include in training set') parser.add_argument('--wave_limit', type=int, default=1400, help='maximum length of spectrogram to include in training set') parser.add_argument('--trunc_size', type=int, default=700, help='used for truncated-BPTT when memory is not enough.') parser.add_argument('--shuffle_data', type=int, default=1, help='whether to shuffle data loader') parser.add_argument('--load_queue_size', type=int, default=8, help='maximum number of batches to load on the memory') parser.add_argument('--n_workers', type=int, default=2, help='number of workers used in data loader') # model parser.add_argument('--charvec_dim', type=int, default=256, help='') parser.add_argument('--hidden_size', type=int, default=128, help='') parser.add_argument('--dec_out_size', type=int, default=80, help='decoder output size') parser.add_argument('--post_out_size', type=int, default=1025, help='should be n_fft / 2 + 1(check n_fft from "input_specL" ') parser.add_argument('--num_filters', type=int, default=16, help='number of filters in filter bank of CBHG') parser.add_argument('--r_factor', type=int, default=5, help='reduction factor(# of multiple output)') parser.add_argument('--dropout', type=float, default=0.5, help='') # optimization parser.add_argument('--max_epochs', type=int, default=100000, help='maximum epoch to train') parser.add_argument('--grad_clip', type=float, default=1, help='gradient clipping') parser.add_argument('--learning_rate', type=float, default=1e-3, help='2e-3 from Ito, I used to use 5e-4') parser.add_argument('--lr_decay_every', type=int, default=25000, help='decay learning rate every...') parser.add_argument('--lr_decay_factor', type=float, default=0.5, help='decay learning rate by this factor') parser.add_argument('--teacher_forcing_ratio', type=float, default=1, help='value between 0~1, use this for scheduled sampling') # loading parser.add_argument('--init_from', type=str, default='', help='load parameters from...') parser.add_argument('--resume', type=int, default=0, help='1 for resume from saved epoch') # misc parser.add_argument('--exp_no', type=int, default=0, help='') parser.add_argument('--print_every', type=int, default=-1, help='') parser.add_argument('--plot_every', type=int, default=-1, help='') parser.add_argument('--save_every', type=int, default=-1, help='') parser.add_argument('--save_dir', type=str, default='checkpoint', help='') parser.add_argument('--pinned_memory', type=int, default=1, help='1 to use pinned memory') parser.add_argument('--gpu', type=int, nargs='+', help='index of gpu machines to run') # debug parser.add_argument('--debug', type=int, default=0, help='1 for debug mode') args = parser.parse_args() torch.manual_seed(0) # set dataset option if args.data == 'blizzard': args.dir_bin = '/home/lyg0722/TTS_corpus/blizzard/segmented/bin/' elif args.data == 'etri': args.dir_bin = '/data2/lyg0722/TTS_corpus/etri/bin/' else: print('no dataset') return if args.gpu is None: args.use_gpu = False args.gpu = [] else: args.use_gpu = True torch.cuda.manual_seed(0) torch.cuda.set_device(args.gpu[0]) loader = DataLoader(args) # set misc options args.vocab_size = loader.get_num_vocab() if args.print_every == -1: args.print_every = loader.iter_per_epoch if args.plot_every == -1: args.plot_every = args.print_every if args.save_every == -1: args.save_every = loader.iter_per_epoch * 10 # save every 10 epoch by default model = Tacotron(args) model_optim = optim.Adam(model.parameters(), args.learning_rate) criterion_mel = nn.L1Loss(size_average=False) criterion_lin = nn.L1Loss(size_average=False) start = time.time() plot_losses = [] print_loss_total = 0 # Reset every print_every plot_loss_total = 0 # Reset every plot_every start_epoch = 0 iter = 1 if args.init_from: checkpoint = torch.load(args.init_from, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) if args.resume != 0: start_epoch = checkpoint['epoch'] plot_losses = checkpoint['plot_losses'] print('loaded checkpoint %s (epoch %d)' % (args.init_from, start_epoch)) model = model.train() if args.use_gpu: model = model.cuda() criterion_mel = criterion_mel.cuda() criterion_lin = criterion_lin.cuda() print('Start training... (1 epoch = %s iters)' % (loader.iter_per_epoch)) while iter < args.max_epochs * loader.iter_per_epoch + 1: if loader.is_subbatch_end: prev_h = (None, None, None) # set prev_h = h_0 when new sentences are loaded enc_input, target_mel, target_lin, wave_lengths, text_lengths = loader.next_batch('train') max_wave_len = max(wave_lengths) enc_input = Variable(enc_input, requires_grad=False) target_mel = Variable(target_mel, requires_grad=False) target_lin = Variable(target_lin, requires_grad=False) prev_h = loader.mask_prev_h(prev_h) model_optim.zero_grad() pred_mel, pred_lin, prev_h = model(enc_input, target_mel[:, :-1], wave_lengths, text_lengths, prev_h) loss_mel = criterion_mel(pred_mel, target_mel[:, 1:])\ .div(max_wave_len * args.batch_size * args.dec_out_size) loss_linear = criterion_lin(pred_lin, target_lin[:, 1:])\ .div(max_wave_len * args.batch_size * args.post_out_size) loss = torch.sum(loss_mel + loss_linear) loss.backward() nn.utils.clip_grad_norm(model.parameters(), args.grad_clip) # gradient clipping model_optim.step() print_loss_total += loss.data[0] plot_loss_total += loss.data[0] if iter % args.print_every == 0: print_loss_avg = print_loss_total / args.print_every print_loss_total = 0 print('%s (%d %d%%) %.4f' % (timeSince(start, iter / args.max_epochs), iter, iter / args.max_epochs * 100, print_loss_avg)) if iter % args.plot_every == 0: plot_loss_avg = plot_loss_total / args.plot_every plot_losses.append(plot_loss_avg) plot_loss_total = 0 save_name = '%s/%dth_exp_loss.png' % (args.save_dir, args.exp_no) savePlot(plot_losses, save_name) if iter % args.save_every == 0: epoch = start_epoch + iter // loader.iter_per_epoch save_name = '%s/%d_%dth.t7' % (args.save_dir, args.exp_no, epoch) state = { 'epoch': epoch, 'args': args, 'state_dict': model.state_dict(), 'optimizer': model_optim.state_dict(), 'plot_losses': plot_losses } torch.save(state, save_name) print('model saved to', save_name) # if is_best: # TODO: implement saving best model. # shutil.copyfile(save_name, '%s/%d_best.t7' % (args.save_dir, args.exp_no)) iter += 1
def main(): parser = argparse.ArgumentParser(description='training script') # data load parser.add_argument('--data', type=str, default='blizzard', help='blizzard / nancy') parser.add_argument('--batch_size', type=int, default=6, help='batch size') parser.add_argument('--text_limit', type=int, default=1500, help='maximum length of text to include in training set') parser.add_argument('--wave_limit', type=int, default=800, help='maximum length of spectrogram to include in training set') parser.add_argument('--shuffle_data', type=int, default=0, help='whether to shuffle data loader') parser.add_argument('--batch_idx', type=int, default=0, help='n-th batch of the dataset') parser.add_argument('--load_queue_size', type=int, default=1, help='maximum number of batches to load on the memory') parser.add_argument('--n_workers', type=int, default=1, help='number of workers used in data loader') # generation option parser.add_argument('--exp_no', type=int, default=0, help='') parser.add_argument('--out_dir', type=str, default='generated', help='') parser.add_argument('--init_from', type=str, default='', help='load parameters from...') parser.add_argument('--caption', type=str, default='', help='text to generate speech') parser.add_argument('--teacher_forcing_ratio', type=float, default=0, help='value between 0~1, use this for scheduled sampling') # audio related option parser.add_argument('--n_fft', type=int, default=2048, help='fft bin size') parser.add_argument('--sample_rate', type=int, default=16000, help='sampling rate') parser.add_argument('--frame_len_inMS', type=int, default=50, help='used to determine window size of fft') parser.add_argument('--frame_shift_inMS', type=int, default=12.5, help='used to determine stride in sfft') parser.add_argument('--num_recon_iters', type=int, default=50, help='# of iteration in griffin-lim recon') # misc parser.add_argument('--gpu', type=int, nargs='+', help='index of gpu machines to run') parser.add_argument('--seed', type=int, default=0, help='random seed') new_args = vars(parser.parse_args()) # load and override some arguments checkpoint = torch.load(new_args['init_from'], map_location=lambda storage, loc: storage) args = checkpoint['args'] for i in new_args: args.__dict__[i] = new_args[i] torch.manual_seed(args.seed) # set dataset option if args.data == 'blizzard': args.dir_bin = '/data2/lyg0722/TTS_corpus/blizzard/segmented/bin/' elif args.data == 'etri': args.dir_bin = '/data2/lyg0722/TTS_corpus/etri/bin/' else: print('no dataset') return if args.gpu is None: args.use_gpu = False args.gpu = [] else: args.use_gpu = True torch.cuda.manual_seed(0) torch.cuda.set_device(args.gpu[0]) model = Tacotron(args) criterion_mel = nn.L1Loss(size_average=False) criterion_lin = nn.L1Loss(size_average=False) window_len = int(np.ceil(args.frame_len_inMS * args.sample_rate / 1000)) hop_length = int(np.ceil(args.frame_shift_inMS * args.sample_rate / 1000)) if args.init_from: model.load_state_dict(checkpoint['state_dict']) print('loaded checkpoint %s' % (args.init_from)) model = model.eval() if args.use_gpu: model = model.cuda() criterion_mel = criterion_mel.cuda() criterion_lin = criterion_lin.cuda() if args.caption: text_raw = args.caption if args.data == 'etri': text_raw = decompose_hangul(text_raw) # For Korean dataset vocab_dict = torch.load(args.dir_bin + 'vocab.t7') enc_input = [vocab_dict[i] for i in text_raw] enc_input = enc_input + [0] # null-padding at tail text_lengths = [len(enc_input)] enc_input = Variable(torch.LongTensor(enc_input).view(1,-1)) dec_input = torch.Tensor(1, 1, args.dec_out_size).fill_(0) # null-padding for start flag dec_input = Variable(dec_input) wave_lengths = [args.wave_limit] # TODO: use <EOS> later... prev_h = (None, None, None) # set prev_h = h_0 when new sentences are loaded if args.gpu: enc_input = enc_input.cuda() dec_input = dec_input.cuda() _, pred_lin, prev_h = model(enc_input, dec_input, wave_lengths, text_lengths, prev_h) # start generation wave = spectrogram2wav( pred_lin.data.view(-1, args.post_out_size).cpu().numpy(), n_fft=args.n_fft, win_length=window_len, hop_length=hop_length, num_iters=args.num_recon_iters ) # write to file outpath1 = '%s/%s_%s.wav' % (args.out_dir, args.exp_no, args.caption) outpath2 = '%s/%s_%s.png' % (args.out_dir, args.exp_no, args.caption) librosa.output.write_wav(outpath1, wave, 16000) saveAttention(text_raw, torch.cat(model.attn_weights, dim=-1).squeeze(), outpath2) else: loader = DataLoader(args) args.vocab_size = loader.get_num_vocab() for iter in range(1, loader.iter_per_epoch + 1): if loader.is_subbatch_end: prev_h = (None, None, None) # set prev_h = h_0 when new sentences are loaded for i in range(args.batch_idx): loader.next_batch('train') enc_input, target_mel, target_lin, wave_lengths, text_lengths = loader.next_batch('train') enc_input = Variable(enc_input, volatile=True) target_mel = Variable(target_mel, volatile=True) target_lin = Variable(target_lin, volatile=True) prev_h = loader.mask_prev_h(prev_h) if args.gpu: enc_input = enc_input.cuda() target_mel = target_mel.cuda() target_lin = target_lin.cuda() pred_mel, pred_lin, prev_h = model(enc_input, target_mel[:, :-1], wave_lengths, text_lengths, prev_h) loss_mel = criterion_mel(pred_mel, target_mel[:, 1:]) \ .div(max(wave_lengths) * args.batch_size * args.dec_out_size) loss_linear = criterion_lin(pred_lin, target_lin[:, 1:]) \ .div(max(wave_lengths) * args.batch_size * args.post_out_size) loss = torch.sum(loss_mel + loss_linear) print('loss:' , loss.data[0]) attentions = torch.cat(model.attn_weights, dim=-1) # write to file for n in range(enc_input.size(0)): wave = spectrogram2wav( pred_lin.data[n].view(-1, args.post_out_size).cpu().numpy(), n_fft=args.n_fft, win_length=window_len, hop_length=hop_length, num_iters=args.num_recon_iters ) outpath1 = '%s/%s_%s_%s.wav' % (args.out_dir, args.exp_no, n, args.caption) librosa.output.write_wav(outpath1, wave, 16000) outpath2 = '%s/%s_%s_%s.png' % (args.out_dir, args.exp_no, n, args.caption) saveAttention(None, attentions[n], outpath2) # showPlot(plot_losses) break