コード例 #1
0
    def _create_ff_module(self, subnet_unit: ReturnnNetwork, prefix, source):
        ff_ln = subnet_unit.add_layer_norm_layer('{}_ff_ln'.format(prefix),
                                                 source)

        ff1 = subnet_unit.add_linear_layer('{}_ff_conv1'.format(prefix),
                                           ff_ln,
                                           activation=self.ff_act,
                                           forward_weights_init=self.ff_init,
                                           n_out=self.ff_dim,
                                           with_bias=True,
                                           l2=self.l2)

        ff2 = subnet_unit.add_linear_layer('{}_ff_conv2'.format(prefix),
                                           ff1,
                                           activation=None,
                                           forward_weights_init=self.ff_init,
                                           n_out=self.enc_value_dim,
                                           dropout=self.dropout,
                                           with_bias=True,
                                           l2=self.l2)

        drop = subnet_unit.add_dropout_layer('{}_ff_drop'.format(prefix),
                                             ff2,
                                             dropout=self.dropout)

        out = subnet_unit.add_combine_layer('{}_ff_out'.format(prefix),
                                            [drop, source],
                                            kind='add',
                                            n_out=self.enc_value_dim)
        return out
コード例 #2
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
 def _add_density_ratio(self, subnet_unit: ReturnnNetwork, lm_subnet,
                        lm_model):
     subnet_unit.add_subnetwork('density_ratio_output',
                                'prev:output',
                                subnetwork_net=lm_subnet,
                                load_on_init=lm_model)
     lm_output_prob = subnet_unit.add_activation_layer(
         'density_ratio_output_prob',
         'density_ratio_output',
         activation='softmax',
         target=self.target)
     return lm_output_prob
コード例 #3
0
 def _create_decoder_block(self, subnet_unit: ReturnnNetwork, source, i):
     prefix = 'transformer_decoder_%02i' % i
     masked_mhsa = self._create_masked_mhsa(subnet_unit, prefix, source)
     mhsa = self._create_mhsa(subnet_unit, prefix, masked_mhsa)
     ff = self._create_ff_module(subnet_unit, prefix, mhsa)
     out = subnet_unit.add_copy_layer(prefix, ff)
     return out
コード例 #4
0
    def _create_masked_mhsa(self, subnet_unit: ReturnnNetwork, prefix, source):
        prefix = '{}_self_att'.format(prefix)

        ln = subnet_unit.add_layer_norm_layer('{}_ln'.format(prefix), source)

        ln_rel_pos_enc = None
        if self.pos_enc == 'rel':
            ln_rel_pos_enc = self.subnet_unit.add_relative_pos_encoding_layer(
                '{}_ln_rel_pos_enc'.format(prefix),
                ln,
                n_out=self.enc_key_per_head_dim,
                forward_weights_init=self.ff_init,
                clipping=self.rel_pos_clipping)

        att = subnet_unit.add_self_att_layer(
            '{}_att'.format(prefix),
            ln,
            num_heads=self.att_num_heads,
            total_key_dim=self.enc_key_dim,
            n_out=self.enc_value_dim,
            attention_left_only=True,
            att_dropout=self.att_dropout,
            forward_weights_init=self.mhsa_init,
            l2=self.l2,
            key_shift=ln_rel_pos_enc if ln_rel_pos_enc is not None else None)

        linear = subnet_unit.add_linear_layer(
            '{}_linear'.format(prefix),
            att,
            activation=None,
            with_bias=False,
            n_out=self.enc_value_dim,
            forward_weights_init=self.mhsa_init_out,
            l2=self.l2)

        drop = subnet_unit.add_dropout_layer('{}_drop'.format(prefix),
                                             linear,
                                             dropout=self.dropout)

        out = subnet_unit.add_combine_layer('{}_out'.format(prefix),
                                            [drop, source],
                                            kind='add',
                                            n_out=self.enc_value_dim)

        return out
