def generate_attn(dr, x_mask, y_mask=None): # compute decode mask from the durations if y_mask is None: y_lengths = dr.sum(1).long() y_lengths[y_lengths < 1] = 1 y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) return attn
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): """ Shapes: - x: :math:`[B, T_seq]` - d_vectors: :math:`[B, C, 1]` - speaker_ids: :math:`[B]` """ sid, g = self._set_cond_input(aux_input) x_lengths = torch.tensor(x.shape[1:2]).to(x.device) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths) if self.num_speakers > 0 and (sid is not None): g = self.emb_g(sid).unsqueeze(-1) if self.args.use_sdp: logw = self.duration_predictor( x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp) else: logw = self.duration_predictor(x, x_mask, g=g) w = torch.exp(logw) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) z_p = (m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale) z = self.flow(z_p, y_mask, g=g, reverse=True) o = self.waveform_decoder((z * y_mask)[:, :, :self.max_inference_len], g=g) outputs = { "model_outputs": o, "alignments": attn.squeeze(1), "z": z, "z_p": z_p, "m_p": m_p, "logs_p": logs_p, } return outputs
def inference(self, x, aux_input={ "x_lengths": None, "d_vectors": None, "speaker_ids": None }): # pylint: disable=dangerous-default-value x_lengths = aux_input["x_lengths"] g = aux_input[ "d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None if g is not None: if self.d_vector_dim: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # compute output durations w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_max_length = None # compute masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # compute attention mask attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask # decoder pass y, logdet = self.decoder(z, y_mask, g=g, reverse=True) attn = attn.squeeze(1).permute(0, 2, 1) outputs = { "model_outputs": y.transpose(1, 2), "logdet": logdet, "y_mean": y_mean.transpose(1, 2), "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, "durations_log": o_dur_log.transpose(1, 2), "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs
def generate_attn(dr, x_mask, y_mask=None): """Generate an attention mask from the durations. Shapes - dr: :math:`(B, T_{en})` - x_mask: :math:`(B, T_{en})` - y_mask: :math:`(B, T_{de})` """ # compute decode mask from the durations if y_mask is None: y_lengths = dr.sum(1).long() y_lengths[y_lengths < 1] = 1 y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) return attn
def generate_path_test(): durations = T.randint(1, 4, (10, 21)) x_length = T.randint(18, 22, (10, )) x_mask = sequence_mask(x_length).unsqueeze(1).long() durations = durations * x_mask.squeeze(1) y_length = durations.sum(1) y_mask = sequence_mask(y_length).unsqueeze(1).long() attn_mask = (T.unsqueeze(x_mask, -1) * T.unsqueeze(y_mask, 2)).squeeze(1).long() print(attn_mask.shape) path = generate_path(durations, attn_mask) assert path.shape == (10, 21, durations.sum(1).max().item()) for b in range(durations.shape[0]): current_idx = 0 for t in range(durations.shape[1]): assert all(path[b, t, current_idx:current_idx + durations[b, t].item()] == 1.0) assert all(path[b, t, :current_idx] == 0.0) assert all(path[b, t, current_idx + durations[b, t].item():] == 0.0) current_idx += durations[b, t].item()