def infer(self, memory, memory_lengths): """ Decoder inference PARAMS ------ memory: Encoder outputs RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ decoder_input = self.get_go_frame(memory) if memory.size(0) > 1: mask = ~get_mask_from_lengths(memory_lengths) else: mask = None self.initialize_decoder_states(memory, mask=mask) mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32) not_finished = torch.ones([memory.size(0)], dtype=torch.int32) if torch.cuda.is_available(): mel_lengths = mel_lengths.cuda() not_finished = not_finished.cuda() mel_outputs, gate_outputs, alignments = [], [], [] while True: decoder_input = self.prenet(decoder_input, inference=True) mel_output, gate_output, alignment = self.decode(decoder_input) dec = torch.le(torch.sigmoid(gate_output.data), self.gate_threshold).to(torch.int32).squeeze(1) not_finished = not_finished * dec mel_lengths += not_finished if self.early_stopping and torch.sum(not_finished) == 0: break mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output] alignments += [alignment] if len(mel_outputs) == self.max_decoder_steps: logging.warning("Reached max decoder steps %d.", self.max_decoder_steps) break decoder_input = mel_output mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments, mel_lengths
def _loss( self, mel_out, mel_out_postnet, gate_out, mel_target, gate_target, target_len, seq_len, ): mel_target.requires_grad = False gate_target.requires_grad = False gate_target = gate_target.view(-1, 1) max_len = mel_target.shape[2] if max_len < mel_out.shape[2]: # Predicted len is larger than reference # Need to slice mel_out = mel_out.narrow(2, 0, max_len) mel_out_postnet = mel_out_postnet.narrow(2, 0, max_len) gate_out = gate_out.narrow(1, 0, max_len).contiguous() elif max_len > mel_out.shape[2]: # Need to do padding pad_amount = max_len - mel_out.shape[2] mel_out = pad(mel_out, (0, pad_amount), value=self.pad_value) mel_out_postnet = pad(mel_out_postnet, (0, pad_amount), value=self.pad_value) gate_out = pad(gate_out, (0, pad_amount), value=1e3) max_len = mel_out.shape[2] mask = ~get_mask_from_lengths(target_len, max_len=max_len) mask = mask.expand(mel_target.shape[1], mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) mel_out.data.masked_fill_(mask, self.pad_value) mel_out_postnet.data.masked_fill_(mask, self.pad_value) gate_out.data.masked_fill_(mask[:, 0, :], 1e3) gate_out = gate_out.view(-1, 1) mel_loss = nn.MSELoss()(mel_out, mel_target) + nn.MSELoss()( mel_out_postnet, mel_target) gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) return mel_loss + gate_loss
def forward(self, memory, decoder_inputs, memory_lengths): """ Decoder forward pass for training PARAMS ------ memory: Encoder outputs decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs memory_lengths: Encoder output lengths for attention masking. RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ decoder_input = self.get_go_frame(memory).unsqueeze(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) decoder_inputs = self.prenet(decoder_inputs) self.initialize_decoder_states( memory, mask=~get_mask_from_lengths(memory_lengths)) mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] mel_output, gate_output, attention_weights = self.decode( decoder_input) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze()] alignments += [attention_weights] mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments