Ejemplo n.º 1
0
 def forward(self, cond_inp, output_lengths, cond_lens=None):# [B, seq_len, dim], int, [B]
     batch_size, enc_T, enc_dim = cond_inp.shape
     
     # get Random Position Offset (this *might* allow better distance generalisation)
     #trandint = torch.randint(10000, (1,), device=cond_inp.device, dtype=cond_inp.dtype)
     
     # get Query from Positional Encoding
     dec_T_max = output_lengths.max().item()
     dec_pos_emb = torch.arange(0, dec_T_max, device=cond_inp.device, dtype=cond_inp.dtype)# + trandint        
     if hasattr(self, 'pos_embedding_q'):
         dec_pos_emb = self.pos_embedding_q(dec_pos_emb.clamp(0, self.pos_embedding_q_max-1).long())[None, ...].repeat(cond_inp.size(0), 1, 1)# [B, enc_T, enc_dim]
     elif hasattr(self, 'positional_embedding'):
         dec_pos_emb = self.positional_embedding(dec_pos_emb, bsz=cond_inp.size(0))# [B, dec_T, enc_dim]
     if not self.merged_pos_enc:
         dec_pos_emb = dec_pos_emb.repeat(1, 1, self.head_num)
     if output_lengths is not None:# masking for batches
         dec_mask = get_mask_from_lengths(output_lengths).unsqueeze(2)# [B, dec_T, 1]
         dec_pos_emb = dec_pos_emb * dec_mask# [B, dec_T, enc_dim] * [B, dec_T, 1] -> [B, dec_T, enc_dim]
     q = dec_pos_emb# [B, dec_T, enc_dim]
     
     # get Key/Value from Encoder Outputs
     k = v = cond_inp# [B, enc_T, enc_dim]
     # (optional) add position encoding to Encoder outputs
     if hasattr(self, 'enc_positional_embedding'):
         enc_pos_emb = torch.arange(0, enc_T, device=cond_inp.device, dtype=cond_inp.dtype)# + trandint
         if hasattr(self, 'pos_embedding_kv'):
             enc_pos_emb = self.pos_embedding_kv(enc_pos_emb.clamp(0, self.pos_embedding_kv_max-1).long())[None, ...].repeat(cond_inp.size(0), 1, 1)# [B, enc_T, enc_dim]
         elif hasattr(self, 'enc_positional_embedding'):
             enc_pos_emb = self.enc_positional_embedding(enc_pos_emb, bsz=cond_inp.size(0))# [B, enc_T, enc_dim]
         if self.pos_enc_k:
             k = k + enc_pos_emb
         if self.pos_enc_v:
             v = v + enc_pos_emb
     
     q = q.transpose(0, 1)# [B, dec_T, enc_dim] -> [dec_T, B, enc_dim]
     k = k.transpose(0, 1)# [B, enc_T, enc_dim] -> [enc_T, B, enc_dim]
     v = v.transpose(0, 1)# [B, enc_T, enc_dim] -> [enc_T, B, enc_dim]
     
     output = self.MH_Transformer(k, q,
         src_key_padding_mask=~get_mask_from_lengths(cond_lens).bool() if (cond_lens is not None) else None,
         tgt_key_padding_mask=~get_mask_from_lengths(output_lengths).bool(),
         memory_key_padding_mask=~get_mask_from_lengths(cond_lens).bool() if (cond_lens is not None) else None)# [dec_T, B, enc_dim], [B, dec_T, enc_T]
     
     output = output.transpose(0, 1)# [dec_T, B, enc_dim] -> [B, dec_T, enc_dim]
     output = output + self.o_residual_weights * dec_pos_emb
     attention_scores = get_mask_3d(output_lengths, cond_lens) if (cond_lens is not None) else None# [B, dec_T, enc_T]
     
     if output_lengths is not None:
         output = output * dec_mask# [B, dec_T, enc_dim] * [B, dec_T, 1]
     return output, attention_scores
