Ejemplo n.º 1
0
    def outputs(self, text, adj_matrix, melspec, text_lengths, mel_lengths):
        ### Size ###
        B, L, T = text.size(0), text.size(1), melspec.size(2)
        adj_matrix = torch.cat([adj_matrix, adj_matrix], dim=1)
        adj_matrix = adj_matrix.transpose(1, 2).reshape(
            -1,
            text_lengths.max().item(),
            text_lengths.max().item() * 3 * 2)

        ### Prepare Encoder Input ###
        encoder_input = self.Embedding(text).transpose(0, 1)
        encoder_input += self.pe[:L].unsqueeze(1)
        encoder_input = self.dropout(encoder_input)

        ### Transformer Encoder ###
        memory1 = encoder_input
        enc_alignments = []
        text_mask = get_mask_from_lengths(text_lengths)
        for layer in self.Encoder1:
            memory1, enc_align = layer(memory1, src_key_padding_mask=text_mask)
            enc_alignments.append(enc_align.unsqueeze(1))
        enc_alignments = torch.cat(enc_alignments, 1)

        memory2 = self.Encoder2(encoder_input, adj_matrix)
        memory = self.linear(torch.cat([memory1, memory2], dim=-1))

        ### Prepare Decoder Input ###
        mel_input = F.pad(melspec, (1, -1)).transpose(1, 2)
        decoder_input = self.Prenet_D(mel_input).transpose(0, 1)
        decoder_input += self.pe[:T].unsqueeze(1)
        decoder_input = self.dropout(decoder_input)

        ### Prepare Masks ###
        mel_mask = get_mask_from_lengths(mel_lengths)
        diag_mask = torch.triu(melspec.new_ones(T, T)).transpose(0, 1)
        diag_mask[diag_mask == 0] = -float('inf')
        diag_mask[diag_mask == 1] = 0

        ### Decoding ###
        tgt = decoder_input
        dec_alignments, enc_dec_alignments = [], []
        for layer in self.Decoder:
            tgt, dec_align, enc_dec_align = layer(
                tgt,
                memory,
                tgt_mask=diag_mask,
                tgt_key_padding_mask=mel_mask,
                memory_key_padding_mask=text_mask)
            dec_alignments.append(dec_align.unsqueeze(1))
            enc_dec_alignments.append(enc_dec_align.unsqueeze(1))
        dec_alignments = torch.cat(dec_alignments, 1)
        enc_dec_alignments = torch.cat(enc_dec_alignments, 1)

        ### Projection + PostNet ###
        mel_out = self.Projection(tgt.transpose(0, 1)).transpose(1, 2)
        mel_out_post = self.Postnet(mel_out) + mel_out

        gate_out = self.Stop(mel_out.transpose(1, 2)).squeeze(-1)

        return mel_out, mel_out_post, dec_alignments, enc_dec_alignments, gate_out
Ejemplo n.º 2
0
    def forward(self, pred, target, lengths):
        mel_out, duration_out, pitch_out, energy_out = pred
        mel_target, duration_target, pitch_target, energy_target = target
        text_lengths, mel_lengths = lengths

        assert (mel_out.shape == mel_target.shape)  #check for the shape
        assert (duration_out.shape == duration_target.shape)
        assert (pitch_out.shape == pitch_out.shape)
        assert (energy_out.shape == energy_target.shape)

        mel_mask = ~get_mask_from_lengths(
            mel_lengths)  # same for pitch and energy
        duration_mask = ~get_mask_from_lengths(text_lengths)

        #print(mel_mask.shape,duration_mask.shape, "Shape of mel mask and duration mask")
        #print(mel_mask.unsqueeze(1).shape)      # [B, 1, T]

        mel_target = mel_target.masked_select(mel_mask.unsqueeze(1))
        mel_out = mel_out.masked_select(mel_mask.unsqueeze(1))

        duration_target = duration_target.masked_select(duration_mask)
        duration_out = duration_out.masked_select(duration_mask)

        pitch_target = pitch_target.masked_select(mel_mask.unsqueeze(1))
        pitch_out = pitch_out.masked_select(mel_mask.unsqueeze(1))

        energy_target = energy_target.masked_select(mel_mask.unsqueeze(1))
        energy_out = energy_out.masked_select(mel_mask.unsqueeze(1))

        mel_loss = nn.L1Loss()(mel_out, mel_target)
        duration_loss = nn.MSELoss()(duration_out, duration_target)
        pitch_loss = nn.MSELoss()(pitch_out, pitch_target)
        energy_loss = nn.MSELoss()(energy_out, energy_target)

        return mel_loss, duration_loss, pitch_loss, energy_loss
Ejemplo n.º 3
0
    def outputs(self, text, alignments, text_lengths, mel_lengths):
        ### Size ###
        B, L, T = text.size(0), text.size(1), alignments.size(1)

        ### Prepare Inputs ###
        encoder_input = self.Embedding(text).transpose(0, 1)
        encoder_input += self.alpha1 * (self.pe[:L].unsqueeze(1))
        encoder_input = self.dropout(encoder_input)

        ### Prepare Masks ###
        text_mask = get_mask_from_lengths(text_lengths)
        mel_mask = get_mask_from_lengths(mel_lengths)

        ### Speech Synthesis ###
        hidden_states = encoder_input
        for layer in self.Encoder:
            hidden_states, _ = layer(hidden_states,
                                     src_key_padding_mask=text_mask)

        durations = self.align2duration(alignments, mel_mask)
        hidden_states_expanded = self.LR(hidden_states, durations)
        hidden_states_expanded += self.alpha2 * (self.pe[:T].unsqueeze(1))
        hidden_states_expanded = self.dropout(hidden_states_expanded)

        for layer in self.Decoder:
            hidden_states_expanded, _ = layer(hidden_states_expanded,
                                              src_key_padding_mask=mel_mask)

        mel_out = self.Projection(hidden_states_expanded.transpose(
            0, 1)).transpose(1, 2)
        duration_out = self.Duration(hidden_states.permute(1, 2, 0))

        return mel_out, duration_out, durations
Ejemplo n.º 4
0
 def outputs(self, text, durations, text_lengths, mel_lengths):
     #print(text.device, durations.device, text_lengths.device, mel_lengths.device)
     ### Size ###
     B, L, T = text.size(0), text.size(1), mel_lengths.max().item()                                 #alignments.size(1)
     #print("Batch",B,"\nTime Length",L,"\nMax Number of Frames",T)
     
     ### Prepare Inputs ###
     encoder_input = self.Embedding(text).transpose(0,1)
     encoder_input += self.alpha1*(self.pe[:L].unsqueeze(1))
     encoder_input = self.dropout(encoder_input)                                                   # [L,B,256]
     ##print(encoder_input.shape, "Shape of Encoder Input")
     
     ### Prepare Masks ###
     text_mask = get_mask_from_lengths(text_lengths)
     mel_mask = get_mask_from_lengths(mel_lengths)
     #print(text_mask.device, mel_mask.device)
     ### Speech Synthesis ###
     hidden_states = encoder_input
     for layer in self.Encoder:
         hidden_states, _ = layer(hidden_states,
                                  src_key_padding_mask=text_mask)
     ##print(hidden_states.shape,"Shape of Encoder Output")                                          # [L,B,256]
     #durations = self.align2duration(alignments, mel_mask)
     
     duration_out = self.Duration(hidden_states.permute(1,2,0))              #[B,L] passing the encoder output to Duration Module
     ##print(duration_out.shape, "Shape of predicted duration")
     
     hidden_states_expanded = self.LR(hidden_states, durations)              #[T, B, 256]
     ##print(hidden_states_expanded.shape, "Shape LR output")
     
     pitch_out = self.Pitch(hidden_states_expanded.permute(1,2,0))            #passing the expanded hidden states into the Pitch and Energy Modules
     energy_out  = self.Energy(hidden_states_expanded.permute(1,2,0))         # [B,T]
     
     #print(pitch_out.shape,energy_out.shape, "Pitch and Energy Shape" )
     #print(pitch_out.dtype,energy_out.dtype, "Pitch and Energy dtype" )
     
     pitch_one_hot = pitch_to_one_hot(pitch_out)                                  # [B,T,256]
     energy_one_hot = energy_to_one_hot(energy_out)                               # [B,T,256]  
     #print(pitch_one_hot.shape,energy_one_hot.shape, "Pitch and Energy One Hot Shape")
     
     #print(hidden_states_expanded.device, pitch_one_hot.device, energy_one_hot.device, "HS, p, e, device") all cpu as ngpu = 0 in hparams
     
     hidden_states_expanded = hidden_states_expanded + pitch_one_hot.transpose(1,0) + energy_one_hot.transpose(1,0)    #adding all the outputs to collect the decoder input
     
     #print(hidden_states_expanded.shape,"Decoder Input Shape")                # [T, B, 256]
     
     hidden_states_expanded += self.alpha2*(self.pe[:T].unsqueeze(1))
     hidden_states_expanded = self.dropout(hidden_states_expanded)
     
     for layer in self.Decoder:
         hidden_states_expanded, _ = layer(hidden_states_expanded,
                                           src_key_padding_mask=mel_mask)
     
     #print(hidden_states_expanded.shape,"Decoder Output Shape")                 #[T,B,256]
     
     mel_out = self.Projection(hidden_states_expanded.transpose(0,1)).transpose(1, 2)
     
     #print(mel_out.shape,"Output Mel Shape")          #[10,80,833] [B, num_mel, T]                         
     
     return mel_out, duration_out, durations, pitch_out, energy_out    
