コード例 #1
0
ファイル: sequence_summary.py プロジェクト: fireae/neural_sp
    def __init__(self,
                 input_dim,
                 n_units,
                 n_layers,
                 bottleneck_dim,
                 dropout,
                 param_init=0.1):

        super(SequenceSummaryNetwork, self).__init__()

        self.n_layers = n_layers

        self.ssn = nn.ModuleList()
        self.ssn += [LinearND(input_dim, n_units, bias=False, dropout=dropout)]
        for l in range(1, n_layers - 1, 1):
            self.ssn += [
                LinearND(n_units,
                         bottleneck_dim if l == n_layers - 2 else n_units,
                         bias=False,
                         dropout=dropout)
            ]
        self.ssn += [
            LinearND(bottleneck_dim, input_dim, bias=False, dropout=dropout)
        ]

        # Initialize parameters
        self.reset_parameters(param_init)
コード例 #2
0
    def __init__(self,
                 key_dim,
                 query_dim,
                 attn_type,
                 attn_dim,
                 dropout=0,
                 n_heads=4):

        super(MultiheadAttentionMechanism, self).__init__()

        self.attn_type = attn_type
        assert attn_dim % n_heads == 0
        self.d_k = attn_dim // n_heads
        self.n_heads = n_heads
        self.key = None
        self.value = None
        self.mask = None

        # attention dropout applied AFTER the softmax layer
        self.attn_dropout = nn.Dropout(p=dropout)

        if attn_type == 'scaled_dot':
            self.w_key = LinearND(key_dim, attn_dim, bias=False)
            self.w_value = LinearND(key_dim, attn_dim, bias=False)
            self.w_query = LinearND(query_dim, attn_dim, bias=False)
        elif attn_type == 'add':
            self.w_key = LinearND(key_dim, attn_dim, bias=True)
            self.w_value = LinearND(key_dim, attn_dim, bias=False)
            self.w_query = LinearND(query_dim, attn_dim, bias=False)
            self.v = LinearND(attn_dim, n_heads, bias=False)
        else:
            raise NotImplementedError(attn_type)

        self.w_out = LinearND(attn_dim, key_dim)
コード例 #3
0
    def __init__(self, key_dim, query_dim, attn_dim, dropout=0, n_heads=4):

        super(MultiheadAttentionMechanism, self).__init__()

        self.d_k = attn_dim // n_heads
        self.n_heads = n_heads
        self.key = None
        self.value = None
        self.mask = None

        # attention dropout applied AFTER the softmax layer
        self.attn_dropout = nn.Dropout(p=dropout)

        self.w_key = LinearND(key_dim, attn_dim, bias=False)
        self.w_value = LinearND(key_dim, attn_dim, bias=False)
        self.w_query = LinearND(query_dim, attn_dim, bias=False)
        self.w_out = LinearND(attn_dim, key_dim)
コード例 #4
0
    def __init__(self,
                 input_dim,
                 in_channel,
                 channels,
                 kernel_sizes,
                 strides,
                 poolings,
                 dropout,
                 batch_norm=False,
                 residual=False,
                 bottleneck_dim=0,
                 param_init=0.1):

        super(ConvEncoder, self).__init__()
        logger = logging.getLogger("training")

        self.in_channel = in_channel
        assert input_dim % in_channel == 0
        self.input_freq = input_dim // in_channel
        self.residual = residual
        self.bridge = None

        assert len(channels) > 0
        assert len(channels) == len(kernel_sizes) == len(strides) == len(
            poolings)

        self.layers = nn.ModuleList()
        in_ch = in_channel
        in_freq = self.input_freq
        for l in range(len(channels)):
            block = Conv2LBlock(input_dim=in_freq,
                                in_channel=in_ch,
                                out_channel=channels[l],
                                kernel_size=kernel_sizes[l],
                                stride=strides[l],
                                pooling=poolings[l],
                                dropout=dropout,
                                batch_norm=batch_norm,
                                residual=residual)
            self.layers += [block]
            in_freq = block.input_dim
            in_ch = channels[l]

        self._output_dim = int(in_ch * in_freq)

        if bottleneck_dim > 0:
            self.bridge = LinearND(self._output_dim, bottleneck_dim)
            self._output_dim = bottleneck_dim

        # Initialize parameters
        self.reset_parameters(param_init)
コード例 #5
0
    def __init__(self,
                 eos,
                 blank,
                 enc_n_units,
                 vocab,
                 dropout=0.0,
                 lsm_prob=0.0,
                 fc_list=[],
                 param_init=0.1):

        super(CTC, self).__init__()
        logger = logging.getLogger('training')

        self.eos = eos
        self.blank = blank
        self.vocab = vocab
        self.lsm_prob = lsm_prob

        self.space = -1
        # TODO(hirofumi): fix layer

        # Fully-connected layers before the softmax
        if len(fc_list) > 0:
            fc_layers = OrderedDict()
            for i in range(len(fc_list)):
                input_dim = enc_n_units if i == 0 else fc_list[i - 1]
                fc_layers['fc' + str(i)] = LinearND(input_dim,
                                                    fc_list[i],
                                                    dropout=dropout)
            fc_layers['fc' + str(len(fc_list))] = LinearND(fc_list[-1],
                                                           vocab,
                                                           dropout=0)
            self.output = nn.Sequential(fc_layers)
        else:
            self.output = LinearND(enc_n_units, vocab)

        import warpctc_pytorch
        self.warpctc_loss = warpctc_pytorch.CTCLoss(size_average=True)
コード例 #6
0
    def __init__(self,
                 input_dim,
                 in_channel,
                 channels,
                 kernel_sizes,
                 dropout,
                 bottleneck_dim=0):

        super(TDSEncoder, self).__init__()
        logger = logging.getLogger("training")

        self.in_channel = in_channel
        assert input_dim % in_channel == 0
        self.input_freq = input_dim // in_channel
        self.bridge = None

        assert len(channels) > 0
        assert len(channels) == len(kernel_sizes)

        layers = OrderedDict()
        in_ch = in_channel
        in_freq = self.input_freq
        for l in range(len(channels)):
            # subsample
            if in_ch != channels[l]:
                layers['subsample%d' % l] = SubsampelBlock(in_channel=in_ch,
                                                           out_channel=channels[l],
                                                           in_freq=in_freq,
                                                           dropout=dropout)

            # Conv
            layers['tds%d_block%d' % (channels[l], l)] = TDSBlock(channel=channels[l],
                                                                  kernel_size=kernel_sizes[l][0],
                                                                  in_freq=in_freq,
                                                                  dropout=dropout)

            in_ch = channels[l]

        self._output_dim = int(in_ch * in_freq)

        if bottleneck_dim > 0:
            self.bridge = LinearND(self._output_dim, bottleneck_dim)
            self._output_dim = bottleneck_dim

        self.layers = nn.Sequential(layers)

        # Initialize parameters
        self.reset_parameters()
コード例 #7
0
ファイル: gated_conv.py プロジェクト: fireae/neural_sp
    def __init__(self,
                 input_dim,
                 in_channel,
                 channels,
                 kernel_sizes,
                 dropout,
                 bottleneck_dim=0,
                 param_init=0.1):

        super(GatedConvEncoder, self).__init__()
        logger = logging.getLogger("training")

        self.in_channel = in_channel
        assert input_dim % in_channel == 0
        self.input_freq = input_dim // in_channel
        self.bridge = None

        assert len(channels) > 0
        assert len(channels) == len(kernel_sizes)

        layers = OrderedDict()
        for l in range(len(channels)):
            layers['conv%d' % l] = GLUBlock(kernel_sizes[l][0],
                                            input_dim,
                                            channels[l],
                                            weight_norm=True,
                                            dropout=0.2)
            input_dim = channels[l]

        # weight normalization + GLU for the last fully-connected layer
        self.fc_glu = nn.utils.weight_norm(nn.Linear(input_dim, input_dim * 2),
                                           name='weight',
                                           dim=0)

        self._output_dim = int(input_dim)

        if bottleneck_dim > 0:
            self.bridge = LinearND(self._output_dim, bottleneck_dim)
            self._output_dim = bottleneck_dim

        self.layers = nn.Sequential(layers)

        # Initialize parameters
        self.reset_parameters(param_init)
