Exemplo n.º 1
0
class WaveNetWrapper(nn.Module):
    """A wrapper around r9y9's WaveNet implementation to integrate it seamlessly into the framework."""
    IDENTIFIER = "r9y9WaveNet"

    def __init__(self, dim_in, dim_out, hparams):
        super().__init__()

        self.len_in_out_multiplier = hparams.len_in_out_multiplier

        # Use the wavenet_vocoder builder to create the model.
        self.model = WaveNet(out_channels=hparams.out_channels,
                             layers=hparams.layers,
                             stacks=hparams.stacks,
                             residual_channels=hparams.residual_channels,
                             gate_channels=hparams.gate_channels,
                             skip_out_channels=hparams.skip_out_channels,
                             kernel_size=hparams.kernel_size,
                             dropout=hparams.dropout,
                             weight_normalization=hparams.weight_normalization,
                             cin_channels=hparams.cin_channels,
                             gin_channels=hparams.gin_channels,
                             n_speakers=hparams.n_speakers,
                             upsample_conditional_features=hparams.upsample_conditional_features,
                             upsample_scales=hparams.upsample_scales,
                             freq_axis_kernel_size=hparams.freq_axis_kernel_size,
                             scalar_input=is_scalar_input(hparams.input_type),
                             use_speaker_embedding=hparams.use_speaker_embedding,
                             )

    def forward(self, inputs, hidden, seq_lengths_inputs, max_length_inputs, target=None, seq_lengths_target=None):

        if target is not None:  # During training and testing with teacher forcing.
            output = self.model(target, c=inputs, g=None, softmax=False)
            # output = self.model(target, c=inputs[:, :, :target.shape[2]], g=None, softmax=False)
            # Output shape is B x C x T. Don't permute here because CrossEntropyLoss requires the same shape.
        else:  # During inference.
            with torch.no_grad():
                self.model.make_generation_fast_()
                assert(len(seq_lengths_inputs) == 1), "Batch synthesis is not supported yet."
                num_frames_to_gen = seq_lengths_inputs[0] * self.len_in_out_multiplier
                output = self.model.incremental_forward(c=inputs, T=num_frames_to_gen, softmax=True, quantize=True)
                # Output shape is B x C x T.

        return output, None

    def set_gpu_flag(self, use_gpu):
        self.use_gpu = use_gpu

    def init_hidden(self, batch_size=1):
        return None

    def parameters(self):
        return self.model.parameters()
Exemplo n.º 2
0
class WaveNetWrapper(nn.Module):
    """A wrapper around r9y9's WaveNet implementation to integrate it seamlessly into the framework."""
    IDENTIFIER = "r9y9WaveNet"

    class Config:
        INPUT_TYPE_MULAW = "mulaw-quantize"
        INPUT_TYPE_RAW = "raw"

        def __init__(
                self,
                cin_channels=80,
                dropout=0.05,
                freq_axis_kernel_size=3,
                gate_channels=512,
                gin_channels=-1,
                hinge_regularizer=True,  # Only used in MoL prediction (INPUT_TYPE_RAW).
                kernel_size=3,
                layers=24,
                log_scale_min=float(np.log(1e-14)),  # Only used in INPUT_TYPE_RAW.
                n_speakers=1,
                out_channels=256,  # Use num_mixtures * 3 (pi, mean, log_scale) for INPUT_TYPE_RAW.
                residual_channels=512,
                scalar_input=is_scalar_input(INPUT_TYPE_MULAW),
                skip_out_channels=256,
                stacks=4,
                upsample_conditional_features=False,
                upsample_scales=[5, 4, 2],
                use_speaker_embedding=False,
                weight_normalization=True,
                legacy=False):

            self.cin_channels = cin_channels
            self.dropout = dropout
            self.freq_axis_kernel_size = freq_axis_kernel_size
            self.gate_channels = gate_channels
            self.gin_channels = gin_channels
            self.hinge_regularizer = hinge_regularizer
            self.kernel_size = kernel_size
            self.layers = layers
            self.log_scale_min = log_scale_min
            self.n_speakers = n_speakers
            self.out_channels = out_channels
            self.residual_channels = residual_channels
            self.scalar_input = scalar_input
            self.skip_out_channels = skip_out_channels
            self.stacks = stacks
            self.upsample_conditional_features = upsample_conditional_features
            self.upsample_scales = upsample_scales
            self.use_speaker_embedding = use_speaker_embedding
            self.weight_normalization = weight_normalization
            self.legacy = legacy

        def create_model(self):
            return WaveNetWrapper(self)

    def __init__(self, config):
        super().__init__()

        # self.len_in_out_multiplier = hparams.len_in_out_multiplier

        # Use the wavenet_vocoder builder to create the model.
        self.model = WaveNet(
            out_channels=config.out_channels,
            layers=config.layers,
            stacks=config.stacks,
            residual_channels=config.residual_channels,
            gate_channels=config.gate_channels,
            skip_out_channels=config.skip_out_channels,
            kernel_size=config.kernel_size,
            dropout=config.dropout,
            weight_normalization=config.weight_normalization,
            cin_channels=config.cin_channels,
            gin_channels=config.gin_channels,
            n_speakers=config.n_speakers,
            upsample_conditional_features=config.upsample_conditional_features,
            upsample_scales=config.upsample_scales,
            freq_axis_kernel_size=config.freq_axis_kernel_size,
            scalar_input=config.scalar_input,
            use_speaker_embedding=config.use_speaker_embedding,
            legacy=config.legacy
        )

        self.has_weight_norm = True
        # self.__deepcopy__ = MethodType(__deepcopy__, self)

    def forward(self, input_, target, seq_lengths, *_):

        if target is not None:  # During training and testing with teacher forcing.
            assert self.has_weight_norm, "Model has been used for generation " \
                "and weight norm was removed, cannot continue training. Remove"\
                " the make_generation_fast_() call to continue training after" \
                " generation."
            output = self.model(target, c=input_, g=None, softmax=False)
            # output = self.model(target, c=inputs[:, :, :target.shape[2]], g=None, softmax=False)
            # Output shape is B x C x T. Don't permute here because CrossEntropyLoss requires the same shape.
        else:  # During inference.
            with torch.no_grad():
                self.model.make_generation_fast_()  # After calling this the training cannot be continued.
                self.has_weight_norm = False
                assert(len(seq_lengths) == 1), "Batch synth is not supported."
                num_frames_to_gen = seq_lengths[0] * self.len_in_out_multiplier
                output = self.model.incremental_forward(
                    c=input_, T=num_frames_to_gen, softmax=True, quantize=True)
                # output = self.model.incremental_forward(
                #   c=inputs[:, :, :1000], T=torch.tensor(1000), softmax=True, quantize=True)

        # Output shape is B x C x T.
        return output, None

    def set_gpu_flag(self, use_gpu):
        self.use_gpu = use_gpu

    def init_hidden(self, batch_size=1):
        return None

    def parameters(self):
        return self.model.parameters()