Exemplo n.º 1
0
    def __init__(self, config: GlowTTSConfig):

        super().__init__()

        # pass all config fields to `self`
        # for fewer code change
        self.config = config
        for key in config:
            setattr(self, key, config[key])

        _, self.config, self.num_chars = self.get_characters(config)
        self.decoder_output_dim = config.out_channels

        self.init_multispeaker(config)

        # if is a multispeaker and c_in_channels is 0, set to 256
        self.c_in_channels = 0
        if self.num_speakers > 1:
            if self.d_vector_dim:
                self.c_in_channels = self.d_vector_dim
            elif self.c_in_channels == 0 and not self.d_vector_dim:
                # TODO: make this adjustable
                self.c_in_channels = 256

        self.run_data_dep_init = config.data_dep_init_steps > 0

        self.encoder = Encoder(
            self.num_chars,
            out_channels=self.out_channels,
            hidden_channels=self.hidden_channels_enc,
            hidden_channels_dp=self.hidden_channels_dp,
            encoder_type=self.encoder_type,
            encoder_params=self.encoder_params,
            mean_only=self.mean_only,
            use_prenet=self.use_encoder_prenet,
            dropout_p_dp=self.dropout_p_dp,
            c_in_channels=self.c_in_channels,
        )

        self.decoder = Decoder(
            self.out_channels,
            self.hidden_channels_dec,
            self.kernel_size_dec,
            self.dilation_rate,
            self.num_flow_blocks_dec,
            self.num_block_layers,
            dropout_p=self.dropout_p_dec,
            num_splits=self.num_splits,
            num_squeeze=self.num_squeeze,
            sigmoid_scale=self.sigmoid_scale,
            c_in_channels=self.c_in_channels,
        )
Exemplo n.º 2
0
    def __init__(
        self,
        config: GlowTTSConfig,
        ap: "AudioProcessor" = None,
        tokenizer: "TTSTokenizer" = None,
        speaker_manager: SpeakerManager = None,
    ):

        super().__init__(config, ap, tokenizer, speaker_manager)

        # pass all config fields to `self`
        # for fewer code change
        self.config = config
        for key in config:
            setattr(self, key, config[key])

        self.decoder_output_dim = config.out_channels

        # init multi-speaker layers if necessary
        self.init_multispeaker(config)

        self.run_data_dep_init = config.data_dep_init_steps > 0
        self.encoder = Encoder(
            self.num_chars,
            out_channels=self.out_channels,
            hidden_channels=self.hidden_channels_enc,
            hidden_channels_dp=self.hidden_channels_dp,
            encoder_type=self.encoder_type,
            encoder_params=self.encoder_params,
            mean_only=self.mean_only,
            use_prenet=self.use_encoder_prenet,
            dropout_p_dp=self.dropout_p_dp,
            c_in_channels=self.c_in_channels,
        )

        self.decoder = Decoder(
            self.out_channels,
            self.hidden_channels_dec,
            self.kernel_size_dec,
            self.dilation_rate,
            self.num_flow_blocks_dec,
            self.num_block_layers,
            dropout_p=self.dropout_p_dec,
            num_splits=self.num_splits,
            num_squeeze=self.num_squeeze,
            sigmoid_scale=self.sigmoid_scale,
            c_in_channels=self.c_in_channels,
        )
Exemplo n.º 3
0
    def __init__(
        self,
        num_chars,
        hidden_channels_enc,
        hidden_channels_dec,
        use_encoder_prenet,
        hidden_channels_dp,
        out_channels,
        num_flow_blocks_dec=12,
        inference_noise_scale=0.33,
        kernel_size_dec=5,
        dilation_rate=5,
        num_block_layers=4,
        dropout_p_dp=0.1,
        dropout_p_dec=0.05,
        num_speakers=0,
        c_in_channels=0,
        num_splits=4,
        num_squeeze=1,
        sigmoid_scale=False,
        mean_only=False,
        encoder_type="transformer",
        encoder_params=None,
        speaker_embedding_dim=None,
    ):

        super().__init__()
        self.num_chars = num_chars
        self.hidden_channels_dp = hidden_channels_dp
        self.hidden_channels_enc = hidden_channels_enc
        self.hidden_channels_dec = hidden_channels_dec
        self.out_channels = out_channels
        self.num_flow_blocks_dec = num_flow_blocks_dec
        self.kernel_size_dec = kernel_size_dec
        self.dilation_rate = dilation_rate
        self.num_block_layers = num_block_layers
        self.dropout_p_dec = dropout_p_dec
        self.num_speakers = num_speakers
        self.c_in_channels = c_in_channels
        self.num_splits = num_splits
        self.num_squeeze = num_squeeze
        self.sigmoid_scale = sigmoid_scale
        self.mean_only = mean_only
        self.use_encoder_prenet = use_encoder_prenet
        self.inference_noise_scale = inference_noise_scale

        # model constants.
        self.noise_scale = 0.33  # defines the noise variance applied to the random z vector at inference.
        self.length_scale = 1.0  # scaler for the duration predictor. The larger it is, the slower the speech.
        self.speaker_embedding_dim = speaker_embedding_dim

        # if is a multispeaker and c_in_channels is 0, set to 256
        if num_speakers > 1:
            if self.c_in_channels == 0 and not self.speaker_embedding_dim:
                # TODO: make this adjustable
                self.c_in_channels = 256
            elif self.speaker_embedding_dim:
                self.c_in_channels = self.speaker_embedding_dim

        self.encoder = Encoder(
            num_chars,
            out_channels=out_channels,
            hidden_channels=hidden_channels_enc,
            hidden_channels_dp=hidden_channels_dp,
            encoder_type=encoder_type,
            encoder_params=encoder_params,
            mean_only=mean_only,
            use_prenet=use_encoder_prenet,
            dropout_p_dp=dropout_p_dp,
            c_in_channels=self.c_in_channels,
        )

        self.decoder = Decoder(
            out_channels,
            hidden_channels_dec,
            kernel_size_dec,
            dilation_rate,
            num_flow_blocks_dec,
            num_block_layers,
            dropout_p=dropout_p_dec,
            num_splits=num_splits,
            num_squeeze=num_squeeze,
            sigmoid_scale=sigmoid_scale,
            c_in_channels=self.c_in_channels,
        )

        if num_speakers > 1 and not speaker_embedding_dim:
            # speaker embedding layer
            self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
            nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
