Exemplo n.º 1
0
class FeedForwardTransformer(torch.nn.Module):
    """Feed Forward Transformer for TTS a.k.a. FastSpeech.
    This is a module of FastSpeech, feed-forward Transformer with duration predictor described in
    `FastSpeech: Fast, Robust and Controllable Text to Speech`_, which does not require any auto-regressive
    processing during inference, resulting in fast decoding compared with auto-regressive Transformer.
    .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
        https://arxiv.org/pdf/1905.09263.pdf
    """
    def __init__(self, idim: int, odim: int, hp: Dict):
        """Initialize feed-forward Transformer module.
        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
        """
        # initialize base classes
        assert check_argument_types()
        torch.nn.Module.__init__(self)

        # fill missing arguments

        # store hyperparameters
        self.idim = idim
        self.odim = odim

        self.use_scaled_pos_enc = hp.model.use_scaled_pos_enc
        self.use_masking = hp.model.use_masking

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = (ScaledPositionalEncoding
                         if self.use_scaled_pos_enc else PositionalEncoding)

        # define encoder
        encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                 embedding_dim=hp.model.adim,
                                                 padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=hp.model.adim,
            attention_heads=hp.model.aheads,
            linear_units=hp.model.eunits,
            num_blocks=hp.model.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.model.encoder_normalize_before,
            concat_after=hp.model.encoder_concat_after,
            positionwise_layer_type=hp.model.positionwise_layer_type,
            positionwise_conv_kernel_size=hp.model.
            positionwise_conv_kernel_size,
        )

        self.duration_predictor = DurationPredictor(
            idim=hp.model.adim,
            n_layers=hp.model.duration_predictor_layers,
            n_chans=hp.model.duration_predictor_chans,
            kernel_size=hp.model.duration_predictor_kernel_size,
            dropout_rate=hp.model.duration_predictor_dropout_rate,
        )

        self.energy_predictor = EnergyPredictor(
            idim=hp.model.adim,
            n_layers=hp.model.duration_predictor_layers,
            n_chans=hp.model.duration_predictor_chans,
            kernel_size=hp.model.duration_predictor_kernel_size,
            dropout_rate=hp.model.duration_predictor_dropout_rate,
            min=hp.data.e_min,
            max=hp.data.e_max,
        )
        self.energy_embed = torch.nn.Linear(hp.model.adim, hp.model.adim)

        self.pitch_predictor = PitchPredictor(
            idim=hp.model.adim,
            n_layers=hp.model.duration_predictor_layers,
            n_chans=hp.model.duration_predictor_chans,
            kernel_size=hp.model.duration_predictor_kernel_size,
            dropout_rate=hp.model.duration_predictor_dropout_rate,
            min=hp.data.p_min,
            max=hp.data.p_max,
        )
        self.pitch_embed = torch.nn.Linear(hp.model.adim, hp.model.adim)

        # define length regulator
        self.length_regulator = LengthRegulator()

        ###### AdaSpeech

        self.utterance_encoder = UtteranceEncoder(idim=hp.audio.n_mels)

        self.phoneme_level_encoder = PhonemeLevelEncoder(idim=hp.audio.n_mels)

        self.phoneme_level_predictor = PhonemeLevelPredictor(
            idim=hp.model.adim)

        self.phone_level_embed = torch.nn.Linear(hp.model.phn_latent_dim,
                                                 hp.model.adim)

        self.acoustic_criterion = AcousticPredictorLoss()

        # define decoder
        # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder
        self.decoder = Encoder(
            idim=hp.model.adim,
            attention_dim=hp.model.ddim,
            attention_heads=hp.model.aheads,
            linear_units=hp.model.dunits,
            num_blocks=hp.model.dlayers,
            input_layer="linear",
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.model.decoder_normalize_before,
            concat_after=hp.model.decoder_concat_after,
            positionwise_layer_type=hp.model.positionwise_layer_type,
            positionwise_conv_kernel_size=hp.model.
            positionwise_conv_kernel_size,
        )

        # define postnet
        self.postnet = (None if hp.model.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=hp.model.postnet_layers,
            n_chans=hp.model.postnet_chans,
            n_filts=hp.model.postnet_filts,
            use_batch_norm=hp.model.use_batch_norm,
            dropout_rate=hp.model.postnet_dropout_rate,
        ))

        # define final projection
        self.feat_out = torch.nn.Linear(hp.model.ddim,
                                        odim * hp.model.reduction_factor)

        # initialize parameters
        self._reset_parameters(
            init_type=hp.model.transformer_init,
            init_enc_alpha=hp.model.initial_encoder_alpha,
            init_dec_alpha=hp.model.initial_decoder_alpha,
        )

        # define criterions
        self.duration_criterion = DurationPredictorLoss()
        self.energy_criterion = EnergyPredictorLoss()
        self.pitch_criterion = PitchPredictorLoss()
        self.criterion = torch.nn.L1Loss(reduction="mean")
        self.use_weighted_masking = hp.model.use_weighted_masking

    def _forward(
        self,
        xs: torch.Tensor,
        ilens: torch.Tensor,
        ys: torch.Tensor = None,
        olens: torch.Tensor = None,
        ds: torch.Tensor = None,
        es: torch.Tensor = None,
        ps: torch.Tensor = None,
        is_inference: bool = False,
        phn_level_predictor: bool = False,
        avg_mel: torch.Tensor = None,
    ) -> Sequence[torch.Tensor]:
        # forward encoder
        x_masks = self._source_mask(
            ilens)  # (B, Tmax, Tmax) -> torch.Size([32, 121, 121])

        hs, _ = self.encoder(
            xs, x_masks)  # (B, Tmax, adim) -> torch.Size([32, 121, 256])

        ## AdaSpeech Specific ##

        uttr = self.utterance_encoder(ys.transpose(1, 2)).transpose(1, 2)
        hs = hs + uttr.repeat(1, hs.size(1), 1)

        phn = None
        ys_phn = None

        if phn_level_predictor:
            if is_inference:
                ys_phn = self.phoneme_level_predictor(hs.transpose(
                    1, 2))  # (B, Tmax, 4)
                hs = hs + self.phone_level_embed(ys_phn)
            else:
                with torch.no_grad():
                    ys_phn = self.phoneme_level_encoder(
                        avg_mel.transpose(1, 2))  # (B, Tmax, 4)

                phn = self.phoneme_level_predictor(hs.transpose(
                    1, 2))  # (B, Tmax, 4)
                hs = hs + self.phone_level_embed(
                    ys_phn.detach())  # (B, Tmax, 256)

        else:
            ys_phn = self.phoneme_level_encoder(avg_mel.transpose(
                1, 2))  # (B, Tmax, 4)
            hs = hs + self.phone_level_embed(ys_phn)  # (B, Tmax, 256)

        # forward duration predictor and length regulator
        d_masks = make_pad_mask(ilens).to(xs.device)

        if is_inference:
            d_outs = self.duration_predictor.inference(hs,
                                                       d_masks)  # (B, Tmax)
            hs = self.length_regulator(hs, d_outs, ilens)  # (B, Lmax, adim)
            one_hot_energy = self.energy_predictor.inference(
                hs)  # (B, Lmax, adim)
            one_hot_pitch = self.pitch_predictor.inference(
                hs)  # (B, Lmax, adim)
        else:
            with torch.no_grad():

                one_hot_energy = self.energy_predictor.to_one_hot(
                    es)  # (B, Lmax, adim)   torch.Size([32, 868, 256])

                one_hot_pitch = self.pitch_predictor.to_one_hot(
                    ps)  # (B, Lmax, adim)   torch.Size([32, 868, 256])

            mel_masks = make_pad_mask(olens).to(xs.device)

            d_outs = self.duration_predictor(hs, d_masks)  # (B, Tmax)

            hs = self.length_regulator(hs, ds, ilens)  # (B, Lmax, adim)

            e_outs = self.energy_predictor(hs, mel_masks)

            p_outs = self.pitch_predictor(hs, mel_masks)

        hs = hs + self.pitch_embed(one_hot_pitch)  # (B, Lmax, adim)
        hs = hs + self.energy_embed(one_hot_energy)  # (B, Lmax, adim)
        # forward decoder
        if olens is not None:
            h_masks = self._source_mask(olens)
        else:
            h_masks = None

        zs, _ = self.decoder(hs, h_masks)  # (B, Lmax, adim)

        before_outs = self.feat_out(zs).view(zs.size(0), -1,
                                             self.odim)  # (B, Lmax, odim)

        # postnet -> (B, Lmax//r * r, odim)
        if self.postnet is None:
            after_outs = before_outs
        else:
            after_outs = before_outs + self.postnet(before_outs.transpose(
                1, 2)).transpose(1, 2)

        if is_inference:
            return before_outs, after_outs, d_outs, one_hot_energy, one_hot_pitch
        else:
            return before_outs, after_outs, d_outs, e_outs, p_outs, phn, ys_phn

    def forward(
        self,
        xs: torch.Tensor,
        ilens: torch.Tensor,
        ys: torch.Tensor,
        olens: torch.Tensor,
        ds: torch.Tensor,
        es: torch.Tensor,
        ps: torch.Tensor,
        avg_mel: torch.Tensor = None,
        phn_level_predictor: bool = False
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """Calculate forward propagation.
        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
        Returns:
            Tensor: Loss value.
        """
        # remove unnecessary padded part (for multi-gpus)
        xs = xs[:, :max(ilens)]  # torch.Size([32, 121]) -> [B, Tmax]
        ys = ys[:, :max(olens)]  # torch.Size([32, 868, 80]) -> [B, Lmax, odim]

        # forward propagation
        before_outs, after_outs, d_outs, e_outs, p_outs, phn, ys_phn = self._forward(
            xs,
            ilens,
            olens,
            ds,
            es,
            ps,
            is_inference=False,
            avg_mel=avg_mel,
            phn_level_predictor=phn_level_predictor)

        # apply mask to remove padded part
        if self.use_masking:
            in_masks = make_non_pad_mask(ilens).to(xs.device)
            d_outs = d_outs.masked_select(in_masks)
            ds = ds.masked_select(in_masks)
            out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
            mel_masks = make_non_pad_mask(olens).to(ys.device)
            before_outs = before_outs.masked_select(out_masks)
            es = es.masked_select(mel_masks)  # Write size
            ps = ps.masked_select(mel_masks)  # Write size
            e_outs = e_outs.masked_select(mel_masks)  # Write size
            p_outs = p_outs.masked_select(mel_masks)  # Write size
            after_outs = (after_outs.masked_select(out_masks)
                          if after_outs is not None else None)
            ys = ys.masked_select(out_masks)
            if phn is not None and ys_phn is not None:
                phn = phn.masked_select(in_masks.unsqueeze(-1))
                ys_phn = ys_phn.masked_select(in_masks.unsqueeze(-1))

        acoustic_loss = 0

        if phn_level_predictor:
            acoustic_loss = self.acoustic_criterion(ys_phn, phn)

        # calculate loss
        before_loss = self.criterion(before_outs, ys)
        after_loss = 0
        if after_outs is not None:
            after_loss = self.criterion(after_outs, ys)
            l1_loss = before_loss + after_loss
        duration_loss = self.duration_criterion(d_outs, ds)
        energy_loss = self.energy_criterion(e_outs, es)
        pitch_loss = self.pitch_criterion(p_outs, ps)

        # make weighted mask and apply it
        if self.use_weighted_masking:
            out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
            out_weights = out_masks.float() / out_masks.sum(
                dim=1, keepdim=True).float()
            out_weights /= ys.size(0) * ys.size(2)
            duration_masks = make_non_pad_mask(ilens).to(ys.device)
            duration_weights = (
                duration_masks.float() /
                duration_masks.sum(dim=1, keepdim=True).float())
            duration_weights /= ds.size(0)

            # apply weight
            l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum()
            duration_loss = (duration_loss.mul(duration_weights).masked_select(
                duration_masks).sum())

        loss = l1_loss + duration_loss + energy_loss + pitch_loss + acoustic_loss
        report_keys = [
            {
                "l1_loss": l1_loss.item()
            },
            {
                "before_loss": before_loss.item()
            },
            {
                "after_loss": after_loss.item()
            },
            {
                "duration_loss": duration_loss.item()
            },
            {
                "energy_loss": energy_loss.item()
            },
            {
                "pitch_loss": pitch_loss.item()
            },
            {
                "acostic_loss": acoustic_loss
            },
            {
                "loss": loss.item()
            },
        ]

        # self.reporter.report(report_keys)

        return loss, report_keys

    def inference(self,
                  x: torch.Tensor,
                  ref_mel: torch.Tensor = None,
                  avg_mel: torch.Tensor = None,
                  phn_level_predictor: bool = True) -> torch.Tensor:
        """Generate the sequence of features given the sequences of characters.
        Args:
            x (Tensor): Input sequence of characters (T,).
            inference_args (Namespace): Dummy for compatibility.
            spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).
        Returns:
            Tensor: Output sequence of features (1, L, odim).
            None: Dummy for compatibility.
            None: Dummy for compatibility.
        """
        # setup batch axis
        ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
        xs = x.unsqueeze(0)

        if ref_mel is not None:
            ref_mel = ref_mel.unsqueeze(0)
        if avg_mel is not None:
            avg_mel = avg_mel.unsqueeze(0)
            # inference
            before_outs, outs, d_outs, _ = self._forward(
                xs,
                ilens=ilens,
                ys=ref_mel,
                avg_mel=avg_mel,
                is_inference=True,
                phn_level_predictor=phn_level_predictor)  # (L, odim)
        else:
            before_outs, outs, d_outs, _ = self._forward(
                xs,
                ilens=ilens,
                ys=ref_mel,
                is_inference=True,
                phn_level_predictor=phn_level_predictor)  # (L, odim)

        # inference
        _, outs, _, _, _ = self._forward(xs, ilens,
                                         is_inference=True)  # (L, odim)

        return outs[0]

    def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor:
        """Make masks for self-attention.
        Examples:
            >>> ilens = [5, 3]
            >>> self._source_mask(ilens)
            tensor([[[1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1]],
                    [[1, 1, 1, 0, 0],
                     [1, 1, 1, 0, 0],
                     [1, 1, 1, 0, 0],
                     [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0]]], dtype=torch.uint8)
        """
        x_masks = make_non_pad_mask(ilens).to(
            device=next(self.parameters()).device)
        return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)

    def _reset_parameters(self,
                          init_type: str,
                          init_enc_alpha: float = 1.0,
                          init_dec_alpha: float = 1.0):
        # initialize parameters
        initialize(self, init_type)

        # initialize alpha in scaled positional encoding
        if self.use_scaled_pos_enc:
            self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha)
            self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)
