def __init__( self, # network structure related idim, odim, adim=384, aheads=4, elayers=6, eunits=1536, dlayers=6, dunits=1536, postnet_layers=5, postnet_chans=256, postnet_filts=5, positionwise_layer_type="conv1d", positionwise_conv_kernel_size=1, use_scaled_pos_enc=True, use_batch_norm=True, encoder_normalize_before=True, decoder_normalize_before=True, encoder_concat_after=False, decoder_concat_after=False, reduction_factor=1, # encoder / decoder conformer_pos_enc_layer_type="rel_pos", conformer_self_attn_layer_type="rel_selfattn", conformer_activation_type="swish", use_macaron_style_in_conformer=True, use_cnn_in_conformer=True, conformer_enc_kernel_size=7, conformer_dec_kernel_size=31, # duration predictor duration_predictor_layers=2, duration_predictor_chans=256, duration_predictor_kernel_size=3, # energy predictor energy_predictor_layers=2, energy_predictor_chans=256, energy_predictor_kernel_size=3, energy_predictor_dropout=0.5, energy_embed_kernel_size=1, energy_embed_dropout=0.0, stop_gradient_from_energy_predictor=True, # pitch predictor pitch_predictor_layers=5, pitch_predictor_chans=256, pitch_predictor_kernel_size=5, pitch_predictor_dropout=0.5, pitch_embed_kernel_size=1, pitch_embed_dropout=0.0, stop_gradient_from_pitch_predictor=True, # pretrained spk emb spk_embed_dim=None, # training related transformer_enc_dropout_rate=0.2, transformer_enc_positional_dropout_rate=0.2, transformer_enc_attn_dropout_rate=0.2, transformer_dec_dropout_rate=0.2, transformer_dec_positional_dropout_rate=0.2, transformer_dec_attn_dropout_rate=0.2, duration_predictor_dropout_rate=0.2, postnet_dropout_rate=0.5, init_type="kaiming_uniform", init_enc_alpha=1.0, init_dec_alpha=1.0, use_masking=False, use_weighted_masking=True, lang='en'): super().__init__() self.idim = idim self.odim = odim self.reduction_factor = reduction_factor self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc self.spk_embed_dim = spk_embed_dim self.padding_idx = 0 encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) self.encoder = Conformer( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size) if self.spk_embed_dim is not None: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout) self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2), torch.nn.Dropout(pitch_embed_dropout)) self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout) self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2), torch.nn.Dropout(energy_embed_dropout)) self.length_regulator = LengthRegulator() self.decoder = Conformer( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size) self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) self.postnet = PostNet(idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate) self.load_state_dict( torch.load(os.path.join("Models", "FastSpeech2_Elizabeth", "best.pt"), map_location='cpu')["model"])
class FastSpeech2(torch.nn.Module, ABC): def __init__( self, # network structure related idim, odim, adim=384, aheads=4, elayers=6, eunits=1536, dlayers=6, dunits=1536, postnet_layers=5, postnet_chans=256, postnet_filts=5, positionwise_layer_type="conv1d", positionwise_conv_kernel_size=1, use_scaled_pos_enc=True, use_batch_norm=True, encoder_normalize_before=True, decoder_normalize_before=True, encoder_concat_after=False, decoder_concat_after=False, reduction_factor=1, # encoder / decoder conformer_pos_enc_layer_type="rel_pos", conformer_self_attn_layer_type="rel_selfattn", conformer_activation_type="swish", use_macaron_style_in_conformer=True, use_cnn_in_conformer=True, conformer_enc_kernel_size=7, conformer_dec_kernel_size=31, # duration predictor duration_predictor_layers=2, duration_predictor_chans=256, duration_predictor_kernel_size=3, # energy predictor energy_predictor_layers=2, energy_predictor_chans=256, energy_predictor_kernel_size=3, energy_predictor_dropout=0.5, energy_embed_kernel_size=1, energy_embed_dropout=0.0, stop_gradient_from_energy_predictor=True, # pitch predictor pitch_predictor_layers=5, pitch_predictor_chans=256, pitch_predictor_kernel_size=5, pitch_predictor_dropout=0.5, pitch_embed_kernel_size=1, pitch_embed_dropout=0.0, stop_gradient_from_pitch_predictor=True, # pretrained spk emb spk_embed_dim=None, # training related transformer_enc_dropout_rate=0.2, transformer_enc_positional_dropout_rate=0.2, transformer_enc_attn_dropout_rate=0.2, transformer_dec_dropout_rate=0.2, transformer_dec_positional_dropout_rate=0.2, transformer_dec_attn_dropout_rate=0.2, duration_predictor_dropout_rate=0.2, postnet_dropout_rate=0.5, init_type="kaiming_uniform", init_enc_alpha=1.0, init_dec_alpha=1.0, use_masking=False, use_weighted_masking=True, lang='en'): super().__init__() self.idim = idim self.odim = odim self.reduction_factor = reduction_factor self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc self.spk_embed_dim = spk_embed_dim self.padding_idx = 0 encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) self.encoder = Conformer( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size) if self.spk_embed_dim is not None: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout) self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2), torch.nn.Dropout(pitch_embed_dropout)) self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout) self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2), torch.nn.Dropout(energy_embed_dropout)) self.length_regulator = LengthRegulator() self.decoder = Conformer( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size) self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) self.postnet = PostNet(idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate) self.load_state_dict( torch.load(os.path.join("Models", "FastSpeech2_Elizabeth", "best.pt"), map_location='cpu')["model"]) def _forward(self, xs, ilens, ys=None, olens=None, ds=None, ps=None, es=None, speaker_embeddings=None, is_inference=False, alpha=1.0): x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, speaker_embeddings) d_masks = make_pad_mask(ilens).to(xs.device) if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs hs = self.length_regulator(hs, d_outs, alpha) else: d_outs = self.duration_predictor(hs, d_masks) p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs hs = self.length_regulator(hs, ds) if olens is not None and not is_inference: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(hs, h_masks) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) return before_outs, after_outs, d_outs, p_outs, e_outs def forward(self, text, speaker_embedding=None, alpha=1.0): self.eval() x = text ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs = x.unsqueeze(0) if speaker_embedding is not None: speaker_embedding = speaker_embedding.unsqueeze(0) _, outs, *_ = self._forward(xs, ilens, None, speaker_embeddings=speaker_embedding, is_inference=True, alpha=alpha) return outs[0] def _integrate_with_spk_embed(self, hs, speaker_embeddings): speaker_embeddings = F.normalize(speaker_embeddings).unsqueeze( 1).expand(-1, hs.size(1), -1) hs = self.projection(torch.cat([hs, speaker_embeddings], dim=-1)) return hs def _source_mask(self, ilens): x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2)