Exemple #1
0
    def __init__(self,
                 input_previous_word: bool,
                 attention_num_hidden: int,
                 attention_coverage_type: Optional[str] = None,
                 attention_coverage_num_hidden: int = 1,
                 prefix='',
                 layer_normalization: bool = False) -> None:
        dynamic_source_num_hidden = 1 if attention_coverage_type is None else attention_coverage_num_hidden
        super().__init__(input_previous_word=input_previous_word,
                         dynamic_source_num_hidden=dynamic_source_num_hidden)
        self.prefix = prefix
        self.attention_num_hidden = attention_num_hidden
        # input (encoder) to hidden
        self.att_e2h_weight = mx.sym.Variable("%satt_e2h_weight" % prefix)
        # input (query) to hidden
        self.att_q2h_weight = mx.sym.Variable("%satt_q2h_weight" % prefix)
        # hidden to score
        self.att_h2s_weight = mx.sym.Variable("%satt_h2s_weight" % prefix)
        # dynamic source (coverage) weights and settings
        # input (coverage) to hidden
        self.att_c2h_weight = mx.sym.Variable(
            "%satt_c2h_weight" % prefix) if attention_coverage_type else None
        self.coverage = sockeye.coverage.get_coverage(
            attention_coverage_type, dynamic_source_num_hidden,
            layer_normalization) if attention_coverage_type else None

        if layer_normalization:
            self._ln = LayerNormalization(num_hidden=attention_num_hidden,
                                          prefix="att_norm")
        else:
            self._ln = None
Exemple #2
0
 def __init__(self,
              num_hidden: int,
              prefix: str = 'lngru_',
              params: Optional[mx.rnn.RNNParams] = None,
              norm_scale: float = 1.0,
              norm_shift: float = 0.0) -> None:
     super(LayerNormGRUCell, self).__init__(num_hidden, prefix, params)
     self._iN = LayerNormalization(
         num_hidden=num_hidden * 3,
         prefix="%si2h" % self._prefix,
         scale=self.params.get('i2h_scale',
                               shape=(num_hidden * 3, ),
                               init=mx.init.Constant(value=norm_scale)),
         shift=self.params.get('i2h_shift',
                               shape=(num_hidden * 3, ),
                               init=mx.init.Constant(value=norm_shift)))
     self._hN = LayerNormalization(
         num_hidden=num_hidden * 3,
         prefix="%sh2h" % self._prefix,
         scale=self.params.get('h2h_scale',
                               shape=(num_hidden * 3, ),
                               init=mx.init.Constant(value=norm_scale)),
         shift=self.params.get('h2h_shift',
                               shape=(num_hidden * 3, ),
                               init=mx.init.Constant(value=norm_shift)))
     self._shape_fix = None
Exemple #3
0
    def __init__(self,
                 config: RecurrentDecoderConfig,
                 attention: attentions.Attention,
                 lexicon: Optional[lexicons.Lexicon] = None,
                 prefix=C.DECODER_PREFIX) -> None:
        # TODO: implement variant without input feeding
        self.rnn_config = config.rnn_config
        self.target_vocab_size = config.vocab_size
        self.num_target_embed = config.num_embed
        self.attention = attention
        self.weight_tying = config.weight_tying
        self.context_gating = config.context_gating
        self.layer_norm = config.layer_normalization
        self.lexicon = lexicon
        self.prefix = prefix

        self.num_hidden = self.rnn_config.num_hidden

        if self.context_gating:
            self.gate_w = mx.sym.Variable("%sgate_weight" % prefix)
            self.gate_b = mx.sym.Variable("%sgate_bias" % prefix)
            self.mapped_rnn_output_w = mx.sym.Variable(
                "%smapped_rnn_output_weight" % prefix)
            self.mapped_rnn_output_b = mx.sym.Variable(
                "%smapped_rnn_output_bias" % prefix)
            self.mapped_context_w = mx.sym.Variable("%smapped_context_weight" %
                                                    prefix)
            self.mapped_context_b = mx.sym.Variable("%smapped_context_bias" %
                                                    prefix)

        # Stacked RNN
        self.rnn = rnn.get_stacked_rnn(self.rnn_config, self.prefix)
        # RNN init state parameters
        self._create_layer_parameters()

        # Hidden state parameters
        self.hidden_w = mx.sym.Variable("%shidden_weight" % prefix)
        self.hidden_b = mx.sym.Variable("%shidden_bias" % prefix)
        self.hidden_norm = LayerNormalization(
            self.num_hidden, prefix="%shidden_norm" %
            prefix) if self.layer_norm else None
        # Embedding & output parameters
        self.embedding = encoder.Embedding(self.num_target_embed,
                                           self.target_vocab_size,
                                           prefix=C.TARGET_EMBEDDING_PREFIX,
                                           dropout=0.)  # TODO dropout?
        if self.weight_tying:
            check_condition(
                self.num_hidden == self.num_target_embed,
                "Weight tying requires target embedding size and rnn_num_hidden to be equal"
            )
            self.cls_w = self.embedding.embed_weight
        else:
            self.cls_w = mx.sym.Variable("%scls_weight" % prefix)
        self.cls_b = mx.sym.Variable("%scls_bias" % prefix)
Exemple #4
0
 def __init__(self,
              num_hidden: int,
              prefix: str = 'lnlstm_',
              params: Optional[mx.rnn.RNNParams] = None,
              forget_bias: float = 1.0,
              norm_scale: float = 1.0,
              norm_shift: float = 0.0) -> None:
     super(LayerNormLSTMCell, self).__init__(num_hidden, prefix, params, forget_bias)
     self._iN = LayerNormalization(prefix="%si2h" % self._prefix,
                                   scale=self.params.get('i2h_scale', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_scale)),
                                   shift=self.params.get('i2h_shift', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_shift)))
     self._hN = LayerNormalization(prefix="%sh2h" % self._prefix,
                                   scale=self.params.get('h2h_scale', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_scale)),
                                   shift=self.params.get('h2h_shift', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_shift)))
     self._cN = LayerNormalization(prefix="%sc" % self._prefix,
                                   scale=self.params.get('c_scale', shape=(num_hidden,), init=mx.init.Constant(value=norm_scale)),
                                   shift=self.params.get('c_shift', shape=(num_hidden,), init=mx.init.Constant(value=norm_shift)))
