Пример #1
0
    def add_decoder_subnetwork(self, subnet_unit: ReturnnNetwork):

        # target embedding
        if self.embed_dropout:
            # TODO: this is not a good approach. if i want to load a checkpoint from a trained model without embed dropout,
            # i would need to remap variable name target_embed to target_embed0 to load target_embed0/W
            subnet_unit.add_linear_layer('target_embed0',
                                         'output',
                                         n_out=self.embed_dim,
                                         initial_output=0,
                                         with_bias=False)
            subnet_unit.add_dropout_layer('target_embed',
                                          'target_embed0',
                                          dropout=self.embed_dropout,
                                          dropout_noise_shape={'*': None})
        else:
            subnet_unit.add_linear_layer('target_embed',
                                         'output',
                                         n_out=self.embed_dim,
                                         initial_output=0,
                                         with_bias=False)

        subnet_unit.add_compare_layer('end', source='output',
                                      value=0)  # sentence end token

        # ------ attention location-awareness ------ #

        # conv-based
        if self.loc_conv_att_filter_size:
            assert self.loc_conv_att_num_channels
            pad_left = subnet_unit.add_pad_layer(
                'feedback_pad_left',
                'prev:att_weights',
                axes='s:0',
                padding=((self.loc_conv_att_filter_size - 1) // 2, 0),
                value=0)
            pad_right = subnet_unit.add_pad_layer(
                'feedback_pad_right',
                pad_left,
                axes='s:0',
                padding=(0, (self.loc_conv_att_filter_size - 1) // 2),
                value=0)
            loc_att_conv = subnet_unit.add_conv_layer(
                'loc_att_conv',
                pad_right,
                activation=None,
                with_bias=False,
                filter_size=(self.loc_conv_att_filter_size, ),
                padding='valid',
                n_out=self.loc_conv_att_num_channels,
                l2=self.l2)
            subnet_unit.add_linear_layer('weight_feedback',
                                         loc_att_conv,
                                         activation=None,
                                         with_bias=False,
                                         n_out=self.enc_key_dim)
        else:
            # additive
            subnet_unit.add_eval_layer(
                'accum_att_weights', [
                    "prev:accum_att_weights", "att_weights",
                    "base:inv_fertility"
                ],
                eval='source(0) + source(1) * source(2) * 0.5',
                out_type={
                    "dim": self.att_num_heads,
                    "shape": (None, self.att_num_heads)
                })
            subnet_unit.add_linear_layer('weight_feedback',
                                         'prev:accum_att_weights',
                                         n_out=self.enc_key_dim,
                                         with_bias=False)

        subnet_unit.add_linear_layer('s_transformed',
                                     's',
                                     n_out=self.enc_key_dim,
                                     with_bias=False)
        subnet_unit.add_combine_layer(
            'energy_in', ['base:enc_ctx', 'weight_feedback', 's_transformed'],
            kind='add',
            n_out=self.enc_key_dim)
        subnet_unit.add_activation_layer('energy_tanh',
                                         'energy_in',
                                         activation='tanh')
        subnet_unit.add_linear_layer('energy',
                                     'energy_tanh',
                                     n_out=self.att_num_heads,
                                     with_bias=False)

        if self.att_dropout:
            subnet_unit.add_softmax_over_spatial_layer('att_weights0',
                                                       'energy')
            subnet_unit.add_dropout_layer('att_weights',
                                          'att_weights0',
                                          dropout=self.att_dropout,
                                          dropout_noise_shape={'*': None})
        else:
            if self.relax_att_scale:
                subnet_unit.add_softmax_over_spatial_layer(
                    'att_weights0', 'energy')
                subnet_unit.add_length_layer('encoder_len',
                                             'base:encoder',
                                             dtype='float32')  # [B]
                subnet_unit.add_eval_layer('scaled_encoder_len',
                                           source=['encoder_len'],
                                           eval='{} / source(0)'.format(
                                               self.relax_att_scale))
                subnet_unit.add_eval_layer(
                    'att_weights',
                    source=['att_weights0', 'scaled_encoder_len'],
                    eval='{} * source(0) + source(1)'.format(
                        1 - self.relax_att_scale))
            else:
                subnet_unit.add_softmax_over_spatial_layer(
                    'att_weights', 'energy')

        subnet_unit.add_generic_att_layer('att0',
                                          weights='att_weights',
                                          base='base:enc_value')
        subnet_unit.add_merge_dims_layer('att', 'att0', axes='except_batch')

        # LM-like component same as here https://arxiv.org/pdf/2001.07263.pdf
        lstm_lm_component = None
        if self.add_lstm_lm:
            lstm_lm_component = subnet_unit.add_rnn_cell_layer(
                'lm_like_s',
                'prev:target_embed',
                n_out=self.lstm_lm_dim,
                l2=self.l2)

        lstm_inputs = []
        if lstm_lm_component:
            lstm_inputs += [lstm_lm_component]
        else:
            lstm_inputs += ['prev:target_embed']
        lstm_inputs += ['prev:att']

        if self.dec_state_no_label_ctx:
            lstm_inputs = ['prev:att']  # no label feedback

        # LSTM decoder
        if self.dec_zoneout:
            subnet_unit.add_rnn_cell_layer('s',
                                           lstm_inputs,
                                           n_out=self.dec_lstm_num_units,
                                           unit='zoneoutlstm',
                                           unit_opts={
                                               'zoneout_factor_cell': 0.15,
                                               'zoneout_factor_output': 0.05
                                           })
        else:
            if self.rec_weight_dropout:
                # a rec layer with unit nativelstm2 is required to use rec_weight_dropout
                subnet_unit.add_rec_layer(
                    's',
                    lstm_inputs,
                    n_out=self.dec_lstm_num_units,
                    l2=self.l2,
                    rec_weight_dropout=self.rec_weight_dropout,
                    unit='NativeLSTM2')
            else:
                subnet_unit.add_rnn_cell_layer('s',
                                               lstm_inputs,
                                               n_out=self.dec_lstm_num_units,
                                               l2=self.l2)

        # AM softmax output layer
        if self.dec_state_no_label_ctx and self.add_lstm_lm:
            subnet_unit.add_linear_layer(
                'readout_in', ["lm_like_s", "prev:target_embed", "att"],
                n_out=self.dec_output_num_units)
            if self.add_no_label_ctx_s_to_output:
                subnet_unit.add_linear_layer(
                    'readout_in',
                    ["lm_like_s", "s", "prev:target_embed", "att"],
                    n_out=self.dec_output_num_units)
        else:
            subnet_unit.add_linear_layer('readout_in',
                                         ["s", "prev:target_embed", "att"],
                                         n_out=self.dec_output_num_units)

        if self.reduceout:
            subnet_unit.add_reduceout_layer('readout', 'readout_in')
        else:
            subnet_unit.add_copy_layer('readout', 'readout_in')

        if self.local_fusion_opts:
            output_prob = subnet_unit.add_softmax_layer('output_prob',
                                                        'readout',
                                                        l2=self.l2,
                                                        target=self.target,
                                                        dropout=self.dropout)
            self._add_local_fusion(subnet_unit, am_output_prob=output_prob)
        elif self.mwer:
            # only MWER so CE is disabled
            output_prob = subnet_unit.add_softmax_layer('output_prob',
                                                        'readout',
                                                        l2=self.l2,
                                                        target=self.target,
                                                        dropout=self.dropout)
        else:
            ce_loss_opts = {'label_smoothing': self.label_smoothing}
            if self.ce_loss_scale:
                ce_loss_opts['scale'] = self.ce_loss_scale
            output_prob = subnet_unit.add_softmax_layer('output_prob',
                                                        'readout',
                                                        l2=self.l2,
                                                        loss='ce',
                                                        loss_opts=ce_loss_opts,
                                                        target=self.target,
                                                        dropout=self.dropout)

        # do not load the bias
        if self.remove_softmax_bias:
            subnet_unit['output_prob']['with_bias'] = False

        # for prior LM estimation
        prior_output_prob = None
        if self.prior_lm_opts:
            prior_output_prob = self._create_prior_net(
                subnet_unit, self.prior_lm_opts
            )  # this require preload_from_files in config

        # Beam search
        # only support shallow fusion for now
        if self.ext_lm_opts:
            self._add_external_LM(subnet_unit, output_prob, prior_output_prob)
        else:
            if self.coverage_term_scale:
                output_prob = subnet_unit.add_eval_layer(
                    'combo_output_prob',
                    eval='safe_log(source(0)) + {} * source(1)'.format(
                        self.coverage_term_scale),
                    source=['output_prob', 'accum_coverage'])
                input_type = 'log_prob'
            else:
                output_prob = 'output_prob'
                input_type = None

            if self.length_norm:
                subnet_unit.add_choice_layer('output',
                                             output_prob,
                                             target=self.target,
                                             beam_size=self.beam_size,
                                             initial_output=0,
                                             input_type=input_type)
            else:
                subnet_unit.add_choice_layer(
                    'output',
                    output_prob,
                    target=self.target,
                    beam_size=self.beam_size,
                    initial_output=0,
                    length_normalization=self.length_norm,
                    input_type=input_type)

        if self.ilmt_opts:
            self._create_ilmt_net(subnet_unit)

        # recurrent subnetwork
        dec_output = self.network.add_subnet_rec_layer(
            'output',
            unit=subnet_unit.get_net(),
            target=self.target,
            source=self.source)

        return dec_output
Пример #2
0
    def add_decoder_subnetwork(self, subnet_unit: ReturnnNetwork):

        subnet_unit.add_compare_layer('end', source='output',
                                      value=0)  # sentence end token

        # target embedding
        subnet_unit.add_linear_layer(
            'target_embed0',
            'output',
            n_out=self.embed_dim,
            initial_output=0,
            with_bias=False,
            l2=self.l2,
            forward_weights_init=self.embed_weight_init)

        subnet_unit.add_dropout_layer('target_embed',
                                      'target_embed0',
                                      dropout=self.embed_dropout,
                                      dropout_noise_shape={'*': None})

        # attention
        att = AttentionMechanism(
            enc_key_dim=self.enc_key_dim,
            att_num_heads=self.att_num_heads,
            att_dropout=self.att_dropout,
            l2=self.l2,
            loc_filter_size=self.loc_conv_att_filter_size,
            loc_num_channels=self.loc_conv_att_num_channels)
        subnet_unit.update(att.create())

        # LM-like component same as here https://arxiv.org/pdf/2001.07263.pdf
        lstm_lm_component = None
        if self.add_lstm_lm:
            lstm_lm_component = subnet_unit.add_rnn_cell_layer(
                'lm_like_s',
                'prev:target_embed',
                n_out=self.lstm_lm_dim,
                l2=self.l2)

        lstm_inputs = []
        if lstm_lm_component:
            lstm_inputs += [lstm_lm_component]
        else:
            lstm_inputs += ['prev:target_embed']
        lstm_inputs += ['prev:att']

        # LSTM decoder (or decoder state)
        if self.dec_zoneout:
            subnet_unit.add_rnn_cell_layer('s',
                                           lstm_inputs,
                                           n_out=self.dec_lstm_num_units,
                                           l2=self.l2,
                                           weights_init=self.lstm_weights_init,
                                           unit='zoneoutlstm',
                                           unit_opts={
                                               'zoneout_factor_cell': 0.15,
                                               'zoneout_factor_output': 0.05
                                           })
        else:
            if self.rec_weight_dropout:
                # a rec layer with unit nativelstm2 is required to use rec_weight_dropout
                subnet_unit.add_rec_layer(
                    's',
                    lstm_inputs,
                    n_out=self.dec_lstm_num_units,
                    l2=self.l2,
                    unit='NativeLSTM2',
                    rec_weight_dropout=self.rec_weight_dropout,
                    weights_init=self.lstm_weights_init)
            else:
                subnet_unit.add_rnn_cell_layer(
                    's',
                    lstm_inputs,
                    n_out=self.dec_lstm_num_units,
                    l2=self.l2,
                    weights_init=self.lstm_weights_init)

        # ASR softmax output layer
        subnet_unit.add_linear_layer('readout_in',
                                     ["s", "prev:target_embed", "att"],
                                     n_out=self.dec_output_num_units,
                                     l2=self.l2)

        if self.reduceout:
            subnet_unit.add_reduceout_layer('readout', 'readout_in')
        else:
            subnet_unit.add_copy_layer('readout', 'readout_in')

        output_prob = subnet_unit.add_softmax_layer(
            'output_prob',
            'readout',
            l2=self.l2,
            loss='ce',
            loss_opts={'label_smoothing': self.label_smoothing},
            target=self.target,
            dropout=self.dropout)

        subnet_unit.add_choice_layer('output',
                                     output_prob,
                                     target=self.target,
                                     beam_size=self.beam_size,
                                     initial_output=0)

        # recurrent subnetwork
        dec_output = self.network.add_subnet_rec_layer(
            'output',
            unit=subnet_unit.get_net(),
            target=self.target,
            source=self.source)

        return dec_output