Exemplo n.º 1
0
    def forward(self, x, z, output_lengths=None):
        x = self.start(
            x
        )  # [B, in_dim, T//scale_factors] -> [B, self.decoder_dims[0], T//scale_factors]
        if output_lengths is not None:
            mask = get_mask_from_lengths(output_lengths).unsqueeze(1)
            x.masked_fill_(~mask, 0.0)

        for gblock in self.Gblocks:
            x = gblock(x, z, output_lengths=output_lengths)

        if output_lengths is not None:
            scale_factor = x.shape[2] / output_lengths.sum().max()
            if scale_factor != 1.0:
                output_lengths = (output_lengths.float() *
                                  (scale_factor)).long()
            mask = ~get_mask_from_lengths(output_lengths).unsqueeze(1)
            x.masked_fill_(mask, 0.0)
        x = self.end(x)  # [B, 1, T]

        x = x.tanh()
        if output_lengths is not None:
            x.masked_fill_(mask, 0.0)

        return x  # [B, 1, T]
Exemplo n.º 2
0
 def _make_masks(ilens, olens):
     """Make masks indicating non-padded part.
     Args:
         ilens (LongTensor or List): Batch of lengths (B,).
         olens (LongTensor or List): Batch of lengths (B,).
     Returns:
         Tensor: Mask tensor indicating non-padded part.
                 dtype=torch.uint8 in PyTorch 1.2-
                 dtype=torch.bool in PyTorch 1.2+ (including 1.2)
     Examples:
         >>> ilens, olens = [5, 2], [8, 5]
         >>> _make_mask(ilens, olens)
         tensor([[[1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1],
                  [1, 1, 1, 1, 1]],
                 [[1, 1, 0, 0, 0],
                  [1, 1, 0, 0, 0],
                  [1, 1, 0, 0, 0],
                  [1, 1, 0, 0, 0],
                  [1, 1, 0, 0, 0],
                  [0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0]]], dtype=torch.uint8)
     """
     in_masks = get_mask_from_lengths(ilens)  # (B, T_in)
     out_masks = get_mask_from_lengths(olens)  # (B, T_out)
     return out_masks.unsqueeze(-1) & in_masks.unsqueeze(
         -2)  # (B, T_out, T_in)
Exemplo n.º 3
0
 def _make_masks(ilens, olens):
     """Make masks indicating non-padded part.
     Args:
         ilens (LongTensor or List): Batch of lengths (B,).
         olens (LongTensor or List): Batch of lengths (B,).
     Returns:
         Tensor: Mask tensor indicating non-padded part.
     """
     in_masks = get_mask_from_lengths(ilens)  # (B, T_in)
     out_masks = get_mask_from_lengths(olens)  # (B, T_out)
     return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2)  # (B, T_out, T_in)
Exemplo n.º 4
0
 def forward(self, cond_inp, output_lengths, cond_lens=None):# [B, seq_len, dim], int, [B]
     batch_size, enc_T, enc_dim = cond_inp.shape
     
     # get Random Position Offset (this *might* allow better distance generalisation)
     #trandint = torch.randint(10000, (1,), device=cond_inp.device, dtype=cond_inp.dtype)
     
     # get Query from Positional Encoding
     dec_T_max = output_lengths.max().item()
     dec_pos_emb = torch.arange(0, dec_T_max, device=cond_inp.device, dtype=cond_inp.dtype)# + trandint        
     if hasattr(self, 'pos_embedding_q'):
         dec_pos_emb = self.pos_embedding_q(dec_pos_emb.clamp(0, self.pos_embedding_q_max-1).long())[None, ...].repeat(cond_inp.size(0), 1, 1)# [B, enc_T, enc_dim]
     elif hasattr(self, 'positional_embedding'):
         dec_pos_emb = self.positional_embedding(dec_pos_emb, bsz=cond_inp.size(0))# [B, dec_T, enc_dim]
     if not self.merged_pos_enc:
         dec_pos_emb = dec_pos_emb.repeat(1, 1, self.head_num)
     if output_lengths is not None:# masking for batches
         dec_mask = get_mask_from_lengths(output_lengths).unsqueeze(2)# [B, dec_T, 1]
         dec_pos_emb = dec_pos_emb * dec_mask# [B, dec_T, enc_dim] * [B, dec_T, 1] -> [B, dec_T, enc_dim]
     q = dec_pos_emb# [B, dec_T, enc_dim]
     
     # get Key/Value from Encoder Outputs
     k = v = cond_inp# [B, enc_T, enc_dim]
     # (optional) add position encoding to Encoder outputs
     if hasattr(self, 'enc_positional_embedding'):
         enc_pos_emb = torch.arange(0, enc_T, device=cond_inp.device, dtype=cond_inp.dtype)# + trandint
         if hasattr(self, 'pos_embedding_kv'):
             enc_pos_emb = self.pos_embedding_kv(enc_pos_emb.clamp(0, self.pos_embedding_kv_max-1).long())[None, ...].repeat(cond_inp.size(0), 1, 1)# [B, enc_T, enc_dim]
         elif hasattr(self, 'enc_positional_embedding'):
             enc_pos_emb = self.enc_positional_embedding(enc_pos_emb, bsz=cond_inp.size(0))# [B, enc_T, enc_dim]
         if self.pos_enc_k:
             k = k + enc_pos_emb
         if self.pos_enc_v:
             v = v + enc_pos_emb
     
     q = q.transpose(0, 1)# [B, dec_T, enc_dim] -> [dec_T, B, enc_dim]
     k = k.transpose(0, 1)# [B, enc_T, enc_dim] -> [enc_T, B, enc_dim]
     v = v.transpose(0, 1)# [B, enc_T, enc_dim] -> [enc_T, B, enc_dim]
     
     output = self.MH_Transformer(k, q,
         src_key_padding_mask=~get_mask_from_lengths(cond_lens).bool() if (cond_lens is not None) else None,
         tgt_key_padding_mask=~get_mask_from_lengths(output_lengths).bool(),
         memory_key_padding_mask=~get_mask_from_lengths(cond_lens).bool() if (cond_lens is not None) else None)# [dec_T, B, enc_dim], [B, dec_T, enc_T]
     
     output = output.transpose(0, 1)# [dec_T, B, enc_dim] -> [B, dec_T, enc_dim]
     output = output + self.o_residual_weights * dec_pos_emb
     attention_scores = get_mask_3d(output_lengths, cond_lens) if (cond_lens is not None) else None# [B, dec_T, enc_T]
     
     if output_lengths is not None:
         output = output * dec_mask# [B, dec_T, enc_dim] * [B, dec_T, 1]
     return output, attention_scores
Exemplo n.º 5
0
 def inference(self, text, speaker_ids, text_lengths=None, sigma=1.0):
     assert not torch.isnan(text).any(), 'text has NaN values.'
     embedded_text = self.embedding(text).transpose(1, 2) # [B, embed, sequence]
     assert not torch.isnan(embedded_text).any(), 'encoder_outputs has NaN values.'
     encoder_outputs = self.encoder.inference(embedded_text, speaker_ids=speaker_ids) # [B, enc_T, enc_dim]
     assert not torch.isnan(encoder_outputs).any(), 'encoder_outputs has NaN values.'
     
     # predict length of each input
     enc_out_mask = get_mask_from_lengths(text_lengths) if (text_lengths is not None) else None
     encoder_lengths = self.length_predictor(encoder_outputs, enc_out_mask)
     assert not torch.isnan(encoder_lengths).any(), 'encoder_lengths has NaN values.'
     
     # sum lengths (used to predict mel-spec length)
     encoder_lengths = encoder_lengths.clamp(1, 128)
     pred_output_lengths = encoder_lengths.sum((1,)).long()
     assert not torch.isnan(encoder_lengths).any(), 'encoder_lengths has NaN values.'
     assert not torch.isnan(pred_output_lengths).any(), 'pred_output_lengths has NaN values.'
     
     if self.speaker_embedding_dim:
         embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
         embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
         encoder_outputs = torch.cat((encoder_outputs, embedded_speakers), dim=2) # [batch, enc_T, enc_dim]
     
     # Positional Attention
     cond, attention_scores = self.positional_attention(encoder_outputs, pred_output_lengths, cond_lens=text_lengths)
     cond = cond.transpose(1, 2)
     assert not torch.isnan(cond).any(), 'cond has NaN values.'
     # [B, enc_T, enc_dim] -> [B, enc_dim, dec_T] # Masked Multi-head Attention
     
     # Decoder
     mel_outputs = self.decoder.infer(cond, sigma=sigma) # [B, dec_T, emb] -> [B, n_mel, dec_T] # Series of Flows
     assert not torch.isnan(mel_outputs).any(), 'mel_outputs has NaN values.'
     
     return self.mask_outputs(
         [mel_outputs, attention_scores, None, None, None])
Exemplo n.º 6
0
    def forward(self, model_output, targets):
        mel_target, gate_target, output_lengths, *_ = targets
        mel_target.requires_grad = False
        gate_target.requires_grad = False
        mel_out, mel_out_postnet, gate_out, _ = model_output
        gate_target = gate_target.view(-1, 1)
        gate_out = gate_out.view(-1, 1)

        # remove paddings before loss calc
        if self.masked_select:
            mask = get_mask_from_lengths(output_lengths)
            mask = mask.expand(mel_target.size(1), mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)
            mel_target = torch.masked_select(mel_target, mask)
            mel_out = torch.masked_select(mel_out, mask)
            mel_out_postnet = torch.masked_select(mel_out_postnet, mask)

        if self.loss_func == 'MSELoss':
            mel_loss = nn.MSELoss()(mel_out, mel_target) + \
                nn.MSELoss()(mel_out_postnet, mel_target)
        elif self.loss_func == 'SmoothL1Loss':
            mel_loss = nn.SmoothL1Loss()(mel_out, mel_target) + \
                nn.SmoothL1Loss()(mel_out_postnet, mel_target)

        gate_loss = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)(
            gate_out, gate_target)
        return mel_loss + gate_loss, gate_loss
Exemplo n.º 7
0
def get_attention_from_lengths(
        memory: Tensor,  # FloatTensor[B, enc_T, enc_dim]
        enc_durations: Tensor,  # FloatTensor[B, enc_T]
        text_lengths: Tensor  #  LongTensor[B]
):
    B, enc_T, mem_dim = memory.shape

    mask = get_mask_from_lengths(text_lengths)
    enc_durations.masked_fill_(~mask, 0.0)

    enc_durations = enc_durations.round()  #  [B, enc_T]
    dec_T = int(enc_durations.sum(dim=1).max().item())  # [B, enc_T] -> int

    attention_contexts = torch.zeros(B,
                                     dec_T,
                                     mem_dim,
                                     device=memory.device,
                                     dtype=memory.dtype)  # [B, dec_T, enc_dim]
    for i in range(B):
        mem_temp = []
        for j in range(int(text_lengths[i].item())):
            duration = int(enc_durations[i, j].item())

            # [B, enc_T, enc_dim] -> [1, enc_dim] -> [duration, enc_dim]
            mem_temp.append(memory[i, j:j + 1].repeat(duration, 1))
        mem_temp = torch.cat(
            mem_temp, dim=0)  # [[duration, enc_dim], ...] -> [dec_T, enc_dim]
        min_len = min(attention_contexts.shape[1], mem_temp.shape[0])
        attention_contexts[i, :min_len] = mem_temp[:min_len]

    return attention_contexts  # [B, dec_T, enc_dim]
Exemplo n.º 8
0
    def forward(self, inputs):
        text, text_lengths, gt_mels, max_len, output_lengths, speaker_ids, torchmoji_hidden, preserve_decoder_states = inputs
        text_lengths, output_lengths = text_lengths.data, output_lengths.data

        assert not torch.isnan(text).any(), 'text has NaN values.'
        embedded_text = self.embedding(text).transpose(
            1, 2)  # [B, embed, sequence]
        assert not torch.isnan(
            embedded_text).any(), 'embedded_text has NaN values.'
        encoder_outputs = self.encoder(
            embedded_text, text_lengths,
            speaker_ids=speaker_ids)  # [B, enc_T, enc_dim]
        assert not torch.isnan(
            encoder_outputs).any(), 'encoder_outputs has NaN values.'

        # predict length of each input
        enc_out_mask = get_mask_from_lengths(text_lengths).unsqueeze(
            -1)  # [B, enc_T, 1]
        encoder_lengths = self.length_predictor(
            encoder_outputs, enc_out_mask)  # [B, enc_T, enc_dim]
        assert not torch.isnan(
            encoder_lengths).any(), 'encoder_lengths has NaN values.'

        # sum lengths (used to predict mel-spec length)
        encoder_lengths = encoder_lengths.clamp(1e-6, 4096)
        pred_output_lengths = encoder_lengths.sum((1, ))
        assert not torch.isnan(
            encoder_lengths).any(), 'encoder_lengths has NaN values.'
        assert not torch.isnan(
            pred_output_lengths).any(), 'pred_output_lengths has NaN values.'

        if self.speaker_embedding_dim:
            embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
            embedded_speakers = embedded_speakers.repeat(
                1, encoder_outputs.size(1), 1)
            encoder_outputs = torch.cat((encoder_outputs, embedded_speakers),
                                        dim=2)  # [batch, enc_T, enc_dim]

        # Positional Attention
        cond, attention_scores = self.positional_attention(
            encoder_outputs, output_lengths, cond_lens=text_lengths)
        cond = cond.transpose(1, 2)
        assert not torch.isnan(cond).any(), 'cond has NaN values.'
        # [B, enc_T, enc_dim] -> [B, enc_dim, dec_T] # Masked Multi-head Attention

        # Decoder
        mel_outputs, log_s_sum, logdet_w_sum = self.decoder(
            gt_mels, cond
        )  # [B, n_mel, dec_T], [B, dec_T, enc_dim] -> [B, n_mel, dec_T], [B] # Series of Flows
        assert not torch.isnan(
            mel_outputs).any(), 'mel_outputs has NaN values.'
        assert not torch.isnan(log_s_sum).any(), 'mel_outputs has NaN values.'
        assert not torch.isnan(
            logdet_w_sum).any(), 'mel_outputs has NaN values.'

        return self.mask_outputs([
            mel_outputs, attention_scores, pred_output_lengths, log_s_sum,
            logdet_w_sum
        ], output_lengths)
Exemplo n.º 9
0
 def mask_outputs(self, outputs, output_lengths=None):
     if self.mask_padding and output_lengths is not None:
         mask = ~get_mask_from_lengths(output_lengths)
         mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
         mask = mask.permute(1, 0, 2)
         # [B, n_mel, steps]
         outputs[0].data.masked_fill_(mask, 0.0)  # [B, n_mel, T]
     return outputs
Exemplo n.º 10
0
    def forward(self, model_output, targets):
        mel_target, gate_target, output_lengths, text_lengths, *_ = targets
        mel_out, attention_scores, pred_output_lengths, log_s_sum, logdet_w_sum = model_output
        batch_size, n_mel_channels, frames = mel_target.shape

        output_lengths_float = output_lengths.float()
        mel_out = mel_out.float()
        log_s_sum = log_s_sum.float()
        logdet_w_sum = logdet_w_sum.float()

        # Length Loss
        len_pred_loss = torch.nn.MSELoss()(pred_output_lengths.log(),
                                           output_lengths_float.log())

        # remove paddings before loss calc
        mask = get_mask_from_lengths(
            output_lengths)[:, None, :]  # [B, 1, T] BoolTensor
        mask = mask.expand(mask.size(0), mel_target.size(1),
                           mask.size(2))  # [B, n_mel, T] BoolTensor
        n_elems = (output_lengths_float.sum() * n_mel_channels)

        # Spectrogram Loss
        mel_out = torch.masked_select(mel_out, mask)
        loss_z = ((mel_out.pow(2).sum()) /
                  self.sigma2_2) / n_elems  # mean z (over all elements)

        loss_w = -logdet_w_sum.sum() / (n_mel_channels * frames)

        log_s_sum = log_s_sum.view(batch_size, -1, frames)
        log_s_sum = torch.masked_select(log_s_sum,
                                        mask[:, :log_s_sum.shape[1], :])
        loss_s = -log_s_sum.sum() / (n_elems)

        loss = loss_z + loss_w + loss_s + (len_pred_loss * 0.01)
        assert not torch.isnan(loss).any(), 'loss has NaN values.'

        # (optional) Guided Attention Loss
        if hasattr(self, 'guided_att'):
            att_loss = self.guided_att(attention_scores, text_lengths,
                                       output_lengths)
            loss = loss + att_loss
        else:
            att_loss = None

        if True:  # Min-Enc Attention Loss
            mask = get_mask_3d(output_lengths, text_lengths)
            attention_scores.sum((1, ))  # [B, dec_T, enc_T]
            mask

        return loss, len_pred_loss, loss_z, loss_w, loss_s, att_loss
Exemplo n.º 11
0
    def forward(self, dec_inp, seq_lens=None):
        if self.word_emb is None:
            inp = dec_inp
            mask = get_mask_from_lengths(seq_lens).unsqueeze(2)
        else:
            inp = self.word_emb(dec_inp)
            # [bsz x L x 1]
            mask = (dec_inp != pad_idx).unsqueeze(2)

        pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype)
        pos_emb = self.pos_emb(pos_seq) * mask
        out = self.drop(inp + pos_emb)

        for layer in self.layers:
            out = layer(out, mask=mask)

        # out = self.drop(out)
        return out, mask
Exemplo n.º 12
0
 def forward(self, x, z=None, output_lengths=None):  # [B, in_dim, T]
     if hasattr(self, 'bn'):
         x = self.bn(x, z)
     x = self.act_func(x)
     if hasattr(self, 'scale'):
         if self.downsample:
             F.avg_pool1d(x, kernel_size=self.scale)
         else:
             x = F.interpolate(
                 x, scale_factor=self.scale,
                 mode='linear')  # [B, in_dim, T]   -> [B, in_dim, x*T]
     if output_lengths is not None:
         scale_factor = x.shape[2] / output_lengths.sum().max()
         if scale_factor != 1.0:
             output_lengths = (output_lengths.float() *
                               (scale_factor)).long()
         mask = get_mask_from_lengths(output_lengths).unsqueeze(1)
         x.masked_fill_(~mask, 0.0)
     x = self.conv(x)  # [B, in_dim, x*T] -> [B, out_dim, x*T]
     return x  # [B, out_dim, x*T]
Exemplo n.º 13
0
    def forward(self, h, z, output_lengths=None):
        scaled_h = F.interpolate(
            h, scale_factor=self.scale, mode='linear'
        ) if self.scale != 1 else h  # [B, input_dim, T] -> [B, input_dim, x*T]
        if output_lengths is not None:
            scale_factor = scaled_h.shape[2] / output_lengths.sum().max()
            if scale_factor != 1.0:
                output_lengths = (output_lengths.float() *
                                  (scale_factor)).long()
            mask = get_mask_from_lengths(output_lengths).unsqueeze(1)
            scaled_h.masked_fill_(~mask, 0.0)
        residual = self.skip_conv(
            scaled_h)  # [B, input_dim, x*T] -> [B, output_dim, x*T]

        for i, resblock in enumerate(
                self.resblocks):  # [B, input_dim, T] -> [B, output_dim, x*T]
            h = resblock(h, z, output_lengths)
            if i == self.res_block_id:
                h += residual
                residiual = h

        return h + residual  # [B, output_dim, x*T]
Exemplo n.º 14
0
def glow_loss(z, log_s_sum, logdet_w_sum, output_lengths, sigma):
    dec_T = output_lengths.max()
    
    B = z.shape[0]
    z = z.view(z.shape[0], -1, dec_T).float()
    log_s_sum = log_s_sum.view(B, -1, dec_T)
    B, z_channels, dec_T = z.shape
    
    n_elems = (output_lengths.float().sum()*z_channels)
    
    # remove paddings before loss calc
    mask = get_mask_from_lengths(output_lengths)[:, None, :] # [B, 1, T] BoolTensor
    mask = mask.expand(B, z_channels, dec_T)# [B, z_channels, T] BoolTensor
    
    z = torch.masked_select(z, mask)
    loss_z = ((z.pow(2).sum()) / sigma)/n_elems # mean z (over all elements)
    
    log_s_sum = torch.masked_select(log_s_sum , mask[:, :log_s_sum.shape[1], :])
    loss_s = -log_s_sum.float().sum()/n_elems
    
    loss_w = -logdet_w_sum.float().sum()/(z_channels*dec_T)
    
    loss = loss_z+loss_w+loss_s
    return loss, loss_z, loss_w, loss_s
Exemplo n.º 15
0
    def forward(self, model, pred, gt, loss_scalars, resGAN=None, dbGAN=None, infGAN=None):
        loss_dict = {}
        file_losses = {}# dict of {"audiofile": {"spec_MSE": spec_MSE, "avg_prob": avg_prob, ...}, ...}
        
        B, n_mel, mel_T = gt['gt_mel'].shape
        tfB = B//(model.decoder.half_inference_mode+1)
        for i in range(tfB):
            current_time = time.time()
            if gt['audiopath'][i] not in file_losses:
                file_losses[gt['audiopath'][i]] = {'speaker_id_ext': gt['speaker_id_ext'][i], 'time': current_time}
        
        if True:
            pred_mel_postnet = pred['pred_mel_postnet']
            pred_mel         = pred['pred_mel']
            gt_mel           =   gt['gt_mel']
            mel_lengths      =   gt['mel_lengths']
            
            mask = get_mask_from_lengths(mel_lengths)
            mask = mask.expand(gt_mel.size(1), *mask.shape).permute(1, 0, 2)
            pred_mel_postnet.masked_fill_(~mask, 0.0)
            pred_mel        .masked_fill_(~mask, 0.0)
            
            with torch.no_grad():
                assert not torch.isnan(pred_mel).any(), 'mel has NaNs'
                assert not torch.isinf(pred_mel).any(), 'mel has Infs'
                assert not torch.isnan(pred_mel_postnet).any(), 'mel has NaNs'
                assert not torch.isinf(pred_mel_postnet).any(), 'mel has Infs'
            
            if model.decoder.half_inference_mode:
                pred_mel_postnet = pred_mel_postnet.chunk(2, dim=0)[0]
                pred_mel         = pred_mel        .chunk(2, dim=0)[0]
                gt_mel           = gt_mel          .chunk(2, dim=0)[0]
                mel_lengths      = mel_lengths     .chunk(2, dim=0)[0]
                mask             = mask            .chunk(2, dim=0)[0]
            B, n_mel, mel_T = gt_mel.shape
            
            teacher_force_till = loss_scalars.get('teacher_force_till',   0)
            p_teacher_forcing  = loss_scalars.get('p_teacher_forcing' , 1.0)
            if p_teacher_forcing == 0.0 and teacher_force_till > 1:
                gt_mel           = gt_mel          [:, :, :teacher_force_till]
                pred_mel         = pred_mel        [:, :, :teacher_force_till]
                pred_mel_postnet = pred_mel_postnet[:, :, :teacher_force_till]
                mel_lengths      = mel_lengths.clamp(max=teacher_force_till)
            
            # spectrogram / decoder loss
            pred_mel_selected = torch.masked_select(pred_mel, mask)
            gt_mel_selected   = torch.masked_select(gt_mel,   mask)
            spec_SE = nn.MSELoss(reduction='none')(pred_mel_selected, gt_mel_selected)
            loss_dict['spec_MSE'] = spec_SE.mean()
            
            losses = spec_SE.split([x*n_mel for x in mel_lengths.cpu()])
            for i in range(tfB):
                audiopath = gt['audiopath'][i]
                file_losses[audiopath]['spec_MSE'] = losses[i].mean().item()
            
            # postnet
            pred_mel_postnet_selected = torch.masked_select(pred_mel_postnet, mask)
            loss_dict['postnet_MSE'] = nn.MSELoss()(pred_mel_postnet_selected, gt_mel_selected)
            
            # squared by frame, mean postnet
            mask = mask.transpose(1, 2)[:, :, :1]# [B, mel_T, n_mel] -> [B, mel_T, 1]
            
            spec_AE = nn.L1Loss(reduction='none')(pred_mel, gt_mel).transpose(1, 2)# -> [B, mel_T, n_mel]
            spec_AE = spec_AE.masked_select(mask).view(mel_lengths.sum(), n_mel)   # -> [B* mel_T, n_mel]
            loss_dict['spec_MFSE'] = (spec_AE * spec_AE.mean(dim=1, keepdim=True)).mean()# multiply by frame means (similar to square op from MSE) and get the mean of the losses
            
            post_AE = nn.L1Loss(reduction='none')(pred_mel_postnet, gt_mel).transpose(1, 2)# -> [B, mel_T, n_mel]
            post_AE = post_AE.masked_select(mask).view(mel_lengths.sum(), n_mel)# -> [B*mel_T, n_mel]
            loss_dict['postnet_MFSE'] = (post_AE * post_AE.mean(dim=1, keepdim=True)).mean()# multiply by frame means (similar to square op from MSE) and get the mean of the losses
            del gt_mel, spec_AE, post_AE,#pred_mel_postnet, pred_mel
        
        if True: # gate/stop loss
            gate_target =   gt['gt_gate_logits'  ]
            gate_out    = pred['pred_gate_logits']
            if model.decoder.half_inference_mode:
                gate_target = gate_target.chunk(2, dim=0)[0]
                gate_out    = gate_out   .chunk(2, dim=0)[0]
            gate_target = gate_target.view(-1, 1)
            gate_out    =    gate_out.view(-1, 1)
            
            loss_dict['gate_loss'] = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)(gate_out, gate_target)
            del gate_target, gate_out
        
        if True: # SylpsNet loss
            syl_mu     = pred['pred_sylps_mu']
            syl_logvar = pred['pred_sylps_logvar']
            if model.decoder.half_inference_mode:
                syl_logvar  = syl_logvar.chunk(2, dim=0)[0]
                syl_mu      = syl_mu    .chunk(2, dim=0)[0]
            loss_dict['sylps_kld'] = -0.5 * (1 + syl_logvar - syl_logvar.exp() - syl_mu.pow(2)).sum()/B
            del syl_mu, syl_logvar
        
        if True: # Pred Sylps loss
            pred_sylps = pred['pred_sylps'].squeeze(1)# [B, 1] -> [B]
            sylps_target = gt['gt_sylps']
            if model.decoder.half_inference_mode:
                pred_sylps      = pred_sylps     .chunk(2, dim=0)[0]
                sylps_target    = sylps_target   .chunk(2, dim=0)[0]
            loss_dict['sylps_MAE'] =  nn.L1Loss()(pred_sylps, sylps_target)
            loss_dict['sylps_MSE'] = nn.MSELoss()(pred_sylps, sylps_target)
            del pred_sylps, sylps_target
        
        if True:# Diagonal Attention Guiding
            alignments     = pred['alignments']
            text_lengths   =   gt['text_lengths']
            output_lengths =   gt['mel_lengths']
            pres_prev_state=   gt['pres_prev_state']
            if model.decoder.half_inference_mode:
                alignments      = alignments     .chunk(2, dim=0)[0]
                text_lengths    = text_lengths   .chunk(2, dim=0)[0]
                output_lengths  = output_lengths .chunk(2, dim=0)[0]
                pres_prev_state = pres_prev_state.chunk(2, dim=0)[0]
            loss_dict['diag_att'] = self.guided_att(alignments[pres_prev_state==0.0],
                                                  text_lengths[pres_prev_state==0.0],
                                                output_lengths[pres_prev_state==0.0])
            del alignments, text_lengths, output_lengths, pres_prev_state
        
        if self.use_res_enc and resGAN is not None:# Residual Encoder KL Divergence Loss
            mu, logvar, mulogvar = pred['res_enc_pkg']
            
            kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss_dict['res_enc_kld'] = kl_loss
            
            # discriminator attempts to predict the letters and speakers using the residual latent space,
            #  the generator attempts to increase the discriminators loss so the latent space will lose speaker and dur info
            #   making it more likely the latent contains information relating to background noise conditions and
            #    other features more relavent to human interests.
            with torch.no_grad():
                gt_speakers   = gt['speaker_id_onehot'].float() # [B, n_speakers]
                gt_sym_durs   = get_class_durations(gt['text'], pred['alignments'].detach(), self.n_symbols)# [B, n_symbols]
            out = resGAN.discriminator(mulogvar)# learns to predict the speaker and
            B = out.shape[0]
            pred_sym_durs, pred_speakers = out.squeeze(-1).split([self.n_symbols, self.n_speakers], dim=1) # amount of 'a','b','c','.', etc sounds that are in the audio.
                                                                      # if there isn't a 'd' sound in the transcript, then d will be 0.0
                                                                      # if there are multiple 'a' sounds, their durations are summed.
            pred_speakers = torch.nn.functional.softmax(pred_speakers, dim=1)
            loss_dict['res_enc_gMSE'] = (nn.MSELoss(reduction='sum')(pred_sym_durs, gt_sym_durs.mean(dim=1, keepdim=True))*0.0001 + nn.MSELoss(reduction='sum')(pred_speakers, gt_speakers.mean(dim=1, keepdim=True)))/B
            
            resGAN.gt_speakers = gt_speakers
            resGAN.gt_sym_durs = gt_sym_durs
            
            del mu, logvar, kl_loss, gt_speakers, gt_sym_durs, pred_sym_durs, pred_speakers
        
        if 1 and model.training and self.use_dbGAN and dbGAN is not None:
            pred_mel_postnet = pred['pred_mel_postnet'].unsqueeze(1)# -> [tfB, 1, n_mel, mel_T]
            pred_mel         = pred['pred_mel']        .unsqueeze(1)# -> [tfB, 1, n_mel, mel_T]
            speaker_embed    = pred['speaker_embed']
            if model.decoder.half_inference_mode:
                pred_mel_postnet = pred_mel_postnet.chunk(2, dim=0)[0]
                pred_mel         = pred_mel        .chunk(2, dim=0)[0]
                speaker_embed    = speaker_embed   .chunk(2, dim=0)[0]
            B, _, n_mel, mel_T = pred_mel.shape
            mels = torch.cat((pred_mel, pred_mel_postnet), dim=0).float()# [2*B, 1, n_mel, mel_T]
            with torch.no_grad():
                assert not (torch.isnan(mels) | torch.isinf(mels)).any(), 'NaN or Inf value found in computation'
            
#            if False:
#                pred_fakeness = checkpoint(dbGAN.discriminator, mels, speaker_id.repeat(2)).squeeze(1)# -> [2*B, mel_T//?]
#            else:
            pred_fakeness = dbGAN.discriminator(mels, speaker_embed.repeat(2, 1)).squeeze(1)# -> [2*B, mel_T//?]
            pred_fakeness, postnet_fakeness = pred_fakeness.chunk(2, dim=0)# -> [B, mel_T//?], [B, mel_T//?]
            
            tfB, post_mel_T = pred_fakeness.shape
            real_label = torch.ones(tfB, post_mel_T, device=pred_mel.device, dtype=pred_mel.dtype)*-1.0# [B]
            loss_dict['dbGAN_gLoss'] = F.mse_loss(pred_fakeness, real_label)*0.5 + F.mse_loss(postnet_fakeness, real_label)*0.5
            with torch.no_grad():
                assert not torch.isnan(loss_dict['dbGAN_gLoss']), 'dbGAN loss is NaN'
                assert not torch.isinf(loss_dict['dbGAN_gLoss']), 'dbGAN loss is Inf'
            del mels, real_label, pred_fakeness, postnet_fakeness, pred_mel, pred_mel_postnet, speaker_embed
        
        if self.use_InfGAN and infGAN is not None and model.decoder.half_inference_mode:
            with torch.no_grad():
                pred_gate = pred['pred_gate_logits'].chunk(2, dim=0)[1].sigmoid()
                pred_gate[:, :5] = 0.0
                # Get inference alignment scores
                pred_mel_lengths = get_first_over_thresh(pred_gate, 0.5)
                pred_mel_lengths.clamp_(max=mel_T)
                pred['pred_mel_lengths'] = pred_mel_lengths
                mask = get_mask_from_lengths(pred_mel_lengths, max_len=mel_T).unsqueeze(1)# [B, 1, mel_T]
            
            tfB = pred_gate.shape[0]
            with freeze_grads(model.decoder.prenet):
                args = infGAN.merge_inputs(model, pred, gt, tfB, mask)# [B/2, mel_T, embed]
            
            if infGAN.training and infGAN.gradient_checkpoint:
                inf_infness = checkpoint(infGAN.discriminator, *args).squeeze(1)# -> [B/2, mel_T]
            else:
                inf_infness = infGAN.discriminator(*args).squeeze(1)# -> [B/2, mel_T]
            
            tf_label = torch.ones(tfB, device=pred_gate.device, dtype=pred_gate.dtype)[:, None].expand(tfB, mel_T)*-1.# [B/2]
            loss_dict['InfGAN_gLoss'] = 2.*F.mse_loss(inf_infness, tf_label)
        
        #################################################################
        ## Colate / Merge the Losses into a single tensor with scalars ##
        #################################################################
        loss_dict = self.colate_losses(loss_dict, loss_scalars)
        
        with torch.no_grad():# get Avg Max Attention and Diagonality Metrics
            
            atd = alignment_metric(pred['alignments'].detach().clone(), gt['text_lengths'], gt['mel_lengths'])
            diagonalitys, avg_prob, char_max_dur, char_min_dur, char_avg_dur, p_missing_enc = atd.values()
            
            loss_dict['diagonality']       = diagonalitys.mean()
            loss_dict['avg_max_attention'] = avg_prob.mean()
            
            for i in range(tfB):
                audiopath = gt['audiopath'][i]
                file_losses[audiopath]['avg_max_attention'] =      avg_prob[i].cpu().item()
                file_losses[audiopath]['att_diagonality']   =  diagonalitys[i].cpu().item()
                file_losses[audiopath]['p_missing_enc']     = p_missing_enc[i].cpu().item()
                file_losses[audiopath]['char_max_dur']      =  char_max_dur[i].cpu().item()
                file_losses[audiopath]['char_min_dur']      =  char_min_dur[i].cpu().item()
                file_losses[audiopath]['char_avg_dur']      =  char_avg_dur[i].cpu().item()
                
                if 0:
                    diagonality_path = f'{os.path.splitext(audiopath)[0]}_diag.pt'
                    torch.save(diagonalitys[i].detach().clone().cpu(), diagonality_path)
                    
                    avg_prob_path = f'{os.path.splitext(audiopath)[0]}_avgp.pt'
                    torch.save(    avg_prob[i].detach().clone().cpu(), avg_prob_path   )
            
            pred_gate = pred['pred_gate_logits'].sigmoid()
            pred_gate[:, :5] = 0.0
            # Get inference alignment scores
            pred_mel_lengths = get_first_over_thresh(pred_gate, 0.7)
            atd = alignment_metric(pred['alignments'].detach().clone(), gt['text_lengths'], pred_mel_lengths)
            atd = {k: v.cpu() for k, v in atd.items()}
            diagonalitys, avg_prob, char_max_dur, char_min_dur, char_avg_dur, p_missing_enc = atd.values()
            scores = []
            for i in range(tfB):
                # factors that make up score
                weighted_score = avg_prob[i].item() # general alignment quality
                diagonality_punishment = max( diagonalitys[i].item()-1.10, 0) * 0.25 # speaking each letter at a similar pace.
                max_dur_punishment     = max( char_max_dur[i].item()-60.0, 0) * 0.005# getting stuck on same letter for 0.5s
                min_dur_punishment     = max(0.00-char_min_dur[i].item(),  0) * 0.5  # skipping single enc outputs
                avg_dur_punishment     = max(3.60-char_avg_dur[i].item(),  0)        # skipping most enc outputs
                mis_dur_punishment     = max(p_missing_enc[i].item()-0.08, 0) if gt['text_lengths'][i] > 12 and gt['mel_lengths'][i] < gt['mel_lengths'].max()*0.75 else 0.0 # skipping some percent of the text
                
                weighted_score -= (diagonality_punishment+max_dur_punishment+min_dur_punishment+avg_dur_punishment+mis_dur_punishment)
                scores.append(weighted_score)
                file_losses[audiopath]['att_score'] = weighted_score
            scores = torch.tensor(scores)
            scores[torch.isnan(scores)] = scores[~torch.isnan(scores)].mean()
            loss_dict['weighted_score'] = scores.to(pred['alignments'].device).mean()
        
        return loss_dict, file_losses
