def forward(self, dec_inp, seq_lens: Optional[torch.Tensor] = None, conditioning: Optional[torch.Tensor] = None): if not self.embed_input: inp = dec_inp assert seq_lens is not None mask = mask_from_lens(seq_lens).unsqueeze(2) else: inp = self.word_emb(dec_inp) # [bsz x L x 1] mask = (dec_inp != self.padding_idx).unsqueeze(2) pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype) pos_emb = self.pos_emb(pos_seq) * mask if conditioning is not None: out = self.drop(inp + pos_emb + conditioning) else: out = self.drop(inp + pos_emb) for layer in self.layers: out = layer(out, mask=mask) # out = self.drop(out) return out, mask
def parse_output(self, outputs, output_lengths): # type: (List[Tensor], Tensor) -> List[Tensor] if self.mask_padding and output_lengths is not None: mask = ~mask_from_lens(output_lengths) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) outputs[0].masked_fill_(mask, 0.0) outputs[1].masked_fill_(mask, 0.0) outputs[2].masked_fill_(mask[:, 0, :], 1e3) # gate energies return outputs
def parse_output(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: mask = ~mask_from_lens(output_lengths) mel_mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mel_mask = mel_mask.permute(1, 0, 2) if outputs[0] is not None: float_mask = (~mask).float().unsqueeze(1) outputs[0] = outputs[0] * float_mask outputs[1].data.masked_fill_(mel_mask, 0.0) outputs[2].data.masked_fill_(mel_mask, 0.0) outputs[3].data.masked_fill_(mel_mask[:, 0, :], 1e3) # gate energies return outputs
def forward(self, model_out, targets, is_training=True, meta_agg='mean'): mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred = model_out mel_tgt, dur_tgt, dur_lens, pitch_tgt = targets mel_tgt.requires_grad = False # (B,H,T) => (B,T,H) mel_tgt = mel_tgt.transpose(1, 2) dur_mask = mask_from_lens(dur_lens, max_len=dur_tgt.size(1)) log_dur_tgt = torch.log(dur_tgt.float() + 1) loss_fn = F.mse_loss dur_pred_loss = loss_fn(log_dur_pred, log_dur_tgt, reduction='none') dur_pred_loss = (dur_pred_loss * dur_mask).sum() / dur_mask.sum() ldiff = mel_tgt.size(1) - mel_out.size(1) mel_out = F.pad(mel_out, (0, 0, 0, ldiff, 0, 0), value=0.0) mel_mask = mel_tgt.ne(0).float() loss_fn = F.mse_loss mel_loss = loss_fn(mel_out, mel_tgt, reduction='none') mel_loss = (mel_loss * mel_mask).sum() / mel_mask.sum() ldiff = pitch_tgt.size(1) - pitch_pred.size(1) pitch_pred = F.pad(pitch_pred, (0, ldiff, 0, 0), value=0.0) pitch_loss = F.mse_loss(pitch_tgt, pitch_pred, reduction='none') pitch_loss = (pitch_loss * dur_mask).sum() / dur_mask.sum() loss = mel_loss loss = (mel_loss + pitch_loss * self.pitch_predictor_loss_scale + dur_pred_loss * self.dur_predictor_loss_scale) meta = { 'loss': loss.clone().detach(), 'mel_loss': mel_loss.clone().detach(), 'duration_predictor_loss': dur_pred_loss.clone().detach(), 'pitch_loss': pitch_loss.clone().detach(), 'dur_error': (torch.abs(dur_pred - dur_tgt).sum() / dur_mask.sum()).detach(), } assert meta_agg in ('sum', 'mean') if meta_agg == 'sum': bsz = mel_out.size(0) meta = {k: v * bsz for k, v in meta.items()} return loss, meta
def forward(self, memory, decoder_inputs, memory_lengths): """ Decoder forward pass for training PARAMS ------ memory: Encoder outputs decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs memory_lengths: Encoder output lengths for attention masking. RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ decoder_input = self.get_go_frame(memory).unsqueeze(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) decoder_inputs = self.prenet(decoder_inputs) mask = ~mask_from_lens(memory_lengths) (attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context, processed_memory) = self.initialize_decoder_states(memory) mel_outputs, gate_outputs, alignments = [], [], [] while len(mel_outputs) < decoder_inputs.size(0) - 1: decoder_input = decoder_inputs[len(mel_outputs)] (mel_output, gate_output, attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context) = self.decode( decoder_input, attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context, memory, processed_memory, mask) mel_outputs += [mel_output.squeeze(1)] gate_outputs += [gate_output.squeeze()] alignments += [attention_weights] mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( torch.stack(mel_outputs), torch.stack(gate_outputs), torch.stack(alignments)) return mel_outputs, gate_outputs, alignments
def forward(self, dec_inp, seq_lens=None): if self.word_emb is None: inp = dec_inp mask = mask_from_lens(seq_lens).unsqueeze(2) else: inp = self.word_emb(dec_inp) # [bsz x L x 1] mask = (dec_inp != pad_idx).unsqueeze(2) pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype) pos_emb = self.pos_emb(pos_seq) * mask out = self.drop(inp + pos_emb) for layer in self.layers: out = layer(out, mask=mask) # out = self.drop(out) return out, mask
def forward(self, model_out, targets, is_training=True, meta_agg='mean'): (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred, pitch_tgt, energy_pred, energy_tgt, attn_soft, attn_hard, attn_dur, attn_logprob) = model_out (mel_tgt, in_lens, out_lens) = targets dur_tgt = attn_dur dur_lens = in_lens mel_tgt.requires_grad = False # (B,H,T) => (B,T,H) mel_tgt = mel_tgt.transpose(1, 2) dur_mask = mask_from_lens(dur_lens, max_len=dur_tgt.size(1)) log_dur_tgt = torch.log(dur_tgt.float() + 1) loss_fn = F.mse_loss dur_pred_loss = loss_fn(log_dur_pred, log_dur_tgt, reduction='none') dur_pred_loss = (dur_pred_loss * dur_mask).sum() / dur_mask.sum() ldiff = mel_tgt.size(1) - mel_out.size(1) mel_out = F.pad(mel_out, (0, 0, 0, ldiff, 0, 0), value=0.0) mel_mask = mel_tgt.ne(0).float() loss_fn = F.mse_loss mel_loss = loss_fn(mel_out, mel_tgt, reduction='none') mel_loss = (mel_loss * mel_mask).sum() / mel_mask.sum() ldiff = pitch_tgt.size(2) - pitch_pred.size(2) pitch_pred = F.pad(pitch_pred, (0, ldiff, 0, 0, 0, 0), value=0.0) pitch_loss = F.mse_loss(pitch_tgt, pitch_pred, reduction='none') pitch_loss = (pitch_loss * dur_mask.unsqueeze(1)).sum() / dur_mask.sum() if energy_pred is not None: energy_pred = F.pad(energy_pred, (0, ldiff, 0, 0), value=0.0) energy_loss = F.mse_loss(energy_tgt, energy_pred, reduction='none') energy_loss = (energy_loss * dur_mask).sum() / dur_mask.sum() else: energy_loss = 0 # Attention loss attn_loss = self.attn_ctc_loss(attn_logprob, in_lens, out_lens) loss = (mel_loss + dur_pred_loss * self.dur_predictor_loss_scale + pitch_loss * self.pitch_predictor_loss_scale + energy_loss * self.energy_predictor_loss_scale + attn_loss * self.attn_loss_scale) meta = { 'loss': loss.clone().detach(), 'mel_loss': mel_loss.clone().detach(), 'duration_predictor_loss': dur_pred_loss.clone().detach(), 'pitch_loss': pitch_loss.clone().detach(), 'attn_loss': attn_loss.clone().detach(), 'dur_error': (torch.abs(dur_pred - dur_tgt).sum() / dur_mask.sum()).detach(), } if energy_pred is not None: meta['energy_loss'] = energy_loss.clone().detach() assert meta_agg in ('sum', 'mean') if meta_agg == 'sum': bsz = mel_out.size(0) meta = {k: v * bsz for k, v in meta.items()} return loss, meta
def forward(self, inputs, use_gt_pitch=True, pace=1.0, max_duration=75): (inputs, input_lens, mel_tgt, mel_lens, pitch_dense, energy_dense, speaker, attn_prior, audiopaths) = inputs mel_max_len = mel_tgt.size(2) # Calculate speaker embedding if self.speaker_emb is None: spk_emb = 0 else: spk_emb = self.speaker_emb(speaker).unsqueeze(1) spk_emb.mul_(self.speaker_emb_weight) # Input FFT enc_out, enc_mask = self.encoder(inputs, conditioning=spk_emb) # Alignment text_emb = self.encoder.word_emb(inputs) # make sure to do the alignments before folding attn_mask = mask_from_lens(input_lens)[..., None] == 0 # attn_mask should be 1 for unused timesteps in the text_enc_w_spkvec tensor attn_soft, attn_logprob = self.attention( mel_tgt, text_emb.permute(0, 2, 1), mel_lens, attn_mask, key_lens=input_lens, keys_encoded=enc_out, attn_prior=attn_prior) attn_hard = self.binarize_attention_parallel( attn_soft, input_lens, mel_lens) # Viterbi --> durations attn_hard_dur = attn_hard.sum(2)[:, 0, :] dur_tgt = attn_hard_dur assert torch.all(torch.eq(dur_tgt.sum(dim=1), mel_lens)) # Predict durations log_dur_pred = self.duration_predictor(enc_out, enc_mask).squeeze(-1) dur_pred = torch.clamp(torch.exp(log_dur_pred) - 1, 0, max_duration) # Predict pitch pitch_pred = self.pitch_predictor(enc_out, enc_mask).permute(0, 2, 1) # Average pitch over characters pitch_tgt = average_pitch(pitch_dense, dur_tgt) if use_gt_pitch and pitch_tgt is not None: pitch_emb = self.pitch_emb(pitch_tgt) else: pitch_emb = self.pitch_emb(pitch_pred) enc_out = enc_out + pitch_emb.transpose(1, 2) # Predict energy if self.energy_conditioning: energy_pred = self.energy_predictor(enc_out, enc_mask).squeeze(-1) # Average energy over characters energy_tgt = average_pitch(energy_dense.unsqueeze(1), dur_tgt) energy_tgt = torch.log(1.0 + energy_tgt) energy_emb = self.energy_emb(energy_tgt) energy_tgt = energy_tgt.squeeze(1) enc_out = enc_out + energy_emb.transpose(1, 2) else: energy_pred = None energy_tgt = None len_regulated, dec_lens = regulate_len( dur_tgt, enc_out, pace, mel_max_len) # Output FFT dec_out, dec_mask = self.decoder(len_regulated, dec_lens) mel_out = self.proj(dec_out) return (mel_out, dec_mask, dur_pred, log_dur_pred, pitch_pred, pitch_tgt, energy_pred, energy_tgt, attn_soft, attn_hard, attn_hard_dur, attn_logprob)
def infer(self, memory, memory_lengths): """ Decoder inference PARAMS ------ memory: Encoder outputs RETURNS ------- mel_outputs: mel outputs from the decoder gate_outputs: gate outputs from the decoder alignments: sequence of attention weights from the decoder """ decoder_input = self.get_go_frame(memory) mask = ~mask_from_lens(memory_lengths) (attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context, processed_memory) = self.initialize_decoder_states(memory) mel_lengths = torch.zeros([memory.size(0)], dtype=torch.int32).cuda() not_finished = torch.ones([memory.size(0)], dtype=torch.int32).cuda() mel_outputs, gate_outputs, alignments = (torch.zeros(1), torch.zeros(1), torch.zeros(1)) first_iter = True while True: decoder_input = self.prenet(decoder_input) (mel_output, gate_output, attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context) = self.decode( decoder_input, attention_hidden, attention_cell, decoder_hidden, decoder_cell, attention_weights, attention_weights_cum, attention_context, memory, processed_memory, mask) if first_iter: mel_outputs = mel_output.unsqueeze(0) gate_outputs = gate_output alignments = attention_weights first_iter = False else: mel_outputs = torch.cat((mel_outputs, mel_output.unsqueeze(0)), dim=0) gate_outputs = torch.cat((gate_outputs, gate_output), dim=0) alignments = torch.cat((alignments, attention_weights), dim=0) dec = torch.le(torch.sigmoid(gate_output), self.gate_threshold).to(torch.int32).squeeze(1) not_finished = not_finished * dec mel_lengths += not_finished if self.early_stopping and torch.sum(not_finished) == 0: break if len(mel_outputs) == self.max_decoder_steps: print("Warning! Reached max decoder steps") break decoder_input = mel_output # NOTE(Adrian): This makes it consitent with training-time dims # (ML x B) x L --> ML x B x L mel_len, bsz, _ = mel_outputs.size() alignments = alignments.view(mel_len, bsz, -1) mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( mel_outputs, gate_outputs, alignments) return mel_outputs, gate_outputs, alignments, mel_lengths