class WaveNetWrapper(nn.Module): """A wrapper around r9y9's WaveNet implementation to integrate it seamlessly into the framework.""" IDENTIFIER = "r9y9WaveNet" def __init__(self, dim_in, dim_out, hparams): super().__init__() self.len_in_out_multiplier = hparams.len_in_out_multiplier # Use the wavenet_vocoder builder to create the model. self.model = WaveNet(out_channels=hparams.out_channels, layers=hparams.layers, stacks=hparams.stacks, residual_channels=hparams.residual_channels, gate_channels=hparams.gate_channels, skip_out_channels=hparams.skip_out_channels, kernel_size=hparams.kernel_size, dropout=hparams.dropout, weight_normalization=hparams.weight_normalization, cin_channels=hparams.cin_channels, gin_channels=hparams.gin_channels, n_speakers=hparams.n_speakers, upsample_conditional_features=hparams.upsample_conditional_features, upsample_scales=hparams.upsample_scales, freq_axis_kernel_size=hparams.freq_axis_kernel_size, scalar_input=is_scalar_input(hparams.input_type), use_speaker_embedding=hparams.use_speaker_embedding, ) def forward(self, inputs, hidden, seq_lengths_inputs, max_length_inputs, target=None, seq_lengths_target=None): if target is not None: # During training and testing with teacher forcing. output = self.model(target, c=inputs, g=None, softmax=False) # output = self.model(target, c=inputs[:, :, :target.shape[2]], g=None, softmax=False) # Output shape is B x C x T. Don't permute here because CrossEntropyLoss requires the same shape. else: # During inference. with torch.no_grad(): self.model.make_generation_fast_() assert(len(seq_lengths_inputs) == 1), "Batch synthesis is not supported yet." num_frames_to_gen = seq_lengths_inputs[0] * self.len_in_out_multiplier output = self.model.incremental_forward(c=inputs, T=num_frames_to_gen, softmax=True, quantize=True) # Output shape is B x C x T. return output, None def set_gpu_flag(self, use_gpu): self.use_gpu = use_gpu def init_hidden(self, batch_size=1): return None def parameters(self): return self.model.parameters()
class WaveNetWrapper(nn.Module): """A wrapper around r9y9's WaveNet implementation to integrate it seamlessly into the framework.""" IDENTIFIER = "r9y9WaveNet" class Config: INPUT_TYPE_MULAW = "mulaw-quantize" INPUT_TYPE_RAW = "raw" def __init__( self, cin_channels=80, dropout=0.05, freq_axis_kernel_size=3, gate_channels=512, gin_channels=-1, hinge_regularizer=True, # Only used in MoL prediction (INPUT_TYPE_RAW). kernel_size=3, layers=24, log_scale_min=float(np.log(1e-14)), # Only used in INPUT_TYPE_RAW. n_speakers=1, out_channels=256, # Use num_mixtures * 3 (pi, mean, log_scale) for INPUT_TYPE_RAW. residual_channels=512, scalar_input=is_scalar_input(INPUT_TYPE_MULAW), skip_out_channels=256, stacks=4, upsample_conditional_features=False, upsample_scales=[5, 4, 2], use_speaker_embedding=False, weight_normalization=True, legacy=False): self.cin_channels = cin_channels self.dropout = dropout self.freq_axis_kernel_size = freq_axis_kernel_size self.gate_channels = gate_channels self.gin_channels = gin_channels self.hinge_regularizer = hinge_regularizer self.kernel_size = kernel_size self.layers = layers self.log_scale_min = log_scale_min self.n_speakers = n_speakers self.out_channels = out_channels self.residual_channels = residual_channels self.scalar_input = scalar_input self.skip_out_channels = skip_out_channels self.stacks = stacks self.upsample_conditional_features = upsample_conditional_features self.upsample_scales = upsample_scales self.use_speaker_embedding = use_speaker_embedding self.weight_normalization = weight_normalization self.legacy = legacy def create_model(self): return WaveNetWrapper(self) def __init__(self, config): super().__init__() # self.len_in_out_multiplier = hparams.len_in_out_multiplier # Use the wavenet_vocoder builder to create the model. self.model = WaveNet( out_channels=config.out_channels, layers=config.layers, stacks=config.stacks, residual_channels=config.residual_channels, gate_channels=config.gate_channels, skip_out_channels=config.skip_out_channels, kernel_size=config.kernel_size, dropout=config.dropout, weight_normalization=config.weight_normalization, cin_channels=config.cin_channels, gin_channels=config.gin_channels, n_speakers=config.n_speakers, upsample_conditional_features=config.upsample_conditional_features, upsample_scales=config.upsample_scales, freq_axis_kernel_size=config.freq_axis_kernel_size, scalar_input=config.scalar_input, use_speaker_embedding=config.use_speaker_embedding, legacy=config.legacy ) self.has_weight_norm = True # self.__deepcopy__ = MethodType(__deepcopy__, self) def forward(self, input_, target, seq_lengths, *_): if target is not None: # During training and testing with teacher forcing. assert self.has_weight_norm, "Model has been used for generation " \ "and weight norm was removed, cannot continue training. Remove"\ " the make_generation_fast_() call to continue training after" \ " generation." output = self.model(target, c=input_, g=None, softmax=False) # output = self.model(target, c=inputs[:, :, :target.shape[2]], g=None, softmax=False) # Output shape is B x C x T. Don't permute here because CrossEntropyLoss requires the same shape. else: # During inference. with torch.no_grad(): self.model.make_generation_fast_() # After calling this the training cannot be continued. self.has_weight_norm = False assert(len(seq_lengths) == 1), "Batch synth is not supported." num_frames_to_gen = seq_lengths[0] * self.len_in_out_multiplier output = self.model.incremental_forward( c=input_, T=num_frames_to_gen, softmax=True, quantize=True) # output = self.model.incremental_forward( # c=inputs[:, :, :1000], T=torch.tensor(1000), softmax=True, quantize=True) # Output shape is B x C x T. return output, None def set_gpu_flag(self, use_gpu): self.use_gpu = use_gpu def init_hidden(self, batch_size=1): return None def parameters(self): return self.model.parameters()