コード例 #1
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
    def _create_prior_net(self, subnet_unit: ReturnnNetwork, opts):
        prior_type = opts.get('type', 'zero')

        # fixed_ctx_vec_variants = ['zero', 'avg', 'train_avg_ctx', 'train_avg_enc', 'avg_zero', 'trained_vec']

        if prior_type == 'zero':  # set att context vector to zero
            prior_att_input = subnet_unit.add_eval_layer(
                'zero_att', 'att', eval='tf.zeros_like(source(0))')
        elif prior_type == 'avg':  # during search per utterance
            self.base_model.network.add_reduce_layer('encoder_mean',
                                                     'encoder',
                                                     mode='mean',
                                                     axes=['t'
                                                           ])  # [B, enc-dim]
            prior_att_input = 'base:encoder_mean'
        elif prior_type == 'train_avg_ctx':  # average all context vectors over training data
            prior_att_input = subnet_unit.add_constant_layer(
                'train_avg_ctx',
                value=opts['data'],
                with_batch_dim=True,
                dtype='float32')
        elif prior_type == 'train_avg_enc':  # average all encoder states over training data
            prior_att_input = subnet_unit.add_constant_layer(
                'train_avg_enc',
                value=opts['data'],
                with_batch_dim=True,
                dtype='float32')
        elif prior_type == 'mini_lstm':  # train a mini LM-like LSTM and use that as prior
            # example: lstmdim_100-l2_5e-05-recwd_0.0
            n_out = 50
            l2 = 0.0
            recwd = 0.0
            if opts.get('prefix_name', None):
                segs = opts['prefix_name'].split('-')
                for arg in segs:
                    name, val = arg.split('_', 1)
                    if name == 'lstmdim':
                        n_out = int(val)
                    elif name == 'l2':
                        l2 = float(val)
                    elif name == 'recwd':
                        recwd = float(val)

            mini_lstm_inputs = opts.get('mini_lstm_inp',
                                        'prev:target_embed').split('+')
            if len(mini_lstm_inputs) == 1:
                mini_lstm_inputs = mini_lstm_inputs[0]

            subnet_unit.add_rec_layer('mini_att_lstm',
                                      mini_lstm_inputs,
                                      n_out=n_out,
                                      l2=l2,
                                      rec_weight_dropout=recwd)
            prior_att_input = subnet_unit.add_linear_layer('mini_att',
                                                           'mini_att_lstm',
                                                           activation=None,
                                                           n_out=2048,
                                                           l2=0.0001)
        elif prior_type == 'adaptive_ctx_vec':  # \hat{c}_i = FF(h_i)
            num_layers = opts.get('num_layers', 3)
            dim = opts.get('dim', 512)
            act = opts.get('act', 'relu')
            x = 's'
            for i in range(num_layers):
                x = subnet_unit.add_linear_layer('adaptive_att_%d' % i,
                                                 x,
                                                 n_out=dim,
                                                 **opts.get('att_opts', {}))
                x = subnet_unit.add_activation_layer('adaptive_att_%d_%s' %
                                                     (i, act),
                                                     x,
                                                     activation=act)
            prior_att_input = subnet_unit.add_linear_layer('adaptive_att',
                                                           x,
                                                           n_out=2048,
                                                           **opts.get(
                                                               'att_opts', {}))
        elif prior_type == 'trained_vec':
            prior_att_input = subnet_unit.add_variable_layer(
                'trained_vec_att_var', shape=[2048], L2=0.0001)
        elif prior_type == 'avg_zero':
            self.base_model.network.add_reduce_layer('encoder_mean',
                                                     'encoder',
                                                     mode='mean',
                                                     axes=['t'
                                                           ])  # [B, enc-dim]
            subnet_unit.add_eval_layer('zero_att',
                                       'att',
                                       eval='tf.zeros_like(source(0))')
            return
        elif prior_type == 'density_ratio':
            assert 'lm_subnet' in opts and 'lm_model' in opts
            return self._add_density_ratio(subnet_unit,
                                           lm_subnet=opts['lm_subnet'],
                                           lm_model=opts['lm_model'])
        else:
            raise ValueError(
                '{} prior type is not supported'.format(prior_type))

        if prior_type != 'mini_lstm':
            is_first_frame = subnet_unit.add_compare_layer('is_first_frame',
                                                           source=':i',
                                                           kind='equal',
                                                           value=0)
            zero_att = subnet_unit.add_eval_layer(
                'zero_att', 'att', eval='tf.zeros_like(source(0))')
            prev_att = subnet_unit.add_switch_layer('prev_att',
                                                    condition=is_first_frame,
                                                    true_from=zero_att,
                                                    false_from=prior_att_input)
        else:
            prev_att = 'prev:' + prior_att_input

        assert prev_att is not None

        key_names = ['s', 'readout_in', 'readout', 'output_prob']
        for key_name in key_names:
            d = copy.deepcopy(subnet_unit[key_name])
            # update attention input
            new_sources = []
            from_list = d['from']
            if isinstance(from_list, str):
                from_list = [from_list]
            assert isinstance(from_list, list)
            for src in from_list:
                if 'att' in src:
                    if src.split(':')[0] == 'prev':
                        assert prev_att not in new_sources
                        new_sources += [prev_att]
                    else:
                        new_sources += [prior_att_input]
                elif src in key_names:
                    new_sources += [('prev:' if 'prev' in src else '') +
                                    'prior_{}'.format(src.split(':')[-1])]
                else:
                    new_sources += [src]
            d['from'] = new_sources
            subnet_unit['prior_{}'.format(key_name)] = d
        return 'prior_output_prob'
