Beispiel #1
0
    def __init__(
            self,
            enc_in_size,
            dec_in_size,
            dec_out_size,
            enc_emb_size=256,
            enc_emb_do=0.0,
            enc_prenet_size=[256, 128],
            enc_prenet_do=[0.5, 0.5],
            enc_prenet_fn='leaky_relu',
            dec_prenet_size=[256, 128],
            dec_prenet_do=[0.5, 0.5],
            dec_prenet_fn='leaky_relu',
            dec_rnn_sizes=[256, 256],
            dec_rnn_cfgs={"type": "lstm"},
            dec_rnn_do=0.0,
            dec_cfg={"type": "standard_decoder"},
            att_cfg={"type": "mlp"},
            dec_core_gen_size=[512],
            dec_core_gen_fn='leaky_relu',
            dec_core_gen_do=0.0,
            # CBHG #
            enc_cbhg_cfg={},
            # FRAME ENDING #
            dec_bern_end_size=[256],
            dec_bern_end_fn='LeakyReLU',
            dec_bern_end_do=0.0,
            # OPTIONAL #
            dec_in_range=None):
        """
        Args:
            enc_in_size : size of vocab
            dec_in_size : input (mel) dim size
            dec_out_size : output (mel) dim size (usually same as dec_in_size)
            dec_in_range : 
                pair of integer [x, y] \in [0, dec_in_size], 
                all dims outside this pair will be masked as 0
                in Tacotron paper, they only use last time-step instead of all group
        """
        super().__init__()
        self.enc_in_size = enc_in_size
        self.dec_in_size = dec_in_size
        self.dec_out_size = dec_out_size  # mel spec dim size
        self.enc_emb_size = enc_emb_size
        self.enc_emb_do = enc_emb_do
        self.enc_prenet_size = enc_prenet_size
        self.enc_prenet_do = ConfigParser.list_parser(enc_prenet_do,
                                                      len(enc_prenet_size))
        self.enc_prenet_fn = enc_prenet_fn
        self.dec_prenet_size = dec_prenet_size
        self.dec_prenet_do = ConfigParser.list_parser(dec_prenet_do,
                                                      len(dec_prenet_size))
        self.dec_prenet_fn = dec_prenet_fn
        self.dec_rnn_sizes = dec_rnn_sizes
        self.dec_rnn_cfgs = dec_rnn_cfgs
        self.dec_rnn_do = dec_rnn_do
        self.dec_core_gen_size = dec_core_gen_size
        self.dec_core_gen_fn = dec_core_gen_fn
        self.dec_core_gen_do = ConfigParser.list_parser(
            dec_core_gen_do, len(dec_core_gen_size))
        self.dec_cfg = dec_cfg
        self.att_cfg = att_cfg

        # FRAME ENDING #
        self.dec_bern_end_size = dec_bern_end_size
        self.dec_bern_end_fn = dec_bern_end_fn
        self.dec_bern_end_do = ConfigParser.list_parser(dec_bern_end_do)

        # OPTIONAL #
        self.dec_in_range = dec_in_range
        if self.dec_in_range is not None:
            assert isinstance(self.dec_in_range, (list, tuple)) \
                    and len(self.dec_in_range) == 2

        # CBHG config #
        self.enc_cbhg_cfg = ConfigParser.item_parser(enc_cbhg_cfg)

        self.enc_emb_lyr = nn.Embedding(enc_in_size, enc_emb_size)
        # init enc prenet #
        self.enc_prenet_lyr = nn.ModuleList()
        prev_size = enc_emb_size
        for ii in range(len(self.enc_prenet_size)):
            self.enc_prenet_lyr.append(
                nn.Linear(prev_size, enc_prenet_size[ii]))
            prev_size = enc_prenet_size[ii]
        # init enc middle #
        self.enc_core_lyr = cbhg.CBHG1d(prev_size, **enc_cbhg_cfg)
        # init dec prenet #
        self.dec_prenet_lyr = nn.ModuleList()
        prev_size = dec_in_size if self.dec_in_range is None else (
            (self.dec_in_range[-1] or 0) - (self.dec_in_range[-2] or 0))
        for ii in range(len(self.dec_prenet_size)):
            self.dec_prenet_lyr.append(
                nn.Linear(prev_size, dec_prenet_size[ii]))
            prev_size = dec_prenet_size[ii]

        # init dec rnn #
        _dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                 len(dec_rnn_sizes))
        for ii in range(len(dec_rnn_sizes)):
            _type = _dec_rnn_cfgs[ii]['type']
            if re.match('stateful.*cell', _type) is None:
                _dec_rnn_cfgs[ii]['type'] = 'stateful_{}cell'.format(_type)
        # TODO : dec_cfg #
        final_enc_size = self.enc_core_lyr.output_size
        self.dec_att_lyr = decoder.StandardDecoder(att_cfg, final_enc_size,
                                                   prev_size, dec_rnn_sizes,
                                                   dec_rnn_cfgs, dec_rnn_do)

        # init decoder layer melspec generator #
        prev_size = self.dec_att_lyr.output_size
        self.dec_core_gen_lyr = nn.ModuleList()
        for ii in range(len(self.dec_core_gen_size)):
            self.dec_core_gen_lyr.append(
                nn.Linear(prev_size, self.dec_core_gen_size[ii]))
            prev_size = self.dec_core_gen_size[ii]
        self.dec_core_gen_lyr.append(nn.Linear(prev_size, self.dec_out_size))

        # init decoder frame ending predictor #
        # p(t=STOP | dec_hid[t], y[t]) #
        _tmp = []
        prev_size = self.dec_att_lyr.output_size + self.dec_out_size
        for ii in range(len(dec_bern_end_size)):
            _tmp.append(nn.Linear(prev_size, self.dec_bern_end_size[ii]))
            _tmp.append(generator_act_module(self.dec_bern_end_fn))
            _tmp.append(nn.Dropout(p=self.dec_bern_end_do[ii]))
            prev_size = self.dec_bern_end_size[ii]
        _tmp.append(nn.Linear(prev_size, 1))
        # output is logit, not transformed into sigmoid #
        self.dec_bernoulli_end_lyr = nn.Sequential(*_tmp)
Beispiel #2
0
    def __init__(
            self,
            enc_in_size,
            dec_in_size,
            dec_out_size,
            dec_out_post_size,
            enc_emb_size=256,
            enc_emb_do=0.0,
            enc_prenet_size=[256, 128],
            enc_prenet_do=[0.5, 0.5],
            enc_prenet_fn='leaky_relu',
            dec_prenet_size=[256, 128],
            dec_prenet_do=[0.5, 0.5],
            dec_prenet_fn='leaky_relu',
            dec_rnn_sizes=[256, 256],
            dec_rnn_cfgs={"type": "lstm"},
            dec_rnn_do=0.0,
            dec_cfg={"type": "standard_decoder"},
            att_cfg={"type": "mlp"},
            # CBHG #
            enc_cbhg_cfg={},
            dec_postnet_cbhg_cfg={},
            # OPTIONAL #
            dec_in_range=None):
        """
        Args:
            enc_in_size : size of vocab
            dec_in_size : input (mel) dim size
            dec_out_size : output (mel) dim size (usually same as dec_in_size)
            dec_out_post_size : output (linear) dim size
            dec_in_range : 
                pair of integer [x, y] \in [0, dec_in_size], 
                all dims outside this pair will be masked as 0
                in Tacotron paper, they only use last time-step instead of all group
        """
        super(TACOTRON, self).__init__()
        self.enc_in_size = enc_in_size
        self.dec_in_size = dec_in_size
        self.dec_out_size = dec_out_size  # first output -> mel spec #
        self.dec_out_post_size = dec_out_post_size  # second output -> raw spec #
        self.enc_emb_size = enc_emb_size
        self.enc_emb_do = enc_emb_do
        self.enc_prenet_size = enc_prenet_size
        self.enc_prenet_do = ConfigParser.list_parser(enc_prenet_do,
                                                      len(enc_prenet_size))
        self.enc_prenet_fn = enc_prenet_fn
        self.dec_prenet_size = dec_prenet_size
        self.dec_prenet_do = ConfigParser.list_parser(dec_prenet_do,
                                                      len(dec_prenet_size))
        self.dec_prenet_fn = dec_prenet_fn
        self.dec_rnn_sizes = dec_rnn_sizes
        self.dec_rnn_cfgs = dec_rnn_cfgs
        self.dec_rnn_do = dec_rnn_do
        self.dec_cfg = dec_cfg
        self.att_cfg = att_cfg

        # OPTIONAL #
        self.dec_in_range = dec_in_range
        if self.dec_in_range is not None:
            assert isinstance(self.dec_in_range, (list, tuple)) \
                    and len(self.dec_in_range) == 2

        # CBHG config #
        self.enc_cbhg_cfg = ConfigParser.item_parser(enc_cbhg_cfg)
        self.dec_postnet_cbhg_cfg = ConfigParser.item_parser(
            dec_postnet_cbhg_cfg)

        self.enc_emb_lyr = nn.Embedding(enc_in_size, enc_emb_size)
        # init enc prenet #
        self.enc_prenet_lyr = nn.ModuleList()
        prev_size = enc_emb_size
        for ii in range(len(self.enc_prenet_size)):
            self.enc_prenet_lyr.append(
                nn.Linear(prev_size, enc_prenet_size[ii]))
            prev_size = enc_prenet_size[ii]
        # init enc middle #
        self.enc_core_lyr = cbhg.CBHG1d(prev_size, **enc_cbhg_cfg)
        # init dec prenet #
        self.dec_prenet_lyr = nn.ModuleList()
        prev_size = dec_in_size if self.dec_in_range is None else (
            (self.dec_in_range[-1] or 0) - (self.dec_in_range[-2] or 0))
        for ii in range(len(self.dec_prenet_size)):
            self.dec_prenet_lyr.append(
                nn.Linear(prev_size, dec_prenet_size[ii]))
            prev_size = dec_prenet_size[ii]

        # init dec rnn #
        _dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                 len(dec_rnn_sizes))
        for ii in range(len(dec_rnn_sizes)):
            _type = _dec_rnn_cfgs[ii]['type']
            if re.match('stateful.*cell', _type) is None:
                _dec_rnn_cfgs[ii]['type'] = 'stateful_{}cell'.format(_type)
        # TODO : dec_cfg #
        final_enc_size = self.enc_core_lyr.output_size
        self.dec_att_lyr = decoder.StandardDecoder(att_cfg, final_enc_size,
                                                   prev_size, dec_rnn_sizes,
                                                   dec_rnn_cfgs, dec_rnn_do)

        # init dec regression melspec #
        self.dec_first_reg_lyr = nn.Linear(self.dec_att_lyr.output_size,
                                           self.dec_out_size)
        # init dec postnet #
        self.dec_postnet_lyr = cbhg.CBHG1d(
            self.dec_out_size,
            conv_proj_filter=[256, dec_out_size],
            **dec_postnet_cbhg_cfg)
        # init dec regression rawspec #
        self.dec_second_reg_lyr = nn.Linear(self.dec_postnet_lyr.output_size,
                                            self.dec_out_post_size)
        pass