def forward(self, x, z, output_lengths=None): x = self.start( x ) # [B, in_dim, T//scale_factors] -> [B, self.decoder_dims[0], T//scale_factors] if output_lengths is not None: mask = get_mask_from_lengths(output_lengths).unsqueeze(1) x.masked_fill_(~mask, 0.0) for gblock in self.Gblocks: x = gblock(x, z, output_lengths=output_lengths) if output_lengths is not None: scale_factor = x.shape[2] / output_lengths.sum().max() if scale_factor != 1.0: output_lengths = (output_lengths.float() * (scale_factor)).long() mask = ~get_mask_from_lengths(output_lengths).unsqueeze(1) x.masked_fill_(mask, 0.0) x = self.end(x) # [B, 1, T] x = x.tanh() if output_lengths is not None: x.masked_fill_(mask, 0.0) return x # [B, 1, T]
def _make_masks(ilens, olens): """Make masks indicating non-padded part. Args: ilens (LongTensor or List): Batch of lengths (B,). olens (LongTensor or List): Batch of lengths (B,). Returns: Tensor: Mask tensor indicating non-padded part. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens, olens = [5, 2], [8, 5] >>> _make_mask(ilens, olens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=torch.uint8) """ in_masks = get_mask_from_lengths(ilens) # (B, T_in) out_masks = get_mask_from_lengths(olens) # (B, T_out) return out_masks.unsqueeze(-1) & in_masks.unsqueeze( -2) # (B, T_out, T_in)
def _make_masks(ilens, olens): """Make masks indicating non-padded part. Args: ilens (LongTensor or List): Batch of lengths (B,). olens (LongTensor or List): Batch of lengths (B,). Returns: Tensor: Mask tensor indicating non-padded part. """ in_masks = get_mask_from_lengths(ilens) # (B, T_in) out_masks = get_mask_from_lengths(olens) # (B, T_out) return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
def forward(self, cond_inp, output_lengths, cond_lens=None):# [B, seq_len, dim], int, [B] batch_size, enc_T, enc_dim = cond_inp.shape # get Random Position Offset (this *might* allow better distance generalisation) #trandint = torch.randint(10000, (1,), device=cond_inp.device, dtype=cond_inp.dtype) # get Query from Positional Encoding dec_T_max = output_lengths.max().item() dec_pos_emb = torch.arange(0, dec_T_max, device=cond_inp.device, dtype=cond_inp.dtype)# + trandint if hasattr(self, 'pos_embedding_q'): dec_pos_emb = self.pos_embedding_q(dec_pos_emb.clamp(0, self.pos_embedding_q_max-1).long())[None, ...].repeat(cond_inp.size(0), 1, 1)# [B, enc_T, enc_dim] elif hasattr(self, 'positional_embedding'): dec_pos_emb = self.positional_embedding(dec_pos_emb, bsz=cond_inp.size(0))# [B, dec_T, enc_dim] if not self.merged_pos_enc: dec_pos_emb = dec_pos_emb.repeat(1, 1, self.head_num) if output_lengths is not None:# masking for batches dec_mask = get_mask_from_lengths(output_lengths).unsqueeze(2)# [B, dec_T, 1] dec_pos_emb = dec_pos_emb * dec_mask# [B, dec_T, enc_dim] * [B, dec_T, 1] -> [B, dec_T, enc_dim] q = dec_pos_emb# [B, dec_T, enc_dim] # get Key/Value from Encoder Outputs k = v = cond_inp# [B, enc_T, enc_dim] # (optional) add position encoding to Encoder outputs if hasattr(self, 'enc_positional_embedding'): enc_pos_emb = torch.arange(0, enc_T, device=cond_inp.device, dtype=cond_inp.dtype)# + trandint if hasattr(self, 'pos_embedding_kv'): enc_pos_emb = self.pos_embedding_kv(enc_pos_emb.clamp(0, self.pos_embedding_kv_max-1).long())[None, ...].repeat(cond_inp.size(0), 1, 1)# [B, enc_T, enc_dim] elif hasattr(self, 'enc_positional_embedding'): enc_pos_emb = self.enc_positional_embedding(enc_pos_emb, bsz=cond_inp.size(0))# [B, enc_T, enc_dim] if self.pos_enc_k: k = k + enc_pos_emb if self.pos_enc_v: v = v + enc_pos_emb q = q.transpose(0, 1)# [B, dec_T, enc_dim] -> [dec_T, B, enc_dim] k = k.transpose(0, 1)# [B, enc_T, enc_dim] -> [enc_T, B, enc_dim] v = v.transpose(0, 1)# [B, enc_T, enc_dim] -> [enc_T, B, enc_dim] output = self.MH_Transformer(k, q, src_key_padding_mask=~get_mask_from_lengths(cond_lens).bool() if (cond_lens is not None) else None, tgt_key_padding_mask=~get_mask_from_lengths(output_lengths).bool(), memory_key_padding_mask=~get_mask_from_lengths(cond_lens).bool() if (cond_lens is not None) else None)# [dec_T, B, enc_dim], [B, dec_T, enc_T] output = output.transpose(0, 1)# [dec_T, B, enc_dim] -> [B, dec_T, enc_dim] output = output + self.o_residual_weights * dec_pos_emb attention_scores = get_mask_3d(output_lengths, cond_lens) if (cond_lens is not None) else None# [B, dec_T, enc_T] if output_lengths is not None: output = output * dec_mask# [B, dec_T, enc_dim] * [B, dec_T, 1] return output, attention_scores
def inference(self, text, speaker_ids, text_lengths=None, sigma=1.0): assert not torch.isnan(text).any(), 'text has NaN values.' embedded_text = self.embedding(text).transpose(1, 2) # [B, embed, sequence] assert not torch.isnan(embedded_text).any(), 'encoder_outputs has NaN values.' encoder_outputs = self.encoder.inference(embedded_text, speaker_ids=speaker_ids) # [B, enc_T, enc_dim] assert not torch.isnan(encoder_outputs).any(), 'encoder_outputs has NaN values.' # predict length of each input enc_out_mask = get_mask_from_lengths(text_lengths) if (text_lengths is not None) else None encoder_lengths = self.length_predictor(encoder_outputs, enc_out_mask) assert not torch.isnan(encoder_lengths).any(), 'encoder_lengths has NaN values.' # sum lengths (used to predict mel-spec length) encoder_lengths = encoder_lengths.clamp(1, 128) pred_output_lengths = encoder_lengths.sum((1,)).long() assert not torch.isnan(encoder_lengths).any(), 'encoder_lengths has NaN values.' assert not torch.isnan(pred_output_lengths).any(), 'pred_output_lengths has NaN values.' if self.speaker_embedding_dim: embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) encoder_outputs = torch.cat((encoder_outputs, embedded_speakers), dim=2) # [batch, enc_T, enc_dim] # Positional Attention cond, attention_scores = self.positional_attention(encoder_outputs, pred_output_lengths, cond_lens=text_lengths) cond = cond.transpose(1, 2) assert not torch.isnan(cond).any(), 'cond has NaN values.' # [B, enc_T, enc_dim] -> [B, enc_dim, dec_T] # Masked Multi-head Attention # Decoder mel_outputs = self.decoder.infer(cond, sigma=sigma) # [B, dec_T, emb] -> [B, n_mel, dec_T] # Series of Flows assert not torch.isnan(mel_outputs).any(), 'mel_outputs has NaN values.' return self.mask_outputs( [mel_outputs, attention_scores, None, None, None])
def forward(self, model_output, targets): mel_target, gate_target, output_lengths, *_ = targets mel_target.requires_grad = False gate_target.requires_grad = False mel_out, mel_out_postnet, gate_out, _ = model_output gate_target = gate_target.view(-1, 1) gate_out = gate_out.view(-1, 1) # remove paddings before loss calc if self.masked_select: mask = get_mask_from_lengths(output_lengths) mask = mask.expand(mel_target.size(1), mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) mel_target = torch.masked_select(mel_target, mask) mel_out = torch.masked_select(mel_out, mask) mel_out_postnet = torch.masked_select(mel_out_postnet, mask) if self.loss_func == 'MSELoss': mel_loss = nn.MSELoss()(mel_out, mel_target) + \ nn.MSELoss()(mel_out_postnet, mel_target) elif self.loss_func == 'SmoothL1Loss': mel_loss = nn.SmoothL1Loss()(mel_out, mel_target) + \ nn.SmoothL1Loss()(mel_out_postnet, mel_target) gate_loss = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)( gate_out, gate_target) return mel_loss + gate_loss, gate_loss
def get_attention_from_lengths( memory: Tensor, # FloatTensor[B, enc_T, enc_dim] enc_durations: Tensor, # FloatTensor[B, enc_T] text_lengths: Tensor # LongTensor[B] ): B, enc_T, mem_dim = memory.shape mask = get_mask_from_lengths(text_lengths) enc_durations.masked_fill_(~mask, 0.0) enc_durations = enc_durations.round() # [B, enc_T] dec_T = int(enc_durations.sum(dim=1).max().item()) # [B, enc_T] -> int attention_contexts = torch.zeros(B, dec_T, mem_dim, device=memory.device, dtype=memory.dtype) # [B, dec_T, enc_dim] for i in range(B): mem_temp = [] for j in range(int(text_lengths[i].item())): duration = int(enc_durations[i, j].item()) # [B, enc_T, enc_dim] -> [1, enc_dim] -> [duration, enc_dim] mem_temp.append(memory[i, j:j + 1].repeat(duration, 1)) mem_temp = torch.cat( mem_temp, dim=0) # [[duration, enc_dim], ...] -> [dec_T, enc_dim] min_len = min(attention_contexts.shape[1], mem_temp.shape[0]) attention_contexts[i, :min_len] = mem_temp[:min_len] return attention_contexts # [B, dec_T, enc_dim]
def forward(self, inputs): text, text_lengths, gt_mels, max_len, output_lengths, speaker_ids, torchmoji_hidden, preserve_decoder_states = inputs text_lengths, output_lengths = text_lengths.data, output_lengths.data assert not torch.isnan(text).any(), 'text has NaN values.' embedded_text = self.embedding(text).transpose( 1, 2) # [B, embed, sequence] assert not torch.isnan( embedded_text).any(), 'embedded_text has NaN values.' encoder_outputs = self.encoder( embedded_text, text_lengths, speaker_ids=speaker_ids) # [B, enc_T, enc_dim] assert not torch.isnan( encoder_outputs).any(), 'encoder_outputs has NaN values.' # predict length of each input enc_out_mask = get_mask_from_lengths(text_lengths).unsqueeze( -1) # [B, enc_T, 1] encoder_lengths = self.length_predictor( encoder_outputs, enc_out_mask) # [B, enc_T, enc_dim] assert not torch.isnan( encoder_lengths).any(), 'encoder_lengths has NaN values.' # sum lengths (used to predict mel-spec length) encoder_lengths = encoder_lengths.clamp(1e-6, 4096) pred_output_lengths = encoder_lengths.sum((1, )) assert not torch.isnan( encoder_lengths).any(), 'encoder_lengths has NaN values.' assert not torch.isnan( pred_output_lengths).any(), 'pred_output_lengths has NaN values.' if self.speaker_embedding_dim: embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] embedded_speakers = embedded_speakers.repeat( 1, encoder_outputs.size(1), 1) encoder_outputs = torch.cat((encoder_outputs, embedded_speakers), dim=2) # [batch, enc_T, enc_dim] # Positional Attention cond, attention_scores = self.positional_attention( encoder_outputs, output_lengths, cond_lens=text_lengths) cond = cond.transpose(1, 2) assert not torch.isnan(cond).any(), 'cond has NaN values.' # [B, enc_T, enc_dim] -> [B, enc_dim, dec_T] # Masked Multi-head Attention # Decoder mel_outputs, log_s_sum, logdet_w_sum = self.decoder( gt_mels, cond ) # [B, n_mel, dec_T], [B, dec_T, enc_dim] -> [B, n_mel, dec_T], [B] # Series of Flows assert not torch.isnan( mel_outputs).any(), 'mel_outputs has NaN values.' assert not torch.isnan(log_s_sum).any(), 'mel_outputs has NaN values.' assert not torch.isnan( logdet_w_sum).any(), 'mel_outputs has NaN values.' return self.mask_outputs([ mel_outputs, attention_scores, pred_output_lengths, log_s_sum, logdet_w_sum ], output_lengths)
def mask_outputs(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: mask = ~get_mask_from_lengths(output_lengths) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) # [B, n_mel, steps] outputs[0].data.masked_fill_(mask, 0.0) # [B, n_mel, T] return outputs
def forward(self, model_output, targets): mel_target, gate_target, output_lengths, text_lengths, *_ = targets mel_out, attention_scores, pred_output_lengths, log_s_sum, logdet_w_sum = model_output batch_size, n_mel_channels, frames = mel_target.shape output_lengths_float = output_lengths.float() mel_out = mel_out.float() log_s_sum = log_s_sum.float() logdet_w_sum = logdet_w_sum.float() # Length Loss len_pred_loss = torch.nn.MSELoss()(pred_output_lengths.log(), output_lengths_float.log()) # remove paddings before loss calc mask = get_mask_from_lengths( output_lengths)[:, None, :] # [B, 1, T] BoolTensor mask = mask.expand(mask.size(0), mel_target.size(1), mask.size(2)) # [B, n_mel, T] BoolTensor n_elems = (output_lengths_float.sum() * n_mel_channels) # Spectrogram Loss mel_out = torch.masked_select(mel_out, mask) loss_z = ((mel_out.pow(2).sum()) / self.sigma2_2) / n_elems # mean z (over all elements) loss_w = -logdet_w_sum.sum() / (n_mel_channels * frames) log_s_sum = log_s_sum.view(batch_size, -1, frames) log_s_sum = torch.masked_select(log_s_sum, mask[:, :log_s_sum.shape[1], :]) loss_s = -log_s_sum.sum() / (n_elems) loss = loss_z + loss_w + loss_s + (len_pred_loss * 0.01) assert not torch.isnan(loss).any(), 'loss has NaN values.' # (optional) Guided Attention Loss if hasattr(self, 'guided_att'): att_loss = self.guided_att(attention_scores, text_lengths, output_lengths) loss = loss + att_loss else: att_loss = None if True: # Min-Enc Attention Loss mask = get_mask_3d(output_lengths, text_lengths) attention_scores.sum((1, )) # [B, dec_T, enc_T] mask return loss, len_pred_loss, loss_z, loss_w, loss_s, att_loss
def forward(self, dec_inp, seq_lens=None): if self.word_emb is None: inp = dec_inp mask = get_mask_from_lengths(seq_lens).unsqueeze(2) else: inp = self.word_emb(dec_inp) # [bsz x L x 1] mask = (dec_inp != pad_idx).unsqueeze(2) pos_seq = torch.arange(inp.size(1), device=inp.device, dtype=inp.dtype) pos_emb = self.pos_emb(pos_seq) * mask out = self.drop(inp + pos_emb) for layer in self.layers: out = layer(out, mask=mask) # out = self.drop(out) return out, mask
def forward(self, x, z=None, output_lengths=None): # [B, in_dim, T] if hasattr(self, 'bn'): x = self.bn(x, z) x = self.act_func(x) if hasattr(self, 'scale'): if self.downsample: F.avg_pool1d(x, kernel_size=self.scale) else: x = F.interpolate( x, scale_factor=self.scale, mode='linear') # [B, in_dim, T] -> [B, in_dim, x*T] if output_lengths is not None: scale_factor = x.shape[2] / output_lengths.sum().max() if scale_factor != 1.0: output_lengths = (output_lengths.float() * (scale_factor)).long() mask = get_mask_from_lengths(output_lengths).unsqueeze(1) x.masked_fill_(~mask, 0.0) x = self.conv(x) # [B, in_dim, x*T] -> [B, out_dim, x*T] return x # [B, out_dim, x*T]
def forward(self, h, z, output_lengths=None): scaled_h = F.interpolate( h, scale_factor=self.scale, mode='linear' ) if self.scale != 1 else h # [B, input_dim, T] -> [B, input_dim, x*T] if output_lengths is not None: scale_factor = scaled_h.shape[2] / output_lengths.sum().max() if scale_factor != 1.0: output_lengths = (output_lengths.float() * (scale_factor)).long() mask = get_mask_from_lengths(output_lengths).unsqueeze(1) scaled_h.masked_fill_(~mask, 0.0) residual = self.skip_conv( scaled_h) # [B, input_dim, x*T] -> [B, output_dim, x*T] for i, resblock in enumerate( self.resblocks): # [B, input_dim, T] -> [B, output_dim, x*T] h = resblock(h, z, output_lengths) if i == self.res_block_id: h += residual residiual = h return h + residual # [B, output_dim, x*T]
def glow_loss(z, log_s_sum, logdet_w_sum, output_lengths, sigma): dec_T = output_lengths.max() B = z.shape[0] z = z.view(z.shape[0], -1, dec_T).float() log_s_sum = log_s_sum.view(B, -1, dec_T) B, z_channels, dec_T = z.shape n_elems = (output_lengths.float().sum()*z_channels) # remove paddings before loss calc mask = get_mask_from_lengths(output_lengths)[:, None, :] # [B, 1, T] BoolTensor mask = mask.expand(B, z_channels, dec_T)# [B, z_channels, T] BoolTensor z = torch.masked_select(z, mask) loss_z = ((z.pow(2).sum()) / sigma)/n_elems # mean z (over all elements) log_s_sum = torch.masked_select(log_s_sum , mask[:, :log_s_sum.shape[1], :]) loss_s = -log_s_sum.float().sum()/n_elems loss_w = -logdet_w_sum.float().sum()/(z_channels*dec_T) loss = loss_z+loss_w+loss_s return loss, loss_z, loss_w, loss_s
def forward(self, model, pred, gt, loss_scalars, resGAN=None, dbGAN=None, infGAN=None): loss_dict = {} file_losses = {}# dict of {"audiofile": {"spec_MSE": spec_MSE, "avg_prob": avg_prob, ...}, ...} B, n_mel, mel_T = gt['gt_mel'].shape tfB = B//(model.decoder.half_inference_mode+1) for i in range(tfB): current_time = time.time() if gt['audiopath'][i] not in file_losses: file_losses[gt['audiopath'][i]] = {'speaker_id_ext': gt['speaker_id_ext'][i], 'time': current_time} if True: pred_mel_postnet = pred['pred_mel_postnet'] pred_mel = pred['pred_mel'] gt_mel = gt['gt_mel'] mel_lengths = gt['mel_lengths'] mask = get_mask_from_lengths(mel_lengths) mask = mask.expand(gt_mel.size(1), *mask.shape).permute(1, 0, 2) pred_mel_postnet.masked_fill_(~mask, 0.0) pred_mel .masked_fill_(~mask, 0.0) with torch.no_grad(): assert not torch.isnan(pred_mel).any(), 'mel has NaNs' assert not torch.isinf(pred_mel).any(), 'mel has Infs' assert not torch.isnan(pred_mel_postnet).any(), 'mel has NaNs' assert not torch.isinf(pred_mel_postnet).any(), 'mel has Infs' if model.decoder.half_inference_mode: pred_mel_postnet = pred_mel_postnet.chunk(2, dim=0)[0] pred_mel = pred_mel .chunk(2, dim=0)[0] gt_mel = gt_mel .chunk(2, dim=0)[0] mel_lengths = mel_lengths .chunk(2, dim=0)[0] mask = mask .chunk(2, dim=0)[0] B, n_mel, mel_T = gt_mel.shape teacher_force_till = loss_scalars.get('teacher_force_till', 0) p_teacher_forcing = loss_scalars.get('p_teacher_forcing' , 1.0) if p_teacher_forcing == 0.0 and teacher_force_till > 1: gt_mel = gt_mel [:, :, :teacher_force_till] pred_mel = pred_mel [:, :, :teacher_force_till] pred_mel_postnet = pred_mel_postnet[:, :, :teacher_force_till] mel_lengths = mel_lengths.clamp(max=teacher_force_till) # spectrogram / decoder loss pred_mel_selected = torch.masked_select(pred_mel, mask) gt_mel_selected = torch.masked_select(gt_mel, mask) spec_SE = nn.MSELoss(reduction='none')(pred_mel_selected, gt_mel_selected) loss_dict['spec_MSE'] = spec_SE.mean() losses = spec_SE.split([x*n_mel for x in mel_lengths.cpu()]) for i in range(tfB): audiopath = gt['audiopath'][i] file_losses[audiopath]['spec_MSE'] = losses[i].mean().item() # postnet pred_mel_postnet_selected = torch.masked_select(pred_mel_postnet, mask) loss_dict['postnet_MSE'] = nn.MSELoss()(pred_mel_postnet_selected, gt_mel_selected) # squared by frame, mean postnet mask = mask.transpose(1, 2)[:, :, :1]# [B, mel_T, n_mel] -> [B, mel_T, 1] spec_AE = nn.L1Loss(reduction='none')(pred_mel, gt_mel).transpose(1, 2)# -> [B, mel_T, n_mel] spec_AE = spec_AE.masked_select(mask).view(mel_lengths.sum(), n_mel) # -> [B* mel_T, n_mel] loss_dict['spec_MFSE'] = (spec_AE * spec_AE.mean(dim=1, keepdim=True)).mean()# multiply by frame means (similar to square op from MSE) and get the mean of the losses post_AE = nn.L1Loss(reduction='none')(pred_mel_postnet, gt_mel).transpose(1, 2)# -> [B, mel_T, n_mel] post_AE = post_AE.masked_select(mask).view(mel_lengths.sum(), n_mel)# -> [B*mel_T, n_mel] loss_dict['postnet_MFSE'] = (post_AE * post_AE.mean(dim=1, keepdim=True)).mean()# multiply by frame means (similar to square op from MSE) and get the mean of the losses del gt_mel, spec_AE, post_AE,#pred_mel_postnet, pred_mel if True: # gate/stop loss gate_target = gt['gt_gate_logits' ] gate_out = pred['pred_gate_logits'] if model.decoder.half_inference_mode: gate_target = gate_target.chunk(2, dim=0)[0] gate_out = gate_out .chunk(2, dim=0)[0] gate_target = gate_target.view(-1, 1) gate_out = gate_out.view(-1, 1) loss_dict['gate_loss'] = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)(gate_out, gate_target) del gate_target, gate_out if True: # SylpsNet loss syl_mu = pred['pred_sylps_mu'] syl_logvar = pred['pred_sylps_logvar'] if model.decoder.half_inference_mode: syl_logvar = syl_logvar.chunk(2, dim=0)[0] syl_mu = syl_mu .chunk(2, dim=0)[0] loss_dict['sylps_kld'] = -0.5 * (1 + syl_logvar - syl_logvar.exp() - syl_mu.pow(2)).sum()/B del syl_mu, syl_logvar if True: # Pred Sylps loss pred_sylps = pred['pred_sylps'].squeeze(1)# [B, 1] -> [B] sylps_target = gt['gt_sylps'] if model.decoder.half_inference_mode: pred_sylps = pred_sylps .chunk(2, dim=0)[0] sylps_target = sylps_target .chunk(2, dim=0)[0] loss_dict['sylps_MAE'] = nn.L1Loss()(pred_sylps, sylps_target) loss_dict['sylps_MSE'] = nn.MSELoss()(pred_sylps, sylps_target) del pred_sylps, sylps_target if True:# Diagonal Attention Guiding alignments = pred['alignments'] text_lengths = gt['text_lengths'] output_lengths = gt['mel_lengths'] pres_prev_state= gt['pres_prev_state'] if model.decoder.half_inference_mode: alignments = alignments .chunk(2, dim=0)[0] text_lengths = text_lengths .chunk(2, dim=0)[0] output_lengths = output_lengths .chunk(2, dim=0)[0] pres_prev_state = pres_prev_state.chunk(2, dim=0)[0] loss_dict['diag_att'] = self.guided_att(alignments[pres_prev_state==0.0], text_lengths[pres_prev_state==0.0], output_lengths[pres_prev_state==0.0]) del alignments, text_lengths, output_lengths, pres_prev_state if self.use_res_enc and resGAN is not None:# Residual Encoder KL Divergence Loss mu, logvar, mulogvar = pred['res_enc_pkg'] kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) loss_dict['res_enc_kld'] = kl_loss # discriminator attempts to predict the letters and speakers using the residual latent space, # the generator attempts to increase the discriminators loss so the latent space will lose speaker and dur info # making it more likely the latent contains information relating to background noise conditions and # other features more relavent to human interests. with torch.no_grad(): gt_speakers = gt['speaker_id_onehot'].float() # [B, n_speakers] gt_sym_durs = get_class_durations(gt['text'], pred['alignments'].detach(), self.n_symbols)# [B, n_symbols] out = resGAN.discriminator(mulogvar)# learns to predict the speaker and B = out.shape[0] pred_sym_durs, pred_speakers = out.squeeze(-1).split([self.n_symbols, self.n_speakers], dim=1) # amount of 'a','b','c','.', etc sounds that are in the audio. # if there isn't a 'd' sound in the transcript, then d will be 0.0 # if there are multiple 'a' sounds, their durations are summed. pred_speakers = torch.nn.functional.softmax(pred_speakers, dim=1) loss_dict['res_enc_gMSE'] = (nn.MSELoss(reduction='sum')(pred_sym_durs, gt_sym_durs.mean(dim=1, keepdim=True))*0.0001 + nn.MSELoss(reduction='sum')(pred_speakers, gt_speakers.mean(dim=1, keepdim=True)))/B resGAN.gt_speakers = gt_speakers resGAN.gt_sym_durs = gt_sym_durs del mu, logvar, kl_loss, gt_speakers, gt_sym_durs, pred_sym_durs, pred_speakers if 1 and model.training and self.use_dbGAN and dbGAN is not None: pred_mel_postnet = pred['pred_mel_postnet'].unsqueeze(1)# -> [tfB, 1, n_mel, mel_T] pred_mel = pred['pred_mel'] .unsqueeze(1)# -> [tfB, 1, n_mel, mel_T] speaker_embed = pred['speaker_embed'] if model.decoder.half_inference_mode: pred_mel_postnet = pred_mel_postnet.chunk(2, dim=0)[0] pred_mel = pred_mel .chunk(2, dim=0)[0] speaker_embed = speaker_embed .chunk(2, dim=0)[0] B, _, n_mel, mel_T = pred_mel.shape mels = torch.cat((pred_mel, pred_mel_postnet), dim=0).float()# [2*B, 1, n_mel, mel_T] with torch.no_grad(): assert not (torch.isnan(mels) | torch.isinf(mels)).any(), 'NaN or Inf value found in computation' # if False: # pred_fakeness = checkpoint(dbGAN.discriminator, mels, speaker_id.repeat(2)).squeeze(1)# -> [2*B, mel_T//?] # else: pred_fakeness = dbGAN.discriminator(mels, speaker_embed.repeat(2, 1)).squeeze(1)# -> [2*B, mel_T//?] pred_fakeness, postnet_fakeness = pred_fakeness.chunk(2, dim=0)# -> [B, mel_T//?], [B, mel_T//?] tfB, post_mel_T = pred_fakeness.shape real_label = torch.ones(tfB, post_mel_T, device=pred_mel.device, dtype=pred_mel.dtype)*-1.0# [B] loss_dict['dbGAN_gLoss'] = F.mse_loss(pred_fakeness, real_label)*0.5 + F.mse_loss(postnet_fakeness, real_label)*0.5 with torch.no_grad(): assert not torch.isnan(loss_dict['dbGAN_gLoss']), 'dbGAN loss is NaN' assert not torch.isinf(loss_dict['dbGAN_gLoss']), 'dbGAN loss is Inf' del mels, real_label, pred_fakeness, postnet_fakeness, pred_mel, pred_mel_postnet, speaker_embed if self.use_InfGAN and infGAN is not None and model.decoder.half_inference_mode: with torch.no_grad(): pred_gate = pred['pred_gate_logits'].chunk(2, dim=0)[1].sigmoid() pred_gate[:, :5] = 0.0 # Get inference alignment scores pred_mel_lengths = get_first_over_thresh(pred_gate, 0.5) pred_mel_lengths.clamp_(max=mel_T) pred['pred_mel_lengths'] = pred_mel_lengths mask = get_mask_from_lengths(pred_mel_lengths, max_len=mel_T).unsqueeze(1)# [B, 1, mel_T] tfB = pred_gate.shape[0] with freeze_grads(model.decoder.prenet): args = infGAN.merge_inputs(model, pred, gt, tfB, mask)# [B/2, mel_T, embed] if infGAN.training and infGAN.gradient_checkpoint: inf_infness = checkpoint(infGAN.discriminator, *args).squeeze(1)# -> [B/2, mel_T] else: inf_infness = infGAN.discriminator(*args).squeeze(1)# -> [B/2, mel_T] tf_label = torch.ones(tfB, device=pred_gate.device, dtype=pred_gate.dtype)[:, None].expand(tfB, mel_T)*-1.# [B/2] loss_dict['InfGAN_gLoss'] = 2.*F.mse_loss(inf_infness, tf_label) ################################################################# ## Colate / Merge the Losses into a single tensor with scalars ## ################################################################# loss_dict = self.colate_losses(loss_dict, loss_scalars) with torch.no_grad():# get Avg Max Attention and Diagonality Metrics atd = alignment_metric(pred['alignments'].detach().clone(), gt['text_lengths'], gt['mel_lengths']) diagonalitys, avg_prob, char_max_dur, char_min_dur, char_avg_dur, p_missing_enc = atd.values() loss_dict['diagonality'] = diagonalitys.mean() loss_dict['avg_max_attention'] = avg_prob.mean() for i in range(tfB): audiopath = gt['audiopath'][i] file_losses[audiopath]['avg_max_attention'] = avg_prob[i].cpu().item() file_losses[audiopath]['att_diagonality'] = diagonalitys[i].cpu().item() file_losses[audiopath]['p_missing_enc'] = p_missing_enc[i].cpu().item() file_losses[audiopath]['char_max_dur'] = char_max_dur[i].cpu().item() file_losses[audiopath]['char_min_dur'] = char_min_dur[i].cpu().item() file_losses[audiopath]['char_avg_dur'] = char_avg_dur[i].cpu().item() if 0: diagonality_path = f'{os.path.splitext(audiopath)[0]}_diag.pt' torch.save(diagonalitys[i].detach().clone().cpu(), diagonality_path) avg_prob_path = f'{os.path.splitext(audiopath)[0]}_avgp.pt' torch.save( avg_prob[i].detach().clone().cpu(), avg_prob_path ) pred_gate = pred['pred_gate_logits'].sigmoid() pred_gate[:, :5] = 0.0 # Get inference alignment scores pred_mel_lengths = get_first_over_thresh(pred_gate, 0.7) atd = alignment_metric(pred['alignments'].detach().clone(), gt['text_lengths'], pred_mel_lengths) atd = {k: v.cpu() for k, v in atd.items()} diagonalitys, avg_prob, char_max_dur, char_min_dur, char_avg_dur, p_missing_enc = atd.values() scores = [] for i in range(tfB): # factors that make up score weighted_score = avg_prob[i].item() # general alignment quality diagonality_punishment = max( diagonalitys[i].item()-1.10, 0) * 0.25 # speaking each letter at a similar pace. max_dur_punishment = max( char_max_dur[i].item()-60.0, 0) * 0.005# getting stuck on same letter for 0.5s min_dur_punishment = max(0.00-char_min_dur[i].item(), 0) * 0.5 # skipping single enc outputs avg_dur_punishment = max(3.60-char_avg_dur[i].item(), 0) # skipping most enc outputs mis_dur_punishment = max(p_missing_enc[i].item()-0.08, 0) if gt['text_lengths'][i] > 12 and gt['mel_lengths'][i] < gt['mel_lengths'].max()*0.75 else 0.0 # skipping some percent of the text weighted_score -= (diagonality_punishment+max_dur_punishment+min_dur_punishment+avg_dur_punishment+mis_dur_punishment) scores.append(weighted_score) file_losses[audiopath]['att_score'] = weighted_score scores = torch.tensor(scores) scores[torch.isnan(scores)] = scores[~torch.isnan(scores)].mean() loss_dict['weighted_score'] = scores.to(pred['alignments'].device).mean() return loss_dict, file_losses
def forward(self, pred, gt, loss_scalars): loss_dict = {} file_losses = {}# dict of {"audiofile": {"spec_MSE": spec_MSE, "avg_prob": avg_prob, ...}, ...} B, n_mel, mel_T = gt['gt_mel'].shape for i in range(B): current_time = time.time() if gt['audiopath'][i] not in file_losses: file_losses[gt['audiopath'][i]] = {'speaker_id_ext': gt['speaker_id_ext'][i], 'time': current_time} if True: pred_mel_postnet = pred['pred_mel_postnet'] pred_mel = pred['pred_mel'] gt_mel = gt['gt_mel'] B, n_mel, mel_T = gt_mel.shape mask = get_mask_from_lengths(gt['mel_lengths']) mask = mask.expand(gt_mel.size(1), *mask.shape).permute(1, 0, 2) # spectrogram / decoder loss pred_mel = torch.masked_select(pred_mel, mask) gt_mel = torch.masked_select(gt_mel, mask) spec_SE = nn.MSELoss(reduction='none')(pred_mel, gt_mel) loss_dict['spec_MSE'] = spec_SE.mean() losses = spec_SE.split([x*n_mel for x in gt['mel_lengths'].cpu()]) for i in range(B): audiopath = gt['audiopath'][i] file_losses[audiopath]['spec_MSE'] = losses[i].mean().item() # postnet pred_mel_postnet.masked_fill_(~mask, 0.0) pred_mel_postnet = torch.masked_select(pred_mel_postnet, mask) loss_dict['postnet_MSE'] = nn.MSELoss()(pred_mel_postnet, gt_mel) # squared by frame, mean postnet mask = get_mask_from_lengths(gt['mel_lengths']).unsqueeze(-1)# -> [B, mel_T] -> [B, mel_T, 1] spec_AE = nn.L1Loss(reduction='none')(pred['pred_mel'], gt['gt_mel']).transpose(1, 2)# -> [B, mel_T, n_mel] spec_AE = spec_AE.masked_select(mask).view(gt['mel_lengths'].sum(), n_mel)# -> [B*mel_T, n_mel] loss_dict['spec_MFSE'] = (spec_AE * spec_AE.mean(dim=1, keepdim=True)).mean()# multiply by frame means (similar to square op from MSE) and get the mean of the losses post_AE = nn.L1Loss(reduction='none')(pred['pred_mel_postnet'], gt['gt_mel']).transpose(1, 2)# -> [B, mel_T, n_mel] post_AE = post_AE.masked_select(mask).view(gt['mel_lengths'].sum(), n_mel)# -> [B*mel_T, n_mel] loss_dict['postnet_MFSE'] = (post_AE * post_AE.mean(dim=1, keepdim=True)).mean()# multiply by frame means (similar to square op from MSE) and get the mean of the losses if True: # gate/stop loss gate_target = gt['gt_gate_logits'].view(-1, 1) gate_out = pred['pred_gate_logits'].view(-1, 1) loss_dict['gate_loss'] = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)(gate_out, gate_target) del gate_target, gate_out if True: # SylpsNet loss syl_mu = pred['pred_sylps_mu'] syl_logvar = pred['pred_sylps_logvar'] loss_dict['sylps_kld'] = -0.5 * (1 + syl_logvar - syl_logvar.exp() - syl_mu.pow(2)).sum()/B del syl_mu, syl_logvar if True: # Pred Sylps loss pred_sylps = pred['pred_sylps'].squeeze(1)# [B, 1] -> [B] sylps_target = gt['gt_sylps'] loss_dict['sylps_MAE'] = nn.L1Loss()(pred_sylps, sylps_target) loss_dict['sylps_MSE'] = nn.MSELoss()(pred_sylps, sylps_target) del pred_sylps, sylps_target if True:# Diagonal Attention Guiding alignments = pred['alignments'] text_lengths = gt['text_lengths'] output_lengths = gt['mel_lengths'] pres_prev_state= gt['pres_prev_state'] loss_dict['diag_att'] = self.guided_att(alignments[pres_prev_state==0.0], text_lengths[pres_prev_state==0.0], output_lengths[pres_prev_state==0.0]) del alignments, text_lengths, output_lengths ################################################################# ## Colate / Merge the Losses into a single tensor with scalars ## ################################################################# loss_dict = self.colate_losses(loss_dict, loss_scalars) with torch.no_grad():# get Avg Max Attention and Diagonality Metrics atd = alignment_metric(pred['alignments'], gt['text_lengths'], gt['mel_lengths']) diagonalitys, avg_prob, char_max_dur, char_min_dur, char_avg_dur, p_missing_enc = atd.values() loss_dict['diagonality'] = diagonalitys.mean() loss_dict['avg_max_attention'] = avg_prob.mean() for i in range(B): audiopath = gt['audiopath'][i] file_losses[audiopath]['avg_max_attention'] = avg_prob[i].cpu().item() file_losses[audiopath]['att_diagonality' ] = diagonalitys[i].cpu().item() file_losses[audiopath]['p_missing_enc'] = p_missing_enc[i].cpu().item() file_losses[audiopath]['char_max_dur'] = char_max_dur[i].cpu().item() file_losses[audiopath]['char_min_dur'] = char_min_dur[i].cpu().item() file_losses[audiopath]['char_avg_dur'] = char_avg_dur[i].cpu().item() pred_gate = pred['pred_gate_logits'].sigmoid() pred_gate[:, :5] = 0.0 # Get inference alignment scores pred_mel_lengths = get_first_over_thresh(pred_gate, 0.7) atd = alignment_metric(pred['alignments'], gt['text_lengths'], pred_mel_lengths) atd = {k: v.cpu() for k, v in atd.items()} diagonalitys, avg_prob, char_max_dur, char_min_dur, char_avg_dur, p_missing_enc = atd.values() scores = [] for i in range(B): # factors that make up score weighted_score = avg_prob[i].item() # general alignment quality diagonality_punishment = max( diagonalitys[i].item()-1.10, 0) * 0.25 # speaking each letter at a similar pace. max_dur_punishment = max( char_max_dur[i].item()-60.0, 0) * 0.005# getting stuck on same letter for 0.5s min_dur_punishment = max(0.00-char_min_dur[i].item(), 0) * 0.5 # skipping single enc outputs avg_dur_punishment = max(3.60-char_avg_dur[i].item(), 0) # skipping most enc outputs mis_dur_punishment = max(p_missing_enc[i].item()-0.08, 0) if gt['text_lengths'][i] > 12 and gt['mel_lengths'][i] < gt['mel_lengths'].max()*0.75 else 0.0 # skipping some percent of the text weighted_score -= (diagonality_punishment+max_dur_punishment+min_dur_punishment+avg_dur_punishment+mis_dur_punishment) scores.append(weighted_score) file_losses[audiopath]['att_score'] = weighted_score scores = torch.tensor(scores) scores[torch.isnan(scores)] = scores[~torch.isnan(scores)].mean() loss_dict['weighted_score'] = scores.to(pred['alignments'].device).mean() return loss_dict, file_losses
def forward(self, model_output, targets, loss_scalars): # loss scalars MelGlow_ls = loss_scalars['MelGlow_ls'] if loss_scalars['MelGlow_ls'] is not None else self.MelGlow_loss_scalar DurGlow_ls = loss_scalars['DurGlow_ls'] if loss_scalars['DurGlow_ls'] is not None else self.DurGlow_loss_scalar VarGlow_ls = loss_scalars['VarGlow_ls'] if loss_scalars['VarGlow_ls'] is not None else self.VarGlow_loss_scalar Sylps_ls = loss_scalars['Sylps_ls' ] if loss_scalars['Sylps_ls' ] is not None else self.Sylps_loss_scalar # loss_func mel_target, text_lengths, output_lengths, perc_loudness_target, f0_target, energy_target, sylps_target, voiced_mask, char_f0, char_voiced, char_energy, *_ = targets B, n_mel, dec_T = mel_target.shape enc_T = text_lengths.max() output_lengths_float = output_lengths.float() loss_dict = {} # Decoder / MelGlow Loss if True: mel_z, log_s_sum, logdet_w_sum = model_output['melglow'] # remove paddings before loss calc mask = get_mask_from_lengths(output_lengths)[:, None, :] # [B, 1, T] BoolTensor mask = mask.expand(mask.size(0), mel_target.size(1), mask.size(2))# [B, n_mel, T] BoolTensor n_elems = (output_lengths_float.sum() * n_mel) mel_z = torch.masked_select(mel_z, mask) dec_loss_z = ((mel_z.pow(2).sum()) / self.sigma2_2)/n_elems # mean z (over all elements) log_s_sum = log_s_sum.view(B, -1, dec_T) log_s_sum = torch.masked_select(log_s_sum , mask[:, :log_s_sum.shape[1], :]) dec_loss_s = -log_s_sum.sum()/(n_elems) dec_loss_w = -logdet_w_sum.sum()/(n_mel*dec_T) dec_loss_d = dec_loss_z+dec_loss_w+dec_loss_s loss = dec_loss_d*MelGlow_ls del mel_z, log_s_sum, logdet_w_sum, mask, n_elems loss_dict["Decoder_Loss_Z"] = dec_loss_z loss_dict["Decoder_Loss_W"] = dec_loss_w loss_dict["Decoder_Loss_S"] = dec_loss_s loss_dict["Decoder_Loss_Total"] = dec_loss_d assert not (torch.isnan(loss) | torch.isinf(loss)).any(), 'Inf/NaN Loss at MelGlow Latents' # CVarGlow Loss if True: z, log_s_sum, logdet_w_sum = model_output['cvarglow'] _ = glow_loss(z, log_s_sum, logdet_w_sum, text_lengths, self.dg_sigma2_2) cvar_loss_d, cvar_loss_z, cvar_loss_w, cvar_loss_s = _ if self.DurGlow_loss_scalar: loss = loss + cvar_loss_d*DurGlow_ls del z, log_s_sum, logdet_w_sum loss_dict["CVar_Loss_Z"] = cvar_loss_z loss_dict["CVar_Loss_W"] = cvar_loss_w loss_dict["CVar_Loss_S"] = cvar_loss_s loss_dict["CVar_Loss_Total"] = cvar_loss_d assert not (torch.isnan(loss) | torch.isinf(loss)).any(), 'Inf/NaN Loss at CVarGlow Latents' # FramGlow Loss if True: z, log_s_sum, logdet_w_sum = model_output['varglow'] z_channels = 6 z = z.view(z.shape[0], z_channels, -1) # remove paddings before loss calc mask = get_mask_from_lengths(output_lengths)[:, None, :]# [B, 1, T] BoolTensor mask = mask.expand(mask.size(0), z_channels, mask.size(2))# [B, n_mel, T] BoolTensor n_elems = (output_lengths_float.sum() * z_channels) z = torch.masked_select(z, mask) var_loss_z = ((z.pow(2).sum()) / self.sigma2_2)/n_elems # mean z (over all elements) log_s_sum = log_s_sum.view(B, -1, dec_T) log_s_sum = torch.masked_select(log_s_sum , mask[:, :log_s_sum.shape[1], :]) var_loss_s = -log_s_sum.sum()/(n_elems) var_loss_w = -logdet_w_sum.sum()/(z_channels*dec_T) var_loss_d = var_loss_z+var_loss_w+var_loss_s loss = loss + var_loss_d*VarGlow_ls del z, log_s_sum, logdet_w_sum, mask, n_elems, z_channels loss_dict["Variance_Loss_Z"] = var_loss_z loss_dict["Variance_Loss_W"] = var_loss_w loss_dict["Variance_Loss_S"] = var_loss_s loss_dict["Variance_Loss_Total"] = var_loss_d assert not (torch.isnan(loss) | torch.isinf(loss)).any(), 'Inf/NaN Loss at VarGlow Latents' # Sylps Loss if True: enc_global_outputs, sylps = model_output['sylps']# [B, 2], [B] mu, logvar = enc_global_outputs.transpose(0, 1)[:2, :]# [2, B] loss_dict["zSylps_Loss"] = NormalLLLoss(mu, logvar, sylps)# [B], [B], [B] -> [B] loss = loss + loss_dict["zSylps_Loss"]*Sylps_ls del mu, logvar, enc_global_outputs, sylps assert not (torch.isnan(loss) | torch.isinf(loss)).any(), 'Inf/NaN Loss at Pred Sylps' # Perceived Loudness Loss if True: enc_global_outputs, perc_loudness = model_output['perc_loud']# [B, 2], [B] mu, logvar = enc_global_outputs.transpose(0, 1)[2:4, :]# [2, B] loss_dict["zPL_Loss"] = NormalLLLoss(mu, logvar, perc_loudness)# [B], [B], [B] -> [B] loss = loss + loss_dict["zPL_Loss"]*Sylps_ls del mu, logvar, enc_global_outputs, perc_loudness assert not (torch.isnan(loss) | torch.isinf(loss)).any(), 'Inf/NaN Loss at Pred Perceived Loudness' loss_dict["loss"] = loss return loss_dict
def inference( self, text: Tensor, # LongTensor[B, enc_T] speaker_ids: Tensor, # LongTensor[B] torchmoji_hidden: Tensor, # FloatTensor[B, embed] sylps: Optional[Tensor] = None, # FloatTensor[B] or None text_lengths: Optional[ Tensor] = None, # LongTensor[B] or None durations: Optional[ Tensor] = None, # FloatTensor[B, enc_T] or None perc_loudness: Optional[ Tensor] = None, # FloatTensor[B] or None f0: Optional[Tensor] = None, # FloatTensor[B, dec_T] or None energy: Optional[Tensor] = None, # FloatTensor[B, dec_T] or None mel_sigma: float = 1.0, dur_sigma: float = 1.0, var_sigma: float = 1.0): assert not self.training, "model must be in eval() mode" # move Tensors to GPU (if not already there) text, speaker_ids, torchmoji_hidden, sylps, text_lengths, durations, perc_loudness, f0, energy = self.update_device( text, speaker_ids, torchmoji_hidden, sylps, text_lengths, durations, perc_loudness, f0, energy) B, enc_T = text.shape if text_lengths is None: text_lengths = torch.ones((B, )).to(text) * enc_T assert text_lengths is not None melenc_outputs = self.mel_encoder( gt_mels, output_lengths, speaker_ids=speaker_ids) if ( self.mel_encoder is not None and not self.melenc_ignore) else None # [B, dec_T, melenc_dim] embedded_text = self.embedding(text).transpose( 1, 2) # [B, embed, sequence] encoder_outputs, enc_global_outputs = self.encoder( embedded_text, text_lengths, speaker_ids=speaker_ids) # [B, enc_T, enc_dim] if sylps is None: sylps = enc_global_outputs[:, 0:1] # [B, 1] if perc_loudness is None: perc_loudness = enc_global_outputs[:, 2:3] # [B, 1] assert sylps is not None # needs to be updated with pred_sylps soon ^TM memory = [ encoder_outputs, ] if self.speaker_embedding_dim: embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] embedded_speakers = embedded_speakers.repeat(1, enc_T, 1) memory.append(embedded_speakers) # [B, enc_T, enc_dim] if sylps is not None: sylps = sylps[..., None] # [B, 1] -> [B, 1, 1] sylps = sylps.repeat(1, enc_T, 1) memory.append(sylps) # [B, enc_T, enc_dim] if perc_loudness is not None: perc_loudness = perc_loudness[..., None] # [B, 1] -> [B, 1, 1] perc_loudness = perc_loudness.repeat(1, enc_T, 1) memory.append(perc_loudness) # [B, enc_T, enc_dim] if torchmoji_hidden is not None: emotion_embed = torchmoji_hidden.unsqueeze( 1) # [B, C] -> [B, 1, C] emotion_embed = self.torchmoji_linear( emotion_embed) # [B, 1, in_C] -> [B, 1, out_C] emotion_embed = emotion_embed.repeat(1, enc_T, 1) memory.append(emotion_embed) # [B, enc_T, enc_dim] memory = torch.cat( memory, dim=2 ) # [[B, enc_T, enc_dim], [B, enc_T, speaker_dim]] -> [B, enc_T, enc_dim+speaker_dim] assert not (torch.isnan(memory) | torch.isinf(memory)).any(), 'Inf/NaN Loss at memory' # CVarGlow mask = get_mask_from_lengths(text_lengths) # [B, T] cvars = self.cvar_glow.infer(memory.transpose(1, 2), sigma=dur_sigma) # ([B, enc_dim, enc_T] , ) norm_char_f0 = cvars[:, 1:2] norm_char_energy = cvars[:, 2:3] char_voiced = cvars[:, 3:4] char_f0 = self.bn_cf0.inverse(norm_char_f0) char_energy = self.bn_cenergy.inverse(norm_char_energy) enc_durations = self.lbn_duration.inverse( cvars[:, :1], mask) # [B, 8, enc_T] -> [B, 1, enc_T] memory = torch.cat((memory, cvars[:, 1:4].transpose(1, 2)), dim=2) # [B, enc_T, enc_dim] +cat+ [B, enc_T, 3] attention_contexts = get_attention_from_lengths( memory, enc_durations[:, 0, :], text_lengths) # -> [B, dec_T, enc_dim] B, dec_T, enc_dim = attention_contexts.shape variances = self.var_glow.infer(attention_contexts.transpose(1, 2), sigma=var_sigma) variances = variances.chunk(2, dim=1)[0] # [B, 3, dec_T] voiced_mask = variances[:, 0, :] f0 = self.bn_f0.inverse(variances[:, 1:2, :]).squeeze(1) energy = self.bn_energy.inverse(variances[:, 2:3, :]).squeeze(1) global_cond = None if self.melenc_enable: # take all current info, and produce global cond tokens which can be randomly sampled from later global_cond = torch.randn(B, n_tokens) # [B, n_tokens] # Decoder cond = [attention_contexts.transpose(1, 2), variances] if global_cond is not None: cond.append(global_cond) cond = torch.cat(cond, dim=1) spect = self.decoder.infer(cond, sigma=mel_sigma) outputs = { "spect": spect, "char_durs": enc_durations, "char_voiced": char_voiced, "char_f0": char_f0, "char_energy": char_energy, "frame_voiced_mask": voiced_mask, "frame_f0": f0, "frame_energy": energy, } return outputs
def forward(self, cond_inp, output_lengths, cond_lens=None): # [B, seq_len, dim], int, [B] batch_size, enc_T, enc_dim = cond_inp.shape # get Random Position Offset (this *might* allow better distance generalisation) #trandint = torch.randint(10000, (1,), device=cond_inp.device, dtype=cond_inp.dtype) # get Query from Positional Encoding dec_T_max = output_lengths.max().item() dec_pos_emb = torch.arange(0, dec_T_max, device=cond_inp.device, dtype=cond_inp.dtype) # + trandint if hasattr(self, 'pos_embedding_q'): dec_pos_emb = self.pos_embedding_q( dec_pos_emb.clamp( 0, self.pos_embedding_q_max - 1).long())[None, ...].repeat( cond_inp.size(0), 1, 1) # [B, enc_T, enc_dim] elif hasattr(self, 'positional_embedding'): dec_pos_emb = self.positional_embedding( dec_pos_emb, bsz=cond_inp.size(0)) # [B, dec_T, enc_dim] if not self.merged_pos_enc: dec_pos_emb = dec_pos_emb.repeat(1, 1, self.head_num) if output_lengths is not None: # masking for batches dec_mask = get_mask_from_lengths(output_lengths).unsqueeze( 2) # [B, dec_T, 1] dec_pos_emb = dec_pos_emb * dec_mask # [B, dec_T, enc_dim] * [B, dec_T, 1] -> [B, dec_T, enc_dim] q = dec_pos_emb # [B, dec_T, enc_dim] # get Key/Value from Encoder Outputs k = v = cond_inp # [B, enc_T, enc_dim] # (optional) add position encoding to Encoder outputs if hasattr(self, 'enc_positional_embedding'): enc_pos_emb = torch.arange(0, enc_T, device=cond_inp.device, dtype=cond_inp.dtype) # + trandint if hasattr(self, 'pos_embedding_kv'): enc_pos_emb = self.pos_embedding_kv( enc_pos_emb.clamp(0, self.pos_embedding_kv_max - 1).long())[None, ...].repeat( cond_inp.size(0), 1, 1) # [B, enc_T, enc_dim] elif hasattr(self, 'enc_positional_embedding'): enc_pos_emb = self.enc_positional_embedding( enc_pos_emb, bsz=cond_inp.size(0)) # [B, enc_T, enc_dim] if self.pos_enc_k: k = k + enc_pos_emb if self.pos_enc_v: v = v + enc_pos_emb enc_mask = get_mask_from_lengths(cond_lens).unsqueeze(1).repeat( 1, q.size(1), 1) if (cond_lens is not None) else None # [B, dec_T, enc_T] if not self.pytorch_native_mha: output, attention_scores = self.multi_head_attention( q, k, v, mask=enc_mask ) # [B, dec_T, enc_dim], [B, n_head, dec_T, enc_T] else: q = q.transpose(0, 1) # [B, dec_T, enc_dim] -> [dec_T, B, enc_dim] k = k.transpose(0, 1) # [B, enc_T, enc_dim] -> [enc_T, B, enc_dim] v = v.transpose(0, 1) # [B, enc_T, enc_dim] -> [enc_T, B, enc_dim] enc_mask = ~enc_mask[:, 0, :] if ( cond_lens is not None) else None # [B, dec_T, enc_T] -> # [B, enc_T] attn_mask = ~get_mask_3d( output_lengths, cond_lens).repeat_interleave(self.head_num, 0) if ( cond_lens is not None) else None #[B*n_head, dec_T, enc_T] attn_mask = attn_mask.float() * -35500.0 if (cond_lens is not None) else None output, attention_scores = self.multi_head_attention( q, k, v, key_padding_mask=enc_mask, attn_mask=attn_mask) # [dec_T, B, enc_dim], [B, dec_T, enc_T] output = output.transpose( 0, 1) # [dec_T, B, enc_dim] -> [B, dec_T, enc_dim] output = output + self.o_residual_weights * dec_pos_emb attention_scores = attention_scores * get_mask_3d( output_lengths, cond_lens) if ( cond_lens is not None) else attention_scores #attention_scores # [B, dec_T, enc_T] for self_att_layer, residual_weight in zip( self.self_attention_layers, self.self_att_o_rws): q = output.transpose( 0, 1) # [B, dec_T, enc_dim] -> [dec_T, B, enc_dim] output, att_sc = self_att_layer( q, k, v, key_padding_mask=enc_mask, attn_mask=attn_mask ) # ..., [dec_T, B, enc_dim], [B, dec_T, enc_T]) output = output.transpose( 0, 1) # [dec_T, B, enc_dim] -> [B, dec_T, enc_dim] output = output * residual_weight + q.transpose( 0, 1) # ([B, dec_T, enc_dim] * rw) + [B, dec_T, enc_dim] att_sc = att_sc * get_mask_3d(output_lengths, cond_lens) if ( cond_lens is not None) else att_sc attention_scores = attention_scores + att_sc attention_scores = attention_scores / (1 + len(self.self_att_o_rws)) if output_lengths is not None: output = output * dec_mask # [B, dec_T, enc_dim] * [B, dec_T, 1] return output, attention_scores
def forward(self, model, pred, gt, loss_scalars,): loss_dict = {} file_losses = {}# dict of {"audiofile": {"spec_MSE": spec_MSE, "avg_prob": avg_prob, ...}, ...} B, n_mel, mel_T = gt['gt_mel'].shape for i in range(B): current_time = time.time() if gt['audiopath'][i] not in file_losses: file_losses[gt['audiopath'][i]] = {'speaker_id_ext': gt['speaker_id_ext'][i], 'time': current_time} if True: pred_mel_postnet = pred['pred_mel_postnet'] pred_mel = pred['pred_mel'] gt_mel = gt['gt_mel'] mel_lengths = gt['mel_lengths'] mask = get_mask_from_lengths(mel_lengths, max_len=gt_mel.size(2)) mask = mask.expand(gt_mel.size(1), *mask.shape).permute(1, 0, 2) pred_mel_postnet.masked_fill_(~mask, 0.0) pred_mel .masked_fill_(~mask, 0.0) with torch.no_grad(): assert not torch.isnan(pred_mel).any(), 'mel has NaNs' assert not torch.isinf(pred_mel).any(), 'mel has Infs' assert not torch.isnan(pred_mel_postnet).any(), 'mel has NaNs' assert not torch.isinf(pred_mel_postnet).any(), 'mel has Infs' B, n_mel, mel_T = gt_mel.shape # spectrogram / decoder loss pred_mel_selected = torch.masked_select(pred_mel, mask) gt_mel_selected = torch.masked_select(gt_mel, mask) spec_SE = nn.MSELoss(reduction='none')(pred_mel_selected, gt_mel_selected) loss_dict['spec_MSE'] = spec_SE.mean() losses = spec_SE.split([x*n_mel for x in mel_lengths.cpu()]) for i in range(B): audiopath = gt['audiopath'][i] file_losses[audiopath]['spec_MSE'] = losses[i].mean().item() # postnet pred_mel_postnet_selected = torch.masked_select(pred_mel_postnet, mask) loss_dict['postnet_MSE'] = nn.MSELoss()(pred_mel_postnet_selected, gt_mel_selected) # squared by frame, mean postnet mask = mask.transpose(1, 2)[:, :, :1]# [B, mel_T, n_mel] -> [B, mel_T, 1] spec_AE = nn.L1Loss(reduction='none')(pred_mel, gt_mel).transpose(1, 2)# -> [B, mel_T, n_mel] spec_AE = spec_AE.masked_select(mask).view(mel_lengths.sum(), n_mel) # -> [B* mel_T, n_mel] loss_dict['spec_MFSE'] = (spec_AE * spec_AE.mean(dim=1, keepdim=True)).mean()# multiply by frame means (similar to square op from MSE) and get the mean of the losses post_AE = nn.L1Loss(reduction='none')(pred_mel_postnet, gt_mel).transpose(1, 2)# -> [B, mel_T, n_mel] post_AE = post_AE.masked_select(mask).view(mel_lengths.sum(), n_mel)# -> [B*mel_T, n_mel] loss_dict['postnet_MFSE'] = (post_AE * post_AE.mean(dim=1, keepdim=True)).mean()# multiply by frame means (similar to square op from MSE) and get the mean of the losses del gt_mel, spec_AE, post_AE,#pred_mel_postnet, pred_mel if True: # Code semantic loss. code_reconst = model(pred_mel_postnet, gt['speaker_embeds'], None) loss_dict['code_L1'] = F.l1_loss(pred['bottleneck_codes'], code_reconst) ################################################################# ## Colate / Merge the Losses into a single tensor with scalars ## ################################################################# loss_dict = self.colate_losses(loss_dict, loss_scalars) return loss_dict, file_losses
def forward(self, inputs): text, gt_mels, speaker_ids, text_lengths, output_lengths,\ alignments, torchmoji_hidden, perc_loudness, f0, energy,\ sylps, voiced_mask, char_f0, char_voiced, char_energy = inputs # zero mean unit variance normalization of features with torch.no_grad(): perc_loudness = self.bn_pl( perc_loudness.unsqueeze(1)) # [B] -> [B, 1] mask = get_mask_from_lengths(output_lengths) # [B, dec_T] f0 = self.bn_f0( f0.unsqueeze(1), (voiced_mask & mask)) # [B, dec_T] -> [B, 1, dec_T] energy = self.bn_energy(energy.unsqueeze(1), mask) # [B, dec_T] -> [B, 1, dec_T] mask = get_mask_from_lengths(text_lengths) # [B, enc_T] char_f0 = self.bn_cf0(char_f0.unsqueeze(1), mask) # [B, 1, enc_T] char_energy = self.bn_cenergy(char_energy.unsqueeze(1), mask) # [B, 1, enc_T] char_voiced = char_voiced.unsqueeze(1) # [B, 1, enc_T] mask = get_mask_from_lengths(text_lengths) # [B, T] enc_durations = alignments.sum(dim=1).unsqueeze( 1) # [B, dec_T, enc_T] -> [B, enc_T] -> [B, 1, enc_T] ln_enc_durations = self.lbn_duration(enc_durations, mask) # [B, 1, enc_T] Norm embedded_text = self.embedding(text).transpose( 1, 2) # [B, embed, sequence] encoder_outputs, enc_global_outputs = self.encoder( embedded_text, text_lengths, speaker_ids=speaker_ids) # [B, enc_T, enc_dim] memory = [ encoder_outputs, ] if self.speaker_embedding_dim: embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] embedded_speakers = embedded_speakers.repeat( 1, encoder_outputs.size(1), 1) memory.append(embedded_speakers) # [B, enc_T, enc_dim] if sylps is not None: sylps = sylps[:, None, None] # [B] -> [B, 1, 1] sylps = sylps.repeat(1, encoder_outputs.size(1), 1) memory.append(sylps) # [B, enc_T, enc_dim] if perc_loudness is not None: perc_loudness = perc_loudness[..., None] # [B, 1] -> [B, 1, 1] perc_loudness = perc_loudness.repeat(1, encoder_outputs.size(1), 1) memory.append(perc_loudness) # [B, enc_T, enc_dim] if torchmoji_hidden is not None: emotion_embed = torchmoji_hidden.unsqueeze( 1) # [B, C] -> [B, 1, C] emotion_embed = self.torchmoji_linear( emotion_embed) # [B, 1, in_C] -> [B, 1, out_C] emotion_embed = emotion_embed.repeat(1, encoder_outputs.size(1), 1) memory.append(emotion_embed) # [B, enc_T, enc_dim] memory = torch.cat( memory, dim=2 ) # [[B, enc_T, enc_dim], [B, enc_T, speaker_dim]] -> [B, enc_T, enc_dim+speaker_dim] assert not (torch.isnan(memory) | torch.isinf(memory)).any(), 'Inf/NaN Loss at memory' # CVarGlow cvar_gt = torch.cat( (ln_enc_durations, char_f0, char_energy, char_voiced), dim=1).repeat(1, 2, 1) # [B, 4, enc_T] -> [B, 8, enc_T] cvar_z, cvar_log_s_sum, cvar_logdet_w_sum = self.cvar_glow( cvar_gt, memory.transpose(1, 2)) # ([B, enc_T], [B, enc_dim, enc_T]) memory = torch.cat((memory, char_f0.transpose( 1, 2), char_energy.transpose(1, 2), char_voiced.transpose(1, 2)), dim=2) # enc_dim += 3 attention_contexts = alignments @ memory # [B, dec_T, enc_T] @ [B, enc_T, enc_dim] -> [B, dec_T, enc_dim] # Variances Inpainter # cond -> attention_contexts # x/z -> voiced_mask + f0 + energy var_gt = torch.cat((voiced_mask.to(f0.dtype).unsqueeze(1), f0, energy), dim=1) var_gt = var_gt.repeat(1, 2, 1) variance_z, variance_log_s_sum, variance_logdet_w_sum = self.var_glow( var_gt, attention_contexts.transpose(1, 2)) global_cond = None if self.melenc_enable: # take all current info, and produce global cond tokens which can be randomly sampled from later melenc_input = torch.cat( (gt_mels, attention_contexts, voiced_mask.float(), f0, energy), dim=1) global_cond, mu, logvar = self.mel_encoder( melenc_input, output_lengths) # [B, n_tokens] # Decoder cond = [ attention_contexts.transpose(1, 2), voiced_mask.to(f0.dtype).unsqueeze(1), f0, energy ] if global_cond is not None: cond.append(global_cond) cond = torch.cat(cond, dim=1) z, log_s_sum, logdet_w_sum = self.decoder(gt_mels.clone(), cond) # [B, n_mel, dec_T], [B, dec_T, enc_dim] # Series of Flows outputs = { "melglow": [z, log_s_sum, logdet_w_sum], "cvarglow": [cvar_z, cvar_log_s_sum, cvar_logdet_w_sum], "varglow": [variance_z, variance_log_s_sum, variance_logdet_w_sum], "sylps": [enc_global_outputs, sylps], "perc_loud": [enc_global_outputs, perc_loudness], } return outputs
def forward(self, model_output, targets, criterion_dict, iter, em_kl_weight=None, DiagonalGuidedAttention_scalar=None): self.em_kl_weight = self.em_kl_weight if em_kl_weight is None else em_kl_weight self.DiagonalGuidedAttention_scalar = self.DiagonalGuidedAttention_scalar if DiagonalGuidedAttention_scalar is None else DiagonalGuidedAttention_scalar amp, n_gpus, model, model_d, hparams, optimizer, optimizer_d, grad_clip_thresh = criterion_dict.values( ) is_overflow = False grad_norm = 0.0 mel_target, gate_target, output_lengths, text_lengths, emotion_id_target, emotion_onehot_target, sylps_target, preserve_decoder, *_ = targets mel_target.requires_grad = False gate_target.requires_grad = False mel_out, mel_out_postnet, gate_out, alignments, pred_sylps, syl_package, em_package, aux_em_package, gan_package, *_ = model_output gate_target = gate_target.view(-1, 1) gate_out = gate_out.view(-1, 1) Bsz, n_mel, dec_T = mel_target.shape unknown_id = self.n_classes supervised_mask = (emotion_id_target != unknown_id) # [B] BoolTensor unsupervised_mask = ~supervised_mask # [B] BoolTensor # remove paddings before loss calc if self.masked_select: mask = get_mask_from_lengths(output_lengths) mask = mask.expand(mel_target.size(1), mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) mel_target_not_masked = mel_target mel_target = torch.masked_select(mel_target, mask) if self.use_LL_Loss: mel_out, mel_logvar = mel_out.chunk(2, dim=1) if mel_out_postnet is not None: mel_out_postnet, mel_logvar_postnet = mel_out_postnet.chunk( 2, dim=1) mel_logvar = torch.masked_select(mel_logvar, mask) mel_logvar_postnet = torch.masked_select( mel_logvar_postnet, mask) mel_out_not_masked = mel_out mel_out = torch.masked_select(mel_out, mask) if mel_out_postnet is not None: mel_out_postnet_not_masked = mel_out_postnet mel_out_postnet = torch.masked_select(mel_out_postnet, mask) postnet_MSE = postnet_MAE = postnet_SMAE = postnet_LL = torch.tensor( 0.) # spectrogram / decoder loss spec_MSE = nn.MSELoss()(mel_out, mel_target) spec_MAE = nn.L1Loss()(mel_out, mel_target) spec_SMAE = nn.SmoothL1Loss()(mel_out, mel_target) if mel_out_postnet is not None: postnet_MSE = nn.MSELoss()(mel_out_postnet, mel_target) postnet_MAE = nn.L1Loss()(mel_out_postnet, mel_target) postnet_SMAE = nn.SmoothL1Loss()(mel_out_postnet, mel_target) if self.use_LL_Loss: spec_LL = NormalLLLoss(mel_out, mel_logvar, mel_target) loss = (spec_LL * self.melout_LL_scalar) if mel_out_postnet is not None: postnet_LL = NormalLLLoss(mel_out_postnet, mel_logvar_postnet, mel_target) loss += (postnet_LL * self.postnet_LL_scalar) else: spec_LL = postnet_LL = torch.tensor(0.0, device=mel_out.device) loss = (spec_MSE * self.melout_MSE_scalar) loss += (spec_MAE * self.melout_MAE_scalar) loss += (spec_SMAE * self.melout_SMAE_scalar) loss += (postnet_MSE * self.postnet_MSE_scalar) loss += (postnet_MAE * self.postnet_MAE_scalar) loss += (postnet_SMAE * self.postnet_SMAE_scalar) if True: # gate/stop loss gate_loss = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)( gate_out, gate_target) loss += gate_loss if True: # SylpsNet loss sylzu, syl_mu, syl_logvar = syl_package sylKLD = -0.5 * (1 + syl_logvar - syl_logvar.exp() - syl_mu.pow(2)).sum() / Bsz loss += (sylKLD * self.syl_KDL_weight) if True: # Pred Sylps loss pred_sylps = pred_sylps.squeeze(1) # [B, 1] -> [B] sylpsMSE = nn.MSELoss()(pred_sylps, sylps_target) sylpsMAE = nn.L1Loss()(pred_sylps, sylps_target) loss += (sylpsMSE * self.pred_sylpsMSE_weight) loss += (sylpsMAE * self.pred_sylpsMAE_weight) if True: # EmotionNet loss zs, em_zu, em_mu, em_logvar, em_params = [ x.squeeze(1) for x in em_package ] # VAE-GST loss SupervisedLoss = ClassicationMAELoss = ClassicationMSELoss = ClassicationNCELoss = SupervisedKDL = UnsupervisedLoss = UnsupervisedKDL = torch.tensor( 0) kl_scale = self.vae_kl_anneal_function(self.anneal_function, self.lag, iter, self.k, self.x0, self.upper) # outputs 0<s<1 em_kl_weight = kl_scale * self.em_kl_weight if (sum(supervised_mask) > 0): # if labeled data > 0: mu_labeled = em_mu[supervised_mask] logvar_labeled = em_logvar[supervised_mask] log_prob_labeled = zs[supervised_mask] y_onehot = emotion_onehot_target[supervised_mask] # -Elbo for labeled data (L(X,y)) SupervisedLoss, SupervisedKDL = self._L(y_onehot, mu_labeled, logvar_labeled, beta=em_kl_weight) loss += SupervisedLoss # Add MSE/MAE Loss prob_labeled = log_prob_labeled.exp() ClassicationMAELoss = nn.L1Loss(reduction='sum')( prob_labeled, y_onehot) / Bsz loss += (ClassicationMAELoss * self.zsClassificationMAELoss) ClassicationMSELoss = nn.MSELoss(reduction='sum')( prob_labeled, y_onehot) / Bsz loss += (ClassicationMSELoss * self.zsClassificationMSELoss) # Add auxiliary classification loss q(y|x) # negative cross entropy ClassicationNCELoss = -torch.sum(y_onehot * log_prob_labeled, dim=1).mean() loss += (ClassicationNCELoss * self.zsClassificationNCELoss) if (sum(unsupervised_mask) > 0): # if unlabeled data > 0: mu_unlabeled = em_mu[unsupervised_mask] logvar_unlabeled = em_logvar[unsupervised_mask] log_prob_unlabeled = zs[unsupervised_mask] # -Elbo for unlabeled data (U(x)) UnsupervisedLoss, UnsupervisedKDL = self._U(log_prob_unlabeled, mu_unlabeled, logvar_unlabeled, beta=em_kl_weight) loss += UnsupervisedLoss if True: # AuxEmotionNet loss aux_zs, aux_em_mu, aux_em_logvar, aux_em_params = [ x.squeeze(1) for x in aux_em_package ] PredDistMSE = PredDistMAE = AuxClassicationMAELoss = AuxClassicationMSELoss = AuxClassicationNCELoss = torch.tensor( 0) # pred em_zu dist param Loss PredDistMSE = nn.MSELoss()(aux_em_params, em_params) PredDistMAE = nn.L1Loss()(aux_em_params, em_params) loss += (PredDistMSE * self.predzu_MSE_weight + PredDistMAE * self.predzu_MAE_weight) # Aux Zs Classification Loss if (sum(supervised_mask) > 0): # if labeled data > 0: log_prob_labeled = aux_zs[supervised_mask] prob_labeled = log_prob_labeled.exp() AuxClassicationMAELoss = nn.L1Loss(reduction='sum')( prob_labeled, y_onehot) / Bsz loss += (AuxClassicationMAELoss * self.auxClassificationMAELoss) AuxClassicationMSELoss = nn.MSELoss(reduction='sum')( prob_labeled, y_onehot) / Bsz loss += (AuxClassicationMSELoss * self.auxClassificationMSELoss) AuxClassicationNCELoss = -torch.sum( y_onehot * log_prob_labeled, dim=1).mean() loss += (AuxClassicationNCELoss * self.auxClassificationNCELoss) if True: # Diagonal Attention Guiding AttentionLoss = self.guided_att( alignments[preserve_decoder == 0.0], text_lengths[preserve_decoder == 0.0], output_lengths[preserve_decoder == 0.0]) loss += (AttentionLoss * self.DiagonalGuidedAttention_scalar) reduced_d_loss = reduced_avg_fakeness = avg_fakeness = 0.0 GAN_Spect_MAE = adv_postnet_loss = torch.tensor(0.) if True and gan_package[0] is not None: real_labels = torch.zeros(mel_target_not_masked.shape[0], device=loss.device, dtype=loss.dtype) # [B] fake_labels = torch.ones(mel_target_not_masked.shape[0], device=loss.device, dtype=loss.dtype) # [B] mel_outputs_adv, speaker_embed, *_ = gan_package if self.masked_select: fill_mask = mel_target_not_masked == 0.0 mel_outputs_adv = mel_outputs_adv.clone() mel_outputs_adv.masked_fill_(fill_mask, 0.0) mel_outputs_adv_masked = torch.masked_select( mel_outputs_adv, mask) mel_out_not_masked = mel_out_not_masked.clone() mel_out_not_masked.masked_fill_(fill_mask, 0.0) if mel_out_postnet is not None: mel_out_postnet_not_masked = mel_out_postnet_not_masked.clone( ) mel_out_postnet_not_masked.masked_fill_(fill_mask, 0.0) # spectrograms [B, n_mel, dec_T] # mel_target_not_masked # mel_out_not_masked # mel_out_postnet_not_masked # mel_outputs_adv speaker_embed = speaker_embed.unsqueeze(2).repeat( 1, 1, dec_T) # [B, embed] -> [B, embed, dec_T] fake_pred_fakeness = model_d( mel_outputs_adv, speaker_embed.detach() ) # should speaker_embed be attached computational graph? Not sure atm avg_fakeness = fake_pred_fakeness.mean() # metric for Tensorboard # Tacotron2 Optimizer / Loss reduced_avg_fakeness = reduce_tensor( avg_fakeness.data, n_gpus).item( ) if hparams.distributed_run else avg_fakeness.item() adv_postnet_loss = nn.BCELoss( )(fake_pred_fakeness, real_labels) # [B] -> [] calc loss to decrease fakeness of model GAN_Spect_MAE = nn.L1Loss()(mel_outputs_adv_masked, mel_target) if reduced_avg_fakeness > 0.4: loss += (adv_postnet_loss * self.adv_postnet_scalar) loss += (GAN_Spect_MAE * (self.adv_postnet_scalar * self.adv_postnet_reconstruction_weight)) # Tacotron2 Optimizer / Loss if hparams.distributed_run: reduced_loss = reduce_tensor(loss.data, n_gpus).item() reduced_gate_loss = reduce_tensor(gate_loss.data, n_gpus).item() else: reduced_loss = loss.item() reduced_gate_loss = gate_loss.item() if optimizer is not None: if hparams.fp16_run: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if hparams.fp16_run: grad_norm = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), grad_clip_thresh) is_overflow = math.isinf(grad_norm) or math.isnan(grad_norm) else: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), grad_clip_thresh) optimizer.step() # (optional) Discriminator Optimizer / Loss if True and gan_package[0] is not None: if optimizer_d is not None: optimizer_d.zero_grad() # spectrograms [B, n_mel, dec_T] # mel_target_not_masked # mel_out_not_masked # mel_out_postnet_not_masked # mel_outputs_adv fake_pred_fakeness = model_d(mel_outputs_adv.detach(), speaker_embed.detach()) fake_d_loss = nn.BCELoss()( fake_pred_fakeness, fake_labels ) # [B] -> [] loss to increase distriminated fakeness of fake samples real_pred_fakeness = model_d(mel_target_not_masked.detach(), speaker_embed.detach()) real_d_loss = nn.BCELoss()( real_pred_fakeness, real_labels ) # [B] -> [] loss to decrease distriminated fakeness of real samples if self.dis_postnet_scalar and mel_out_postnet is not None: fake_pred_fakeness = model_d( mel_out_postnet_not_masked.detach(), speaker_embed.detach()) fake_d_loss += self.dis_postnet_scalar * nn.BCELoss()( fake_pred_fakeness, fake_labels ) # [B] -> [] loss to increase distriminated fakeness of fake samples if self.dis_spect_scalar: fake_pred_fakeness = model_d(mel_out_not_masked.detach(), speaker_embed.detach()) fake_d_loss += self.dis_spect_scalar * nn.BCELoss()( fake_pred_fakeness, fake_labels ) # [B] -> [] loss to increase distriminated fakeness of fake samples d_loss = (real_d_loss + fake_d_loss) * (self.adv_postnet_scalar * 0.5) reduced_d_loss = reduce_tensor( d_loss.data, n_gpus).item() if hparams.distributed_run else d_loss.item() if optimizer_d is not None and reduced_avg_fakeness < 0.85: if hparams.fp16_run: with amp.scale_loss(d_loss, optimizer_d) as scaled_loss: scaled_loss.backward() else: d_loss.backward() if hparams.fp16_run: grad_norm_d = torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer_d), grad_clip_thresh) is_overflow = math.isinf(grad_norm_d) or math.isnan( grad_norm_d) else: grad_norm_d = torch.nn.utils.clip_grad_norm_( model_d.parameters(), grad_clip_thresh) optimizer_d.step() with torch.no_grad(): # debug/fun S_Bsz = supervised_mask.sum().item() U_Bsz = unsupervised_mask.sum().item() ClassicationAccStr = 'N/A' Top1ClassificationAcc = 0.0 if S_Bsz > 0: Top1ClassificationAcc = (torch.argmax( log_prob_labeled.exp(), dim=1) == torch.argmax( y_onehot, dim=1)).float().sum().item() / S_Bsz # top-1 accuracy self.AvgClassAcc = self.AvgClassAcc * 0.95 + Top1ClassificationAcc * 0.05 ClassicationAccStr = round(Top1ClassificationAcc * 100, 2) print(" Total loss = ", loss.item(), '\n', " Spect LLL = ", spec_LL.item(), '\n', " Postnet Spect LLL = ", postnet_LL.item(), '\n', " Spect MSE = ", spec_MSE.item(), '\n', " Spect MAE = ", spec_MAE.item(), '\n', " Spect SMAE = ", spec_SMAE.item(), '\n', " Postnet Spect MSE = ", postnet_MSE.item(), '\n', " Postnet Spect MAE = ", postnet_MAE.item(), '\n', " Postnet Spect SMAE = ", postnet_SMAE.item(), '\n', " Gate BCE = ", gate_loss.item(), '\n', " sylKLD = ", sylKLD.item(), '\n', " sylpsMSE = ", sylpsMSE.item(), '\n', " sylpsMAE = ", sylpsMAE.item(), '\n', " SupervisedLoss = ", SupervisedLoss.item(), '\n', " SupervisedKDL = ", SupervisedKDL.item(), '\n', " UnsupervisedLoss = ", UnsupervisedLoss.item(), '\n', " UnsupervisedKDL = ", UnsupervisedKDL.item(), '\n', " ClassicationMSELoss = ", ClassicationMSELoss.item(), '\n', " ClassicationMAELoss = ", ClassicationMAELoss.item(), '\n', " ClassicationNCELoss = ", ClassicationNCELoss.item(), '\n', "AuxClassicationMSELoss = ", AuxClassicationMSELoss.item(), '\n', "AuxClassicationMAELoss = ", AuxClassicationMAELoss.item(), '\n', "AuxClassicationNCELoss = ", AuxClassicationNCELoss.item(), '\n', " Predicted Zu MSE = ", PredDistMSE.item(), '\n', " Predicted Zu MAE = ", PredDistMAE.item(), '\n', " DiagAttentionLoss = ", AttentionLoss.item(), '\n', " PredAvgFakeness = ", reduced_avg_fakeness, '\n', " GeneratorMAE = ", GAN_Spect_MAE.item(), '\n', " GeneratorLoss = ", adv_postnet_loss.item(), '\n', " DiscriminatorLoss = ", reduced_d_loss / self.adv_postnet_scalar, '\n', " ClassicationAcc = ", ClassicationAccStr, '%\n', " AvgClassicatAcc = ", round(self.AvgClassAcc * 100, 2), '%\n', " Total Batch Size = ", Bsz, '\n', " Super Batch Size = ", S_Bsz, '\n', " UnSup Batch Size = ", U_Bsz, '\n', sep='') loss_terms = [ [loss.item(), 1.0], [spec_MSE.item(), self.melout_MSE_scalar], [spec_MAE.item(), self.melout_MAE_scalar], [spec_SMAE.item(), self.melout_SMAE_scalar], [postnet_MSE.item(), self.postnet_MSE_scalar], [postnet_MAE.item(), self.postnet_MAE_scalar], [postnet_SMAE.item(), self.postnet_SMAE_scalar], [gate_loss.item(), 1.0], [sylKLD.item(), self.syl_KDL_weight], [sylpsMSE.item(), self.pred_sylpsMSE_weight], [sylpsMAE.item(), self.pred_sylpsMAE_weight], [SupervisedLoss.item(), 1.0], [SupervisedKDL.item(), em_kl_weight * 0.5], [UnsupervisedLoss.item(), 1.0], [UnsupervisedKDL.item(), em_kl_weight * 0.5], [ClassicationMSELoss.item(), self.zsClassificationMSELoss], [ClassicationMAELoss.item(), self.zsClassificationMAELoss], [ClassicationNCELoss.item(), self.zsClassificationNCELoss], [AuxClassicationMSELoss.item(), self.auxClassificationMSELoss], [AuxClassicationMAELoss.item(), self.auxClassificationMAELoss], [AuxClassicationNCELoss.item(), self.auxClassificationNCELoss], [PredDistMSE.item(), self.predzu_MSE_weight], [PredDistMAE.item(), self.predzu_MAE_weight], [Top1ClassificationAcc, 1.0], [reduced_avg_fakeness, 1.0], [adv_postnet_loss.item(), self.adv_postnet_scalar], [ reduced_d_loss / self.adv_postnet_scalar, self.adv_postnet_scalar ], ] return loss, gate_loss, loss_terms, reduced_loss, reduced_gate_loss, grad_norm, is_overflow