def compute_masks(self, text_lengths, mel_lengths): """Compute masks against sequence paddings.""" # B x T_in_max (boolean) device = text_lengths.device input_mask = sequence_mask(text_lengths).to(device) output_mask = None if mel_lengths is not None: max_len = mel_lengths.max() r = self.decoder.r max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len output_mask = sequence_mask(mel_lengths, max_len=max_len).to(device) return input_mask, output_mask
def _forward_encoder(self, x, x_lengths, g=None): if hasattr(self, "emb_g"): g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] if g is not None: g = g.unsqueeze(-1) # [B, T, C] x_emb = self.emb(x) # [B, C, T] x_emb = torch.transpose(x_emb, 1, -1) # compute sequence masks x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) # encoder pass o_en = self.encoder(x_emb, x_mask) # speaker conditioning for duration predictor if g is not None: o_en_dp = self._concat_speaker_embedding(o_en, g) else: o_en_dp = o_en return o_en, o_en_dp, x_mask, g
def forward(self, x, target, length): """ Args: x: A Variable containing a FloatTensor of size (batch, max_len) which contains the unnormalized probability for each class. target: A Variable containing a LongTensor of size (batch, max_len) which contains the index of the true class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. Returns: loss: An average loss value in range [0, 1] masked by the length. """ # mask: (batch, max_len, 1) target.requires_grad = False mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float() loss = functional.binary_cross_entropy_with_logits( x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum') loss = loss / mask.sum() return loss
def forward(self, x, target, length): """ Args: x: A Variable containing a FloatTensor of size (batch, max_len, dim) which contains the unnormalized probability for each class. target: A Variable containing a LongTensor of size (batch, max_len, dim) which contains the index of the true class for each corresponding step. length: A Variable containing a LongTensor of size (batch,) which contains the length of each data in a batch. Returns: loss: An average loss value in range [0, 1] masked by the length. """ # mask: (batch, max_len, 1) target.requires_grad = False mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() if self.seq_len_norm: norm_w = mask / mask.sum(dim=1, keepdim=True) out_weights = norm_w.div(target.shape[0] * target.shape[2]) mask = mask.expand_as(x) loss = functional.mse_loss(x * mask, target * mask, reduction='none') loss = loss.mul(out_weights.to(loss.device)).sum() else: mask = mask.expand_as(x) loss = functional.mse_loss(x * mask, target * mask, reduction='sum') loss = loss / mask.sum() return loss
def test_encoder(): input_dummy = torch.rand(8, 14, 37).to(device) input_lengths = torch.randint(31, 37, (8, )).long().to(device) input_lengths[-1] = 37 input_mask = torch.unsqueeze( sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) # residual bn conv encoder layer = Encoder(out_channels=11, in_hidden_channels=14, encoder_type='residual_conv_bn').to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # transformer encoder layer = Encoder(out_channels=11, in_hidden_channels=14, encoder_type='transformer', encoder_params={ 'hidden_channels_ffn': 768, 'num_heads': 2, "kernel_size": 3, "dropout_p": 0.1, "num_layers": 6, "rel_attn_window_size": 4, "input_length": None }).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37]
def decoder_inference(self, y, y_lengths=None, g=None): """ Shapes: y: [B, C, T] y_lengths: B g: [B, C] or B """ y_max_length = y.size(2) # norm speaker embeddings if g is not None: if self.external_speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # reverse decoder and predict y, logdet = self.decoder(z, y_mask, g=g, reverse=True) return y, logdet
def inference(self, x, x_lengths, g=None): if g is not None: if self.speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # compute output durations w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_max_length = None # compute masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # compute attention mask attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask # decoder pass y, logdet = self.decoder(z, y_mask, g=g, reverse=True) attn = attn.squeeze(1).permute(0, 2, 1) return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
def forward(self, x, x_lengths, g=None): # embedding layer # [B ,T, D] x = self.emb(x) * math.sqrt(self.hidden_channels) # [B, D, T] x = torch.transpose(x, 1, -1) # compute input sequence mask x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # pre-conv layers if self.encoder_type in ['transformer', 'time-depth-separable']: if self.use_prenet: x = self.pre(x, x_mask) # encoder x = self.encoder(x, x_mask) # set duration predictor input if g is not None: g_exp = g.expand(-1, -1, x.size(-1)) x_dp = torch.cat([torch.detach(x), g_exp], 1) else: x_dp = torch.detach(x) # final projection layer x_m = self.proj_m(x) * x_mask if not self.mean_only: x_logs = self.proj_s(x) * x_mask else: x_logs = torch.zeros_like(x_m) # duration predictor logw = self.duration_predictor(x_dp, x_mask) return x_m, x_logs, logw, x_mask
def _forward_mdn(self, o_en, y, y_lengths, x_mask): # MAS potentials and alignment mu, log_sigma = self.mdn_block(o_en) y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask) return dr_mas, mu, log_sigma, logp
def test_decoder(): input_dummy = torch.rand(8, 128, 37).to(device) input_lengths = torch.randint(31, 37, (8,)).long().to(device) input_lengths[-1] = 37 input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) # residual bn conv decoder layer = Decoder(out_channels=11, in_hidden_channels=128).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # transformer decoder layer = Decoder( out_channels=11, in_hidden_channels=128, decoder_type="relative_position_transformer", decoder_params={ "hidden_channels_ffn": 128, "num_heads": 2, "kernel_size": 3, "dropout_p": 0.1, "num_layers": 8, "rel_attn_window_size": 4, "input_length": None, }, ).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # wavenet decoder layer = Decoder( out_channels=11, in_hidden_channels=128, decoder_type="wavenet", decoder_params={ "num_blocks": 12, "hidden_channels": 192, "kernel_size": 5, "dilation_rate": 1, "num_layers": 4, "dropout_p": 0.05, }, ).to(device) output = layer(input_dummy, input_mask) # FFTransformer decoder layer = Decoder( out_channels=11, in_hidden_channels=128, decoder_type="fftransformer", decoder_params={ "hidden_channels_ffn": 31, "num_heads": 2, "dropout_p": 0.1, "num_layers": 2, }, ).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37]
def forward(self, x, y, length=None): """ Shapes: x: B x T y: B x T length: B """ mask = sequence_mask(sequence_length=length, max_len=y.size(1)).float() return torch.nn.functional.smooth_l1_loss( x * mask, y * mask, reduction='sum') / mask.sum()
def test_duration_predictor(): input_dummy = torch.rand(8, 128, 27).to(device) input_lengths = torch.randint(20, 27, (8, )).long().to(device) input_lengths[-1] = 27 x_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) layer = DurationPredictor(hidden_channels=128).to(device) output = layer(input_dummy, x_mask) assert list(output.shape) == [8, 1, 27]
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) # expand o_en with durations o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) # positional encoding if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) # speaker embedding if g is not None: o_en_ex = self._sum_speaker_embedding(o_en_ex, g) # decoder pass o_de = self.decoder(o_en_ex, y_mask, g=g) return o_de, attn.transpose(1, 2)
def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): """ Shapes: x: [B, T] x_lenghts: B y: [B, C, T] y_lengths: B g: [B, C] or B """ y_max_length = y.size(2) # norm speaker embeddings if g is not None: if self.speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, None) # create masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path with torch.no_grad(): o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
def generate_path(duration, mask): """ duration: [b, t_x] mask: [b, t_x, t_y] """ device = duration.device b, t_x, t_y = mask.shape cum_duration = torch.cumsum(duration, 1) path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = path.view(b, t_x, t_y) path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] path = path * mask return path
def test_encoder(): input_dummy = torch.rand(8, 14, 37).to(device) input_lengths = torch.randint(31, 37, (8, )).long().to(device) input_lengths[-1] = 37 input_mask = torch.unsqueeze( sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) # relative positional transformer encoder layer = Encoder(out_channels=11, in_hidden_channels=14, encoder_type='relative_position_transformer', encoder_params={ 'hidden_channels_ffn': 768, 'num_heads': 2, "kernel_size": 3, "dropout_p": 0.1, "num_layers": 6, "rel_attn_window_size": 4, "input_length": None }).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # residual conv bn encoder layer = Encoder(out_channels=11, in_hidden_channels=14, encoder_type='residual_conv_bn', encoder_params={ "kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13 }).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] # FFTransformer encoder layer = Encoder(out_channels=14, in_hidden_channels=14, encoder_type='fftransformer', encoder_params={ "hidden_channels_ffn": 31, "num_heads": 2, "num_layers": 2, "dropout_p": 0.1 }).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 14, 37]
def forward(self, y_hat, y, length=None): """ Args: y_hat (tensor): model prediction values. y (tensor): target values. length (tensor): length of each sample in a batch. Shapes: y_hat: B x T X D y: B x T x D length: B Returns: loss: An average loss value in range [0, 1] masked by the length. """ if length is not None: m = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float().to( y_hat.device) y_hat, y = y_hat * m, y * m return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1))
def forward(self, x, x_lengths, g=None): """ Shapes: x: [B, C, T] x_lengths: [B] g (optional): [B, 1, T] """ # embedding layer # [B ,T, D] x = self.emb(x) * math.sqrt(self.hidden_channels) # [B, D, T] x = torch.transpose(x, 1, -1) # compute input sequence mask x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # prenet if hasattr(self, 'prenet') and self.use_prenet: x = self.prenet(x, x_mask) # encoder x = self.encoder(x, x_mask) # postnet if hasattr(self, 'postnet'): x = self.postnet(x) * x_mask # set duration predictor input if g is not None: g_exp = g.expand(-1, -1, x.size(-1)) x_dp = torch.cat([torch.detach(x), g_exp], 1) else: x_dp = torch.detach(x) # final projection layer x_m = self.proj_m(x) * x_mask if not self.mean_only: x_logs = self.proj_s(x) * x_mask else: x_logs = torch.zeros_like(x_m) # duration predictor logw = self.duration_predictor(x_dp, x_mask) return x_m, x_logs, logw, x_mask
def test_in_out(self): # pylint: disable=no-self-use # test input == target layer = L1LossMasked(seq_len_norm=False) dummy_input = T.ones(4, 8, 128).float() dummy_target = T.ones(4, 8, 128).float() dummy_length = (T.ones(4) * 8).long() output = layer(dummy_input, dummy_target, dummy_length) assert output.item() == 0.0 # test input != target dummy_input = T.ones(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.ones(4) * 8).long() output = layer(dummy_input, dummy_target, dummy_length) assert output.item() == 1.0, "1.0 vs {}".format(output.item()) # test if padded values of input makes any difference dummy_input = T.ones(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.arange(5, 9)).long() mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 1.0, "1.0 vs {}".format(output.item()) dummy_input = T.rand(4, 8, 128).float() dummy_target = dummy_input.detach() dummy_length = (T.arange(5, 9)).long() mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 0, "0 vs {}".format(output.item()) # seq_len_norm = True # test input == target layer = L1LossMasked(seq_len_norm=True) dummy_input = T.ones(4, 8, 128).float() dummy_target = T.ones(4, 8, 128).float() dummy_length = (T.ones(4) * 8).long() output = layer(dummy_input, dummy_target, dummy_length) assert output.item() == 0.0 # test input != target dummy_input = T.ones(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.ones(4) * 8).long() output = layer(dummy_input, dummy_target, dummy_length) assert output.item() == 1.0, "1.0 vs {}".format(output.item()) # test if padded values of input makes any difference dummy_input = T.ones(4, 8, 128).float() dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.arange(5, 9)).long() mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format( output.item()) dummy_input = T.rand(4, 8, 128).float() dummy_target = dummy_input.detach() dummy_length = (T.arange(5, 9)).long() mask = ((sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) assert output.item() == 0, "0 vs {}".format(output.item())
def forward(self, x, x_lengths, y, y_lengths, phase=None, g=None): # pylint: disable=unused-argument """ Shapes: x: [B, T_max] x_lengths: [B] y_lengths: [B] dr: [B, T_max] g: [B, C] """ o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None if phase == 0: # train encoder and MDN o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) dr_mas, mu, log_sigma, logp = self._forward_mdn( o_en, y, y_lengths, x_mask) y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask) elif phase == 1: # train decoder o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask) o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g) elif phase == 2: # train the whole except duration predictor o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) dr_mas, mu, log_sigma, logp = self._forward_mdn( o_en, y, y_lengths, x_mask) o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) elif phase == 3: # train duration predictor o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_dr_log = self.duration_predictor(x, x_mask) dr_mas, mu, log_sigma, logp = self._forward_mdn( o_en, y, y_lengths, x_mask) o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) o_dr_log = o_dr_log.squeeze(1) else: o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) dr_mas, mu, log_sigma, logp = self._forward_mdn( o_en, y, y_lengths, x_mask) o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) o_dr_log = o_dr_log.squeeze(1) dr_mas_log = torch.log(dr_mas + 1).squeeze(1) return o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): """ It's similar to the teacher forcing in Tacotron. It was proposed in: https://arxiv.org/abs/2104.05557 Shapes: x: [B, T] x_lenghts: B y: [B, C, T] y_lengths: B g: [B, C] or B """ y_max_length = y.size(2) # norm speaker embeddings if g is not None: if self.external_speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] # embedding pass o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. y, y_lengths, y_max_length, attn = self.preprocess( y, y_lengths, y_max_length, None) # create masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path between z and encoder output o_scale = torch.exp(-2 * o_log_scale) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() y_mean, y_log_scale, o_attn_dur = self.compute_outputs( attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) # get predited aligned distribution z = y_mean * y_mask # reverse the decoder and predict using the aligned distribution y, logdet = self.decoder(z, y_mask, g=g, reverse=True) return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
def _make_masks(ilens, olens): in_masks = sequence_mask(ilens) out_masks = sequence_mask(olens) return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2)