def __init__( self, # network structure related idim, odim, embed_dim=0, eprenet_conv_layers=0, eprenet_conv_chans=0, eprenet_conv_filts=0, dprenet_layers=2, dprenet_units=256, elayers=6, eunits=1024, adim=512, aheads=4, dlayers=6, dunits=1024, postnet_layers=5, postnet_chans=256, postnet_filts=5, positionwise_layer_type="conv1d", positionwise_conv_kernel_size=1, use_scaled_pos_enc=True, use_batch_norm=True, encoder_normalize_before=True, decoder_normalize_before=True, encoder_concat_after=True, # True according to https://github.com/soobinseo/Transformer-TTS decoder_concat_after=True, # True according to https://github.com/soobinseo/Transformer-TTS reduction_factor=1, spk_embed_dim=None, spk_embed_integration_type="concat", # training related transformer_enc_dropout_rate=0.1, transformer_enc_positional_dropout_rate=0.1, transformer_enc_attn_dropout_rate=0.1, transformer_dec_dropout_rate=0.1, transformer_dec_positional_dropout_rate=0.1, transformer_dec_attn_dropout_rate=0.1, transformer_enc_dec_attn_dropout_rate=0.1, eprenet_dropout_rate=0.0, dprenet_dropout_rate=0.5, postnet_dropout_rate=0.5, init_type="xavier_uniform", # since we have little to no # asymetric activations, this seems to work better than kaiming init_enc_alpha=1.0, use_masking=False, # either this or weighted masking, not both use_weighted_masking=True, # if there are severely different sized samples in one batch bce_pos_weight=7.0, # scaling the loss of the stop token prediction loss_type="L1", use_guided_attn_loss=True, num_heads_applied_guided_attn=2, num_layers_applied_guided_attn=2, modules_applied_guided_attn=("encoder-decoder", ), guided_attn_loss_sigma=0.4, # standard deviation from diagonal that is allowed guided_attn_loss_lambda=25.0): super().__init__() self.idim = idim self.odim = odim self.eos = idim - 1 self.spk_embed_dim = spk_embed_dim self.reduction_factor = reduction_factor self.use_guided_attn_loss = use_guided_attn_loss self.use_scaled_pos_enc = use_scaled_pos_enc self.use_guided_attn_loss = use_guided_attn_loss if self.use_guided_attn_loss: if num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = elayers else: self.num_layers_applied_guided_attn = num_layers_applied_guided_attn if num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = aheads else: self.num_heads_applied_guided_attn = num_heads_applied_guided_attn self.modules_applied_guided_attn = modules_applied_guided_attn if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type self.padding_idx = 0 pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) if eprenet_conv_layers != 0: encoder_input_layer = torch.nn.Sequential( EncoderPrenet(idim=idim, embed_dim=embed_dim, elayers=0, econv_layers=eprenet_conv_layers, econv_chans=eprenet_conv_chans, econv_filts=eprenet_conv_filts, use_batch_norm=use_batch_norm, dropout_rate=eprenet_dropout_rate, padding_idx=self.padding_idx), torch.nn.Linear(eprenet_conv_chans, adim)) else: encoder_input_layer = torch.nn.Embedding( num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size) if self.spk_embed_dim is not None: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) decoder_input_layer = torch.nn.Sequential( DecoderPrenet(idim=odim, n_layers=dprenet_layers, n_units=dprenet_units, dropout_rate=dprenet_dropout_rate), torch.nn.Linear(dprenet_units, adim)) self.decoder = Decoder( odim=odim, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, self_attention_dropout_rate=transformer_dec_attn_dropout_rate, src_attention_dropout_rate=transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after) self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) self.prob_out = torch.nn.Linear(adim, reduction_factor) self.postnet = PostNet(idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda) self.criterion = TransformerLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda) self.load_state_dict( torch.load(os.path.join("Models", "TransformerTTS_Eva", "best.pt"), map_location='cpu')["model"])
class Transformer(torch.nn.Module, ABC): def __init__( self, # network structure related idim, odim, embed_dim=0, eprenet_conv_layers=0, eprenet_conv_chans=0, eprenet_conv_filts=0, dprenet_layers=2, dprenet_units=256, elayers=6, eunits=1024, adim=512, aheads=4, dlayers=6, dunits=1024, postnet_layers=5, postnet_chans=256, postnet_filts=5, positionwise_layer_type="conv1d", positionwise_conv_kernel_size=1, use_scaled_pos_enc=True, use_batch_norm=True, encoder_normalize_before=True, decoder_normalize_before=True, encoder_concat_after=True, # True according to https://github.com/soobinseo/Transformer-TTS decoder_concat_after=True, # True according to https://github.com/soobinseo/Transformer-TTS reduction_factor=1, spk_embed_dim=None, spk_embed_integration_type="concat", # training related transformer_enc_dropout_rate=0.1, transformer_enc_positional_dropout_rate=0.1, transformer_enc_attn_dropout_rate=0.1, transformer_dec_dropout_rate=0.1, transformer_dec_positional_dropout_rate=0.1, transformer_dec_attn_dropout_rate=0.1, transformer_enc_dec_attn_dropout_rate=0.1, eprenet_dropout_rate=0.0, dprenet_dropout_rate=0.5, postnet_dropout_rate=0.5, init_type="xavier_uniform", # since we have little to no # asymetric activations, this seems to work better than kaiming init_enc_alpha=1.0, use_masking=False, # either this or weighted masking, not both use_weighted_masking=True, # if there are severely different sized samples in one batch bce_pos_weight=7.0, # scaling the loss of the stop token prediction loss_type="L1", use_guided_attn_loss=True, num_heads_applied_guided_attn=2, num_layers_applied_guided_attn=2, modules_applied_guided_attn=("encoder-decoder", ), guided_attn_loss_sigma=0.4, # standard deviation from diagonal that is allowed guided_attn_loss_lambda=25.0): super().__init__() self.idim = idim self.odim = odim self.eos = idim - 1 self.spk_embed_dim = spk_embed_dim self.reduction_factor = reduction_factor self.use_guided_attn_loss = use_guided_attn_loss self.use_scaled_pos_enc = use_scaled_pos_enc self.use_guided_attn_loss = use_guided_attn_loss if self.use_guided_attn_loss: if num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = elayers else: self.num_layers_applied_guided_attn = num_layers_applied_guided_attn if num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = aheads else: self.num_heads_applied_guided_attn = num_heads_applied_guided_attn self.modules_applied_guided_attn = modules_applied_guided_attn if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type self.padding_idx = 0 pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) if eprenet_conv_layers != 0: encoder_input_layer = torch.nn.Sequential( EncoderPrenet(idim=idim, embed_dim=embed_dim, elayers=0, econv_layers=eprenet_conv_layers, econv_chans=eprenet_conv_chans, econv_filts=eprenet_conv_filts, use_batch_norm=use_batch_norm, dropout_rate=eprenet_dropout_rate, padding_idx=self.padding_idx), torch.nn.Linear(eprenet_conv_chans, adim)) else: encoder_input_layer = torch.nn.Embedding( num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size) if self.spk_embed_dim is not None: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) decoder_input_layer = torch.nn.Sequential( DecoderPrenet(idim=odim, n_layers=dprenet_layers, n_units=dprenet_units, dropout_rate=dprenet_dropout_rate), torch.nn.Linear(dprenet_units, adim)) self.decoder = Decoder( odim=odim, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, self_attention_dropout_rate=transformer_dec_attn_dropout_rate, src_attention_dropout_rate=transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after) self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) self.prob_out = torch.nn.Linear(adim, reduction_factor) self.postnet = PostNet(idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda) self.criterion = TransformerLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda) self.load_state_dict( torch.load(os.path.join("Models", "TransformerTTS_Eva", "best.pt"), map_location='cpu')["model"]) def forward(self, text, speaker_embedding=None): self.eval() x = text xs = x.unsqueeze(0) hs, _ = self.encoder(xs, None) if self.spk_embed_dim is not None: speaker_embeddings = speaker_embedding.unsqueeze(0) hs = self._integrate_with_spk_embed(hs, speaker_embeddings) maxlen = int(hs.size(1) * 10.0 / self.reduction_factor) minlen = int(hs.size(1) * 0.0 / self.reduction_factor) idx = 0 ys = hs.new_zeros(1, 1, self.odim) outs, probs = [], [] z_cache = self.decoder.init_state(x) while True: idx += 1 y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device) z, z_cache = self.decoder.forward_one_step(ys, y_masks, hs, cache=z_cache) outs += [self.feat_out(z).view(self.reduction_factor, self.odim)] probs += [torch.sigmoid(self.prob_out(z))[0]] ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)), dim=1) att_ws_ = [] for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention) and "src" in name: att_ws_ += [m.attn[0, :, -1].unsqueeze(1)] if idx == 1: att_ws = att_ws_ else: att_ws = [ torch.cat([att_w, att_w_], dim=1) for att_w, att_w_ in zip(att_ws, att_ws_) ] if int(sum(probs[-1] >= 0.5)) > 0 or idx >= maxlen: if idx < minlen: continue outs = (torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2)) if self.postnet is not None: outs = outs + self.postnet(outs) outs = outs.transpose(2, 1).squeeze(0) break return outs @staticmethod def _add_first_frame_and_remove_last_frame(ys): return torch.cat( [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1) def _source_mask(self, ilens): x_masks = make_non_pad_mask(ilens).to(ilens.device) return x_masks.unsqueeze(-2) def _target_mask(self, olens): y_masks = make_non_pad_mask(olens).to(olens.device) s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) return y_masks.unsqueeze(-2) & s_masks def _integrate_with_spk_embed(self, hs, speaker_embeddings): speaker_embeddings = F.normalize(speaker_embeddings).unsqueeze( 1).expand(-1, hs.size(1), -1) hs = self.projection(torch.cat([hs, speaker_embeddings], dim=-1)) return hs