def forward(self, x, target, length): """ Args: x: A Variable containing a FloatTensor of size (batch, max_len, dim) which contains the unnormalized probability for each class. target: A Variable containing a LongTensor of size (batch, max_len, dim) which contains the index of the true class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. Shapes: x: B x T X D target: B x T x D length: B Returns: loss: An average loss value in range [0, 1] masked by the length. """ # mask: (batch, max_len, 1) target.requires_grad = False mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() if self.seq_len_norm: norm_w = mask / mask.sum(dim=1, keepdim=True) out_weights = norm_w.div(target.shape[0] * target.shape[2]) mask = mask.expand_as(x) loss = functional.l1_loss(x * mask, target * mask, reduction="none") loss = loss.mul(out_weights.to(loss.device)).sum() else: mask = mask.expand_as(x) loss = functional.l1_loss(x * mask, target * mask, reduction="sum") loss = loss / mask.sum() return loss
def forward(self, x, target, length): """ Args: x: A Variable containing a FloatTensor of size (batch, max_len) which contains the unnormalized probability for each class. target: A Variable containing a LongTensor of size (batch, max_len) which contains the index of the true class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. Shapes: x: B x T target: B x T length: B Returns: loss: An average loss value in range [0, 1] masked by the length. """ # mask: (batch, max_len, 1) target.requires_grad = False if length is not None: mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float() x = x * mask target = target * mask num_items = mask.sum() else: num_items = torch.numel(x) loss = functional.binary_cross_entropy_with_logits( x, target, pos_weight=self.pos_weight, reduction="sum") loss = loss / num_items return loss
def decoder_inference(self, y, y_lengths=None, aux_input={ "d_vectors": None, "speaker_ids": None }): # pylint: disable=dangerous-default-value """ Shapes: - y: :math:`[B, T, C]` - y_lengths: :math:`B` - g: :math:`[B, C] or B` """ y = y.transpose(1, 2) y_max_length = y.size(2) g = self._speaker_embedding(aux_input) y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # reverse decoder and predict y, logdet = self.decoder(z, y_mask, g=g, reverse=True) outputs = {} outputs["model_outputs"] = y.transpose(1, 2) outputs["logdet"] = logdet return outputs
def seqeunce_mask_test(): lengths = T.randint(10, 15, (8, )) mask = sequence_mask(lengths) for i in range(8): l = lengths[i].item() assert mask[i, :l].sum() == l assert mask[i, l:].sum() == 0
def _forward_mdn(self, o_en, y, y_lengths, x_mask): # MAS potentials and alignment mu, log_sigma = self.mdn_block(o_en) y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask) return dr_mas, mu, log_sigma, logp
def forward( self, mel_slice, mel_slice_hat, z_p, logs_q, m_p, logs_p, z_len, scores_disc_fake, feats_disc_fake, feats_disc_real, loss_duration, use_speaker_encoder_as_loss=False, gt_spk_emb=None, syn_spk_emb=None, ): """ Shapes: - mel_slice : :math:`[B, 1, T]` - mel_slice_hat: :math:`[B, 1, T]` - z_p: :math:`[B, C, T]` - logs_q: :math:`[B, C, T]` - m_p: :math:`[B, C, T]` - logs_p: :math:`[B, C, T]` - z_len: :math:`[B]` - scores_disc_fake[i]: :math:`[B, C]` - feats_disc_fake[i][j]: :math:`[B, C, T', P]` - feats_disc_real[i][j]: :math:`[B, C, T', P]` """ loss = 0.0 return_dict = {} z_mask = sequence_mask(z_len).float() # compute losses loss_kl = ( self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1)) * self.kl_loss_alpha ) loss_feat = ( self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha ) loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration if use_speaker_encoder_as_loss: loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha loss = loss + loss_se return_dict["loss_spk_encoder"] = loss_se # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl return_dict["loss_feat"] = loss_feat return_dict["loss_mel"] = loss_mel return_dict["loss_duration"] = loss_duration return_dict["loss"] = loss return return_dict
def generate_attn(dr, x_mask, y_mask=None): # compute decode mask from the durations if y_mask is None: y_lengths = dr.sum(1).long() y_lengths[y_lengths < 1] = 1 y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) return attn
def forward(self, x, y, length=None): """ Shapes: x: B x T y: B x T length: B """ mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float() return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum()
def test_decoder(): input_dummy = torch.rand(8, 128, 37).to(device) input_lengths = torch.randint(31, 37, (8,)).long().to(device) input_lengths[-1] = 37 input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) # residual bn conv decoder layer = Decoder(out_channels=11, in_hidden_channels=128).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # transformer decoder layer = Decoder( out_channels=11, in_hidden_channels=128, decoder_type="relative_position_transformer", decoder_params={ "hidden_channels_ffn": 128, "num_heads": 2, "kernel_size": 3, "dropout_p": 0.1, "num_layers": 8, "rel_attn_window_size": 4, "input_length": None, }, ).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # wavenet decoder layer = Decoder( out_channels=11, in_hidden_channels=128, decoder_type="wavenet", decoder_params={ "num_blocks": 12, "hidden_channels": 192, "kernel_size": 5, "dilation_rate": 1, "num_layers": 4, "dropout_p": 0.05, }, ).to(device) output = layer(input_dummy, input_mask) # FFTransformer decoder layer = Decoder( out_channels=11, in_hidden_channels=128, decoder_type="fftransformer", decoder_params={ "hidden_channels_ffn": 31, "num_heads": 2, "dropout_p": 0.1, "num_layers": 2, }, ).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37]
def forward( self, x, x_lengths, y, y_lengths, aux_input={"d_vectors": None}, phase=None ): # pylint: disable=unused-argument """ Shapes: - x: :math:`[B, T_max]` - x_lengths: :math:`[B]` - y_lengths: :math:`[B]` - dr: :math:`[B, T_max]` - g: :math:`[B, C]` """ y = y.transpose(1, 2) g = aux_input["d_vectors"] if "d_vectors" in aux_input else None o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None if phase == 0: # train encoder and MDN o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) attn = self.generate_attn(dr_mas, x_mask, y_mask) elif phase == 1: # train decoder o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask) o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g) elif phase == 2: # train the whole except duration predictor o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) elif phase == 3: # train duration predictor o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_dr_log = self.duration_predictor(x, x_mask) dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) o_dr_log = o_dr_log.squeeze(1) else: o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) o_dr_log = o_dr_log.squeeze(1) dr_mas_log = torch.log(dr_mas + 1).squeeze(1) outputs = { "model_outputs": o_de.transpose(1, 2), "alignments": attn, "durations_log": o_dr_log, "durations_mas_log": dr_mas_log, "mu": mu, "log_sigma": log_sigma, "logp": logp, } return outputs
def generate_path_test(): durations = T.randint(1, 4, (10, 21)) x_length = T.randint(18, 22, (10, )) x_mask = sequence_mask(x_length).unsqueeze(1).long() durations = durations * x_mask.squeeze(1) y_length = durations.sum(1) y_mask = sequence_mask(y_length).unsqueeze(1).long() attn_mask = (T.unsqueeze(x_mask, -1) * T.unsqueeze(y_mask, 2)).squeeze(1).long() print(attn_mask.shape) path = generate_path(durations, attn_mask) assert path.shape == (10, 21, durations.sum(1).max().item()) for b in range(durations.shape[0]): current_idx = 0 for t in range(durations.shape[1]): assert all(path[b, t, current_idx:current_idx + durations[b, t].item()] == 1.0) assert all(path[b, t, :current_idx] == 0.0) assert all(path[b, t, current_idx + durations[b, t].item():] == 0.0) current_idx += durations[b, t].item()
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): """ Shapes: - x: :math:`[B, T_seq]` - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` """ sid, g = self._set_cond_input(aux_input) x_lengths = torch.tensor(x.shape[1:2]).to(x.device) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths) if self.num_speakers > 0 and (sid is not None): g = self.emb_g(sid).unsqueeze(-1) if self.args.use_sdp: logw = self.duration_predictor( x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp) else: logw = self.duration_predictor(x, x_mask, g=g) w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) z_p = (m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale) z = self.flow(z_p, y_mask, g=g, reverse=True) o = self.waveform_decoder((z * y_mask)[:, :, :self.max_inference_len], g=g) outputs = { "model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, } return outputs
def forward( self, waveform, waveform_hat, z_p, logs_q, m_p, logs_p, z_len, scores_disc_fake, feats_disc_fake, feats_disc_real, loss_duration, ): """ Shapes: - waveform : :math:`[B, 1, T]` - waveform_hat: :math:`[B, 1, T]` - z_p: :math:`[B, C, T]` - logs_q: :math:`[B, C, T]` - m_p: :math:`[B, C, T]` - logs_p: :math:`[B, C, T]` - z_len: :math:`[B]` - scores_disc_fake[i]: :math:`[B, C]` - feats_disc_fake[i][j]: :math:`[B, C, T', P]` - feats_disc_real[i][j]: :math:`[B, C, T', P]` """ loss = 0.0 return_dict = {} z_mask = sequence_mask(z_len).float() # compute mel spectrograms from the waveforms mel = self.stft(waveform) mel_hat = self.stft(waveform_hat) # compute losses loss_feat = self.feature_loss(feats_disc_fake, feats_disc_real) * self.feat_loss_alpha loss_gen = self.generator_loss( scores_disc_fake)[0] * self.gen_loss_alpha loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl return_dict["loss_feat"] = loss_feat return_dict["loss_mel"] = loss_mel return_dict["loss_duration"] = loss_duration return_dict["loss"] = loss return return_dict
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) # expand o_en with durations o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) # positional encoding if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) # speaker embedding if g is not None: o_en_ex = self._sum_speaker_embedding(o_en_ex, g) # decoder pass o_de = self.decoder(o_en_ex, y_mask, g=g) return o_de, attn.transpose(1, 2)
def inference(self, x, aux_input={ "x_lengths": None, "d_vectors": None, "speaker_ids": None }): # pylint: disable=dangerous-default-value x_lengths = aux_input["x_lengths"] g = aux_input[ "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None if g is not None: if self.d_vector_dim: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # compute output durations w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_max_length = None # compute masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # compute attention mask attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask # decoder pass y, logdet = self.decoder(z, y_mask, g=g, reverse=True) attn = attn.squeeze(1).permute(0, 2, 1) outputs = { "model_outputs": y.transpose(1, 2), "logdet": logdet, "y_mean": y_mean.transpose(1, 2), "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, "durations_log": o_dur_log.transpose(1, 2), "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs
def forward(self, x, x_lengths, g=None): """ Shapes: - x: :math:`[B, C, T]` - x_lengths: :math:`[B, 1]` - g: :math:`[B, C, 1]` """ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x = self.pre(x) * x_mask x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask mean, log_scale = torch.split(stats, self.out_channels, dim=1) z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask return z, mean, log_scale, x_mask
def forward(self, x, x_lengths): """ Shapes: - x: :math:`[B, T]` - x_length: :math:`[B]` """ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] x = torch.transpose(x, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x = self.encoder(x * x_mask, x_mask) stats = self.proj(x) * x_mask m, logs = torch.split(stats, self.out_channels, dim=1) return x, m, logs, x_mask
def generate_attn(dr, x_mask, y_mask=None): """Generate an attention mask from the durations. Shapes - dr: :math:`(B, T_{en})` - x_mask: :math:`(B, T_{en})` - y_mask: :math:`(B, T_{de})` """ # compute decode mask from the durations if y_mask is None: y_lengths = dr.sum(1).long() y_lengths[y_lengths < 1] = 1 y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) return attn
def forward(self, y_hat, y, length=None): """ Args: y_hat (tensor): model prediction values. y (tensor): target values. length (tensor): length of each sample in a batch. Shapes: y_hat: B x T X D y: B x T x D length: B Returns: loss: An average loss value in range [0, 1] masked by the length. """ if length is not None: m = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float().to(y_hat.device) y_hat, y = y_hat * m, y * m return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1))
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument """Model's inference pass. Args: x (torch.LongTensor): Input character sequence. aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. Shapes: - x: [B, T_max] - x_lengths: [B] - g: [B, C] """ g = self._set_speaker_input(aux_input) x_lengths = torch.tensor(x.shape[1:2]).to(x.device) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() # encoder pass o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) # duration predictor pass o_dr_log = self.duration_predictor(o_en, x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) y_lengths = o_dr.sum(1) # pitch predictor pass o_pitch = None if self.args.use_pitch: o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) o_en = o_en + o_pitch_emb # decoder pass o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) outputs = { "model_outputs": o_de, "alignments": attn, "pitch": o_pitch, "durations_log": o_dr_log, } return outputs
def test_encoder(): input_dummy = torch.rand(8, 14, 37).to(device) input_lengths = torch.randint(31, 37, (8,)).long().to(device) input_lengths[-1] = 37 input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) # relative positional transformer encoder layer = Encoder( out_channels=11, in_hidden_channels=14, encoder_type="relative_position_transformer", encoder_params={ "hidden_channels_ffn": 768, "num_heads": 2, "kernel_size": 3, "dropout_p": 0.1, "num_layers": 6, "rel_attn_window_size": 4, "input_length": None, }, ).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # residual conv bn encoder layer = Encoder( out_channels=11, in_hidden_channels=14, encoder_type="residual_conv_bn", encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, ).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # FFTransformer encoder layer = Encoder( out_channels=14, in_hidden_channels=14, encoder_type="fftransformer", encoder_params={"hidden_channels_ffn": 31, "num_heads": 2, "num_layers": 2, "dropout_p": 0.1}, ).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 14, 37]
def _forward_decoder( self, o_en: torch.FloatTensor, dr: torch.IntTensor, x_mask: torch.FloatTensor, y_lengths: torch.IntTensor, g: torch.FloatTensor, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: """Decoding forward pass. 1. Compute the decoder output mask 2. Expand encoder output with the durations. 3. Apply position encoding. 4. Add speaker embeddings if multi-speaker mode. 5. Run the decoder. Args: o_en (torch.FloatTensor): Encoder output. dr (torch.IntTensor): Ground truth durations or alignment network durations. x_mask (torch.IntTensor): Input sequence mask. y_lengths (torch.IntTensor): Output sequence lengths. g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. Returns: Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. """ y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) # expand o_en with durations o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) # positional encoding if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) # speaker embedding if g is not None: o_en_ex = self._sum_speaker_embedding(o_en_ex, g) # decoder pass o_de = self.decoder(o_en_ex, y_mask, g=g) return o_de.transpose(1, 2), attn.transpose(1, 2)
def forward(self, x, x_lengths, g=None): """ Shapes: - x: :math:`[B, C, T]` - x_lengths: :math:`[B]` - g (optional): :math:`[B, 1, T]` """ # embedding layer # [B ,T, D] x = self.emb(x) * math.sqrt(self.hidden_channels) # [B, D, T] x = torch.transpose(x, 1, -1) # compute input sequence mask x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # prenet if hasattr(self, "prenet") and self.use_prenet: x = self.prenet(x, x_mask) # encoder x = self.encoder(x, x_mask) # postnet if hasattr(self, "postnet"): x = self.postnet(x) * x_mask # set duration predictor input if g is not None: g_exp = g.expand(-1, -1, x.size(-1)) x_dp = torch.cat([x.detach(), g_exp], 1) else: x_dp = x.detach() # final projection layer x_m = self.proj_m(x) * x_mask if not self.mean_only: x_logs = self.proj_s(x) * x_mask else: x_logs = torch.zeros_like(x_m) # duration predictor logw = self.duration_predictor(x_dp, x_mask) return x_m, x_logs, logw, x_mask
def forward(self, x, x_lengths, lang_emb=None): """ Shapes: - x: :math:`[B, T]` - x_length: :math:`[B]` """ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] # concat the lang emb in embedding chars if lang_emb is not None: x = torch.cat( (x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) x = torch.transpose(x, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) x = self.encoder(x * x_mask, x_mask) stats = self.proj(x) * x_mask m, logs = torch.split(stats, self.out_channels, dim=1) return x, m, logs, x_mask
def decoder_inference(self, y, y_lengths=None, aux_input={ "d_vectors": None, "speaker_ids": None }): # pylint: disable=dangerous-default-value """ Shapes: - y: :math:`[B, T, C]` - y_lengths: :math:`B` - g: :math:`[B, C] or B` """ y = y.transpose(1, 2) y_max_length = y.size(2) g = aux_input[ "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None # norm speaker embeddings if g is not None: if self.external_d_vector_dim: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # reverse decoder and predict y, logdet = self.decoder(z, y_mask, g=g, reverse=True) outputs = {} outputs["model_outputs"] = y.transpose(1, 2) outputs["logdet"] = logdet return outputs
def _forward_encoder(self, x, x_lengths, g=None): if hasattr(self, "emb_g"): g = nn.functional.normalize(self.speaker_embedding(g)) # [B, C, 1] if g is not None: g = g.unsqueeze(-1) # [B, T, C] x_emb = self.emb(x) # [B, C, T] x_emb = torch.transpose(x_emb, 1, -1) # compute sequence masks x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) # encoder pass o_en = self.encoder(x_emb, x_mask) # speaker conditioning for duration predictor if g is not None: o_en_dp = self._concat_speaker_embedding(o_en, g) else: o_en_dp = o_en return o_en, o_en_dp, x_mask, g
def _make_masks(ilens, olens): in_masks = sequence_mask(ilens) out_masks = sequence_mask(olens) return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2)
def forward( self, x: torch.LongTensor, x_lengths: torch.LongTensor, y_lengths: torch.LongTensor, y: torch.FloatTensor = None, dr: torch.IntTensor = None, pitch: torch.FloatTensor = None, aux_input: Dict = { "d_vectors": None, "speaker_ids": None }, # pylint: disable=unused-argument ) -> Dict: """Model's forward pass. Args: x (torch.LongTensor): Input character sequences. x_lengths (torch.LongTensor): Input sequence lengths. y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None. y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None. pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. Shapes: - x: :math:`[B, T_max]` - x_lengths: :math:`[B]` - y_lengths: :math:`[B]` - y: :math:`[B, T_max2]` - dr: :math:`[B, T_max]` - g: :math:`[B, C]` - pitch: :math:`[B, 1, T]` """ g = self._set_speaker_input(aux_input) # compute sequence masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # encoder pass o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) # duration predictor pass if self.args.detach_duration_predictor: o_dr_log = self.duration_predictor(o_en.detach(), x_mask) else: o_dr_log = self.duration_predictor(o_en, x_mask) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) # generate attn mask from predicted durations o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) # aligner o_alignment_dur = None alignment_soft = None alignment_logprob = None alignment_mas = None if self.use_aligner: o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner( x_emb, y, x_mask, y_mask) alignment_soft = alignment_soft.transpose(1, 2) alignment_mas = alignment_mas.transpose(1, 2) dr = o_alignment_dur # pitch predictor pass o_pitch = None avg_pitch = None if self.args.use_pitch: o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor( o_en, x_mask, pitch, dr) o_en = o_en + o_pitch_emb # decoder pass o_de, attn = self._forward_decoder( o_en, dr, x_mask, y_lengths, g=None) # TODO: maybe pass speaker embedding (g) too outputs = { "model_outputs": o_de, # [B, T, C] "durations_log": o_dr_log.squeeze(1), # [B, T] "durations": o_dr.squeeze(1), # [B, T] "attn_durations": o_attn, # for visualization [B, T_en, T_de'] "pitch_avg": o_pitch, "pitch_avg_gt": avg_pitch, "alignments": attn, # [B, T_de, T_en] "alignment_soft": alignment_soft, "alignment_mas": alignment_mas, "o_alignment_dur": o_alignment_dur, "alignment_logprob": alignment_logprob, "x_mask": x_mask, "y_mask": y_mask, } return outputs
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, aux_input={ "d_vectors": None, "speaker_ids": None }): # pylint: disable=dangerous-default-value """ It's similar to the teacher forcing in Tacotron. It was proposed in: https://arxiv.org/abs/2104.05557 Shapes: - x: :math:`[B, T]` - x_lenghts: :math:`B` - y: :math:`[B, T, C]` - y_lengths: :math:`B` - g: :math:`[B, C] or B` """ y = y.transpose(1, 2) y_max_length = y.size(2) # norm speaker embeddings g = aux_input[ "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None if self.use_speaker_embedding or self.use_d_vector_file: if not self.use_d_vector_file: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, None) # create masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path between z and encoder output o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) # get predited aligned distribution z = y_mean * y_mask # reverse the decoder and predict using the aligned distribution y, logdet = self.decoder(z, y_mask, g=g, reverse=True) outputs = { "model_outputs": z.transpose(1, 2), "logdet": logdet, "y_mean": y_mean.transpose(1, 2), "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, "durations_log": o_dur_log.transpose(1, 2), "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs
def forward(self, x, x_lengths, y, y_lengths=None, aux_input={ "d_vectors": None, "speaker_ids": None }): # pylint: disable=dangerous-default-value """ Shapes: - x: :math:`[B, T]` - x_lenghts::math:`B` - y: :math:`[B, T, C]` - y_lengths::math:`B` - g: :math:`[B, C] or B` """ # [B, T, C] -> [B, C, T] y = y.transpose(1, 2) y_max_length = y.size(2) # norm speaker embeddings g = aux_input[ "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None if self.use_speaker_embedding or self.use_d_vector_file: if not self.use_d_vector_file: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, None) # create masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) # [B, 1, T_en, T_de] attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path with torch.no_grad(): o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) outputs = { "z": z.transpose(1, 2), "logdet": logdet, "y_mean": y_mean.transpose(1, 2), "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, "durations_log": o_dur_log.transpose(1, 2), "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs