コード例 #1
0
    def __init__(self, args, src_dict, embed_speaker):
        super().__init__(src_dict)
        self.args = args
        self.padding_idx = src_dict.pad()
        self.n_frames_per_step = args.n_frames_per_step
        self.out_dim = args.output_frame_dim * args.n_frames_per_step

        self.embed_speaker = embed_speaker
        self.spk_emb_proj = None
        if embed_speaker is not None:
            self.spk_emb_proj = nn.Linear(
                args.encoder_embed_dim + args.speaker_embed_dim,
                args.encoder_embed_dim)

        self.dropout_module = FairseqDropout(
            p=args.dropout, module_name=self.__class__.__name__)
        self.embed_tokens = Embedding(len(src_dict),
                                      args.encoder_embed_dim,
                                      padding_idx=self.padding_idx)

        self.embed_positions = PositionalEmbedding(args.max_source_positions,
                                                   args.encoder_embed_dim,
                                                   self.padding_idx)
        self.pos_emb_alpha = nn.Parameter(torch.ones(1))
        self.dec_pos_emb_alpha = nn.Parameter(torch.ones(1))

        self.encoder_fft_layers = nn.ModuleList(
            FFTLayer(args.encoder_embed_dim,
                     args.encoder_attention_heads,
                     args.fft_hidden_dim,
                     args.fft_kernel_size,
                     dropout=args.dropout,
                     attention_dropout=args.attention_dropout)
            for _ in range(args.encoder_layers))

        self.var_adaptor = VarianceAdaptor(args)

        self.decoder_fft_layers = nn.ModuleList(
            FFTLayer(args.decoder_embed_dim,
                     args.decoder_attention_heads,
                     args.fft_hidden_dim,
                     args.fft_kernel_size,
                     dropout=args.dropout,
                     attention_dropout=args.attention_dropout)
            for _ in range(args.decoder_layers))

        self.out_proj = nn.Linear(args.decoder_embed_dim, self.out_dim)

        self.postnet = None
        if args.add_postnet:
            self.postnet = Postnet(self.out_dim, args.postnet_conv_dim,
                                   args.postnet_conv_kernel_size,
                                   args.postnet_layers, args.postnet_dropout)

        self.apply(model_init)
コード例 #2
0
    def __init__(self, args, src_dict, padding_idx=1):
        super().__init__(None)
        self._future_mask = torch.empty(0)

        self.args = args
        self.padding_idx = src_dict.pad() if src_dict else padding_idx
        self.n_frames_per_step = args.n_frames_per_step
        self.out_dim = args.output_frame_dim * args.n_frames_per_step

        self.dropout_module = FairseqDropout(
            args.dropout, module_name=self.__class__.__name__
        )
        self.embed_positions = PositionalEmbedding(
            args.max_target_positions, args.decoder_embed_dim, self.padding_idx
        )
        self.pos_emb_alpha = nn.Parameter(torch.ones(1))
        self.prenet = nn.Sequential(
            Prenet(
                self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout
            ),
            nn.Linear(args.prenet_dim, args.decoder_embed_dim),
        )

        self.n_transformer_layers = args.decoder_transformer_layers
        self.transformer_layers = nn.ModuleList(
            TransformerDecoderLayer(args) for _ in range(self.n_transformer_layers)
        )
        if args.decoder_normalize_before:
            self.layer_norm = LayerNorm(args.decoder_embed_dim)
        else:
            self.layer_norm = None

        self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim)
        self.eos_proj = nn.Linear(args.decoder_embed_dim, 1)

        self.postnet = Postnet(
            self.out_dim,
            args.postnet_conv_dim,
            args.postnet_conv_kernel_size,
            args.postnet_layers,
            args.postnet_dropout,
        )

        self.ctc_proj = None
        if getattr(args, "ctc_weight", 0.0) > 0.0:
            self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))

        self.apply(decoder_init)