コード例 #8
0
ファイル: gated_convlm.py プロジェクト: xdcesc/neural_sp
    def __init__(self, args):

        super(ModelBase, self).__init__()

        self.emb_dim = args.emb_dim
        self.n_units = args.n_units
        self.n_layers = args.n_layers
        self.tie_embedding = args.tie_embedding
        self.backward = args.backward

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []

        self.embed = Embedding(vocab=self.vocab,
                               emb_dim=args.emb_dim,
                               dropout=args.dropout_emb,
                               ignore_index=self.pad)

        layers = OrderedDict()

        model_size = args.lm_type.replace('gated_conv_', '')

        if model_size == 'small':
            layers['conv1-1'] = GLUBlock(4,
                                         args.emb_dim,
                                         600,
                                         bottlececk_dim=300,
                                         dropout=args.dropout_hidden)
            layers['conv2-1'] = GLUBlock(4,
                                         600,
                                         600,
                                         bottlececk_dim=300,
                                         dropout=args.dropout_hidden)
            layers['conv3-1'] = GLUBlock(4,
                                         600,
                                         600,
                                         bottlececk_dim=300,
                                         dropout=args.dropout_hidden)
            layers['conv4-1'] = GLUBlock(4,
                                         600,
                                         600,
                                         bottlececk_dim=300,
                                         dropout=args.dropout_hidden)
            layers['conv5-1'] = GLUBlock(4,
                                         600,
                                         600,
                                         bottlececk_dim=300,
                                         dropout=args.dropout_hidden)
            last_dim = 600

        elif model_size == '8':
            layers['conv1-1'] = GLUBlock(4,
                                         args.emb_dim,
                                         900,
                                         dropout=args.dropout_hidden)
            for i in range(1, 8, 1):
                layers['conv2-%d' % i] = GLUBlock(4,
                                                  900,
                                                  900,
                                                  dropout=args.dropout_hidden)
            last_dim = 900

        elif model_size == '8B':
            raise NotImplementedError

        elif model_size == '9':
            raise NotImplementedError

        elif model_size == '13':
            layers['conv1-1'] = GLUBlock(4,
                                         args.emb_dim,
                                         1268,
                                         dropout=args.dropout_hidden)
            for i in range(1, 13, 1):
                layers['conv2-%d' % i] = GLUBlock(4,
                                                  1268,
                                                  1268,
                                                  dropout=args.dropout_hidden)
            last_dim = 1268

        elif model_size == '14':
            for i in range(1, 4, 1):
                layers['conv1-%d' % i] = GLUBlock(
                    6,
                    args.emb_dim if i == 1 else 850,
                    850,
                    dropout=args.dropout_hidden)
            layers['conv2-1'] = GLUBlock(1,
                                         850,
                                         850,
                                         dropout=args.dropout_hidden)
            for i in range(1, 5, 1):
                layers['conv3-%d' % i] = GLUBlock(5,
                                                  850,
                                                  850,
                                                  dropout=args.dropout_hidden)
            layers['conv4-1'] = GLUBlock(1,
                                         850,
                                         850,
                                         dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                layers['conv5-%d' % i] = GLUBlock(4,
                                                  850,
                                                  850,
                                                  dropout=args.dropout_hidden)
            layers['conv6-1'] = GLUBlock(4,
                                         850,
                                         1024,
                                         dropout=args.dropout_hidden)
            layers['conv7-1'] = GLUBlock(4,
                                         1024,
                                         2048,
                                         dropout=args.dropout_hidden)
            last_dim = 2048

        elif model_size == '14B':
            layers['conv1-1'] = GLUBlock(5,
                                         args.emb_dim,
                                         512,
                                         dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                layers['conv2-%d' % i] = GLUBlock(5,
                                                  512,
                                                  512,
                                                  bottlececk_dim=128,
                                                  dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                layers['conv3-%d' % i] = GLUBlock(5,
                                                  512 if i == 1 else 1024,
                                                  1024,
                                                  bottlececk_dim=512,
                                                  dropout=args.dropout_hidden)
            for i in range(1, 7, 1):
                layers['conv4-%d' % i] = GLUBlock(5,
                                                  1024 if i == 1 else 2048,
                                                  2048,
                                                  bottlececk_dim=1024,
                                                  dropout=args.dropout_hidden)
            layers['conv5-1'] = GLUBlock(5,
                                         2048,
                                         4096,
                                         bottlececk_dim=1024,
                                         dropout=args.dropout_hidden)
            last_dim = 4096

        else:
            raise NotImplementedError(model_size)

        self.layers = nn.Sequential(layers)

        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                last_dim,
                self.vocab,
                # cutoffs=[self.vocab // 10, 3 * self.vocab // 10],
                cutoffs=[self.vocab // 25, self.vocab // 5],
                div_value=4.0)
            self.output = None
        else:
            self.adaptive_softmax = None
            self.output = LinearND(last_dim,
                                   self.vocab,
                                   dropout=args.dropout_out)
            # NOTE: include bias even when tying weights

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if args.tie_embedding:
                if args.n_units != args.emb_dim:
                    raise ValueError(
                        'When using the tied flag, n_units must be equal to emb_dim.'
                    )
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters(args.param_init, dist=args.param_init_dist)

        # Initialize bias vectors with zero
        self.reset_parameters(0, dist='constant', keys=['bias'])
コード例 #9
0
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 attn_type,
                 attn_n_heads,
                 n_layers,
                 d_model,
                 d_ff,
                 pe_type,
                 tie_embedding,
                 vocab,
                 dropout=0.0,
                 dropout_emb=0.0,
                 dropout_att=0.0,
                 lsm_prob=0.0,
                 layer_norm_eps=1e-6,
                 ctc_weight=0.0,
                 ctc_fc_list=[],
                 backward=False,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 adaptive_softmax=False):

        super(TransformerDecoder, self).__init__()

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.enc_n_units = enc_n_units
        self.d_model = d_model
        self.n_layers = n_layers
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.ctc_fc_list = ctc_fc_list
        self.backward = backward
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        if ctc_weight > 0:
            # Fully-connected layers for CTC
            if len(ctc_fc_list) > 0:
                fc_layers = OrderedDict()
                for i in range(len(ctc_fc_list)):
                    input_dim = d_model if i == 0 else ctc_fc_list[i - 1]
                    fc_layers['fc' + str(i)] = LinearND(input_dim, ctc_fc_list[i], dropout=dropout)
                fc_layers['fc' + str(len(ctc_fc_list))] = LinearND(ctc_fc_list[-1], vocab, dropout=0)
                self.output_ctc = nn.Sequential(fc_layers)
            else:
                self.output_ctc = LinearND(d_model, vocab)
            self.decode_ctc_greedy = GreedyDecoder(blank=blank)
            self.decode_ctc_beam = BeamSearchDecoder(blank=blank)
            import warpctc_pytorch
            self.warpctc_loss = warpctc_pytorch.CTCLoss(size_average=True)

        if ctc_weight < global_weight:
            self.layers = nn.ModuleList(
                [TransformerDecoderBlock(d_model, d_ff, attn_type, attn_n_heads,
                                         dropout, dropout_att, layer_norm_eps)
                 for _ in range(n_layers)])

            self.embed = Embedding(vocab, d_model,
                                   dropout=0,  # NOTE: do not apply dropout here
                                   ignore_index=pad)
            if pe_type:
                self.pos_emb_out = PositionalEncoding(d_model, dropout_emb, pe_type)

            if adaptive_softmax:
                self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                    d_model, vocab,
                    cutoffs=[round(self.vocab / 15), 3 * round(self.vocab / 15)],
                    # cutoffs=[self.vocab // 25, 3 * self.vocab // 5],
                    div_value=4.0)
                self.output = None
            else:
                self.adaptive_softmax = None
                self.output = LinearND(d_model, vocab)

                # Optionally tie weights as in:
                # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
                # https://arxiv.org/abs/1608.05859
                # and
                # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
                # https://arxiv.org/abs/1611.01462
                if tie_embedding:
                    self.output.fc.weight = self.embed.embed.weight

            self.norm_top = nn.LayerNorm(d_model, eps=layer_norm_eps)

        # Initialize parameters
        self.reset_parameters()
コード例 #10
0
    def __init__(self,
                 input_dim,
                 rnn_type,
                 n_units,
                 n_projs,
                 n_layers,
                 dropout_in,
                 dropout,
                 subsample,
                 subsample_type='drop',
                 n_stacks=1,
                 n_splices=1,
                 last_proj_dim=0,
                 conv_in_channel=1,
                 conv_channels=0,
                 conv_kernel_sizes=[],
                 conv_strides=[],
                 conv_poolings=[],
                 conv_batch_norm=False,
                 conv_residual=False,
                 conv_bottleneck_dim=0,
                 residual=False,
                 n_layers_sub1=0,
                 n_layers_sub2=0,
                 nin=False,
                 task_specific_layer=False,
                 param_init=0.1):

        super(RNNEncoder, self).__init__()

        logger = logging.getLogger("training")

        if len(subsample) > 0 and len(subsample) != n_layers:
            raise ValueError('subsample must be the same size as n_layers.')
        if n_layers_sub1 < 0 or (n_layers_sub1 > 1
                                 and n_layers < n_layers_sub1):
            raise ValueError('Set n_layers_sub1 between 1 to n_layers.')
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1
                                 and n_layers_sub1 < n_layers_sub2):
            raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.')

        self.rnn_type = rnn_type
        self.bidirectional = True if rnn_type in [
            'blstm', 'bgru', 'conv_blstm', 'conv_bgru'
        ] else False
        self.n_units = n_units
        self.n_dirs = 2 if self.bidirectional else 1
        self.n_projs = n_projs
        self.n_layers = n_layers

        # Setting for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # Setting for subsampling
        self.subsample = subsample
        self.subsample_type = subsample_type

        # Setting for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # Setting for residual connections
        self.residual = residual
        if residual:
            assert np.prod(subsample) == 1

        # Setting for the NiN (Network in Network)
        self.nin = nin

        # Dropout for input-hidden connection
        self.dropout_in = nn.Dropout(p=dropout_in)

        # Setting for CNNs before RNNs
        if conv_channels and rnn_type not in ['blstm', 'lstm', 'bgru', 'gru']:
            channels = [int(c) for c in conv_channels.split('_')
                        ] if len(conv_channels) > 0 else []
            kernel_sizes = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_kernel_sizes.split('_')
                            ] if len(conv_kernel_sizes) > 0 else []
            if rnn_type in ['tds', 'gated_conv']:
                strides = []
                poolings = []
            else:
                strides = [[
                    int(c.split(',')[0].replace('(', '')),
                    int(c.split(',')[1].replace(')', ''))
                ] for c in conv_strides.split('_')
                           ] if len(conv_strides) > 0 else []
                poolings = [[
                    int(c.split(',')[0].replace('(', '')),
                    int(c.split(',')[1].replace(')', ''))
                ] for c in conv_poolings.split('_')
                            ] if len(conv_poolings) > 0 else []
            if 'conv_' in rnn_type:
                self.subsample = [1] * self.n_layers
                logger.warning(
                    'Subsampling is automatically ignored because CNN layers are used before RNN layers.'
                )
        else:
            channels = []
            kernel_sizes = []
            strides = []
            poolings = []

        if len(channels) > 0:
            if rnn_type == 'tds':
                self.conv = TDSEncoder(input_dim=input_dim * n_stacks,
                                       in_channel=conv_in_channel,
                                       channels=channels,
                                       kernel_sizes=kernel_sizes,
                                       dropout=dropout,
                                       bottleneck_dim=last_proj_dim)
            elif rnn_type == 'gated_conv':
                self.conv = GatedConvEncoder(input_dim=input_dim * n_stacks,
                                             in_channel=conv_in_channel,
                                             channels=channels,
                                             kernel_sizes=kernel_sizes,
                                             dropout=dropout,
                                             bottleneck_dim=last_proj_dim,
                                             param_init=param_init)
            else:
                assert n_stacks == 1 and n_splices == 1
                self.conv = ConvEncoder(input_dim,
                                        in_channel=conv_in_channel,
                                        channels=channels,
                                        kernel_sizes=kernel_sizes,
                                        strides=strides,
                                        poolings=poolings,
                                        dropout=0,
                                        batch_norm=conv_batch_norm,
                                        residual=conv_residual,
                                        bottleneck_dim=conv_bottleneck_dim,
                                        param_init=param_init)
            self._output_dim = self.conv.output_dim
        else:
            self._output_dim = input_dim * n_splices * n_stacks
            self.conv = None

        if rnn_type not in ['conv', 'tds', 'gated_conv']:
            # Fast implementation without processes between each layer
            self.fast_impl = False
            if np.prod(
                    self.subsample
            ) == 1 and self.n_projs == 0 and not residual and n_layers_sub1 == 0 and not nin:
                self.fast_impl = True
                if 'lstm' in rnn_type:
                    rnn = nn.LSTM
                elif 'gru' in rnn_type:
                    rnn = nn.GRU
                else:
                    raise ValueError(
                        'rnn_type must be "(conv_)(b)lstm" or "(conv_)(b)gru".'
                    )

                self.rnn = rnn(self._output_dim,
                               n_units,
                               n_layers,
                               bias=True,
                               batch_first=True,
                               dropout=dropout,
                               bidirectional=self.bidirectional)
                # NOTE: pytorch introduces a dropout layer on the outputs of each layer EXCEPT the last layer
                self._output_dim = n_units * self.n_dirs
                self.dropout_top = nn.Dropout(p=dropout)
            else:
                self.rnn = nn.ModuleList()
                self.dropout = nn.ModuleList()
                if self.n_projs > 0:
                    self.proj = nn.ModuleList()
                if subsample_type == 'max_pool' and np.prod(
                        self.subsample) > 1:
                    self.max_pool = nn.ModuleList()
                    for l in range(n_layers):
                        if self.subsample[l] > 1:
                            self.max_pool += [
                                nn.MaxPool2d((1, 1),
                                             stride=(self.subsample[l], 1),
                                             ceil_mode=True)
                            ]
                        else:
                            self.max_pool += [None]
                if subsample_type == 'concat' and np.prod(self.subsample) > 1:
                    self.concat_proj = nn.ModuleList()
                    self.concat_bn = nn.ModuleList()
                    for l in range(n_layers):
                        if self.subsample[l] > 1:
                            self.concat_proj += [
                                LinearND(
                                    n_units * self.n_dirs * self.subsample[l],
                                    n_units * self.n_dirs)
                            ]
                            self.concat_bn += [
                                nn.BatchNorm1d(n_units * self.n_dirs)
                            ]
                        else:
                            self.concat_proj += [None]
                            self.concat_bn += [None]
                if nin:
                    self.nin_conv = nn.ModuleList()
                    self.nin_bn = nn.ModuleList()

                for l in range(n_layers):
                    if 'lstm' in rnn_type:
                        rnn_i = nn.LSTM
                    elif 'gru' in rnn_type:
                        rnn_i = nn.GRU
                    else:
                        raise ValueError(
                            'rnn_type must be "(conv_)(b)lstm" or "(conv_)(b)gru".'
                        )

                    self.rnn += [
                        rnn_i(self._output_dim,
                              n_units,
                              1,
                              bias=True,
                              batch_first=True,
                              dropout=0,
                              bidirectional=self.bidirectional)
                    ]
                    self.dropout += [nn.Dropout(p=dropout)]
                    self._output_dim = n_units * self.n_dirs

                    # Projection layer
                    if n_projs > 0 and l != n_layers - 1:
                        self.proj += [LinearND(n_units * self.n_dirs, n_projs)]
                        self._output_dim = n_projs

                    # Task specific layer
                    if l == n_layers_sub1 - 1 and task_specific_layer:
                        self.rnn_sub1 = rnn_i(self._output_dim,
                                              n_units,
                                              1,
                                              bias=True,
                                              batch_first=True,
                                              dropout=0,
                                              bidirectional=self.bidirectional)
                        self.dropout_sub1 = nn.Dropout(p=dropout)
                        if last_proj_dim != self.output_dim:
                            self.bridge_sub1 = LinearND(n_units,
                                                        last_proj_dim,
                                                        dropout=dropout)
                    if l == n_layers_sub2 - 1 and task_specific_layer:
                        self.rnn_sub2 = rnn_i(self._output_dim,
                                              n_units,
                                              1,
                                              bias=True,
                                              batch_first=True,
                                              dropout=0,
                                              bidirectional=self.bidirectional)
                        self.dropout_sub2 = nn.Dropout(p=dropout)
                        if last_proj_dim != self.output_dim:
                            self.bridge_sub2 = LinearND(n_units,
                                                        last_proj_dim,
                                                        dropout=dropout)

                    # Network in network (1*1 conv + batch normalization + ReLU)
                    # NOTE: exclude the last layer
                    if nin and l != n_layers - 1:
                        self.nin_conv += [
                            nn.Conv2d(in_channels=self._output_dim,
                                      out_channels=self._output_dim,
                                      kernel_size=1,
                                      stride=1,
                                      padding=0)
                        ]
                        self.nin_bn += [nn.BatchNorm2d(self._output_dim)]
                        if n_layers_sub1 > 0 or n_layers_sub2 > 0:
                            assert task_specific_layer

                if last_proj_dim != self.output_dim:
                    self.bridge = LinearND(self._output_dim,
                                           last_proj_dim,
                                           dropout=dropout)
                    self._output_dim = last_proj_dim

        # Initialize parameters
        self.reset_parameters(param_init)
コード例 #11
0
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 rnn_type,
                 n_units,
                 n_projs,
                 n_layers,
                 residual,
                 bottleneck_dim,
                 emb_dim,
                 vocab,
                 tie_embedding=False,
                 dropout=0.0,
                 dropout_emb=0.0,
                 lsm_prob=0.0,
                 ctc_weight=0.0,
                 ctc_lsm_prob=0.0,
                 ctc_fc_list=[],
                 lm_init=None,
                 lmobj_weight=0.0,
                 share_lm_softmax=False,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 param_init=0.1,
                 start_pointing=False,
                 end_pointing=True):

        super(RNNTransducer, self).__init__()
        logger = logging.getLogger('training')

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.vocab = vocab
        self.rnn_type = rnn_type
        assert rnn_type in ['lstm_transducer', 'gru_transducer']
        self.enc_n_units = enc_n_units
        self.dec_n_units = n_units
        self.n_projs = n_projs
        self.n_layers = n_layers
        self.residual = residual
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.lmobj_weight = lmobj_weight
        self.share_lm_softmax = share_lm_softmax
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        # VAD
        self.start_pointing = start_pointing
        self.end_pointing = end_pointing

        # for cache
        self.prev_spk = ''
        self.lmstate_final = None
        self.state_cache = OrderedDict()

        if ctc_weight > 0:
            self.ctc = CTC(eos=eos,
                           blank=blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=param_init)

        if ctc_weight < global_weight:
            import warprnnt_pytorch
            self.warprnnt_loss = warprnnt_pytorch.RNNTLoss()

            # for MTL with LM objective
            if lmobj_weight > 0:
                if share_lm_softmax:
                    self.output_lmobj = self.output  # share paramters
                else:
                    self.output_lmobj = LinearND(n_units, vocab)

            # Prediction network
            self.fast_impl = False
            rnn = nn.LSTM if rnn_type == 'lstm_transducer' else nn.GRU
            if n_projs == 0 and not residual:
                self.fast_impl = True
                self.rnn = rnn(emb_dim,
                               n_units,
                               n_layers,
                               bias=True,
                               batch_first=True,
                               dropout=dropout,
                               bidirectional=False)
                # NOTE: pytorch introduces a dropout layer on the outputs of each layer EXCEPT the last layer
                dec_idim = n_units
                self.dropout_top = nn.Dropout(p=dropout)
            else:
                self.rnn = nn.ModuleList()
                self.dropout = nn.ModuleList(
                    [nn.Dropout(p=dropout) for _ in range(n_layers)])
                if n_projs > 0:
                    self.proj = nn.ModuleList(
                        [LinearND(dec_idim, n_projs) for _ in range(n_layers)])
                dec_idim = emb_dim
                for l in range(n_layers):
                    self.rnn += [
                        rnn(dec_idim,
                            n_units,
                            1,
                            bias=True,
                            batch_first=True,
                            dropout=0,
                            bidirectional=False)
                    ]
                    dec_idim = n_projs if n_projs > 0 else n_units

            self.embed = Embedding(vocab,
                                   emb_dim,
                                   dropout=dropout_emb,
                                   ignore_index=pad)

            self.w_enc = LinearND(enc_n_units, bottleneck_dim, bias=True)
            self.w_dec = LinearND(dec_idim, bottleneck_dim, bias=False)
            self.output = LinearND(bottleneck_dim, vocab)

        # Initialize parameters
        self.reset_parameters(param_init)

        # prediction network initialization with pre-trained LM
        if lm_init is not None:
            assert lm_init.vocab == vocab
            assert lm_init.n_units == n_units
            assert lm_init.n_projs == n_projs
            assert lm_init.n_layers == n_layers
            assert lm_init.residual == residual

            param_dict = dict(lm_init.named_parameters())
            for n, p in self.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if 'output' in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)
コード例 #12
0
    def __init__(self, args, save_path=None):

        super(LMBase, self).__init__()
        logger = logging.getLogger('training')
        logger.info(self.__class__.__name__)

        self.save_path = save_path

        self.emb_dim = args.emb_dim
        self.rnn_type = args.lm_type
        assert args.lm_type in ['lstm', 'gru']
        self.n_units = args.n_units
        self.n_projs = args.n_projs
        self.n_layers = args.n_layers
        self.residual = args.residual
        self.use_glu = args.use_glu
        self.n_units_cv = args.n_units_null_context

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []

        self.embed = Embedding(vocab=self.vocab,
                               emb_dim=args.emb_dim,
                               dropout=args.dropout_emb,
                               ignore_index=self.pad)

        self.fast_impl = False
        rnn = nn.LSTM if args.lm_type == 'lstm' else nn.GRU
        if args.n_projs == 0 and not args.residual:
            self.fast_impl = True
            self.rnn = rnn(args.emb_dim + args.n_units_null_context,
                           args.n_units,
                           args.n_layers,
                           bias=True,
                           batch_first=True,
                           dropout=args.dropout_hidden,
                           bidirectional=False)
            # NOTE: pytorch introduces a dropout layer on the outputs of each layer EXCEPT the last layer
            rnn_idim = args.n_units
            self.dropout_top = nn.Dropout(p=args.dropout_hidden)
        else:
            self.rnn = nn.ModuleList()
            self.dropout = nn.ModuleList([
                nn.Dropout(p=args.dropout_hidden) for _ in range(args.n_layers)
            ])
            if args.n_projs > 0:
                self.proj = nn.ModuleList([
                    LinearND(args.n_units, args.n_projs)
                    for _ in range(args.n_layers)
                ])
            rnn_idim = args.emb_dim + args.n_units_null_context
            for l in range(args.n_layers):
                self.rnn += [
                    rnn(rnn_idim,
                        args.n_units,
                        1,
                        bias=True,
                        batch_first=True,
                        dropout=0,
                        bidirectional=False)
                ]
                rnn_idim = args.n_units
                if args.n_projs > 0:
                    rnn_idim = args.n_projs

        if self.use_glu:
            self.fc_glu = LinearND(rnn_idim,
                                   rnn_idim * 2,
                                   dropout=args.dropout_hidden)

        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                rnn_idim,
                self.vocab,
                # cutoffs=[self.vocab // 10, 3 * self.vocab // 10],
                cutoffs=[self.vocab // 25, self.vocab // 5],
                div_value=4.0)
            self.output = None
        else:
            self.adaptive_softmax = None
            self.output = LinearND(rnn_idim,
                                   self.vocab,
                                   dropout=args.dropout_out)
            # NOTE: include bias even when tying weights

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if args.tie_embedding:
                if args.n_units != args.emb_dim:
                    raise ValueError(
                        'When using the tied flag, n_units must be equal to emb_dim.'
                    )
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters(args.param_init)

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init,
                                  dist='orthogonal',
                                  keys=['rnn', 'weight'])