Ejemplo n.º 5
0
    def forward(self, text, melspec, align, text_lengths, mel_lengths, criterion, stage):
        text = text[:,:text_lengths.max().item()]
        melspec = melspec[:,:,:mel_lengths.max().item()]
        
        if stage==0:
            encoder_input = self.Prenet(text)
            hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            mu_sigma = self.get_mu_sigma(hidden_states)
            mdn_loss, _ = criterion(mu_sigma, melspec, text_lengths, mel_lengths)
            return mdn_loss
        
        elif stage==1:
            align = align[:, :text_lengths.max().item(), :mel_lengths.max().item()]
            encoder_input = self.Prenet(text)
            hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            mel_out = self.get_melspec(hidden_states, align, mel_lengths)
            
            mel_mask = ~get_mask_from_lengths(mel_lengths)
            melspec = melspec.masked_select(mel_mask.unsqueeze(1))
            mel_out = mel_out.masked_select(mel_mask.unsqueeze(1))
            fft_loss = nn.L1Loss()(mel_out, melspec)
            
            return fft_loss
        
        elif stage==2:
            encoder_input = self.Prenet(text)
            hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            mu_sigma = self.get_mu_sigma(hidden_states)
            mdn_loss, log_prob_matrix = criterion(mu_sigma, melspec, text_lengths, mel_lengths)
            
            align = self.viterbi(log_prob_matrix, text_lengths, mel_lengths) # B, T
            mel_out = self.get_melspec(hidden_states, align, mel_lengths)
            
            mel_mask = ~get_mask_from_lengths(mel_lengths)
            melspec = melspec.masked_select(mel_mask.unsqueeze(1))
            mel_out = mel_out.masked_select(mel_mask.unsqueeze(1))
            fft_loss = nn.L1Loss()(mel_out, melspec)
            
            return mdn_loss + fft_loss
        
        elif stage==3:
            align = align[:, :text_lengths.max().item(), :mel_lengths.max().item()]
            duration_out = self.get_duration(text, text_lengths) # gradient cut
            duration_target = align.sum(-1)
            
            duration_mask = ~get_mask_from_lengths(text_lengths)
            duration_target = duration_target.masked_select(duration_mask)
            duration_out = duration_out.masked_select(duration_mask)
            duration_loss = nn.MSELoss()(torch.log(duration_out), torch.log(duration_target))

            return duration_loss
Ejemplo n.º 6
0
    def outputs(self, text, melspec, text_lengths, mel_lengths):
        ### Size ###
        B, L, T = text.size(0), text.size(1), melspec.size(2)
        
        ### Prepare Encoder Input ###
        encoder_input = self.Embedding(text).transpose(0,1)
        encoder_input += self.alpha1*(self.pe[:L].unsqueeze(1))
        encoder_input = self.dropout(encoder_input)

        ### Prepare Decoder Input ###
        mel_input = F.pad(melspec, (1,-1)).transpose(1,2)
        decoder_input = self.Prenet_D(mel_input).transpose(0,1)
        decoder_input += self.alpha2*(self.pe[:T].unsqueeze(1))
        decoder_input = self.dropout(decoder_input)

        ### Prepare Masks ###
        text_mask = get_mask_from_lengths(text_lengths)
        mel_mask = get_mask_from_lengths(mel_lengths)
        diag_mask = torch.triu(melspec.new_ones(T,T)).transpose(0, 1)
        diag_mask[diag_mask == 0] = -float('inf')
        diag_mask[diag_mask == 1] = 0

        ### Transformer Encoder ###
        memory = encoder_input
        enc_alignments = []
        for layer in self.Encoder:
            memory, enc_align = layer(memory, src_key_padding_mask=text_mask)
            enc_alignments.append(enc_align.unsqueeze(1))
        enc_alignments = torch.cat(enc_alignments, 1)

        ### Transformer Decoder ###
        tgt = decoder_input
        dec_alignments, enc_dec_alignments = [], []
        for layer in self.Decoder:
            tgt, dec_align, enc_dec_align = layer(tgt,
                                                  memory,
                                                  tgt_mask=diag_mask,
                                                  tgt_key_padding_mask=mel_mask,
                                                  memory_key_padding_mask=text_mask)
            dec_alignments.append(dec_align.unsqueeze(1))
            enc_dec_alignments.append(enc_dec_align.unsqueeze(1))
        dec_alignments = torch.cat(dec_alignments, 1)
        enc_dec_alignments = torch.cat(enc_dec_alignments, 1)

        ### Projection + PostNet ###
        mel_out = self.Projection(tgt.transpose(0, 1)).transpose(1, 2)
        mel_out_post = self.Postnet(mel_out) + mel_out

        gate_out = self.Stop(mel_out.transpose(1, 2)).squeeze(-1)
        
        return mel_out, mel_out_post, enc_alignments, dec_alignments, enc_dec_alignments, gate_out
Ejemplo n.º 7
0
    def mask_decoder_output(self, decoder_outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~utl.get_mask_from_lengths(output_lengths)
            float_mask = (~mask).float().unsqueeze(1)
            decoder_outputs *= float_mask

        return decoder_outputs
Ejemplo n.º 8
0
    def forward(self, memory, decoder_inputs, memory_lengths):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)

        self.initialize_decoder_states(memory, mask=~get_mask_from_lengths(memory_lengths))

        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            decoder_input = decoder_inputs[len(mel_outputs)]
            mel_output, gate_output, attention_weights = self.decode(decoder_input)
            mel_outputs += [mel_output.squeeze(1)]
            gate_output += [gate_output.squeeze(1)]
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments
        )

        return mel_outputs, gate_outputs, alignments
Ejemplo n.º 9
0
def batch_diagonal_guide(text_lengths, mel_lengths, g=0.2):
    dtype, device = torch.float32, text_lengths.device

    grid_text = torch.arange(text_lengths.max(), dtype=dtype, device=device)
    grid_text = grid_text.view(1, -1) / text_lengths.view(-1, 1)  # (B, T)

    grid_mel = torch.arange(mel_lengths.max(), dtype=dtype, device=device)
    grid_mel = grid_mel.view(1, -1) / mel_lengths.view(-1, 1)  # (B, M)

    grid = grid_text.unsqueeze(1) - grid_mel.unsqueeze(2)  # (B, M, T)

    # apply text and mel length masks
    grid.transpose(2, 1)[~get_mask_from_lengths(text_lengths)] = 0.
    grid[~get_mask_from_lengths(mel_lengths)] = 0.

    W = 1 - torch.exp(-grid ** 2 / (2 * g ** 2))
    return W
Ejemplo n.º 10
0
    def forward(self, x, lengths):
        alignments = []
        x = x.transpose(0, 1)
        mask = get_mask_from_lengths(lengths)
        for layer in self.FFT_layers:
            x, align = layer(x, src_key_padding_mask=mask)
            alignments.append(align.unsqueeze(1))
        alignments = torch.cat(alignments, 1)

        return x.transpose(0, 1), alignments
Ejemplo n.º 11
0
    def align2duration(self, alignments, mel_lengths):
        ids = alignments.new_tensor(torch.arange(alignments.size(2)))
        max_ids = torch.max(alignments, dim=2)[1].unsqueeze(-1)
        mel_mask = get_mask_from_lengths(mel_lengths)
        one_hot = 1.0 * (ids == max_ids)
        one_hot.masked_fill_(mel_mask.unsqueeze(2), 0)

        durations = torch.sum(one_hot, dim=1)

        return durations
Ejemplo n.º 12
0
    def parse_output(self, outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths)
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs[0].data.masked_fill_(mask, 0.0)
            outputs[1].data.masked_fill_(mask, 0.0)
            outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        return outputs
Ejemplo n.º 13
0
    def parse_output(self, outputs, output_lengths=None, text_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~utl.get_mask_from_lengths(output_lengths)
            outputs.mels.data.masked_fill_(mask.unsqueeze(1), 0.0)
            outputs.mels_postnet.data.masked_fill_(mask.unsqueeze(1), 0.0)
            outputs.gate.data.masked_fill_(mask, 1e3)

            if text_lengths is not None:
                outputs.alignments.data.masked_fill_(~utl.get_mask_3d(output_lengths, text_lengths), 0.0)

        return outputs
Ejemplo n.º 14
0
    def parse_output(self, outputs, output_lengths=None, text_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~utl.get_mask_from_lengths(output_lengths)
            mel_mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mel_mask = mel_mask.permute(1, 0, 2)

            outputs.mels.data.masked_fill_(mel_mask, 0.0)
            outputs.mels_postnet.data.masked_fill_(mel_mask, 0.0)
            outputs.gate.data.masked_fill_(mel_mask[:, 0, :], 1e3)
            if text_lengths is not None:
                outputs.alignments.data.masked_fill_(~utl.get_mask_3d(output_lengths, text_lengths), 0.0)

        return outputs
Ejemplo n.º 15
0
    def forward(self, memory, decoder_inputs, memory_lengths, p_teacher_forcing=1.0):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """

        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)

        self.initialize_decoder_states(
            memory, mask=~utl.get_mask_from_lengths(memory_lengths))

        mel_outputs, gate_outputs, alignments, decoder_outputs = [], [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            if prob2bool(p_teacher_forcing) or len(mel_outputs) == 0:
                decoder_input = decoder_inputs[len(mel_outputs)]
            else:
                decoder_input = self.prenet(mel_outputs[-1])

            mel_output, gate_output, attention_weights, decoder_output = self.decode(decoder_input)

            mel_outputs.append(mel_output)
            gate_outputs.append(gate_output)
            alignments.append(attention_weights)

            if decoder_output is not None:
                decoder_outputs.append(decoder_output)

        mel_outputs, gate_outputs, alignments, decoder_outputs = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments, decoder_outputs)

        return mel_outputs, gate_outputs, alignments, decoder_outputs
Ejemplo n.º 16
0
    def forward(self, pred, target, guide):
        mel_out, mel_out_post, gate_out = pred
        mel_target, gate_target = target
        alignments, text_lengths, mel_lengths = guide

        mask = ~get_mask_from_lengths(mel_lengths)

        mel_target = mel_target.masked_select(mask.unsqueeze(1))
        mel_out_post = mel_out_post.masked_select(mask.unsqueeze(1))
        mel_out = mel_out.masked_select(mask.unsqueeze(1))

        gate_target = gate_target.masked_select(mask)
        gate_out = gate_out.masked_select(mask)

        mel_loss = nn.L1Loss()(mel_out, mel_target) + nn.L1Loss()(mel_out_post,
                                                                  mel_target)
        bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(5.0))(
            gate_out, gate_target)
        guide_loss = self.guide_loss(alignments, text_lengths, mel_lengths)

        return mel_loss, bce_loss, guide_loss
Ejemplo n.º 17
0
    def forward(self,
                text,
                melspec,
                align,
                text_lengths,
                mel_lengths,
                criterion,
                stage,
                log_viterbi=False,
                cpu_viterbi=False):
        text = text[:, :text_lengths.max().item()]
        melspec = melspec[:, :, :mel_lengths.max().item()]

        if stage == 0:
            # encoder_input = self.Prenet(text)
            # import pdb;pdb.set_trace()
            # hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            # hidden_states, _ = self.FFT_lower(encoder_input, mel_lengths)
            log_probs, hidden_states_spec, _ = self.get_am(
                melspec, mel_lengths, text)
            # mu_sigma = self.get_mu_sigma(hidden_states)
            # mdn_loss, log_prob_matrix = criterion(probs, hidden_states_spec, text_lengths, mel_lengths)
            # mdn_loss, _ = criterion(mu_sigma, melspec, text_lengths, mel_lengths)
            # import pdb;pdb.set_trace()
            mel_lengths = torch.ceil(mel_lengths.float() / 2).long()
            mdn_loss = self.ctc_loss(log_probs, text, mel_lengths,
                                     text_lengths) / log_probs.size(1)
            return mdn_loss

        elif stage == 1:
            align = align[:, :text_lengths.max().item(), :mel_lengths.max().
                          item()]
            encoder_input = self.Prenet(text)
            hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            mel_out = self.get_melspec(hidden_states, align, mel_lengths)

            mel_mask = ~get_mask_from_lengths(mel_lengths)
            melspec = melspec.masked_select(mel_mask.unsqueeze(1))
            mel_out = mel_out.masked_select(mel_mask.unsqueeze(1))
            fft_loss = nn.L1Loss()(mel_out, melspec)

            return fft_loss

        elif stage == 2:
            encoder_input = self.Prenet(text)
            hidden_states, _ = self.FFT_lower(encoder_input, text_lengths)
            probs, hidden_states_spec = self.get_am(melspec, mel_lengths, text)

            # mu_sigma = self.get_mu_sigma(hidden_states)
            mdn_loss, log_prob_matrix = criterion(probs, hidden_states_spec,
                                                  text_lengths, mel_lengths)

            before = datetime.now()
            if cpu_viterbi:
                align = self.viterbi_cpu(log_prob_matrix, text_lengths.cpu(),
                                         mel_lengths.cpu())  # B, T
            else:
                align = self.viterbi(log_prob_matrix, text_lengths,
                                     mel_lengths)  # B, T
            after = datetime.now()

            if log_viterbi:
                time_delta = after - before
                print(f'Viterbi took {time_delta.total_seconds()} secs')

            mel_out = self.get_melspec(hidden_states, align, mel_lengths)

            mel_mask = ~get_mask_from_lengths(mel_lengths)
            melspec = melspec.masked_select(mel_mask.unsqueeze(1))
            mel_out = mel_out.masked_select(mel_mask.unsqueeze(1))
            fft_loss = nn.L1Loss()(mel_out, melspec)

            return mdn_loss + fft_loss

        elif stage == 3:
            align = align[:, :text_lengths.max().item(), :mel_lengths.max().
                          item()]
            duration_out = self.get_duration(text,
                                             text_lengths)  # gradient cut
            duration_target = align.sum(-1)

            duration_mask = ~get_mask_from_lengths(text_lengths)
            duration_target = duration_target.masked_select(duration_mask)
            duration_out = duration_out.masked_select(duration_mask)
            duration_loss = nn.MSELoss()(torch.log(duration_out),
                                         torch.log(duration_target))

            return duration_loss
Ejemplo n.º 18
0
    def mask_decoder_output(self, decoder_outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~utl.get_mask_from_lengths(output_lengths)
            decoder_outputs.data.masked_fill_(mask.unsqueeze(1), 0.0)

        return decoder_outputs