Example #1
0
    def forward(self, memory, decoder_inputs, memory_lengths, f0s):
        """ Decoder forward pass for training
        PARAMS
        ------
        memory: Encoder outputs
        decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
        memory_lengths: Encoder output lengths for attention masking.

        RETURNS
        -------
        mel_outputs: mel outputs from the decoder
        gate_outputs: gate outputs from the decoder
        alignments: sequence of attention weights from the decoder
        """

        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)
        if isinstance(f0s, torch.Tensor):
            # audio features
            f0_dummy = self.get_end_f0(f0s)
            f0s = torch.cat((f0s, f0_dummy), dim=2)
            f0s = F.relu(self.prenet_f0(f0s))
            f0s = f0s.permute(2, 0, 1)

        self.initialize_decoder_states(
            memory, mask=~get_mask_from_lengths(memory_lengths))

        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            if len(mel_outputs) == 0 or np.random.uniform(
                    0.0, 1.0) <= self.p_teacher_forcing:
                if isinstance(f0s, torch.Tensor):
                    decoder_input = torch.cat(
                        (decoder_inputs[len(mel_outputs)],
                         f0s[len(mel_outputs)]),
                        dim=1)
                else:
                    decoder_input = decoder_inputs[len(mel_outputs)]
            else:
                if isinstance(f0s, torch.Tensor):
                    decoder_input = torch.cat(
                        (self.prenet(mel_outputs[-1]), f0s[len(mel_outputs)]),
                        dim=1)
                else:
                    decoder_input = self.prenet(mel_outputs[-1])

            mel_output, gate_output, attention_weights = self.decode(
                decoder_input)
            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze()]
            alignments += [attention_weights]

        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments
Example #2
0
    def parse_output(self, outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths)
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs[0].data.masked_fill_(mask, 0.0)
            outputs[1].data.masked_fill_(mask, 0.0)
            outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        return outputs