コード例 #13
0
    def __init__(self,
                 input_dim,
                 attn_type,
                 attn_n_heads,
                 n_layers,
                 d_model,
                 d_ff,
                 pe_type,
                 dropout_in=0,
                 dropout=0,
                 dropout_att=0,
                 layer_norm_eps=1e-6,
                 n_stacks=1,
                 n_splices=1,
                 conv_in_channel=1,
                 conv_channels=0,
                 conv_kernel_sizes=[],
                 conv_strides=[],
                 conv_poolings=[],
                 conv_batch_norm=False,
                 conv_residual=False,
                 conv_bottleneck_dim=0):

        super(TransformerEncoder, self).__init__()

        self.d_model = d_model
        self.pe_type = pe_type

        # Setting for CNNs before RNNs
        if conv_channels:
            channels = [int(c) for c in conv_channels.split('_')
                        ] if len(conv_channels) > 0 else []
            kernel_sizes = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_kernel_sizes.split('_')
                            ] if len(conv_kernel_sizes) > 0 else []
            strides = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_strides.split('_')
                       ] if len(conv_strides) > 0 else []
            poolings = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_poolings.split('_')
                        ] if len(conv_poolings) > 0 else []
        else:
            channels = []
            kernel_sizes = []
            strides = []
            poolings = []

        if len(channels) > 0:
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim * n_stacks,
                                    in_channel=conv_in_channel,
                                    channels=channels,
                                    kernel_sizes=kernel_sizes,
                                    strides=strides,
                                    poolings=poolings,
                                    dropout=0,
                                    batch_norm=conv_batch_norm,
                                    residual=conv_residual,
                                    bottleneck_dim=d_model)
            self._output_dim = self.conv.output_dim
        else:
            self._output_dim = input_dim * n_splices * n_stacks
            self.conv = None

            self.embed_in = LinearND(
                self._output_dim, d_model,
                dropout=0)  # NOTE: do not apply dropout here

        if pe_type:
            self.pos_emb_in = PositionalEncoding(d_model, dropout_in, pe_type)
        self.layer_norm_in = nn.LayerNorm(d_model, eps=layer_norm_eps)

        # Self-attention layers
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(d_model, d_ff, attn_type, attn_n_heads,
                                    dropout, dropout_att, layer_norm_eps)
            for l in range(n_layers)
        ])
        self.layer_norm_top = nn.LayerNorm(d_model, eps=layer_norm_eps)

        self._output_dim = d_model
コード例 #14
0
    def __init__(self,
                 key_dim,
                 query_dim,
                 attn_type,
                 attn_dim,
                 sharpening_factor=1,
                 sigmoid_smoothing=False,
                 conv_out_channels=10,
                 conv_kernel_size=100,
                 dropout=0):

        super(AttentionMechanism, self).__init__()

        self.attn_type = attn_type
        self.attn_dim = attn_dim
        self.sharpening_factor = sharpening_factor
        self.sigmoid_smoothing = sigmoid_smoothing
        self.n_heads = 1
        self.key = None
        self.mask = None

        # attention dropout applied AFTER the softmax layer
        self.attn_dropout = nn.Dropout(p=dropout)

        if attn_type == 'no':
            pass
            # NOTE: sequence-to-sequence without attetnion (use the last state as a context vector)

        elif attn_type == 'add':
            self.w_key = LinearND(key_dim, attn_dim, bias=True)
            self.w_query = LinearND(query_dim, attn_dim, bias=False)
            self.v = LinearND(attn_dim, 1, bias=False)

        elif attn_type == 'location':
            self.w_key = LinearND(key_dim, attn_dim, bias=True)
            self.w_query = LinearND(query_dim, attn_dim, bias=False)
            self.w_conv = LinearND(conv_out_channels, attn_dim, bias=False)
            self.conv = nn.Conv2d(in_channels=1,
                                  out_channels=conv_out_channels,
                                  kernel_size=(1, conv_kernel_size * 2 + 1),
                                  stride=1,
                                  padding=(0, conv_kernel_size),
                                  bias=False)
            self.v = LinearND(attn_dim, 1, bias=False)

        elif attn_type == 'dot':
            self.w_key = LinearND(key_dim, attn_dim, bias=False)
            self.w_query = LinearND(query_dim, attn_dim, bias=False)

        elif attn_type == 'luong_dot':
            pass
            # NOTE: no additional parameters

        elif attn_type == 'luong_general':
            self.w_key = LinearND(key_dim, query_dim, bias=False)

        elif attn_type == 'luong_concat':
            self.w = LinearND(key_dim + query_dim, attn_dim, bias=False)
            self.v = LinearND(attn_dim, 1, bias=False)

        else:
            raise ValueError(attn_type)
コード例 #15
0
ファイル: transformer.py プロジェクト: ChavesLiu/neural_sp
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 attn_type,
                 attn_n_heads,
                 n_layers,
                 d_model,
                 d_ff,
                 vocab,
                 tie_embedding=False,
                 pe_type='add',
                 layer_norm_eps=1e-6,
                 dropout=0.0,
                 dropout_emb=0.0,
                 dropout_att=0.0,
                 lsm_prob=0.0,
                 focal_loss_weight=0.0,
                 focal_loss_gamma=2.0,
                 ctc_weight=0.0,
                 ctc_lsm_prob=0.0,
                 ctc_fc_list=[],
                 backward=False,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 adaptive_softmax=False):

        super(TransformerDecoder, self).__init__()
        logger = logging.getLogger('training')

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.enc_n_units = enc_n_units
        self.d_model = d_model
        self.n_layers = n_layers
        self.attn_n_heads = attn_n_heads
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.focal_loss_weight = focal_loss_weight
        self.focal_loss_gamma = focal_loss_gamma
        self.ctc_weight = ctc_weight
        self.bwd = backward
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        if ctc_weight > 0:
            self.ctc = CTC(eos=eos,
                           blank=blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=0.1)

        if ctc_weight < global_weight:
            self.embed = Embedding(
                vocab,
                d_model,
                dropout=0,  # NOTE: do not apply dropout here
                ignore_index=pad)
            self.pos_enc = PositionalEncoding(d_model, dropout_emb, pe_type)
            self.layers = nn.ModuleList([
                TransformerDecoderBlock(d_model, d_ff, attn_type, attn_n_heads,
                                        dropout, dropout_att, layer_norm_eps)
                for _ in range(n_layers)
            ])
            self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

            if adaptive_softmax:
                self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                    d_model,
                    vocab,
                    cutoffs=[
                        round(self.vocab / 15), 3 * round(self.vocab / 15)
                    ],
                    # cutoffs=[self.vocab // 25, 3 * self.vocab // 5],
                    div_value=4.0)
                self.output = None
            else:
                self.adaptive_softmax = None
                self.output = LinearND(d_model, vocab)

                # Optionally tie weights as in:
                # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
                # https://arxiv.org/abs/1608.05859
                # and
                # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
                # https://arxiv.org/abs/1611.01462
                if tie_embedding:
                    self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters()
コード例 #16
0
    def __init__(self, args):

        super(ModelBase, self).__init__()

        self.emb_dim = args.emb_dim
        self.rnn_type = args.lm_type
        assert args.lm_type in ['lstm', 'gru']
        self.n_units = args.n_units
        self.n_layers = args.n_layers
        self.residual = args.residual
        self.use_glu = args.use_glu

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []

        self.embed = Embedding(vocab=self.vocab,
                               emb_dim=args.emb_dim,
                               dropout=args.dropout_emb,
                               ignore_index=self.pad)

        self.fast_impl = False
        if args.n_projs == 0 and not args.residual:
            self.fast_impl = True
            if 'lstm' in args.lm_type:
                rnn = nn.LSTM
            elif 'gru' in args.lm_type:
                rnn = nn.GRU
            else:
                raise ValueError('rnn_type must be "(b)lstm" or "(b)gru".')

            self.rnn = rnn(args.emb_dim,
                           args.n_units,
                           args.n_layers,
                           bias=True,
                           batch_first=True,
                           dropout=args.dropout_hidden,
                           bidirectional=False)
            # NOTE: pytorch introduces a dropout layer on the outputs of each layer EXCEPT the last layer
            rnn_idim = args.n_units
            self.dropout_top = nn.Dropout(p=args.dropout_hidden)
        else:
            self.rnn = torch.nn.ModuleList()
            self.dropout = torch.nn.ModuleList()
            if args.n_projs > 0:
                self.proj = torch.nn.ModuleList()
            rnn_idim = args.emb_dim
            for l in range(args.n_layers):
                self.rnn += [
                    getattr(nn, args.lm_type.upper())(rnn_idim,
                                                      args.n_units,
                                                      1,
                                                      bias=True,
                                                      batch_first=True,
                                                      dropout=0,
                                                      bidirectional=False)
                ]
                self.dropout += [nn.Dropout(p=args.dropout_hidden)]
                rnn_idim = args.n_units

                if l != self.n_layers - 1 and args.n_projs > 0:
                    self.proj += [LinearND(rnn_idim, args.n_projs)]
                    rnn_idim = args.n_projs

        if self.use_glu:
            self.fc_glu = LinearND(rnn_idim,
                                   rnn_idim * 2,
                                   dropout=args.dropout_hidden)

        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                rnn_idim,
                self.vocab,
                # cutoffs=[self.vocab // 10, 3 * self.vocab // 10],
                cutoffs=[self.vocab // 25, self.vocab // 5],
                div_value=4.0)
            self.output = None
        else:
            self.adaptive_softmax = None
            self.output = LinearND(rnn_idim,
                                   self.vocab,
                                   dropout=args.dropout_out)
            # NOTE: include bias even when tying weights

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if args.tie_embedding:
                if args.n_units != args.emb_dim:
                    raise ValueError(
                        'When using the tied flag, n_units must be equal to emb_dim.'
                    )
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters(args.param_init, dist=args.param_init_dist)

        # Initialize bias vectors with zero
        self.reset_parameters(0, dist='constant', keys=['bias'])

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init,
                                  dist='orthogonal',
                                  keys=['rnn', 'weight'])