Exemplo n.º 4
0
class GlowTTS(nn.Module):
    """Glow TTS models from https://arxiv.org/abs/2005.11129

    Args:
        num_chars (int): number of embedding characters.
        hidden_channels_enc (int): number of embedding and encoder channels.
        hidden_channels_dec (int): number of decoder channels.
        use_encoder_prenet (bool): enable/disable prenet for encoder. Prenet modules are hard-coded for each alternative encoder.
        hidden_channels_dp (int): number of duration predictor channels.
        out_channels (int): number of output channels. It should be equal to the number of spectrogram filter.
        num_flow_blocks_dec (int): number of decoder blocks.
        kernel_size_dec (int): decoder kernel size.
        dilation_rate (int): rate to increase dilation by each layer in a decoder block.
        num_block_layers (int): number of decoder layers in each decoder block.
        dropout_p_dec (float): dropout rate for decoder.
        num_speaker (int): number of speaker to define the size of speaker embedding layer.
        c_in_channels (int): number of speaker embedding channels. It is set to 512 if embeddings are learned.
        num_splits (int): number of split levels in inversible conv1x1 operation.
        num_squeeze (int): number of squeeze levels. When squeezing channels increases and time steps reduces by the factor 'num_squeeze'.
        sigmoid_scale (bool): enable/disable sigmoid scaling in decoder.
        mean_only (bool): if True, encoder only computes mean value and uses constant variance for each time step.
        encoder_type (str): encoder module type.
        encoder_params (dict): encoder module parameters.
        speaker_embedding_dim (int): channels of external speaker embedding vectors.
    """
    def __init__(
        self,
        num_chars,
        hidden_channels_enc,
        hidden_channels_dec,
        use_encoder_prenet,
        hidden_channels_dp,
        out_channels,
        num_flow_blocks_dec=12,
        inference_noise_scale=0.33,
        kernel_size_dec=5,
        dilation_rate=5,
        num_block_layers=4,
        dropout_p_dp=0.1,
        dropout_p_dec=0.05,
        num_speakers=0,
        c_in_channels=0,
        num_splits=4,
        num_squeeze=1,
        sigmoid_scale=False,
        mean_only=False,
        encoder_type="transformer",
        encoder_params=None,
        speaker_embedding_dim=None,
    ):

        super().__init__()
        self.num_chars = num_chars
        self.hidden_channels_dp = hidden_channels_dp
        self.hidden_channels_enc = hidden_channels_enc
        self.hidden_channels_dec = hidden_channels_dec
        self.out_channels = out_channels
        self.num_flow_blocks_dec = num_flow_blocks_dec
        self.kernel_size_dec = kernel_size_dec
        self.dilation_rate = dilation_rate
        self.num_block_layers = num_block_layers
        self.dropout_p_dec = dropout_p_dec
        self.num_speakers = num_speakers
        self.c_in_channels = c_in_channels
        self.num_splits = num_splits
        self.num_squeeze = num_squeeze
        self.sigmoid_scale = sigmoid_scale
        self.mean_only = mean_only
        self.use_encoder_prenet = use_encoder_prenet
        self.inference_noise_scale = inference_noise_scale

        # model constants.
        self.noise_scale = 0.33  # defines the noise variance applied to the random z vector at inference.
        self.length_scale = 1.0  # scaler for the duration predictor. The larger it is, the slower the speech.
        self.speaker_embedding_dim = speaker_embedding_dim

        # if is a multispeaker and c_in_channels is 0, set to 256
        if num_speakers > 1:
            if self.c_in_channels == 0 and not self.speaker_embedding_dim:
                # TODO: make this adjustable
                self.c_in_channels = 256
            elif self.speaker_embedding_dim:
                self.c_in_channels = self.speaker_embedding_dim

        self.encoder = Encoder(
            num_chars,
            out_channels=out_channels,
            hidden_channels=hidden_channels_enc,
            hidden_channels_dp=hidden_channels_dp,
            encoder_type=encoder_type,
            encoder_params=encoder_params,
            mean_only=mean_only,
            use_prenet=use_encoder_prenet,
            dropout_p_dp=dropout_p_dp,
            c_in_channels=self.c_in_channels,
        )

        self.decoder = Decoder(
            out_channels,
            hidden_channels_dec,
            kernel_size_dec,
            dilation_rate,
            num_flow_blocks_dec,
            num_block_layers,
            dropout_p=dropout_p_dec,
            num_splits=num_splits,
            num_squeeze=num_squeeze,
            sigmoid_scale=sigmoid_scale,
            c_in_channels=self.c_in_channels,
        )

        if num_speakers > 1 and not speaker_embedding_dim:
            # speaker embedding layer
            self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
            nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)

    @staticmethod
    def compute_outputs(attn, o_mean, o_log_scale, x_mask):
        # compute final values with the computed alignment
        y_mean = torch.matmul(
            attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
                1, 2)  # [b, t', t], [b, t, d] -> [b, d, t']
        y_log_scale = torch.matmul(
            attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(
                1, 2)).transpose(1, 2)  # [b, t', t], [b, t, d] -> [b, d, t']
        # compute total duration with adjustment
        o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
        return y_mean, y_log_scale, o_attn_dur

    def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
        """
        Shapes:
            x: [B, T]
            x_lenghts: B
            y: [B, C, T]
            y_lengths: B
            g: [B, C] or B
        """
        y_max_length = y.size(2)
        # norm speaker embeddings
        if g is not None:
            if self.speaker_embedding_dim:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h, 1]

        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # drop redisual frames wrt num_squeeze and set y_lengths.
        y, y_lengths, y_max_length, attn = self.preprocess(
            y, y_lengths, y_max_length, None)
        # create masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # decoder pass
        z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
        # find the alignment path
        with torch.no_grad():
            o_scale = torch.exp(-2 * o_log_scale)
            logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
                                 (z**2))  # [b, t, d] x [b, d, t'] = [b, t, t']
            logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
                                 z)  # [b, t, d] x [b, d, t'] = [b, t, t']
            logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp = logp1 + logp2 + logp3 + logp4  # [b, t, t']
            attn = maximum_path(logp,
                                attn_mask.squeeze(1)).unsqueeze(1).detach()
        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)
        attn = attn.squeeze(1).permute(0, 2, 1)
        return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur

    @torch.no_grad()
    def inference_with_MAS(self,
                           x,
                           x_lengths,
                           y=None,
                           y_lengths=None,
                           attn=None,
                           g=None):
        """
        It's similar to the teacher forcing in Tacotron.
        It was proposed in: https://arxiv.org/abs/2104.05557
        Shapes:
            x: [B, T]
            x_lenghts: B
            y: [B, C, T]
            y_lengths: B
            g: [B, C] or B
        """
        y_max_length = y.size(2)
        # norm speaker embeddings
        if g is not None:
            if self.external_speaker_embedding_dim:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h, 1]

        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # drop redisual frames wrt num_squeeze and set y_lengths.
        y, y_lengths, y_max_length, attn = self.preprocess(
            y, y_lengths, y_max_length, None)
        # create masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # decoder pass
        z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
        # find the alignment path between z and encoder output
        o_scale = torch.exp(-2 * o_log_scale)
        logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
                          [1]).unsqueeze(-1)  # [b, t, 1]
        logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
                             (z**2))  # [b, t, d] x [b, d, t'] = [b, t, t']
        logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
                             z)  # [b, t, d] x [b, d, t'] = [b, t, t']
        logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
                          [1]).unsqueeze(-1)  # [b, t, 1]
        logp = logp1 + logp2 + logp3 + logp4  # [b, t, t']
        attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()

        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)
        attn = attn.squeeze(1).permute(0, 2, 1)

        # get predited aligned distribution
        z = y_mean * y_mask

        # reverse the decoder and predict using the aligned distribution
        y, logdet = self.decoder(z, y_mask, g=g, reverse=True)

        return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur

    @torch.no_grad()
    def decoder_inference(self, y, y_lengths=None, g=None):
        """
        Shapes:
            y: [B, C, T]
            y_lengths: B
            g: [B, C] or B
        """
        y_max_length = y.size(2)
        # norm speaker embeddings
        if g is not None:
            if self.external_speaker_embedding_dim:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h, 1]

        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(y.dtype)

        # decoder pass
        z, logdet = self.decoder(y, y_mask, g=g, reverse=False)

        # reverse decoder and predict
        y, logdet = self.decoder(z, y_mask, g=g, reverse=True)

        return y, logdet

    @torch.no_grad()
    def inference(self, x, x_lengths, g=None):
        if g is not None:
            if self.speaker_embedding_dim:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h]

        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # compute output durations
        w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
        w_ceil = torch.ceil(w)
        y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
        y_max_length = None
        # compute masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # compute attention mask
        attn = generate_path(w_ceil.squeeze(1),
                             attn_mask.squeeze(1)).unsqueeze(1)
        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)

        z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
             self.inference_noise_scale) * y_mask
        # decoder pass
        y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
        attn = attn.squeeze(1).permute(0, 2, 1)
        return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur

    def preprocess(self, y, y_lengths, y_max_length, attn=None):
        if y_max_length is not None:
            y_max_length = (y_max_length //
                            self.num_squeeze) * self.num_squeeze
            y = y[:, :, :y_max_length]
            if attn is not None:
                attn = attn[:, :, :, :y_max_length]
        y_lengths = (y_lengths // self.num_squeeze) * self.num_squeeze
        return y, y_lengths, y_max_length, attn

    def store_inverse(self):
        self.decoder.store_inverse()

    def load_checkpoint(self, config, checkpoint_path, eval=False):  # pylint: disable=unused-argument, redefined-builtin
        state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
        self.load_state_dict(state["model"])
        if eval:
            self.eval()
            self.store_inverse()
            assert not self.training
Exemplo n.º 5
0
class GlowTTS(BaseTTS):
    """GlowTTS model.

    Paper::
        https://arxiv.org/abs/2005.11129

    Paper abstract::
        Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate
        mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained
        without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS,
        a flow-based generative model for parallel TTS that does not require any external aligner. By combining the
        properties of flows and dynamic programming, the proposed model searches for the most probable monotonic
        alignment between text and the latent representation of speech on its own. We demonstrate that enforcing hard
        monotonic alignments enables robust TTS, which generalizes to long utterances, and employing generative flows
        enables fast, diverse, and controllable speech synthesis. Glow-TTS obtains an order-of-magnitude speed-up over
        the autoregressive model, Tacotron 2, at synthesis with comparable speech quality. We further show that our
        model can be easily extended to a multi-speaker setting.

    Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments.

    Examples:
        >>> from TTS.tts.configs import GlowTTSConfig
        >>> from TTS.tts.models.glow_tts import GlowTTS
        >>> config = GlowTTSConfig()
        >>> model = GlowTTS(config)

    """
    def __init__(self, config: GlowTTSConfig):

        super().__init__()

        # pass all config fields to `self`
        # for fewer code change
        self.config = config
        for key in config:
            setattr(self, key, config[key])

        _, self.config, self.num_chars = self.get_characters(config)
        self.decoder_output_dim = config.out_channels

        self.init_multispeaker(config)

        # if is a multispeaker and c_in_channels is 0, set to 256
        self.c_in_channels = 0
        if self.num_speakers > 1:
            if self.d_vector_dim:
                self.c_in_channels = self.d_vector_dim
            elif self.c_in_channels == 0 and not self.d_vector_dim:
                # TODO: make this adjustable
                self.c_in_channels = 256

        self.run_data_dep_init = config.data_dep_init_steps > 0

        self.encoder = Encoder(
            self.num_chars,
            out_channels=self.out_channels,
            hidden_channels=self.hidden_channels_enc,
            hidden_channels_dp=self.hidden_channels_dp,
            encoder_type=self.encoder_type,
            encoder_params=self.encoder_params,
            mean_only=self.mean_only,
            use_prenet=self.use_encoder_prenet,
            dropout_p_dp=self.dropout_p_dp,
            c_in_channels=self.c_in_channels,
        )

        self.decoder = Decoder(
            self.out_channels,
            self.hidden_channels_dec,
            self.kernel_size_dec,
            self.dilation_rate,
            self.num_flow_blocks_dec,
            self.num_block_layers,
            dropout_p=self.dropout_p_dec,
            num_splits=self.num_splits,
            num_squeeze=self.num_squeeze,
            sigmoid_scale=self.sigmoid_scale,
            c_in_channels=self.c_in_channels,
        )

    def init_multispeaker(self, config: "Coqpit", data: list = None) -> None:
        """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
        or with external `d_vectors` computed from a speaker encoder model.

        If you need a different behaviour, override this function for your model.

        Args:
            config (Coqpit): Model configuration.
            data (List, optional): Dataset items to infer number of speakers. Defaults to None.
        """
        # init speaker manager
        self.speaker_manager = get_speaker_manager(config, data=data)
        self.num_speakers = self.speaker_manager.num_speakers
        if config.use_d_vector_file:
            self.external_d_vector_dim = config.d_vector_dim
        else:
            self.external_d_vector_dim = 0
        # init speaker embedding layer
        if config.use_speaker_embedding and not config.use_d_vector_file:
            self.embedded_speaker_dim = self.c_in_channels
            self.emb_g = nn.Embedding(self.num_speakers,
                                      self.embedded_speaker_dim)
            nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)

    @staticmethod
    def compute_outputs(attn, o_mean, o_log_scale, x_mask):
        """Compute and format the mode outputs with the given alignment map"""
        y_mean = torch.matmul(
            attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
                1, 2)  # [b, t', t], [b, t, d] -> [b, d, t']
        y_log_scale = torch.matmul(
            attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(
                1, 2)).transpose(1, 2)  # [b, t', t], [b, t, d] -> [b, d, t']
        # compute total duration with adjustment
        o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
        return y_mean, y_log_scale, o_attn_dur

    def unlock_act_norm_layers(self):
        """Unlock activation normalization layers for data depended initalization."""
        for f in self.decoder.flows:
            if getattr(f, "set_ddi", False):
                f.set_ddi(True)

    def lock_act_norm_layers(self):
        """Lock activation normalization layers."""
        for f in self.decoder.flows:
            if getattr(f, "set_ddi", False):
                f.set_ddi(False)

    def forward(self,
                x,
                x_lengths,
                y,
                y_lengths=None,
                aux_input={
                    "d_vectors": None,
                    "speaker_ids": None
                }):  # pylint: disable=dangerous-default-value
        """
        Shapes:
            - x: :math:`[B, T]`
            - x_lenghts::math:`B`
            - y: :math:`[B, T, C]`
            - y_lengths::math:`B`
            - g: :math:`[B, C] or B`
        """
        # [B, T, C] -> [B, C, T]
        y = y.transpose(1, 2)
        y_max_length = y.size(2)
        # norm speaker embeddings
        g = aux_input[
            "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
        if self.use_speaker_embedding or self.use_d_vector_file:
            if not self.use_d_vector_file:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h, 1]
        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # drop redisual frames wrt num_squeeze and set y_lengths.
        y, y_lengths, y_max_length, attn = self.preprocess(
            y, y_lengths, y_max_length, None)
        # create masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        # [B, 1, T_en, T_de]
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # decoder pass
        z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
        # find the alignment path
        with torch.no_grad():
            o_scale = torch.exp(-2 * o_log_scale)
            logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
                                 (z**2))  # [b, t, d] x [b, d, t'] = [b, t, t']
            logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
                                 z)  # [b, t, d] x [b, d, t'] = [b, t, t']
            logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp = logp1 + logp2 + logp3 + logp4  # [b, t, t']
            attn = maximum_path(logp,
                                attn_mask.squeeze(1)).unsqueeze(1).detach()
        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)
        attn = attn.squeeze(1).permute(0, 2, 1)
        outputs = {
            "z": z.transpose(1, 2),
            "logdet": logdet,
            "y_mean": y_mean.transpose(1, 2),
            "y_log_scale": y_log_scale.transpose(1, 2),
            "alignments": attn,
            "durations_log": o_dur_log.transpose(1, 2),
            "total_durations_log": o_attn_dur.transpose(1, 2),
        }
        return outputs

    @torch.no_grad()
    def inference_with_MAS(self,
                           x,
                           x_lengths,
                           y=None,
                           y_lengths=None,
                           aux_input={
                               "d_vectors": None,
                               "speaker_ids": None
                           }):  # pylint: disable=dangerous-default-value
        """
        It's similar to the teacher forcing in Tacotron.
        It was proposed in: https://arxiv.org/abs/2104.05557

        Shapes:
            - x: :math:`[B, T]`
            - x_lenghts: :math:`B`
            - y: :math:`[B, T, C]`
            - y_lengths: :math:`B`
            - g: :math:`[B, C] or B`
        """
        y = y.transpose(1, 2)
        y_max_length = y.size(2)
        # norm speaker embeddings
        g = aux_input[
            "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
        if self.use_speaker_embedding or self.use_d_vector_file:
            if not self.use_d_vector_file:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h, 1]
        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # drop redisual frames wrt num_squeeze and set y_lengths.
        y, y_lengths, y_max_length, attn = self.preprocess(
            y, y_lengths, y_max_length, None)
        # create masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # decoder pass
        z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
        # find the alignment path between z and encoder output
        o_scale = torch.exp(-2 * o_log_scale)
        logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
                          [1]).unsqueeze(-1)  # [b, t, 1]
        logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
                             (z**2))  # [b, t, d] x [b, d, t'] = [b, t, t']
        logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
                             z)  # [b, t, d] x [b, d, t'] = [b, t, t']
        logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
                          [1]).unsqueeze(-1)  # [b, t, 1]
        logp = logp1 + logp2 + logp3 + logp4  # [b, t, t']
        attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()

        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)
        attn = attn.squeeze(1).permute(0, 2, 1)

        # get predited aligned distribution
        z = y_mean * y_mask

        # reverse the decoder and predict using the aligned distribution
        y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
        outputs = {
            "model_outputs": z.transpose(1, 2),
            "logdet": logdet,
            "y_mean": y_mean.transpose(1, 2),
            "y_log_scale": y_log_scale.transpose(1, 2),
            "alignments": attn,
            "durations_log": o_dur_log.transpose(1, 2),
            "total_durations_log": o_attn_dur.transpose(1, 2),
        }
        return outputs

    @torch.no_grad()
    def decoder_inference(self,
                          y,
                          y_lengths=None,
                          aux_input={
                              "d_vectors": None,
                              "speaker_ids": None
                          }):  # pylint: disable=dangerous-default-value
        """
        Shapes:
            - y: :math:`[B, T, C]`
            - y_lengths: :math:`B`
            - g: :math:`[B, C] or B`
        """
        y = y.transpose(1, 2)
        y_max_length = y.size(2)
        g = aux_input[
            "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
        # norm speaker embeddings
        if g is not None:
            if self.external_d_vector_dim:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h, 1]

        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(y.dtype)

        # decoder pass
        z, logdet = self.decoder(y, y_mask, g=g, reverse=False)

        # reverse decoder and predict
        y, logdet = self.decoder(z, y_mask, g=g, reverse=True)

        outputs = {}
        outputs["model_outputs"] = y.transpose(1, 2)
        outputs["logdet"] = logdet
        return outputs

    @torch.no_grad()
    def inference(self,
                  x,
                  aux_input={
                      "x_lengths": None,
                      "d_vectors": None,
                      "speaker_ids": None
                  }):  # pylint: disable=dangerous-default-value
        x_lengths = aux_input["x_lengths"]
        g = aux_input[
            "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None

        if g is not None:
            if self.d_vector_dim:
                g = F.normalize(g).unsqueeze(-1)
            else:
                g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h]

        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # compute output durations
        w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
        w_ceil = torch.ceil(w)
        y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
        y_max_length = None
        # compute masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # compute attention mask
        attn = generate_path(w_ceil.squeeze(1),
                             attn_mask.squeeze(1)).unsqueeze(1)
        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)

        z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
             self.inference_noise_scale) * y_mask
        # decoder pass
        y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
        attn = attn.squeeze(1).permute(0, 2, 1)
        outputs = {
            "model_outputs": y.transpose(1, 2),
            "logdet": logdet,
            "y_mean": y_mean.transpose(1, 2),
            "y_log_scale": y_log_scale.transpose(1, 2),
            "alignments": attn,
            "durations_log": o_dur_log.transpose(1, 2),
            "total_durations_log": o_attn_dur.transpose(1, 2),
        }
        return outputs

    def train_step(self, batch: dict, criterion: nn.Module):
        """A single training step. Forward pass and loss computation. Run data depended initialization for the
        first `config.data_dep_init_steps` steps.

        Args:
            batch (dict): [description]
            criterion (nn.Module): [description]
        """
        text_input = batch["text_input"]
        text_lengths = batch["text_lengths"]
        mel_input = batch["mel_input"]
        mel_lengths = batch["mel_lengths"]
        d_vectors = batch["d_vectors"]
        speaker_ids = batch["speaker_ids"]

        if self.run_data_dep_init and self.training:
            # compute data-dependent initialization of activation norm layers
            self.unlock_act_norm_layers()
            with torch.no_grad():
                _ = self.forward(
                    text_input,
                    text_lengths,
                    mel_input,
                    mel_lengths,
                    aux_input={
                        "d_vectors": d_vectors,
                        "speaker_ids": speaker_ids
                    },
                )
            outputs = None
            loss_dict = None
            self.lock_act_norm_layers()
        else:
            # normal training step
            outputs = self.forward(
                text_input,
                text_lengths,
                mel_input,
                mel_lengths,
                aux_input={
                    "d_vectors": d_vectors,
                    "speaker_ids": speaker_ids
                },
            )

            with autocast(enabled=False):  # avoid mixed_precision in criterion
                loss_dict = criterion(
                    outputs["z"].float(),
                    outputs["y_mean"].float(),
                    outputs["y_log_scale"].float(),
                    outputs["logdet"].float(),
                    mel_lengths,
                    outputs["durations_log"].float(),
                    outputs["total_durations_log"].float(),
                    text_lengths,
                )
        return outputs, loss_dict

    def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):  # pylint: disable=no-self-use
        alignments = outputs["alignments"]
        text_input = batch["text_input"]
        text_lengths = batch["text_lengths"]
        mel_input = batch["mel_input"]
        d_vectors = batch["d_vectors"]
        speaker_ids = batch["speaker_ids"]

        # model runs reverse flow to predict spectrograms
        pred_outputs = self.inference(
            text_input[:1],
            aux_input={
                "x_lengths": text_lengths[:1],
                "d_vectors": d_vectors,
                "speaker_ids": speaker_ids
            },
        )
        model_outputs = pred_outputs["model_outputs"]

        pred_spec = model_outputs[0].data.cpu().numpy()
        gt_spec = mel_input[0].data.cpu().numpy()
        align_img = alignments[0].data.cpu().numpy()

        figures = {
            "prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
            "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
            "alignment": plot_alignment(align_img, output_fig=False),
        }

        # Sample audio
        train_audio = ap.inv_melspectrogram(pred_spec.T)
        return figures, {"audio": train_audio}

    @torch.no_grad()
    def eval_step(self, batch: dict, criterion: nn.Module):
        return self.train_step(batch, criterion)

    def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
        return self.train_log(ap, batch, outputs)

    @torch.no_grad()
    def test_run(self, ap):
        """Generic test run for `tts` models used by `Trainer`.

        You can override this for a different behaviour.

        Returns:
            Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
        """
        print(" | > Synthesizing test sentences.")
        test_audios = {}
        test_figures = {}
        test_sentences = self.config.test_sentences
        aux_inputs = self.get_aux_input()
        if len(test_sentences) == 0:
            print(" | [!] No test sentences provided.")
        else:
            for idx, sen in enumerate(test_sentences):
                outputs = synthesis(
                    self,
                    sen,
                    self.config,
                    "cuda" in str(next(self.parameters()).device),
                    ap,
                    speaker_id=aux_inputs["speaker_id"],
                    d_vector=aux_inputs["d_vector"],
                    style_wav=aux_inputs["style_wav"],
                    enable_eos_bos_chars=self.config.enable_eos_bos_chars,
                    use_griffin_lim=True,
                    do_trim_silence=False,
                )

                test_audios["{}-audio".format(idx)] = outputs["wav"]
                test_figures["{}-prediction".format(idx)] = plot_spectrogram(
                    outputs["outputs"]["model_outputs"], ap, output_fig=False)
                test_figures["{}-alignment".format(idx)] = plot_alignment(
                    outputs["alignments"], output_fig=False)
        return test_figures, test_audios

    def preprocess(self, y, y_lengths, y_max_length, attn=None):
        if y_max_length is not None:
            y_max_length = (y_max_length //
                            self.num_squeeze) * self.num_squeeze
            y = y[:, :, :y_max_length]
            if attn is not None:
                attn = attn[:, :, :, :y_max_length]
        y_lengths = (y_lengths // self.num_squeeze) * self.num_squeeze
        return y, y_lengths, y_max_length, attn

    def store_inverse(self):
        self.decoder.store_inverse()

    def load_checkpoint(self, config, checkpoint_path, eval=False):  # pylint: disable=unused-argument, redefined-builtin
        state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
        self.load_state_dict(state["model"])
        if eval:
            self.eval()
            self.store_inverse()
            assert not self.training

    def get_criterion(self):
        from TTS.tts.layers.losses import GlowTTSLoss  # pylint: disable=import-outside-toplevel

        return GlowTTSLoss()

    def on_train_step_start(self, trainer):
        """Decide on every training step wheter enable/disable data depended initialization."""
        self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps
Exemplo n.º 6
0
    def __init__(self,
                 num_chars,
                 hidden_channels,
                 filter_channels,
                 filter_channels_dp,
                 out_channels,
                 kernel_size=3,
                 num_heads=2,
                 num_layers_enc=6,
                 dropout_p=0.1,
                 num_flow_blocks_dec=12,
                 kernel_size_dec=5,
                 dilation_rate=5,
                 num_block_layers=4,
                 dropout_p_dec=0.,
                 num_speakers=0,
                 c_in_channels=0,
                 num_splits=4,
                 num_sqz=1,
                 sigmoid_scale=False,
                 rel_attn_window_size=None,
                 input_length=None,
                 mean_only=False,
                 hidden_channels_enc=None,
                 hidden_channels_dec=None,
                 use_encoder_prenet=False,
                 encoder_type="transformer"):

        super().__init__()
        self.num_chars = num_chars
        self.hidden_channels = hidden_channels
        self.filter_channels = filter_channels
        self.filter_channels_dp = filter_channels_dp
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.num_heads = num_heads
        self.num_layers_enc = num_layers_enc
        self.dropout_p = dropout_p
        self.num_flow_blocks_dec = num_flow_blocks_dec
        self.kernel_size_dec = kernel_size_dec
        self.dilation_rate = dilation_rate
        self.num_block_layers = num_block_layers
        self.dropout_p_dec = dropout_p_dec
        self.num_speakers = num_speakers
        self.c_in_channels = c_in_channels
        self.num_splits = num_splits
        self.num_sqz = num_sqz
        self.sigmoid_scale = sigmoid_scale
        self.rel_attn_window_size = rel_attn_window_size
        self.input_length = input_length
        self.mean_only = mean_only
        self.hidden_channels_enc = hidden_channels_enc
        self.hidden_channels_dec = hidden_channels_dec
        self.use_encoder_prenet = use_encoder_prenet
        self.noise_scale = 0.66
        self.length_scale = 1.

        self.encoder = Encoder(num_chars,
                               out_channels=out_channels,
                               hidden_channels=hidden_channels,
                               filter_channels=filter_channels,
                               filter_channels_dp=filter_channels_dp,
                               encoder_type=encoder_type,
                               num_heads=num_heads,
                               num_layers=num_layers_enc,
                               kernel_size=kernel_size,
                               dropout_p=dropout_p,
                               mean_only=mean_only,
                               use_prenet=use_encoder_prenet,
                               c_in_channels=c_in_channels)

        self.decoder = Decoder(out_channels,
                               hidden_channels_dec or hidden_channels,
                               kernel_size_dec,
                               dilation_rate,
                               num_flow_blocks_dec,
                               num_block_layers,
                               dropout_p=dropout_p_dec,
                               num_splits=num_splits,
                               num_sqz=num_sqz,
                               sigmoid_scale=sigmoid_scale,
                               c_in_channels=c_in_channels)

        if num_speakers > 1:
            self.emb_g = nn.Embedding(num_speakers, c_in_channels)
            nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
Exemplo n.º 7
0
class GlowTts(nn.Module):
    """Glow TTS models from https://arxiv.org/abs/2005.11129"""
    def __init__(self,
                 num_chars,
                 hidden_channels,
                 filter_channels,
                 filter_channels_dp,
                 out_channels,
                 kernel_size=3,
                 num_heads=2,
                 num_layers_enc=6,
                 dropout_p=0.1,
                 num_flow_blocks_dec=12,
                 kernel_size_dec=5,
                 dilation_rate=5,
                 num_block_layers=4,
                 dropout_p_dec=0.,
                 num_speakers=0,
                 c_in_channels=0,
                 num_splits=4,
                 num_sqz=1,
                 sigmoid_scale=False,
                 rel_attn_window_size=None,
                 input_length=None,
                 mean_only=False,
                 hidden_channels_enc=None,
                 hidden_channels_dec=None,
                 use_encoder_prenet=False,
                 encoder_type="transformer"):

        super().__init__()
        self.num_chars = num_chars
        self.hidden_channels = hidden_channels
        self.filter_channels = filter_channels
        self.filter_channels_dp = filter_channels_dp
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.num_heads = num_heads
        self.num_layers_enc = num_layers_enc
        self.dropout_p = dropout_p
        self.num_flow_blocks_dec = num_flow_blocks_dec
        self.kernel_size_dec = kernel_size_dec
        self.dilation_rate = dilation_rate
        self.num_block_layers = num_block_layers
        self.dropout_p_dec = dropout_p_dec
        self.num_speakers = num_speakers
        self.c_in_channels = c_in_channels
        self.num_splits = num_splits
        self.num_sqz = num_sqz
        self.sigmoid_scale = sigmoid_scale
        self.rel_attn_window_size = rel_attn_window_size
        self.input_length = input_length
        self.mean_only = mean_only
        self.hidden_channels_enc = hidden_channels_enc
        self.hidden_channels_dec = hidden_channels_dec
        self.use_encoder_prenet = use_encoder_prenet
        self.noise_scale = 0.66
        self.length_scale = 1.

        self.encoder = Encoder(num_chars,
                               out_channels=out_channels,
                               hidden_channels=hidden_channels,
                               filter_channels=filter_channels,
                               filter_channels_dp=filter_channels_dp,
                               encoder_type=encoder_type,
                               num_heads=num_heads,
                               num_layers=num_layers_enc,
                               kernel_size=kernel_size,
                               dropout_p=dropout_p,
                               mean_only=mean_only,
                               use_prenet=use_encoder_prenet,
                               c_in_channels=c_in_channels)

        self.decoder = Decoder(out_channels,
                               hidden_channels_dec or hidden_channels,
                               kernel_size_dec,
                               dilation_rate,
                               num_flow_blocks_dec,
                               num_block_layers,
                               dropout_p=dropout_p_dec,
                               num_splits=num_splits,
                               num_sqz=num_sqz,
                               sigmoid_scale=sigmoid_scale,
                               c_in_channels=c_in_channels)

        if num_speakers > 1:
            self.emb_g = nn.Embedding(num_speakers, c_in_channels)
            nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)

    @staticmethod
    def compute_outputs(attn, o_mean, o_log_scale, x_mask):
        # compute final values with the computed alignment
        y_mean = torch.matmul(
            attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
                1, 2)  # [b, t', t], [b, t, d] -> [b, d, t']
        y_log_scale = torch.matmul(
            attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(
                1, 2)).transpose(1, 2)  # [b, t', t], [b, t, d] -> [b, d, t']
        # compute total duration with adjustment
        o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
        return y_mean, y_log_scale, o_attn_dur

    def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
        """
            Shapes:
                x: B x T
                x_lenghts: B
                y: B x C x T
                y_lengths: B
        """
        y_max_length = y.size(2)
        # norm speaker embeddings
        if g is not None:
            g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h]
        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # format feature vectors and feature vector lenghts
        y, y_lengths, y_max_length, attn = self.preprocess(
            y, y_lengths, y_max_length, None)
        # create masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # decoder pass
        z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
        # find the alignment path
        with torch.no_grad():
            o_scale = torch.exp(-2 * o_log_scale)
            logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
                                 (z**2))  # [b, t, d] x [b, d, t'] = [b, t, t']
            logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
                                 z)  # [b, t, d] x [b, d, t'] = [b, t, t']
            logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
                              [1]).unsqueeze(-1)  # [b, t, 1]
            logp = logp1 + logp2 + logp3 + logp4  # [b, t, t']
            attn = maximum_path(logp,
                                attn_mask.squeeze(1)).unsqueeze(1).detach()
        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)
        attn = attn.squeeze(1).permute(0, 2, 1)
        return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur

    @torch.no_grad()
    def inference(self, x, x_lengths, g=None):
        if g is not None:
            g = F.normalize(self.emb_g(g)).unsqueeze(-1)  # [b, h]
        # embedding pass
        o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
                                                              x_lengths,
                                                              g=g)
        # compute output durations
        w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
        w_ceil = torch.ceil(w)
        y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
        y_max_length = None
        # compute masks
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
                                 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        # compute attention mask
        attn = generate_path(w_ceil.squeeze(1),
                             attn_mask.squeeze(1)).unsqueeze(1)
        y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
            attn, o_mean, o_log_scale, x_mask)
        z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
             self.noise_scale) * y_mask
        # decoder pass
        y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
        attn = attn.squeeze(1).permute(0, 2, 1)
        return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur

    def preprocess(self, y, y_lengths, y_max_length, attn=None):
        if y_max_length is not None:
            y_max_length = (y_max_length // self.num_sqz) * self.num_sqz
            y = y[:, :, :y_max_length]
            if attn is not None:
                attn = attn[:, :, :, :y_max_length]
        y_lengths = (y_lengths // self.num_sqz) * self.num_sqz
        return y, y_lengths, y_max_length, attn

    def store_inverse(self):
        self.decoder.store_inverse()