Exemplo n.º 16
0
 def forward(self, pred, gt, loss_scalars):
     
     loss_dict = {}
     file_losses = {}# dict of {"audiofile": {"spec_MSE": spec_MSE, "avg_prob": avg_prob, ...}, ...}
     
     B, n_mel, mel_T = gt['gt_mel'].shape
     for i in range(B):
         current_time = time.time()
         if gt['audiopath'][i] not in file_losses:
             file_losses[gt['audiopath'][i]] = {'speaker_id_ext': gt['speaker_id_ext'][i], 'time': current_time}
     
     if True:
         pred_mel_postnet = pred['pred_mel_postnet']
         pred_mel         = pred['pred_mel']
         gt_mel           =   gt['gt_mel']
         
         B, n_mel, mel_T = gt_mel.shape
         
         mask = get_mask_from_lengths(gt['mel_lengths'])
         mask = mask.expand(gt_mel.size(1), *mask.shape).permute(1, 0, 2)
         
         # spectrogram / decoder loss
         pred_mel = torch.masked_select(pred_mel, mask)
         gt_mel   = torch.masked_select(gt_mel, mask)
         spec_SE = nn.MSELoss(reduction='none')(pred_mel, gt_mel)
         loss_dict['spec_MSE'] = spec_SE.mean()
         
         losses = spec_SE.split([x*n_mel for x in gt['mel_lengths'].cpu()])
         for i in range(B):
             audiopath = gt['audiopath'][i]
             file_losses[audiopath]['spec_MSE'] = losses[i].mean().item()
         
         # postnet
         pred_mel_postnet.masked_fill_(~mask, 0.0)
         pred_mel_postnet = torch.masked_select(pred_mel_postnet, mask)
         loss_dict['postnet_MSE'] = nn.MSELoss()(pred_mel_postnet, gt_mel)
         
         # squared by frame, mean postnet
         mask = get_mask_from_lengths(gt['mel_lengths']).unsqueeze(-1)# -> [B, mel_T] -> [B, mel_T, 1]
         
         spec_AE = nn.L1Loss(reduction='none')(pred['pred_mel'], gt['gt_mel']).transpose(1, 2)# -> [B, mel_T, n_mel]
         spec_AE = spec_AE.masked_select(mask).view(gt['mel_lengths'].sum(), n_mel)# -> [B*mel_T, n_mel]
         loss_dict['spec_MFSE'] = (spec_AE * spec_AE.mean(dim=1, keepdim=True)).mean()# multiply by frame means (similar to square op from MSE) and get the mean of the losses
         
         post_AE = nn.L1Loss(reduction='none')(pred['pred_mel_postnet'], gt['gt_mel']).transpose(1, 2)# -> [B, mel_T, n_mel]
         post_AE = post_AE.masked_select(mask).view(gt['mel_lengths'].sum(), n_mel)# -> [B*mel_T, n_mel]
         loss_dict['postnet_MFSE'] = (post_AE * post_AE.mean(dim=1, keepdim=True)).mean()# multiply by frame means (similar to square op from MSE) and get the mean of the losses
     
     if True: # gate/stop loss
         gate_target =  gt['gt_gate_logits'].view(-1, 1)
         gate_out = pred['pred_gate_logits'].view(-1, 1)
         loss_dict['gate_loss'] = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)(gate_out, gate_target)
         del gate_target, gate_out
     
     if True: # SylpsNet loss
         syl_mu     = pred['pred_sylps_mu']
         syl_logvar = pred['pred_sylps_logvar']
         loss_dict['sylps_kld'] = -0.5 * (1 + syl_logvar - syl_logvar.exp() - syl_mu.pow(2)).sum()/B
         del syl_mu, syl_logvar
     
     if True: # Pred Sylps loss
         pred_sylps = pred['pred_sylps'].squeeze(1)# [B, 1] -> [B]
         sylps_target = gt['gt_sylps']
         loss_dict['sylps_MAE'] = nn.L1Loss()(pred_sylps, sylps_target)
         loss_dict['sylps_MSE'] = nn.MSELoss()(pred_sylps, sylps_target)
         del pred_sylps, sylps_target
     
     if True:# Diagonal Attention Guiding
         alignments     = pred['alignments']
         text_lengths   = gt['text_lengths']
         output_lengths = gt['mel_lengths']
         pres_prev_state= gt['pres_prev_state']
         loss_dict['diag_att'] = self.guided_att(alignments[pres_prev_state==0.0],
                                               text_lengths[pres_prev_state==0.0],
                                             output_lengths[pres_prev_state==0.0])
         del alignments, text_lengths, output_lengths
     
     #################################################################
     ## Colate / Merge the Losses into a single tensor with scalars ##
     #################################################################
     loss_dict = self.colate_losses(loss_dict, loss_scalars)
     
     with torch.no_grad():# get Avg Max Attention and Diagonality Metrics
         
         atd = alignment_metric(pred['alignments'], gt['text_lengths'], gt['mel_lengths'])
         diagonalitys, avg_prob, char_max_dur, char_min_dur, char_avg_dur, p_missing_enc = atd.values()
         
         loss_dict['diagonality']       = diagonalitys.mean()
         loss_dict['avg_max_attention'] = avg_prob.mean()
         
         for i in range(B):
             audiopath = gt['audiopath'][i]
             file_losses[audiopath]['avg_max_attention'] =      avg_prob[i].cpu().item()
             file_losses[audiopath]['att_diagonality'  ] =  diagonalitys[i].cpu().item()
             file_losses[audiopath]['p_missing_enc']     = p_missing_enc[i].cpu().item()
             file_losses[audiopath]['char_max_dur']      =  char_max_dur[i].cpu().item()
             file_losses[audiopath]['char_min_dur']      =  char_min_dur[i].cpu().item()
             file_losses[audiopath]['char_avg_dur']      =  char_avg_dur[i].cpu().item()
         
         pred_gate = pred['pred_gate_logits'].sigmoid()
         pred_gate[:, :5] = 0.0
         # Get inference alignment scores
         pred_mel_lengths = get_first_over_thresh(pred_gate, 0.7)
         atd = alignment_metric(pred['alignments'], gt['text_lengths'], pred_mel_lengths)
         atd = {k: v.cpu() for k, v in atd.items()}
         diagonalitys, avg_prob, char_max_dur, char_min_dur, char_avg_dur, p_missing_enc = atd.values()
         scores = []
         for i in range(B):
             # factors that make up score
             weighted_score = avg_prob[i].item() # general alignment quality
             diagonality_punishment = max( diagonalitys[i].item()-1.10, 0) * 0.25 # speaking each letter at a similar pace.
             max_dur_punishment     = max( char_max_dur[i].item()-60.0, 0) * 0.005# getting stuck on same letter for 0.5s
             min_dur_punishment     = max(0.00-char_min_dur[i].item(),  0) * 0.5  # skipping single enc outputs
             avg_dur_punishment     = max(3.60-char_avg_dur[i].item(),  0)        # skipping most enc outputs
             mis_dur_punishment     = max(p_missing_enc[i].item()-0.08, 0) if gt['text_lengths'][i] > 12 and gt['mel_lengths'][i] < gt['mel_lengths'].max()*0.75 else 0.0 # skipping some percent of the text
             
             weighted_score -= (diagonality_punishment+max_dur_punishment+min_dur_punishment+avg_dur_punishment+mis_dur_punishment)
             scores.append(weighted_score)
             file_losses[audiopath]['att_score'] = weighted_score
         scores = torch.tensor(scores)
         scores[torch.isnan(scores)] = scores[~torch.isnan(scores)].mean()
         loss_dict['weighted_score'] = scores.to(pred['alignments'].device).mean()
     
     return loss_dict, file_losses
