def build_model(): if is_mulaw_quantize(hparams.input_type): if hparams.out_channels != hparams.quantize_channels: raise RuntimeError( "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'" ) model = getattr(builder, hparams.builder)( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, weight_normalization=hparams.weight_normalization, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, upsample_conditional_features=hparams.upsample_conditional_features, upsample_scales=hparams.upsample_scales, freq_axis_kernel_size=hparams.freq_axis_kernel_size, scalar_input=is_scalar_input(hparams.input_type), ) return model
def build_model(): if is_mulaw_quantize(hparams.input_type): if hparams.out_channels != hparams.quantize_channels: raise RuntimeError( "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'" ) if hparams.upsample_conditional_features and hparams.cin_channels < 0: s = "Upsample conv layers were specified while local conditioning disabled. " s += "Notice that upsample conv layers will never be used." warn(s) model = getattr(builder, hparams.builder)( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, weight_normalization=hparams.weight_normalization, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, upsample_conditional_features=hparams.upsample_conditional_features, upsample_scales=hparams.upsample_scales, freq_axis_kernel_size=hparams.freq_axis_kernel_size, scalar_input=is_scalar_input(hparams.input_type), legacy=hparams.legacy, ) return model
def build_model(): if is_mulaw_quantize(wavenet_hparams.input_type): if wavenet_hparams.out_channels != wavenet_hparams.quantize_channels: raise RuntimeError( "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'" ) if (wavenet_hparams.upsample_conditional_features and wavenet_hparams.cin_channels < 0): s = "Upsample conv layers were specified while local conditioning disabled. " s += "Notice that upsample conv layers will never be used." warn(s) upsample_params = wavenet_hparams.upsample_params upsample_params["cin_channels"] = wavenet_hparams.cin_channels upsample_params["cin_pad"] = wavenet_hparams.cin_pad model = WaveNet( out_channels=wavenet_hparams.out_channels, layers=wavenet_hparams.layers, stacks=wavenet_hparams.stacks, residual_channels=wavenet_hparams.residual_channels, gate_channels=wavenet_hparams.gate_channels, skip_out_channels=wavenet_hparams.skip_out_channels, cin_channels=wavenet_hparams.cin_channels, gin_channels=wavenet_hparams.gin_channels, n_speakers=wavenet_hparams.n_speakers, dropout=wavenet_hparams.dropout, kernel_size=wavenet_hparams.kernel_size, cin_pad=wavenet_hparams.cin_pad, upsample_conditional_features=wavenet_hparams. upsample_conditional_features, upsample_params=upsample_params, scalar_input=is_scalar_input(wavenet_hparams.input_type), output_distribution=wavenet_hparams.output_distribution, ) return model
def get_model(): global hparams upsample_params = hparams.upsample_params upsample_params["cin_channels"] = hparams.cin_channels upsample_params["cin_pad"] = hparams.cin_pad model = WaveNet( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, cin_pad=hparams.cin_pad, upsample_conditional_features=hparams.upsample_conditional_features, upsample_params=upsample_params, scalar_input=is_scalar_input(hparams.input_type), output_distribution=hparams.output_distribution, ) # print(model) return model
def build_vqvae_model(): if is_mulaw_quantize(hparams.input_type): if hparams.out_channels != hparams.quantize_channels: raise RuntimeError( "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'" ) if hparams.upsample_conditional_features and hparams.cin_channels < 0: s = "Upsample conv layers were specified while local conditioning disabled. " s += "Notice that upsample conv layers will never be used." warn(s) upsample_params = hparams.upsample_params upsample_params["cin_channels"] = hparams.cin_channels upsample_params["cin_pad"] = hparams.cin_pad wavenet = WaveNet( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, cin_pad=hparams.cin_pad, upsample_conditional_features=hparams.upsample_conditional_features, upsample_params=upsample_params, scalar_input=is_scalar_input(hparams.input_type), output_distribution=hparams.output_distribution, use_speaker_embedding=True, ) if hparams.use_K1 and hparams.K1 != hparams.K: K1 = hparams.K1 else: K1 = None if hparams.post_conv: hid = 64 else: hid = hparams.cin_channels model = VQVAE(wavenet=wavenet, c_in=39, hid=hid, frame_rate=hparams.frame_rate, use_time_jitter=hparams.time_jitter, K=hparams.K, ema=hparams.ema, sliced=hparams.sliced, ins_norm=hparams.ins_norm, post_conv=hparams.post_conv, adain=hparams.adain, dropout=hparams.vq_drop, drop_dim=hparams.drop_dim, K1=K1, num_slices=hparams.num_slices) return model
def build_model(): if is_mulaw_quantize(hparams.input_type): if hparams.out_channels != hparams.quantize_channels: raise RuntimeError( "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'" ) if hparams.upsample_conditional_features and hparams.cin_channels < 0: s = "Upsample conv layers were specified while local conditioning disabled. " s += "Notice that upsample conv layers will never be used." warn(s) upsample_params = hparams.upsample_params upsample_params["cin_channels"] = hparams.cin_channels upsample_params["cin_pad"] = hparams.cin_pad if hparams.name == 'new_inae': use_speaker_embedding = False else: use_speaker_embedding = True wavenet = WaveNet( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, cin_pad=hparams.cin_pad, upsample_conditional_features=hparams.upsample_conditional_features, upsample_params=upsample_params, scalar_input=is_scalar_input(hparams.input_type), output_distribution=hparams.output_distribution, use_speaker_embedding=use_speaker_embedding, ) if hparams.name == 'inae': model = INAE(wavenet=wavenet, c_in=39, hid=64, frame_rate=hparams.frame_rate, adain=hparams.adain) elif hparams.name == 'inae1': model = INAE1(wavenet=wavenet, c_in=39, hid=64, frame_rate=hparams.frame_rate, adain=hparams.adain) elif hparams.name == 'new_inae': model = NewINAE(wavenet=wavenet, c_in=39, hid=64, frame_rate=hparams.frame_rate) return model
def __init__( self, cin_channels=80, dropout=0.05, freq_axis_kernel_size=3, gate_channels=512, gin_channels=-1, hinge_regularizer=True, # Only used in MoL prediction (INPUT_TYPE_RAW). kernel_size=3, layers=24, log_scale_min=float(np.log(1e-14)), # Only used in INPUT_TYPE_RAW. n_speakers=1, out_channels=256, # Use num_mixtures * 3 (pi, mean, log_scale) for INPUT_TYPE_RAW. residual_channels=512, scalar_input=is_scalar_input(INPUT_TYPE_MULAW), skip_out_channels=256, stacks=4, upsample_conditional_features=False, upsample_scales=[5, 4, 2], use_speaker_embedding=False, weight_normalization=True, legacy=False): self.cin_channels = cin_channels self.dropout = dropout self.freq_axis_kernel_size = freq_axis_kernel_size self.gate_channels = gate_channels self.gin_channels = gin_channels self.hinge_regularizer = hinge_regularizer self.kernel_size = kernel_size self.layers = layers self.log_scale_min = log_scale_min self.n_speakers = n_speakers self.out_channels = out_channels self.residual_channels = residual_channels self.scalar_input = scalar_input self.skip_out_channels = skip_out_channels self.stacks = stacks self.upsample_conditional_features = upsample_conditional_features self.upsample_scales = upsample_scales self.use_speaker_embedding = use_speaker_embedding self.weight_normalization = weight_normalization self.legacy = legacy
def build_catae_model(): if is_mulaw_quantize(hparams.input_type): if hparams.out_channels != hparams.quantize_channels: raise RuntimeError( "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'" ) if hparams.upsample_conditional_features and hparams.cin_channels < 0: s = "Upsample conv layers were specified while local conditioning disabled. " s += "Notice that upsample conv layers will never be used." warn(s) upsample_params = hparams.upsample_params upsample_params["cin_channels"] = hparams.cin_channels upsample_params["cin_pad"] = hparams.cin_pad wavenet = WaveNet( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, cin_pad=hparams.cin_pad, upsample_conditional_features=hparams.upsample_conditional_features, upsample_params=upsample_params, scalar_input=is_scalar_input(hparams.input_type), output_distribution=hparams.output_distribution, use_speaker_embedding=True, ) model = CatWavAE(wavenet=wavenet, c_in=39, hid=hparams.cin_channels, tau=0.1, k=hparams.K, frame_rate=hparams.frame_rate, hard=hparams.hard, slices=hparams.num_slices) return model
def save_checkpoint(device, model, global_step, global_test_step, checkpoint_dir, epoch, ema=None): checkpoint_path = join( checkpoint_dir, hparams.name + "_checkpoint_step{:09d}.pth.tar".format(global_step)) optimizer_state = model.optimizer.state_dict( ) if hparams.save_optimizer_state else None torch.save( { "model": model.decode_model.state_dict(), "optimizer": optimizer_state, "global_step": global_step, "global_epoch": epoch, "global_test_step": global_test_step, }, checkpoint_path) print("Saved checkpoint:", checkpoint_path) if ema is not None: averaged_model = WaveNet( scalar_input=is_scalar_input(hparams.input_type)) averaged_model = torch.nn.DataParallel(averaged_model).to(device) averaged_model = clone_as_averaged_model(averaged_model, model, ema) checkpoint_path = join( checkpoint_dir, "checkpoint_step{:09d}_ema.pth".format(global_step)) torch.save( { "model": averaged_model.state_dict(), "optimizer": optimizer_state, "global_step": global_step, "global_epoch": epoch, "global_test_step": global_test_step, }, checkpoint_path) print("Saved averaged checkpoint:", checkpoint_path)
def build_model(hparams_json=None): if hparams_json is not None: with open(hparams_json, 'r') as jf: hparams = HParams(**json.load(jf)) if is_mulaw_quantize(hparams.input_type): if hparams.out_channels != hparams.quantize_channels: raise RuntimeError( "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'") if hparams.upsample_conditional_features and hparams.cin_channels < 0: s = "Upsample conv layers were specified while local conditioning disabled. " s += "Notice that upsample conv layers will never be used." warn(s) upsample_params = hparams.upsample_params upsample_params["cin_channels"] = hparams.cin_channels upsample_params["cin_pad"] = hparams.cin_pad use_speaker_embedding = True if hparams.gin_channels > 0 else False model = WaveNet( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, cin_pad=hparams.cin_pad, upsample_conditional_features=hparams.upsample_conditional_features, upsample_net=hparams.upsample_net, upsample_params=upsample_params, scalar_input=is_scalar_input(hparams.input_type), use_speaker_embedding=use_speaker_embedding, output_distribution=hparams.output_distribution, ) return model
def __init__(self, dim_in, dim_out, hparams): super().__init__() self.len_in_out_multiplier = hparams.len_in_out_multiplier # Use the wavenet_vocoder builder to create the model. self.model = WaveNet(out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, kernel_size=hparams.kernel_size, dropout=hparams.dropout, weight_normalization=hparams.weight_normalization, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, n_speakers=hparams.n_speakers, upsample_conditional_features=hparams.upsample_conditional_features, upsample_scales=hparams.upsample_scales, freq_axis_kernel_size=hparams.freq_axis_kernel_size, scalar_input=is_scalar_input(hparams.input_type), use_speaker_embedding=hparams.use_speaker_embedding, )
model = WaveNet( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, cin_pad=hparams.cin_pad, upsample_conditional_features=hparams.upsample_conditional_features, upsample_params=upsample_params, scalar_input=is_scalar_input(hparams.input_type), output_distribution=hparams.output_distribution, ) loss_net = NetWithLossClass(model, hparams) lr = get_lr(hparams.optimizer_params["lr"], hparams.nepochs, step_size_per_epoch) lr = Tensor(lr) if args.checkpoint != '': param_dict = load_checkpoint(args.pre_trained_model_path) load_param_into_net(model, param_dict) print('Successfully loading the pre-trained model') weights = model.trainable_params() optimizer = Adam(weights, learning_rate=lr, loss_scale=1024.) train_net = TrainOneStepCell(loss_net, optimizer)
def build_model(hparams, name=None): assert name is not None if is_mulaw_quantize(hparams.input_type): if hparams.out_channels != hparams.quantize_channels: raise RuntimeError( "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'" ) if hparams.upsample_conditional_features and hparams.cin_channels < 0: s = "Upsample conv layers were specified while local conditioning disabled. " s += "Notice that upsample conv layers will never be used." warn(s) if name == "teacher": model = getattr(builder, "wavenet")( out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, weight_normalization=hparams.weight_normalization, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, upsample_conditional_features=hparams. upsample_conditional_features, upsample_scales=hparams.upsample_scales, freq_axis_kernel_size=hparams.freq_axis_kernel_size, scalar_input=is_scalar_input(hparams.input_type), ) elif name == "parallel": model = getattr(builder, "student_wavenet")( out_channels=hparams.student_out_channels, layers=hparams.student_layers, stacks=hparams.student_stacks, residual_channels=hparams.student_residual_channels, iaf_layer_sizes=hparams.iaf_layer_sizes, gate_channels=hparams.student_gate_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, weight_normalization=hparams.weight_normalization, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, upsample_conditional_features=hparams. upsample_conditional_features, upsample_scales=hparams.upsample_scales, freq_axis_kernel_size=hparams.freq_axis_kernel_size, scalar_input=is_scalar_input(hparams.input_type), ) elif name == "clari": model = getattr(builder, "clari_wavenet")( out_channels=hparams.student_out_channels, layers=hparams.student_layers, stacks=hparams.student_stacks, residual_channels=hparams.student_residual_channels, iaf_layer_sizes=hparams.iaf_layer_sizes, gate_channels=hparams.student_gate_channels, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, weight_normalization=hparams.weight_normalization, n_speakers=hparams.n_speakers, dropout=hparams.dropout, kernel_size=hparams.kernel_size, upsample_conditional_features=hparams. upsample_conditional_features, upsample_scales=hparams.upsample_scales, freq_axis_kernel_size=hparams.freq_axis_kernel_size, scalar_input=is_scalar_input(hparams.input_type), use_skip=hparams.use_skip, iaf_shift=hparams.iaf_shift) else: raise Exception("No such model") return model