Esempio n. 1
0
	def forward(self, memory, decoder_inputs, memory_lengths):
		''' Decoder forward pass for training
		PARAMS
		------
		memory: Encoder outputs
		decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
		memory_lengths: Encoder output lengths for attention masking.

		RETURNS
		-------
		mel_outputs: mel outputs from the decoder
		gate_outputs: gate outputs from the decoder
		alignments: sequence of attention weights from the decoder
		'''
		decoder_input = self.get_go_frame(memory).unsqueeze(0)
		decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
		decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
		decoder_inputs = self.prenet(decoder_inputs)

		self.initialize_decoder_states(
			memory, mask=~get_mask_from_lengths(memory_lengths))

		mel_outputs, gate_outputs, alignments = [], [], []
		while len(mel_outputs) < decoder_inputs.size(0) - 1:
			decoder_input = decoder_inputs[len(mel_outputs)]
			mel_output, gate_output, attention_weights = self.decode(
				decoder_input)
			mel_outputs += [mel_output.squeeze(1)]
			gate_outputs += [gate_output.squeeze()]
			alignments += [attention_weights]
		mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
			mel_outputs, gate_outputs, alignments)

		return mel_outputs, gate_outputs, alignments
Esempio n. 2
0
	def parse_output(self, outputs, output_lengths=None):
		if self.mask_padding and output_lengths is not None:
			mask = ~get_mask_from_lengths(output_lengths, True) # (B, T)
			mask = mask.expand(self.num_mels, mask.size(0), mask.size(1)) # (80, B, T)
			mask = mask.permute(1, 0, 2) # (B, 80, T)
			
			outputs[0].data.masked_fill_(mask, 0.0) # (B, 80, T)
			outputs[1].data.masked_fill_(mask, 0.0) # (B, 80, T)
			slice = torch.arange(0, mask.size(2), self.n_frames_per_step)
			outputs[2].data.masked_fill_(mask[:, 0, slice], 1e3)  # gate energies (B, T//n_frames_per_step)

		return outputs
Esempio n. 3
0
    def forward(self, memory, decoder_inputs, ss_prob, memory_lengths):
        ''' Decoder forward pass for training
		PARAMS
		------
		memory: Encoder outputs
		decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
		memory_lengths: Encoder output lengths for attention masking.

		RETURNS
		-------
		mel_outputs: mel outputs from the decoder
		gate_outputs: gate outputs from the decoder
		alignments: sequence of attention weights from the decoder
		'''
        decoder_input = self.get_go_frame(memory).unsqueeze(0)
        decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
        decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
        decoder_inputs = self.prenet(decoder_inputs)

        self.initialize_decoder_states(
            memory, mask=~get_mask_from_lengths(memory_lengths))

        mel_outputs, gate_outputs, alignments = [], [], []
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            if hps.Scheduled_Sampling and len(mel_outputs) != 0:
                sample_prob = memory.new(memory.shape[0]).uniform_(0, 1)
                sample_mask = sample_prob < ss_prob
                pre_pred = self.prenet(mel_output.detach())
                if sample_mask.sum() == 0:
                    decoder_input = decoder_inputs[len(mel_outputs)]
                elif sample_mask.sum() == memory.shape[0]:
                    decoder_input = pre_pred
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    decoder_input = decoder_inputs[len(mel_outputs)].clone()
                    decoder_input.index_copy_(
                        0, sample_ind, pre_pred.index_select(0, sample_ind))
            else:
                decoder_input = decoder_inputs[len(mel_outputs)]

            mel_output, gate_output, attention_weights = self.decode(
                decoder_input)
            mel_outputs += [mel_output.squeeze(1)]
            gate_outputs += [gate_output.squeeze()]
            alignments += [attention_weights]
        mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
            mel_outputs, gate_outputs, alignments)

        return mel_outputs, gate_outputs, alignments