Exemple #5
0
 def __init__(self,
              coverage_num_hidden: int,
              activation: str,
              layer_normalization: bool) -> None:
     super().__init__()
     self.activation = activation
     self.num_hidden = coverage_num_hidden
     # input (encoder) to hidden
     self.cov_e2h_weight = mx.sym.Variable("%se2h_weight" % self.prefix)
     # decoder to hidden
     self.cov_dec2h_weight = mx.sym.Variable("%si2h_weight" % self.prefix)
     # previous coverage to hidden
     self.cov_prev2h_weight = mx.sym.Variable("%sprev2h_weight" % self.prefix)
     # attention scores to hidden
     self.cov_a2h_weight = mx.sym.Variable("%sa2h_weight" % self.prefix)
     # optional layer normalization
     self.layer_norm = None
     if layer_normalization and not self.num_hidden != 1:
         self.layer_norm = LayerNormalization(self.num_hidden,
                                              prefix="%snorm" % self.prefix) if layer_normalization else None
Exemple #6
0
 def __init__(self,
              num_hidden: int,
              prefix: str = 'lnggru_',
              params: Optional[mx.rnn.RNNParams] = None,
              norm_scale: float = 1.0,
              norm_shift: float = 0.0) -> None:
     super(LayerNormPerGateGRUCell, self).__init__(num_hidden, prefix, params)
     self._norm_layers = list()  # type: List[LayerNormalization]
     for name in ['r', 'z', 'o']:
         scale = self.params.get('%s_shift' % name, init=mx.init.Constant(value=norm_shift))
         shift = self.params.get('%s_scale' % name, init=mx.init.Constant(value=norm_scale))
         self._norm_layers.append(LayerNormalization(prefix="%s%s" % (self._prefix, name), scale=scale, shift=shift))
Exemple #7
0
 def __init__(self,
              num_hidden: int,
              prefix: str = 'lnglstm_',
              params: Optional[mx.rnn.RNNParams] = None,
              forget_bias: float = 1.0,
              norm_scale: float = 1.0,
              norm_shift: float = 0.0) -> None:
     super(LayerNormPerGateLSTMCell, self).__init__(num_hidden, prefix, params, forget_bias)
     self._norm_layers = list()  # type: List[LayerNormalization]
     for name in ['i', 'f', 'c', 'o', 's']:
         scale = self.params.get('%s_shift' % name,
                                 init=mx.init.Constant(value=norm_shift))
         shift = self.params.get('%s_scale' % name,
                                 init=mx.init.Constant(value=norm_scale if name != "f" else forget_bias))
         self._norm_layers.append(
             LayerNormalization(prefix="%s%s" % (self._prefix, name), scale=scale, shift=shift))
Exemple #8
0
 def _create_layer_parameters(self):
     """
     Creates parameters for encoder last state transformation into decoder layer initial states.
     """
     self.init_ws, self.init_bs = [], []
     self.init_norms = []
     for state_idx, (_, init_num_hidden) in enumerate(self.rnn.state_shape):
         self.init_ws.append(
             mx.sym.Variable("%senc2decinit_%d_weight" %
                             (self.prefix, state_idx)))
         self.init_bs.append(
             mx.sym.Variable("%senc2decinit_%d_bias" %
                             (self.prefix, state_idx)))
         if self.layer_norm:
             self.init_norms.append(
                 LayerNormalization(num_hidden=init_num_hidden,
                                    prefix="%senc2decinit_%d_norm" %
                                    (self.prefix, state_idx)))
Exemple #9
0
class LayerNormGRUCell(mx.rnn.GRUCell):
    """
    Gated Recurrent Unit (GRU) network cell with layer normalization across gates.
    Based on Jimmy Lei Ba et al: Layer Normalization (https://arxiv.org/pdf/1607.06450.pdf)

    :param num_hidden: number of RNN hidden units. Number of units in output symbol.
    :param prefix: prefix for name of layers (and name of weight if params is None).
    :param params: RNNParams or None. Container for weight sharing between cells. Created if None.
    :param norm_scale: scale/gain for layer normalization.
    :param norm_shift: shift/bias after layer normalization.
    """
    def __init__(self,
                 num_hidden: int,
                 prefix: str = 'lngru_',
                 params: Optional[mx.rnn.RNNParams] = None,
                 norm_scale: float = 1.0,
                 norm_shift: float = 0.0) -> None:
        super(LayerNormGRUCell, self).__init__(num_hidden, prefix, params)
        self._iN = LayerNormalization(
            num_hidden=num_hidden * 3,
            prefix="%si2h" % self._prefix,
            scale=self.params.get('i2h_scale',
                                  shape=(num_hidden * 3, ),
                                  init=mx.init.Constant(value=norm_scale)),
            shift=self.params.get('i2h_shift',
                                  shape=(num_hidden * 3, ),
                                  init=mx.init.Constant(value=norm_shift)))
        self._hN = LayerNormalization(
            num_hidden=num_hidden * 3,
            prefix="%sh2h" % self._prefix,
            scale=self.params.get('h2h_scale',
                                  shape=(num_hidden * 3, ),
                                  init=mx.init.Constant(value=norm_scale)),
            shift=self.params.get('h2h_shift',
                                  shape=(num_hidden * 3, ),
                                  init=mx.init.Constant(value=norm_shift)))
        self._shape_fix = None

    def __call__(self, inputs, states):
        self._counter += 1

        seq_idx = self._counter
        name = '%st%d_' % (self._prefix, seq_idx)
        prev_state_h = states[0]

        i2h = mx.sym.FullyConnected(data=inputs,
                                    weight=self._iW,
                                    bias=self._iB,
                                    num_hidden=self._num_hidden * 3,
                                    name="%s_i2h" % name)
        h2h = mx.sym.FullyConnected(data=prev_state_h,
                                    weight=self._hW,
                                    bias=self._hB,
                                    num_hidden=self._num_hidden * 3,
                                    name="%s_h2h" % name)
        if self._counter == 0:
            self._shape_fix = mx.sym.zeros_like(i2h)
        else:
            assert self._shape_fix is not None

        i2h = self._iN.normalize(i2h)
        h2h = self._hN.normalize(self._shape_fix + h2h)

        i2h_r, i2h_z, i2h = mx.sym.split(i2h,
                                         num_outputs=3,
                                         name="%s_i2h_slice" % name)
        h2h_r, h2h_z, h2h = mx.sym.split(h2h,
                                         num_outputs=3,
                                         name="%s_h2h_slice" % name)

        reset_gate = mx.sym.Activation(i2h_r + h2h_r,
                                       act_type="sigmoid",
                                       name="%s_r_act" % name)
        update_gate = mx.sym.Activation(i2h_z + h2h_z,
                                        act_type="sigmoid",
                                        name="%s_z_act" % name)

        next_h_tmp = mx.sym.Activation(i2h + reset_gate * h2h,
                                       act_type="tanh",
                                       name="%s_h_act" % name)

        next_h = mx.sym._internal._plus((1. - update_gate) * next_h_tmp,
                                        update_gate * prev_state_h,
                                        name='%sout' % name)

        return next_h, [next_h]
Exemple #10
0
class LayerNormLSTMCell(mx.rnn.LSTMCell):
    """
    Long-Short Term Memory (LSTM) network cell with layer normalization across gates.
    Based on Jimmy Lei Ba et al: Layer Normalization (https://arxiv.org/pdf/1607.06450.pdf)

    :param num_hidden: number of RNN hidden units. Number of units in output symbol.
    :param prefix: prefix for name of layers (and name of weight if params is None).
    :param params: RNNParams or None. Container for weight sharing between cells. Created if None.
    :param forget_bias: bias added to forget gate, default 1.0. Jozefowicz et al. 2015 recommends setting this to 1.0.
    :param norm_scale: scale/gain for layer normalization.
    :param norm_shift: shift/bias after layer normalization.
    """
    def __init__(self,
                 num_hidden: int,
                 prefix: str = 'lnlstm_',
                 params: Optional[mx.rnn.RNNParams] = None,
                 forget_bias: float = 1.0,
                 norm_scale: float = 1.0,
                 norm_shift: float = 0.0) -> None:
        super(LayerNormLSTMCell, self).__init__(num_hidden, prefix, params,
                                                forget_bias)
        self._iN = LayerNormalization(
            num_hidden=num_hidden * 4,
            prefix="%si2h" % self._prefix,
            scale=self.params.get('i2h_scale',
                                  shape=(num_hidden * 4, ),
                                  init=mx.init.Constant(value=norm_scale)),
            shift=self.params.get('i2h_shift',
                                  shape=(num_hidden * 4, ),
                                  init=mx.init.Constant(value=norm_shift)))
        self._hN = LayerNormalization(
            num_hidden=num_hidden * 4,
            prefix="%sh2h" % self._prefix,
            scale=self.params.get('h2h_scale',
                                  shape=(num_hidden * 4, ),
                                  init=mx.init.Constant(value=norm_scale)),
            shift=self.params.get('h2h_shift',
                                  shape=(num_hidden * 4, ),
                                  init=mx.init.Constant(value=norm_shift)))
        self._cN = LayerNormalization(
            num_hidden=num_hidden,
            prefix="%sc" % self._prefix,
            scale=self.params.get('c_scale',
                                  shape=(num_hidden, ),
                                  init=mx.init.Constant(value=norm_scale)),
            shift=self.params.get('c_shift',
                                  shape=(num_hidden, ),
                                  init=mx.init.Constant(value=norm_shift)))
        self._shape_fix = None

    def __call__(self, inputs, states):
        self._counter += 1
        name = '%st%d_' % (self._prefix, self._counter)
        i2h = mx.sym.FullyConnected(data=inputs,
                                    weight=self._iW,
                                    bias=self._iB,
                                    num_hidden=self._num_hidden * 4,
                                    name='%si2h' % name)
        if self._counter == 0:
            self._shape_fix = mx.sym.zeros_like(i2h)
        else:
            assert self._shape_fix is not None
        h2h = mx.sym.FullyConnected(data=states[0],
                                    weight=self._hW,
                                    bias=self._hB,
                                    num_hidden=self._num_hidden * 4,
                                    name='%sh2h' % name)
        gates = self._iN.normalize(i2h) + self._hN.normalize(self._shape_fix +
                                                             h2h)
        in_gate, forget_gate, in_transform, out_gate = mx.sym.split(
            gates, num_outputs=4, axis=1, name="%sslice" % name)
        in_gate = mx.sym.Activation(in_gate,
                                    act_type="sigmoid",
                                    name='%si' % name)
        forget_gate = mx.sym.Activation(forget_gate,
                                        act_type="sigmoid",
                                        name='%sf' % name)
        in_transform = mx.sym.Activation(in_transform,
                                         act_type="tanh",
                                         name='%sc' % name)
        out_gate = mx.sym.Activation(out_gate,
                                     act_type="sigmoid",
                                     name='%so' % name)
        next_c = mx.sym._internal._plus(forget_gate * states[1],
                                        in_gate * in_transform,
                                        name='%sstate' % name)
        next_h = mx.sym._internal._mul(out_gate,
                                       mx.sym.Activation(
                                           self._cN.normalize(next_c),
                                           act_type="tanh"),
                                       name='%sout' % name)
        return next_h, [next_h, next_c]