Exemplo n.º 2
0
    def __init__(self, idim: int, odim: int, hp: Dict):
        """Initialize feed-forward Transformer module.
        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
        """
        # initialize base classes
        assert check_argument_types()
        torch.nn.Module.__init__(self)

        # fill missing arguments

        # store hyperparameters
        self.idim = idim
        self.odim = odim

        self.use_scaled_pos_enc = hp.model.use_scaled_pos_enc
        self.use_masking = hp.model.use_masking

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = (ScaledPositionalEncoding
                         if self.use_scaled_pos_enc else PositionalEncoding)

        # define encoder
        encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                 embedding_dim=hp.model.adim,
                                                 padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=hp.model.adim,
            attention_heads=hp.model.aheads,
            linear_units=hp.model.eunits,
            num_blocks=hp.model.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.model.encoder_normalize_before,
            concat_after=hp.model.encoder_concat_after,
            positionwise_layer_type=hp.model.positionwise_layer_type,
            positionwise_conv_kernel_size=hp.model.
            positionwise_conv_kernel_size,
        )

        self.duration_predictor = DurationPredictor(
            idim=hp.model.adim,
            n_layers=hp.model.duration_predictor_layers,
            n_chans=hp.model.duration_predictor_chans,
            kernel_size=hp.model.duration_predictor_kernel_size,
            dropout_rate=hp.model.duration_predictor_dropout_rate,
        )

        self.energy_predictor = EnergyPredictor(
            idim=hp.model.adim,
            n_layers=hp.model.duration_predictor_layers,
            n_chans=hp.model.duration_predictor_chans,
            kernel_size=hp.model.duration_predictor_kernel_size,
            dropout_rate=hp.model.duration_predictor_dropout_rate,
            min=hp.data.e_min,
            max=hp.data.e_max,
        )
        self.energy_embed = torch.nn.Linear(hp.model.adim, hp.model.adim)

        self.pitch_predictor = PitchPredictor(
            idim=hp.model.adim,
            n_layers=hp.model.duration_predictor_layers,
            n_chans=hp.model.duration_predictor_chans,
            kernel_size=hp.model.duration_predictor_kernel_size,
            dropout_rate=hp.model.duration_predictor_dropout_rate,
            min=hp.data.p_min,
            max=hp.data.p_max,
        )
        self.pitch_embed = torch.nn.Linear(hp.model.adim, hp.model.adim)

        # define length regulator
        self.length_regulator = LengthRegulator()

        ###### AdaSpeech

        self.utterance_encoder = UtteranceEncoder(idim=hp.audio.n_mels)

        self.phoneme_level_encoder = PhonemeLevelEncoder(idim=hp.audio.n_mels)

        self.phoneme_level_predictor = PhonemeLevelPredictor(
            idim=hp.model.adim)

        self.phone_level_embed = torch.nn.Linear(hp.model.phn_latent_dim,
                                                 hp.model.adim)

        self.acoustic_criterion = AcousticPredictorLoss()

        # define decoder
        # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder
        self.decoder = Encoder(
            idim=hp.model.adim,
            attention_dim=hp.model.ddim,
            attention_heads=hp.model.aheads,
            linear_units=hp.model.dunits,
            num_blocks=hp.model.dlayers,
            input_layer="linear",
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.model.decoder_normalize_before,
            concat_after=hp.model.decoder_concat_after,
            positionwise_layer_type=hp.model.positionwise_layer_type,
            positionwise_conv_kernel_size=hp.model.
            positionwise_conv_kernel_size,
        )

        # define postnet
        self.postnet = (None if hp.model.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=hp.model.postnet_layers,
            n_chans=hp.model.postnet_chans,
            n_filts=hp.model.postnet_filts,
            use_batch_norm=hp.model.use_batch_norm,
            dropout_rate=hp.model.postnet_dropout_rate,
        ))

        # define final projection
        self.feat_out = torch.nn.Linear(hp.model.ddim,
                                        odim * hp.model.reduction_factor)

        # initialize parameters
        self._reset_parameters(
            init_type=hp.model.transformer_init,
            init_enc_alpha=hp.model.initial_encoder_alpha,
            init_dec_alpha=hp.model.initial_decoder_alpha,
        )

        # define criterions
        self.duration_criterion = DurationPredictorLoss()
        self.energy_criterion = EnergyPredictorLoss()
        self.pitch_criterion = PitchPredictorLoss()
        self.criterion = torch.nn.L1Loss(reduction="mean")
        self.use_weighted_masking = hp.model.use_weighted_masking
