def __init__(self, args, src_dict, embed_speaker): super().__init__(src_dict) self.args = args self.padding_idx = src_dict.pad() self.n_frames_per_step = args.n_frames_per_step self.out_dim = args.output_frame_dim * args.n_frames_per_step self.embed_speaker = embed_speaker self.spk_emb_proj = None if embed_speaker is not None: self.spk_emb_proj = nn.Linear( args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim) self.dropout_module = FairseqDropout( p=args.dropout, module_name=self.__class__.__name__) self.embed_tokens = Embedding(len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx) self.embed_positions = PositionalEmbedding(args.max_source_positions, args.encoder_embed_dim, self.padding_idx) self.pos_emb_alpha = nn.Parameter(torch.ones(1)) self.dec_pos_emb_alpha = nn.Parameter(torch.ones(1)) self.encoder_fft_layers = nn.ModuleList( FFTLayer(args.encoder_embed_dim, args.encoder_attention_heads, args.fft_hidden_dim, args.fft_kernel_size, dropout=args.dropout, attention_dropout=args.attention_dropout) for _ in range(args.encoder_layers)) self.var_adaptor = VarianceAdaptor(args) self.decoder_fft_layers = nn.ModuleList( FFTLayer(args.decoder_embed_dim, args.decoder_attention_heads, args.fft_hidden_dim, args.fft_kernel_size, dropout=args.dropout, attention_dropout=args.attention_dropout) for _ in range(args.decoder_layers)) self.out_proj = nn.Linear(args.decoder_embed_dim, self.out_dim) self.postnet = None if args.add_postnet: self.postnet = Postnet(self.out_dim, args.postnet_conv_dim, args.postnet_conv_kernel_size, args.postnet_layers, args.postnet_dropout) self.apply(model_init)
def __init__(self, args, src_dict, padding_idx=1): super().__init__(None) self._future_mask = torch.empty(0) self.args = args self.padding_idx = src_dict.pad() if src_dict else padding_idx self.n_frames_per_step = args.n_frames_per_step self.out_dim = args.output_frame_dim * args.n_frames_per_step self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) self.embed_positions = PositionalEmbedding( args.max_target_positions, args.decoder_embed_dim, self.padding_idx ) self.pos_emb_alpha = nn.Parameter(torch.ones(1)) self.prenet = nn.Sequential( Prenet( self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout ), nn.Linear(args.prenet_dim, args.decoder_embed_dim), ) self.n_transformer_layers = args.decoder_transformer_layers self.transformer_layers = nn.ModuleList( TransformerDecoderLayer(args) for _ in range(self.n_transformer_layers) ) if args.decoder_normalize_before: self.layer_norm = LayerNorm(args.decoder_embed_dim) else: self.layer_norm = None self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim) self.eos_proj = nn.Linear(args.decoder_embed_dim, 1) self.postnet = Postnet( self.out_dim, args.postnet_conv_dim, args.postnet_conv_kernel_size, args.postnet_layers, args.postnet_dropout, ) self.ctc_proj = None if getattr(args, "ctc_weight", 0.0) > 0.0: self.ctc_proj = nn.Linear(self.out_dim, len(src_dict)) self.apply(decoder_init)