Ejemplo n.º 2
0
    def forward(self, model_output, targets):
        mel_target, gate_target, output_lengths, text_lengths, *_ = targets
        mel_out, attention_scores, pred_output_lengths, log_s_sum, logdet_w_sum = model_output
        batch_size, n_mel_channels, frames = mel_target.shape

        output_lengths_float = output_lengths.float()
        mel_out = mel_out.float()
        log_s_sum = log_s_sum.float()
        logdet_w_sum = logdet_w_sum.float()

        # Length Loss
        len_pred_loss = torch.nn.MSELoss()(pred_output_lengths.log(),
                                           output_lengths_float.log())

        # remove paddings before loss calc
        mask = get_mask_from_lengths(
            output_lengths)[:, None, :]  # [B, 1, T] BoolTensor
        mask = mask.expand(mask.size(0), mel_target.size(1),
                           mask.size(2))  # [B, n_mel, T] BoolTensor
        n_elems = (output_lengths_float.sum() * n_mel_channels)

        # Spectrogram Loss
        mel_out = torch.masked_select(mel_out, mask)
        loss_z = ((mel_out.pow(2).sum()) /
                  self.sigma2_2) / n_elems  # mean z (over all elements)

        loss_w = -logdet_w_sum.sum() / (n_mel_channels * frames)

        log_s_sum = log_s_sum.view(batch_size, -1, frames)
        log_s_sum = torch.masked_select(log_s_sum,
                                        mask[:, :log_s_sum.shape[1], :])
        loss_s = -log_s_sum.sum() / (n_elems)

        loss = loss_z + loss_w + loss_s + (len_pred_loss * 0.01)
        assert not torch.isnan(loss).any(), 'loss has NaN values.'

        # (optional) Guided Attention Loss
        if hasattr(self, 'guided_att'):
            att_loss = self.guided_att(attention_scores, text_lengths,
                                       output_lengths)
            loss = loss + att_loss
        else:
            att_loss = None

        if True:  # Min-Enc Attention Loss
            mask = get_mask_3d(output_lengths, text_lengths)
            attention_scores.sum((1, ))  # [B, dec_T, enc_T]
            mask

        return loss, len_pred_loss, loss_z, loss_w, loss_s, att_loss
Ejemplo n.º 3
0
    def forward(self,
                encoder_outputs,
                encoder_lengths,
                output_lengths,
                cond_lens=None,
                attention_override=None):
        if attention_override is None:
            B, enc_T, enc_dim = encoder_outputs.shape  # [Batch Size, Text Length, Encoder Dimension]
            dec_T = output_lengths.max().item()  # Length of Spectrogram

            #encoder_lengths = encoder_lengths# [B, enc_T]
            #encoder_outputs = encoder_outputs# [B, enc_T, enc_dim]

            start_pos = torch.zeros(B,
                                    device=encoder_outputs.device,
                                    dtype=encoder_outputs.dtype)  # [B]
            attention_pos = torch.arange(dec_T,
                                         device=encoder_outputs.device,
                                         dtype=encoder_outputs.dtype).expand(
                                             B, dec_T)  # [B, dec_T, enc_T]
            attention = torch.zeros(
                B,
                dec_T,
                enc_T,
                device=encoder_outputs.device,
                dtype=encoder_outputs.dtype)  # [B, dec_T, enc_T]
            for enc_inx in range(encoder_lengths.shape[1]):
                dur = encoder_lengths[:, enc_inx]  # [B]
                end_pos = start_pos + dur  # [B]
                if cond_lens is not None:  # if last char, extend till end of decoder sequence
                    mask = (cond_lens == (enc_inx + 1))  # [B]
                    if mask.any():
                        end_pos.masked_fill_(mask, dec_T)

                att = (attention_pos >= start_pos.unsqueeze(-1).repeat(
                    1, dec_T)) & (attention_pos < end_pos.unsqueeze(-1).repeat(
                        1, dec_T))
                attention[:, :, enc_inx][
                    att] = 1.  # set predicted duration values to positive

                start_pos = start_pos + dur  # [B]
            if cond_lens is not None:
                attention = attention * get_mask_3d(output_lengths, cond_lens)
        else:
            attention = attention_override
        return attention.matmul(
            encoder_outputs
        )  # [B, dec_T, enc_T] @ [B, enc_T, enc_dim] -> [B, dec_T, enc_dim]