Exemplo n.º 3
0
class FeedForwardTransformer(torch.nn.Module):
    def __init__(self, idim: int, odim: int, hp: Dict):
        """Initialize feed-forward Transformer module.
        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
        """
        # initialize base classes
        assert check_argument_types()
        torch.nn.Module.__init__(self)

        # fill missing arguments

        # store hyperparameters
        self.idim = idim
        self.odim = odim

        self.use_scaled_pos_enc = hp.model.use_scaled_pos_enc
        self.use_masking = hp.model.use_masking

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = (ScaledPositionalEncoding
                         if self.use_scaled_pos_enc else PositionalEncoding)

        # define encoder
        encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                 embedding_dim=hp.model.adim,
                                                 padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=hp.model.adim,
            attention_heads=hp.model.aheads,
            linear_units=hp.model.eunits,
            num_blocks=hp.model.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.model.encoder_normalize_before,
            concat_after=hp.model.encoder_concat_after,
            positionwise_layer_type=hp.model.positionwise_layer_type,
            positionwise_conv_kernel_size=hp.model.
            positionwise_conv_kernel_size,
        )

        self.duration_predictor = DurationPredictor(
            idim=hp.model.adim,
            n_layers=hp.model.duration_predictor_layers,
            n_chans=hp.model.duration_predictor_chans,
            kernel_size=hp.model.duration_predictor_kernel_size,
            dropout_rate=hp.model.duration_predictor_dropout_rate,
        )

        self.energy_predictor = EnergyPredictor(
            idim=hp.model.adim,
            n_layers=hp.model.duration_predictor_layers,
            n_chans=hp.model.duration_predictor_chans,
            kernel_size=hp.model.duration_predictor_kernel_size,
            dropout_rate=hp.model.duration_predictor_dropout_rate,
            min=hp.data.e_min,
            max=hp.data.e_max,
        )
        self.energy_embed = torch.nn.Linear(hp.model.adim, hp.model.adim)

        self.pitch_predictor = PitchPredictor(
            idim=hp.model.adim,
            n_layers=hp.model.duration_predictor_layers,
            n_chans=hp.model.duration_predictor_chans,
            kernel_size=hp.model.duration_predictor_kernel_size,
            dropout_rate=hp.model.duration_predictor_dropout_rate,
            min=hp.data.p_min,
            max=hp.data.p_max,
        )
        self.pitch_embed = torch.nn.Linear(hp.model.adim, hp.model.adim)

        # define length regulator
        self.length_regulator = LengthRegulator()

        # define decoder
        # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder
        self.decoder = Encoder(
            idim=256,
            attention_dim=256,
            attention_heads=hp.model.aheads,
            linear_units=hp.model.dunits,
            num_blocks=hp.model.dlayers,
            input_layer=None,
            dropout_rate=0.2,
            positional_dropout_rate=0.2,
            attention_dropout_rate=0.2,
            pos_enc_class=pos_enc_class,
            normalize_before=hp.model.decoder_normalize_before,
            concat_after=hp.model.decoder_concat_after,
            positionwise_layer_type=hp.model.positionwise_layer_type,
            positionwise_conv_kernel_size=hp.model.
            positionwise_conv_kernel_size,
        )

        # define postnet
        self.postnet = (None if hp.model.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=hp.model.postnet_layers,
            n_chans=hp.model.postnet_chans,
            n_filts=hp.model.postnet_filts,
            use_batch_norm=hp.model.use_batch_norm,
            dropout_rate=hp.model.postnet_dropout_rate,
        ))

        # define final projection
        self.feat_out = torch.nn.Linear(hp.model.adim,
                                        odim * hp.model.reduction_factor)

        # initialize parameters
        self._reset_parameters(
            init_type=hp.model.transformer_init,
            init_enc_alpha=hp.model.initial_encoder_alpha,
            init_dec_alpha=hp.model.initial_decoder_alpha,
        )

        # define criterions
        self.duration_criterion = DurationPredictorLoss()
        self.energy_criterion = EnergyPredictorLoss()
        self.pitch_criterion = PitchPredictorLoss()
        self.criterion = torch.nn.L1Loss(reduction="mean")
        self.use_weighted_masking = hp.model.use_weighted_masking

    def _forward(self, xs: torch.Tensor, ilens: torch.Tensor):
        # forward encoder
        x_masks = self._source_mask(
            ilens)  # (B, Tmax, Tmax) -> torch.Size([32, 121, 121])

        hs, _ = self.encoder(
            xs, x_masks)  # (B, Tmax, adim) -> torch.Size([32, 121, 256])
        # print("ys :", ys.shape)

        # # forward duration predictor and length regulator
        d_masks = make_pad_mask_script(ilens).to(xs.device)

        d_outs = self.duration_predictor.inference(hs, d_masks)  # (B, Tmax)
        hs = self.length_regulator(hs, d_outs, ilens)  # (B, Lmax, adim)

        one_hot_energy = self.energy_predictor.inference(hs)  # (B, Lmax, adim)

        one_hot_pitch = self.pitch_predictor.inference(hs)  # (B, Lmax, adim)

        hs = hs + self.pitch_embed(one_hot_pitch)  # (B, Lmax, adim)
        hs = hs + self.energy_embed(one_hot_energy)  # (B, Lmax, adim)

        # # forward decoder
        #  h_masks = self._source_mask(olens) we can find olens from length regulator and then calculate mask
        # h_masks = torch.empty(0)

        zs, _ = self.decoder(hs, None)  # (B, Lmax, adim)

        before_outs = self.feat_out(zs).view(zs.size(0), -1,
                                             self.odim)  # (B, Lmax, odim)

        # postnet -> (B, Lmax//r * r, odim)
        after_outs = before_outs + self.postnet(before_outs.transpose(
            1, 2)).transpose(1, 2)
        return after_outs

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Generate the sequence of features given the sequences of characters.
        Args:
            x (Tensor): Input sequence of characters (T,).
            inference_args (Namespace): Dummy for compatibility.
            spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).
        Returns:
            Tensor: Output sequence of features (1, L, odim).
            None: Dummy for compatibility.
            None: Dummy for compatibility.
        """
        # setup batch axis
        ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
        xs = x.unsqueeze(0)

        # inference
        outs = self._forward(xs, ilens)  # (L, odim)

        return outs[0]

    def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor:
        """Make masks for self-attention.
        Examples:
            >>> ilens = [5, 3]
            >>> self._source_mask(ilens)
            tensor([[[1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1],
                     [1, 1, 1, 1, 1]],
                    [[1, 1, 1, 0, 0],
                     [1, 1, 1, 0, 0],
                     [1, 1, 1, 0, 0],
                     [0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0]]], dtype=torch.uint8)
        """
        x_masks = make_non_pad_mask_script(ilens)
        return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)

    def _reset_parameters(self,
                          init_type: str,
                          init_enc_alpha: float = 1.0,
                          init_dec_alpha: float = 1.0):
        # initialize parameters
        initialize(self, init_type)
        #
        # initialize alpha in scaled positional encoding
        if self.use_scaled_pos_enc:
            self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha)
            self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)