Exemple #11
0
class RecurrentDecoder(Decoder):
    """
    Class to generate the decoder part of the computation graph in sequence-to-sequence models.
    The architecture is based on Luong et al, 2015: Effective Approaches to Attention-based Neural Machine Translation.

    :param config: Configuration for recurrent decoder.
    :param attention: Attention model.
    :param lexicon: Optional Lexicon.
    :param prefix: Decoder symbol prefix.
    """
    def __init__(self,
                 config: RecurrentDecoderConfig,
                 attention: attentions.Attention,
                 lexicon: Optional[lexicons.Lexicon] = None,
                 prefix=C.DECODER_PREFIX) -> None:
        # TODO: implement variant without input feeding
        self.rnn_config = config.rnn_config
        self.target_vocab_size = config.vocab_size
        self.num_target_embed = config.num_embed
        self.attention = attention
        self.weight_tying = config.weight_tying
        self.context_gating = config.context_gating
        self.layer_norm = config.layer_normalization
        self.lexicon = lexicon
        self.prefix = prefix

        self.num_hidden = self.rnn_config.num_hidden

        if self.context_gating:
            self.gate_w = mx.sym.Variable("%sgate_weight" % prefix)
            self.gate_b = mx.sym.Variable("%sgate_bias" % prefix)
            self.mapped_rnn_output_w = mx.sym.Variable(
                "%smapped_rnn_output_weight" % prefix)
            self.mapped_rnn_output_b = mx.sym.Variable(
                "%smapped_rnn_output_bias" % prefix)
            self.mapped_context_w = mx.sym.Variable("%smapped_context_weight" %
                                                    prefix)
            self.mapped_context_b = mx.sym.Variable("%smapped_context_bias" %
                                                    prefix)

        # Stacked RNN
        self.rnn = rnn.get_stacked_rnn(self.rnn_config, self.prefix)
        # RNN init state parameters
        self._create_layer_parameters()

        # Hidden state parameters
        self.hidden_w = mx.sym.Variable("%shidden_weight" % prefix)
        self.hidden_b = mx.sym.Variable("%shidden_bias" % prefix)
        self.hidden_norm = LayerNormalization(
            self.num_hidden, prefix="%shidden_norm" %
            prefix) if self.layer_norm else None
        # Embedding & output parameters
        self.embedding = encoder.Embedding(self.num_target_embed,
                                           self.target_vocab_size,
                                           prefix=C.TARGET_EMBEDDING_PREFIX,
                                           dropout=0.)  # TODO dropout?
        if self.weight_tying:
            check_condition(
                self.num_hidden == self.num_target_embed,
                "Weight tying requires target embedding size and rnn_num_hidden to be equal"
            )
            self.cls_w = self.embedding.embed_weight
        else:
            self.cls_w = mx.sym.Variable("%scls_weight" % prefix)
        self.cls_b = mx.sym.Variable("%scls_bias" % prefix)

    def get_num_hidden(self) -> int:
        """
        Returns the representation size of this decoder.

        :return: Number of hidden units.
        """
        return self.num_hidden

    def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]:
        """
        Returns a list of RNNCells used by this decoder.
        """
        return [self.rnn]

    def _create_layer_parameters(self):
        """
        Creates parameters for encoder last state transformation into decoder layer initial states.
        """
        self.init_ws, self.init_bs = [], []
        self.init_norms = []
        for state_idx, (_, init_num_hidden) in enumerate(self.rnn.state_shape):
            self.init_ws.append(
                mx.sym.Variable("%senc2decinit_%d_weight" %
                                (self.prefix, state_idx)))
            self.init_bs.append(
                mx.sym.Variable("%senc2decinit_%d_bias" %
                                (self.prefix, state_idx)))
            if self.layer_norm:
                self.init_norms.append(
                    LayerNormalization(num_hidden=init_num_hidden,
                                       prefix="%senc2decinit_%d_norm" %
                                       (self.prefix, state_idx)))

    def create_layer_input_variables(self, batch_size: int) \
            -> Tuple[List[mx.sym.Symbol], List[mx.io.DataDesc], List[str]]:
        """
        Creates RNN layer state variables. Used for inference.
        Returns nested list of layer_states variables, flat list of layer shapes (for module binding),
        and a flat list of layer names (for BucketingModule's data names)

        :param batch_size: Batch size.
        """
        layer_states, layer_shapes, layer_names = [], [], []
        for state_idx, (_, init_num_hidden) in enumerate(self.rnn.state_shape):
            name = "%senc2decinit_%d" % (self.prefix, state_idx)
            layer_states.append(mx.sym.Variable(name))
            layer_shapes.append(
                mx.io.DataDesc(name=name,
                               shape=(batch_size, init_num_hidden),
                               layout=C.BATCH_MAJOR))
            layer_names.append(name)
        return layer_states, layer_shapes, layer_names

    def compute_init_states(self, source_encoded: mx.sym.Symbol,
                            source_length: mx.sym.Symbol) -> DecoderState:
        """
        Computes initial states of the decoder, hidden state, and one for each RNN layer.
        Init states for RNN layers are computed using 1 non-linear FC with the last state of the encoder as input.

        :param source_encoded: Concatenated encoder states. Shape: (source_seq_len, batch_size, encoder_num_hidden).
        :param source_length: Lengths of source sequences. Shape: (batch_size,).
        :return: Decoder state.
        """
        # initial decoder hidden state
        hidden = mx.sym.tile(data=mx.sym.expand_dims(data=source_length * 0,
                                                     axis=1),
                             reps=(1, self.num_hidden))
        # initial states for each layer
        layer_states = []
        for state_idx, (_, init_num_hidden) in enumerate(self.rnn.state_shape):
            init = mx.sym.FullyConnected(
                data=mx.sym.SequenceLast(data=source_encoded,
                                         sequence_length=source_length,
                                         use_sequence_length=True),
                num_hidden=init_num_hidden,
                weight=self.init_ws[state_idx],
                bias=self.init_bs[state_idx],
                name="%senc2decinit_%d" % (self.prefix, state_idx))
            if self.layer_norm:
                init = self.init_norms[state_idx].normalize(init)
            init = mx.sym.Activation(data=init,
                                     act_type="tanh",
                                     name="%senc2dec_inittanh_%d" %
                                     (self.prefix, state_idx))
            layer_states.append(init)
        return DecoderState(hidden, layer_states)

    def _step(
            self,
            word_vec_prev: mx.sym.Symbol,
            state: DecoderState,
            attention_func: Callable,
            attention_state: attentions.AttentionState,
            seq_idx: int = 0
    ) -> Tuple[DecoderState, attentions.AttentionState]:
        """
        Performs single-time step in the RNN, given previous word vector, previous hidden state, attention function,
        and RNN layer states.

        :param word_vec_prev: Embedding of previous target word. Shape: (batch_size, num_target_embed).
        :param state: Decoder state consisting of hidden and layer states.
        :param attention_func: Attention function to produce context vector.
        :param attention_state: Previous attention state.
        :param seq_idx: Decoder time step.
        :return: (new decoder state, updated attention state).
        """
        # (1) RNN step
        # concat previous word embedding and previous hidden state
        rnn_input = mx.sym.concat(word_vec_prev,
                                  state.hidden,
                                  dim=1,
                                  name="%sconcat_target_context_t%d" %
                                  (self.prefix, seq_idx))
        # rnn_output: (batch_size, rnn_num_hidden)
        # next_layer_states: num_layers * [batch_size, rnn_num_hidden]
        rnn_output, layer_states = self.rnn(rnn_input, state.layer_states)

        # (2) Attention step
        attention_input = self.attention.make_input(seq_idx, word_vec_prev,
                                                    rnn_output)
        attention_state = attention_func(attention_input, attention_state)

        # (3) Combine context with hidden state
        if self.context_gating:
            # context: (batch_size, encoder_num_hidden)
            # gate: (batch_size, rnn_num_hidden)
            gate = mx.sym.FullyConnected(data=mx.sym.concat(
                word_vec_prev, rnn_output, attention_state.context, dim=1),
                                         num_hidden=self.num_hidden,
                                         weight=self.gate_w,
                                         bias=self.gate_b)
            gate = mx.sym.Activation(data=gate,
                                     act_type="sigmoid",
                                     name="%sgate_activation_t%d" %
                                     (self.prefix, seq_idx))

            # mapped_rnn_output: (batch_size, rnn_num_hidden)
            mapped_rnn_output = mx.sym.FullyConnected(
                data=rnn_output,
                num_hidden=self.num_hidden,
                weight=self.mapped_rnn_output_w,
                bias=self.mapped_rnn_output_b,
                name="%smapped_rnn_output_fc_t%d" % (self.prefix, seq_idx))
            # mapped_context: (batch_size, rnn_num_hidden)
            mapped_context = mx.sym.FullyConnected(
                data=attention_state.context,
                num_hidden=self.num_hidden,
                weight=self.mapped_context_w,
                bias=self.mapped_context_b,
                name="%smapped_context_fc_t%d" % (self.prefix, seq_idx))

            # hidden: (batch_size, rnn_num_hidden)
            hidden = mx.sym.Activation(
                data=gate * mapped_rnn_output + (1 - gate) * mapped_context,
                act_type="tanh",
                name="%snext_hidden_t%d" % (self.prefix, seq_idx))

        else:
            # hidden: (batch_size, rnn_num_hidden)
            hidden = mx.sym.FullyConnected(
                data=mx.sym.concat(rnn_output, attention_state.context, dim=1),
                # use same number of hidden states as RNN
                num_hidden=self.num_hidden,
                weight=self.hidden_w,
                bias=self.hidden_b)

            if self.layer_norm:
                hidden = self.hidden_norm.normalize(hidden)

            # hidden: (batch_size, rnn_num_hidden)
            hidden = mx.sym.Activation(data=hidden,
                                       act_type="tanh",
                                       name="%snext_hidden_t%d" %
                                       (self.prefix, seq_idx))

        return DecoderState(hidden, layer_states), attention_state

    def decode(
            self,
            source_encoded: mx.sym.Symbol,
            source_seq_len: int,
            source_length: mx.sym.Symbol,
            target: mx.sym.Symbol,
            target_seq_len: int,
            source_lexicon: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol:
        """
        Returns decoder logits with batch size and target sequence length collapsed into a single dimension.

        :param source_encoded: Concatenated encoder states. Shape: (source_seq_len, batch_size, encoder_num_hidden).
        :param source_seq_len: Maximum source sequence length.
        :param source_length: Lengths of source sequences. Shape: (batch_size,).
        :param target: Target sequence. Shape: (batch_size, target_seq_len).
        :param target_seq_len: Maximum target sequence length.
        :param source_lexicon: Lexical biases for current sentence.
               Shape: (batch_size, target_vocab_size, source_seq_len)
        :return: Logits of next-word predictions for target sequence.
                 Shape: (batch_size * target_seq_len, target_vocab_size)
        """
        # process encoder states
        source_encoded_batch_major = mx.sym.swapaxes(
            source_encoded, dim1=0, dim2=1, name='source_encoded_batch_major')

        # embed and slice target words
        # target_embed: (batch_size, target_seq_len, num_target_embed)
        target_embed, _, _ = self.embedding.encode(target, None,
                                                   target_seq_len)
        # target_embed: target_seq_len * (batch_size, num_target_embed)
        target_embed = mx.sym.split(data=target_embed,
                                    num_outputs=target_seq_len,
                                    axis=1,
                                    squeeze_axis=True)

        # get recurrent attention function conditioned on source
        attention_func = self.attention.on(source_encoded_batch_major,
                                           source_length, source_seq_len)
        attention_state = self.attention.get_initial_state(
            source_length, source_seq_len)

        # initialize decoder states
        # hidden: (batch_size, rnn_num_hidden)
        # layer_states: List[(batch_size, state_num_hidden]
        state = self.compute_init_states(source_encoded, source_length)

        # hidden_all: target_seq_len * (batch_size, 1, rnn_num_hidden)
        hidden_all = []

        # TODO: possible alternative: feed back the context vector instead of the hidden (see lamtram)

        lexical_biases = []

        self.rnn.reset()
        # TODO remove this once mxnet.rnn.SequentialRNNCell.reset() invokes recursive calls on layer cells
        for cell in self.rnn._cells:
            cell.reset()

        for seq_idx in range(target_seq_len):
            # hidden: (batch_size, rnn_num_hidden)
            state, attention_state = self._step(target_embed[seq_idx], state,
                                                attention_func,
                                                attention_state, seq_idx)

            # hidden_expanded: (batch_size, 1, rnn_num_hidden)
            hidden_all.append(mx.sym.expand_dims(data=state.hidden, axis=1))

            if source_lexicon is not None:
                assert self.lexicon is not None, "source_lexicon should not be None if no lexicon available"
                lexical_biases.append(
                    self.lexicon.calculate_lex_bias(source_lexicon,
                                                    attention_state.probs))

        # concatenate along time axis
        # hidden_concat: (batch_size, target_seq_len, rnn_num_hidden)
        hidden_concat = mx.sym.concat(*hidden_all,
                                      dim=1,
                                      name="%shidden_concat" % self.prefix)
        # hidden_concat: (batch_size * target_seq_len, rnn_num_hidden)
        hidden_concat = mx.sym.reshape(data=hidden_concat,
                                       shape=(-1, self.num_hidden))

        # logits: (batch_size * target_seq_len, target_vocab_size)
        logits = mx.sym.FullyConnected(data=hidden_concat,
                                       num_hidden=self.target_vocab_size,
                                       weight=self.cls_w,
                                       bias=self.cls_b,
                                       name=C.LOGITS_NAME)

        if source_lexicon is not None:
            # lexical_biases_concat: (batch_size, target_seq_len, target_vocab_size)
            lexical_biases_concat = mx.sym.concat(*lexical_biases,
                                                  dim=1,
                                                  name='lex_bias_concat')
            # lexical_biases_concat: (batch_size * target_seq_len, target_vocab_size)
            lexical_biases_concat = mx.sym.reshape(
                data=lexical_biases_concat, shape=(-1, self.target_vocab_size))
            logits = mx.sym.broadcast_add(lhs=logits,
                                          rhs=lexical_biases_concat,
                                          name='%s_plus_lex_bias' %
                                          C.LOGITS_NAME)

        return logits

    def predict(
        self,
        word_id_prev: mx.sym.Symbol,
        state_prev: DecoderState,
        attention_func: Callable,
        attention_state_prev: attentions.AttentionState,
        source_lexicon: Optional[mx.sym.Symbol] = None,
        softmax_temperature: Optional[float] = None
    ) -> Tuple[mx.sym.Symbol, DecoderState, attentions.AttentionState]:
        """
        Given previous word id, attention function, previous hidden state and RNN layer states,
        returns Softmax predictions (not a loss symbol), next hidden state, and next layer
        states. Used for inference.

        :param word_id_prev: Previous target word id. Shape: (1,).
        :param state_prev: Previous decoder state consisting of hidden and layer states.
        :param attention_func: Attention function to produce context vector.
        :param attention_state_prev: Previous attention state.
        :param source_lexicon: Lexical biases for current sentence.
               Shape: (batch_size, target_vocab_size, source_seq_len).
        :param softmax_temperature: Optional parameter to control steepness of softmax distribution.
        :return: (predicted next-word distribution, decoder state, attention state).
        """
        # target side embedding
        word_vec_prev, _, _ = self.embedding.encode(word_id_prev, None, 1)

        # state.hidden: (batch_size, rnn_num_hidden)
        # attention_state.dynamic_source: (batch_size, source_seq_len, coverage_num_hidden)
        # attention_state.probs: (batch_size, source_seq_len)
        state, attention_state = self._step(word_vec_prev, state_prev,
                                            attention_func,
                                            attention_state_prev)

        # logits: (batch_size, target_vocab_size)
        logits = mx.sym.FullyConnected(data=state.hidden,
                                       num_hidden=self.target_vocab_size,
                                       weight=self.cls_w,
                                       bias=self.cls_b,
                                       name=C.LOGITS_NAME)

        if source_lexicon is not None:
            assert self.lexicon is not None
            # lex_bias: (batch_size, 1, target_vocab_size)
            lex_bias = self.lexicon.calculate_lex_bias(source_lexicon,
                                                       attention_state.probs)
            # lex_bias: (batch_size, target_vocab_size)
            lex_bias = mx.sym.reshape(data=lex_bias,
                                      shape=(-1, self.target_vocab_size))
            logits = mx.sym.broadcast_add(lhs=logits,
                                          rhs=lex_bias,
                                          name='%s_plus_lex_bias' %
                                          C.LOGITS_NAME)

        if softmax_temperature is not None:
            logits /= softmax_temperature

        softmax_out = mx.sym.softmax(data=logits, name=C.SOFTMAX_NAME)
        return softmax_out, state, attention_state
Exemple #12
0
class MlpAttention(Attention):
    """
    Attention computed through a one-layer MLP with num_hidden units [Luong et al, 2015].

    :math:`score(h_t, h_s) = \\mathbf{W}_a tanh(\\mathbf{W}_c [h_t, h_s] + b)`

    :math:`a = softmax(score(*, h_s))`

    Optionally, if attention_coverage_type is not None, attention uses dynamic source encoding ('coverage' mechanism)
    as in Tu et al. (2016): Modeling Coverage for Neural Machine Translation.

    :math:`score(h_t, h_s) = \\mathbf{W}_a tanh(\\mathbf{W}_c [h_t, h_s, c_s] + b)`

    :math:`c_s` is the decoder time-step dependent source encoding which is updated using the current
    decoder state.

    :param input_previous_word: Feed the previous target embedding into the attention mechanism.
    :param attention_num_hidden: Number of hidden units.
    :param attention_coverage_type: The type of update for the dynamic source encoding.
           If None, no dynamic source encoding is done.
    :param attention_coverage_num_hidden: Number of hidden units for coverage attention.
    :param prefix: Layer name prefix.
    :param layer_normalization: If true, normalizes hidden layer outputs before tanh activation.
    """
    def __init__(self,
                 input_previous_word: bool,
                 attention_num_hidden: int,
                 attention_coverage_type: Optional[str] = None,
                 attention_coverage_num_hidden: int = 1,
                 prefix='',
                 layer_normalization: bool = False) -> None:
        dynamic_source_num_hidden = 1 if attention_coverage_type is None else attention_coverage_num_hidden
        super().__init__(input_previous_word=input_previous_word,
                         dynamic_source_num_hidden=dynamic_source_num_hidden)
        self.prefix = prefix
        self.attention_num_hidden = attention_num_hidden
        # input (encoder) to hidden
        self.att_e2h_weight = mx.sym.Variable("%satt_e2h_weight" % prefix)
        # input (query) to hidden
        self.att_q2h_weight = mx.sym.Variable("%satt_q2h_weight" % prefix)
        # hidden to score
        self.att_h2s_weight = mx.sym.Variable("%satt_h2s_weight" % prefix)
        # dynamic source (coverage) weights and settings
        # input (coverage) to hidden
        self.att_c2h_weight = mx.sym.Variable(
            "%satt_c2h_weight" % prefix) if attention_coverage_type else None
        self.coverage = sockeye.coverage.get_coverage(
            attention_coverage_type, dynamic_source_num_hidden,
            layer_normalization) if attention_coverage_type else None

        if layer_normalization:
            self._ln = LayerNormalization(num_hidden=attention_num_hidden,
                                          prefix="att_norm")
        else:
            self._ln = None

    def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol,
           source_seq_len: int) -> Callable:
        """
        Returns callable to be used for recurrent attention in a sequence decoder.
        The callable is a recurrent function of the form:
        AttentionState = attend(AttentionInput, AttentionState).

        :param source: Shape: (batch_size, seq_len, encoder_num_hidden).
        :param source_length: Shape: (batch_size,).
        :param source_seq_len: Maximum length of source sequences.
        :return: Attention callable.
        """

        coverage_func = self.coverage.on(
            source, source_length, source_seq_len) if self.coverage else None

        # (batch_size * seq_len, attention_num_hidden)
        source_hidden = mx.sym.FullyConnected(
            data=mx.sym.reshape(data=source,
                                shape=(-3, -1),
                                name="%satt_flat_source" % self.prefix),
            weight=self.att_e2h_weight,
            num_hidden=self.attention_num_hidden,
            no_bias=True,
            name="%satt_source_hidden_fc" % self.prefix)

        # (batch_size, seq_len, attention_num_hidden)
        source_hidden = mx.sym.reshape(
            source_hidden,
            shape=(-1, source_seq_len, self.attention_num_hidden),
            name="%satt_source_hidden" % self.prefix)

        def attend(att_input: AttentionInput,
                   att_state: AttentionState) -> AttentionState:
            """
            Returns updated attention state given attention input and current attention state.

            :param att_input: Attention input as returned by make_input().
            :param att_state: Current attention state
            :return: Updated attention state.
            """

            # (batch_size, attention_num_hidden)
            query_hidden = mx.sym.FullyConnected(
                data=att_input.query,
                weight=self.att_q2h_weight,
                num_hidden=self.attention_num_hidden,
                no_bias=True,
                name="%satt_query_hidden" % self.prefix)

            # (batch_size, 1, attention_num_hidden)
            query_hidden = mx.sym.expand_dims(
                data=query_hidden,
                axis=1,
                name="%satt_query_hidden_expanded" % self.prefix)

            attention_hidden_lhs = source_hidden
            if self.coverage:
                # (batch_size * seq_len, attention_num_hidden)
                dynamic_hidden = mx.sym.FullyConnected(
                    data=mx.sym.reshape(data=att_state.dynamic_source,
                                        shape=(-3, -1),
                                        name="%satt_flat_dynamic_source" %
                                        self.prefix),
                    weight=self.att_c2h_weight,
                    num_hidden=self.attention_num_hidden,
                    no_bias=True,
                    name="%satt_dynamic_source_hidden_fc" % self.prefix)

                # (batch_size, seq_len, attention_num_hidden)
                dynamic_hidden = mx.sym.reshape(
                    dynamic_hidden,
                    shape=(-1, source_seq_len, self.attention_num_hidden),
                    name="%satt_dynamic_source_hidden" % self.prefix)

                # (batch_size, seq_len, attention_num_hidden
                attention_hidden_lhs = dynamic_hidden + source_hidden

            # (batch_size, seq_len, attention_num_hidden)
            attention_hidden = mx.sym.broadcast_add(
                lhs=attention_hidden_lhs,
                rhs=query_hidden,
                name="%satt_query_plus_input" % self.prefix)

            # (batch_size * seq_len, attention_num_hidden)
            attention_hidden = mx.sym.reshape(
                data=attention_hidden,
                shape=(-3, -1),
                name="%satt_query_plus_input_before_fc" % self.prefix)

            if self._ln is not None:
                attention_hidden = self._ln.normalize(attention_hidden)

            # (batch_size * seq_len, attention_num_hidden)
            attention_hidden = mx.sym.Activation(attention_hidden,
                                                 act_type="tanh",
                                                 name="%satt_hidden" %
                                                 self.prefix)

            # (batch_size * seq_len, 1)
            attention_scores = mx.sym.FullyConnected(
                data=attention_hidden,
                weight=self.att_h2s_weight,
                num_hidden=1,
                no_bias=True,
                name="%sraw_att_score_fc" % self.prefix)

            # (batch_size, seq_len, 1)
            attention_scores = mx.sym.reshape(attention_scores,
                                              shape=(-1, source_seq_len, 1),
                                              name="%sraw_att_score_fc" %
                                              self.prefix)

            context, attention_probs = get_context_and_attention_probs(
                source, source_length, attention_scores)

            dynamic_source = att_state.dynamic_source
            if self.coverage:
                # update dynamic source encoding
                # Note: this is a slight change to the Tu et al, 2016 paper: input to the coverage update
                # is the attention input query, not the previous decoder state.
                dynamic_source = coverage_func(
                    prev_hidden=att_input.query,
                    attention_prob_scores=attention_probs,
                    prev_coverage=att_state.dynamic_source)

            return AttentionState(context=context,
                                  probs=attention_probs,
                                  dynamic_source=dynamic_source)

        return attend