Ejemplo n.º 4
0
    def parse_encoder_outputs(self, encoder_outputs, durations, output_lengths,
                              text_lengths):
        """
        Acts as Monotonic Attention for Encoder Outputs.
        
        [B, enc_T, enc_dim] x [B, enc_T, durations] -> [B, dec_T, enc_dim]
        """
        B, enc_T, enc_dim = encoder_outputs.shape  # [Batch Size, Text Length, Encoder Dimension]
        dec_T = output_lengths.max().item()  # Length of Features

        start_pos = torch.zeros(B,
                                device=encoder_outputs.device,
                                dtype=encoder_outputs.dtype)  # [B]
        attention_pos = torch.arange(dec_T,
                                     device=encoder_outputs.device,
                                     dtype=encoder_outputs.dtype).expand(
                                         B, dec_T)  # [B, dec_T, enc_T]
        attention = torch.zeros(
            B,
            dec_T,
            enc_T,
            device=encoder_outputs.device,
            dtype=encoder_outputs.dtype)  # [B, dec_T, enc_T]
        for enc_inx in range(durations.shape[1]):
            dur = durations[:, enc_inx]  # [B]
            end_pos = start_pos + dur  # [B]
            if text_lengths is not None:  # if last char, extend till end of decoder sequence
                mask = (text_lengths == (enc_inx + 1))  # [B]
                if mask.any():
                    end_pos.masked_fill_(mask, dec_T)

            att = (attention_pos >= start_pos.unsqueeze(-1).repeat(1, dec_T)
                   ) & (attention_pos < end_pos.unsqueeze(-1).repeat(1, dec_T))
            attention[:, :, enc_inx][
                att] = 1.  # set predicted duration values to positive

            start_pos = start_pos + dur  # [B]
        if text_lengths is not None:
            attention = attention * get_mask_3d(output_lengths, text_lengths)
        return attention.matmul(encoder_outputs)