Exemplo n.º 17
0
 def forward(self, model_output, targets, loss_scalars):
     # loss scalars
     MelGlow_ls = loss_scalars['MelGlow_ls'] if loss_scalars['MelGlow_ls'] is not None else self.MelGlow_loss_scalar
     DurGlow_ls = loss_scalars['DurGlow_ls'] if loss_scalars['DurGlow_ls'] is not None else self.DurGlow_loss_scalar
     VarGlow_ls = loss_scalars['VarGlow_ls'] if loss_scalars['VarGlow_ls'] is not None else self.VarGlow_loss_scalar
     Sylps_ls   = loss_scalars['Sylps_ls'  ] if loss_scalars['Sylps_ls'  ] is not None else self.Sylps_loss_scalar
     
     # loss_func
     mel_target, text_lengths, output_lengths, perc_loudness_target, f0_target, energy_target, sylps_target, voiced_mask, char_f0, char_voiced, char_energy, *_ = targets
     B, n_mel, dec_T = mel_target.shape
     enc_T = text_lengths.max()
     output_lengths_float = output_lengths.float()
     
     loss_dict = {}
     
     # Decoder / MelGlow Loss
     if True:
         mel_z, log_s_sum, logdet_w_sum = model_output['melglow']
         
         # remove paddings before loss calc
         mask = get_mask_from_lengths(output_lengths)[:, None, :] # [B, 1, T] BoolTensor
         mask = mask.expand(mask.size(0), mel_target.size(1), mask.size(2))# [B, n_mel, T] BoolTensor
         
         n_elems = (output_lengths_float.sum() * n_mel)
         mel_z = torch.masked_select(mel_z, mask)
         dec_loss_z = ((mel_z.pow(2).sum()) / self.sigma2_2)/n_elems # mean z (over all elements)
         
         log_s_sum = log_s_sum.view(B, -1, dec_T)
         log_s_sum = torch.masked_select(log_s_sum , mask[:, :log_s_sum.shape[1], :])
         dec_loss_s = -log_s_sum.sum()/(n_elems)
         
         dec_loss_w = -logdet_w_sum.sum()/(n_mel*dec_T)
         
         dec_loss_d = dec_loss_z+dec_loss_w+dec_loss_s
         loss = dec_loss_d*MelGlow_ls
         del mel_z, log_s_sum, logdet_w_sum, mask, n_elems
         loss_dict["Decoder_Loss_Z"] = dec_loss_z
         loss_dict["Decoder_Loss_W"] = dec_loss_w
         loss_dict["Decoder_Loss_S"] = dec_loss_s
         loss_dict["Decoder_Loss_Total"] = dec_loss_d
         assert not (torch.isnan(loss) | torch.isinf(loss)).any(), 'Inf/NaN Loss at MelGlow Latents'
     
     # CVarGlow Loss
     if True:
         z, log_s_sum, logdet_w_sum = model_output['cvarglow']
         _ = glow_loss(z, log_s_sum, logdet_w_sum, text_lengths, self.dg_sigma2_2)
         cvar_loss_d, cvar_loss_z, cvar_loss_w, cvar_loss_s = _
         
         if self.DurGlow_loss_scalar:
             loss = loss + cvar_loss_d*DurGlow_ls
         del z, log_s_sum, logdet_w_sum
         loss_dict["CVar_Loss_Z"] = cvar_loss_z
         loss_dict["CVar_Loss_W"] = cvar_loss_w
         loss_dict["CVar_Loss_S"] = cvar_loss_s
         loss_dict["CVar_Loss_Total"] = cvar_loss_d
         assert not (torch.isnan(loss) | torch.isinf(loss)).any(), 'Inf/NaN Loss at CVarGlow Latents'
     
     # FramGlow Loss
     if True:
         z, log_s_sum, logdet_w_sum = model_output['varglow']
         z_channels = 6
         z = z.view(z.shape[0], z_channels, -1)
         
         # remove paddings before loss calc
         mask = get_mask_from_lengths(output_lengths)[:, None, :]#   [B, 1, T] BoolTensor
         mask = mask.expand(mask.size(0), z_channels, mask.size(2))# [B, n_mel, T] BoolTensor
         n_elems = (output_lengths_float.sum() * z_channels)
         
         z = torch.masked_select(z, mask)
         var_loss_z = ((z.pow(2).sum()) / self.sigma2_2)/n_elems # mean z (over all elements)
         
         log_s_sum = log_s_sum.view(B, -1, dec_T)
         log_s_sum = torch.masked_select(log_s_sum , mask[:, :log_s_sum.shape[1], :])
         var_loss_s = -log_s_sum.sum()/(n_elems)
         
         var_loss_w = -logdet_w_sum.sum()/(z_channels*dec_T)
         
         var_loss_d = var_loss_z+var_loss_w+var_loss_s
         loss = loss + var_loss_d*VarGlow_ls
         del z, log_s_sum, logdet_w_sum, mask, n_elems, z_channels
         loss_dict["Variance_Loss_Z"] = var_loss_z
         loss_dict["Variance_Loss_W"] = var_loss_w
         loss_dict["Variance_Loss_S"] = var_loss_s
         loss_dict["Variance_Loss_Total"] = var_loss_d
         assert not (torch.isnan(loss) | torch.isinf(loss)).any(), 'Inf/NaN Loss at VarGlow Latents'
     
     # Sylps Loss
     if True:
         enc_global_outputs, sylps = model_output['sylps']# [B, 2], [B]
         mu, logvar = enc_global_outputs.transpose(0, 1)[:2, :]# [2, B]
         
         loss_dict["zSylps_Loss"] = NormalLLLoss(mu, logvar, sylps)# [B], [B], [B] -> [B]
         loss = loss + loss_dict["zSylps_Loss"]*Sylps_ls
         del mu, logvar, enc_global_outputs, sylps
         assert not (torch.isnan(loss) | torch.isinf(loss)).any(), 'Inf/NaN Loss at Pred Sylps'
     
     # Perceived Loudness Loss
     if True:
         enc_global_outputs, perc_loudness = model_output['perc_loud']# [B, 2], [B]
         mu, logvar = enc_global_outputs.transpose(0, 1)[2:4, :]# [2, B]
         
         loss_dict["zPL_Loss"] = NormalLLLoss(mu, logvar, perc_loudness)# [B], [B], [B] -> [B]
         loss = loss + loss_dict["zPL_Loss"]*Sylps_ls
         del mu, logvar, enc_global_outputs, perc_loudness
         assert not (torch.isnan(loss) | torch.isinf(loss)).any(), 'Inf/NaN Loss at Pred Perceived Loudness'
     
     loss_dict["loss"] = loss
     return loss_dict
Exemplo n.º 18
0
    def inference(
            self,
            text: Tensor,  #  LongTensor[B, enc_T]
            speaker_ids: Tensor,  #  LongTensor[B]
            torchmoji_hidden: Tensor,  # FloatTensor[B, embed] 
            sylps: Optional[Tensor] = None,  # FloatTensor[B]        or None
            text_lengths: Optional[
                Tensor] = None,  #  LongTensor[B]        or None
            durations: Optional[
                Tensor] = None,  # FloatTensor[B, enc_T] or None
            perc_loudness: Optional[
                Tensor] = None,  # FloatTensor[B]        or None
            f0: Optional[Tensor] = None,  # FloatTensor[B, dec_T] or None
            energy: Optional[Tensor] = None,  # FloatTensor[B, dec_T] or None
            mel_sigma: float = 1.0,
            dur_sigma: float = 1.0,
            var_sigma: float = 1.0):
        assert not self.training, "model must be in eval() mode"

        # move Tensors to GPU (if not already there)
        text, speaker_ids, torchmoji_hidden, sylps, text_lengths, durations, perc_loudness, f0, energy = self.update_device(
            text, speaker_ids, torchmoji_hidden, sylps, text_lengths,
            durations, perc_loudness, f0, energy)
        B, enc_T = text.shape

        if text_lengths is None:
            text_lengths = torch.ones((B, )).to(text) * enc_T
        assert text_lengths is not None

        melenc_outputs = self.mel_encoder(
            gt_mels, output_lengths, speaker_ids=speaker_ids) if (
                self.mel_encoder is not None
                and not self.melenc_ignore) else None  # [B, dec_T, melenc_dim]

        embedded_text = self.embedding(text).transpose(
            1, 2)  #    [B, embed, sequence]
        encoder_outputs, enc_global_outputs = self.encoder(
            embedded_text, text_lengths,
            speaker_ids=speaker_ids)  # [B, enc_T, enc_dim]
        if sylps is None:
            sylps = enc_global_outputs[:, 0:1]  # [B, 1]
        if perc_loudness is None:
            perc_loudness = enc_global_outputs[:, 2:3]  # [B, 1]

        assert sylps is not None  # needs to be updated with pred_sylps soon ^TM

        memory = [
            encoder_outputs,
        ]
        if self.speaker_embedding_dim:
            embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
            embedded_speakers = embedded_speakers.repeat(1, enc_T, 1)
            memory.append(embedded_speakers)  # [B, enc_T, enc_dim]
        if sylps is not None:
            sylps = sylps[..., None]  # [B, 1] -> [B, 1, 1]
            sylps = sylps.repeat(1, enc_T, 1)
            memory.append(sylps)  # [B, enc_T, enc_dim]
        if perc_loudness is not None:
            perc_loudness = perc_loudness[..., None]  # [B, 1] -> [B, 1, 1]
            perc_loudness = perc_loudness.repeat(1, enc_T, 1)
            memory.append(perc_loudness)  # [B, enc_T, enc_dim]
        if torchmoji_hidden is not None:
            emotion_embed = torchmoji_hidden.unsqueeze(
                1)  # [B, C] -> [B, 1, C]
            emotion_embed = self.torchmoji_linear(
                emotion_embed)  # [B, 1, in_C] -> [B, 1, out_C]
            emotion_embed = emotion_embed.repeat(1, enc_T, 1)
            memory.append(emotion_embed)  #   [B, enc_T, enc_dim]
        memory = torch.cat(
            memory, dim=2
        )  # [[B, enc_T, enc_dim], [B, enc_T, speaker_dim]] -> [B, enc_T, enc_dim+speaker_dim]
        assert not (torch.isnan(memory)
                    | torch.isinf(memory)).any(), 'Inf/NaN Loss at memory'

        # CVarGlow
        mask = get_mask_from_lengths(text_lengths)  # [B, T]
        cvars = self.cvar_glow.infer(memory.transpose(1, 2), sigma=dur_sigma)
        #  ([B, enc_dim, enc_T]   ,                )
        norm_char_f0 = cvars[:, 1:2]
        norm_char_energy = cvars[:, 2:3]
        char_voiced = cvars[:, 3:4]
        char_f0 = self.bn_cf0.inverse(norm_char_f0)
        char_energy = self.bn_cenergy.inverse(norm_char_energy)

        enc_durations = self.lbn_duration.inverse(
            cvars[:, :1], mask)  # [B, 8, enc_T] -> [B, 1, enc_T]
        memory = torch.cat((memory, cvars[:, 1:4].transpose(1, 2)),
                           dim=2)  # [B, enc_T, enc_dim] +cat+ [B, enc_T, 3]

        attention_contexts = get_attention_from_lengths(
            memory, enc_durations[:, 0, :], text_lengths)
        #                -> [B, dec_T, enc_dim]
        B, dec_T, enc_dim = attention_contexts.shape

        variances = self.var_glow.infer(attention_contexts.transpose(1, 2),
                                        sigma=var_sigma)
        variances = variances.chunk(2, dim=1)[0]  # [B, 3, dec_T]
        voiced_mask = variances[:, 0, :]
        f0 = self.bn_f0.inverse(variances[:, 1:2, :]).squeeze(1)
        energy = self.bn_energy.inverse(variances[:, 2:3, :]).squeeze(1)

        global_cond = None
        if self.melenc_enable:  # take all current info, and produce global cond tokens which can be randomly sampled from later
            global_cond = torch.randn(B, n_tokens)  # [B, n_tokens]

        # Decoder
        cond = [attention_contexts.transpose(1, 2), variances]
        if global_cond is not None:
            cond.append(global_cond)
        cond = torch.cat(cond, dim=1)
        spect = self.decoder.infer(cond, sigma=mel_sigma)

        outputs = {
            "spect": spect,
            "char_durs": enc_durations,
            "char_voiced": char_voiced,
            "char_f0": char_f0,
            "char_energy": char_energy,
            "frame_voiced_mask": voiced_mask,
            "frame_f0": f0,
            "frame_energy": energy,
        }
        return outputs
Exemplo n.º 19
0
    def forward(self,
                cond_inp,
                output_lengths,
                cond_lens=None):  # [B, seq_len, dim], int, [B]
        batch_size, enc_T, enc_dim = cond_inp.shape

        # get Random Position Offset (this *might* allow better distance generalisation)
        #trandint = torch.randint(10000, (1,), device=cond_inp.device, dtype=cond_inp.dtype)

        # get Query from Positional Encoding
        dec_T_max = output_lengths.max().item()
        dec_pos_emb = torch.arange(0,
                                   dec_T_max,
                                   device=cond_inp.device,
                                   dtype=cond_inp.dtype)  # + trandint
        if hasattr(self, 'pos_embedding_q'):
            dec_pos_emb = self.pos_embedding_q(
                dec_pos_emb.clamp(
                    0, self.pos_embedding_q_max - 1).long())[None, ...].repeat(
                        cond_inp.size(0), 1, 1)  # [B, enc_T, enc_dim]
        elif hasattr(self, 'positional_embedding'):
            dec_pos_emb = self.positional_embedding(
                dec_pos_emb, bsz=cond_inp.size(0))  # [B, dec_T, enc_dim]
        if not self.merged_pos_enc:
            dec_pos_emb = dec_pos_emb.repeat(1, 1, self.head_num)
        if output_lengths is not None:  # masking for batches
            dec_mask = get_mask_from_lengths(output_lengths).unsqueeze(
                2)  # [B, dec_T, 1]
            dec_pos_emb = dec_pos_emb * dec_mask  # [B, dec_T, enc_dim] * [B, dec_T, 1] -> [B, dec_T, enc_dim]
        q = dec_pos_emb  # [B, dec_T, enc_dim]

        # get Key/Value from Encoder Outputs
        k = v = cond_inp  # [B, enc_T, enc_dim]
        # (optional) add position encoding to Encoder outputs
        if hasattr(self, 'enc_positional_embedding'):
            enc_pos_emb = torch.arange(0,
                                       enc_T,
                                       device=cond_inp.device,
                                       dtype=cond_inp.dtype)  # + trandint
            if hasattr(self, 'pos_embedding_kv'):
                enc_pos_emb = self.pos_embedding_kv(
                    enc_pos_emb.clamp(0, self.pos_embedding_kv_max -
                                      1).long())[None, ...].repeat(
                                          cond_inp.size(0), 1,
                                          1)  # [B, enc_T, enc_dim]
            elif hasattr(self, 'enc_positional_embedding'):
                enc_pos_emb = self.enc_positional_embedding(
                    enc_pos_emb, bsz=cond_inp.size(0))  # [B, enc_T, enc_dim]
            if self.pos_enc_k:
                k = k + enc_pos_emb
            if self.pos_enc_v:
                v = v + enc_pos_emb
        enc_mask = get_mask_from_lengths(cond_lens).unsqueeze(1).repeat(
            1, q.size(1), 1) if (cond_lens
                                 is not None) else None  # [B, dec_T, enc_T]

        if not self.pytorch_native_mha:
            output, attention_scores = self.multi_head_attention(
                q, k, v, mask=enc_mask
            )  # [B, dec_T, enc_dim], [B, n_head, dec_T, enc_T]
        else:
            q = q.transpose(0, 1)  # [B, dec_T, enc_dim] -> [dec_T, B, enc_dim]
            k = k.transpose(0, 1)  # [B, enc_T, enc_dim] -> [enc_T, B, enc_dim]
            v = v.transpose(0, 1)  # [B, enc_T, enc_dim] -> [enc_T, B, enc_dim]

            enc_mask = ~enc_mask[:, 0, :] if (
                cond_lens
                is not None) else None  # [B, dec_T, enc_T] -> # [B, enc_T]
            attn_mask = ~get_mask_3d(
                output_lengths,
                cond_lens).repeat_interleave(self.head_num, 0) if (
                    cond_lens is not None) else None  #[B*n_head, dec_T, enc_T]
            attn_mask = attn_mask.float() * -35500.0 if (cond_lens
                                                         is not None) else None

            output, attention_scores = self.multi_head_attention(
                q, k, v, key_padding_mask=enc_mask,
                attn_mask=attn_mask)  # [dec_T, B, enc_dim], [B, dec_T, enc_T]

            output = output.transpose(
                0, 1)  # [dec_T, B, enc_dim] -> [B, dec_T, enc_dim]
            output = output + self.o_residual_weights * dec_pos_emb
            attention_scores = attention_scores * get_mask_3d(
                output_lengths, cond_lens) if (
                    cond_lens is not None) else attention_scores
            #attention_scores # [B, dec_T, enc_T]

            for self_att_layer, residual_weight in zip(
                    self.self_attention_layers, self.self_att_o_rws):
                q = output.transpose(
                    0, 1)  # [B, dec_T, enc_dim] -> [dec_T, B, enc_dim]

                output, att_sc = self_att_layer(
                    q, k, v, key_padding_mask=enc_mask, attn_mask=attn_mask
                )  # ..., [dec_T, B, enc_dim], [B, dec_T, enc_T])

                output = output.transpose(
                    0, 1)  # [dec_T, B, enc_dim] -> [B, dec_T, enc_dim]
                output = output * residual_weight + q.transpose(
                    0, 1)  # ([B, dec_T, enc_dim] * rw) + [B, dec_T, enc_dim]
                att_sc = att_sc * get_mask_3d(output_lengths, cond_lens) if (
                    cond_lens is not None) else att_sc
                attention_scores = attention_scores + att_sc

            attention_scores = attention_scores / (1 +
                                                   len(self.self_att_o_rws))

        if output_lengths is not None:
            output = output * dec_mask  # [B, dec_T, enc_dim] * [B, dec_T, 1]
        return output, attention_scores
Exemplo n.º 20
0
 def forward(self, model, pred, gt, loss_scalars,):
     loss_dict = {}
     file_losses = {}# dict of {"audiofile": {"spec_MSE": spec_MSE, "avg_prob": avg_prob, ...}, ...}
     
     B, n_mel, mel_T = gt['gt_mel'].shape
     for i in range(B):
         current_time = time.time()
         if gt['audiopath'][i] not in file_losses:
             file_losses[gt['audiopath'][i]] = {'speaker_id_ext': gt['speaker_id_ext'][i], 'time': current_time}
     
     if True:
         pred_mel_postnet = pred['pred_mel_postnet']
         pred_mel         = pred['pred_mel']
         gt_mel           =   gt['gt_mel']
         mel_lengths      =   gt['mel_lengths']
         
         mask = get_mask_from_lengths(mel_lengths, max_len=gt_mel.size(2))
         mask = mask.expand(gt_mel.size(1), *mask.shape).permute(1, 0, 2)
         pred_mel_postnet.masked_fill_(~mask, 0.0)
         pred_mel        .masked_fill_(~mask, 0.0)
         
         with torch.no_grad():
             assert not torch.isnan(pred_mel).any(), 'mel has NaNs'
             assert not torch.isinf(pred_mel).any(), 'mel has Infs'
             assert not torch.isnan(pred_mel_postnet).any(), 'mel has NaNs'
             assert not torch.isinf(pred_mel_postnet).any(), 'mel has Infs'
         
         B, n_mel, mel_T = gt_mel.shape
         
         # spectrogram / decoder loss
         pred_mel_selected = torch.masked_select(pred_mel, mask)
         gt_mel_selected   = torch.masked_select(gt_mel,   mask)
         spec_SE = nn.MSELoss(reduction='none')(pred_mel_selected, gt_mel_selected)
         loss_dict['spec_MSE'] = spec_SE.mean()
         
         losses = spec_SE.split([x*n_mel for x in mel_lengths.cpu()])
         for i in range(B):
             audiopath = gt['audiopath'][i]
             file_losses[audiopath]['spec_MSE'] = losses[i].mean().item()
         
         # postnet
         pred_mel_postnet_selected = torch.masked_select(pred_mel_postnet, mask)
         loss_dict['postnet_MSE'] = nn.MSELoss()(pred_mel_postnet_selected, gt_mel_selected)
         
         # squared by frame, mean postnet
         mask = mask.transpose(1, 2)[:, :, :1]# [B, mel_T, n_mel] -> [B, mel_T, 1]
         
         spec_AE = nn.L1Loss(reduction='none')(pred_mel, gt_mel).transpose(1, 2)# -> [B, mel_T, n_mel]
         spec_AE = spec_AE.masked_select(mask).view(mel_lengths.sum(), n_mel)   # -> [B* mel_T, n_mel]
         loss_dict['spec_MFSE'] = (spec_AE * spec_AE.mean(dim=1, keepdim=True)).mean()# multiply by frame means (similar to square op from MSE) and get the mean of the losses
         
         post_AE = nn.L1Loss(reduction='none')(pred_mel_postnet, gt_mel).transpose(1, 2)# -> [B, mel_T, n_mel]
         post_AE = post_AE.masked_select(mask).view(mel_lengths.sum(), n_mel)# -> [B*mel_T, n_mel]
         loss_dict['postnet_MFSE'] = (post_AE * post_AE.mean(dim=1, keepdim=True)).mean()# multiply by frame means (similar to square op from MSE) and get the mean of the losses
         del gt_mel, spec_AE, post_AE,#pred_mel_postnet, pred_mel
     
     if True:
         # Code semantic loss.
         code_reconst = model(pred_mel_postnet, gt['speaker_embeds'], None)
         loss_dict['code_L1'] = F.l1_loss(pred['bottleneck_codes'], code_reconst)
     
     #################################################################
     ## Colate / Merge the Losses into a single tensor with scalars ##
     #################################################################
     loss_dict = self.colate_losses(loss_dict, loss_scalars)
     
     return loss_dict, file_losses
Exemplo n.º 21
0
    def forward(self, inputs):
        text, gt_mels, speaker_ids, text_lengths, output_lengths,\
            alignments, torchmoji_hidden, perc_loudness, f0, energy,\
            sylps, voiced_mask, char_f0, char_voiced, char_energy = inputs

        # zero mean unit variance normalization of features
        with torch.no_grad():
            perc_loudness = self.bn_pl(
                perc_loudness.unsqueeze(1))  # [B] -> [B, 1]

            mask = get_mask_from_lengths(output_lengths)  # [B, dec_T]
            f0 = self.bn_f0(
                f0.unsqueeze(1),
                (voiced_mask & mask))  # [B, dec_T] -> [B, 1, dec_T]
            energy = self.bn_energy(energy.unsqueeze(1),
                                    mask)  # [B, dec_T] -> [B, 1, dec_T]

            mask = get_mask_from_lengths(text_lengths)  # [B, enc_T]
            char_f0 = self.bn_cf0(char_f0.unsqueeze(1), mask)  # [B, 1, enc_T]
            char_energy = self.bn_cenergy(char_energy.unsqueeze(1),
                                          mask)  # [B, 1, enc_T]
            char_voiced = char_voiced.unsqueeze(1)  # [B, 1, enc_T]

            mask = get_mask_from_lengths(text_lengths)  # [B, T]
            enc_durations = alignments.sum(dim=1).unsqueeze(
                1)  # [B, dec_T, enc_T] -> [B, enc_T] -> [B, 1, enc_T]
            ln_enc_durations = self.lbn_duration(enc_durations,
                                                 mask)  # [B, 1, enc_T] Norm

        embedded_text = self.embedding(text).transpose(
            1, 2)  #    [B, embed, sequence]
        encoder_outputs, enc_global_outputs = self.encoder(
            embedded_text, text_lengths,
            speaker_ids=speaker_ids)  # [B, enc_T, enc_dim]
        memory = [
            encoder_outputs,
        ]
        if self.speaker_embedding_dim:
            embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
            embedded_speakers = embedded_speakers.repeat(
                1, encoder_outputs.size(1), 1)
            memory.append(embedded_speakers)  # [B, enc_T, enc_dim]
        if sylps is not None:
            sylps = sylps[:, None, None]  # [B] -> [B, 1, 1]
            sylps = sylps.repeat(1, encoder_outputs.size(1), 1)
            memory.append(sylps)  # [B, enc_T, enc_dim]
        if perc_loudness is not None:
            perc_loudness = perc_loudness[..., None]  # [B, 1] -> [B, 1, 1]
            perc_loudness = perc_loudness.repeat(1, encoder_outputs.size(1), 1)
            memory.append(perc_loudness)  # [B, enc_T, enc_dim]
        if torchmoji_hidden is not None:
            emotion_embed = torchmoji_hidden.unsqueeze(
                1)  # [B, C] -> [B, 1, C]
            emotion_embed = self.torchmoji_linear(
                emotion_embed)  # [B, 1, in_C] -> [B, 1, out_C]
            emotion_embed = emotion_embed.repeat(1, encoder_outputs.size(1), 1)
            memory.append(emotion_embed)  #   [B, enc_T, enc_dim]
        memory = torch.cat(
            memory, dim=2
        )  # [[B, enc_T, enc_dim], [B, enc_T, speaker_dim]] -> [B, enc_T, enc_dim+speaker_dim]
        assert not (torch.isnan(memory)
                    | torch.isinf(memory)).any(), 'Inf/NaN Loss at memory'

        # CVarGlow
        cvar_gt = torch.cat(
            (ln_enc_durations, char_f0, char_energy, char_voiced),
            dim=1).repeat(1, 2, 1)  # [B, 4, enc_T] -> [B, 8, enc_T]
        cvar_z, cvar_log_s_sum, cvar_logdet_w_sum = self.cvar_glow(
            cvar_gt, memory.transpose(1, 2))
        #  ([B, enc_T], [B, enc_dim, enc_T])

        memory = torch.cat((memory, char_f0.transpose(
            1, 2), char_energy.transpose(1, 2), char_voiced.transpose(1, 2)),
                           dim=2)  # enc_dim += 3

        attention_contexts = alignments @ memory
        #             [B, dec_T, enc_T] @ [B, enc_T, enc_dim] -> [B, dec_T, enc_dim]

        # Variances Inpainter
        # cond -> attention_contexts
        # x/z  -> voiced_mask + f0 + energy

        var_gt = torch.cat((voiced_mask.to(f0.dtype).unsqueeze(1), f0, energy),
                           dim=1)
        var_gt = var_gt.repeat(1, 2, 1)
        variance_z, variance_log_s_sum, variance_logdet_w_sum = self.var_glow(
            var_gt, attention_contexts.transpose(1, 2))

        global_cond = None
        if self.melenc_enable:  # take all current info, and produce global cond tokens which can be randomly sampled from later
            melenc_input = torch.cat(
                (gt_mels, attention_contexts, voiced_mask.float(), f0, energy),
                dim=1)
            global_cond, mu, logvar = self.mel_encoder(
                melenc_input, output_lengths)  # [B, n_tokens]

        # Decoder
        cond = [
            attention_contexts.transpose(1, 2),
            voiced_mask.to(f0.dtype).unsqueeze(1), f0, energy
        ]
        if global_cond is not None:
            cond.append(global_cond)
        cond = torch.cat(cond, dim=1)
        z, log_s_sum, logdet_w_sum = self.decoder(gt_mels.clone(), cond)
        #   [B, n_mel, dec_T], [B, dec_T, enc_dim] # Series of Flows

        outputs = {
            "melglow": [z, log_s_sum, logdet_w_sum],
            "cvarglow": [cvar_z, cvar_log_s_sum, cvar_logdet_w_sum],
            "varglow": [variance_z, variance_log_s_sum, variance_logdet_w_sum],
            "sylps": [enc_global_outputs, sylps],
            "perc_loud": [enc_global_outputs, perc_loudness],
        }
        return outputs