コード例 #2
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
    def add_decoder_subnetwork(self, subnet_unit: ReturnnNetwork):

        # target embedding
        if self.embed_dropout:
            # TODO: this is not a good approach. if i want to load a checkpoint from a trained model without embed dropout,
            # i would need to remap variable name target_embed to target_embed0 to load target_embed0/W
            subnet_unit.add_linear_layer('target_embed0',
                                         'output',
                                         n_out=self.embed_dim,
                                         initial_output=0,
                                         with_bias=False)
            subnet_unit.add_dropout_layer('target_embed',
                                          'target_embed0',
                                          dropout=self.embed_dropout,
                                          dropout_noise_shape={'*': None})
        else:
            subnet_unit.add_linear_layer('target_embed',
                                         'output',
                                         n_out=self.embed_dim,
                                         initial_output=0,
                                         with_bias=False)

        subnet_unit.add_compare_layer('end', source='output',
                                      value=0)  # sentence end token

        # ------ attention location-awareness ------ #

        # conv-based
        if self.loc_conv_att_filter_size:
            assert self.loc_conv_att_num_channels
            pad_left = subnet_unit.add_pad_layer(
                'feedback_pad_left',
                'prev:att_weights',
                axes='s:0',
                padding=((self.loc_conv_att_filter_size - 1) // 2, 0),
                value=0)
            pad_right = subnet_unit.add_pad_layer(
                'feedback_pad_right',
                pad_left,
                axes='s:0',
                padding=(0, (self.loc_conv_att_filter_size - 1) // 2),
                value=0)
            loc_att_conv = subnet_unit.add_conv_layer(
                'loc_att_conv',
                pad_right,
                activation=None,
                with_bias=False,
                filter_size=(self.loc_conv_att_filter_size, ),
                padding='valid',
                n_out=self.loc_conv_att_num_channels,
                l2=self.l2)
            subnet_unit.add_linear_layer('weight_feedback',
                                         loc_att_conv,
                                         activation=None,
                                         with_bias=False,
                                         n_out=self.enc_key_dim)
        else:
            # additive
            subnet_unit.add_eval_layer(
                'accum_att_weights', [
                    "prev:accum_att_weights", "att_weights",
                    "base:inv_fertility"
                ],
                eval='source(0) + source(1) * source(2) * 0.5',
                out_type={
                    "dim": self.att_num_heads,
                    "shape": (None, self.att_num_heads)
                })
            subnet_unit.add_linear_layer('weight_feedback',
                                         'prev:accum_att_weights',
                                         n_out=self.enc_key_dim,
                                         with_bias=False)

        subnet_unit.add_linear_layer('s_transformed',
                                     's',
                                     n_out=self.enc_key_dim,
                                     with_bias=False)
        subnet_unit.add_combine_layer(
            'energy_in', ['base:enc_ctx', 'weight_feedback', 's_transformed'],
            kind='add',
            n_out=self.enc_key_dim)
        subnet_unit.add_activation_layer('energy_tanh',
                                         'energy_in',
                                         activation='tanh')
        subnet_unit.add_linear_layer('energy',
                                     'energy_tanh',
                                     n_out=self.att_num_heads,
                                     with_bias=False)

        if self.att_dropout:
            subnet_unit.add_softmax_over_spatial_layer('att_weights0',
                                                       'energy')
            subnet_unit.add_dropout_layer('att_weights',
                                          'att_weights0',
                                          dropout=self.att_dropout,
                                          dropout_noise_shape={'*': None})
        else:
            if self.relax_att_scale:
                subnet_unit.add_softmax_over_spatial_layer(
                    'att_weights0', 'energy')
                subnet_unit.add_length_layer('encoder_len',
                                             'base:encoder',
                                             dtype='float32')  # [B]
                subnet_unit.add_eval_layer('scaled_encoder_len',
                                           source=['encoder_len'],
                                           eval='{} / source(0)'.format(
                                               self.relax_att_scale))
                subnet_unit.add_eval_layer(
                    'att_weights',
                    source=['att_weights0', 'scaled_encoder_len'],
                    eval='{} * source(0) + source(1)'.format(
                        1 - self.relax_att_scale))
            else:
                subnet_unit.add_softmax_over_spatial_layer(
                    'att_weights', 'energy')

        subnet_unit.add_generic_att_layer('att0',
                                          weights='att_weights',
                                          base='base:enc_value')
        subnet_unit.add_merge_dims_layer('att', 'att0', axes='except_batch')

        # LM-like component same as here https://arxiv.org/pdf/2001.07263.pdf
        lstm_lm_component = None
        if self.add_lstm_lm:
            lstm_lm_component = subnet_unit.add_rnn_cell_layer(
                'lm_like_s',
                'prev:target_embed',
                n_out=self.lstm_lm_dim,
                l2=self.l2)

        lstm_inputs = []
        if lstm_lm_component:
            lstm_inputs += [lstm_lm_component]
        else:
            lstm_inputs += ['prev:target_embed']
        lstm_inputs += ['prev:att']

        if self.dec_state_no_label_ctx:
            lstm_inputs = ['prev:att']  # no label feedback

        # LSTM decoder
        if self.dec_zoneout:
            subnet_unit.add_rnn_cell_layer('s',
                                           lstm_inputs,
                                           n_out=self.dec_lstm_num_units,
                                           unit='zoneoutlstm',
                                           unit_opts={
                                               'zoneout_factor_cell': 0.15,
                                               'zoneout_factor_output': 0.05
                                           })
        else:
            if self.rec_weight_dropout:
                # a rec layer with unit nativelstm2 is required to use rec_weight_dropout
                subnet_unit.add_rec_layer(
                    's',
                    lstm_inputs,
                    n_out=self.dec_lstm_num_units,
                    l2=self.l2,
                    rec_weight_dropout=self.rec_weight_dropout,
                    unit='NativeLSTM2')
            else:
                subnet_unit.add_rnn_cell_layer('s',
                                               lstm_inputs,
                                               n_out=self.dec_lstm_num_units,
                                               l2=self.l2)

        # AM softmax output layer
        if self.dec_state_no_label_ctx and self.add_lstm_lm:
            subnet_unit.add_linear_layer(
                'readout_in', ["lm_like_s", "prev:target_embed", "att"],
                n_out=self.dec_output_num_units)
            if self.add_no_label_ctx_s_to_output:
                subnet_unit.add_linear_layer(
                    'readout_in',
                    ["lm_like_s", "s", "prev:target_embed", "att"],
                    n_out=self.dec_output_num_units)
        else:
            subnet_unit.add_linear_layer('readout_in',
                                         ["s", "prev:target_embed", "att"],
                                         n_out=self.dec_output_num_units)

        if self.reduceout:
            subnet_unit.add_reduceout_layer('readout', 'readout_in')
        else:
            subnet_unit.add_copy_layer('readout', 'readout_in')

        if self.local_fusion_opts:
            output_prob = subnet_unit.add_softmax_layer('output_prob',
                                                        'readout',
                                                        l2=self.l2,
                                                        target=self.target,
                                                        dropout=self.dropout)
            self._add_local_fusion(subnet_unit, am_output_prob=output_prob)
        elif self.mwer:
            # only MWER so CE is disabled
            output_prob = subnet_unit.add_softmax_layer('output_prob',
                                                        'readout',
                                                        l2=self.l2,
                                                        target=self.target,
                                                        dropout=self.dropout)
        else:
            ce_loss_opts = {'label_smoothing': self.label_smoothing}
            if self.ce_loss_scale:
                ce_loss_opts['scale'] = self.ce_loss_scale
            output_prob = subnet_unit.add_softmax_layer('output_prob',
                                                        'readout',
                                                        l2=self.l2,
                                                        loss='ce',
                                                        loss_opts=ce_loss_opts,
                                                        target=self.target,
                                                        dropout=self.dropout)

        # do not load the bias
        if self.remove_softmax_bias:
            subnet_unit['output_prob']['with_bias'] = False

        # for prior LM estimation
        prior_output_prob = None
        if self.prior_lm_opts:
            prior_output_prob = self._create_prior_net(
                subnet_unit, self.prior_lm_opts
            )  # this require preload_from_files in config

        # Beam search
        # only support shallow fusion for now
        if self.ext_lm_opts:
            self._add_external_LM(subnet_unit, output_prob, prior_output_prob)
        else:
            if self.coverage_term_scale:
                output_prob = subnet_unit.add_eval_layer(
                    'combo_output_prob',
                    eval='safe_log(source(0)) + {} * source(1)'.format(
                        self.coverage_term_scale),
                    source=['output_prob', 'accum_coverage'])
                input_type = 'log_prob'
            else:
                output_prob = 'output_prob'
                input_type = None

            if self.length_norm:
                subnet_unit.add_choice_layer('output',
                                             output_prob,
                                             target=self.target,
                                             beam_size=self.beam_size,
                                             initial_output=0,
                                             input_type=input_type)
            else:
                subnet_unit.add_choice_layer(
                    'output',
                    output_prob,
                    target=self.target,
                    beam_size=self.beam_size,
                    initial_output=0,
                    length_normalization=self.length_norm,
                    input_type=input_type)

        if self.ilmt_opts:
            self._create_ilmt_net(subnet_unit)

        # recurrent subnetwork
        dec_output = self.network.add_subnet_rec_layer(
            'output',
            unit=subnet_unit.get_net(),
            target=self.target,
            source=self.source)

        return dec_output
コード例 #3
0
ファイル: rnn_decoder.py プロジェクト: rwth-i6/i6_experiments
    def add_decoder_subnetwork(self, subnet_unit: ReturnnNetwork):

        subnet_unit.add_compare_layer('end', source='output',
                                      value=0)  # sentence end token

        # target embedding
        subnet_unit.add_linear_layer(
            'target_embed0',
            'output',
            n_out=self.embed_dim,
            initial_output=0,
            with_bias=False,
            l2=self.l2,
            forward_weights_init=self.embed_weight_init)

        subnet_unit.add_dropout_layer('target_embed',
                                      'target_embed0',
                                      dropout=self.embed_dropout,
                                      dropout_noise_shape={'*': None})

        # attention
        att = AttentionMechanism(
            enc_key_dim=self.enc_key_dim,
            att_num_heads=self.att_num_heads,
            att_dropout=self.att_dropout,
            l2=self.l2,
            loc_filter_size=self.loc_conv_att_filter_size,
            loc_num_channels=self.loc_conv_att_num_channels)
        subnet_unit.update(att.create())

        # LM-like component same as here https://arxiv.org/pdf/2001.07263.pdf
        lstm_lm_component = None
        if self.add_lstm_lm:
            lstm_lm_component = subnet_unit.add_rnn_cell_layer(
                'lm_like_s',
                'prev:target_embed',
                n_out=self.lstm_lm_dim,
                l2=self.l2)

        lstm_inputs = []
        if lstm_lm_component:
            lstm_inputs += [lstm_lm_component]
        else:
            lstm_inputs += ['prev:target_embed']
        lstm_inputs += ['prev:att']

        # LSTM decoder (or decoder state)
        if self.dec_zoneout:
            subnet_unit.add_rnn_cell_layer('s',
                                           lstm_inputs,
                                           n_out=self.dec_lstm_num_units,
                                           l2=self.l2,
                                           weights_init=self.lstm_weights_init,
                                           unit='zoneoutlstm',
                                           unit_opts={
                                               'zoneout_factor_cell': 0.15,
                                               'zoneout_factor_output': 0.05
                                           })
        else:
            if self.rec_weight_dropout:
                # a rec layer with unit nativelstm2 is required to use rec_weight_dropout
                subnet_unit.add_rec_layer(
                    's',
                    lstm_inputs,
                    n_out=self.dec_lstm_num_units,
                    l2=self.l2,
                    unit='NativeLSTM2',
                    rec_weight_dropout=self.rec_weight_dropout,
                    weights_init=self.lstm_weights_init)
            else:
                subnet_unit.add_rnn_cell_layer(
                    's',
                    lstm_inputs,
                    n_out=self.dec_lstm_num_units,
                    l2=self.l2,
                    weights_init=self.lstm_weights_init)

        # ASR softmax output layer
        subnet_unit.add_linear_layer('readout_in',
                                     ["s", "prev:target_embed", "att"],
                                     n_out=self.dec_output_num_units,
                                     l2=self.l2)

        if self.reduceout:
            subnet_unit.add_reduceout_layer('readout', 'readout_in')
        else:
            subnet_unit.add_copy_layer('readout', 'readout_in')

        output_prob = subnet_unit.add_softmax_layer(
            'output_prob',
            'readout',
            l2=self.l2,
            loss='ce',
            loss_opts={'label_smoothing': self.label_smoothing},
            target=self.target,
            dropout=self.dropout)

        subnet_unit.add_choice_layer('output',
                                     output_prob,
                                     target=self.target,
                                     beam_size=self.beam_size,
                                     initial_output=0)

        # recurrent subnetwork
        dec_output = self.network.add_subnet_rec_layer(
            'output',
            unit=subnet_unit.get_net(),
            target=self.target,
            source=self.source)

        return dec_output