Example #1
0
    def forward(self,
                src_seq,
                src_len,
                mel_len=None,
                d_target=None,
                p_target=None,
                e_target=None,
                max_src_len=None,
                max_mel_len=None,
                d_control=1.0,
                p_control=1.0,
                e_control=1.0):
        src_mask = get_mask_from_lengths(src_len, max_src_len)
        mel_mask = get_mask_from_lengths(
            mel_len, max_mel_len) if mel_len is not None else None

        encoder_output = self.encoder(src_seq, src_mask)
        if d_target is not None:
            variance_adaptor_output, d_prediction, p_prediction, e_prediction, _, _ = self.variance_adaptor(
                encoder_output, src_mask, mel_mask, d_target, p_target,
                e_target, max_mel_len, d_control, p_control, e_control)
        else:
            variance_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor(
                encoder_output, src_mask, mel_mask, d_target, p_target,
                e_target, max_mel_len, d_control, p_control, e_control)

        decoder_output = self.decoder(variance_adaptor_output, mel_mask)
        mel_output = self.mel_linear(decoder_output)

        if self.use_postnet:
            mel_output_postnet = self.postnet(mel_output) + mel_output
        else:
            mel_output_postnet = mel_output

        return mel_output, mel_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len
Example #2
0
    def forward(self, src_seq, src_len, mel_len=None, d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None, speaker_ids=None):
        src_mask = get_mask_from_lengths(src_len, max_src_len)
        mel_mask = get_mask_from_lengths(mel_len, max_mel_len) if mel_len is not None else None
        
        encoder_output = self.encoder(src_seq, src_mask)
        
        if self.use_spk_embed and speaker_ids is not None:
            spk_embed = self.embed_speakers(speaker_ids)
            encoder_output = self.speaker_integrator(encoder_output, spk_embed)
        
        if d_target is not None:
            variance_adaptor_output, d_prediction, p_prediction, e_prediction, _, _ = self.variance_adaptor(
                encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len)
        else:
            variance_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor(
                    encoder_output, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len)
        
        if self.use_spk_embed and speaker_ids is not None:
            variance_adaptor_output = self.speaker_integrator(variance_adaptor_output, spk_embed)
            
        decoder_output = self.decoder(variance_adaptor_output, mel_mask)
        mel_output = self.mel_linear(decoder_output)
        
        if self.use_postnet:
            mel_output_postnet = self.postnet(mel_output) + mel_output
        else:
            mel_output_postnet = mel_output

        return mel_output, mel_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len
    def forward(self,
                src_seq,
                src_len,
                mel_len=None,
                max_src_len=None,
                max_mel_len=None):
        src_mask = get_mask_from_lengths(src_len, max_src_len)
        mel_mask = get_mask_from_lengths(
            mel_len, max_mel_len) if mel_len is not None else None

        encoder_output, enc_attns = self.encoder(src_seq, src_mask, True)

        variance_adaptor_output, d_prediction, W, pred_mel_mask = self.variance_adaptor(
            encoder_output, src_mask, max_mel_len)

        decoder_output, dec_attns, pred_mel_mask = self.decoder(
            variance_adaptor_output, pred_mel_mask, True)
        mel_output = self.mel_linear(decoder_output)

        if self.use_postnet:
            mel_output_postnet = self.postnet(mel_output) + mel_output
        else:
            mel_output_postnet = mel_output

        return mel_output, mel_output_postnet, d_prediction, src_mask, pred_mel_mask, enc_attns, dec_attns, W
