def _create_ff_block(self, subnet_unit: ReturnnNetwork, source, prefix):
     prefix = '{}_ff'.format(prefix)
     ln = subnet_unit.add_layer_norm_layer('{}_laynorm'.format(prefix),
                                           source)
     conv1 = subnet_unit.add_linear_layer(
         '{}_conv1'.format(prefix),
         ln,
         with_bias=True,
         activation='relu',
         forward_weights_init=self.forward_weights_init,
         n_out=self.ff_dim)
     conv2 = subnet_unit.add_linear_layer(
         '{}_conv2'.format(prefix),
         conv1,
         with_bias=True,
         activation=None,
         forward_weights_init=self.forward_weights_init,
         n_out=self.out_dim)
     drop = subnet_unit.add_dropout_layer('{}_drop'.format(prefix),
                                          conv2,
                                          dropout=self.dropout)
     out = subnet_unit.add_combine_layer('{}_out'.format(prefix),
                                         [drop, source],
                                         kind='add',
                                         n_out=self.out_dim)
     return out
 def _create_masked_mhsa(self, subnet_unit: ReturnnNetwork, source, prefix):
     prefix = '{}_self_att'.format(prefix)
     ln = subnet_unit.add_layer_norm_layer('{}_laynorm'.format(prefix),
                                           source)
     att = subnet_unit.add_self_att_layer(
         '{}_att'.format(prefix),
         ln,
         forward_weights_init=self.forward_weights_init,
         att_dropout=self.att_dropout,
         attention_left_only=True,
         n_out=self.v_dim,
         num_heads=self.att_num_heads,
         total_key_dim=self.qk_dim)
     lin = subnet_unit.add_linear_layer(
         '{}_lin'.format(prefix),
         att,
         n_out=self.out_dim,
         with_bias=False,
         forward_weights_init=self.forward_weights_init)
     drop = subnet_unit.add_dropout_layer('{}_drop'.format(prefix),
                                          lin,
                                          dropout=self.dropout)
     out = subnet_unit.add_combine_layer('{}_out'.format(prefix),
                                         [drop, source],
                                         kind='add',
                                         n_out=self.out_dim)
     return out
  def _create_prior_net(self, subnet_unit: ReturnnNetwork):
    prior_att_input = self._add_prior_input(subnet_unit)

    # for the first frame in decoding, don't use average but zero always
    is_first_frame = subnet_unit.add_compare_layer('is_first_frame', source=':i', kind='equal', value=0)
    zero_att = subnet_unit.add_eval_layer('zero_att', 'att', eval='tf.zeros_like(source(0))')
    prev_att = subnet_unit.add_switch_layer(
      'prev_att', condition=is_first_frame, true_from=zero_att, false_from=prior_att_input)

    key_names = ['s', 'readout_in', 'readout', 'output_prob']
    for key_name in key_names:
      d = copy.deepcopy(subnet_unit[key_name])
      # update attention input
      new_sources = []
      from_list = d['from']
      if isinstance(from_list, str):
        from_list = [from_list]
      assert isinstance(from_list, list)
      for src in from_list:
        if 'att' in src:
          if src.split(':')[0] == 'prev':
            assert prev_att not in new_sources
            new_sources += [prev_att]  # switched based on decoder index
          else:
            new_sources += [prior_att_input]
        elif src in key_names:
          new_sources += [('prev:' if 'prev' in src else '') + 'prior_{}'.format(src.split(':')[-1])]
        else:
          new_sources += [src]
      d['from'] = new_sources
      subnet_unit['prior_{}'.format(key_name)] = d
    return 'prior_output_prob'
    def create_network(self):
        subnet_unit = ReturnnNetwork()
        target_embed_raw = subnet_unit.add_linear_layer(
            '{}target_embed_raw'.format(self.prefix_name),
            self.source,
            forward_weights_init=self.forward_weights_init,
            n_out=self.embed_dim,
            with_bias=False,
            param_device='CPU' if self.emb_cpu_lookup else None)

        target_embed_with_pos = subnet_unit.add_pos_encoding_layer(
            '{}target_embed_with_pos'.format(self.prefix_name),
            target_embed_raw)

        target_embed = subnet_unit.add_dropout_layer(
            '{}target_embed'.format(self.prefix_name),
            target_embed_with_pos,
            dropout=self.embed_dropout)

        target_embed_lin = subnet_unit.add_linear_layer(
            '{}target_embed_lin'.format(self.prefix_name),
            target_embed,
            with_bias=False,
            forward_weights_init=self.forward_weights_init,
            n_out=self.out_dim)

        x = target_embed_lin
        for i in range(self.num_layers):
            x = self._create_decoder_block(subnet_unit, x, i)

        # final LN
        decoder = subnet_unit.add_layer_norm_layer(
            '{}decoder'.format(self.prefix_name), x)

        subnet_unit.add_softmax_layer(
            '{}output'.format(self.prefix_name),
            decoder,
            forward_weights_init=self.forward_weights_init,
            loss='ce',
            target=self.target,
            with_bias=True,
            dropout=self.dropout)

        if self.use_as_ext_lm:
            self.network = copy.deepcopy(subnet_unit)
        else:
            self.network.add_subnet_rec_layer('output',
                                              unit=subnet_unit.get_net(),
                                              target=self.target,
                                              source=self.source)

        return 'output'
    def __init__(self,
                 source='data:delayed',
                 target='data',
                 num_layers=6,
                 ff_dim=4096,
                 att_num_heads=8,
                 out_dim=1024,
                 qk_dim=1024,
                 v_dim=1024,
                 dropout=0.0,
                 att_dropout=0.0,
                 embed_dropout=0.0,
                 embed_dim=128,
                 emb_cpu_lookup=True,
                 forward_weights_init=None,
                 prefix_name=None,
                 use_as_ext_lm=False,
                 vocab_size=None):

        self.source = source
        self.target = target
        self.num_layers = num_layers

        self.ff_dim = ff_dim
        self.att_num_heads = att_num_heads
        self.out_dim = out_dim
        self.qk_dim = qk_dim
        self.v_dim = v_dim
        self.dropout = dropout
        self.embed_dropout = embed_dropout
        self.att_dropout = att_dropout
        self.embed_dim = embed_dim
        self.emb_cpu_lookup = emb_cpu_lookup

        # use this as default for now
        if forward_weights_init is None:
            forward_weights_init = "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)"
        self.forward_weights_init = forward_weights_init

        self.use_as_ext_lm = use_as_ext_lm
        self.vocab_size = vocab_size
        if not prefix_name:
            prefix_name = ''
        self.prefix_name = prefix_name

        self.network = ReturnnNetwork()
