예제 #1
0
    def forward(self, padded_input, input_lengths):
        """
        Args:
            padded_input: N x T x D
            input_lengths: N

        Returns:
            enc_output: N x T x H
        """
        # Prepare masks
        non_pad_mask = sequence_mask(input_lengths).unsqueeze(-1)
        length = padded_input.size(1)
        slf_attn_mask = get_attn_pad_mask(input_lengths, length)

        # Forward
        enc_output = self.dropout(
            self.layer_norm_in(self.token_emb(padded_input)) +
            self.positional_encoding(padded_input))

        for enc_layer in self.layer_stack:
            enc_output = enc_layer(
                enc_output,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask)

        return enc_output
예제 #2
0
    def forward(self,
                encoder_padded_outputs,
                encoder_input_lengths,
                masking=True):
        """
        Args:
            padded_input: N x To
            encoder_padded_outputs: N x Ti x H

        Returns:
        """
        # Get Deocder Input and Output

        # Prepare masks
        mask = sequence_mask(encoder_input_lengths)  # B x T

        logits = self.tgt_word_prj(encoder_padded_outputs)
        if masking:
            # B x T x V
            mask = mask.view(mask.shape[0], mask.shape[1],
                             1).repeat(1, 1, self.dim_output)
            logits *= mask

        len_logits = encoder_input_lengths

        return logits, len_logits
예제 #3
0
 def forward(self, x, target, length):
     target.requires_grad = False
     mask = sequence_mask(sequence_length=length,
                          maxlen=target.size(1)).unsqueeze(2).float()
     mask = mask.expand_as(x)
     loss = F.l1_loss(x * mask, target * mask, reduction='sum')
     loss = loss / mask.sum()
     return loss
예제 #4
0
    def forward(self, characters, text_lengths, mel_specs):
        B = characters.size(0)
        mask = sequence_mask(text_lengths).to(characters.device)

        inputs = self.embedding(characters)
        encoder_outputs = self.encoder(inputs)
        mel_outputs, alignments, stop_tokens = self.decoder(
            encoder_outputs, mel_specs, mask)
        # 複数フレームがまとまっているので元に戻す
        mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
        # メルスペクトログラムを線形スペクトログラムに変換する
        linear_outputs = self.postnet(mel_outputs)
        linear_outputs = self.last_linear(linear_outputs)
        return mel_outputs, linear_outputs, alignments, stop_tokens
    def forward(self, padded_input, input_lengths):
        """
        Args:
            padded_input: N x T x D
            input_lengths: N

        Returns:
            enc_output: N x T x H
        """
        x, input_lengths = self.conv(padded_input, input_lengths)
        x = self.dropout(x)
        alphas = self.linear(x).squeeze(-1)
        alphas = torch.sigmoid(alphas)
        pad_mask = sequence_mask(input_lengths)

        return alphas * pad_mask
예제 #6
0
    def forward(self, encoder_padded_outputs, encoder_input_lengths):
        """
        Args:
            padded_input: N x To
            encoder_padded_outputs: N x Ti x H

        Returns:
        """
        # Get Deocder Input and Output

        # Prepare masks
        mask = sequence_mask(encoder_input_lengths)  # B x T

        # B x T x V
        mask = mask.view(mask.shape[0], mask.shape[1],
                         1).repeat(1, 1, self.dim_output)

        # before softmax
        logits = self.tgt_word_prj(encoder_padded_outputs) * mask

        return logits