Example #4
0
    def forward(self,
                src_seq,
                src_len,
                mel_len=None,
                d_target=None,
                p_target=None,
                e_target=None,
                max_src_len=None,
                max_mel_len=None):
        # mask[false, false, false, true, true], 前面是src_len个false,后面max_src_len个true,无用的信息是true
        src_mask = get_mask_from_lengths(src_len, max_src_len)
        mel_mask = get_mask_from_lengths(
            mel_len, max_mel_len) if mel_len is not None else None

        # encoder_output <16, 110, 256>
        encoder_output = self.encoder(src_seq, src_mask)
        if d_target is not None:
            variance_adaptor_output, d_prediction, p_prediction, e_prediction, _, _ = self.variance_adaptor(
                encoder_output, src_mask, mel_mask, d_target, p_target,
                e_target, max_mel_len)
        else:
            variance_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor(
                encoder_output, src_mask, mel_mask, d_target, p_target,
                e_target, max_mel_len)
        # variance_adaptor_output, decoder_output <16, 965, 256>
        decoder_output = self.decoder(variance_adaptor_output, mel_mask)
        # mel_output <16, 965, 80>
        mel_output = self.mel_linear(decoder_output)

        if self.use_postnet:
            mel_output_postnet = self.postnet(mel_output) + mel_output
        else:
            mel_output_postnet = mel_output

        return mel_output, mel_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len
    def forward(self, embedded_text, text_lengths, mels, mels_lengths):
        embedded_prosody, _ = self.encoder(mels)

        # Bottleneck
        embedded_prosody = self.encoder_bottleneck(embedded_prosody)

        # Obtain k and v from prosody embedding
        key, value = torch.split(embedded_prosody,
                                 self.prosody_embedding_dim,
                                 dim=-1)  # [N, Ty, prosody_embedding_dim] * 2

        # Get attention mask
        text_mask = get_mask_from_lengths(text_lengths).float().unsqueeze(
            -1)  # [B, seq_len, 1]
        mels_mask = get_mask_from_lengths(mels_lengths).float().unsqueeze(
            -1)  # [B, req_len, 1]
        attn_mask = torch.bmm(text_mask,
                              mels_mask.transpose(-2,
                                                  -1))  # [N, seq_len, ref_len]

        # Attention
        style_embed, alignments = self.ref_attn(embedded_text, key, value,
                                                attn_mask)

        # Apply ReLU as the activation function to force the values of the prosody embedding to lie in [0, ∞].
        style_embed = F.relu(style_embed)

        return style_embed, alignments
Example #6
0
    def forward(self,
                src_seq,
                ref_mel,
                src_len,
                mel_len=None,
                d_target=None,
                p_target=None,
                e_target=None,
                max_src_len=None,
                max_mel_len=None):
        src_mask = get_mask_from_lengths(src_len, max_src_len)
        mel_mask = get_mask_from_lengths(
            mel_len, max_mel_len) if mel_len is not None else None
        if hp.vocoder == 'WORLD':
            ap_mask = get_mask_from_lengths(
                mel_len, max_mel_len) if mel_len is not None else None
            sp_mask = get_mask_from_lengths(
                mel_len, max_mel_len) if mel_len is not None else None


#         print(src_seq)
        encoder_output = self.encoder(src_seq, src_mask)
        #         style_embed = self.gst(ref_mel)  # [N, 256]
        #         style_embed = style_embed.expand_as(encoder_output)
        #         encoder_output= encoder_output+style_embed
        encoder_output = encoder_output

        variance_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor(
            encoder_output, src_mask, mel_mask, d_target, p_target, e_target,
            max_mel_len)

        decoder_output = self.decoder(variance_adaptor_output, mel_mask)

        #         if hp.vocoder=='WORLD':
        #             f0_decoder_output = self.f0_decoder(variance_adaptor_output, mel_mask)

        if hp.vocoder == 'WORLD':
            ap_output = self.ap_linear(decoder_output)
            sp_output = self.sp_linear(decoder_output)

            if self.use_postnet:
                sp_output_postnet = self.postnet(sp_output) + sp_output
            else:
                sp_output_postnet = sp_output

            return ap_output, sp_output, sp_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, ap_mask, sp_mask
        else:
            mel_output = self.mel_linear(decoder_output)

            if self.use_postnet:
                mel_output_postnet = self.postnet(mel_output) + mel_output
            else:
                mel_output_postnet = mel_output

            return mel_output, mel_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len
Example #7
0
    def forward(self, src_seq, src_len, mel_len=None, d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None):
#         print(src_seq.shape)
#         print(src_len.shape)
        src_mask = get_mask_from_lengths(src_len, max_src_len)
#         print(src_mask.shape)
        mel_mask = get_mask_from_lengths(mel_len, max_mel_len) if mel_len is not None else None
        if hp.vocoder=='WORLD':
            ap_mask = get_mask_from_lengths(mel_len, max_mel_len) if mel_len is not None else None
            sp_mask = get_mask_from_lengths(mel_len, max_mel_len) if mel_len is not None else None


#         print(src_seq)
        encoder_output = self.encoder(src_seq, src_mask)
#         style_embed = self.gst(ref_mel)  # [N, 256]
#         style_embed = style_embed.expand_as(encoder_output)
#         encoder_output= encoder_output+style_embed
        encoder_output= encoder_output

        variance_adaptor_output, d_prediction, p_prediction, e_prediction, mel_len, mel_mask = self.variance_adaptor(
                    encoder_output, src_seq, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len)
#         print( variance_adaptor_output.shape)
#         plt.matshow( variance_adaptor_output[0].detach().cpu().numpy())
#         plt.savefig('variance_adaptor_output.png')
#         plt.cla()
#         print(mel_mask)
        decoder_output = self.decoder(variance_adaptor_output, mel_mask)
#         print(sp_mask[0])
#         if hp.vocoder=='WORLD':
#             f0_decoder_output = self.f0_decoder(variance_adaptor_output, mel_mask)

        
        if hp.vocoder=='WORLD':
            ap_output = self.ap_linear(decoder_output)
            sp_output = self.sp_linear(decoder_output)


            if self.use_postnet:
                sp_output_postnet = self.postnet(sp_output) + sp_output
            else:
                sp_output_postnet = sp_output

            return ap_output, sp_output, sp_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, ap_mask, sp_mask,variance_adaptor_output,decoder_output

        else:
            mel_output = self.mel_linear(decoder_output)

            if self.use_postnet:
                mel_output_postnet = self.postnet(mel_output) + mel_output
            else:
                mel_output_postnet = mel_output

            return mel_output, mel_output_postnet, d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len
    def forward(self, attention_hidden_state, memory, processed_memory,
                attention_weights_cat, mask, mel_iter, duration=None):
        """
        PARAMS
        ------
        attention_hidden_state: attention rnn last output
        memory: encoder outputs
        processed_memory: processed encoder outputs
        attention_weights_cat: previous and cummulative attention weights
        mask: binary mask for padded data
        """


        if memory.size(0) == 1:  # batch = 1 일때, (합성할때)
            if duration != None:
                mel_per = round(duration / memory.size(1) + 0.5)
                A = 400
                B = duration - A + 100
                a = [memory.size(1), mel_iter // mel_per + A // 2]
                b = [memory.size(1), mel_iter // mel_per - A // 2]
                c = [memory.size(1), B // mel_per]

            else:
                A = 500
                B = memory.size(1) * 14 // 3 
                a = [memory.size(1), mel_iter // 6 + 250]
                b = [memory.size(1), mel_iter // 6 - 250]
                c = [memory.size(1), 500]

            a = torch.tensor(a).cuda()
            b = torch.tensor(b).cuda()
            c = torch.tensor(c).cuda()

            if memory.size(1) < A:  # 짧은문장
                alignment = self.get_alignment_energies(
                    attention_hidden_state, processed_memory, attention_weights_cat, mel_iter)
            else:
                if mel_iter < A:
                    sw_mask = ~get_mask_from_lengths(a)[1:]
                # else:
                elif mel_iter >= A and mel_iter < B:
                    sw_mask = get_mask_from_lengths2(a)[1:] ^ ~get_mask_from_lengths2(b)[1:]
                else:
                    sw_mask = get_mask_from_lengths2(b)[1:]

                alignment = self.get_alignment_energies(
                    attention_hidden_state, processed_memory, attention_weights_cat, mel_iter)
                alignment.data.masked_fill_(sw_mask, self.score_mask_value)

        else:  # batch 2 이상 : 학습할때
            alignment = self.get_alignment_energies(
                attention_hidden_state, processed_memory, attention_weights_cat, mel_iter)
        if mask is not None:
            alignment.data.masked_fill_(mask, self.score_mask_value)

        attention_weights = F.softmax(alignment, dim=1)
        attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
        attention_context = attention_context.squeeze(1)

        return attention_context, attention_weights
Example #9
0
    def forward(self, memory, decoder_inputs, memory_lengths):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """

        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)

        # audio features
        #        f0_dummy = self.get_end_f0(f0s)
        #        f0s = torch.cat((f0s, f0_dummy), dim=2)
        #        f0s = F.relu(self.prenet_f0(f0s))
        #        f0s = f0s.permute(2, 0, 1)
        f0 = memory.new_zeros(
            memory.size(0),  # batch_size
            self.f0_size  # dummy f0 for using pretrained model
        )

        self.initialize_decoder_states(
            memory, mask=~get_mask_from_lengths(memory_lengths))

        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            #            if len(mel_outputs) == 0 or np.random.uniform(0.0, 1.0) <= self.p_teacher_forcing:
            #                decoder_input = torch.cat((decoder_inputs[len(mel_outputs)],
            #                                           f0s[len(mel_outputs)]), dim=1)
            #            else:
            #                decoder_input = torch.cat((self.prenet(mel_outputs[-1]),
            #                                           f0s[len(mel_outputs)]), dim=1)
            if len(mel_outputs) == 0 or np.random.uniform(
                    0.0, 1.0) < self.p_teacher_forcing:
                decoder_input = torch.cat(
                    (decoder_inputs[len(mel_outputs)], f0), dim=1)

            else:
                decoder_input = torch.cat((self.prenet(mel_outputs[-1]), f0),
                                          dim=1)

            mel_output, gate_output, attention_weights = self.decode(
                decoder_input)
            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze()]
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments
Example #10
0
    def forward(self, inputs):
        text, text_lengths, mels, mel_lengths = inputs

        encoder_outputs = self.encoder(text, text_lengths)

        if self.speaker_encoder is not None:
            fragments = mels.unfold(1, self.n_fragment_mel_windows, self.n_fragment_mel_windows // 2).transpose(2, 3)
            fragment_counts = mel_lengths // (self.n_fragment_mel_windows // 2) - 1
            fragments = torch.cat([f[:fc] for f, fc in zip(fragments, fragment_counts)])
            speaker_embeddings = self.speaker_encoder.inference(fragments, fragment_counts.tolist())
            encoder_outputs = torch.cat([encoder_outputs,
                                         speaker_embeddings.unsqueeze(1).repeat(1, encoder_outputs.size(1), 1)], dim=2)

        mel_outputs, gate_outputs, alignments = self.decoder(encoder_outputs, text_lengths, mels)

        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        if mel_lengths is not None:
            mask = ~get_mask_from_lengths(mel_lengths)
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 2, 0)

            mel_outputs.data.masked_fill_(mask, 0.0)
            mel_outputs_postnet.data.masked_fill_(mask, 0.0)
            gate_outputs.data.masked_fill_(mask[:, :, 0], 1e3)  # gate energies
        return mel_outputs, mel_outputs_postnet, gate_outputs, alignments
    def forward(self,
                x,
                src_mask,
                mel_mask=None,
                duration_target=None,
                pitch_target=None,
                energy_target=None,
                max_len=None):
        log_duration_prediction = self.duration_predictor(x, src_mask)

        pitch_prediction = self.pitch_predictor(x, src_mask)
        pitch_embedding = self.pitch_embedding_producer(
            pitch_prediction.unsqueeze(2))

        energy_prediction = self.energy_predictor(x, src_mask)
        energy_embedding = self.energy_embedding_producer(
            energy_prediction.unsqueeze(2))

        x = x + pitch_embedding + energy_embedding

        if duration_target is not None:
            x, mel_len = self.length_regulator(x, duration_target, max_len)
        else:
            duration_rounded = torch.clamp(torch.round(
                torch.exp(log_duration_prediction) - hp.log_offset),
                                           min=0)
            x, mel_len = self.length_regulator(x, duration_rounded, max_len)
            mel_mask = utils.get_mask_from_lengths(mel_len)

        return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask
Example #12
0
    def forward(self, memory, decoder_inputs, memory_lengths):
        # memory : (B, Seq_len, 512) --> encoder outputs
        # decoder_inputs : (B, Mel_Channels : 80, frames)
        # memory lengths : (B)

        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        # print('go frames : ', decoder_input.size())
        # print('decoder inputs : ', decoder_inputs.size())
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        # print('decoder inputs : ', decoder_inputs.size())
        decoder_inputs = self.prenet(decoder_inputs)

        self.initailze_decoder_states(memory,
                                      mask=~get_mask_from_lengths(memory_lengths))

        mel_outputs, alignments, gate_outputs = [], [], []

        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            decoder_input = decoder_inputs[len(mel_outputs)]
            mel_output, attention_weights, gate_output = self.decode(decoder_input)
            # mel_output : (1, B, 240)
            mel_outputs.append(mel_output)
            alignments.append(attention_weights)
            gate_outputs.append(gate_output)

        # print('decoder prediction : ', len(mel_outputs))

        mel_outputs, alignments, gate_outputs = self.parse_decoder_outputs(mel_outputs, alignments, gate_outputs)
        # print('mel outputs', mel_outputs.size())
        return mel_outputs, alignments, gate_outputs
    def forward(self, memory, decoder_inputs, mel_length, memory_lengths):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """
        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)

        self.initialize_decoder_states(
            memory, mask=~get_mask_from_lengths(memory_lengths))

        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            decoder_input = decoder_inputs[len(mel_outputs)]
            mel_output, gate_output, attention_weights = self.decode(
                decoder_input, memory_lengths)
            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze()]
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments
Example #14
0
    def forward(self, model_output, targets):
        mel_target, gate_target, output_lengths, *_ = targets
        mel_target.requires_grad = False
        gate_target.requires_grad = False
        mel_out, mel_out_postnet, gate_out, _ = model_output
        gate_target = gate_target.view(-1, 1)
        gate_out = gate_out.view(-1, 1)

        # remove paddings before loss calc
        if self.masked_select:
            mask = get_mask_from_lengths(output_lengths)
            mask = mask.expand(mel_target.size(1), mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)
            mel_target = torch.masked_select(mel_target, mask)
            mel_out = torch.masked_select(mel_out, mask)
            mel_out_postnet = torch.masked_select(mel_out_postnet, mask)

        if self.loss_func == 'MSELoss':
            mel_loss = nn.MSELoss()(mel_out, mel_target) + \
                nn.MSELoss()(mel_out_postnet, mel_target)
        elif self.loss_func == 'SmoothL1Loss':
            mel_loss = nn.SmoothL1Loss()(mel_out, mel_target) + \
                nn.SmoothL1Loss()(mel_out_postnet, mel_target)

        gate_loss = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)(
            gate_out, gate_target)
        return mel_loss + gate_loss, gate_loss
Example #15
0
    def mask(self, x, predicted, cemb, input_lengths, output_lengths, max_c_len, max_mel_len):
        x_mask = ~utils.get_mask_from_lengths(output_lengths, max_mel_len)
        x_mask = x_mask.expand(hparams.n_mel_channels,
                               x_mask.size(0), x_mask.size(1))
        x_mask = x_mask.permute(1, 2, 0)

        c_mask = ~utils.get_mask_from_lengths(input_lengths, max_c_len)
        c_mask = c_mask.expand(hparams.pre_gru_out_dim,
                               c_mask.size(0), c_mask.size(1))
        c_mask = c_mask.permute(1, 2, 0)

        x.data.masked_fill_(x_mask, 0.0)
        cemb.data.masked_fill_(c_mask, 0.0)
        predicted.data.masked_fill_(c_mask[:, :, 0], 0.0)

        return x, predicted, cemb
Example #16
0
    def hybrid_forward(self, F, memory, decoder_inputs, memory_lengths, *args, **kwargs):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.
        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """

        #memory = memory.transpose((1, 0, 2))
        decoder_input = self.get_go_frame(F, memory).expand_dims(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = F.concat(decoder_input, decoder_inputs, dim=0)
        decoder_inputs = self._prenet(decoder_inputs)

        self.initialize_decoder_states(F, memory, mask=get_mask_from_lengths(F, memory_lengths))
        #self.initialize_decoder_states(F, memory, mask=~get_mask_from_lengths(F, memory_lengths))

        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.shape[0] - 1:
            decoder_input = decoder_inputs[len(mel_outputs)]
            mel_output, gate_output, attention_weights = self.decode(F, decoder_input)
            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze()]
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(mel_outputs, gate_outputs, alignments)
        return mel_outputs, gate_outputs, alignments
Example #17
0
    def parse_output(self, outputs, outputs_length=None):
        if outputs_length is not None:
            mask = ~get_mask_from_lengths(outputs_length)
            mask = mask.expand(self.n_landmarks_channels, mask.size(0),
                               mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs.masked_fill_(mask, 0.0)
Example #18
0
    def mask(self, mel_1, mel_2, length_mel, max_mel_len):
        x_mask = ~utils.get_mask_from_lengths(length_mel, max_mel_len)
        x_mask = x_mask.expand(hp.n_mel_channels, x_mask.size(0),
                               x_mask.size(1))
        x_mask = x_mask.permute(1, 2, 0)
        mel_1.data.masked_fill_(x_mask, 0.0)
        mel_2.data.masked_fill_(x_mask, 0.0)

        return mel_1, mel_2
Example #19
0
    def forward(self, text, text_len, is_mask=True):
        x, x_len = self.embed(text).transpose(1, 2), text_len
        if is_mask:
            mask = get_mask_from_lengths(x_len)
        else:
            mask = None
        out = self.predictor(x, mask)
        out = self.projection(out).squeeze(1)

        return out
    def parse_output(self, outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths)
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs[0].data.masked_fill_(mask, 0.0)
            outputs[1].data.masked_fill_(mask, 0.0)
            outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies
        return outputs
Example #21
0
    def forward(self, memory, decoder_inputs, memory_lengths, f0s):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """

        decoder_input = self.get_go_frame(memory).unsqueeze(
            0)  # initialize with zero frame
        decoder_inputs = self.parse_decoder_inputs(
            decoder_inputs)  # reflect n_frames_per_step
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(
            decoder_inputs)  #mel dim to hparams.prenet_dim

        # audio features
        f0_dummy = self.get_end_f0(f0s)
        f0s = torch.cat((f0s, f0_dummy), dim=2)
        f0s = F.relu(self.prenet_f0(f0s))
        f0s = f0s.permute(2, 0, 1)  #(T,B,C)

        self.initialize_decoder_states(  # initialize decoder member variables for dymanic tensor
            memory,
            mask=~get_mask_from_lengths(memory_lengths))

        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(
                0) - 1:  # Decoder.forward() continues until g.t. mel length.
            if len(mel_outputs) == 0 or np.random.uniform(
                    0.0, 1.0) <= self.p_teacher_forcing:
                decoder_input = torch.cat(
                    (decoder_inputs[len(mel_outputs)], f0s[len(mel_outputs)]),
                    dim=1)
            else:
                decoder_input = torch.cat(
                    (self.prenet(mel_outputs[-1]), f0s[len(mel_outputs)]),
                    dim=1)
            mel_output, gate_output, attention_weights = self.decode(
                decoder_input)
            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze()]
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments
Example #22
0
    def forward(self, memory, decoder_inputs, memory_lengths):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """

        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)

        self.initialize_decoder_states(
            memory, mask=~get_mask_from_lengths(memory_lengths))

        mel_outputs, gate_outputs, alignments = [], [], []

        # scheduled sampling
        decoder_inputs_student = list()

        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            # scheduled sampling
            use_student_input = False
            if len(decoder_inputs_student) > 0:
                prob = self.student_input_prob
                toss = random.random()
                use_student_input = toss < prob
            
            if use_student_input:
                decoder_input = self.prenet(decoder_inputs_student[-1])
            else:
                decoder_input = decoder_inputs[len(mel_outputs)]
            
            mel_output, gate_output, attention_weights = self.decode(
                decoder_input)

            decoder_inputs_student.append(mel_output)

            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze()] * self.n_frames_per_step
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments
Example #23
0
    def forward(self, text, durs, is_mask=True):
        x, x_len = self.embed(text, durs).transpose(1, 2), durs.sum(-1)
        if is_mask:
            mask = get_mask_from_lengths(x_len)
        else:
            mask = None
        out = self.predictor(x, mask)
        uv = self.sil_proj(out).squeeze(1)
        value = self.body_proj(out).squeeze(1)

        return uv, value
Example #24
0
    def forward(self, src_seq, mel_target, mel_aug, p_norm, e_input, src_len, mel_len, d_target=None, p_target=None, e_target=None, max_src_len=None, max_mel_len=None, speaker_embed=None, d_control=1.0, p_control=1.0, e_control=1.0):
        src_mask = get_mask_from_lengths(src_len, max_src_len)
        mel_mask = get_mask_from_lengths(mel_len, max_mel_len)

        # Style modeling
        if d_target is not None:
            style_modeling_output, noise_encoding, d_prediction, p_prediction, e_prediction, _, _, (aug_posterior_d, aug_posterior_p, aug_posterior_e) = self.style_modeling(
                src_seq, speaker_embed, mel_target, mel_aug, p_norm, e_input, src_len, mel_len, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len, d_control, p_control, e_control)
        else:
            style_modeling_output, noise_encoding, d_prediction, p_prediction, e_prediction, mel_len, mel_mask, (aug_posterior_d, aug_posterior_p, aug_posterior_e) = self.style_modeling(
                src_seq, speaker_embed, mel_target, mel_aug, p_norm, e_input, src_len, mel_len, src_mask, mel_mask, d_target, p_target, e_target, max_mel_len, d_control, p_control, e_control)

        # Clean decoding
        mel_output, mel_output_postnet = self.decode(style_modeling_output, mel_mask)

        # Noisy decoding
        mel_output_noisy, mel_output_postnet_noisy = self.decode(style_modeling_output.detach() + noise_encoding, mel_mask)

        return (mel_output, mel_output_noisy), (mel_output_postnet, mel_output_postnet_noisy), d_prediction, p_prediction, e_prediction, src_mask, mel_mask, mel_len, \
            (aug_posterior_d, aug_posterior_p, aug_posterior_e)
Example #25
0
def parse_output(outputs, output_lengths=None, n_mel_channels=80):
    if output_lengths is not None:
        max_len = torch.max(output_lengths).item()
        mask = ~get_mask_from_lengths(output_lengths, max_len)
        mask = mask.expand(n_mel_channels, mask.size(0), mask.size(1))
        mask = mask.permute(1, 0, 2)

        outputs[0].data.masked_fill_(mask, 0.0)
        outputs[1].data.masked_fill_(mask, 0.0)
        outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies

    return outputs
Example #26
0
    def parse_output(self, outputs, text_lengths=None, output_lengths=None):
        if self.mask_padding and text_lengths is not None and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths)

            mask1 = mask.expand(self.n_mel_channels, mask.size(0),
                                mask.size(1))
            mask1 = mask1.permute(1, 0, 2)
            outputs[0].data.masked_fill_(mask1, 0.0)
            outputs[1].data.masked_fill_(mask1, 0.0)

            mask2 = mask.expand(
                max(text_lengths).item() + 1, mask.size(0), mask.size(1))
            mask2 = mask2.permute(1, 2, 0)
            outputs[2].data.masked_fill_(mask2, 0.0)

            mask = ~get_mask_from_lengths(text_lengths)
            mask3 = mask.expand(self.frame_level_rnn_dim, mask.size(0),
                                mask.size(1))
            mask3 = mask3.permute(1, 2, 0)
            outputs[3].data.masked_fill_(mask3, 0.0)
            outputs[4].data.masked_fill_(mask3, 0.0)

        return outputs
Example #27
0
    def parse_output(self, outputs, output_lengths=None):

        if self.mask_padding and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths)
            print("Shapeo of mask: ", mask.shape)
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs[0].data.masked_fill_(mask, 0.0)
            outputs[1].data.masked_fill_(mask, 0.0)
            # outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
        return outputs
Example #28
0
    def forward(self, src_seq,speaker_emb, src_len, hz_seq = None,mel_len=None, d_target=None,  max_src_len=None, max_mel_len=None, d_control=1.0, p_control=1.0, e_control=1.0):
        src_mask = get_mask_from_lengths(src_len, max_src_len)
        mel_mask = get_mask_from_lengths(
            mel_len, max_mel_len) if mel_len is not None else None

        encoder_output = self.encoder(src_seq, src_mask,hz_seq=hz_seq)
        if d_target is not None:
            variance_adaptor_output, d_prediction,   _, _ = self.variance_adaptor(
                encoder_output, src_mask, mel_mask, d_target,   max_mel_len, d_control, p_control, e_control)
        else:
            variance_adaptor_output, d_prediction,   mel_len, mel_mask = self.variance_adaptor(
                encoder_output, src_mask, mel_mask, d_target,   max_mel_len, d_control, p_control, e_control)
            

        decoder_output = self.decoder(variance_adaptor_output, mel_mask,speaker_emb)
        mel_output = self.mel_linear(decoder_output)
        if self.use_postnet:
            unet_out = self.postnet(torch.unsqueeze(mel_output,1))
            mel_output_postnet = unet_out[:,0,:,:]+ mel_output
        else:
            mel_output_postnet = mel_output

        return mel_output, mel_output_postnet, d_prediction,  src_mask, mel_mask, mel_len
Example #29
0
    def parse_outputs(self, mel_outputs, mel_outputs_postnet, gate_outputs,
                      output_lengths):
        mask = ~get_mask_from_lengths(output_lengths, pad=True)
        mask = mask.expand(80, mask.size(0), mask.size(1))
        mask = mask.permute(1, 0, 2)
        # mask : (B, 80, Frames)

        mel_outputs.data.masked_fill_(mask, 0.0)
        mel_outputs_postnet.data.masked_fill_(mask, 0.0)

        # gate outputs : (B, Frames // 3)
        slice_mask = torch.arange(0, mask.size(2), 1)
        gate_outputs.data.masked_fill_(mask[:, 0, slice_mask], 1e3)
        return mel_outputs, mel_outputs_postnet, gate_outputs
Example #30
0
    def forward(self,
                x,
                src_mask,
                mel_mask=None,
                duration_target=None,
                pitch_target=None,
                energy_target=None,
                max_len=None):

        log_duration_prediction = self.duration_predictor(x, src_mask)
        if duration_target is not None:
            x, mel_len = self.length_regulator(x, duration_target, max_len)
        else:
            duration_rounded = torch.clamp(torch.round(
                torch.exp(log_duration_prediction) - hp.log_offset),
                                           min=0)
            x, mel_len = self.length_regulator(x, duration_rounded, max_len)
            mel_mask = utils.get_mask_from_lengths(mel_len)

        pitch_prediction = self.pitch_predictor(x, mel_mask)
        if pitch_target is not None:
            src = torch.ceil((pitch_target - hp.f0_min) /
                             (hp.f0_max - hp.f0_min) * hp.n_bins).long()
            #             print(src)
            pitch_embedding = self.pitch_embedding(src)
        else:
            src = torch.ceil((pitch_prediction - hp.f0_min) /
                             (hp.f0_max - hp.f0_min) * hp.n_bins).long()

            pitch_embedding = self.pitch_embedding(src)

        energy_prediction = self.energy_predictor(x, mel_mask)
        if energy_target is not None:
            #             print(energy_target)
            src = torch.ceil(
                (energy_target - hp.energy_min) /
                (hp.energy_max - hp.energy_min) * hp.n_bins).long()
            #             print(src)
            energy_embedding = self.energy_embedding(src)
        else:
            src = torch.ceil(
                (energy_prediction - hp.energy_min) /
                (hp.energy_max - hp.energy_min) * hp.n_bins).long()

            energy_embedding = self.energy_embedding(src)

        x = x + pitch_embedding + energy_embedding

        return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask