def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask): # find the max alignment path attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) log_p = self.compute_log_probs(mu, log_sigma, y) # [B, T_en, T_dec] attn = maximum_path(log_p, attn_mask.squeeze(1)).unsqueeze(1) dr_mas = torch.sum(attn, -1) return dr_mas.squeeze(1), log_p
def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): """ Shapes: x: [B, T] x_lenghts: B y: [B, C, T] y_lengths: B g: [B, C] or B """ y_max_length = y.size(2) # norm speaker embeddings if g is not None: if self.speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, None) # create masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path with torch.no_grad(): o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): """ It's similar to the teacher forcing in Tacotron. It was proposed in: https://arxiv.org/abs/2104.05557 Shapes: x: [B, T] x_lenghts: B y: [B, C, T] y_lengths: B g: [B, C] or B """ y_max_length = y.size(2) # norm speaker embeddings if g is not None: if self.external_speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, None) # create masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path between z and encoder output o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) # get predited aligned distribution z = y_mean * y_mask # reverse the decoder and predict using the aligned distribution y, logdet = self.decoder(z, y_mask, g=g, reverse=True) return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur