Exemplo n.º 1
0
    def __init__(self, enc_in_size, dec_in_size, n_class,
            enc_fnn_sizes=[512], enc_fnn_act='tanh', enc_fnn_do=0.25,
            enc_rnn_sizes=[256, 256, 256], enc_rnn_cfgs={"type":"lstm", "bi":True}, enc_rnn_do=0.25,
            downsampling=None,
            dec_emb_size=64, dec_emb_do=0.0,
            dec_rnn_sizes=[512, 512], dec_rnn_cfgs={"type":"lstm"}, dec_rnn_do=0.25,
            dec_cfg={"type":"standard_decoder"}, 
            att_cfg={"type":"mlp"},
            ) :
        super(ENCRNN_DECRNN_ATT_ASR, self).__init__()

        self.enc_in_size = enc_in_size
        self.dec_in_size = dec_in_size
        self.n_class = n_class
        self.enc_fnn_sizes = enc_fnn_sizes
        self.enc_fnn_act = enc_fnn_act
        self.enc_fnn_do = ConfigParser.list_parser(enc_fnn_do, len(enc_fnn_sizes))
        self.enc_rnn_sizes = enc_rnn_sizes
        self.enc_rnn_cfgs = enc_rnn_cfgs
        self.enc_rnn_do =  ConfigParser.list_parser(enc_rnn_do, len(enc_rnn_sizes))
        self.downsampling = ConfigParser.list_parser(downsampling, len(enc_rnn_sizes))

        self.dec_emb_size = dec_emb_size
        self.dec_emb_do = dec_emb_do
        self.dec_rnn_sizes = dec_rnn_sizes
        self.dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs, len(dec_rnn_sizes)) 
        self.dec_rnn_do = ConfigParser.list_parser(dec_rnn_do, len(dec_rnn_sizes))
        self.dec_cfg = dec_cfg
        self.att_cfg = att_cfg

        # modules #
        # init encoder #
        self.enc_fnn = nn.ModuleList()
        prev_size = enc_in_size
        for ii in range(len(enc_fnn_sizes)) :
            self.enc_fnn.append(nn.Linear(prev_size, enc_fnn_sizes[ii]))
            prev_size = enc_fnn_sizes[ii]

        self.enc_rnn = nn.ModuleList()
        _enc_rnn_cfgs = ConfigParser.list_parser(enc_rnn_cfgs, len(enc_rnn_sizes))
        for ii in range(len(enc_rnn_sizes)) :
            _rnn_cfg = {}
            _rnn_cfg['type'] = _enc_rnn_cfgs[ii]['type']
            _rnn_cfg['args'] = [prev_size, enc_rnn_sizes[ii], 1, True, True, 0, _enc_rnn_cfgs[ii]['bi']]
            self.enc_rnn.append(generator_rnn(_rnn_cfg))
            prev_size = enc_rnn_sizes[ii] * (2 if _enc_rnn_cfgs[ii]['bi'] else 1)
        final_enc_size = prev_size
        # init decoder #
        self.dec_emb = nn.Embedding(self.dec_in_size, dec_emb_size, padding_idx=None)
        prev_size = dec_emb_size
        _dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs, len(dec_rnn_sizes))
        for ii in range(len(dec_rnn_sizes)) :
            _type = _dec_rnn_cfgs[ii]['type']
            if re.match('stateful.*cell', _type) is None :
                _dec_rnn_cfgs[ii]['type'] = 'stateful_{}cell'.format(_type)
        # TODO : dec_cfg #
        self.dec = decoder.StandardDecoder(att_cfg, final_enc_size, dec_emb_size, 
                dec_rnn_sizes, _dec_rnn_cfgs, dec_rnn_do)
        self.pre_softmax = nn.Linear(self.dec.output_size, n_class)
        pass 
Exemplo n.º 2
0
    def __init__(self,
                 enc_in_size,
                 dec_in_size,
                 dec_out_size,
                 dec_rnn_sizes=[512, 512],
                 dec_rnn_cfgs={'type': 'lstm'},
                 dec_rnn_do=0.25,
                 dec_cfg={'type': 'standard_decoder'},
                 att_cfg={'type': 'mlp'}):
        super().__init__()
        self.enc_in_size = enc_in_size
        self.dec_in_size = dec_in_size
        self.dec_out_size = dec_out_size

        self.dec_rnn_sizes = dec_rnn_sizes
        self.dec_rnn_cfgs = dec_rnn_cfgs
        self.dec_rnn_do = ConfigParser.list_parser(dec_rnn_do,
                                                   len(dec_rnn_sizes))
        self.dec_cfg = dec_cfg
        self.att_cfg = att_cfg

        # init encoder #
        self.enc_lyr = encoder.StandardRNNEncoder(enc_in_size,
                                                  do=0.0,
                                                  downsampling={
                                                      'type': 'last',
                                                      'step': 2
                                                  })

        # init decoder #
        _dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                 len(dec_rnn_sizes))
        for ii in range(len(dec_rnn_sizes)):
            _type = _dec_rnn_cfgs[ii]['type']
            if re.match('stateful.*cell', _type) is None:
                _dec_rnn_cfgs[ii]['type'] = 'stateful_{}cell'.format(_type)

        prev_size = dec_in_size
        self.dec_lyr = decoder.StandardDecoder(att_cfg,
                                               self.enc_lyr.out_features,
                                               prev_size, dec_rnn_sizes,
                                               _dec_rnn_cfgs, self.dec_rnn_do)

        # init decoder regression #
        self.dec_core_reg_lyr = nn.Linear(self.dec_lyr.out_features,
                                          dec_out_size)
        pass
Exemplo n.º 3
0
    def __init__(
            self,
            enc_in_size,
            dec_in_size,
            dec_out_size,
            enc_emb_size=256,
            enc_emb_do=0.0,
            enc_prenet_size=[256, 128],
            enc_prenet_do=[0.5, 0.5],
            enc_prenet_fn='leaky_relu',
            dec_prenet_size=[256, 128],
            dec_prenet_do=[0.5, 0.5],
            dec_prenet_fn='leaky_relu',
            dec_rnn_sizes=[256, 256],
            dec_rnn_cfgs={"type": "lstm"},
            dec_rnn_do=0.0,
            dec_cfg={"type": "standard_decoder"},
            att_cfg={"type": "mlp"},
            dec_core_gen_size=[512],
            dec_core_gen_fn='leaky_relu',
            dec_core_gen_do=0.0,
            # CBHG #
            enc_cbhg_cfg={},
            # FRAME ENDING #
            dec_bern_end_size=[256],
            dec_bern_end_fn='LeakyReLU',
            dec_bern_end_do=0.0,
            # OPTIONAL #
            dec_in_range=None):
        """
        Args:
            enc_in_size : size of vocab
            dec_in_size : input (mel) dim size
            dec_out_size : output (mel) dim size (usually same as dec_in_size)
            dec_in_range : 
                pair of integer [x, y] \in [0, dec_in_size], 
                all dims outside this pair will be masked as 0
                in Tacotron paper, they only use last time-step instead of all group
        """
        super().__init__()
        self.enc_in_size = enc_in_size
        self.dec_in_size = dec_in_size
        self.dec_out_size = dec_out_size  # mel spec dim size
        self.enc_emb_size = enc_emb_size
        self.enc_emb_do = enc_emb_do
        self.enc_prenet_size = enc_prenet_size
        self.enc_prenet_do = ConfigParser.list_parser(enc_prenet_do,
                                                      len(enc_prenet_size))
        self.enc_prenet_fn = enc_prenet_fn
        self.dec_prenet_size = dec_prenet_size
        self.dec_prenet_do = ConfigParser.list_parser(dec_prenet_do,
                                                      len(dec_prenet_size))
        self.dec_prenet_fn = dec_prenet_fn
        self.dec_rnn_sizes = dec_rnn_sizes
        self.dec_rnn_cfgs = dec_rnn_cfgs
        self.dec_rnn_do = dec_rnn_do
        self.dec_core_gen_size = dec_core_gen_size
        self.dec_core_gen_fn = dec_core_gen_fn
        self.dec_core_gen_do = ConfigParser.list_parser(
            dec_core_gen_do, len(dec_core_gen_size))
        self.dec_cfg = dec_cfg
        self.att_cfg = att_cfg

        # FRAME ENDING #
        self.dec_bern_end_size = dec_bern_end_size
        self.dec_bern_end_fn = dec_bern_end_fn
        self.dec_bern_end_do = ConfigParser.list_parser(dec_bern_end_do)

        # OPTIONAL #
        self.dec_in_range = dec_in_range
        if self.dec_in_range is not None:
            assert isinstance(self.dec_in_range, (list, tuple)) \
                    and len(self.dec_in_range) == 2

        # CBHG config #
        self.enc_cbhg_cfg = ConfigParser.item_parser(enc_cbhg_cfg)

        self.enc_emb_lyr = nn.Embedding(enc_in_size, enc_emb_size)
        # init enc prenet #
        self.enc_prenet_lyr = nn.ModuleList()
        prev_size = enc_emb_size
        for ii in range(len(self.enc_prenet_size)):
            self.enc_prenet_lyr.append(
                nn.Linear(prev_size, enc_prenet_size[ii]))
            prev_size = enc_prenet_size[ii]
        # init enc middle #
        self.enc_core_lyr = cbhg.CBHG1d(prev_size, **enc_cbhg_cfg)
        # init dec prenet #
        self.dec_prenet_lyr = nn.ModuleList()
        prev_size = dec_in_size if self.dec_in_range is None else (
            (self.dec_in_range[-1] or 0) - (self.dec_in_range[-2] or 0))
        for ii in range(len(self.dec_prenet_size)):
            self.dec_prenet_lyr.append(
                nn.Linear(prev_size, dec_prenet_size[ii]))
            prev_size = dec_prenet_size[ii]

        # init dec rnn #
        _dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                 len(dec_rnn_sizes))
        for ii in range(len(dec_rnn_sizes)):
            _type = _dec_rnn_cfgs[ii]['type']
            if re.match('stateful.*cell', _type) is None:
                _dec_rnn_cfgs[ii]['type'] = 'stateful_{}cell'.format(_type)
        # TODO : dec_cfg #
        final_enc_size = self.enc_core_lyr.output_size
        self.dec_att_lyr = decoder.StandardDecoder(att_cfg, final_enc_size,
                                                   prev_size, dec_rnn_sizes,
                                                   dec_rnn_cfgs, dec_rnn_do)

        # init decoder layer melspec generator #
        prev_size = self.dec_att_lyr.output_size
        self.dec_core_gen_lyr = nn.ModuleList()
        for ii in range(len(self.dec_core_gen_size)):
            self.dec_core_gen_lyr.append(
                nn.Linear(prev_size, self.dec_core_gen_size[ii]))
            prev_size = self.dec_core_gen_size[ii]
        self.dec_core_gen_lyr.append(nn.Linear(prev_size, self.dec_out_size))

        # init decoder frame ending predictor #
        # p(t=STOP | dec_hid[t], y[t]) #
        _tmp = []
        prev_size = self.dec_att_lyr.output_size + self.dec_out_size
        for ii in range(len(dec_bern_end_size)):
            _tmp.append(nn.Linear(prev_size, self.dec_bern_end_size[ii]))
            _tmp.append(generator_act_module(self.dec_bern_end_fn))
            _tmp.append(nn.Dropout(p=self.dec_bern_end_do[ii]))
            prev_size = self.dec_bern_end_size[ii]
        _tmp.append(nn.Linear(prev_size, 1))
        # output is logit, not transformed into sigmoid #
        self.dec_bernoulli_end_lyr = nn.Sequential(*_tmp)
Exemplo n.º 4
0
    def __init__(
            self,
            enc_in_size,
            dec_in_size,
            dec_out_size,
            enc_emb_size=256,
            enc_emb_do=0.0,
            enc_conv_sizes=[5, 5, 5],
            enc_conv_filter=[256, 256, 256],
            enc_conv_do=0.25,
            enc_conv_fn='LeakyReLU',
            enc_rnn_sizes=[256],
            enc_rnn_cfgs={
                "type": "lstm",
                'bi': True
            },
            enc_rnn_do=0.2,
            dec_prenet_size=[256, 256],
            dec_prenet_fn='leaky_relu',
            dec_prenet_do=0.25,
            dec_rnn_sizes=[512, 512],
            dec_rnn_cfgs={"type": "lstm"},
            dec_rnn_do=0.2,
            dec_proj_size=[512, 512],
            dec_proj_fn='leaky_relu',
            dec_proj_do=0.25,
            dec_bern_end_size=[256],
            dec_bern_end_do=0.0,
            dec_bern_end_fn='LeakyReLU',
            dec_cfg={"type": "standard_decoder"},
            att_cfg={
                "type": "mlp_history",
                "kwargs": {
                    "history_conv_ksize": [2, 4, 8]
                }
            },  # location sensitive attention
            # OPTIONAL #
        dec_in_range=None,
            use_bn=False,  # Tacotron V2 default activate BatchNorm
            use_ln=False,  # Use layer-normalization on feedforward
    ):
        """
        Tacotron V2 
        
        Decoder generates 2 outputs mel + linear spec, use main for conditional input next step
        Args:
            enc_in_size : size of vocab
            dec_in_size : input (mel) dim size
            dec_out_size : output (mel/linear) dim size (usually same as dec_in_size)
            dec_in_range : 
                pair of integer [x, y] \in [0, dec_in_size], 
                all dims outside this pair will be masked as 0
                in Tacotron paper, they only use last time-step instead of all group
        """
        super().__init__()
        self.enc_in_size = enc_in_size
        self.dec_in_size = dec_in_size
        self.dec_out_size = dec_out_size  # output projection -> mel/linear spec #

        self.enc_emb_size = enc_emb_size
        self.enc_emb_do = enc_emb_do

        self.enc_conv_sizes = enc_conv_sizes
        self.enc_conv_filter = enc_conv_filter
        self.enc_conv_do = ConfigParser.list_parser(enc_conv_do,
                                                    len(enc_conv_sizes))
        self.enc_conv_fn = enc_conv_fn

        self.enc_rnn_sizes = enc_rnn_sizes
        self.enc_rnn_do = ConfigParser.list_parser(enc_rnn_do,
                                                   len(enc_rnn_sizes))
        self.enc_rnn_cfgs = ConfigParser.list_parser(enc_rnn_cfgs,
                                                     len(enc_rnn_sizes))

        self.dec_prenet_size = dec_prenet_size
        self.dec_prenet_do = ConfigParser.list_parser(dec_prenet_do,
                                                      len(dec_prenet_size))
        self.dec_prenet_fn = dec_prenet_fn
        self.dec_rnn_sizes = dec_rnn_sizes
        self.dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                     len(dec_rnn_sizes))
        self.dec_rnn_do = ConfigParser.list_parser(dec_rnn_do,
                                                   len(dec_rnn_sizes))

        self.dec_proj_size = dec_proj_size
        self.dec_proj_fn = dec_proj_fn
        self.dec_proj_do = ConfigParser.list_parser(dec_proj_do,
                                                    len(dec_proj_size))

        self.dec_bern_end_size = dec_bern_end_size
        self.dec_bern_end_do = ConfigParser.list_parser(
            dec_bern_end_do, len(dec_bern_end_size))
        self.dec_bern_end_fn = dec_bern_end_fn

        self.dec_cfg = dec_cfg
        self.att_cfg = att_cfg
        self.use_bn = use_bn
        self.use_ln = use_ln
        if use_ln == True:
            raise ValueError("Layer Normalization is not supported yet!")

        # OPTIONAL #
        self.dec_in_range = dec_in_range
        if self.dec_in_range is not None:
            assert isinstance(self.dec_in_range, (list, tuple)) \
                    and len(self.dec_in_range) == 2
        ### FINISH ###

        # init emb layer
        self.enc_emb_lyr = nn.Embedding(enc_in_size, enc_emb_size)

        # init enc conv #
        _tmp = []
        prev_size = enc_emb_size
        for ii in range(len(self.enc_conv_sizes)):
            _tmp.append(
                Conv1dEv(prev_size,
                         self.enc_conv_filter[ii],
                         self.enc_conv_sizes[ii],
                         padding='same'))
            _tmp.append(generator_act_module(self.enc_conv_fn))
            if self.use_bn:
                _tmp.append(nn.BatchNorm1d(self.enc_conv_filter[ii]))
            _tmp.append(nn.Dropout(p=self.enc_conv_do[ii]))
            prev_size = self.enc_conv_filter[ii]
        self.enc_conv_lyr = nn.Sequential(*_tmp)

        # init enc rnn #
        self.enc_rnn_lyr = nn.ModuleList()
        _enc_rnn_cfgs = ConfigParser.list_parser(enc_rnn_cfgs,
                                                 len(enc_rnn_sizes))
        for ii in range(len(self.enc_rnn_sizes)):
            _rnn_cfg = {}
            _rnn_cfg['type'] = _enc_rnn_cfgs[ii]['type']
            _rnn_cfg['args'] = [
                prev_size, enc_rnn_sizes[ii], 1, True, True, 0,
                _enc_rnn_cfgs[ii]['bi']
            ]
            self.enc_rnn_lyr.append(generator_rnn(_rnn_cfg))
            prev_size = enc_rnn_sizes[ii]

        # init dec prenet #
        _tmp = []
        prev_size = dec_in_size if self.dec_in_range is None else (
            (self.dec_in_range[-1] or 0) - (self.dec_in_range[-2] or 0))
        for ii in range(len(self.dec_prenet_size)):
            _tmp.append(nn.Linear(prev_size, self.dec_prenet_size[ii]))
            prev_size = self.dec_prenet_size[ii]

        self.dec_prenet_lyr = nn.ModuleList(_tmp)

        # init dec rnn #
        _dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                 len(dec_rnn_sizes))
        for ii in range(len(dec_rnn_sizes)):
            _type = _dec_rnn_cfgs[ii]['type']
            if re.match('stateful.*cell', _type) is None:
                _dec_rnn_cfgs[ii]['type'] = 'stateful_{}cell'.format(_type)

        final_enc_size = self.enc_rnn_lyr[-1].hidden_size * (
            2 if self.enc_rnn_lyr[-1].bidirectional else 1)
        assert 'type' in dec_cfg, "decoder type need to be defined"
        if dec_cfg['type'] == 'standard_decoder':
            _tmp_dec_cfg = dict(dec_cfg)
            del _tmp_dec_cfg['type']  #
            self.dec_att_lyr = decoder.StandardDecoder(att_cfg=att_cfg,
                                                       ctx_size=final_enc_size,
                                                       in_size=prev_size,
                                                       rnn_sizes=dec_rnn_sizes,
                                                       rnn_cfgs=dec_rnn_cfgs,
                                                       rnn_do=dec_rnn_do,
                                                       **_tmp_dec_cfg)

        # init dec lin proj -> mel/linear-spec
        prev_size = self.dec_att_lyr.out_features
        _tmp = []
        for ii in range(len(self.dec_proj_size)):
            _tmp.append(nn.Linear(prev_size, self.dec_proj_size[ii]))
            prev_size = self.dec_proj_size[ii]
        _tmp.append(nn.Linear(prev_size, self.dec_out_size))
        self.dec_proj_lyr = nn.ModuleList(_tmp)

        # init dec bern end layer
        _tmp = []
        prev_size = self.dec_out_size + self.dec_att_lyr.out_features + (
            self.enc_rnn_lyr[-1].hidden_size *
            (2 if self.enc_rnn_lyr[-1].bidirectional else 1))
        for ii in range(len(self.dec_bern_end_size)):
            _tmp.append(nn.Linear(prev_size, self.dec_bern_end_size[ii]))
            _tmp.append(generator_act_module(dec_bern_end_fn))
            _tmp.append(nn.Dropout(self.dec_bern_end_do[ii]))
            prev_size = self.dec_bern_end_size[ii]
            pass
        _tmp.append(nn.Linear(prev_size, 1))
        self.dec_bern_end_lyr = nn.Sequential(*_tmp)
        pass
Exemplo n.º 5
0
    def __init__(
        self,
        enc_in_size,
        dec_in_size,
        n_class,
        enc_fnn_sizes=[512],
        enc_fnn_act='tanh',
        enc_fnn_do=0.25,
        enc_cnn_channels=256,
        enc_cnn_ksize=[5, 5, 5, 5],
        enc_cnn_do=0.25,
        enc_cnn_strides=[1, 1, 1, 1],
        enc_cnn_act='leaky_relu',
        dec_emb_size=64,
        dec_emb_do=0.0,
        dec_rnn_sizes=[512, 512],
        dec_rnn_cfgs={"type": "lstm"},
        dec_rnn_do=0.25,
        dec_cfg={"type": "standard_decoder"},
        att_cfg={"type": "mlp"},
    ):
        super(ENCCNN_DECRNN_ATT_ASR, self).__init__()

        self.enc_in_size = enc_in_size
        self.dec_in_size = dec_in_size
        self.n_class = n_class
        self.enc_fnn_sizes = enc_fnn_sizes
        self.enc_fnn_act = enc_fnn_act
        self.enc_fnn_do = ConfigParser.list_parser(enc_fnn_do,
                                                   len(enc_fnn_sizes))
        self.enc_cnn_channels = enc_cnn_channels  # use same size for highway #
        self.enc_cnn_ksize = enc_cnn_ksize
        self.enc_cnn_strides = enc_cnn_strides
        self.enc_cnn_do = ConfigParser.list_parser(enc_cnn_do,
                                                   len(enc_cnn_ksize))
        self.enc_cnn_act = enc_cnn_act

        self.dec_emb_size = dec_emb_size
        self.dec_emb_do = dec_emb_do
        self.dec_rnn_sizes = dec_rnn_sizes
        self.dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                     len(dec_rnn_sizes))
        self.dec_rnn_do = ConfigParser.list_parser(dec_rnn_do,
                                                   len(dec_rnn_sizes))
        self.dec_cfg = dec_cfg
        self.att_cfg = att_cfg

        # modules #
        # init encoder #
        self.enc_fnn = nn.ModuleList()
        prev_size = enc_in_size
        for ii in range(len(enc_fnn_sizes)):
            self.enc_fnn.append(nn.Linear(prev_size, enc_fnn_sizes[ii]))
            prev_size = enc_fnn_sizes[ii]

        self.enc_cnn = nn.ModuleList()
        self.use_pad1 = []
        # batch x ndim x seq x 1#
        for ii in range(len(enc_cnn_ksize)):
            self.enc_cnn.append(
                nn.Conv2d(prev_size,
                          enc_cnn_channels,
                          kernel_size=(self.enc_cnn_ksize[ii], 1),
                          stride=(self.enc_cnn_strides[ii], 1),
                          padding=((self.enc_cnn_ksize[ii] - 1) // 2, 0)))
            self.use_pad1.append(True if self.enc_cnn_ksize[ii] %
                                 2 == 0 else False)
            prev_size = enc_cnn_channels

        final_enc_size = prev_size
        # init position embedding function #
        self.pos_emb = nn.Linear(1, final_enc_size)

        # init decoder #
        self.dec_emb = nn.Embedding(self.dec_in_size,
                                    dec_emb_size,
                                    padding_idx=None)
        prev_size = dec_emb_size
        _dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                 len(dec_rnn_sizes))
        for ii in range(len(dec_rnn_sizes)):
            _type = _dec_rnn_cfgs[ii]['type']
            if re.match('stateful.*cell', _type) is None:
                _dec_rnn_cfgs[ii]['type'] = 'stateful_{}cell'.format(_type)
        # TODO : dec_cfg #
        self.dec = decoder.StandardDecoder(att_cfg, final_enc_size,
                                           dec_emb_size, dec_rnn_sizes,
                                           _dec_rnn_cfgs, dec_rnn_do)
        self.pre_softmax = nn.Linear(self.dec.output_size, n_class)
        pass
Exemplo n.º 6
0
    def __init__(self,
                 enc_in_size,
                 dec_in_size,
                 n_class,
                 enc_fnn_sizes=[512],
                 enc_fnn_act='tanh',
                 enc_fnn_do=0.25,
                 enc_rnn_sizes=[256, 256, 256],
                 enc_rnn_cfgs={
                     "type": "lstm",
                     "bi": True
                 },
                 enc_rnn_do=0.25,
                 downsampling=None,
                 dec_emb_size=64,
                 dec_emb_do=0.0,
                 dec_rnn_sizes=[512, 512],
                 dec_rnn_cfgs={"type": "lstm"},
                 dec_rnn_do=0.25,
                 dec_cfg={"type": "standard_decoder"},
                 att_cfg={"type": "mlp"},
                 enc_prior_cfg=None,
                 dec_prior_cfg={
                     'type': 'normal',
                     'kwargs': {
                         'mu': 0,
                         'sigma': 1.0
                     }
                 }):
        super().__init__()

        self.enc_in_size = enc_in_size
        self.dec_in_size = dec_in_size
        self.n_class = n_class
        self.enc_fnn_sizes = enc_fnn_sizes
        self.enc_fnn_act = enc_fnn_act
        self.enc_fnn_do = ConfigParser.list_parser(enc_fnn_do,
                                                   len(enc_fnn_sizes))
        self.enc_rnn_sizes = enc_rnn_sizes
        self.enc_rnn_cfgs = enc_rnn_cfgs
        self.enc_rnn_do = ConfigParser.list_parser(enc_rnn_do,
                                                   len(enc_rnn_sizes))
        self.downsampling = ConfigParser.list_parser(downsampling,
                                                     len(enc_rnn_sizes))

        self.dec_emb_size = dec_emb_size
        self.dec_emb_do = dec_emb_do
        self.dec_rnn_sizes = dec_rnn_sizes
        self.dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                     len(dec_rnn_sizes))
        self.dec_rnn_do = ConfigParser.list_parser(dec_rnn_do,
                                                   len(dec_rnn_sizes))
        self.dec_cfg = dec_cfg
        self.att_cfg = att_cfg

        self.enc_prior_cfg = enc_prior_cfg
        self.dec_prior_cfg = dec_prior_cfg

        # modules #
        # init encoder #
        # TODO : add bayesian encoder RNN #
        self.enc_fnn = nn.ModuleList()
        prev_size = enc_in_size
        for ii in range(len(enc_fnn_sizes)):
            self.enc_fnn.append(nn.Linear(prev_size, enc_fnn_sizes[ii]))
            prev_size = enc_fnn_sizes[ii]

        self.enc_rnn = nn.ModuleList()
        _enc_rnn_cfgs = ConfigParser.list_parser(enc_rnn_cfgs,
                                                 len(enc_rnn_sizes))
        for ii in range(len(enc_rnn_sizes)):
            _rnn_cfg = {}
            _rnn_cfg['type'] = _enc_rnn_cfgs[ii]['type']
            _rnn_cfg['args'] = [
                prev_size, enc_rnn_sizes[ii], 1, True, True, 0,
                _enc_rnn_cfgs[ii]['bi']
            ]
            self.enc_rnn.append(generator_rnn(_rnn_cfg))
            prev_size = enc_rnn_sizes[ii] * (2
                                             if _enc_rnn_cfgs[ii]['bi'] else 1)
        final_enc_size = prev_size
        # init decoder #
        # TODO : add bayesian decoder RNN #
        self.dec_emb = nn.Embedding(self.dec_in_size,
                                    dec_emb_size,
                                    padding_idx=None)
        prev_size = dec_emb_size
        _dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                 len(dec_rnn_sizes))
        for ii in range(len(dec_rnn_sizes)):
            _type = _dec_rnn_cfgs[ii]['type']
            if re.match('stateful.*cell', _type) is None:
                _dec_rnn_cfgs[ii]['type'] = 'stateful_{}cell'.format(_type)
        # TODO : dec_cfg #
        self.dec = decoder.StandardDecoder(att_cfg, final_enc_size,
                                           dec_emb_size, dec_rnn_sizes,
                                           _dec_rnn_cfgs, dec_rnn_do)

        if self.dec_prior_cfg is None:
            self.pre_softmax = nn.Linear(self.dec.output_size, n_class)
        else:
            self.pre_softmax = LinearBayes(
                self.dec.output_size,
                n_class,
                posterior_w=NormalRV(0, 0.05),
                posterior_b=NormalRV(0, 0.05),
                prior_w=generator_bayes_rv(self.dec_prior_cfg),
                prior_b=generator_bayes_rv(self.dec_prior_cfg))
        pass
Exemplo n.º 7
0
    def __init__(
        self,
        enc_in_size,
        dec_in_size,
        n_class,
        enc_cnn_sizes=[80, 25, 10, 5],
        enc_cnn_act='leaky_relu',
        enc_cnn_stride=[4, 2, 1, 1],
        enc_cnn_do=0.0,
        enc_cnn_filter=256,
        enc_cnn_gated=[False, False, False, False],
        use_bn=False,
        enc_nin_filter=[128, 128],
        enc_rnn_sizes=[256, 256, 256],
        enc_rnn_cfgs={
            "type": "lstm",
            "bi": True
        },
        enc_rnn_do=0.25,
        downsampling=None,
        dec_emb_size=64,
        dec_emb_do=0.0,
        dec_rnn_sizes=[512, 512],
        dec_rnn_cfgs={"type": "lstm"},
        dec_rnn_do=0.25,
        dec_cfg={"type": "standard_decoder"},
        att_cfg={"type": "mlp"},
    ):
        super(ENCCNNRNN_DECRNN_ATT_ASR, self).__init__()

        self.enc_in_size = enc_in_size
        self.dec_in_size = dec_in_size
        self.n_class = n_class
        self.enc_cnn_sizes = enc_cnn_sizes
        self.enc_cnn_act = enc_cnn_act
        self.enc_cnn_gated = ConfigParser.list_parser(enc_cnn_gated,
                                                      len(enc_cnn_sizes))
        self.enc_cnn_stride = enc_cnn_stride
        self.enc_cnn_filter = ConfigParser.list_parser(enc_cnn_filter,
                                                       len(enc_cnn_sizes))
        self.enc_cnn_do = ConfigParser.list_parser(enc_cnn_do,
                                                   len(enc_cnn_sizes))
        self.use_bn = use_bn
        self.enc_nin_filter = enc_nin_filter

        self.enc_rnn_sizes = enc_rnn_sizes  # kernel size #
        self.enc_rnn_cfgs = enc_rnn_cfgs
        self.enc_rnn_do = ConfigParser.list_parser(enc_rnn_do,
                                                   len(enc_rnn_sizes))

        self.downsampling = ConfigParser.list_parser(downsampling,
                                                     len(enc_rnn_sizes))

        self.dec_emb_size = dec_emb_size
        self.dec_emb_do = dec_emb_do
        self.dec_rnn_sizes = dec_rnn_sizes
        self.dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                     len(dec_rnn_sizes))
        self.dec_rnn_do = ConfigParser.list_parser(dec_rnn_do,
                                                   len(dec_rnn_sizes))
        self.dec_cfg = dec_cfg
        self.att_cfg = att_cfg

        # modules #
        # init encoder #
        self.enc_cnn = nn.ModuleList()
        self.enc_cnn_bn = nn.ModuleList()
        prev_size = enc_in_size
        prev_ch = 1
        for ii in range(len(enc_cnn_sizes)):
            if self.enc_cnn_gated[ii]:
                _cnn_lyr = GatedConv2dLinearUnit
            else:
                _cnn_lyr = Conv2dEv
            self.enc_cnn.append(
                _cnn_lyr(prev_ch,
                         self.enc_cnn_filter[ii], (self.enc_cnn_sizes[ii], 1),
                         stride=(self.enc_cnn_stride[ii], 1),
                         padding='valid',
                         dilation=1))
            self.enc_cnn_bn.append(nn.BatchNorm2d(self.enc_cnn_filter[ii]))
            prev_size = enc_cnn_sizes[ii]
            prev_ch = self.enc_cnn_filter[ii]

        self.enc_nin = nn.ModuleList()
        for ii in range(len(enc_nin_filter)):
            self.enc_nin = self.enc_nin.append(
                nn.Conv2d(prev_ch, enc_nin_filter[ii], [1, 1]))
            prev_ch = enc_nin_filter[ii]
        self.enc_raw_enc = nn.ModuleList(
            [self.enc_cnn, self.enc_cnn_bn, self.enc_nin])
        prev_size = prev_ch  # global pooling after conv #
        self.enc_rnn = nn.ModuleList()
        _enc_rnn_cfgs = ConfigParser.list_parser(enc_rnn_cfgs,
                                                 len(enc_rnn_sizes))
        for ii in range(len(enc_rnn_sizes)):
            _rnn_cfg = {}
            _rnn_cfg['type'] = _enc_rnn_cfgs[ii]['type']
            _rnn_cfg['args'] = [
                prev_size, enc_rnn_sizes[ii], 1, True, True, 0,
                _enc_rnn_cfgs[ii]['bi']
            ]
            self.enc_rnn.append(generator_rnn(_rnn_cfg))
            prev_size = enc_rnn_sizes[ii] * (2
                                             if _enc_rnn_cfgs[ii]['bi'] else 1)
        final_enc_size = prev_size
        # init decoder #
        self.dec_emb = nn.Embedding(self.dec_in_size,
                                    dec_emb_size,
                                    padding_idx=None)
        prev_size = dec_emb_size
        _dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                 len(dec_rnn_sizes))
        for ii in range(len(dec_rnn_sizes)):
            _type = _dec_rnn_cfgs[ii]['type']
            if re.match('stateful.*cell', _type) is None:
                _dec_rnn_cfgs[ii]['type'] = 'stateful_{}cell'.format(_type)
        # TODO : dec_cfg #
        self.dec = decoder.StandardDecoder(att_cfg, final_enc_size,
                                           dec_emb_size, dec_rnn_sizes,
                                           _dec_rnn_cfgs, dec_rnn_do)
        self.pre_softmax = nn.Linear(self.dec.output_size, n_class)
        pass
Exemplo n.º 8
0
    def __init__(
            self,
            enc_in_size,
            dec_in_size,
            dec_out_size,
            dec_out_post_size,
            enc_emb_size=256,
            enc_emb_do=0.0,
            enc_prenet_size=[256, 128],
            enc_prenet_do=[0.5, 0.5],
            enc_prenet_fn='leaky_relu',
            dec_prenet_size=[256, 128],
            dec_prenet_do=[0.5, 0.5],
            dec_prenet_fn='leaky_relu',
            dec_rnn_sizes=[256, 256],
            dec_rnn_cfgs={"type": "lstm"},
            dec_rnn_do=0.0,
            dec_cfg={"type": "standard_decoder"},
            att_cfg={"type": "mlp"},
            # CBHG #
            enc_cbhg_cfg={},
            dec_postnet_cbhg_cfg={},
            # OPTIONAL #
            dec_in_range=None):
        """
        Args:
            enc_in_size : size of vocab
            dec_in_size : input (mel) dim size
            dec_out_size : output (mel) dim size (usually same as dec_in_size)
            dec_out_post_size : output (linear) dim size
            dec_in_range : 
                pair of integer [x, y] \in [0, dec_in_size], 
                all dims outside this pair will be masked as 0
                in Tacotron paper, they only use last time-step instead of all group
        """
        super(TACOTRON, self).__init__()
        self.enc_in_size = enc_in_size
        self.dec_in_size = dec_in_size
        self.dec_out_size = dec_out_size  # first output -> mel spec #
        self.dec_out_post_size = dec_out_post_size  # second output -> raw spec #
        self.enc_emb_size = enc_emb_size
        self.enc_emb_do = enc_emb_do
        self.enc_prenet_size = enc_prenet_size
        self.enc_prenet_do = ConfigParser.list_parser(enc_prenet_do,
                                                      len(enc_prenet_size))
        self.enc_prenet_fn = enc_prenet_fn
        self.dec_prenet_size = dec_prenet_size
        self.dec_prenet_do = ConfigParser.list_parser(dec_prenet_do,
                                                      len(dec_prenet_size))
        self.dec_prenet_fn = dec_prenet_fn
        self.dec_rnn_sizes = dec_rnn_sizes
        self.dec_rnn_cfgs = dec_rnn_cfgs
        self.dec_rnn_do = dec_rnn_do
        self.dec_cfg = dec_cfg
        self.att_cfg = att_cfg

        # OPTIONAL #
        self.dec_in_range = dec_in_range
        if self.dec_in_range is not None:
            assert isinstance(self.dec_in_range, (list, tuple)) \
                    and len(self.dec_in_range) == 2

        # CBHG config #
        self.enc_cbhg_cfg = ConfigParser.item_parser(enc_cbhg_cfg)
        self.dec_postnet_cbhg_cfg = ConfigParser.item_parser(
            dec_postnet_cbhg_cfg)

        self.enc_emb_lyr = nn.Embedding(enc_in_size, enc_emb_size)
        # init enc prenet #
        self.enc_prenet_lyr = nn.ModuleList()
        prev_size = enc_emb_size
        for ii in range(len(self.enc_prenet_size)):
            self.enc_prenet_lyr.append(
                nn.Linear(prev_size, enc_prenet_size[ii]))
            prev_size = enc_prenet_size[ii]
        # init enc middle #
        self.enc_core_lyr = cbhg.CBHG1d(prev_size, **enc_cbhg_cfg)
        # init dec prenet #
        self.dec_prenet_lyr = nn.ModuleList()
        prev_size = dec_in_size if self.dec_in_range is None else (
            (self.dec_in_range[-1] or 0) - (self.dec_in_range[-2] or 0))
        for ii in range(len(self.dec_prenet_size)):
            self.dec_prenet_lyr.append(
                nn.Linear(prev_size, dec_prenet_size[ii]))
            prev_size = dec_prenet_size[ii]

        # init dec rnn #
        _dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                 len(dec_rnn_sizes))
        for ii in range(len(dec_rnn_sizes)):
            _type = _dec_rnn_cfgs[ii]['type']
            if re.match('stateful.*cell', _type) is None:
                _dec_rnn_cfgs[ii]['type'] = 'stateful_{}cell'.format(_type)
        # TODO : dec_cfg #
        final_enc_size = self.enc_core_lyr.output_size
        self.dec_att_lyr = decoder.StandardDecoder(att_cfg, final_enc_size,
                                                   prev_size, dec_rnn_sizes,
                                                   dec_rnn_cfgs, dec_rnn_do)

        # init dec regression melspec #
        self.dec_first_reg_lyr = nn.Linear(self.dec_att_lyr.output_size,
                                           self.dec_out_size)
        # init dec postnet #
        self.dec_postnet_lyr = cbhg.CBHG1d(
            self.dec_out_size,
            conv_proj_filter=[256, dec_out_size],
            **dec_postnet_cbhg_cfg)
        # init dec regression rawspec #
        self.dec_second_reg_lyr = nn.Linear(self.dec_postnet_lyr.output_size,
                                            self.dec_out_post_size)
        pass
Exemplo n.º 9
0
    def __init__(
        self,
        enc_in_size,
        dec_in_size,
        dec_out_size,
        enc_fnn_sizes=[512],
        enc_fnn_act='LeakyReLU',
        enc_fnn_do=0.25,
        enc_rnn_sizes=[256, 256, 256],
        enc_rnn_cfgs={
            "type": "lstm",
            "bi": True
        },
        enc_rnn_do=0.25,
        downsampling=[False, True, True],
        dec_emb_size=256,
        dec_emb_do=0.25,
        dec_emb_tied_weight=True,
        # tying weight from char/word embedding with softmax layer
        dec_rnn_sizes=[512, 512],
        dec_rnn_cfgs={"type": "lstm"},
        dec_rnn_do=0.25,
        dec_cfg={"type": "standard_decoder"},
        att_cfg={"type": "mlp"},
        use_layernorm=False,
    ):
        super().__init__()

        self.enc_in_size = enc_in_size
        self.dec_in_size = dec_in_size
        self.dec_out_size = dec_out_size
        self.enc_fnn_sizes = enc_fnn_sizes
        self.enc_fnn_act = enc_fnn_act
        self.enc_fnn_do = ConfigParser.list_parser(enc_fnn_do,
                                                   len(enc_fnn_sizes))
        self.enc_rnn_sizes = enc_rnn_sizes
        self.enc_rnn_cfgs = enc_rnn_cfgs
        self.enc_rnn_do = ConfigParser.list_parser(enc_rnn_do,
                                                   len(enc_rnn_sizes))
        self.downsampling = ConfigParser.list_parser(downsampling,
                                                     len(enc_rnn_sizes))

        self.dec_emb_size = dec_emb_size
        self.dec_emb_do = dec_emb_do
        self.dec_emb_tied_weight = dec_emb_tied_weight
        self.dec_rnn_sizes = dec_rnn_sizes
        self.dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                     len(dec_rnn_sizes))
        self.dec_rnn_do = ConfigParser.list_parser(dec_rnn_do,
                                                   len(dec_rnn_sizes))
        self.dec_cfg = dec_cfg
        self.att_cfg = att_cfg

        self.use_layernorm = use_layernorm
        if self.use_layernorm == True:
            raise ValueError("LayerNorm is not implemented yet")

        # modules #
        # init encoder #
        prev_size = enc_in_size
        _tmp = []
        for ii in range(len(enc_fnn_sizes)):
            _tmp.append(nn.Linear(prev_size, enc_fnn_sizes[ii]))
            if use_layernorm:
                _tmp.append(LayerNorm(enc_fnn_sizes[ii]))
            _tmp.append(generator_act_module(enc_fnn_act))
            _tmp.append(nn.Dropout(p=self.enc_fnn_do[ii]))
            prev_size = enc_fnn_sizes[ii]
        self.enc_fnn_lyr = nn.Sequential(*_tmp)

        self.enc_rnn_lyr = nn.ModuleList()
        _enc_rnn_cfgs = ConfigParser.list_parser(enc_rnn_cfgs,
                                                 len(enc_rnn_sizes))
        for ii in range(len(enc_rnn_sizes)):
            _rnn_cfg = {}
            _rnn_cfg['type'] = _enc_rnn_cfgs[ii]['type']
            _rnn_cfg['args'] = [
                prev_size, enc_rnn_sizes[ii], 1, True, True, 0,
                _enc_rnn_cfgs[ii]['bi']
            ]
            self.enc_rnn_lyr.append(generator_rnn(_rnn_cfg))
            prev_size = enc_rnn_sizes[ii] * (2
                                             if _enc_rnn_cfgs[ii]['bi'] else 1)
        final_enc_size = prev_size
        # init decoder #
        self.dec_emb_lyr = nn.Embedding(self.dec_in_size,
                                        dec_emb_size,
                                        padding_idx=None)
        prev_size = dec_emb_size
        _dec_rnn_cfgs = ConfigParser.list_parser(dec_rnn_cfgs,
                                                 len(dec_rnn_sizes))
        for ii in range(len(dec_rnn_sizes)):
            _type = _dec_rnn_cfgs[ii]['type']
            if re.match('stateful.*cell', _type) is None:
                _dec_rnn_cfgs[ii]['type'] = 'stateful_{}cell'.format(_type)
        # TODO : dec_cfg #
        assert 'type' in dec_cfg, "decoder type need to be defined"
        if dec_cfg['type'] == 'standard_decoder':
            _tmp_dec_cfg = dict(dec_cfg)
            del _tmp_dec_cfg['type']  #
            self.dec_att_lyr = decoder.StandardDecoder(att_cfg=att_cfg,
                                                       ctx_size=final_enc_size,
                                                       in_size=dec_emb_size,
                                                       rnn_sizes=dec_rnn_sizes,
                                                       rnn_cfgs=_dec_rnn_cfgs,
                                                       rnn_do=dec_rnn_do,
                                                       **_tmp_dec_cfg)
        else:
            raise NotImplementedError("decoder type {} is not found".format(
                dec_cfg['type']))
        self.dec_presoftmax_lyr = nn.Linear(self.dec_att_lyr.output_size,
                                            dec_out_size)
        if dec_emb_tied_weight:
            assert dec_out_size == dec_in_size and self.dec_emb_lyr.embedding_dim == self.dec_presoftmax_lyr.in_features
            self.dec_presoftmax_lyr.weight = self.dec_emb_lyr.weight
        pass