Exemplo n.º 22
0
    def forward(self,
                model_output,
                targets,
                criterion_dict,
                iter,
                em_kl_weight=None,
                DiagonalGuidedAttention_scalar=None):
        self.em_kl_weight = self.em_kl_weight if em_kl_weight is None else em_kl_weight
        self.DiagonalGuidedAttention_scalar = self.DiagonalGuidedAttention_scalar if DiagonalGuidedAttention_scalar is None else DiagonalGuidedAttention_scalar
        amp, n_gpus, model, model_d, hparams, optimizer, optimizer_d, grad_clip_thresh = criterion_dict.values(
        )
        is_overflow = False
        grad_norm = 0.0

        mel_target, gate_target, output_lengths, text_lengths, emotion_id_target, emotion_onehot_target, sylps_target, preserve_decoder, *_ = targets
        mel_target.requires_grad = False
        gate_target.requires_grad = False
        mel_out, mel_out_postnet, gate_out, alignments, pred_sylps, syl_package, em_package, aux_em_package, gan_package, *_ = model_output
        gate_target = gate_target.view(-1, 1)
        gate_out = gate_out.view(-1, 1)

        Bsz, n_mel, dec_T = mel_target.shape

        unknown_id = self.n_classes
        supervised_mask = (emotion_id_target != unknown_id)  # [B] BoolTensor
        unsupervised_mask = ~supervised_mask  # [B] BoolTensor

        # remove paddings before loss calc
        if self.masked_select:
            mask = get_mask_from_lengths(output_lengths)
            mask = mask.expand(mel_target.size(1), mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            mel_target_not_masked = mel_target
            mel_target = torch.masked_select(mel_target, mask)
            if self.use_LL_Loss:
                mel_out, mel_logvar = mel_out.chunk(2, dim=1)
                if mel_out_postnet is not None:
                    mel_out_postnet, mel_logvar_postnet = mel_out_postnet.chunk(
                        2, dim=1)
                mel_logvar = torch.masked_select(mel_logvar, mask)
                mel_logvar_postnet = torch.masked_select(
                    mel_logvar_postnet, mask)

            mel_out_not_masked = mel_out
            mel_out = torch.masked_select(mel_out, mask)
            if mel_out_postnet is not None:
                mel_out_postnet_not_masked = mel_out_postnet
                mel_out_postnet = torch.masked_select(mel_out_postnet, mask)

        postnet_MSE = postnet_MAE = postnet_SMAE = postnet_LL = torch.tensor(
            0.)

        # spectrogram / decoder loss
        spec_MSE = nn.MSELoss()(mel_out, mel_target)
        spec_MAE = nn.L1Loss()(mel_out, mel_target)
        spec_SMAE = nn.SmoothL1Loss()(mel_out, mel_target)
        if mel_out_postnet is not None:
            postnet_MSE = nn.MSELoss()(mel_out_postnet, mel_target)
            postnet_MAE = nn.L1Loss()(mel_out_postnet, mel_target)
            postnet_SMAE = nn.SmoothL1Loss()(mel_out_postnet, mel_target)
        if self.use_LL_Loss:
            spec_LL = NormalLLLoss(mel_out, mel_logvar, mel_target)
            loss = (spec_LL * self.melout_LL_scalar)
            if mel_out_postnet is not None:
                postnet_LL = NormalLLLoss(mel_out_postnet, mel_logvar_postnet,
                                          mel_target)
            loss += (postnet_LL * self.postnet_LL_scalar)
        else:
            spec_LL = postnet_LL = torch.tensor(0.0, device=mel_out.device)
            loss = (spec_MSE * self.melout_MSE_scalar)
            loss += (spec_MAE * self.melout_MAE_scalar)
            loss += (spec_SMAE * self.melout_SMAE_scalar)
            loss += (postnet_MSE * self.postnet_MSE_scalar)
            loss += (postnet_MAE * self.postnet_MAE_scalar)
            loss += (postnet_SMAE * self.postnet_SMAE_scalar)

        if True:  # gate/stop loss
            gate_loss = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)(
                gate_out, gate_target)
            loss += gate_loss

        if True:  # SylpsNet loss
            sylzu, syl_mu, syl_logvar = syl_package
            sylKLD = -0.5 * (1 + syl_logvar - syl_logvar.exp() -
                             syl_mu.pow(2)).sum() / Bsz
            loss += (sylKLD * self.syl_KDL_weight)

        if True:  # Pred Sylps loss
            pred_sylps = pred_sylps.squeeze(1)  # [B, 1] -> [B]
            sylpsMSE = nn.MSELoss()(pred_sylps, sylps_target)
            sylpsMAE = nn.L1Loss()(pred_sylps, sylps_target)
            loss += (sylpsMSE * self.pred_sylpsMSE_weight)
            loss += (sylpsMAE * self.pred_sylpsMAE_weight)

        if True:  # EmotionNet loss
            zs, em_zu, em_mu, em_logvar, em_params = [
                x.squeeze(1) for x in em_package
            ]  # VAE-GST loss
            SupervisedLoss = ClassicationMAELoss = ClassicationMSELoss = ClassicationNCELoss = SupervisedKDL = UnsupervisedLoss = UnsupervisedKDL = torch.tensor(
                0)

            kl_scale = self.vae_kl_anneal_function(self.anneal_function,
                                                   self.lag, iter, self.k,
                                                   self.x0,
                                                   self.upper)  # outputs 0<s<1
            em_kl_weight = kl_scale * self.em_kl_weight

            if (sum(supervised_mask) > 0):  # if labeled data > 0:
                mu_labeled = em_mu[supervised_mask]
                logvar_labeled = em_logvar[supervised_mask]
                log_prob_labeled = zs[supervised_mask]
                y_onehot = emotion_onehot_target[supervised_mask]

                # -Elbo for labeled data (L(X,y))
                SupervisedLoss, SupervisedKDL = self._L(y_onehot,
                                                        mu_labeled,
                                                        logvar_labeled,
                                                        beta=em_kl_weight)
                loss += SupervisedLoss

                # Add MSE/MAE Loss
                prob_labeled = log_prob_labeled.exp()
                ClassicationMAELoss = nn.L1Loss(reduction='sum')(
                    prob_labeled, y_onehot) / Bsz
                loss += (ClassicationMAELoss * self.zsClassificationMAELoss)

                ClassicationMSELoss = nn.MSELoss(reduction='sum')(
                    prob_labeled, y_onehot) / Bsz
                loss += (ClassicationMSELoss * self.zsClassificationMSELoss)

                # Add auxiliary classification loss q(y|x) # negative cross entropy
                ClassicationNCELoss = -torch.sum(y_onehot * log_prob_labeled,
                                                 dim=1).mean()
                loss += (ClassicationNCELoss * self.zsClassificationNCELoss)

            if (sum(unsupervised_mask) > 0):  # if unlabeled data > 0:
                mu_unlabeled = em_mu[unsupervised_mask]
                logvar_unlabeled = em_logvar[unsupervised_mask]
                log_prob_unlabeled = zs[unsupervised_mask]

                # -Elbo for unlabeled data (U(x))
                UnsupervisedLoss, UnsupervisedKDL = self._U(log_prob_unlabeled,
                                                            mu_unlabeled,
                                                            logvar_unlabeled,
                                                            beta=em_kl_weight)
                loss += UnsupervisedLoss

        if True:  # AuxEmotionNet loss
            aux_zs, aux_em_mu, aux_em_logvar, aux_em_params = [
                x.squeeze(1) for x in aux_em_package
            ]
            PredDistMSE = PredDistMAE = AuxClassicationMAELoss = AuxClassicationMSELoss = AuxClassicationNCELoss = torch.tensor(
                0)

            # pred em_zu dist param Loss
            PredDistMSE = nn.MSELoss()(aux_em_params, em_params)
            PredDistMAE = nn.L1Loss()(aux_em_params, em_params)
            loss += (PredDistMSE * self.predzu_MSE_weight +
                     PredDistMAE * self.predzu_MAE_weight)

            # Aux Zs Classification Loss
            if (sum(supervised_mask) > 0):  # if labeled data > 0:
                log_prob_labeled = aux_zs[supervised_mask]
                prob_labeled = log_prob_labeled.exp()

                AuxClassicationMAELoss = nn.L1Loss(reduction='sum')(
                    prob_labeled, y_onehot) / Bsz
                loss += (AuxClassicationMAELoss *
                         self.auxClassificationMAELoss)

                AuxClassicationMSELoss = nn.MSELoss(reduction='sum')(
                    prob_labeled, y_onehot) / Bsz
                loss += (AuxClassicationMSELoss *
                         self.auxClassificationMSELoss)

                AuxClassicationNCELoss = -torch.sum(
                    y_onehot * log_prob_labeled, dim=1).mean()
                loss += (AuxClassicationNCELoss *
                         self.auxClassificationNCELoss)

        if True:  # Diagonal Attention Guiding
            AttentionLoss = self.guided_att(
                alignments[preserve_decoder == 0.0],
                text_lengths[preserve_decoder == 0.0],
                output_lengths[preserve_decoder == 0.0])
            loss += (AttentionLoss * self.DiagonalGuidedAttention_scalar)

        reduced_d_loss = reduced_avg_fakeness = avg_fakeness = 0.0
        GAN_Spect_MAE = adv_postnet_loss = torch.tensor(0.)
        if True and gan_package[0] is not None:
            real_labels = torch.zeros(mel_target_not_masked.shape[0],
                                      device=loss.device,
                                      dtype=loss.dtype)  # [B]
            fake_labels = torch.ones(mel_target_not_masked.shape[0],
                                     device=loss.device,
                                     dtype=loss.dtype)  # [B]

            mel_outputs_adv, speaker_embed, *_ = gan_package
            if self.masked_select:
                fill_mask = mel_target_not_masked == 0.0
                mel_outputs_adv = mel_outputs_adv.clone()
                mel_outputs_adv.masked_fill_(fill_mask, 0.0)
                mel_outputs_adv_masked = torch.masked_select(
                    mel_outputs_adv, mask)
                mel_out_not_masked = mel_out_not_masked.clone()
                mel_out_not_masked.masked_fill_(fill_mask, 0.0)
                if mel_out_postnet is not None:
                    mel_out_postnet_not_masked = mel_out_postnet_not_masked.clone(
                    )
                    mel_out_postnet_not_masked.masked_fill_(fill_mask, 0.0)

            # spectrograms [B, n_mel, dec_T]
            # mel_target_not_masked
            # mel_out_not_masked
            # mel_out_postnet_not_masked
            # mel_outputs_adv

            speaker_embed = speaker_embed.unsqueeze(2).repeat(
                1, 1, dec_T)  # [B, embed] -> [B, embed, dec_T]
            fake_pred_fakeness = model_d(
                mel_outputs_adv, speaker_embed.detach()
            )  # should speaker_embed be attached computational graph? Not sure atm
            avg_fakeness = fake_pred_fakeness.mean()  # metric for Tensorboard
            # Tacotron2 Optimizer / Loss
            reduced_avg_fakeness = reduce_tensor(
                avg_fakeness.data, n_gpus).item(
                ) if hparams.distributed_run else avg_fakeness.item()

            adv_postnet_loss = nn.BCELoss(
            )(fake_pred_fakeness,
              real_labels)  # [B] -> [] calc loss to decrease fakeness of model
            GAN_Spect_MAE = nn.L1Loss()(mel_outputs_adv_masked, mel_target)
            if reduced_avg_fakeness > 0.4:
                loss += (adv_postnet_loss * self.adv_postnet_scalar)
                loss += (GAN_Spect_MAE *
                         (self.adv_postnet_scalar *
                          self.adv_postnet_reconstruction_weight))

        # Tacotron2 Optimizer / Loss
        if hparams.distributed_run:
            reduced_loss = reduce_tensor(loss.data, n_gpus).item()
            reduced_gate_loss = reduce_tensor(gate_loss.data, n_gpus).item()
        else:
            reduced_loss = loss.item()
            reduced_gate_loss = gate_loss.item()

        if optimizer is not None:
            if hparams.fp16_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if hparams.fp16_run:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), grad_clip_thresh)
                is_overflow = math.isinf(grad_norm) or math.isnan(grad_norm)
            else:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), grad_clip_thresh)

            optimizer.step()

        # (optional) Discriminator Optimizer / Loss
        if True and gan_package[0] is not None:
            if optimizer_d is not None:
                optimizer_d.zero_grad()

            # spectrograms [B, n_mel, dec_T]
            # mel_target_not_masked
            # mel_out_not_masked
            # mel_out_postnet_not_masked
            # mel_outputs_adv

            fake_pred_fakeness = model_d(mel_outputs_adv.detach(),
                                         speaker_embed.detach())
            fake_d_loss = nn.BCELoss()(
                fake_pred_fakeness, fake_labels
            )  # [B] -> [] loss to increase distriminated fakeness of fake samples

            real_pred_fakeness = model_d(mel_target_not_masked.detach(),
                                         speaker_embed.detach())
            real_d_loss = nn.BCELoss()(
                real_pred_fakeness, real_labels
            )  # [B] -> [] loss to decrease distriminated fakeness of real samples

            if self.dis_postnet_scalar and mel_out_postnet is not None:
                fake_pred_fakeness = model_d(
                    mel_out_postnet_not_masked.detach(),
                    speaker_embed.detach())
                fake_d_loss += self.dis_postnet_scalar * nn.BCELoss()(
                    fake_pred_fakeness, fake_labels
                )  # [B] -> [] loss to increase distriminated fakeness of fake samples

            if self.dis_spect_scalar:
                fake_pred_fakeness = model_d(mel_out_not_masked.detach(),
                                             speaker_embed.detach())
                fake_d_loss += self.dis_spect_scalar * nn.BCELoss()(
                    fake_pred_fakeness, fake_labels
                )  # [B] -> [] loss to increase distriminated fakeness of fake samples

            d_loss = (real_d_loss + fake_d_loss) * (self.adv_postnet_scalar *
                                                    0.5)
            reduced_d_loss = reduce_tensor(
                d_loss.data,
                n_gpus).item() if hparams.distributed_run else d_loss.item()

            if optimizer_d is not None and reduced_avg_fakeness < 0.85:
                if hparams.fp16_run:
                    with amp.scale_loss(d_loss, optimizer_d) as scaled_loss:
                        scaled_loss.backward()
                else:
                    d_loss.backward()

                if hparams.fp16_run:
                    grad_norm_d = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer_d), grad_clip_thresh)
                    is_overflow = math.isinf(grad_norm_d) or math.isnan(
                        grad_norm_d)
                else:
                    grad_norm_d = torch.nn.utils.clip_grad_norm_(
                        model_d.parameters(), grad_clip_thresh)

                optimizer_d.step()

        with torch.no_grad():  # debug/fun
            S_Bsz = supervised_mask.sum().item()
            U_Bsz = unsupervised_mask.sum().item()
            ClassicationAccStr = 'N/A'
            Top1ClassificationAcc = 0.0
            if S_Bsz > 0:
                Top1ClassificationAcc = (torch.argmax(
                    log_prob_labeled.exp(), dim=1) == torch.argmax(
                        y_onehot,
                        dim=1)).float().sum().item() / S_Bsz  # top-1 accuracy
                self.AvgClassAcc = self.AvgClassAcc * 0.95 + Top1ClassificationAcc * 0.05
                ClassicationAccStr = round(Top1ClassificationAcc * 100, 2)

        print("            Total loss = ",
              loss.item(),
              '\n',
              "             Spect LLL = ",
              spec_LL.item(),
              '\n',
              "     Postnet Spect LLL = ",
              postnet_LL.item(),
              '\n',
              "             Spect MSE = ",
              spec_MSE.item(),
              '\n',
              "             Spect MAE = ",
              spec_MAE.item(),
              '\n',
              "            Spect SMAE = ",
              spec_SMAE.item(),
              '\n',
              "     Postnet Spect MSE = ",
              postnet_MSE.item(),
              '\n',
              "     Postnet Spect MAE = ",
              postnet_MAE.item(),
              '\n',
              "    Postnet Spect SMAE = ",
              postnet_SMAE.item(),
              '\n',
              "              Gate BCE = ",
              gate_loss.item(),
              '\n',
              "                sylKLD = ",
              sylKLD.item(),
              '\n',
              "              sylpsMSE = ",
              sylpsMSE.item(),
              '\n',
              "              sylpsMAE = ",
              sylpsMAE.item(),
              '\n',
              "        SupervisedLoss = ",
              SupervisedLoss.item(),
              '\n',
              "         SupervisedKDL = ",
              SupervisedKDL.item(),
              '\n',
              "      UnsupervisedLoss = ",
              UnsupervisedLoss.item(),
              '\n',
              "       UnsupervisedKDL = ",
              UnsupervisedKDL.item(),
              '\n',
              "   ClassicationMSELoss = ",
              ClassicationMSELoss.item(),
              '\n',
              "   ClassicationMAELoss = ",
              ClassicationMAELoss.item(),
              '\n',
              "   ClassicationNCELoss = ",
              ClassicationNCELoss.item(),
              '\n',
              "AuxClassicationMSELoss = ",
              AuxClassicationMSELoss.item(),
              '\n',
              "AuxClassicationMAELoss = ",
              AuxClassicationMAELoss.item(),
              '\n',
              "AuxClassicationNCELoss = ",
              AuxClassicationNCELoss.item(),
              '\n',
              "      Predicted Zu MSE = ",
              PredDistMSE.item(),
              '\n',
              "      Predicted Zu MAE = ",
              PredDistMAE.item(),
              '\n',
              "     DiagAttentionLoss = ",
              AttentionLoss.item(),
              '\n',
              "       PredAvgFakeness = ",
              reduced_avg_fakeness,
              '\n',
              "          GeneratorMAE = ",
              GAN_Spect_MAE.item(),
              '\n',
              "         GeneratorLoss = ",
              adv_postnet_loss.item(),
              '\n',
              "     DiscriminatorLoss = ",
              reduced_d_loss / self.adv_postnet_scalar,
              '\n',
              "      ClassicationAcc  = ",
              ClassicationAccStr,
              '%\n',
              "      AvgClassicatAcc  = ",
              round(self.AvgClassAcc * 100, 2),
              '%\n',
              "      Total Batch Size = ",
              Bsz,
              '\n',
              "      Super Batch Size = ",
              S_Bsz,
              '\n',
              "      UnSup Batch Size = ",
              U_Bsz,
              '\n',
              sep='')

        loss_terms = [
            [loss.item(), 1.0],
            [spec_MSE.item(), self.melout_MSE_scalar],
            [spec_MAE.item(), self.melout_MAE_scalar],
            [spec_SMAE.item(), self.melout_SMAE_scalar],
            [postnet_MSE.item(), self.postnet_MSE_scalar],
            [postnet_MAE.item(), self.postnet_MAE_scalar],
            [postnet_SMAE.item(), self.postnet_SMAE_scalar],
            [gate_loss.item(), 1.0],
            [sylKLD.item(), self.syl_KDL_weight],
            [sylpsMSE.item(), self.pred_sylpsMSE_weight],
            [sylpsMAE.item(), self.pred_sylpsMAE_weight],
            [SupervisedLoss.item(), 1.0],
            [SupervisedKDL.item(), em_kl_weight * 0.5],
            [UnsupervisedLoss.item(), 1.0],
            [UnsupervisedKDL.item(), em_kl_weight * 0.5],
            [ClassicationMSELoss.item(), self.zsClassificationMSELoss],
            [ClassicationMAELoss.item(), self.zsClassificationMAELoss],
            [ClassicationNCELoss.item(), self.zsClassificationNCELoss],
            [AuxClassicationMSELoss.item(), self.auxClassificationMSELoss],
            [AuxClassicationMAELoss.item(), self.auxClassificationMAELoss],
            [AuxClassicationNCELoss.item(), self.auxClassificationNCELoss],
            [PredDistMSE.item(), self.predzu_MSE_weight],
            [PredDistMAE.item(), self.predzu_MAE_weight],
            [Top1ClassificationAcc, 1.0],
            [reduced_avg_fakeness, 1.0],
            [adv_postnet_loss.item(), self.adv_postnet_scalar],
            [
                reduced_d_loss / self.adv_postnet_scalar,
                self.adv_postnet_scalar
            ],
        ]
        return loss, gate_loss, loss_terms, reduced_loss, reduced_gate_loss, grad_norm, is_overflow