Exemple #13
0
class ActivationCoverage(Coverage):
    """
    Implements a coverage mechanism whose updates are performed by a Perceptron with
    configurable activation function.

    :param coverage_num_hidden: Number of hidden units for coverage vectors.
    :param activation: Type of activation for Perceptron.
    :param layer_normalization: If true, applies layer normalization before non-linear activation.
    :param prefix: Layer name prefix.
    """
    def __init__(self,
                 coverage_num_hidden: int,
                 activation: str,
                 layer_normalization: bool,
                 prefix="cov_") -> None:
        self.prefix = prefix
        self.activation = activation
        self.num_hidden = coverage_num_hidden
        # input (encoder) to hidden
        self.cov_e2h_weight = mx.sym.Variable("%se2h_weight" % self.prefix)
        # decoder to hidden
        self.cov_dec2h_weight = mx.sym.Variable("%si2h_weight" % self.prefix)
        # previous coverage to hidden
        self.cov_prev2h_weight = mx.sym.Variable("%sprev2h_weight" %
                                                 self.prefix)
        # attention scores to hidden
        self.cov_a2h_weight = mx.sym.Variable("%sa2h_weight" % self.prefix)
        # optional layer normalization
        self.layer_norm = None
        if layer_normalization and not self.num_hidden != 1:
            self.layer_norm = LayerNormalization(
                self.num_hidden, prefix="%snorm" %
                prefix) if layer_normalization else None

    def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol,
           source_seq_len: int) -> Callable:
        """
        Returns callable to be used for updating coverage vectors in a sequence decoder.

        :param source: Shape: (batch_size, seq_len, encoder_num_hidden).
        :param source_length: Shape: (batch_size,).
        :param source_seq_len: Maximum length of source sequences.
        :return: Coverage callable.
        """

        # (batch_size * seq_len, coverage_hidden_num)
        source_hidden = mx.sym.FullyConnected(
            data=mx.sym.reshape(data=source,
                                shape=(-3, -1),
                                name="%sflat_source" % self.prefix),
            weight=self.cov_e2h_weight,
            no_bias=True,
            num_hidden=self.num_hidden,
            name="%ssource_hidden_fc" % self.prefix)

        # (batch_size, seq_len, coverage_hidden_num)
        source_hidden = mx.sym.reshape(source_hidden,
                                       shape=(-1, source_seq_len,
                                              self.num_hidden),
                                       name="%ssource_hidden" % self.prefix)

        def update_coverage(prev_hidden: mx.sym.Symbol,
                            attention_prob_scores: mx.sym.Symbol,
                            prev_coverage: mx.sym.Symbol):
            """
            :param prev_hidden: Previous hidden decoder state. Shape: (batch_size, decoder_num_hidden).
            :param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len, 1).
            :param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden).
            :return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden).
            """

            # (batch_size * seq_len, coverage_hidden_num)
            coverage_hidden = mx.sym.FullyConnected(
                data=mx.sym.reshape(data=prev_coverage,
                                    shape=(-3, -1),
                                    name="%sflat_previous" % self.prefix),
                weight=self.cov_prev2h_weight,
                no_bias=True,
                num_hidden=self.num_hidden,
                name="%sprevious_hidden_fc" % self.prefix)

            # (batch_size, source_seq_len, coverage_hidden_num)
            coverage_hidden = mx.sym.reshape(
                coverage_hidden,
                shape=(-1, source_seq_len, self.num_hidden),
                name="%sprevious_hidden" % self.prefix)

            # (batch_size, source_seq_len, 1)
            attention_prob_score = mx.sym.expand_dims(attention_prob_scores,
                                                      axis=2)

            # (batch_size * source_seq_len, coverage_num_hidden)
            attention_hidden = mx.sym.FullyConnected(
                data=mx.sym.reshape(attention_prob_score,
                                    shape=(-3, 0),
                                    name="%sreshape_att_probs" % self.prefix),
                weight=self.cov_a2h_weight,
                no_bias=True,
                num_hidden=self.num_hidden,
                name="%sattention_fc" % self.prefix)

            # (batch_size, source_seq_len, coverage_num_hidden)
            attention_hidden = mx.sym.reshape(
                attention_hidden,
                shape=(-1, source_seq_len, self.num_hidden),
                name="%sreshape_att" % self.prefix)

            # (batch_size, coverage_num_hidden)
            prev_hidden = mx.sym.FullyConnected(data=prev_hidden,
                                                weight=self.cov_dec2h_weight,
                                                no_bias=True,
                                                num_hidden=self.num_hidden,
                                                name="%sdecoder_hidden")

            # (batch_size, 1, coverage_num_hidden)
            prev_hidden = mx.sym.expand_dims(
                data=prev_hidden,
                axis=1,
                name="%sinput_decoder_hidden_expanded" % self.prefix)

            # (batch_size, source_seq_len, coverage_num_hidden)
            intermediate = mx.sym.broadcast_add(lhs=source_hidden,
                                                rhs=prev_hidden,
                                                name="%ssource_plus_hidden" %
                                                self.prefix)

            # (batch_size, source_seq_len, coverage_num_hidden)
            updated_coverage = intermediate + attention_hidden + coverage_hidden

            if self.layer_norm is not None:
                updated_coverage = self.layer_norm.normalize(updated_coverage)

            # (batch_size, seq_len, coverage_num_hidden)
            coverage = mx.sym.Activation(data=updated_coverage,
                                         act_type=self.activation,
                                         name="%sactivation" % self.prefix)

            return mask_coverage(coverage, source_length)

        return update_coverage