Ejemplo n.º 5
0
    def forward(self,
                cond_inp,
                output_lengths,
                cond_lens=None):  # [B, seq_len, dim], int, [B]
        batch_size, enc_T, enc_dim = cond_inp.shape

        # get Random Position Offset (this *might* allow better distance generalisation)
        #trandint = torch.randint(10000, (1,), device=cond_inp.device, dtype=cond_inp.dtype)

        # get Query from Positional Encoding
        dec_T_max = output_lengths.max().item()
        dec_pos_emb = torch.arange(0,
                                   dec_T_max,
                                   device=cond_inp.device,
                                   dtype=cond_inp.dtype)  # + trandint
        if hasattr(self, 'pos_embedding_q'):
            dec_pos_emb = self.pos_embedding_q(
                dec_pos_emb.clamp(
                    0, self.pos_embedding_q_max - 1).long())[None, ...].repeat(
                        cond_inp.size(0), 1, 1)  # [B, enc_T, enc_dim]
        elif hasattr(self, 'positional_embedding'):
            dec_pos_emb = self.positional_embedding(
                dec_pos_emb, bsz=cond_inp.size(0))  # [B, dec_T, enc_dim]
        if not self.merged_pos_enc:
            dec_pos_emb = dec_pos_emb.repeat(1, 1, self.head_num)
        if output_lengths is not None:  # masking for batches
            dec_mask = get_mask_from_lengths(output_lengths).unsqueeze(
                2)  # [B, dec_T, 1]
            dec_pos_emb = dec_pos_emb * dec_mask  # [B, dec_T, enc_dim] * [B, dec_T, 1] -> [B, dec_T, enc_dim]
        q = dec_pos_emb  # [B, dec_T, enc_dim]

        # get Key/Value from Encoder Outputs
        k = v = cond_inp  # [B, enc_T, enc_dim]
        # (optional) add position encoding to Encoder outputs
        if hasattr(self, 'enc_positional_embedding'):
            enc_pos_emb = torch.arange(0,
                                       enc_T,
                                       device=cond_inp.device,
                                       dtype=cond_inp.dtype)  # + trandint
            if hasattr(self, 'pos_embedding_kv'):
                enc_pos_emb = self.pos_embedding_kv(
                    enc_pos_emb.clamp(0, self.pos_embedding_kv_max -
                                      1).long())[None, ...].repeat(
                                          cond_inp.size(0), 1,
                                          1)  # [B, enc_T, enc_dim]
            elif hasattr(self, 'enc_positional_embedding'):
                enc_pos_emb = self.enc_positional_embedding(
                    enc_pos_emb, bsz=cond_inp.size(0))  # [B, enc_T, enc_dim]
            if self.pos_enc_k:
                k = k + enc_pos_emb
            if self.pos_enc_v:
                v = v + enc_pos_emb
        enc_mask = get_mask_from_lengths(cond_lens).unsqueeze(1).repeat(
            1, q.size(1), 1) if (cond_lens
                                 is not None) else None  # [B, dec_T, enc_T]

        if not self.pytorch_native_mha:
            output, attention_scores = self.multi_head_attention(
                q, k, v, mask=enc_mask
            )  # [B, dec_T, enc_dim], [B, n_head, dec_T, enc_T]
        else:
            q = q.transpose(0, 1)  # [B, dec_T, enc_dim] -> [dec_T, B, enc_dim]
            k = k.transpose(0, 1)  # [B, enc_T, enc_dim] -> [enc_T, B, enc_dim]
            v = v.transpose(0, 1)  # [B, enc_T, enc_dim] -> [enc_T, B, enc_dim]

            enc_mask = ~enc_mask[:, 0, :] if (
                cond_lens
                is not None) else None  # [B, dec_T, enc_T] -> # [B, enc_T]
            attn_mask = ~get_mask_3d(
                output_lengths,
                cond_lens).repeat_interleave(self.head_num, 0) if (
                    cond_lens is not None) else None  #[B*n_head, dec_T, enc_T]
            attn_mask = attn_mask.float() * -35500.0 if (cond_lens
                                                         is not None) else None

            output, attention_scores = self.multi_head_attention(
                q, k, v, key_padding_mask=enc_mask,
                attn_mask=attn_mask)  # [dec_T, B, enc_dim], [B, dec_T, enc_T]

            output = output.transpose(
                0, 1)  # [dec_T, B, enc_dim] -> [B, dec_T, enc_dim]
            output = output + self.o_residual_weights * dec_pos_emb
            attention_scores = attention_scores * get_mask_3d(
                output_lengths, cond_lens) if (
                    cond_lens is not None) else attention_scores
            #attention_scores # [B, dec_T, enc_T]

            for self_att_layer, residual_weight in zip(
                    self.self_attention_layers, self.self_att_o_rws):
                q = output.transpose(
                    0, 1)  # [B, dec_T, enc_dim] -> [dec_T, B, enc_dim]

                output, att_sc = self_att_layer(
                    q, k, v, key_padding_mask=enc_mask, attn_mask=attn_mask
                )  # ..., [dec_T, B, enc_dim], [B, dec_T, enc_T])

                output = output.transpose(
                    0, 1)  # [dec_T, B, enc_dim] -> [B, dec_T, enc_dim]
                output = output * residual_weight + q.transpose(
                    0, 1)  # ([B, dec_T, enc_dim] * rw) + [B, dec_T, enc_dim]
                att_sc = att_sc * get_mask_3d(output_lengths, cond_lens) if (
                    cond_lens is not None) else att_sc
                attention_scores = attention_scores + att_sc

            attention_scores = attention_scores / (1 +
                                                   len(self.self_att_o_rws))

        if output_lengths is not None:
            output = output * dec_mask  # [B, dec_T, enc_dim] * [B, dec_T, 1]
        return output, attention_scores