コード例 #17
0
    def __init__(self, args, save_path=None):

        super(LMBase, self).__init__()
        logger = logging.getLogger('training')
        logger.info(self.__class__.__name__)

        self.save_path = save_path

        self.d_model = args.d_model
        self.d_ff = args.d_ff
        self.pe_type = args.pe_type
        self.n_layers = args.n_layers
        self.n_heads = args.attn_n_heads
        self.tie_embedding = args.tie_embedding

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # self.lsm_prob = lsm_prob

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []

        self.embed = Embedding(
            vocab=self.vocab,
            emb_dim=self.d_model,
            dropout=0,  # NOTE: do not apply dropout here
            ignore_index=self.pad)
        self.pos_enc = PositionalEncoding(args.d_model, args.dropout_emb,
                                          args.pe_type)

        self.layers = nn.ModuleList([
            TransformerDecoderBlock(args.d_model,
                                    args.d_ff,
                                    args.attn_type,
                                    args.attn_n_heads,
                                    args.dropout_hidden,
                                    args.dropout_att,
                                    args.layer_norm_eps,
                                    src_attention=False)
            for _ in range(self.n_layers)
        ])
        self.norm_out = nn.LayerNorm(args.d_model, eps=args.layer_norm_eps)

        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                args.d_model,
                self.vocab,
                cutoffs=[round(self.vocab / 15), 3 * round(self.vocab / 15)],
                # cutoffs=[self.vocab // 25, 3 * self.vocab // 5],
                div_value=4.0)
            self.output = None
        else:
            self.adaptive_softmax = None
            self.output = LinearND(self.d_model, self.vocab)

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if args.tie_embedding:
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters()
コード例 #18
0
ファイル: seq2seq.py プロジェクト: xdcesc/neural_sp
    def __init__(self, args):

        super(ModelBase, self).__init__()

        # for encoder
        self.input_type = args.input_type
        self.input_dim = args.input_dim
        self.n_stacks = args.n_stacks
        self.n_skips = args.n_skips
        self.n_splices = args.n_splices
        self.enc_type = args.enc_type
        self.enc_n_units = args.enc_n_units
        if args.enc_type in ['blstm', 'bgru']:
            self.enc_n_units *= 2
        self.bridge_layer = args.bridge_layer

        # for OOV resolution
        self.enc_n_layers = args.enc_n_layers
        self.enc_n_layers_sub1 = args.enc_n_layers_sub1
        self.subsample = [int(s) for s in args.subsample.split('_')]

        # for attention layer
        self.attn_n_heads = args.attn_n_heads

        # for decoder
        self.vocab = args.vocab
        self.vocab_sub1 = args.vocab_sub1
        self.vocab_sub2 = args.vocab_sub2
        self.blank = 0
        self.unk = 1
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for the sub tasks
        self.main_weight = 1 - args.sub1_weight - args.sub2_weight
        self.sub1_weight = args.sub1_weight
        self.sub2_weight = args.sub2_weight
        self.mtl_per_batch = args.mtl_per_batch
        self.task_specific_layer = args.task_specific_layer

        # for CTC
        self.ctc_weight = min(args.ctc_weight, self.main_weight)
        self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight)
        self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight)

        # for backward decoder
        self.bwd_weight = min(args.bwd_weight, self.main_weight)
        self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight
        self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1
        self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2

        # Feature extraction
        self.ssn = None
        if args.sequence_summary_network:
            assert args.input_type == 'speech'
            self.ssn = SequenceSummaryNetwork(args.input_dim,
                                              n_units=512,
                                              n_layers=3,
                                              bottleneck_dim=100,
                                              dropout=0)

        # Encoder
        if args.enc_type == 'transformer':
            self.enc = TransformerEncoder(
                input_dim=args.input_dim
                if args.input_type == 'speech' else args.emb_dim,
                attn_type=args.transformer_attn_type,
                attn_n_heads=args.transformer_attn_n_heads,
                n_layers=args.transformer_enc_n_layers,
                d_model=args.d_model,
                d_ff=args.d_ff,
                # pe_type=args.pe_type,
                pe_type=False,
                dropout_in=args.dropout_in,
                dropout=args.dropout_enc,
                dropout_att=args.dropout_att,
                layer_norm_eps=args.layer_norm_eps,
                n_stacks=args.n_stacks,
                n_splices=args.n_splices,
                conv_in_channel=args.conv_in_channel,
                conv_channels=args.conv_channels,
                conv_kernel_sizes=args.conv_kernel_sizes,
                conv_strides=args.conv_strides,
                conv_poolings=args.conv_poolings,
                conv_batch_norm=args.conv_batch_norm,
                conv_residual=args.conv_residual,
                conv_bottleneck_dim=args.conv_bottleneck_dim)
        else:
            self.enc = RNNEncoder(
                input_dim=args.input_dim
                if args.input_type == 'speech' else args.emb_dim,
                rnn_type=args.enc_type,
                n_units=args.enc_n_units,
                n_projs=args.enc_n_projs,
                n_layers=args.enc_n_layers,
                n_layers_sub1=args.enc_n_layers_sub1,
                n_layers_sub2=args.enc_n_layers_sub2,
                dropout_in=args.dropout_in,
                dropout=args.dropout_enc,
                subsample=list(map(int, args.subsample.split('_'))) + [1] *
                (args.enc_n_layers - len(args.subsample.split('_'))),
                subsample_type=args.subsample_type,
                n_stacks=args.n_stacks,
                n_splices=args.n_splices,
                conv_in_channel=args.conv_in_channel,
                conv_channels=args.conv_channels,
                conv_kernel_sizes=args.conv_kernel_sizes,
                conv_strides=args.conv_strides,
                conv_poolings=args.conv_poolings,
                conv_batch_norm=args.conv_batch_norm,
                conv_residual=args.conv_residual,
                conv_bottleneck_dim=args.conv_bottleneck_dim,
                residual=args.enc_residual,
                nin=args.enc_nin,
                task_specific_layer=args.task_specific_layer)
            # NOTE: pure CNN/TDS encoders are also included

        if args.freeze_encoder:
            for p in self.enc.parameters():
                p.requires_grad = False

        # Bridge layer between the encoder and decoder
        self.is_bridge = False
        if (args.enc_type in ['conv', 'tds', 'gated_conv', 'transformer']
                and args.ctc_weight < 1
            ) or args.dec_type == 'transformer' or args.bridge_layer:
            self.bridge = LinearND(self.enc.output_dim,
                                   args.d_model if args.dec_type
                                   == 'transformer' else args.dec_n_units,
                                   dropout=args.dropout_enc)
            self.is_bridge = True
            if self.sub1_weight > 0:
                self.bridge_sub1 = LinearND(self.enc.output_dim,
                                            args.dec_n_units,
                                            dropout=args.dropout_enc)
            if self.sub2_weight > 0:
                self.bridge_sub2 = LinearND(self.enc.output_dim,
                                            args.dec_n_units,
                                            dropout=args.dropout_enc)
            self.enc_n_units = args.dec_n_units

        # main task
        directions = []
        if self.fwd_weight > 0 or self.ctc_weight > 0:
            directions.append('fwd')
        if self.bwd_weight > 0:
            directions.append('bwd')
        for dir in directions:
            # Cold fusion
            if args.lm_fusion and dir == 'fwd':
                lm = RNNLM(args.lm_conf)
                lm, _ = load_checkpoint(lm, args.lm_fusion)
            else:
                args.lm_conf = False
                lm = None
            # TODO(hirofumi): cold fusion for backward RNNLM

            # Decoder
            if args.dec_type == 'transformer':
                dec = TransformerDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.transformer_attn_type,
                    attn_n_heads=args.transformer_attn_n_heads,
                    n_layers=args.transformer_dec_n_layers,
                    d_model=args.d_model,
                    d_ff=args.d_ff,
                    pe_type=args.pe_type,
                    tie_embedding=args.tie_embedding,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    lsm_prob=args.lsm_prob,
                    layer_norm_eps=args.layer_norm_eps,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    backward=(dir == 'bwd'),
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch)
            else:
                dec = RNNDecoder(
                    eos=self.eos,
                    unk=self.unk,
                    pad=self.pad,
                    blank=self.blank,
                    enc_n_units=self.enc.output_dim,
                    attn_type=args.attn_type,
                    attn_dim=args.attn_dim,
                    attn_sharpening_factor=args.attn_sharpening,
                    attn_sigmoid_smoothing=args.attn_sigmoid,
                    attn_conv_out_channels=args.attn_conv_n_channels,
                    attn_conv_kernel_size=args.attn_conv_width,
                    attn_n_heads=args.attn_n_heads,
                    rnn_type=args.dec_type,
                    n_units=args.dec_n_units,
                    n_projs=args.dec_n_projs,
                    n_layers=args.dec_n_layers,
                    loop_type=args.dec_loop_type,
                    residual=args.dec_residual,
                    bottleneck_dim=args.dec_bottleneck_dim,
                    emb_dim=args.emb_dim,
                    tie_embedding=args.tie_embedding,
                    vocab=self.vocab,
                    dropout=args.dropout_dec,
                    dropout_emb=args.dropout_emb,
                    dropout_att=args.dropout_att,
                    ss_prob=args.ss_prob,
                    ss_type=args.ss_type,
                    lsm_prob=args.lsm_prob,
                    fl_weight=args.focal_loss_weight,
                    fl_gamma=args.focal_loss_gamma,
                    ctc_weight=self.ctc_weight if dir == 'fwd' else 0,
                    ctc_fc_list=[
                        int(fc) for fc in args.ctc_fc_list.split('_')
                    ] if args.ctc_fc_list is not None
                    and len(args.ctc_fc_list) > 0 else [],
                    input_feeding=args.input_feeding,
                    backward=(dir == 'bwd'),
                    # lm=args.lm_conf,
                    lm=lm,  # TODO(hirofumi): load RNNLM in the model init.
                    lm_fusion_type=args.lm_fusion_type,
                    contextualize=args.contextualize,
                    lm_init=args.lm_init,
                    lmobj_weight=args.lmobj_weight,
                    share_lm_softmax=args.share_lm_softmax,
                    global_weight=self.main_weight -
                    self.bwd_weight if dir == 'fwd' else self.bwd_weight,
                    mtl_per_batch=args.mtl_per_batch,
                    adaptive_softmax=args.adaptive_softmax)
            setattr(self, 'dec_' + dir, dec)

        # sub task
        for sub in ['sub1', 'sub2']:
            if getattr(self, sub + '_weight') > 0:
                if args.dec_type == 'transformer':
                    raise NotImplementedError
                else:
                    dec_sub = RNNDecoder(
                        eos=self.eos,
                        unk=self.unk,
                        pad=self.pad,
                        blank=self.blank,
                        enc_n_units=self.enc_n_units,
                        attn_type=args.attn_type,
                        attn_dim=args.attn_dim,
                        attn_sharpening_factor=args.attn_sharpening,
                        attn_sigmoid_smoothing=args.attn_sigmoid,
                        attn_conv_out_channels=args.attn_conv_n_channels,
                        attn_conv_kernel_size=args.attn_conv_width,
                        attn_n_heads=1,
                        rnn_type=args.dec_type,
                        n_units=args.dec_n_units,
                        n_projs=args.dec_n_projs,
                        n_layers=args.dec_n_layers,
                        loop_type=args.dec_loop_type,
                        residual=args.dec_residual,
                        bottleneck_dim=args.dec_bottleneck_dim,
                        emb_dim=args.emb_dim,
                        tie_embedding=args.tie_embedding,
                        vocab=getattr(self, 'vocab_' + sub),
                        dropout=args.dropout_dec,
                        dropout_emb=args.dropout_emb,
                        dropout_att=args.dropout_att,
                        ss_prob=args.ss_prob,
                        ss_type=args.ss_type,
                        lsm_prob=args.lsm_prob,
                        fl_weight=args.focal_loss_weight,
                        fl_gamma=args.focal_loss_gamma,
                        ctc_weight=getattr(self, 'ctc_weight_' + sub),
                        ctc_fc_list=[
                            int(fc) for fc in getattr(args, 'ctc_fc_list_' +
                                                      sub).split('_')
                        ] if getattr(args, 'ctc_fc_list_' + sub) is not None
                        and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else
                        [],
                        input_feeding=args.input_feeding,
                        global_weight=getattr(self, sub + '_weight'),
                        mtl_per_batch=args.mtl_per_batch)
                setattr(self, 'dec_fwd_' + sub, dec_sub)

        if args.input_type == 'text':
            if args.vocab == args.vocab_sub1:
                # Share the embedding layer between input and output
                self.embed_in = dec.embed
            else:
                self.embed_in = Embedding(vocab=args.vocab_sub1,
                                          emb_dim=args.emb_dim,
                                          dropout=args.dropout_emb,
                                          ignore_index=self.pad)

        # Initialize parameters in CNN layers
        self.reset_parameters(
            args.param_init,
            #   dist='xavier_uniform',
            #   dist='kaiming_uniform',
            dist='lecun',
            keys=['conv'],
            ignore_keys=['score'])

        # Initialize parameters in the encoder
        if args.enc_type == 'transformer':
            self.reset_parameters(args.param_init,
                                  dist='xavier_uniform',
                                  keys=['enc'],
                                  ignore_keys=['embed_in'])
            self.reset_parameters(args.d_model**-0.5,
                                  dist='normal',
                                  keys=['embed_in'])
        else:
            self.reset_parameters(args.param_init,
                                  dist=args.param_init_dist,
                                  keys=['enc'],
                                  ignore_keys=['conv'])

        # Initialize parameters in the decoder
        if args.dec_type == 'transformer':
            self.reset_parameters(args.param_init,
                                  dist='xavier_uniform',
                                  keys=['dec'],
                                  ignore_keys=['embed'])
            self.reset_parameters(args.d_model**-0.5,
                                  dist='normal',
                                  keys=['embed'])
        else:
            self.reset_parameters(args.param_init,
                                  dist=args.param_init_dist,
                                  keys=['dec'])

        # Initialize bias vectors with zero
        self.reset_parameters(0, dist='constant', keys=['bias'])

        # Recurrent weights are orthogonalized
        if args.rec_weight_orthogonal:
            self.reset_parameters(args.param_init,
                                  dist='orthogonal',
                                  keys=['rnn', 'weight'])

        # Initialize bias in forget gate with 1
        # self.init_forget_gate_bias_with_one()

        # Initialize bias in gating with -1 for cold fusion
        if args.lm_fusion:
            self.reset_parameters(-1,
                                  dist='constant',
                                  keys=['linear_lm_gate.fc.bias'])

        if args.lm_fusion_type == 'deep' and args.lm_fusion:
            for n, p in self.named_parameters():
                if 'output' in n or 'output_bn' in n or 'linear' in n:
                    p.requires_grad = True
                else:
                    p.requires_grad = False
コード例 #19
0
    def __init__(self, args, save_path=None):

        super(LMBase, self).__init__()
        logger = logging.getLogger('training')
        logger.info(self.__class__.__name__)

        self.save_path = save_path

        self.emb_dim = args.emb_dim
        self.n_units = args.n_units
        self.n_layers = args.n_layers
        self.lsm_prob = args.lsm_prob

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []

        self.embed = Embedding(vocab=self.vocab,
                               emb_dim=args.emb_dim,
                               dropout=args.dropout_in,
                               ignore_index=self.pad)

        model_size = args.lm_type.replace('gated_conv_', '')

        blocks = OrderedDict()
        if model_size == 'custom':
            blocks['conv1'] = GLUBlock(args.kernel_size,
                                       args.emb_dim,
                                       args.n_units,
                                       bottlececk_dim=args.n_projs,
                                       dropout=args.dropout_hidden)
            for l in range(args.n_layers - 1):
                blocks['conv%d' % (l + 2)] = GLUBlock(
                    args.kernel_size,
                    args.n_units,
                    args.n_units,
                    bottlececk_dim=args.n_projs,
                    dropout=args.dropout_hidden)
            last_dim = args.n_units

        elif model_size == '8':
            blocks['conv1'] = GLUBlock(4,
                                       args.emb_dim,
                                       900,
                                       dropout=args.dropout_hidden)
            for i in range(1, 8, 1):
                blocks['conv2-%d' % i] = GLUBlock(4,
                                                  900,
                                                  900,
                                                  dropout=args.dropout_hidden)
            last_dim = 900

        elif model_size == '8B':
            blocks['conv1'] = GLUBlock(1,
                                       args.emb_dim,
                                       512,
                                       dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                blocks['conv2-%d' % i] = GLUBlock(5,
                                                  512,
                                                  512,
                                                  bottlececk_dim=128,
                                                  dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                blocks['conv3-%d' % i] = GLUBlock(5,
                                                  512,
                                                  512,
                                                  bottlececk_dim=256,
                                                  dropout=args.dropout_hidden)
            blocks['conv4'] = GLUBlock(1,
                                       512,
                                       2048,
                                       bottlececk_dim=1024,
                                       dropout=args.dropout_hidden)
            last_dim = 2048

        elif model_size == '9':
            blocks['conv1'] = GLUBlock(4,
                                       args.emb_dim,
                                       807,
                                       dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                blocks['conv2-%d-1' % i] = GLUBlock(
                    4, 807, 807, dropout=args.dropout_hidden)
                blocks['conv2-%d-2' % i] = GLUBlock(
                    4, 807, 807, dropout=args.dropout_hidden)
            last_dim = 807

        elif model_size == '13':
            blocks['conv1'] = GLUBlock(4,
                                       args.emb_dim,
                                       1268,
                                       dropout=args.dropout_hidden)
            for i in range(1, 13, 1):
                blocks['conv2-%d' % i] = GLUBlock(4,
                                                  1268,
                                                  1268,
                                                  dropout=args.dropout_hidden)
            last_dim = 1268

        elif model_size == '14':
            for i in range(1, 4, 1):
                blocks['conv1-%d' % i] = GLUBlock(
                    6,
                    args.emb_dim if i == 1 else 850,
                    850,
                    dropout=args.dropout_hidden)
            blocks['conv2'] = GLUBlock(1,
                                       850,
                                       850,
                                       dropout=args.dropout_hidden)
            for i in range(1, 5, 1):
                blocks['conv3-%d' % i] = GLUBlock(5,
                                                  850,
                                                  850,
                                                  dropout=args.dropout_hidden)
            blocks['conv4'] = GLUBlock(1,
                                       850,
                                       850,
                                       dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                blocks['conv5-%d' % i] = GLUBlock(4,
                                                  850,
                                                  850,
                                                  dropout=args.dropout_hidden)
            blocks['conv6'] = GLUBlock(4,
                                       850,
                                       1024,
                                       dropout=args.dropout_hidden)
            blocks['conv7'] = GLUBlock(4,
                                       1024,
                                       2048,
                                       dropout=args.dropout_hidden)
            last_dim = 2048

        elif model_size == '14B':
            blocks['conv1'] = GLUBlock(5,
                                       args.emb_dim,
                                       512,
                                       dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                blocks['conv2-%d' % i] = GLUBlock(5,
                                                  512,
                                                  512,
                                                  bottlececk_dim=128,
                                                  dropout=args.dropout_hidden)
            for i in range(1, 4, 1):
                blocks['conv3-%d' % i] = GLUBlock(5,
                                                  512 if i == 1 else 1024,
                                                  1024,
                                                  bottlececk_dim=512,
                                                  dropout=args.dropout_hidden)
            for i in range(1, 7, 1):
                blocks['conv4-%d' % i] = GLUBlock(5,
                                                  1024 if i == 1 else 2048,
                                                  2048,
                                                  bottlececk_dim=1024,
                                                  dropout=args.dropout_hidden)
            blocks['conv5'] = GLUBlock(5,
                                       2048,
                                       4096,
                                       bottlececk_dim=1024,
                                       dropout=args.dropout_hidden)
            last_dim = 4096

        else:
            raise NotImplementedError(model_size)

        self.blocks = nn.Sequential(blocks)

        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                last_dim,
                self.vocab,
                # cutoffs=[self.vocab // 10, 3 * self.vocab // 10],
                cutoffs=[self.vocab // 25, self.vocab // 5],
                div_value=4.0)
            self.output = None
        else:
            self.adaptive_softmax = None
            self.output = LinearND(last_dim,
                                   self.vocab,
                                   dropout=args.dropout_out)
            # NOTE: include bias even when tying weights

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if args.tie_embedding:
                if args.n_units != args.emb_dim:
                    raise ValueError(
                        'When using the tied flag, n_units must be equal to emb_dim.'
                    )
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters(args.param_init)
コード例 #20
0
    def __init__(self,
                 input_dim,
                 attn_type,
                 attn_n_heads,
                 n_layers,
                 d_model,
                 d_ff,
                 pe_type='add',
                 layer_norm_eps=1e-6,
                 dropout_in=0,
                 dropout=0,
                 dropout_att=0,
                 last_proj_dim=0,
                 n_stacks=1,
                 n_splices=1,
                 conv_in_channel=1,
                 conv_channels=0,
                 conv_kernel_sizes=[],
                 conv_strides=[],
                 conv_poolings=[],
                 conv_batch_norm=False,
                 conv_residual=False,
                 conv_bottleneck_dim=0,
                 param_init=0.1):

        super(TransformerEncoder, self).__init__()
        logger = logging.getLogger("training")

        self.d_model = d_model
        self.n_layers = n_layers
        self.attn_n_heads = attn_n_heads
        self.pe_type = pe_type

        # Setting for CNNs before RNNs
        if conv_channels:
            channels = [int(c) for c in conv_channels.split('_')
                        ] if len(conv_channels) > 0 else []
            kernel_sizes = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_kernel_sizes.split('_')
                            ] if len(conv_kernel_sizes) > 0 else []
            strides = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_strides.split('_')
                       ] if len(conv_strides) > 0 else []
            poolings = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_poolings.split('_')
                        ] if len(conv_poolings) > 0 else []
        else:
            channels = []
            kernel_sizes = []
            strides = []
            poolings = []
            logger.warning(
                'Subsampling is automatically ignored because CNN layers are used before RNN layers.'
            )

        if len(channels) > 0:
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=channels,
                                    kernel_sizes=kernel_sizes,
                                    strides=strides,
                                    poolings=poolings,
                                    dropout=0,
                                    batch_norm=conv_batch_norm,
                                    residual=conv_residual,
                                    bottleneck_dim=d_model,
                                    param_init=param_init)
            self._output_dim = self.conv.output_dim
        else:
            self._output_dim = input_dim * n_splices * n_stacks
            self.conv = None

            self.embed = LinearND(self._output_dim, d_model,
                                  dropout=0)  # NOTE: do not apply dropout here

        self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type)
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(d_model, d_ff, attn_type, attn_n_heads,
                                    dropout, dropout_att, layer_norm_eps)
            for l in range(n_layers)
        ])
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

        if last_proj_dim != self.output_dim:
            self.bridge = LinearND(self._output_dim,
                                   last_proj_dim,
                                   dropout=dropout)
            self._output_dim = last_proj_dim
        else:
            self.bridge = None
            self._output_dim = d_model

        # Initialize parameters
        self.reset_parameters()