Exemple #14
0
    def __init__(self,
                 num_hidden: int,
                 attention: sockeye.attention.Attention,
                 target_vocab_size: int,
                 num_target_embed: int,
                 num_layers=1,
                 prefix=C.DECODER_PREFIX,
                 weight_tying=False,
                 dropout=0.0,
                 cell_type: str = C.LSTM_TYPE,
                 residual: bool = False,
                 forget_bias: float = 0.0,
                 lexicon: Optional[sockeye.lexicon.Lexicon] = None,
                 context_gating: bool = False,
                 layer_normalization: bool = False) -> None:
        # TODO: implement variant without input feeding
        self.num_layers = num_layers
        self.prefix = prefix
        self.dropout = dropout
        self.num_hidden = num_hidden
        self.attention = attention
        self.target_vocab_size = target_vocab_size
        self.num_target_embed = num_target_embed
        self.context_gating = context_gating
        if self.context_gating:
            self.gate_w = mx.sym.Variable("%sgate_weight" % prefix)
            self.gate_b = mx.sym.Variable("%sgate_bias" % prefix)
            self.mapped_rnn_output_w = mx.sym.Variable(
                "%smapped_rnn_output_weight" % prefix)
            self.mapped_rnn_output_b = mx.sym.Variable(
                "%smapped_rnn_output_bias" % prefix)
            self.mapped_context_w = mx.sym.Variable("%smapped_context_weight" %
                                                    prefix)
            self.mapped_context_b = mx.sym.Variable("%smapped_context_bias" %
                                                    prefix)
        self.layer_norm = layer_normalization

        # Decoder stacked RNN
        self.rnn = sockeye.rnn.get_stacked_rnn(cell_type, num_hidden,
                                               num_layers, dropout, prefix,
                                               residual, forget_bias)

        # Decoder parameters
        # RNN init state parameters
        self._create_layer_parameters()

        # Hidden state parameters
        self.hidden_w = mx.sym.Variable("%shidden_weight" % prefix)
        self.hidden_b = mx.sym.Variable("%shidden_bias" % prefix)
        self.hidden_norm = LayerNormalization(
            self.num_hidden, prefix="%shidden_norm" %
            prefix) if self.layer_norm else None
        # Embedding & output parameters
        self.embedding = sockeye.encoder.Embedding(
            self.num_target_embed,
            self.target_vocab_size,
            prefix=C.TARGET_EMBEDDING_PREFIX,
            dropout=0.)  # TODO dropout?
        if weight_tying:
            check_condition(
                self.num_hidden == self.num_target_embed,
                "Weight tying requires target embedding size and rnn_num_hidden to be equal"
            )
            self.cls_w = self.embedding.embed_weight
        else:
            self.cls_w = mx.sym.Variable("%scls_weight" % prefix)
        self.cls_b = mx.sym.Variable("%scls_bias" % prefix)

        self.lexicon = lexicon