Beispiel #6
0
    def create(self):
        out_net = ReturnnNetwork()

        pad_left = out_net.add_pad_layer('feedback_pad_left',
                                         'prev:att_weights',
                                         axes='s:0',
                                         padding=((self.filter_size - 1) // 2,
                                                  0),
                                         value=0)

        pad_right = out_net.add_pad_layer('feedback_pad_right',
                                          pad_left,
                                          axes='s:0',
                                          padding=(0, (self.filter_size - 1) //
                                                   2),
                                          value=0)

        loc_att_conv = out_net.add_conv_layer('loc_att_conv',
                                              pad_right,
                                              activation=None,
                                              with_bias=False,
                                              filter_size=(self.filter_size, ),
                                              padding='valid',
                                              n_out=self.num_channels,
                                              l2=self.l2)

        self.name = out_net.add_linear_layer('weight_feedback',
                                             loc_att_conv,
                                             activation=None,
                                             with_bias=False,
                                             n_out=self.enc_key_dim)

        return out_net.get_net()
  def _add_prior_input(self, subnet_unit: ReturnnNetwork):
    prior_type = self.prior_lm_opts.get('type', None)
    assert prior_type is not None, 'prior_type not defined'

    if prior_type == 'mini_lstm':
      # add mini lstm layers
      subnet_unit.add_rec_layer(
        'mini_att_lstm', 'prev:' + self.prior_lm_opts.get('target_embed_name', 'target_embed'),
        n_out=self.prior_lm_opts.get('mini_lstm_dim', 50), l2=self.prior_lm_opts.get('l2', 0.0))
      prior_att_input = subnet_unit.add_linear_layer(
        'mini_att', 'mini_att_lstm', activation=None, n_out=512, l2=0.0001)
    elif prior_type == 'zero':
      prior_att_input = subnet_unit.add_eval_layer(
        'zero_att', 'transformer_decoder_01_att', eval='tf.zeros_like(source(0))')
    else:
      raise ValueError()

    return prior_att_input
Beispiel #8
0
    def create(self):
        out_net = ReturnnNetwork()

        out_net.add_eval_layer(
            'accum_att_weights',
            ["prev:accum_att_weights", "att_weights", "base:inv_fertility"],
            eval='source(0) + source(1) * source(2) * 0.5',
            out_type={
                "dim": self.att_num_heads,
                "shape": (None, self.att_num_heads)
            })

        self.name = out_net.add_linear_layer('weight_feedback',
                                             'prev:accum_att_weights',
                                             n_out=self.enc_key_dim,
                                             with_bias=False)

        return out_net.get_net()
  def _add_prior_input(self, subnet_unit: ReturnnNetwork):
    prior_type = self.prior_lm_opts['type']
    assert prior_type == 'mini_lstm'

    num_layers = self.prior_lm_opts['dec_layers']
    assert num_layers > 0

    variant = self.prior_lm_opts['mini_lstm_variant']
    assert variant in ['single', 'many']

    if variant == 'single':
      subnet_unit.add_rec_layer(
        'mini_att_lstm', 'prev:' + self.prior_lm_opts.get('target_embed_name', 'target_embed'),
        n_out=self.prior_lm_opts.get('mini_lstm_dim', 50), l2=self.prior_lm_opts.get('l2', 0.0))
    else:
      for i in range(1, num_layers + 1):
        subnet_unit.add_rec_layer(
          'mini_att_lstm_%02i' % i, 'prev:' + self.prior_lm_opts.get('target_embed_name', 'target_embed'),
          n_out=self.prior_lm_opts.get('mini_lstm_dim', 50), l2=self.prior_lm_opts.get('l2', 0.0))

    for i in range(1, num_layers + 1):
      subnet_unit.add_linear_layer(
        'mini_att_%02i' % i, 'mini_att_lstm_%02i' % i if variant == 'many' else 'mini_att_lstm', activation=None,
        n_out=512, l2=0.0001)
class TransformerLM:
    def __init__(self,
                 source='data:delayed',
                 target='data',
                 num_layers=6,
                 ff_dim=4096,
                 att_num_heads=8,
                 out_dim=1024,
                 qk_dim=1024,
                 v_dim=1024,
                 dropout=0.0,
                 att_dropout=0.0,
                 embed_dropout=0.0,
                 embed_dim=128,
                 emb_cpu_lookup=True,
                 forward_weights_init=None,
                 prefix_name=None,
                 use_as_ext_lm=False,
                 vocab_size=None):

        self.source = source
        self.target = target
        self.num_layers = num_layers

        self.ff_dim = ff_dim
        self.att_num_heads = att_num_heads
        self.out_dim = out_dim
        self.qk_dim = qk_dim
        self.v_dim = v_dim
        self.dropout = dropout
        self.embed_dropout = embed_dropout
        self.att_dropout = att_dropout
        self.embed_dim = embed_dim
        self.emb_cpu_lookup = emb_cpu_lookup

        # use this as default for now
        if forward_weights_init is None:
            forward_weights_init = "variance_scaling_initializer(mode='fan_in', distribution='uniform', scale=1.0)"
        self.forward_weights_init = forward_weights_init

        self.use_as_ext_lm = use_as_ext_lm
        self.vocab_size = vocab_size
        if not prefix_name:
            prefix_name = ''
        self.prefix_name = prefix_name

        self.network = ReturnnNetwork()

    def _create_ff_block(self, subnet_unit: ReturnnNetwork, source, prefix):
        prefix = '{}_ff'.format(prefix)
        ln = subnet_unit.add_layer_norm_layer('{}_laynorm'.format(prefix),
                                              source)
        conv1 = subnet_unit.add_linear_layer(
            '{}_conv1'.format(prefix),
            ln,
            with_bias=True,
            activation='relu',
            forward_weights_init=self.forward_weights_init,
            n_out=self.ff_dim)
        conv2 = subnet_unit.add_linear_layer(
            '{}_conv2'.format(prefix),
            conv1,
            with_bias=True,
            activation=None,
            forward_weights_init=self.forward_weights_init,
            n_out=self.out_dim)
        drop = subnet_unit.add_dropout_layer('{}_drop'.format(prefix),
                                             conv2,
                                             dropout=self.dropout)
        out = subnet_unit.add_combine_layer('{}_out'.format(prefix),
                                            [drop, source],
                                            kind='add',
                                            n_out=self.out_dim)
        return out

    def _create_masked_mhsa(self, subnet_unit: ReturnnNetwork, source, prefix):
        prefix = '{}_self_att'.format(prefix)
        ln = subnet_unit.add_layer_norm_layer('{}_laynorm'.format(prefix),
                                              source)
        att = subnet_unit.add_self_att_layer(
            '{}_att'.format(prefix),
            ln,
            forward_weights_init=self.forward_weights_init,
            att_dropout=self.att_dropout,
            attention_left_only=True,
            n_out=self.v_dim,
            num_heads=self.att_num_heads,
            total_key_dim=self.qk_dim)
        lin = subnet_unit.add_linear_layer(
            '{}_lin'.format(prefix),
            att,
            n_out=self.out_dim,
            with_bias=False,
            forward_weights_init=self.forward_weights_init)
        drop = subnet_unit.add_dropout_layer('{}_drop'.format(prefix),
                                             lin,
                                             dropout=self.dropout)
        out = subnet_unit.add_combine_layer('{}_out'.format(prefix),
                                            [drop, source],
                                            kind='add',
                                            n_out=self.out_dim)
        return out

    def _create_decoder_block(self, subnet_unit: ReturnnNetwork, source, i):
        prefix = self.prefix_name + ('dec_%i' % i)
        masked_mhsa = self._create_masked_mhsa(subnet_unit, source, prefix)
        ff = self._create_ff_block(subnet_unit, masked_mhsa, prefix)
        out = subnet_unit.add_copy_layer(prefix, ff)
        return out

    def create_network(self):
        subnet_unit = ReturnnNetwork()
        target_embed_raw = subnet_unit.add_linear_layer(
            '{}target_embed_raw'.format(self.prefix_name),
            self.source,
            forward_weights_init=self.forward_weights_init,
            n_out=self.embed_dim,
            with_bias=False,
            param_device='CPU' if self.emb_cpu_lookup else None)

        target_embed_with_pos = subnet_unit.add_pos_encoding_layer(
            '{}target_embed_with_pos'.format(self.prefix_name),
            target_embed_raw)

        target_embed = subnet_unit.add_dropout_layer(
            '{}target_embed'.format(self.prefix_name),
            target_embed_with_pos,
            dropout=self.embed_dropout)

        target_embed_lin = subnet_unit.add_linear_layer(
            '{}target_embed_lin'.format(self.prefix_name),
            target_embed,
            with_bias=False,
            forward_weights_init=self.forward_weights_init,
            n_out=self.out_dim)

        x = target_embed_lin
        for i in range(self.num_layers):
            x = self._create_decoder_block(subnet_unit, x, i)

        # final LN
        decoder = subnet_unit.add_layer_norm_layer(
            '{}decoder'.format(self.prefix_name), x)

        subnet_unit.add_softmax_layer(
            '{}output'.format(self.prefix_name),
            decoder,
            forward_weights_init=self.forward_weights_init,
            loss='ce',
            target=self.target,
            with_bias=True,
            dropout=self.dropout)

        if self.use_as_ext_lm:
            self.network = copy.deepcopy(subnet_unit)
        else:
            self.network.add_subnet_rec_layer('output',
                                              unit=subnet_unit.get_net(),
                                              target=self.target,
                                              source=self.source)

        return 'output'
 def _create_decoder_block(self, subnet_unit: ReturnnNetwork, source, i):
     prefix = self.prefix_name + ('dec_%i' % i)
     masked_mhsa = self._create_masked_mhsa(subnet_unit, source, prefix)
     ff = self._create_ff_block(subnet_unit, masked_mhsa, prefix)
     out = subnet_unit.add_copy_layer(prefix, ff)
     return out
Beispiel #12
0
    def create(self):
        out_net = ReturnnNetwork()

        out_net.add_linear_layer('s_transformed',
                                 's',
                                 n_out=self.enc_key_dim,
                                 with_bias=False,
                                 l2=self.l2)  # project query

        if self.loc_num_channels is not None:
            assert self.loc_filter_size is not None
            weight_feedback = ConvLocAwareness(
                enc_key_dim=self.enc_key_dim,
                filter_size=self.loc_filter_size,
                num_channels=self.loc_num_channels,
                l2=self.l2)
        else:
            # additive
            weight_feedback = AdditiveLocAwareness(
                enc_key_dim=self.enc_key_dim, att_num_heads=self.att_num_heads)

        out_net.update(weight_feedback.create())  # add att weight feedback

        out_net.add_combine_layer(
            'energy_in',
            ['base:enc_ctx', weight_feedback.name, 's_transformed'],
            kind='add',
            n_out=self.enc_key_dim)

        # compute energies
        out_net.add_activation_layer('energy_tanh',
                                     'energy_in',
                                     activation='tanh')
        energy = out_net.add_linear_layer('energy',
                                          'energy_tanh',
                                          n_out=self.att_num_heads,
                                          with_bias=False,
                                          l2=self.l2)

        if self.att_dropout:
            att_weights0 = out_net.add_softmax_over_spatial_layer(
                'att_weights0', energy)
            att_weights = out_net.add_dropout_layer(
                'att_weights',
                att_weights0,
                dropout=self.att_dropout,
                dropout_noise_shape={'*': None})
        else:
            att_weights = out_net.add_softmax_over_spatial_layer(
                'att_weights', energy)

        att0 = out_net.add_generic_att_layer('att0',
                                             weights=att_weights,
                                             base='base:enc_value')
        self.name = out_net.add_merge_dims_layer('att',
                                                 att0,
                                                 axes='except_batch')

        return out_net.get_net()
  def _create_external_lm_net(self) -> dict:
    lm_net_out = ReturnnNetwork()

    ext_lm_subnet = self.ext_lm_opts['lm_subnet']
    ext_lm_scale = self.ext_lm_opts['lm_scale']

    assert isinstance(ext_lm_subnet, dict)
    is_recurrent = self.ext_lm_opts.get('is_recurrent', False)
    if is_recurrent:
      lm_output_prob = self.ext_lm_opts['lm_output_prob_name']
      ext_lm_subnet[lm_output_prob]['target'] = self.target
      lm_net_out.update(ext_lm_subnet)  # just append
    else:
      ext_lm_model = self.ext_lm_opts['lm_model']
      lm_net_out.add_subnetwork(
        'lm_output', 'prev:output', subnetwork_net=ext_lm_subnet, load_on_init=ext_lm_model)
      lm_output_prob = lm_net_out.add_activation_layer(
        'lm_output_prob', 'lm_output', activation='softmax', target=self.target)

    fusion_str = 'safe_log(source(0)) + {} * safe_log(source(1))'.format(ext_lm_scale)  # shallow fusion
    fusion_source = [self.am_output_prob, lm_output_prob]

    if self.prior_lm_opts:

      if self.dec_type == 'lstm':
        ilm_decoder = LSTMILMDecoder(self.asr_decoder, self.prior_lm_opts)
      elif self.dec_type == 'transformer':
        ilm_decoder = TransformerMiniLSTMDecoder(self.asr_decoder, self.prior_lm_opts)
      else:
        raise ValueError('dec type: {} is not valid'.format(self.dec_type))

      ilm_decoder.create_network()  # add ILM
      fusion_str += ' - {} * safe_log(source(2))'.format(self.prior_lm_opts['scale'])
      fusion_source += [ilm_decoder.output_prob_name]

    lm_net_out.add_eval_layer('combo_output_prob', source=fusion_source, eval=fusion_str)
    lm_net_out.add_choice_layer(
      'output', 'combo_output_prob', target=self.target, beam_size=self.beam_size, initial_output=0,
      input_type='log_prob')

    return lm_net_out.get_net()