def train(): torch.cuda.set_device(0) iteration = 0 model = WaveRNN(HPARAMS) model = model.cuda() optimizer = optim.Adam(model.parameters(), lr=HPARAMS.lr) if ARGS.checkpoint: if os.path.basename(ARGS.checkpoint).startswith('ema_model'): ema_checkpoint = ARGS.checkpoint else: ema_checkpoint = 'ema_model_' + os.path.basename(ARGS.checkpoint) ema_checkpoint = os.path.join(os.path.dirname(ARGS.checkpoint), ema_checkpoint) # Initialise EMA from the ema checkpoint. logging.info('Initialising ema model {}'.format(ema_checkpoint)) ema_model = WaveRNN(HPARAMS).cuda() ema_base_model, _ = load_checkpoint(ema_checkpoint, ema_model) ema = init_ema(ema_base_model, HPARAMS.ema_rate) # Initialise vanilla model logging.info('Loading checkpoint {}'.format(ARGS.checkpoint)) model, iteration, optimizer = load_checkpoint(ARGS.checkpoint, model, optimizer) else: # Initialise EMA from scratch. ema = init_ema(model, HPARAMS.ema_rate) criterion = nn.NLLLoss(reduction='sum').cuda() train_loader, test_loader = get_loader(ARGS.data, 'train', HPARAMS), get_loader(ARGS.data, 'valid', HPARAMS) whole_loader = get_loader(ARGS.data, 'valid', HPARAMS, whole=True) model = nn.DataParallel(model) epoch_offset = max(0, int(iteration / len(train_loader))) for _ in range(epoch_offset, ARGS.epochs): iteration = train_step( train_loader, test_loader, whole_loader, model, optimizer, criterion, iteration, ema=ema ) averaged_model = clone_as_averaged_model(model, ema) save_checkpoint( { 'state_dict': model.module.state_dict(), 'iteration': iteration, 'dataset': ARGS.data, 'optimizer': optimizer.state_dict(), }, iteration, 'checkpoints/{}/lastmodel.pth'.format(ARGS.expName), ARGS.expName, ) save_checkpoint( { 'state_dict': averaged_model.state_dict(), 'iteration': iteration, 'dataset': ARGS.data, 'optimizer': optimizer.state_dict(), }, iteration, 'checkpoints/{}/ema_model_lastmodel.pth'.format(ARGS.expName), ARGS.expName, )
def test_subscale_vs_standard_inference_partity(): hparams = create_hparams() model = WaveRNN(hparams, debug=True).cuda() seq_len = 100 m = torch.rand(1, hparams.feat_dims, seq_len).cuda() x = torch.rand(1, seq_len * hparams.hop_length).cuda() _, _, standard_x = model.inference(m, gt=x) _, _, subscale_x = model.subscale_inference(m, gt=x) assert (abs(standard_x - subscale_x).mean() < 1e-6)
def clone_as_averaged_model(model, ema): averaged_model = WaveRNN(HPARAMS) averaged_model.cuda() averaged_model.load_state_dict(model.module.state_dict()) for name, param in averaged_model.named_parameters(): if name in ema.shadow: param.data = ema.shadow[name].clone() return averaged_model
def test_inference_forward_parity(): hparams = create_hparams() model = WaveRNN(hparams, debug=True).cuda() model.train() data_path = '../data/short_sens/' whole_segments = get_loader(data_path, 'valid', hparams, whole=True) for i, (x, m, _) in enumerate(whole_segments): x, m = x.cuda(), m.cuda() forward_output, f_context, f_x = model.train_mode_generate(x, m) inference_output, i_cont_dict, i_x = model.inference(m, gt=x) assert (abs(i_x - f_x).mean() < 1e-6) '''
coarse_classes, fine_classes = split_signal_PJ(sample) # In[26]: plot(coarse_classes[73000:73100]) # In[27]: plot(fine_classes[73000:73100]) # ### Train Model # In[28]: model = WaveRNN().cuda() # In[29]: coarse_classes, fine_classes = split_signal(sample) # In[30]: batch_size = 128 # 8gb gpu coarse_classes = coarse_classes[:len(coarse_classes) // batch_size * batch_size] fine_classes = fine_classes[:len(fine_classes) // batch_size * batch_size] coarse_classes = np.reshape(coarse_classes, (batch_size, -1)) fine_classes = np.reshape(fine_classes, (batch_size, -1)) # In[31]:
def __init__(self): # Parse Arguments parser = argparse.ArgumentParser(description='TTS') self.args = parser.parse_args() self.args.vocoder = 'wavernn' self.args.hp_file = 'hparams.py' self.args.voc_weights = False self.args.tts_weights = False self.args.save_attn = False self.args.batched = True self.args.target = None self.args.overlap = None self.args.force_cpu = False #================ vocoder ================# if self.args.vocoder in ['griffinlim', 'gl']: self.args.vocoder = 'griffinlim' elif self.args.vocoder in ['wavernn', 'wr']: self.args.vocoder = 'wavernn' else: raise argparse.ArgumentError('Must provide a valid vocoder type!') hp.configure(self.args.hp_file) # Load hparams from file # set defaults for any arguments that depend on hparams if self.args.vocoder == 'wavernn': if self.args.target is None: self.args.target = hp.voc_target if self.args.overlap is None: self.args.overlap = hp.voc_overlap if self.args.batched is None: self.args.batched = hp.voc_gen_batched #================ others ================# paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id) print("hello") print(paths.base) if not self.args.force_cpu and torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') print('Using device:', device) # === Wavernn === # if self.args.vocoder == 'wavernn': print('\nInitialising WaveRNN Model...\n') self.voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims, fc_dims=hp.voc_fc_dims, bits=hp.bits, pad=hp.voc_pad, upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode=hp.voc_mode).to(device) voc_load_path = self.args.voc_weights if self.args.voc_weights else paths.voc_latest_weights #print(paths.voc_latest_weights) self.voc_model.load(voc_load_path) # === Tacotron === # if hp.tts_model == 'tacotron': print('\nInitialising Tacotron Model...\n') self.tts_model = Tacotron( embed_dims=hp.tts_embed_dims, num_chars=len(symbols), encoder_dims=hp.tts_encoder_dims, decoder_dims=hp.tts_decoder_dims, n_mels=hp.num_mels, fft_bins=hp.num_mels, postnet_dims=hp.tts_postnet_dims, encoder_K=hp.tts_encoder_K, lstm_dims=hp.tts_lstm_dims, postnet_K=hp.tts_postnet_K, num_highways=hp.tts_num_highways, dropout=hp.tts_dropout, stop_threshold=hp.tts_stop_threshold).to(device) tts_load_path = self.args.tts_weights if self.args.tts_weights else paths.tts_latest_weights self.tts_model.load(tts_load_path) # === Tacotron2 === # elif hp.tts_model == 'tacotron2': print('\nInitializing Tacotron2 Model...\n') self.tts_model = Tacotron2().to(device) tts_load_path = self.args.tts_weights if self.args.tts_weights else paths.tts_latest_weights self.tts_model.load(tts_load_path) # === Infomation === # if hp.tts_model == 'tacotron': if self.args.vocoder == 'wavernn': voc_k = self.voc_model.get_step() // 1000 tts_k = self.tts_model.get_step() // 1000 simple_table([ ('Tacotron', str(tts_k) + 'k'), ('r', self.tts_model.r), ('Vocoder Type', 'WaveRNN'), ('WaveRNN', str(voc_k) + 'k'), ('Generation Mode', 'Batched' if self.args.batched else 'Unbatched'), ('Target Samples', self.args.target if self.args.batched else 'N/A'), ('Overlap Samples', self.args.overlap if self.args.batched else 'N/A') ]) elif self.args.vocoder == 'griffinlim': tts_k = self.tts_model.get_step() // 1000 simple_table([('Tacotron', str(tts_k) + 'k'), ('r', self.tts_model.r), ('Vocoder Type', 'Griffin-Lim'), ('GL Iters', self.args.iters)]) elif hp.tts_model == 'tacotron2': if self.args.vocoder == 'wavernn': voc_k = self.voc_model.get_step() // 1000 tts_k = self.tts_model.get_step() // 1000 simple_table([ ('Tacotron2', str(tts_k) + 'k'), ('Vocoder Type', 'WaveRNN'), ('WaveRNN', str(voc_k) + 'k'), ('Generation Mode', 'Batched' if self.args.batched else 'Unbatched'), ('Target Samples', self.args.target if self.args.batched else 'N/A'), ('Overlap Samples', self.args.overlap if self.args.batched else 'N/A') ]) elif self.args.vocoder == 'griffinlim': tts_k = self.tts_model.get_step() // 1000 simple_table([('Tacotron2', str(tts_k) + 'k'), ('Vocoder Type', 'Griffin-Lim'), ('GL Iters', self.args.iters)])
class TaiwaneseTacotron(): def __init__(self): # Parse Arguments parser = argparse.ArgumentParser(description='TTS') self.args = parser.parse_args() self.args.vocoder = 'wavernn' self.args.hp_file = 'hparams.py' self.args.voc_weights = False self.args.tts_weights = False self.args.save_attn = False self.args.batched = True self.args.target = None self.args.overlap = None self.args.force_cpu = False #================ vocoder ================# if self.args.vocoder in ['griffinlim', 'gl']: self.args.vocoder = 'griffinlim' elif self.args.vocoder in ['wavernn', 'wr']: self.args.vocoder = 'wavernn' else: raise argparse.ArgumentError('Must provide a valid vocoder type!') hp.configure(self.args.hp_file) # Load hparams from file # set defaults for any arguments that depend on hparams if self.args.vocoder == 'wavernn': if self.args.target is None: self.args.target = hp.voc_target if self.args.overlap is None: self.args.overlap = hp.voc_overlap if self.args.batched is None: self.args.batched = hp.voc_gen_batched #================ others ================# paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id) print("hello") print(paths.base) if not self.args.force_cpu and torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') print('Using device:', device) # === Wavernn === # if self.args.vocoder == 'wavernn': print('\nInitialising WaveRNN Model...\n') self.voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims, fc_dims=hp.voc_fc_dims, bits=hp.bits, pad=hp.voc_pad, upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode=hp.voc_mode).to(device) voc_load_path = self.args.voc_weights if self.args.voc_weights else paths.voc_latest_weights #print(paths.voc_latest_weights) self.voc_model.load(voc_load_path) # === Tacotron === # if hp.tts_model == 'tacotron': print('\nInitialising Tacotron Model...\n') self.tts_model = Tacotron( embed_dims=hp.tts_embed_dims, num_chars=len(symbols), encoder_dims=hp.tts_encoder_dims, decoder_dims=hp.tts_decoder_dims, n_mels=hp.num_mels, fft_bins=hp.num_mels, postnet_dims=hp.tts_postnet_dims, encoder_K=hp.tts_encoder_K, lstm_dims=hp.tts_lstm_dims, postnet_K=hp.tts_postnet_K, num_highways=hp.tts_num_highways, dropout=hp.tts_dropout, stop_threshold=hp.tts_stop_threshold).to(device) tts_load_path = self.args.tts_weights if self.args.tts_weights else paths.tts_latest_weights self.tts_model.load(tts_load_path) # === Tacotron2 === # elif hp.tts_model == 'tacotron2': print('\nInitializing Tacotron2 Model...\n') self.tts_model = Tacotron2().to(device) tts_load_path = self.args.tts_weights if self.args.tts_weights else paths.tts_latest_weights self.tts_model.load(tts_load_path) # === Infomation === # if hp.tts_model == 'tacotron': if self.args.vocoder == 'wavernn': voc_k = self.voc_model.get_step() // 1000 tts_k = self.tts_model.get_step() // 1000 simple_table([ ('Tacotron', str(tts_k) + 'k'), ('r', self.tts_model.r), ('Vocoder Type', 'WaveRNN'), ('WaveRNN', str(voc_k) + 'k'), ('Generation Mode', 'Batched' if self.args.batched else 'Unbatched'), ('Target Samples', self.args.target if self.args.batched else 'N/A'), ('Overlap Samples', self.args.overlap if self.args.batched else 'N/A') ]) elif self.args.vocoder == 'griffinlim': tts_k = self.tts_model.get_step() // 1000 simple_table([('Tacotron', str(tts_k) + 'k'), ('r', self.tts_model.r), ('Vocoder Type', 'Griffin-Lim'), ('GL Iters', self.args.iters)]) elif hp.tts_model == 'tacotron2': if self.args.vocoder == 'wavernn': voc_k = self.voc_model.get_step() // 1000 tts_k = self.tts_model.get_step() // 1000 simple_table([ ('Tacotron2', str(tts_k) + 'k'), ('Vocoder Type', 'WaveRNN'), ('WaveRNN', str(voc_k) + 'k'), ('Generation Mode', 'Batched' if self.args.batched else 'Unbatched'), ('Target Samples', self.args.target if self.args.batched else 'N/A'), ('Overlap Samples', self.args.overlap if self.args.batched else 'N/A') ]) elif self.args.vocoder == 'griffinlim': tts_k = self.tts_model.get_step() // 1000 simple_table([('Tacotron2', str(tts_k) + 'k'), ('Vocoder Type', 'Griffin-Lim'), ('GL Iters', self.args.iters)]) def generate(self, 華, input_text): inputs = [text_to_sequence(input_text.strip(), ['basic_cleaners'])] if hp.tts_model == 'tacotron2': self.gen_tacotron2(華, inputs) elif hp.tts_model == 'tacotron': self.gen_tacotron(華, inputs) else: print(f"Wrong tts model type {{{tts_model_type}}}") print('\n\nDone.\n') # custom function def gen_tacotron2(self, 華, inputs): for i, x in enumerate(inputs, 1): print(f'\n| Generating {i}/{len(inputs)}') print(x) x = np.array(x)[None, :] x = torch.autograd.Variable(torch.from_numpy(x)).cuda().long() self.tts_model.eval() mel_outputs, mel_outputs_postnet, _, alignments = self.tts_model.inference( x) if self.args.vocoder == 'griffinlim': v_type = self.args.vocoder elif self.args.vocoder == 'wavernn' and self.args.batched: v_type = 'wavernn_batched' else: v_type = 'wavernn_unbatched' # == define output name == # if len(華) == 0: output_name = re.split(r'\,|\.|\!|\?| ', input_text)[0] elif 1 <= len(華) <= 9: output_name = 華[:-1] elif 9 < len(華): output_name = 華[:8] print(output_name) save_path = "output/{}.wav".format(output_name) ## if self.args.vocoder == 'wavernn': m = mel_outputs_postnet self.voc_model.generate(m, save_path, self.args.batched, hp.voc_target, hp.voc_overlap, hp.mu_law) elif self.args.vocoder == 'griffinlim': m = torch.squeeze(mel_outputs_postnet).detach().cpu().numpy() wav = reconstruct_waveform(m, n_iter=self.args.iters) save_wav(wav, save_path) # custom function def gen_tacotron(self, 華, inputs): for i, x in enumerate(inputs, 1): print(f'\n| Generating {i}/{len(inputs)}') _, m, attention = self.tts_model.generate(x) # Fix mel spectrogram scaling to be from 0 to 1 m = (m + 4) / 8 np.clip(m, 0, 1, out=m) if self.args.vocoder == 'griffinlim': v_type = self.args.vocoder elif self.args.vocoder == 'wavernn' and self.args.batched: v_type = 'wavernn_batched' else: v_type = 'wavernn_unbatched' # == define output name == # if len(華) == 0: output_name = re.split(r'\,|\.|\!|\?| ', input_text)[0] elif 1 <= len(華) <= 9: output_name = 華[:-1] elif 9 < len(華): output_name = 華[:8] print(output_name) save_path = "output/{}.wav".format(output_name) ## if self.args.vocoder == 'wavernn': m = torch.tensor(m).unsqueeze(0) self.voc_model.generate(m, save_path, self.args.batched, hp.voc_target, hp.voc_overlap, hp.mu_law) elif self.args.vocoder == 'griffinlim': wav = reconstruct_waveform(m, n_iter=self.args.iters) save_wav(wav, save_path)
def load_model(): model = WaveRNN(HPARAMS).cuda() model, _ = load_checkpoint(ARGS.checkpoint, model) return model