コード例 #5
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
    def create_network(self):
        subnet_unit = ReturnnNetwork()

        dec_output = self.add_decoder_subnetwork(subnet_unit)

        # Add to Encoder network

        if hasattr(self.base_model,
                   'enc_proj_dim') and self.base_model.enc_proj_dim:
            self.base_model.network.add_copy_layer('enc_ctx', 'encoder_proj')
            self.base_model.network.add_split_dim_layer(
                'enc_value',
                'encoder_proj',
                dims=(self.att_num_heads,
                      self.enc_value_dim // self.att_num_heads))
        else:
            self.base_model.network.add_linear_layer('enc_ctx',
                                                     'encoder',
                                                     with_bias=True,
                                                     n_out=self.enc_key_dim,
                                                     l2=self.base_model.l2)
            self.base_model.network.add_split_dim_layer(
                'enc_value',
                'encoder',
                dims=(self.att_num_heads,
                      self.enc_value_dim // self.att_num_heads))

        self.base_model.network.add_linear_layer('inv_fertility',
                                                 'encoder',
                                                 activation='sigmoid',
                                                 n_out=self.att_num_heads,
                                                 with_bias=False)

        decision_layer_name = self.base_model.network.add_decide_layer(
            'decision', dec_output, target=self.target)
        self.decision_layer_name = decision_layer_name

        return dec_output
コード例 #6
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
class RNNDecoder:
    """
  Represents RNN LSTM Attention-based decoder

  Related:
    * Single headed attention based sequence-to-sequence model for state-of-the-art results on Switchboard
      ref: https://arxiv.org/abs/2001.07263
  """
    def __init__(self,
                 base_model,
                 source=None,
                 dropout=0.3,
                 label_smoothing=0.1,
                 target='bpe',
                 beam_size=12,
                 embed_dim=621,
                 embed_dropout=0.,
                 dec_lstm_num_units=1000,
                 dec_output_num_units=1000,
                 l2=None,
                 att_dropout=None,
                 rec_weight_dropout=None,
                 dec_zoneout=False,
                 ff_init=None,
                 add_lstm_lm=False,
                 lstm_lm_dim=1000,
                 loc_conv_att_filter_size=None,
                 loc_conv_att_num_channels=None,
                 reduceout=True,
                 att_num_heads=1,
                 embed_weight_init=None,
                 lstm_weights_init=None):
        """
    :param base_model: base/encoder model instance
    :param str source: input to decoder subnetwork
    :param float dropout: Dropout applied to the softmax input
    :param float label_smoothing: label smoothing value applied to softmax
    :param str target: target data key name
    :param int beam_size: value of the beam size
    :param int embed_dim: target embedding dimension
    :param float|None embed_dropout: dropout to be applied on the target embedding
    :param int dec_lstm_num_units: the number of hidden units for the decoder LSTM
    :param int dec_output_num_units: the number of hidden dimensions for the last layer before softmax
    :param float|None l2: weight decay with l2 norm
    :param float|None att_dropout: dropout applied to attention weights
    :param float|None rec_weight_dropout: dropout applied to weight paramters
    :param bool dec_zoneout: if set, zoneout LSTM cell is used in the decoder instead of nativelstm2
    :param str|None ff_init: feed-forward weights initialization
    :param bool add_lstm_lm: add separate LSTM layer that acts as LM-like model
      same as here: https://arxiv.org/abs/2001.07263
    :param float lstm_lm_dim:
    :param int|None loc_conv_att_filter_size:
    :param int|None loc_conv_att_num_channels:
    :param bool reduceout: if set to True, maxout layer is used
    :param int att_num_heads: number of attention heads
    """

        self.base_model = base_model

        self.source = source

        self.dropout = dropout
        self.label_smoothing = label_smoothing

        self.enc_key_dim = base_model.enc_key_dim
        self.enc_value_dim = base_model.enc_value_dim
        self.att_num_heads = att_num_heads

        self.target = target

        self.beam_size = beam_size

        self.embed_dim = embed_dim
        self.embed_dropout = embed_dropout

        self.dec_lstm_num_units = dec_lstm_num_units
        self.dec_output_num_units = dec_output_num_units

        self.ff_init = ff_init

        self.decision_layer_name = None  # this is set in the end-point config

        self.l2 = l2
        self.att_dropout = att_dropout
        self.rec_weight_dropout = rec_weight_dropout
        self.dec_zoneout = dec_zoneout

        self.add_lstm_lm = add_lstm_lm
        self.lstm_lm_dim = lstm_lm_dim

        self.loc_conv_att_filter_size = loc_conv_att_filter_size
        self.loc_conv_att_num_channels = loc_conv_att_num_channels

        self.embed_weight_init = embed_weight_init
        self.lstm_weights_init = lstm_weights_init

        self.reduceout = reduceout

        self.network = ReturnnNetwork()
        self.subnet_unit = ReturnnNetwork()
        self.dec_output = None

    def add_decoder_subnetwork(self, subnet_unit: ReturnnNetwork):

        subnet_unit.add_compare_layer('end', source='output',
                                      value=0)  # sentence end token

        # target embedding
        subnet_unit.add_linear_layer(
            'target_embed0',
            'output',
            n_out=self.embed_dim,
            initial_output=0,
            with_bias=False,
            l2=self.l2,
            forward_weights_init=self.embed_weight_init)

        subnet_unit.add_dropout_layer('target_embed',
                                      'target_embed0',
                                      dropout=self.embed_dropout,
                                      dropout_noise_shape={'*': None})

        # attention
        att = AttentionMechanism(
            enc_key_dim=self.enc_key_dim,
            att_num_heads=self.att_num_heads,
            att_dropout=self.att_dropout,
            l2=self.l2,
            loc_filter_size=self.loc_conv_att_filter_size,
            loc_num_channels=self.loc_conv_att_num_channels)
        subnet_unit.update(att.create())

        # LM-like component same as here https://arxiv.org/pdf/2001.07263.pdf
        lstm_lm_component = None
        if self.add_lstm_lm:
            lstm_lm_component = subnet_unit.add_rnn_cell_layer(
                'lm_like_s',
                'prev:target_embed',
                n_out=self.lstm_lm_dim,
                l2=self.l2)

        lstm_inputs = []
        if lstm_lm_component:
            lstm_inputs += [lstm_lm_component]
        else:
            lstm_inputs += ['prev:target_embed']
        lstm_inputs += ['prev:att']

        # LSTM decoder (or decoder state)
        if self.dec_zoneout:
            subnet_unit.add_rnn_cell_layer('s',
                                           lstm_inputs,
                                           n_out=self.dec_lstm_num_units,
                                           l2=self.l2,
                                           weights_init=self.lstm_weights_init,
                                           unit='zoneoutlstm',
                                           unit_opts={
                                               'zoneout_factor_cell': 0.15,
                                               'zoneout_factor_output': 0.05
                                           })
        else:
            if self.rec_weight_dropout:
                # a rec layer with unit nativelstm2 is required to use rec_weight_dropout
                subnet_unit.add_rec_layer(
                    's',
                    lstm_inputs,
                    n_out=self.dec_lstm_num_units,
                    l2=self.l2,
                    unit='NativeLSTM2',
                    rec_weight_dropout=self.rec_weight_dropout,
                    weights_init=self.lstm_weights_init)
            else:
                subnet_unit.add_rnn_cell_layer(
                    's',
                    lstm_inputs,
                    n_out=self.dec_lstm_num_units,
                    l2=self.l2,
                    weights_init=self.lstm_weights_init)

        # ASR softmax output layer
        subnet_unit.add_linear_layer('readout_in',
                                     ["s", "prev:target_embed", "att"],
                                     n_out=self.dec_output_num_units,
                                     l2=self.l2)

        if self.reduceout:
            subnet_unit.add_reduceout_layer('readout', 'readout_in')
        else:
            subnet_unit.add_copy_layer('readout', 'readout_in')

        output_prob = subnet_unit.add_softmax_layer(
            'output_prob',
            'readout',
            l2=self.l2,
            loss='ce',
            loss_opts={'label_smoothing': self.label_smoothing},
            target=self.target,
            dropout=self.dropout)

        subnet_unit.add_choice_layer('output',
                                     output_prob,
                                     target=self.target,
                                     beam_size=self.beam_size,
                                     initial_output=0)

        # recurrent subnetwork
        dec_output = self.network.add_subnet_rec_layer(
            'output',
            unit=subnet_unit.get_net(),
            target=self.target,
            source=self.source)

        return dec_output

    def create_network(self):
        self.dec_output = self.add_decoder_subnetwork(self.subnet_unit)

        # Add to Base/Encoder network

        if hasattr(self.base_model,
                   'enc_proj_dim') and self.base_model.enc_proj_dim:
            self.base_model.network.add_copy_layer('enc_ctx', 'encoder_proj')
            self.base_model.network.add_split_dim_layer(
                'enc_value',
                'encoder_proj',
                dims=(self.att_num_heads,
                      self.enc_value_dim // self.att_num_heads))
        else:
            self.base_model.network.add_linear_layer('enc_ctx',
                                                     'encoder',
                                                     with_bias=True,
                                                     n_out=self.enc_key_dim,
                                                     l2=self.base_model.l2)
            self.base_model.network.add_split_dim_layer(
                'enc_value',
                'encoder',
                dims=(self.att_num_heads,
                      self.enc_value_dim // self.att_num_heads))

        self.base_model.network.add_linear_layer('inv_fertility',
                                                 'encoder',
                                                 activation='sigmoid',
                                                 n_out=self.att_num_heads,
                                                 with_bias=False)

        decision_layer_name = self.base_model.network.add_decide_layer(
            'decision', self.dec_output, target=self.target)
        self.decision_layer_name = decision_layer_name

        return self.dec_output
コード例 #7
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
    def _add_local_fusion(self, subnet: ReturnnNetwork, am_output_prob):
        prefix_name = self.local_fusion_opts.get('prefix', 'local_fusion')
        with_label_smoothing = self.local_fusion_opts.get(
            'with_label_smoothing', False)

        if self.local_fusion_opts['lm_type'] == 'n_gram':
            lm_output_prob = subnet.add_kenlm_layer(
                '{}_lm_output_prob'.format(prefix_name),
                **self.local_fusion_opts['kenlm_opts'])
        else:
            lm_subnet = self.local_fusion_opts['lm_subnet']
            lm_model = self.local_fusion_opts['lm_model']
            vocab_size = self.local_fusion_opts['vocab_size']

            # make sure all layers in LM subnet are not trainable
            def make_non_trainable(d):
                for v in d.values():  # layers
                    assert isinstance(v, dict)
                    v.update({'trainable': False})

            # Add LM subnetwork.
            lm_subnet_copy = copy.deepcopy(lm_subnet)
            make_non_trainable(lm_subnet_copy)
            lm_subnet_name = '{}_lm_output'.format(prefix_name)
            subnet.add_subnetwork(lm_subnet_name, ['prev:output'],
                                  subnetwork_net=lm_subnet_copy,
                                  load_on_init=lm_model,
                                  trainable=False,
                                  n_out=vocab_size)
            lm_output_prob = subnet.add_activation_layer(
                '{}_lm_output_prob'.format(prefix_name),
                lm_subnet_name,
                activation='softmax',
                target=self.target)  # not in log-space

        # define new loss criteria
        eval_str = "self.network.get_config().typed_value('fusion_eval0_norm')(safe_log(source(0)), safe_log(source(1)))"
        if self.local_fusion_opts['lm_type'] == 'n_gram':
            eval_str = "self.network.get_config().typed_value('fusion_eval0_norm')(safe_log(source(0)), source(1))"
        combo_output_log_prob = subnet.add_eval_layer(
            'combo_output_log_prob', [am_output_prob, lm_output_prob],
            eval=eval_str)

        # local fusion criteria. Eq. (8) in the paper
        if with_label_smoothing:
            subnet.add_eval_layer(
                'combo_output_prob',
                combo_output_log_prob,
                eval="tf.exp(source(0))",
                target=self.target,
                loss='ce',
                loss_opts={'label_smoothing': self.label_smoothing})
        else:
            subnet.add_eval_layer('combo_output_prob',
                                  combo_output_log_prob,
                                  eval="tf.exp(source(0))",
                                  target=self.target,
                                  loss='ce')

        subnet.add_choice_layer('output',
                                combo_output_log_prob,
                                target=self.target,
                                beam_size=self.beam_size,
                                initial_output=0,
                                input_type='log_prob')
コード例 #8
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
    def _add_external_LM(self,
                         subnet_unit: ReturnnNetwork,
                         am_output_prob,
                         prior_output_prob=None):
        ext_lm_scale = self.ext_lm_opts[
            'lm_scale'] if not self.trained_scales else 'lm_scale'

        is_recurrent = self.ext_lm_opts.get('is_recurrent', False)

        log_lm_prob = False  # if lm_prob is already in log-space or not

        if 'gram_lm' in self.ext_lm_opts['name']:
            log_lm_prob = True  # already in log-space
            lm_output_prob = subnet_unit.add_kenlm_layer(
                'lm_output_prob', **self.ext_lm_opts['kenlm_opts'])
        elif is_recurrent:
            ext_lm_subnet = self.ext_lm_opts['lm_subnet']
            assert isinstance(ext_lm_subnet, dict)

            lm_output_prob = self.ext_lm_opts['lm_output_prob_name']
            ext_lm_subnet[lm_output_prob]['target'] = self.target
            ext_lm_subnet[lm_output_prob][
                'loss'] = None  # TODO: is this needed?
            subnet_unit.update(ext_lm_subnet)  # just append
        else:
            ext_lm_subnet = self.ext_lm_opts['lm_subnet']
            assert isinstance(ext_lm_subnet, dict)

            ext_lm_model = self.ext_lm_opts['lm_model']
            subnet_unit.add_subnetwork('lm_output',
                                       'prev:output',
                                       subnetwork_net=ext_lm_subnet,
                                       load_on_init=ext_lm_model)
            lm_output_prob = subnet_unit.add_activation_layer(
                'lm_output_prob',
                'lm_output',
                activation='softmax',
                target=self.target)

        fusion_str = 'safe_log(source(0)) + {} * '.format(ext_lm_scale)
        if log_lm_prob:
            fusion_str += 'source(1)'
        else:
            fusion_str += 'safe_log(source(1))'

        fusion_source = [am_output_prob, lm_output_prob]
        if prior_output_prob:
            fusion_source += [prior_output_prob]
            prior_scale = self.prior_lm_opts[
                'scale'] if not self.trained_scales else 'prior_scale'
            fusion_str += ' - {} * safe_log(source(2))'.format(prior_scale)

        if self.coverage_term_scale:
            fusion_str += ' + {} * source({})'.format(self.coverage_term_scale,
                                                      len(fusion_source))
            fusion_source += ['accum_coverage']

        if self.trained_scales:
            fusion_str = 'source(0) * safe_log(source(1)) + source(2) * safe_log(source(3))'
            fusion_source = [
                'am_scale', am_output_prob, 'lm_scale', lm_output_prob
            ]
            if prior_output_prob:
                fusion_str += ' - source(4) * safe_log(source(5))'
                fusion_source += ['prior_scale', prior_output_prob]

        subnet_unit.add_eval_layer('combo_output_prob',
                                   source=fusion_source,
                                   eval=fusion_str)
        subnet_unit.add_choice_layer('output',
                                     'combo_output_prob',
                                     target=self.target,
                                     beam_size=self.beam_size,
                                     initial_output=0,
                                     input_type='log_prob')
コード例 #9
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
class RNNDecoder:
    """
  Represents RNN LSTM Attention-based decoder

  Related:
    * Single headed attention based sequence-to-sequence model for state-of-the-art results on Switchboard
      ref: https://arxiv.org/abs/2001.07263
  """
    def __init__(self,
                 base_model,
                 source=None,
                 dropout=0.3,
                 label_smoothing=0.1,
                 target='bpe',
                 length_norm=True,
                 beam_size=12,
                 embed_dim=621,
                 embed_dropout=0.,
                 dec_lstm_num_units=1000,
                 dec_output_num_units=1000,
                 l2=None,
                 att_dropout=None,
                 rec_weight_dropout=None,
                 dec_zoneout=False,
                 ext_lm_opts=None,
                 prior_lm_opts=None,
                 local_fusion_opts=None,
                 ff_init=None,
                 add_lstm_lm=False,
                 lstm_lm_dim=1000,
                 loc_conv_att_filter_size=None,
                 loc_conv_att_num_channels=None,
                 density_ratio_opts=None,
                 mwer=False,
                 reduceout=True,
                 att_num_heads=1,
                 embed_weight=False,
                 coverage_term_scale=None,
                 ilmt_opts=None,
                 trained_scales=False,
                 remove_softmax_bias=False,
                 relax_att_scale=None,
                 ce_loss_scale=None,
                 dec_state_no_label_ctx=False,
                 add_no_label_ctx_s_to_output=False):
        """
    :param base_model: base/encoder model instance
    :param str source: input to decoder subnetwork
    :param float dropout: Dropout applied to the softmax input
    :param float label_smoothing: label smoothing value applied to softmax
    :param List[int]|None pool_sizes: a list of pool sizes between LSTM layers
    :param int enc_key_dim: attention key dimension
    :param int att_num_heads: number of attention heads
    :param str target: target data key name
    :param int beam_size: value of the beam size
    :param int embed_dim: target embedding dimension
    :param float|None embed_dropout: dropout to be applied on the target embedding
    :param int dec_lstm_num_units: the number of hidden units for the decoder LSTM
    :param int dec_output_num_units: the number of hidden dimensions for the last layer before softmax
    :param float|None l2: weight decay with l2 norm
    :param float|None att_dropout: dropout applied to attention weights
    :param float|None rec_weight_dropout: dropout applied to weight paramters
    :param bool dec_zoneout: if set, zoneout LSTM cell is used in the decoder instead of nativelstm2
    :param dict[str]|None ext_lm_opts: external LM opts such as subnetwork, lm_model, scale, etc
    :param float|None prior_lm_scale: prior LM scale
    :param dict[str]|None local_fusion_opts: dict containing LM subnetwork, AM scale, and LM scale
      paper: https://arxiv.org/abs/2005.10049
    :param str|None ff_init: feed-forward weights initialization
    :param bool add_lstm_lm: add separate LSTM layer that acts as LM-like model
      same as here: https://arxiv.org/abs/2001.07263
    """

        self.base_model = base_model

        self.source = source
        self.dropout = dropout
        self.label_smoothing = label_smoothing

        self.enc_key_dim = base_model.enc_key_dim
        self.enc_value_dim = base_model.enc_value_dim
        self.att_num_heads = att_num_heads

        self.target = target
        self.length_norm = length_norm
        self.beam_size = beam_size

        self.embed_dim = embed_dim
        self.embed_dropout = embed_dropout
        self.dec_lstm_num_units = dec_lstm_num_units
        self.dec_output_num_units = dec_output_num_units

        self.ff_init = ff_init

        self.decision_layer_name = None  # this is set in the end-point config

        self.l2 = l2
        self.att_dropout = att_dropout
        self.rec_weight_dropout = rec_weight_dropout
        self.dec_zoneout = dec_zoneout

        self.ext_lm_opts = ext_lm_opts
        self.prior_lm_opts = prior_lm_opts

        self.local_fusion_opts = local_fusion_opts

        self.density_ratio_opts = density_ratio_opts

        self.add_lstm_lm = add_lstm_lm
        self.lstm_lm_dim = lstm_lm_dim

        self.loc_conv_att_filter_size = loc_conv_att_filter_size
        self.loc_conv_att_num_channels = loc_conv_att_num_channels

        self.mwer = mwer

        self.reduceout = reduceout

        self.embed_weight = embed_weight

        self.coverage_term_scale = coverage_term_scale

        self.ilmt_opts = ilmt_opts

        self.trained_scales = trained_scales

        self.remove_softmax_bias = remove_softmax_bias

        self.relax_att_scale = relax_att_scale

        self.ce_loss_scale = ce_loss_scale

        self.dec_state_no_label_ctx = dec_state_no_label_ctx
        self.add_no_label_ctx_s_to_output = add_no_label_ctx_s_to_output

        self.network = ReturnnNetwork()

    def _create_prior_net(self, subnet_unit: ReturnnNetwork, opts):
        prior_type = opts.get('type', 'zero')

        # fixed_ctx_vec_variants = ['zero', 'avg', 'train_avg_ctx', 'train_avg_enc', 'avg_zero', 'trained_vec']

        if prior_type == 'zero':  # set att context vector to zero
            prior_att_input = subnet_unit.add_eval_layer(
                'zero_att', 'att', eval='tf.zeros_like(source(0))')
        elif prior_type == 'avg':  # during search per utterance
            self.base_model.network.add_reduce_layer('encoder_mean',
                                                     'encoder',
                                                     mode='mean',
                                                     axes=['t'
                                                           ])  # [B, enc-dim]
            prior_att_input = 'base:encoder_mean'
        elif prior_type == 'train_avg_ctx':  # average all context vectors over training data
            prior_att_input = subnet_unit.add_constant_layer(
                'train_avg_ctx',
                value=opts['data'],
                with_batch_dim=True,
                dtype='float32')
        elif prior_type == 'train_avg_enc':  # average all encoder states over training data
            prior_att_input = subnet_unit.add_constant_layer(
                'train_avg_enc',
                value=opts['data'],
                with_batch_dim=True,
                dtype='float32')
        elif prior_type == 'mini_lstm':  # train a mini LM-like LSTM and use that as prior
            # example: lstmdim_100-l2_5e-05-recwd_0.0
            n_out = 50
            l2 = 0.0
            recwd = 0.0
            if opts.get('prefix_name', None):
                segs = opts['prefix_name'].split('-')
                for arg in segs:
                    name, val = arg.split('_', 1)
                    if name == 'lstmdim':
                        n_out = int(val)
                    elif name == 'l2':
                        l2 = float(val)
                    elif name == 'recwd':
                        recwd = float(val)

            mini_lstm_inputs = opts.get('mini_lstm_inp',
                                        'prev:target_embed').split('+')
            if len(mini_lstm_inputs) == 1:
                mini_lstm_inputs = mini_lstm_inputs[0]

            subnet_unit.add_rec_layer('mini_att_lstm',
                                      mini_lstm_inputs,
                                      n_out=n_out,
                                      l2=l2,
                                      rec_weight_dropout=recwd)
            prior_att_input = subnet_unit.add_linear_layer('mini_att',
                                                           'mini_att_lstm',
                                                           activation=None,
                                                           n_out=2048,
                                                           l2=0.0001)
        elif prior_type == 'adaptive_ctx_vec':  # \hat{c}_i = FF(h_i)
            num_layers = opts.get('num_layers', 3)
            dim = opts.get('dim', 512)
            act = opts.get('act', 'relu')
            x = 's'
            for i in range(num_layers):
                x = subnet_unit.add_linear_layer('adaptive_att_%d' % i,
                                                 x,
                                                 n_out=dim,
                                                 **opts.get('att_opts', {}))
                x = subnet_unit.add_activation_layer('adaptive_att_%d_%s' %
                                                     (i, act),
                                                     x,
                                                     activation=act)
            prior_att_input = subnet_unit.add_linear_layer('adaptive_att',
                                                           x,
                                                           n_out=2048,
                                                           **opts.get(
                                                               'att_opts', {}))
        elif prior_type == 'trained_vec':
            prior_att_input = subnet_unit.add_variable_layer(
                'trained_vec_att_var', shape=[2048], L2=0.0001)
        elif prior_type == 'avg_zero':
            self.base_model.network.add_reduce_layer('encoder_mean',
                                                     'encoder',
                                                     mode='mean',
                                                     axes=['t'
                                                           ])  # [B, enc-dim]
            subnet_unit.add_eval_layer('zero_att',
                                       'att',
                                       eval='tf.zeros_like(source(0))')
            return
        elif prior_type == 'density_ratio':
            assert 'lm_subnet' in opts and 'lm_model' in opts
            return self._add_density_ratio(subnet_unit,
                                           lm_subnet=opts['lm_subnet'],
                                           lm_model=opts['lm_model'])
        else:
            raise ValueError(
                '{} prior type is not supported'.format(prior_type))

        if prior_type != 'mini_lstm':
            is_first_frame = subnet_unit.add_compare_layer('is_first_frame',
                                                           source=':i',
                                                           kind='equal',
                                                           value=0)
            zero_att = subnet_unit.add_eval_layer(
                'zero_att', 'att', eval='tf.zeros_like(source(0))')
            prev_att = subnet_unit.add_switch_layer('prev_att',
                                                    condition=is_first_frame,
                                                    true_from=zero_att,
                                                    false_from=prior_att_input)
        else:
            prev_att = 'prev:' + prior_att_input

        assert prev_att is not None

        key_names = ['s', 'readout_in', 'readout', 'output_prob']
        for key_name in key_names:
            d = copy.deepcopy(subnet_unit[key_name])
            # update attention input
            new_sources = []
            from_list = d['from']
            if isinstance(from_list, str):
                from_list = [from_list]
            assert isinstance(from_list, list)
            for src in from_list:
                if 'att' in src:
                    if src.split(':')[0] == 'prev':
                        assert prev_att not in new_sources
                        new_sources += [prev_att]
                    else:
                        new_sources += [prior_att_input]
                elif src in key_names:
                    new_sources += [('prev:' if 'prev' in src else '') +
                                    'prior_{}'.format(src.split(':')[-1])]
                else:
                    new_sources += [src]
            d['from'] = new_sources
            subnet_unit['prior_{}'.format(key_name)] = d
        return 'prior_output_prob'

    def _create_ilmt_net(self, subnet_unit: ReturnnNetwork):
        self._create_prior_net(subnet_unit, self.ilmt_opts)  # add prior layers
        subnet_unit['prior_output_prob']['loss_opts'].update(
            {'scale': self.ilmt_opts['scale']})

        # remove label smoothing
        if 'label_smoothing' in subnet_unit['prior_output_prob'][
                'loss_opts'] and self.ilmt_opts.get('no_ilmt_lbs', False):
            subnet_unit['prior_output_prob']['loss_opts'][
                'label_smoothing'] = 0

        if 'label_smoothing' in subnet_unit['prior_output_prob'][
                'loss_opts'] and self.ilmt_opts.get('no_asr_lbs', False):
            subnet_unit['output_prob']['loss_opts']['label_smoothing'] = 0

        reuse_params_mapping = {
            'prior_s': {
                'lstm_cell/kernel': 'output/rec/s/rec/lstm_cell/kernel',
                'lstm_cell/bias': 'output/rec/s/rec/lstm_cell/bias'
            }
        }
        for name in ['prior_readout_in', 'prior_readout', 'prior_output_prob']:
            reuse_params_mapping[name] = {
                'W': 'output/rec/{}/W'.format(name[len('prior_'):]),
                'b': 'output/rec/{}/b'.format(name[len('prior_'):])
            }

        if self.ilmt_opts.get('share_params', False):
            from recipe.crnn.config import CodeWrapper
            layer_names = [
                'prior_s', 'prior_readout_in', 'prior_readout',
                'prior_output_prob'
            ]
            for layer in layer_names:
                value = copy.deepcopy(subnet_unit[layer])
                map = reuse_params_mapping[layer]
                value['reuse_params'] = {'map': {}}
                for k, v in map.items():
                    value['reuse_params']['map'][k] = {
                        'custom':
                        CodeWrapper(
                            "lambda **_kwargs: get_var('{}', _kwargs['shape'])"
                            .format(v))
                    }
                if layer == 'prior_s':
                    #value['reuse_params'] = {'auto_create_missing': True, 'reuse_layer': 's'}
                    value['reuse_params']['auto_create_missing'] = True
                subnet_unit[layer] = value

    def _add_external_LM(self,
                         subnet_unit: ReturnnNetwork,
                         am_output_prob,
                         prior_output_prob=None):
        ext_lm_scale = self.ext_lm_opts[
            'lm_scale'] if not self.trained_scales else 'lm_scale'

        is_recurrent = self.ext_lm_opts.get('is_recurrent', False)

        log_lm_prob = False  # if lm_prob is already in log-space or not

        if 'gram_lm' in self.ext_lm_opts['name']:
            log_lm_prob = True  # already in log-space
            lm_output_prob = subnet_unit.add_kenlm_layer(
                'lm_output_prob', **self.ext_lm_opts['kenlm_opts'])
        elif is_recurrent:
            ext_lm_subnet = self.ext_lm_opts['lm_subnet']
            assert isinstance(ext_lm_subnet, dict)

            lm_output_prob = self.ext_lm_opts['lm_output_prob_name']
            ext_lm_subnet[lm_output_prob]['target'] = self.target
            ext_lm_subnet[lm_output_prob][
                'loss'] = None  # TODO: is this needed?
            subnet_unit.update(ext_lm_subnet)  # just append
        else:
            ext_lm_subnet = self.ext_lm_opts['lm_subnet']
            assert isinstance(ext_lm_subnet, dict)

            ext_lm_model = self.ext_lm_opts['lm_model']
            subnet_unit.add_subnetwork('lm_output',
                                       'prev:output',
                                       subnetwork_net=ext_lm_subnet,
                                       load_on_init=ext_lm_model)
            lm_output_prob = subnet_unit.add_activation_layer(
                'lm_output_prob',
                'lm_output',
                activation='softmax',
                target=self.target)

        fusion_str = 'safe_log(source(0)) + {} * '.format(ext_lm_scale)
        if log_lm_prob:
            fusion_str += 'source(1)'
        else:
            fusion_str += 'safe_log(source(1))'

        fusion_source = [am_output_prob, lm_output_prob]
        if prior_output_prob:
            fusion_source += [prior_output_prob]
            prior_scale = self.prior_lm_opts[
                'scale'] if not self.trained_scales else 'prior_scale'
            fusion_str += ' - {} * safe_log(source(2))'.format(prior_scale)

        if self.coverage_term_scale:
            fusion_str += ' + {} * source({})'.format(self.coverage_term_scale,
                                                      len(fusion_source))
            fusion_source += ['accum_coverage']

        if self.trained_scales:
            fusion_str = 'source(0) * safe_log(source(1)) + source(2) * safe_log(source(3))'
            fusion_source = [
                'am_scale', am_output_prob, 'lm_scale', lm_output_prob
            ]
            if prior_output_prob:
                fusion_str += ' - source(4) * safe_log(source(5))'
                fusion_source += ['prior_scale', prior_output_prob]

        subnet_unit.add_eval_layer('combo_output_prob',
                                   source=fusion_source,
                                   eval=fusion_str)
        subnet_unit.add_choice_layer('output',
                                     'combo_output_prob',
                                     target=self.target,
                                     beam_size=self.beam_size,
                                     initial_output=0,
                                     input_type='log_prob')

    def _add_density_ratio(self, subnet_unit: ReturnnNetwork, lm_subnet,
                           lm_model):
        subnet_unit.add_subnetwork('density_ratio_output',
                                   'prev:output',
                                   subnetwork_net=lm_subnet,
                                   load_on_init=lm_model)
        lm_output_prob = subnet_unit.add_activation_layer(
            'density_ratio_output_prob',
            'density_ratio_output',
            activation='softmax',
            target=self.target)
        return lm_output_prob

    def _add_local_fusion(self, subnet: ReturnnNetwork, am_output_prob):
        prefix_name = self.local_fusion_opts.get('prefix', 'local_fusion')
        with_label_smoothing = self.local_fusion_opts.get(
            'with_label_smoothing', False)

        if self.local_fusion_opts['lm_type'] == 'n_gram':
            lm_output_prob = subnet.add_kenlm_layer(
                '{}_lm_output_prob'.format(prefix_name),
                **self.local_fusion_opts['kenlm_opts'])
        else:
            lm_subnet = self.local_fusion_opts['lm_subnet']
            lm_model = self.local_fusion_opts['lm_model']
            vocab_size = self.local_fusion_opts['vocab_size']

            # make sure all layers in LM subnet are not trainable
            def make_non_trainable(d):
                for v in d.values():  # layers
                    assert isinstance(v, dict)
                    v.update({'trainable': False})

            # Add LM subnetwork.
            lm_subnet_copy = copy.deepcopy(lm_subnet)
            make_non_trainable(lm_subnet_copy)
            lm_subnet_name = '{}_lm_output'.format(prefix_name)
            subnet.add_subnetwork(lm_subnet_name, ['prev:output'],
                                  subnetwork_net=lm_subnet_copy,
                                  load_on_init=lm_model,
                                  trainable=False,
                                  n_out=vocab_size)
            lm_output_prob = subnet.add_activation_layer(
                '{}_lm_output_prob'.format(prefix_name),
                lm_subnet_name,
                activation='softmax',
                target=self.target)  # not in log-space

        # define new loss criteria
        eval_str = "self.network.get_config().typed_value('fusion_eval0_norm')(safe_log(source(0)), safe_log(source(1)))"
        if self.local_fusion_opts['lm_type'] == 'n_gram':
            eval_str = "self.network.get_config().typed_value('fusion_eval0_norm')(safe_log(source(0)), source(1))"
        combo_output_log_prob = subnet.add_eval_layer(
            'combo_output_log_prob', [am_output_prob, lm_output_prob],
            eval=eval_str)

        # local fusion criteria. Eq. (8) in the paper
        if with_label_smoothing:
            subnet.add_eval_layer(
                'combo_output_prob',
                combo_output_log_prob,
                eval="tf.exp(source(0))",
                target=self.target,
                loss='ce',
                loss_opts={'label_smoothing': self.label_smoothing})
        else:
            subnet.add_eval_layer('combo_output_prob',
                                  combo_output_log_prob,
                                  eval="tf.exp(source(0))",
                                  target=self.target,
                                  loss='ce')

        subnet.add_choice_layer('output',
                                combo_output_log_prob,
                                target=self.target,
                                beam_size=self.beam_size,
                                initial_output=0,
                                input_type='log_prob')

    def add_decoder_subnetwork(self, subnet_unit: ReturnnNetwork):

        # target embedding
        if self.embed_dropout:
            # TODO: this is not a good approach. if i want to load a checkpoint from a trained model without embed dropout,
            # i would need to remap variable name target_embed to target_embed0 to load target_embed0/W
            subnet_unit.add_linear_layer('target_embed0',
                                         'output',
                                         n_out=self.embed_dim,
                                         initial_output=0,
                                         with_bias=False)
            subnet_unit.add_dropout_layer('target_embed',
                                          'target_embed0',
                                          dropout=self.embed_dropout,
                                          dropout_noise_shape={'*': None})
        else:
            subnet_unit.add_linear_layer('target_embed',
                                         'output',
                                         n_out=self.embed_dim,
                                         initial_output=0,
                                         with_bias=False)

        subnet_unit.add_compare_layer('end', source='output',
                                      value=0)  # sentence end token

        # ------ attention location-awareness ------ #

        # conv-based
        if self.loc_conv_att_filter_size:
            assert self.loc_conv_att_num_channels
            pad_left = subnet_unit.add_pad_layer(
                'feedback_pad_left',
                'prev:att_weights',
                axes='s:0',
                padding=((self.loc_conv_att_filter_size - 1) // 2, 0),
                value=0)
            pad_right = subnet_unit.add_pad_layer(
                'feedback_pad_right',
                pad_left,
                axes='s:0',
                padding=(0, (self.loc_conv_att_filter_size - 1) // 2),
                value=0)
            loc_att_conv = subnet_unit.add_conv_layer(
                'loc_att_conv',
                pad_right,
                activation=None,
                with_bias=False,
                filter_size=(self.loc_conv_att_filter_size, ),
                padding='valid',
                n_out=self.loc_conv_att_num_channels,
                l2=self.l2)
            subnet_unit.add_linear_layer('weight_feedback',
                                         loc_att_conv,
                                         activation=None,
                                         with_bias=False,
                                         n_out=self.enc_key_dim)
        else:
            # additive
            subnet_unit.add_eval_layer(
                'accum_att_weights', [
                    "prev:accum_att_weights", "att_weights",
                    "base:inv_fertility"
                ],
                eval='source(0) + source(1) * source(2) * 0.5',
                out_type={
                    "dim": self.att_num_heads,
                    "shape": (None, self.att_num_heads)
                })
            subnet_unit.add_linear_layer('weight_feedback',
                                         'prev:accum_att_weights',
                                         n_out=self.enc_key_dim,
                                         with_bias=False)

        subnet_unit.add_linear_layer('s_transformed',
                                     's',
                                     n_out=self.enc_key_dim,
                                     with_bias=False)
        subnet_unit.add_combine_layer(
            'energy_in', ['base:enc_ctx', 'weight_feedback', 's_transformed'],
            kind='add',
            n_out=self.enc_key_dim)
        subnet_unit.add_activation_layer('energy_tanh',
                                         'energy_in',
                                         activation='tanh')
        subnet_unit.add_linear_layer('energy',
                                     'energy_tanh',
                                     n_out=self.att_num_heads,
                                     with_bias=False)

        if self.att_dropout:
            subnet_unit.add_softmax_over_spatial_layer('att_weights0',
                                                       'energy')
            subnet_unit.add_dropout_layer('att_weights',
                                          'att_weights0',
                                          dropout=self.att_dropout,
                                          dropout_noise_shape={'*': None})
        else:
            if self.relax_att_scale:
                subnet_unit.add_softmax_over_spatial_layer(
                    'att_weights0', 'energy')
                subnet_unit.add_length_layer('encoder_len',
                                             'base:encoder',
                                             dtype='float32')  # [B]
                subnet_unit.add_eval_layer('scaled_encoder_len',
                                           source=['encoder_len'],
                                           eval='{} / source(0)'.format(
                                               self.relax_att_scale))
                subnet_unit.add_eval_layer(
                    'att_weights',
                    source=['att_weights0', 'scaled_encoder_len'],
                    eval='{} * source(0) + source(1)'.format(
                        1 - self.relax_att_scale))
            else:
                subnet_unit.add_softmax_over_spatial_layer(
                    'att_weights', 'energy')

        subnet_unit.add_generic_att_layer('att0',
                                          weights='att_weights',
                                          base='base:enc_value')
        subnet_unit.add_merge_dims_layer('att', 'att0', axes='except_batch')

        # LM-like component same as here https://arxiv.org/pdf/2001.07263.pdf
        lstm_lm_component = None
        if self.add_lstm_lm:
            lstm_lm_component = subnet_unit.add_rnn_cell_layer(
                'lm_like_s',
                'prev:target_embed',
                n_out=self.lstm_lm_dim,
                l2=self.l2)

        lstm_inputs = []
        if lstm_lm_component:
            lstm_inputs += [lstm_lm_component]
        else:
            lstm_inputs += ['prev:target_embed']
        lstm_inputs += ['prev:att']

        if self.dec_state_no_label_ctx:
            lstm_inputs = ['prev:att']  # no label feedback

        # LSTM decoder
        if self.dec_zoneout:
            subnet_unit.add_rnn_cell_layer('s',
                                           lstm_inputs,
                                           n_out=self.dec_lstm_num_units,
                                           unit='zoneoutlstm',
                                           unit_opts={
                                               'zoneout_factor_cell': 0.15,
                                               'zoneout_factor_output': 0.05
                                           })
        else:
            if self.rec_weight_dropout:
                # a rec layer with unit nativelstm2 is required to use rec_weight_dropout
                subnet_unit.add_rec_layer(
                    's',
                    lstm_inputs,
                    n_out=self.dec_lstm_num_units,
                    l2=self.l2,
                    rec_weight_dropout=self.rec_weight_dropout,
                    unit='NativeLSTM2')
            else:
                subnet_unit.add_rnn_cell_layer('s',
                                               lstm_inputs,
                                               n_out=self.dec_lstm_num_units,
                                               l2=self.l2)

        # AM softmax output layer
        if self.dec_state_no_label_ctx and self.add_lstm_lm:
            subnet_unit.add_linear_layer(
                'readout_in', ["lm_like_s", "prev:target_embed", "att"],
                n_out=self.dec_output_num_units)
            if self.add_no_label_ctx_s_to_output:
                subnet_unit.add_linear_layer(
                    'readout_in',
                    ["lm_like_s", "s", "prev:target_embed", "att"],
                    n_out=self.dec_output_num_units)
        else:
            subnet_unit.add_linear_layer('readout_in',
                                         ["s", "prev:target_embed", "att"],
                                         n_out=self.dec_output_num_units)

        if self.reduceout:
            subnet_unit.add_reduceout_layer('readout', 'readout_in')
        else:
            subnet_unit.add_copy_layer('readout', 'readout_in')

        if self.local_fusion_opts:
            output_prob = subnet_unit.add_softmax_layer('output_prob',
                                                        'readout',
                                                        l2=self.l2,
                                                        target=self.target,
                                                        dropout=self.dropout)
            self._add_local_fusion(subnet_unit, am_output_prob=output_prob)
        elif self.mwer:
            # only MWER so CE is disabled
            output_prob = subnet_unit.add_softmax_layer('output_prob',
                                                        'readout',
                                                        l2=self.l2,
                                                        target=self.target,
                                                        dropout=self.dropout)
        else:
            ce_loss_opts = {'label_smoothing': self.label_smoothing}
            if self.ce_loss_scale:
                ce_loss_opts['scale'] = self.ce_loss_scale
            output_prob = subnet_unit.add_softmax_layer('output_prob',
                                                        'readout',
                                                        l2=self.l2,
                                                        loss='ce',
                                                        loss_opts=ce_loss_opts,
                                                        target=self.target,
                                                        dropout=self.dropout)

        # do not load the bias
        if self.remove_softmax_bias:
            subnet_unit['output_prob']['with_bias'] = False

        # for prior LM estimation
        prior_output_prob = None
        if self.prior_lm_opts:
            prior_output_prob = self._create_prior_net(
                subnet_unit, self.prior_lm_opts
            )  # this require preload_from_files in config

        # Beam search
        # only support shallow fusion for now
        if self.ext_lm_opts:
            self._add_external_LM(subnet_unit, output_prob, prior_output_prob)
        else:
            if self.coverage_term_scale:
                output_prob = subnet_unit.add_eval_layer(
                    'combo_output_prob',
                    eval='safe_log(source(0)) + {} * source(1)'.format(
                        self.coverage_term_scale),
                    source=['output_prob', 'accum_coverage'])
                input_type = 'log_prob'
            else:
                output_prob = 'output_prob'
                input_type = None

            if self.length_norm:
                subnet_unit.add_choice_layer('output',
                                             output_prob,
                                             target=self.target,
                                             beam_size=self.beam_size,
                                             initial_output=0,
                                             input_type=input_type)
            else:
                subnet_unit.add_choice_layer(
                    'output',
                    output_prob,
                    target=self.target,
                    beam_size=self.beam_size,
                    initial_output=0,
                    length_normalization=self.length_norm,
                    input_type=input_type)

        if self.ilmt_opts:
            self._create_ilmt_net(subnet_unit)

        # recurrent subnetwork
        dec_output = self.network.add_subnet_rec_layer(
            'output',
            unit=subnet_unit.get_net(),
            target=self.target,
            source=self.source)

        return dec_output

    def create_network(self):
        subnet_unit = ReturnnNetwork()

        dec_output = self.add_decoder_subnetwork(subnet_unit)

        # Add to Encoder network

        if hasattr(self.base_model,
                   'enc_proj_dim') and self.base_model.enc_proj_dim:
            self.base_model.network.add_copy_layer('enc_ctx', 'encoder_proj')
            self.base_model.network.add_split_dim_layer(
                'enc_value',
                'encoder_proj',
                dims=(self.att_num_heads,
                      self.enc_value_dim // self.att_num_heads))
        else:
            self.base_model.network.add_linear_layer('enc_ctx',
                                                     'encoder',
                                                     with_bias=True,
                                                     n_out=self.enc_key_dim,
                                                     l2=self.base_model.l2)
            self.base_model.network.add_split_dim_layer(
                'enc_value',
                'encoder',
                dims=(self.att_num_heads,
                      self.enc_value_dim // self.att_num_heads))

        self.base_model.network.add_linear_layer('inv_fertility',
                                                 'encoder',
                                                 activation='sigmoid',
                                                 n_out=self.att_num_heads,
                                                 with_bias=False)

        decision_layer_name = self.base_model.network.add_decide_layer(
            'decision', dec_output, target=self.target)
        self.decision_layer_name = decision_layer_name

        return dec_output
コード例 #10
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
    def __init__(self,
                 base_model,
                 source=None,
                 dropout=0.3,
                 label_smoothing=0.1,
                 target='bpe',
                 length_norm=True,
                 beam_size=12,
                 embed_dim=621,
                 embed_dropout=0.,
                 dec_lstm_num_units=1000,
                 dec_output_num_units=1000,
                 l2=None,
                 att_dropout=None,
                 rec_weight_dropout=None,
                 dec_zoneout=False,
                 ext_lm_opts=None,
                 prior_lm_opts=None,
                 local_fusion_opts=None,
                 ff_init=None,
                 add_lstm_lm=False,
                 lstm_lm_dim=1000,
                 loc_conv_att_filter_size=None,
                 loc_conv_att_num_channels=None,
                 density_ratio_opts=None,
                 mwer=False,
                 reduceout=True,
                 att_num_heads=1,
                 embed_weight=False,
                 coverage_term_scale=None,
                 ilmt_opts=None,
                 trained_scales=False,
                 remove_softmax_bias=False,
                 relax_att_scale=None,
                 ce_loss_scale=None,
                 dec_state_no_label_ctx=False,
                 add_no_label_ctx_s_to_output=False):
        """
    :param base_model: base/encoder model instance
    :param str source: input to decoder subnetwork
    :param float dropout: Dropout applied to the softmax input
    :param float label_smoothing: label smoothing value applied to softmax
    :param List[int]|None pool_sizes: a list of pool sizes between LSTM layers
    :param int enc_key_dim: attention key dimension
    :param int att_num_heads: number of attention heads
    :param str target: target data key name
    :param int beam_size: value of the beam size
    :param int embed_dim: target embedding dimension
    :param float|None embed_dropout: dropout to be applied on the target embedding
    :param int dec_lstm_num_units: the number of hidden units for the decoder LSTM
    :param int dec_output_num_units: the number of hidden dimensions for the last layer before softmax
    :param float|None l2: weight decay with l2 norm
    :param float|None att_dropout: dropout applied to attention weights
    :param float|None rec_weight_dropout: dropout applied to weight paramters
    :param bool dec_zoneout: if set, zoneout LSTM cell is used in the decoder instead of nativelstm2
    :param dict[str]|None ext_lm_opts: external LM opts such as subnetwork, lm_model, scale, etc
    :param float|None prior_lm_scale: prior LM scale
    :param dict[str]|None local_fusion_opts: dict containing LM subnetwork, AM scale, and LM scale
      paper: https://arxiv.org/abs/2005.10049
    :param str|None ff_init: feed-forward weights initialization
    :param bool add_lstm_lm: add separate LSTM layer that acts as LM-like model
      same as here: https://arxiv.org/abs/2001.07263
    """

        self.base_model = base_model

        self.source = source
        self.dropout = dropout
        self.label_smoothing = label_smoothing

        self.enc_key_dim = base_model.enc_key_dim
        self.enc_value_dim = base_model.enc_value_dim
        self.att_num_heads = att_num_heads

        self.target = target
        self.length_norm = length_norm
        self.beam_size = beam_size

        self.embed_dim = embed_dim
        self.embed_dropout = embed_dropout
        self.dec_lstm_num_units = dec_lstm_num_units
        self.dec_output_num_units = dec_output_num_units

        self.ff_init = ff_init

        self.decision_layer_name = None  # this is set in the end-point config

        self.l2 = l2
        self.att_dropout = att_dropout
        self.rec_weight_dropout = rec_weight_dropout
        self.dec_zoneout = dec_zoneout

        self.ext_lm_opts = ext_lm_opts
        self.prior_lm_opts = prior_lm_opts

        self.local_fusion_opts = local_fusion_opts

        self.density_ratio_opts = density_ratio_opts

        self.add_lstm_lm = add_lstm_lm
        self.lstm_lm_dim = lstm_lm_dim

        self.loc_conv_att_filter_size = loc_conv_att_filter_size
        self.loc_conv_att_num_channels = loc_conv_att_num_channels

        self.mwer = mwer

        self.reduceout = reduceout

        self.embed_weight = embed_weight

        self.coverage_term_scale = coverage_term_scale

        self.ilmt_opts = ilmt_opts

        self.trained_scales = trained_scales

        self.remove_softmax_bias = remove_softmax_bias

        self.relax_att_scale = relax_att_scale

        self.ce_loss_scale = ce_loss_scale

        self.dec_state_no_label_ctx = dec_state_no_label_ctx
        self.add_no_label_ctx_s_to_output = add_no_label_ctx_s_to_output

        self.network = ReturnnNetwork()
コード例 #11
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
    def _create_prior_net(self, subnet_unit: ReturnnNetwork, opts):
        prior_type = opts.get('type', 'zero')

        # fixed_ctx_vec_variants = ['zero', 'avg', 'train_avg_ctx', 'train_avg_enc', 'avg_zero', 'trained_vec']

        if prior_type == 'zero':  # set att context vector to zero
            prior_att_input = subnet_unit.add_eval_layer(
                'zero_att', 'att', eval='tf.zeros_like(source(0))')
        elif prior_type == 'avg':  # during search per utterance
            self.base_model.network.add_reduce_layer('encoder_mean',
                                                     'encoder',
                                                     mode='mean',
                                                     axes=['t'
                                                           ])  # [B, enc-dim]
            prior_att_input = 'base:encoder_mean'
        elif prior_type == 'train_avg_ctx':  # average all context vectors over training data
            prior_att_input = subnet_unit.add_constant_layer(
                'train_avg_ctx',
                value=opts['data'],
                with_batch_dim=True,
                dtype='float32')
        elif prior_type == 'train_avg_enc':  # average all encoder states over training data
            prior_att_input = subnet_unit.add_constant_layer(
                'train_avg_enc',
                value=opts['data'],
                with_batch_dim=True,
                dtype='float32')
        elif prior_type == 'mini_lstm':  # train a mini LM-like LSTM and use that as prior
            # example: lstmdim_100-l2_5e-05-recwd_0.0
            n_out = 50
            l2 = 0.0
            recwd = 0.0
            if opts.get('prefix_name', None):
                segs = opts['prefix_name'].split('-')
                for arg in segs:
                    name, val = arg.split('_', 1)
                    if name == 'lstmdim':
                        n_out = int(val)
                    elif name == 'l2':
                        l2 = float(val)
                    elif name == 'recwd':
                        recwd = float(val)

            mini_lstm_inputs = opts.get('mini_lstm_inp',
                                        'prev:target_embed').split('+')
            if len(mini_lstm_inputs) == 1:
                mini_lstm_inputs = mini_lstm_inputs[0]

            subnet_unit.add_rec_layer('mini_att_lstm',
                                      mini_lstm_inputs,
                                      n_out=n_out,
                                      l2=l2,
                                      rec_weight_dropout=recwd)
            prior_att_input = subnet_unit.add_linear_layer('mini_att',
                                                           'mini_att_lstm',
                                                           activation=None,
                                                           n_out=2048,
                                                           l2=0.0001)
        elif prior_type == 'adaptive_ctx_vec':  # \hat{c}_i = FF(h_i)
            num_layers = opts.get('num_layers', 3)
            dim = opts.get('dim', 512)
            act = opts.get('act', 'relu')
            x = 's'
            for i in range(num_layers):
                x = subnet_unit.add_linear_layer('adaptive_att_%d' % i,
                                                 x,
                                                 n_out=dim,
                                                 **opts.get('att_opts', {}))
                x = subnet_unit.add_activation_layer('adaptive_att_%d_%s' %
                                                     (i, act),
                                                     x,
                                                     activation=act)
            prior_att_input = subnet_unit.add_linear_layer('adaptive_att',
                                                           x,
                                                           n_out=2048,
                                                           **opts.get(
                                                               'att_opts', {}))
        elif prior_type == 'trained_vec':
            prior_att_input = subnet_unit.add_variable_layer(
                'trained_vec_att_var', shape=[2048], L2=0.0001)
        elif prior_type == 'avg_zero':
            self.base_model.network.add_reduce_layer('encoder_mean',
                                                     'encoder',
                                                     mode='mean',
                                                     axes=['t'
                                                           ])  # [B, enc-dim]
            subnet_unit.add_eval_layer('zero_att',
                                       'att',
                                       eval='tf.zeros_like(source(0))')
            return
        elif prior_type == 'density_ratio':
            assert 'lm_subnet' in opts and 'lm_model' in opts
            return self._add_density_ratio(subnet_unit,
                                           lm_subnet=opts['lm_subnet'],
                                           lm_model=opts['lm_model'])
        else:
            raise ValueError(
                '{} prior type is not supported'.format(prior_type))

        if prior_type != 'mini_lstm':
            is_first_frame = subnet_unit.add_compare_layer('is_first_frame',
                                                           source=':i',
                                                           kind='equal',
                                                           value=0)
            zero_att = subnet_unit.add_eval_layer(
                'zero_att', 'att', eval='tf.zeros_like(source(0))')
            prev_att = subnet_unit.add_switch_layer('prev_att',
                                                    condition=is_first_frame,
                                                    true_from=zero_att,
                                                    false_from=prior_att_input)
        else:
            prev_att = 'prev:' + prior_att_input

        assert prev_att is not None

        key_names = ['s', 'readout_in', 'readout', 'output_prob']
        for key_name in key_names:
            d = copy.deepcopy(subnet_unit[key_name])
            # update attention input
            new_sources = []
            from_list = d['from']
            if isinstance(from_list, str):
                from_list = [from_list]
            assert isinstance(from_list, list)
            for src in from_list:
                if 'att' in src:
                    if src.split(':')[0] == 'prev':
                        assert prev_att not in new_sources
                        new_sources += [prev_att]
                    else:
                        new_sources += [prior_att_input]
                elif src in key_names:
                    new_sources += [('prev:' if 'prev' in src else '') +
                                    'prior_{}'.format(src.split(':')[-1])]
                else:
                    new_sources += [src]
            d['from'] = new_sources
            subnet_unit['prior_{}'.format(key_name)] = d
        return 'prior_output_prob'
コード例 #12
0
class RNNEncoder:
    """
  Represents RNN LSTM Attention-based Encoder
  """
    def __init__(self,
                 input='data',
                 enc_layers=6,
                 bidirectional=True,
                 residual_lstm=False,
                 residual_proj_dim=None,
                 specaug=True,
                 with_conv=True,
                 dropout=0.3,
                 pool_sizes='3_2',
                 lstm_dim=None,
                 enc_key_dim=1024,
                 enc_value_dim=2048,
                 att_num_heads=1,
                 target='bpe',
                 l2=None,
                 rec_weight_dropout=None,
                 with_ctc=False,
                 ctc_dropout=0.,
                 ctc_l2=0.,
                 ctc_opts=None,
                 enc_proj_dim=None,
                 ctc_loss_scale=None,
                 conv_time_pooling=None):
        """
    :param str input: (layer) name of the network input
    :param int enc_layers: the number of encoder layers
    :param bool bidirectional: If set, bidirectional LSTMs are used
    :param bool specaug: If True, SpecAugment is used
    :param bool with_conv: if True, conv layers are applied initially
    :param float dropout: Dropout applied on the input of multiple layers
    :param str|int|List[int]|None pool_sizes: a list of pool sizes between LSTM layers
    :param int enc_key_dim: attention key dimension
    :param int enc_value_dim: attention value dimension
    :param int att_num_heads: number of attention heads
    :param str target: target data key name
    :param float|None l2: weight decay with l2 norm
    :param float|None rec_weight_dropout: dropout applied to the hidden-to-hidden LSTM weight matrices
    :param bool with_ctc: if set, CTC is used
    :param float ctc_dropout: dropout applied on input to ctc
    :param float ctc_l2: L2 applied to the weight matrix of CTC softmax
    :param dict[str] ctc_opts: options for CTC
    """

        self.input = input
        self.enc_layers = enc_layers

        if pool_sizes is not None:
            if isinstance(pool_sizes, str):
                pool_sizes = list(map(
                    int, pool_sizes.split('_'))) + [1] * (enc_layers - 3)
            elif isinstance(pool_sizes, int):
                pool_sizes = [pool_sizes] * (self.enc_layers - 1)

            assert isinstance(pool_sizes, list), 'pool_sizes must be a list'
            assert all([isinstance(e, int) for e in pool_sizes
                        ]), 'pool_sizes must only contains integers'
            assert len(pool_sizes) < enc_layers

        self.pool_sizes = pool_sizes

        if conv_time_pooling is None:
            self.conv_time_pooling = [1, 1]
        else:
            self.conv_time_pooling = list(
                map(int, conv_time_pooling.split('_')))

        self.bidirectional = bidirectional

        self.residual_lstm = residual_lstm
        self.residual_proj_dim = residual_proj_dim

        self.specaug = specaug
        self.with_conv = with_conv
        self.dropout = dropout

        self.enc_key_dim = enc_key_dim
        self.enc_value_dim = enc_value_dim
        self.att_num_heads = att_num_heads
        self.enc_key_per_head_dim = enc_key_dim // att_num_heads
        self.enc_val_per_head_dim = enc_value_dim // att_num_heads
        self.lstm_dim = lstm_dim
        if lstm_dim is None:
            self.lstm_dim = enc_value_dim // 2

        self.target = target

        self.l2 = l2
        self.rec_weight_dropout = rec_weight_dropout

        self.with_ctc = with_ctc
        self.ctc_dropout = ctc_dropout
        self.ctc_l2 = ctc_l2
        self.ctc_loss_scale = ctc_loss_scale
        self.ctc_opts = ctc_opts
        if self.ctc_opts is None:
            self.ctc_opts = {}

        self.enc_proj_dim = enc_proj_dim

        self.network = ReturnnNetwork()

    def create_network(self):
        data = self.input
        if self.specaug:
            data = self.network.add_eval_layer(
                'source',
                data,
                eval=
                "self.network.get_config().typed_value('transform')(source(0, as_data=True), network=self.network)"
            )

        lstm_input = data
        if self.with_conv:
            lstm_input = self.network.add_conv_block(
                'conv_merged',
                data,
                hwpc_sizes=[((3, 3), (self.conv_time_pooling[0], 2), 32),
                            ((3, 3), (self.conv_time_pooling[1], 2), 32)],
                l2=self.l2,
                activation=None)

        if self.residual_lstm:
            last_lstm_layer = self.network.add_residual_lstm_layers(
                lstm_input,
                self.enc_layers,
                self.lstm_dim,
                self.dropout,
                self.l2,
                self.rec_weight_dropout,
                self.pool_sizes,
                residual_proj_dim=self.residual_proj_dim,
                batch_norm=True)
        else:
            last_lstm_layer = self.network.add_lstm_layers(
                lstm_input, self.enc_layers, self.lstm_dim, self.dropout,
                self.l2, self.rec_weight_dropout, self.pool_sizes,
                self.bidirectional)

        encoder = self.network.add_copy_layer('encoder', last_lstm_layer)
        if self.enc_proj_dim:
            encoder = self.network.add_linear_layer('encoder_proj',
                                                    encoder,
                                                    n_out=self.enc_proj_dim,
                                                    l2=self.l2,
                                                    dropout=self.dropout)

        if self.with_ctc:
            default_ctc_loss_opts = {
                "beam_width": 1,
                "ctc_opts": {
                    "ignore_longer_outputs_than_inputs": True
                }
            }
            default_ctc_loss_opts.update(self.ctc_opts)
            if self.ctc_loss_scale:
                default_ctc_loss_opts['scale'] = self.ctc_loss_scale
            self.network.add_softmax_layer('ctc',
                                           encoder,
                                           l2=self.ctc_l2,
                                           target=self.target,
                                           loss='ctc',
                                           dropout=self.ctc_dropout,
                                           loss_opts=default_ctc_loss_opts)

        return encoder
コード例 #13
0
    def _create_mhsa(self, subnet_unit: ReturnnNetwork, prefix, source):
        ln = subnet_unit.add_layer_norm_layer('{}_att_ln'.format(prefix),
                                              source)

        att_query0 = subnet_unit.add_linear_layer(
            '{}_att_query0'.format(prefix),
            ln,
            with_bias=False,
            n_out=self.enc_value_dim,
            forward_weights_init=self.mhsa_init,
            l2=self.l2)

        # (B, H, D/H)
        att_query = subnet_unit.add_split_dim_layer(
            '{}_att_query'.format(prefix),
            att_query0,
            axis='F',
            dims=(self.enc_att_num_heads, self.enc_key_per_head_dim))

        # --------------- Add to the encoder network --------------- #
        att_key0 = self.base_model.network.add_linear_layer(
            '{}_att_key0'.format(prefix),
            'encoder',
            with_bias=False,
            n_out=self.enc_key_dim,
            forward_weights_init=self.mhsa_init,
            l2=self.l2)

        # (B, enc-T, H, D/H)
        att_key = self.base_model.network.add_split_dim_layer(
            '{}_att_key'.format(prefix),
            att_key0,
            axis='F',
            dims=(self.enc_att_num_heads, self.enc_key_per_head_dim))

        att_value0 = self.base_model.network.add_linear_layer(
            '{}_att_value0'.format(prefix),
            'encoder',
            with_bias=False,
            n_out=self.enc_value_dim,
            forward_weights_init=self.mhsa_init,
            l2=self.l2)

        # (B, enc-T, H, D'/H)
        att_value = self.base_model.network.add_split_dim_layer(
            '{}_att_value'.format(prefix),
            att_value0,
            axis='F',
            dims=(self.enc_att_num_heads, self.enc_val_per_head_dim))
        # ----------------------------------------------------------- #

        # (B, H, enc-T, 1)
        att_energy = subnet_unit.add_dot_layer(
            '{}_att_energy'.format(prefix),
            source=['base:' + att_key, att_query],
            red1=-1,
            red2=-1,
            var1='T',
            var2='T?')

        att_weights = subnet_unit.add_softmax_over_spatial_layer(
            '{}_att_weights'.format(prefix),
            att_energy,
            energy_factor=self.enc_key_per_head_dim**-0.5)

        att_weights_drop = subnet_unit.add_dropout_layer(
            '{}_att_weights_drop'.format(prefix),
            att_weights,
            dropout=self.att_dropout,
            dropout_noise_shape={"*": None})

        # (B, H, V)
        att0 = subnet_unit.add_generic_att_layer('{}_att0'.format(prefix),
                                                 weights=att_weights_drop,
                                                 base='base:' + att_value)

        att = subnet_unit.add_merge_dims_layer(
            '{}_att'.format(prefix), att0,
            axes='static')  # (B, H*V) except_batch

        # output projection
        att_linear = subnet_unit.add_linear_layer(
            '{}_att_linear'.format(prefix),
            att,
            with_bias=False,
            n_out=self.enc_value_dim,
            forward_weights_init=self.mhsa_init_out,
            l2=self.l2)

        att_drop = subnet_unit.add_dropout_layer('{}_att_drop'.format(prefix),
                                                 att_linear,
                                                 dropout=self.dropout)

        out = subnet_unit.add_combine_layer('{}_att_out'.format(prefix),
                                            [att_drop, source],
                                            kind='add',
                                            n_out=self.enc_value_dim)
        return out
コード例 #14
0
class TransformerDecoder:
    """
  Represents standard Transformer decoder

  * Attention Is All You Need
  * Ref: https://arxiv.org/abs/1706.03762
  """
    def __init__(self,
                 base_model,
                 target='bpe',
                 dec_layers=6,
                 beam_size=12,
                 ff_init=None,
                 ff_dim=2048,
                 ff_act='relu',
                 att_num_heads=8,
                 dropout=0.1,
                 att_dropout=0.0,
                 softmax_dropout=0.0,
                 embed_dropout=0.1,
                 l2=0.0,
                 embed_pos_enc=False,
                 apply_embed_weight=False,
                 label_smoothing=0.1,
                 mhsa_init=None,
                 mhsa_out_init=None,
                 pos_enc='rel',
                 rel_pos_clipping=16):

        self.base_model = base_model
        self.enc_value_dim = base_model.enc_value_dim
        self.enc_key_dim = base_model.enc_key_dim
        self.enc_att_num_heads = base_model.att_num_heads
        self.enc_key_per_head_dim = base_model.enc_key_per_head_dim
        self.enc_val_per_head_dim = base_model.enc_val_per_head_dim

        self.att_num_heads = att_num_heads

        self.target = target
        self.dec_layers = dec_layers
        self.beam_size = beam_size

        self.ff_init = ff_init
        self.ff_dim = ff_dim
        self.ff_act = ff_act

        self.mhsa_init = mhsa_init
        self.mhsa_init_out = mhsa_out_init

        self.pos_enc = pos_enc
        self.rel_pos_clipping = rel_pos_clipping

        self.dropout = dropout
        self.softmax_dropout = softmax_dropout
        self.att_dropout = att_dropout
        self.label_smoothing = label_smoothing

        self.l2 = l2

        self.embed_dropout = embed_dropout
        self.embed_pos_enc = embed_pos_enc

        self.embed_weight = None

        if apply_embed_weight:
            self.embed_weight = self.enc_value_dim**0.5

        self.decision_layer_name = None

        self.network = ReturnnNetwork()
        self.subnet_unit = ReturnnNetwork()
        self.output_prob = None

    def _create_masked_mhsa(self, subnet_unit: ReturnnNetwork, prefix, source):
        prefix = '{}_self_att'.format(prefix)

        ln = subnet_unit.add_layer_norm_layer('{}_ln'.format(prefix), source)

        ln_rel_pos_enc = None
        if self.pos_enc == 'rel':
            ln_rel_pos_enc = self.subnet_unit.add_relative_pos_encoding_layer(
                '{}_ln_rel_pos_enc'.format(prefix),
                ln,
                n_out=self.enc_key_per_head_dim,
                forward_weights_init=self.ff_init,
                clipping=self.rel_pos_clipping)

        att = subnet_unit.add_self_att_layer(
            '{}_att'.format(prefix),
            ln,
            num_heads=self.att_num_heads,
            total_key_dim=self.enc_key_dim,
            n_out=self.enc_value_dim,
            attention_left_only=True,
            att_dropout=self.att_dropout,
            forward_weights_init=self.mhsa_init,
            l2=self.l2,
            key_shift=ln_rel_pos_enc if ln_rel_pos_enc is not None else None)

        linear = subnet_unit.add_linear_layer(
            '{}_linear'.format(prefix),
            att,
            activation=None,
            with_bias=False,
            n_out=self.enc_value_dim,
            forward_weights_init=self.mhsa_init_out,
            l2=self.l2)

        drop = subnet_unit.add_dropout_layer('{}_drop'.format(prefix),
                                             linear,
                                             dropout=self.dropout)

        out = subnet_unit.add_combine_layer('{}_out'.format(prefix),
                                            [drop, source],
                                            kind='add',
                                            n_out=self.enc_value_dim)

        return out

    def _create_mhsa(self, subnet_unit: ReturnnNetwork, prefix, source):
        ln = subnet_unit.add_layer_norm_layer('{}_att_ln'.format(prefix),
                                              source)

        att_query0 = subnet_unit.add_linear_layer(
            '{}_att_query0'.format(prefix),
            ln,
            with_bias=False,
            n_out=self.enc_value_dim,
            forward_weights_init=self.mhsa_init,
            l2=self.l2)

        # (B, H, D/H)
        att_query = subnet_unit.add_split_dim_layer(
            '{}_att_query'.format(prefix),
            att_query0,
            axis='F',
            dims=(self.enc_att_num_heads, self.enc_key_per_head_dim))

        # --------------- Add to the encoder network --------------- #
        att_key0 = self.base_model.network.add_linear_layer(
            '{}_att_key0'.format(prefix),
            'encoder',
            with_bias=False,
            n_out=self.enc_key_dim,
            forward_weights_init=self.mhsa_init,
            l2=self.l2)

        # (B, enc-T, H, D/H)
        att_key = self.base_model.network.add_split_dim_layer(
            '{}_att_key'.format(prefix),
            att_key0,
            axis='F',
            dims=(self.enc_att_num_heads, self.enc_key_per_head_dim))

        att_value0 = self.base_model.network.add_linear_layer(
            '{}_att_value0'.format(prefix),
            'encoder',
            with_bias=False,
            n_out=self.enc_value_dim,
            forward_weights_init=self.mhsa_init,
            l2=self.l2)

        # (B, enc-T, H, D'/H)
        att_value = self.base_model.network.add_split_dim_layer(
            '{}_att_value'.format(prefix),
            att_value0,
            axis='F',
            dims=(self.enc_att_num_heads, self.enc_val_per_head_dim))
        # ----------------------------------------------------------- #

        # (B, H, enc-T, 1)
        att_energy = subnet_unit.add_dot_layer(
            '{}_att_energy'.format(prefix),
            source=['base:' + att_key, att_query],
            red1=-1,
            red2=-1,
            var1='T',
            var2='T?')

        att_weights = subnet_unit.add_softmax_over_spatial_layer(
            '{}_att_weights'.format(prefix),
            att_energy,
            energy_factor=self.enc_key_per_head_dim**-0.5)

        att_weights_drop = subnet_unit.add_dropout_layer(
            '{}_att_weights_drop'.format(prefix),
            att_weights,
            dropout=self.att_dropout,
            dropout_noise_shape={"*": None})

        # (B, H, V)
        att0 = subnet_unit.add_generic_att_layer('{}_att0'.format(prefix),
                                                 weights=att_weights_drop,
                                                 base='base:' + att_value)

        att = subnet_unit.add_merge_dims_layer(
            '{}_att'.format(prefix), att0,
            axes='static')  # (B, H*V) except_batch

        # output projection
        att_linear = subnet_unit.add_linear_layer(
            '{}_att_linear'.format(prefix),
            att,
            with_bias=False,
            n_out=self.enc_value_dim,
            forward_weights_init=self.mhsa_init_out,
            l2=self.l2)

        att_drop = subnet_unit.add_dropout_layer('{}_att_drop'.format(prefix),
                                                 att_linear,
                                                 dropout=self.dropout)

        out = subnet_unit.add_combine_layer('{}_att_out'.format(prefix),
                                            [att_drop, source],
                                            kind='add',
                                            n_out=self.enc_value_dim)
        return out

    def _create_ff_module(self, subnet_unit: ReturnnNetwork, prefix, source):
        ff_ln = subnet_unit.add_layer_norm_layer('{}_ff_ln'.format(prefix),
                                                 source)

        ff1 = subnet_unit.add_linear_layer('{}_ff_conv1'.format(prefix),
                                           ff_ln,
                                           activation=self.ff_act,
                                           forward_weights_init=self.ff_init,
                                           n_out=self.ff_dim,
                                           with_bias=True,
                                           l2=self.l2)

        ff2 = subnet_unit.add_linear_layer('{}_ff_conv2'.format(prefix),
                                           ff1,
                                           activation=None,
                                           forward_weights_init=self.ff_init,
                                           n_out=self.enc_value_dim,
                                           dropout=self.dropout,
                                           with_bias=True,
                                           l2=self.l2)

        drop = subnet_unit.add_dropout_layer('{}_ff_drop'.format(prefix),
                                             ff2,
                                             dropout=self.dropout)

        out = subnet_unit.add_combine_layer('{}_ff_out'.format(prefix),
                                            [drop, source],
                                            kind='add',
                                            n_out=self.enc_value_dim)
        return out

    def _create_decoder_block(self, subnet_unit: ReturnnNetwork, source, i):
        prefix = 'transformer_decoder_%02i' % i
        masked_mhsa = self._create_masked_mhsa(subnet_unit, prefix, source)
        mhsa = self._create_mhsa(subnet_unit, prefix, masked_mhsa)
        ff = self._create_ff_module(subnet_unit, prefix, mhsa)
        out = subnet_unit.add_copy_layer(prefix, ff)
        return out

    def _create_decoder(self, subnet_unit: ReturnnNetwork):

        self.output_prob = subnet_unit.add_softmax_layer(
            'output_prob',
            'decoder',
            loss='ce',
            loss_opts={'label_smoothing': self.label_smoothing},
            target=self.target,
            dropout=self.softmax_dropout,
            forward_weights_init=self.ff_init,
            l2=self.l2)

        output = subnet_unit.add_choice_layer('output',
                                              self.output_prob,
                                              target=self.target,
                                              beam_size=self.beam_size,
                                              initial_output=0)
        subnet_unit.add_compare_layer('end', output, value=0)

        target_embed_raw = subnet_unit.add_linear_layer(
            'target_embed_raw',
            'prev:' + output,
            with_bias=False,
            n_out=self.enc_value_dim,
            forward_weights_init=self.ff_init,
            l2=self.l2)

        if self.embed_weight:
            target_embed_raw = subnet_unit.add_eval_layer(
                'target_embed_weighted',
                target_embed_raw,
                eval='source(0) * %f' % self.embed_weight)

        if self.embed_pos_enc:
            target_embed_raw = subnet_unit.add_pos_encoding_layer(
                'target_embed_pos_enc', target_embed_raw)

        target_embed = subnet_unit.add_dropout_layer(
            'target_embed',
            target_embed_raw,
            dropout=self.embed_dropout,
            dropout_noise_shape={"*": None})

        x = target_embed
        for i in range(1, self.dec_layers + 1):
            x = self._create_decoder_block(subnet_unit, x, i)
        subnet_unit.add_layer_norm_layer('decoder', x)

        dec_output = self.network.add_subnet_rec_layer(
            'output', unit=subnet_unit.get_net(), target=self.target)

        return dec_output

    def create_network(self):
        dec_output = self._create_decoder(self.subnet_unit)

        # recurrent subnetwork
        decision_layer_name = self.base_model.network.add_decide_layer(
            'decision', dec_output, target=self.target)
        self.decision_layer_name = decision_layer_name

        return dec_output
コード例 #15
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
    def __init__(self,
                 base_model,
                 source=None,
                 dropout=0.3,
                 label_smoothing=0.1,
                 target='bpe',
                 beam_size=12,
                 embed_dim=621,
                 embed_dropout=0.,
                 dec_lstm_num_units=1000,
                 dec_output_num_units=1000,
                 l2=None,
                 att_dropout=None,
                 rec_weight_dropout=None,
                 dec_zoneout=False,
                 ff_init=None,
                 add_lstm_lm=False,
                 lstm_lm_dim=1000,
                 loc_conv_att_filter_size=None,
                 loc_conv_att_num_channels=None,
                 reduceout=True,
                 att_num_heads=1,
                 embed_weight_init=None,
                 lstm_weights_init=None):
        """
    :param base_model: base/encoder model instance
    :param str source: input to decoder subnetwork
    :param float dropout: Dropout applied to the softmax input
    :param float label_smoothing: label smoothing value applied to softmax
    :param str target: target data key name
    :param int beam_size: value of the beam size
    :param int embed_dim: target embedding dimension
    :param float|None embed_dropout: dropout to be applied on the target embedding
    :param int dec_lstm_num_units: the number of hidden units for the decoder LSTM
    :param int dec_output_num_units: the number of hidden dimensions for the last layer before softmax
    :param float|None l2: weight decay with l2 norm
    :param float|None att_dropout: dropout applied to attention weights
    :param float|None rec_weight_dropout: dropout applied to weight paramters
    :param bool dec_zoneout: if set, zoneout LSTM cell is used in the decoder instead of nativelstm2
    :param str|None ff_init: feed-forward weights initialization
    :param bool add_lstm_lm: add separate LSTM layer that acts as LM-like model
      same as here: https://arxiv.org/abs/2001.07263
    :param float lstm_lm_dim:
    :param int|None loc_conv_att_filter_size:
    :param int|None loc_conv_att_num_channels:
    :param bool reduceout: if set to True, maxout layer is used
    :param int att_num_heads: number of attention heads
    """

        self.base_model = base_model

        self.source = source

        self.dropout = dropout
        self.label_smoothing = label_smoothing

        self.enc_key_dim = base_model.enc_key_dim
        self.enc_value_dim = base_model.enc_value_dim
        self.att_num_heads = att_num_heads

        self.target = target

        self.beam_size = beam_size

        self.embed_dim = embed_dim
        self.embed_dropout = embed_dropout

        self.dec_lstm_num_units = dec_lstm_num_units
        self.dec_output_num_units = dec_output_num_units

        self.ff_init = ff_init

        self.decision_layer_name = None  # this is set in the end-point config

        self.l2 = l2
        self.att_dropout = att_dropout
        self.rec_weight_dropout = rec_weight_dropout
        self.dec_zoneout = dec_zoneout

        self.add_lstm_lm = add_lstm_lm
        self.lstm_lm_dim = lstm_lm_dim

        self.loc_conv_att_filter_size = loc_conv_att_filter_size
        self.loc_conv_att_num_channels = loc_conv_att_num_channels

        self.embed_weight_init = embed_weight_init
        self.lstm_weights_init = lstm_weights_init

        self.reduceout = reduceout

        self.network = ReturnnNetwork()
        self.subnet_unit = ReturnnNetwork()
        self.dec_output = None
コード例 #16
0
    def _create_decoder(self, subnet_unit: ReturnnNetwork):

        self.output_prob = subnet_unit.add_softmax_layer(
            'output_prob',
            'decoder',
            loss='ce',
            loss_opts={'label_smoothing': self.label_smoothing},
            target=self.target,
            dropout=self.softmax_dropout,
            forward_weights_init=self.ff_init,
            l2=self.l2)

        output = subnet_unit.add_choice_layer('output',
                                              self.output_prob,
                                              target=self.target,
                                              beam_size=self.beam_size,
                                              initial_output=0)
        subnet_unit.add_compare_layer('end', output, value=0)

        target_embed_raw = subnet_unit.add_linear_layer(
            'target_embed_raw',
            'prev:' + output,
            with_bias=False,
            n_out=self.enc_value_dim,
            forward_weights_init=self.ff_init,
            l2=self.l2)

        if self.embed_weight:
            target_embed_raw = subnet_unit.add_eval_layer(
                'target_embed_weighted',
                target_embed_raw,
                eval='source(0) * %f' % self.embed_weight)

        if self.embed_pos_enc:
            target_embed_raw = subnet_unit.add_pos_encoding_layer(
                'target_embed_pos_enc', target_embed_raw)

        target_embed = subnet_unit.add_dropout_layer(
            'target_embed',
            target_embed_raw,
            dropout=self.embed_dropout,
            dropout_noise_shape={"*": None})

        x = target_embed
        for i in range(1, self.dec_layers + 1):
            x = self._create_decoder_block(subnet_unit, x, i)
        subnet_unit.add_layer_norm_layer('decoder', x)

        dec_output = self.network.add_subnet_rec_layer(
            'output', unit=subnet_unit.get_net(), target=self.target)

        return dec_output
コード例 #17
0
    def __init__(self,
                 input='data',
                 input_layer='conv',
                 num_blocks=16,
                 conv_kernel_size=32,
                 specaug=True,
                 pos_enc='rel',
                 activation='swish',
                 block_final_norm=True,
                 ff_dim=512,
                 ff_bias=True,
                 ctc_loss_scale=None,
                 dropout=0.1,
                 att_dropout=0.1,
                 enc_key_dim=256,
                 att_num_heads=4,
                 target='bpe',
                 l2=0.0,
                 lstm_dropout=0.1,
                 rec_weight_dropout=0.,
                 with_ctc=False,
                 native_ctc=False,
                 ctc_dropout=0.,
                 ctc_l2=0.,
                 ctc_opts=None,
                 subsample=None,
                 start_conv_init=None,
                 conv_module_init=None,
                 mhsa_init=None,
                 mhsa_out_init=None,
                 ff_init=None,
                 rel_pos_clipping=16,
                 dropout_in=0.1,
                 stoc_layers_prob=0.0,
                 batch_norm_opts=None,
                 pytorch_bn_opts=False,
                 use_ln=False,
                 pooling_str=None,
                 self_att_l2=0.0,
                 sandwich_conv=False):
        """
    :param str input: input layer name
    :param str input_layer: type of input layer which does subsampling
    :param int num_blocks: number of Conformer blocks
    :param int conv_kernel_size: kernel size for conv layers in Convolution module
    :param bool|None specaug: If true, then SpecAug is appliedi wi
    :param str|None activation: activation used to sandwich modules
    :param bool block_final_norm: if True, apply layer norm at the end of each conformer block
    :param bool final_norm: if True, apply layer norm to the output of the encoder
    :param int|None ff_dim: dimension of the first linear layer in FF module
    :param str|None ff_init: FF layers initialization
    :param bool|None ff_bias: If true, then bias is used for the FF layers
    :param float embed_dropout: dropout applied to the source embedding
    :param float dropout: general dropout
    :param float att_dropout: dropout applied to attention weights
    :param int enc_key_dim: encoder key dimension, also denoted as d_model, or d_key
    :param int att_num_heads: the number of attention heads
    :param str target: target labels key name
    :param float l2: add L2 regularization for trainable weights parameters
    :param float lstm_dropout: dropout applied to the input of the LSTMs in case they are used
    :param float rec_weight_dropout: dropout applied to the hidden-to-hidden weight matrices of the LSTM in case used
    :param bool with_ctc: if true, CTC loss is used
    :param bool native_ctc: if true, use returnn native ctc implementation instead of TF implementation
    :param float ctc_dropout: dropout applied on input to ctc
    :param float ctc_l2: L2 applied to the weight matrix of CTC softmax
    :param dict[str] ctc_opts: options for CTC
    """

        self.input = input
        self.input_layer = input_layer

        self.num_blocks = num_blocks
        self.conv_kernel_size = conv_kernel_size

        self.pos_enc = pos_enc
        self.rel_pos_clipping = rel_pos_clipping

        self.ff_bias = ff_bias

        self.specaug = specaug

        self.activation = activation

        self.block_final_norm = block_final_norm

        self.dropout = dropout
        self.att_dropout = att_dropout
        self.lstm_dropout = lstm_dropout

        self.dropout_in = dropout_in

        # key and value dimensions are the same
        self.enc_key_dim = enc_key_dim
        self.enc_value_dim = enc_key_dim
        self.att_num_heads = att_num_heads
        self.enc_key_per_head_dim = enc_key_dim // att_num_heads
        self.enc_val_per_head_dim = enc_key_dim // att_num_heads

        self.ff_dim = ff_dim
        if self.ff_dim is None:
            self.ff_dim = 2 * self.enc_key_dim

        self.target = target

        self.l2 = l2
        self.self_att_l2 = self_att_l2
        self.rec_weight_dropout = rec_weight_dropout

        if batch_norm_opts is None:
            batch_norm_opts = {}

        if pytorch_bn_opts:
            batch_norm_opts['momentum'] = 0.1
            batch_norm_opts['epsilon'] = 1e-3
            batch_norm_opts['update_sample_only_in_training'] = True
            batch_norm_opts['delay_sample_update'] = True

        self.batch_norm_opts = batch_norm_opts

        self.with_ctc = with_ctc
        self.native_ctc = native_ctc
        self.ctc_dropout = ctc_dropout
        self.ctc_loss_scale = ctc_loss_scale
        self.ctc_l2 = ctc_l2
        self.ctc_opts = ctc_opts
        if not self.ctc_opts:
            self.ctc_opts = {}

        self.start_conv_init = start_conv_init
        self.conv_module_init = conv_module_init
        self.mhsa_init = mhsa_init
        self.mhsa_out_init = mhsa_out_init
        self.ff_init = ff_init

        self.sandwich_conv = sandwich_conv

        # add maxpooling layers
        self.subsample = subsample
        self.subsample_list = [1] * num_blocks
        if subsample:
            for idx, s in enumerate(map(int,
                                        subsample.split('_')[:num_blocks])):
                self.subsample_list[idx] = s

        self.network = ReturnnNetwork()

        self.stoc_layers_prob = stoc_layers_prob
        if stoc_layers_prob:
            # this is only used to define the shape for the dropout mask (it needs source)
            self.mask_var = self.network.add_variable_layer('mask_var',
                                                            shape=(1, ),
                                                            init=1)

        self.use_ln = use_ln

        self.pooling_str = pooling_str
コード例 #18
0
class ConformerEncoder:
    """
  Represents Conformer Encoder Architecture

  * Conformer: Convolution-augmented Transformer for Speech Recognition
  * Ref: https://arxiv.org/abs/2005.08100
  """
    def __init__(self,
                 input='data',
                 input_layer='conv',
                 num_blocks=16,
                 conv_kernel_size=32,
                 specaug=True,
                 pos_enc='rel',
                 activation='swish',
                 block_final_norm=True,
                 ff_dim=512,
                 ff_bias=True,
                 ctc_loss_scale=None,
                 dropout=0.1,
                 att_dropout=0.1,
                 enc_key_dim=256,
                 att_num_heads=4,
                 target='bpe',
                 l2=0.0,
                 lstm_dropout=0.1,
                 rec_weight_dropout=0.,
                 with_ctc=False,
                 native_ctc=False,
                 ctc_dropout=0.,
                 ctc_l2=0.,
                 ctc_opts=None,
                 subsample=None,
                 start_conv_init=None,
                 conv_module_init=None,
                 mhsa_init=None,
                 mhsa_out_init=None,
                 ff_init=None,
                 rel_pos_clipping=16,
                 dropout_in=0.1,
                 stoc_layers_prob=0.0,
                 batch_norm_opts=None,
                 pytorch_bn_opts=False,
                 use_ln=False,
                 pooling_str=None,
                 self_att_l2=0.0,
                 sandwich_conv=False):
        """
    :param str input: input layer name
    :param str input_layer: type of input layer which does subsampling
    :param int num_blocks: number of Conformer blocks
    :param int conv_kernel_size: kernel size for conv layers in Convolution module
    :param bool|None specaug: If true, then SpecAug is appliedi wi
    :param str|None activation: activation used to sandwich modules
    :param bool block_final_norm: if True, apply layer norm at the end of each conformer block
    :param bool final_norm: if True, apply layer norm to the output of the encoder
    :param int|None ff_dim: dimension of the first linear layer in FF module
    :param str|None ff_init: FF layers initialization
    :param bool|None ff_bias: If true, then bias is used for the FF layers
    :param float embed_dropout: dropout applied to the source embedding
    :param float dropout: general dropout
    :param float att_dropout: dropout applied to attention weights
    :param int enc_key_dim: encoder key dimension, also denoted as d_model, or d_key
    :param int att_num_heads: the number of attention heads
    :param str target: target labels key name
    :param float l2: add L2 regularization for trainable weights parameters
    :param float lstm_dropout: dropout applied to the input of the LSTMs in case they are used
    :param float rec_weight_dropout: dropout applied to the hidden-to-hidden weight matrices of the LSTM in case used
    :param bool with_ctc: if true, CTC loss is used
    :param bool native_ctc: if true, use returnn native ctc implementation instead of TF implementation
    :param float ctc_dropout: dropout applied on input to ctc
    :param float ctc_l2: L2 applied to the weight matrix of CTC softmax
    :param dict[str] ctc_opts: options for CTC
    """

        self.input = input
        self.input_layer = input_layer

        self.num_blocks = num_blocks
        self.conv_kernel_size = conv_kernel_size

        self.pos_enc = pos_enc
        self.rel_pos_clipping = rel_pos_clipping

        self.ff_bias = ff_bias

        self.specaug = specaug

        self.activation = activation

        self.block_final_norm = block_final_norm

        self.dropout = dropout
        self.att_dropout = att_dropout
        self.lstm_dropout = lstm_dropout

        self.dropout_in = dropout_in

        # key and value dimensions are the same
        self.enc_key_dim = enc_key_dim
        self.enc_value_dim = enc_key_dim
        self.att_num_heads = att_num_heads
        self.enc_key_per_head_dim = enc_key_dim // att_num_heads
        self.enc_val_per_head_dim = enc_key_dim // att_num_heads

        self.ff_dim = ff_dim
        if self.ff_dim is None:
            self.ff_dim = 2 * self.enc_key_dim

        self.target = target

        self.l2 = l2
        self.self_att_l2 = self_att_l2
        self.rec_weight_dropout = rec_weight_dropout

        if batch_norm_opts is None:
            batch_norm_opts = {}

        if pytorch_bn_opts:
            batch_norm_opts['momentum'] = 0.1
            batch_norm_opts['epsilon'] = 1e-3
            batch_norm_opts['update_sample_only_in_training'] = True
            batch_norm_opts['delay_sample_update'] = True

        self.batch_norm_opts = batch_norm_opts

        self.with_ctc = with_ctc
        self.native_ctc = native_ctc
        self.ctc_dropout = ctc_dropout
        self.ctc_loss_scale = ctc_loss_scale
        self.ctc_l2 = ctc_l2
        self.ctc_opts = ctc_opts
        if not self.ctc_opts:
            self.ctc_opts = {}

        self.start_conv_init = start_conv_init
        self.conv_module_init = conv_module_init
        self.mhsa_init = mhsa_init
        self.mhsa_out_init = mhsa_out_init
        self.ff_init = ff_init

        self.sandwich_conv = sandwich_conv

        # add maxpooling layers
        self.subsample = subsample
        self.subsample_list = [1] * num_blocks
        if subsample:
            for idx, s in enumerate(map(int,
                                        subsample.split('_')[:num_blocks])):
                self.subsample_list[idx] = s

        self.network = ReturnnNetwork()

        self.stoc_layers_prob = stoc_layers_prob
        if stoc_layers_prob:
            # this is only used to define the shape for the dropout mask (it needs source)
            self.mask_var = self.network.add_variable_layer('mask_var',
                                                            shape=(1, ),
                                                            init=1)

        self.use_ln = use_ln

        self.pooling_str = pooling_str

    def _get_stoc_layer_dropout(self, layer_index):
        """
    Returns the probability to drop a layer
      p_l = l / L * (1 - p)  where p is a hyperparameter

    :param int layer_index: index of layer
    :rtype float
    """
        return layer_index / self.num_blocks * (1 - self.stoc_layers_prob)

    def _add_stoc_res_layer(self, prefix_name, f_x, x, layer_index):
        """
    Add stochastic layer to the network. the layer will be scaled and masked
      M * F(x) * (1 / 1 - p_l)

    :param prefix_name: prefix name for layer
    :param f_x: module output. e.g self-attention or FF
    :param x: input
    :param int layer_index: index of layer
    :rtype list[str]
    """
        stoc_layer_drop = self._get_stoc_layer_dropout(layer_index)
        stoc_scale = 1 / 1 - stoc_layer_drop
        mask = self.network.add_dropout_layer(
            'stoc_layer{}_mask'.format(layer_index), self.mask_var,
            stoc_layer_drop)
        masked_and_scaled_out = self.network.add_eval_layer(
            '{}_scaled_mask_layer'.format(prefix_name), [mask, f_x],
            eval='source(0) * source(1) * {}'.format(stoc_scale))
        return [masked_and_scaled_out, x]

    def _create_ff_module(self, prefix_name, i, source, layer_index):
        """
    Add Feed Forward Module:
      LN -> FFN -> Swish -> Dropout -> FFN -> Dropout

    :param str prefix_name: some prefix name
    :param int i: FF module index
    :param str source: name of source layer
    :param int layer_index: index of layer
    :return: last layer name of this module
    :rtype: str
    """
        prefix_name = prefix_name + '_ffmod_{}'.format(i)

        ln = self.network.add_layer_norm_layer('{}_ln'.format(prefix_name),
                                               source)

        ff1 = self.network.add_linear_layer('{}_ff1'.format(prefix_name),
                                            ln,
                                            n_out=self.ff_dim,
                                            l2=self.l2,
                                            forward_weights_init=self.ff_init,
                                            with_bias=self.ff_bias)

        swish_act = self.network.add_activation_layer(
            '{}_swish'.format(prefix_name), ff1, activation=self.activation)

        drop1 = self.network.add_dropout_layer('{}_drop1'.format(prefix_name),
                                               swish_act,
                                               dropout=self.dropout)

        ff2 = self.network.add_linear_layer('{}_ff2'.format(prefix_name),
                                            drop1,
                                            n_out=self.enc_key_dim,
                                            l2=self.l2,
                                            forward_weights_init=self.ff_init,
                                            with_bias=self.ff_bias)

        drop2 = self.network.add_dropout_layer('{}_drop2'.format(prefix_name),
                                               ff2,
                                               dropout=self.dropout)

        half_step_ff = self.network.add_eval_layer(
            '{}_half_step'.format(prefix_name), drop2, eval='0.5 * source(0)')

        res_inputs = [half_step_ff, source]

        if self.stoc_layers_prob:
            res_inputs = self._add_stoc_res_layer(prefix_name,
                                                  f_x=half_step_ff,
                                                  x=source,
                                                  layer_index=layer_index)

        ff_module_res = self.network.add_combine_layer(
            '{}_res'.format(prefix_name),
            kind='add',
            source=res_inputs,
            n_out=self.enc_key_dim)

        return ff_module_res

    def _create_mhsa_module(self, prefix_name, source, layer_index):
        """
    Add Multi-Headed Selft-Attention Module:
      LN + MHSA + Dropout

    :param str prefix: some prefix name
    :param str source: name of source layer
    :param int layer_index: index of layer
    :return: last layer name of this module
    :rtype: str
    """
        prefix_name = '{}_self_att'.format(prefix_name)
        ln = self.network.add_layer_norm_layer('{}_ln'.format(prefix_name),
                                               source)
        ln_rel_pos_enc = None

        if self.pos_enc == 'rel':
            ln_rel_pos_enc = self.network.add_relative_pos_encoding_layer(
                '{}_ln_rel_pos_enc'.format(prefix_name),
                ln,
                n_out=self.enc_key_per_head_dim,
                forward_weights_init=self.ff_init,
                clipping=self.rel_pos_clipping)

        mhsa = self.network.add_self_att_layer(
            '{}'.format(prefix_name),
            ln,
            n_out=self.enc_value_dim,
            num_heads=self.att_num_heads,
            total_key_dim=self.enc_key_dim,
            att_dropout=self.att_dropout,
            forward_weights_init=self.mhsa_init,
            key_shift=ln_rel_pos_enc if ln_rel_pos_enc is not None else None,
            l2=self.self_att_l2)

        mhsa_linear = self.network.add_linear_layer(
            '{}_linear'.format(prefix_name),
            mhsa,
            n_out=self.enc_key_dim,
            l2=self.l2,
            forward_weights_init=self.mhsa_out_init,
            with_bias=False)

        drop = self.network.add_dropout_layer('{}_dropout'.format(prefix_name),
                                              mhsa_linear,
                                              dropout=self.dropout)

        res_inputs = [drop, source]

        if self.stoc_layers_prob:
            res_inputs = self._add_stoc_res_layer(prefix_name,
                                                  f_x=drop,
                                                  x=source,
                                                  layer_index=layer_index)

        mhsa_res = self.network.add_combine_layer('{}_res'.format(prefix_name),
                                                  kind='add',
                                                  source=res_inputs,
                                                  n_out=self.enc_value_dim)
        return mhsa_res

    def _create_convolution_module(self,
                                   prefix_name,
                                   source,
                                   layer_index,
                                   half_step=False):
        """
    Add Convolution Module:
      LN + point-wise-conv + GLU + depth-wise-conv + BN + Swish + point-wise-conv + Dropout

    :param str prefix_name: some prefix name
    :param str source: name of source layer
    :param int layer_index: index of layer
    :return: last layer name of this module
    :rtype: str
    """
        prefix_name = '{}_conv_mod'.format(prefix_name)

        ln = self.network.add_layer_norm_layer('{}_ln'.format(prefix_name),
                                               source)

        pointwise_conv1 = self.network.add_linear_layer(
            '{}_pointwise_conv1'.format(prefix_name),
            ln,
            n_out=2 * self.enc_key_dim,
            activation=None,
            l2=self.l2,
            with_bias=self.ff_bias,
            forward_weights_init=self.conv_module_init)

        glu_act = self.network.add_gating_layer('{}_glu'.format(prefix_name),
                                                pointwise_conv1)

        depthwise_conv = self.network.add_conv_layer(
            '{}_depthwise_conv2'.format(prefix_name),
            glu_act,
            n_out=self.enc_key_dim,
            filter_size=(self.conv_kernel_size, ),
            groups=self.enc_key_dim,
            l2=self.l2,
            forward_weights_init=self.conv_module_init)

        if self.use_ln:
            bn = self.network.add_layer_norm_layer(
                '{}_layer_norm'.format(prefix_name), depthwise_conv)
        else:
            bn = self.network.add_batch_norm_layer('{}_bn'.format(prefix_name),
                                                   depthwise_conv,
                                                   opts=self.batch_norm_opts)

        swish_act = self.network.add_activation_layer(
            '{}_swish'.format(prefix_name), bn, activation='swish')

        pointwise_conv2 = self.network.add_linear_layer(
            '{}_pointwise_conv2'.format(prefix_name),
            swish_act,
            n_out=self.enc_key_dim,
            activation=None,
            l2=self.l2,
            with_bias=self.ff_bias,
            forward_weights_init=self.conv_module_init)

        drop = self.network.add_dropout_layer('{}_drop'.format(prefix_name),
                                              pointwise_conv2,
                                              dropout=self.dropout)

        if half_step:
            drop = self.network.add_eval_layer(
                '{}_half_step'.format(prefix_name),
                drop,
                eval='0.5 * source(0)')

        res_inputs = [drop, source]

        if self.stoc_layers_prob:
            res_inputs = self._add_stoc_res_layer(prefix_name,
                                                  f_x=drop,
                                                  x=source,
                                                  layer_index=layer_index)

        res = self.network.add_combine_layer('{}_res'.format(prefix_name),
                                             kind='add',
                                             source=res_inputs,
                                             n_out=self.enc_key_dim)
        return res

    def _create_conformer_block(self, i, source):
        """
    Add the whole Conformer block:
      x1 = x0 + 1/2 * FFN(x0)             (FFN module 1)
      x2 = x1 + MHSA(x1)                  (MHSA)
      x3 = x2 + Conv(x2)                  (Conv module)
      x4 = LayerNorm(x3 + 1/2 * FFN(x3))  (FFN module 2)

    :param int i: layer index
    :param str source: name of source layer
    :return: last layer name of this module
    :rtype: str
    """
        prefix_name = 'conformer_block_%02i' % i
        ff_module1 = self._create_ff_module(prefix_name, 1, source, i)

        mhsa_input = ff_module1
        if self.sandwich_conv:
            conv_module1 = self._create_convolution_module(prefix_name +
                                                           '_sandwich',
                                                           ff_module1,
                                                           i,
                                                           half_step=True)
            mhsa_input = conv_module1
        mhsa = self._create_mhsa_module(prefix_name, mhsa_input, i)
        conv_module = self._create_convolution_module(
            prefix_name, mhsa, i, half_step=self.sandwich_conv)

        ff_module2 = self._create_ff_module(prefix_name, 2, conv_module, i)
        res = ff_module2
        if self.block_final_norm:
            res = self.network.add_layer_norm_layer(
                '{}_ln'.format(prefix_name), res)
        if self.subsample:
            assert 0 <= i - 1 < len(self.subsample)
            subsample_factor = self.subsample_list[i - 1]
            if subsample_factor > 1:
                res = self.network.add_pool_layer(
                    res + '_pool{}'.format(i),
                    res,
                    pool_size=(subsample_factor, ))
        res = self.network.add_copy_layer(prefix_name, res)
        return res

    def create_network(self):
        """
    ConvSubsampling/LSTM -> Linear -> Dropout -> [Conformer Blocks] x N
    """
        data = self.input
        if self.specaug:
            data = self.network.add_eval_layer(
                'source',
                data,
                eval=
                "self.network.get_config().typed_value('transform')(source(0, as_data=True), network=self.network)"
            )

        subsampled_input = None
        if self.input_layer is None:
            subsampled_input = data
        elif 'lstm' in self.input_layer:
            sample_factor = int(self.input_layer.split('-')[1])
            pool_sizes = None
            if sample_factor == 2:
                pool_sizes = [2, 1]
            elif sample_factor == 4:
                pool_sizes = [2, 2]
            elif sample_factor == 6:
                pool_sizes = [3, 2]
            # add 2 LSTM layers with max pooling to subsample and encode positional information
            subsampled_input = self.network.add_lstm_layers(
                data,
                num_layers=2,
                lstm_dim=self.enc_key_dim,
                dropout=self.lstm_dropout,
                bidirectional=True,
                rec_weight_dropout=self.rec_weight_dropout,
                l2=self.l2,
                pool_sizes=pool_sizes)
        elif self.input_layer == 'conv':
            # subsample by 4
            subsampled_input = self.network.add_conv_block(
                'conv_merged',
                data,
                hwpc_sizes=[((3, 3), (2, 2), self.enc_key_dim),
                            ((3, 3), (2, 2), self.enc_key_dim)],
                l2=self.l2,
                activation='relu',
                init=self.start_conv_init)
        elif self.input_layer == 'vgg':
            subsampled_input = self.network.add_conv_block(
                'vgg_conv_merged',
                data,
                hwpc_sizes=[((3, 3), (2, 2), 32), ((3, 3), (2, 2), 64)],
                l2=self.l2,
                activation='relu',
                init=self.start_conv_init)
        elif self.input_layer == 'neural_sp_conv':
            subsampled_input = self.network.add_conv_block(
                'conv_merged',
                data,
                hwpc_sizes=([(3, 3), (1, 1), 32], [(3, 3), (2, 2), 32]),
                l2=self.l2,
                activation='relu',
                init=self.start_conv_init)

        assert subsampled_input is not None

        source_linear = self.network.add_linear_layer(
            'source_linear',
            subsampled_input,
            n_out=self.enc_key_dim,
            l2=self.l2,
            forward_weights_init=self.ff_init,
            with_bias=False)

        # add positional encoding
        if self.pos_enc == 'abs':
            source_linear = self.network.add_pos_encoding_layer(
                '{}_abs_pos_enc'.format(subsampled_input), source_linear)

        if self.dropout_in:
            source_linear = self.network.add_dropout_layer(
                'source_dropout', source_linear, dropout=self.dropout_in)

        conformer_block_src = source_linear
        for i in range(1, self.num_blocks + 1):
            conformer_block_src = self._create_conformer_block(
                i, conformer_block_src)

        encoder = self.network.add_copy_layer('encoder', conformer_block_src)

        if self.with_ctc:
            default_ctc_loss_opts = {'beam_width': 1}
            if self.native_ctc:
                default_ctc_loss_opts['use_native'] = True
            else:
                self.ctc_opts.update(
                    {"ignore_longer_outputs_than_inputs":
                     True})  # always enable
            if self.ctc_opts:
                default_ctc_loss_opts['ctc_opts'] = self.ctc_opts
            self.network.add_softmax_layer('ctc',
                                           encoder,
                                           l2=self.ctc_l2,
                                           target=self.target,
                                           loss='ctc',
                                           dropout=self.ctc_dropout,
                                           loss_opts=default_ctc_loss_opts,
                                           loss_scale=self.ctc_loss_scale)
        return encoder
コード例 #19
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
    def add_decoder_subnetwork(self, subnet_unit: ReturnnNetwork):

        # target embedding
        if self.embed_dropout:
            # TODO: this is not a good approach. if i want to load a checkpoint from a trained model without embed dropout,
            # i would need to remap variable name target_embed to target_embed0 to load target_embed0/W
            subnet_unit.add_linear_layer('target_embed0',
                                         'output',
                                         n_out=self.embed_dim,
                                         initial_output=0,
                                         with_bias=False)
            subnet_unit.add_dropout_layer('target_embed',
                                          'target_embed0',
                                          dropout=self.embed_dropout,
                                          dropout_noise_shape={'*': None})
        else:
            subnet_unit.add_linear_layer('target_embed',
                                         'output',
                                         n_out=self.embed_dim,
                                         initial_output=0,
                                         with_bias=False)

        subnet_unit.add_compare_layer('end', source='output',
                                      value=0)  # sentence end token

        # ------ attention location-awareness ------ #

        # conv-based
        if self.loc_conv_att_filter_size:
            assert self.loc_conv_att_num_channels
            pad_left = subnet_unit.add_pad_layer(
                'feedback_pad_left',
                'prev:att_weights',
                axes='s:0',
                padding=((self.loc_conv_att_filter_size - 1) // 2, 0),
                value=0)
            pad_right = subnet_unit.add_pad_layer(
                'feedback_pad_right',
                pad_left,
                axes='s:0',
                padding=(0, (self.loc_conv_att_filter_size - 1) // 2),
                value=0)
            loc_att_conv = subnet_unit.add_conv_layer(
                'loc_att_conv',
                pad_right,
                activation=None,
                with_bias=False,
                filter_size=(self.loc_conv_att_filter_size, ),
                padding='valid',
                n_out=self.loc_conv_att_num_channels,
                l2=self.l2)
            subnet_unit.add_linear_layer('weight_feedback',
                                         loc_att_conv,
                                         activation=None,
                                         with_bias=False,
                                         n_out=self.enc_key_dim)
        else:
            # additive
            subnet_unit.add_eval_layer(
                'accum_att_weights', [
                    "prev:accum_att_weights", "att_weights",
                    "base:inv_fertility"
                ],
                eval='source(0) + source(1) * source(2) * 0.5',
                out_type={
                    "dim": self.att_num_heads,
                    "shape": (None, self.att_num_heads)
                })
            subnet_unit.add_linear_layer('weight_feedback',
                                         'prev:accum_att_weights',
                                         n_out=self.enc_key_dim,
                                         with_bias=False)

        subnet_unit.add_linear_layer('s_transformed',
                                     's',
                                     n_out=self.enc_key_dim,
                                     with_bias=False)
        subnet_unit.add_combine_layer(
            'energy_in', ['base:enc_ctx', 'weight_feedback', 's_transformed'],
            kind='add',
            n_out=self.enc_key_dim)
        subnet_unit.add_activation_layer('energy_tanh',
                                         'energy_in',
                                         activation='tanh')
        subnet_unit.add_linear_layer('energy',
                                     'energy_tanh',
                                     n_out=self.att_num_heads,
                                     with_bias=False)

        if self.att_dropout:
            subnet_unit.add_softmax_over_spatial_layer('att_weights0',
                                                       'energy')
            subnet_unit.add_dropout_layer('att_weights',
                                          'att_weights0',
                                          dropout=self.att_dropout,
                                          dropout_noise_shape={'*': None})
        else:
            if self.relax_att_scale:
                subnet_unit.add_softmax_over_spatial_layer(
                    'att_weights0', 'energy')
                subnet_unit.add_length_layer('encoder_len',
                                             'base:encoder',
                                             dtype='float32')  # [B]
                subnet_unit.add_eval_layer('scaled_encoder_len',
                                           source=['encoder_len'],
                                           eval='{} / source(0)'.format(
                                               self.relax_att_scale))
                subnet_unit.add_eval_layer(
                    'att_weights',
                    source=['att_weights0', 'scaled_encoder_len'],
                    eval='{} * source(0) + source(1)'.format(
                        1 - self.relax_att_scale))
            else:
                subnet_unit.add_softmax_over_spatial_layer(
                    'att_weights', 'energy')

        subnet_unit.add_generic_att_layer('att0',
                                          weights='att_weights',
                                          base='base:enc_value')
        subnet_unit.add_merge_dims_layer('att', 'att0', axes='except_batch')

        # LM-like component same as here https://arxiv.org/pdf/2001.07263.pdf
        lstm_lm_component = None
        if self.add_lstm_lm:
            lstm_lm_component = subnet_unit.add_rnn_cell_layer(
                'lm_like_s',
                'prev:target_embed',
                n_out=self.lstm_lm_dim,
                l2=self.l2)

        lstm_inputs = []
        if lstm_lm_component:
            lstm_inputs += [lstm_lm_component]
        else:
            lstm_inputs += ['prev:target_embed']
        lstm_inputs += ['prev:att']

        if self.dec_state_no_label_ctx:
            lstm_inputs = ['prev:att']  # no label feedback

        # LSTM decoder
        if self.dec_zoneout:
            subnet_unit.add_rnn_cell_layer('s',
                                           lstm_inputs,
                                           n_out=self.dec_lstm_num_units,
                                           unit='zoneoutlstm',
                                           unit_opts={
                                               'zoneout_factor_cell': 0.15,
                                               'zoneout_factor_output': 0.05
                                           })
        else:
            if self.rec_weight_dropout:
                # a rec layer with unit nativelstm2 is required to use rec_weight_dropout
                subnet_unit.add_rec_layer(
                    's',
                    lstm_inputs,
                    n_out=self.dec_lstm_num_units,
                    l2=self.l2,
                    rec_weight_dropout=self.rec_weight_dropout,
                    unit='NativeLSTM2')
            else:
                subnet_unit.add_rnn_cell_layer('s',
                                               lstm_inputs,
                                               n_out=self.dec_lstm_num_units,
                                               l2=self.l2)

        # AM softmax output layer
        if self.dec_state_no_label_ctx and self.add_lstm_lm:
            subnet_unit.add_linear_layer(
                'readout_in', ["lm_like_s", "prev:target_embed", "att"],
                n_out=self.dec_output_num_units)
            if self.add_no_label_ctx_s_to_output:
                subnet_unit.add_linear_layer(
                    'readout_in',
                    ["lm_like_s", "s", "prev:target_embed", "att"],
                    n_out=self.dec_output_num_units)
        else:
            subnet_unit.add_linear_layer('readout_in',
                                         ["s", "prev:target_embed", "att"],
                                         n_out=self.dec_output_num_units)

        if self.reduceout:
            subnet_unit.add_reduceout_layer('readout', 'readout_in')
        else:
            subnet_unit.add_copy_layer('readout', 'readout_in')

        if self.local_fusion_opts:
            output_prob = subnet_unit.add_softmax_layer('output_prob',
                                                        'readout',
                                                        l2=self.l2,
                                                        target=self.target,
                                                        dropout=self.dropout)
            self._add_local_fusion(subnet_unit, am_output_prob=output_prob)
        elif self.mwer:
            # only MWER so CE is disabled
            output_prob = subnet_unit.add_softmax_layer('output_prob',
                                                        'readout',
                                                        l2=self.l2,
                                                        target=self.target,
                                                        dropout=self.dropout)
        else:
            ce_loss_opts = {'label_smoothing': self.label_smoothing}
            if self.ce_loss_scale:
                ce_loss_opts['scale'] = self.ce_loss_scale
            output_prob = subnet_unit.add_softmax_layer('output_prob',
                                                        'readout',
                                                        l2=self.l2,
                                                        loss='ce',
                                                        loss_opts=ce_loss_opts,
                                                        target=self.target,
                                                        dropout=self.dropout)

        # do not load the bias
        if self.remove_softmax_bias:
            subnet_unit['output_prob']['with_bias'] = False

        # for prior LM estimation
        prior_output_prob = None
        if self.prior_lm_opts:
            prior_output_prob = self._create_prior_net(
                subnet_unit, self.prior_lm_opts
            )  # this require preload_from_files in config

        # Beam search
        # only support shallow fusion for now
        if self.ext_lm_opts:
            self._add_external_LM(subnet_unit, output_prob, prior_output_prob)
        else:
            if self.coverage_term_scale:
                output_prob = subnet_unit.add_eval_layer(
                    'combo_output_prob',
                    eval='safe_log(source(0)) + {} * source(1)'.format(
                        self.coverage_term_scale),
                    source=['output_prob', 'accum_coverage'])
                input_type = 'log_prob'
            else:
                output_prob = 'output_prob'
                input_type = None

            if self.length_norm:
                subnet_unit.add_choice_layer('output',
                                             output_prob,
                                             target=self.target,
                                             beam_size=self.beam_size,
                                             initial_output=0,
                                             input_type=input_type)
            else:
                subnet_unit.add_choice_layer(
                    'output',
                    output_prob,
                    target=self.target,
                    beam_size=self.beam_size,
                    initial_output=0,
                    length_normalization=self.length_norm,
                    input_type=input_type)

        if self.ilmt_opts:
            self._create_ilmt_net(subnet_unit)

        # recurrent subnetwork
        dec_output = self.network.add_subnet_rec_layer(
            'output',
            unit=subnet_unit.get_net(),
            target=self.target,
            source=self.source)

        return dec_output
コード例 #20
0
    def __init__(self,
                 input='data',
                 enc_layers=6,
                 bidirectional=True,
                 residual_lstm=False,
                 residual_proj_dim=None,
                 specaug=True,
                 with_conv=True,
                 dropout=0.3,
                 pool_sizes='3_2',
                 lstm_dim=None,
                 enc_key_dim=1024,
                 enc_value_dim=2048,
                 att_num_heads=1,
                 target='bpe',
                 l2=None,
                 rec_weight_dropout=None,
                 with_ctc=False,
                 ctc_dropout=0.,
                 ctc_l2=0.,
                 ctc_opts=None,
                 enc_proj_dim=None,
                 ctc_loss_scale=None,
                 conv_time_pooling=None):
        """
    :param str input: (layer) name of the network input
    :param int enc_layers: the number of encoder layers
    :param bool bidirectional: If set, bidirectional LSTMs are used
    :param bool specaug: If True, SpecAugment is used
    :param bool with_conv: if True, conv layers are applied initially
    :param float dropout: Dropout applied on the input of multiple layers
    :param str|int|List[int]|None pool_sizes: a list of pool sizes between LSTM layers
    :param int enc_key_dim: attention key dimension
    :param int enc_value_dim: attention value dimension
    :param int att_num_heads: number of attention heads
    :param str target: target data key name
    :param float|None l2: weight decay with l2 norm
    :param float|None rec_weight_dropout: dropout applied to the hidden-to-hidden LSTM weight matrices
    :param bool with_ctc: if set, CTC is used
    :param float ctc_dropout: dropout applied on input to ctc
    :param float ctc_l2: L2 applied to the weight matrix of CTC softmax
    :param dict[str] ctc_opts: options for CTC
    """

        self.input = input
        self.enc_layers = enc_layers

        if pool_sizes is not None:
            if isinstance(pool_sizes, str):
                pool_sizes = list(map(
                    int, pool_sizes.split('_'))) + [1] * (enc_layers - 3)
            elif isinstance(pool_sizes, int):
                pool_sizes = [pool_sizes] * (self.enc_layers - 1)

            assert isinstance(pool_sizes, list), 'pool_sizes must be a list'
            assert all([isinstance(e, int) for e in pool_sizes
                        ]), 'pool_sizes must only contains integers'
            assert len(pool_sizes) < enc_layers

        self.pool_sizes = pool_sizes

        if conv_time_pooling is None:
            self.conv_time_pooling = [1, 1]
        else:
            self.conv_time_pooling = list(
                map(int, conv_time_pooling.split('_')))

        self.bidirectional = bidirectional

        self.residual_lstm = residual_lstm
        self.residual_proj_dim = residual_proj_dim

        self.specaug = specaug
        self.with_conv = with_conv
        self.dropout = dropout

        self.enc_key_dim = enc_key_dim
        self.enc_value_dim = enc_value_dim
        self.att_num_heads = att_num_heads
        self.enc_key_per_head_dim = enc_key_dim // att_num_heads
        self.enc_val_per_head_dim = enc_value_dim // att_num_heads
        self.lstm_dim = lstm_dim
        if lstm_dim is None:
            self.lstm_dim = enc_value_dim // 2

        self.target = target

        self.l2 = l2
        self.rec_weight_dropout = rec_weight_dropout

        self.with_ctc = with_ctc
        self.ctc_dropout = ctc_dropout
        self.ctc_l2 = ctc_l2
        self.ctc_loss_scale = ctc_loss_scale
        self.ctc_opts = ctc_opts
        if self.ctc_opts is None:
            self.ctc_opts = {}

        self.enc_proj_dim = enc_proj_dim

        self.network = ReturnnNetwork()
コード例 #21
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
    def add_decoder_subnetwork(self, subnet_unit: ReturnnNetwork):

        subnet_unit.add_compare_layer('end', source='output',
                                      value=0)  # sentence end token

        # target embedding
        subnet_unit.add_linear_layer(
            'target_embed0',
            'output',
            n_out=self.embed_dim,
            initial_output=0,
            with_bias=False,
            l2=self.l2,
            forward_weights_init=self.embed_weight_init)

        subnet_unit.add_dropout_layer('target_embed',
                                      'target_embed0',
                                      dropout=self.embed_dropout,
                                      dropout_noise_shape={'*': None})

        # attention
        att = AttentionMechanism(
            enc_key_dim=self.enc_key_dim,
            att_num_heads=self.att_num_heads,
            att_dropout=self.att_dropout,
            l2=self.l2,
            loc_filter_size=self.loc_conv_att_filter_size,
            loc_num_channels=self.loc_conv_att_num_channels)
        subnet_unit.update(att.create())

        # LM-like component same as here https://arxiv.org/pdf/2001.07263.pdf
        lstm_lm_component = None
        if self.add_lstm_lm:
            lstm_lm_component = subnet_unit.add_rnn_cell_layer(
                'lm_like_s',
                'prev:target_embed',
                n_out=self.lstm_lm_dim,
                l2=self.l2)

        lstm_inputs = []
        if lstm_lm_component:
            lstm_inputs += [lstm_lm_component]
        else:
            lstm_inputs += ['prev:target_embed']
        lstm_inputs += ['prev:att']

        # LSTM decoder (or decoder state)
        if self.dec_zoneout:
            subnet_unit.add_rnn_cell_layer('s',
                                           lstm_inputs,
                                           n_out=self.dec_lstm_num_units,
                                           l2=self.l2,
                                           weights_init=self.lstm_weights_init,
                                           unit='zoneoutlstm',
                                           unit_opts={
                                               'zoneout_factor_cell': 0.15,
                                               'zoneout_factor_output': 0.05
                                           })
        else:
            if self.rec_weight_dropout:
                # a rec layer with unit nativelstm2 is required to use rec_weight_dropout
                subnet_unit.add_rec_layer(
                    's',
                    lstm_inputs,
                    n_out=self.dec_lstm_num_units,
                    l2=self.l2,
                    unit='NativeLSTM2',
                    rec_weight_dropout=self.rec_weight_dropout,
                    weights_init=self.lstm_weights_init)
            else:
                subnet_unit.add_rnn_cell_layer(
                    's',
                    lstm_inputs,
                    n_out=self.dec_lstm_num_units,
                    l2=self.l2,
                    weights_init=self.lstm_weights_init)

        # ASR softmax output layer
        subnet_unit.add_linear_layer('readout_in',
                                     ["s", "prev:target_embed", "att"],
                                     n_out=self.dec_output_num_units,
                                     l2=self.l2)

        if self.reduceout:
            subnet_unit.add_reduceout_layer('readout', 'readout_in')
        else:
            subnet_unit.add_copy_layer('readout', 'readout_in')

        output_prob = subnet_unit.add_softmax_layer(
            'output_prob',
            'readout',
            l2=self.l2,
            loss='ce',
            loss_opts={'label_smoothing': self.label_smoothing},
            target=self.target,
            dropout=self.dropout)

        subnet_unit.add_choice_layer('output',
                                     output_prob,
                                     target=self.target,
                                     beam_size=self.beam_size,
                                     initial_output=0)

        # recurrent subnetwork
        dec_output = self.network.add_subnet_rec_layer(
            'output',
            unit=subnet_unit.get_net(),
            target=self.target,
            source=self.source)

        return dec_output
コード例 #22
0
    def __init__(self,
                 base_model,
                 target='bpe',
                 dec_layers=6,
                 beam_size=12,
                 ff_init=None,
                 ff_dim=2048,
                 ff_act='relu',
                 att_num_heads=8,
                 dropout=0.1,
                 att_dropout=0.0,
                 softmax_dropout=0.0,
                 embed_dropout=0.1,
                 l2=0.0,
                 embed_pos_enc=False,
                 apply_embed_weight=False,
                 label_smoothing=0.1,
                 mhsa_init=None,
                 mhsa_out_init=None,
                 pos_enc='rel',
                 rel_pos_clipping=16):

        self.base_model = base_model
        self.enc_value_dim = base_model.enc_value_dim
        self.enc_key_dim = base_model.enc_key_dim
        self.enc_att_num_heads = base_model.att_num_heads
        self.enc_key_per_head_dim = base_model.enc_key_per_head_dim
        self.enc_val_per_head_dim = base_model.enc_val_per_head_dim

        self.att_num_heads = att_num_heads

        self.target = target
        self.dec_layers = dec_layers
        self.beam_size = beam_size

        self.ff_init = ff_init
        self.ff_dim = ff_dim
        self.ff_act = ff_act

        self.mhsa_init = mhsa_init
        self.mhsa_init_out = mhsa_out_init

        self.pos_enc = pos_enc
        self.rel_pos_clipping = rel_pos_clipping

        self.dropout = dropout
        self.softmax_dropout = softmax_dropout
        self.att_dropout = att_dropout
        self.label_smoothing = label_smoothing

        self.l2 = l2

        self.embed_dropout = embed_dropout
        self.embed_pos_enc = embed_pos_enc

        self.embed_weight = None

        if apply_embed_weight:
            self.embed_weight = self.enc_value_dim**0.5

        self.decision_layer_name = None

        self.network = ReturnnNetwork()
        self.subnet_unit = ReturnnNetwork()
        self.output_prob = None