def __init__(self, input_previous_word: bool, attention_num_hidden: int, attention_coverage_type: Optional[str] = None, attention_coverage_num_hidden: int = 1, prefix='', layer_normalization: bool = False) -> None: dynamic_source_num_hidden = 1 if attention_coverage_type is None else attention_coverage_num_hidden super().__init__(input_previous_word=input_previous_word, dynamic_source_num_hidden=dynamic_source_num_hidden) self.prefix = prefix self.attention_num_hidden = attention_num_hidden # input (encoder) to hidden self.att_e2h_weight = mx.sym.Variable("%satt_e2h_weight" % prefix) # input (query) to hidden self.att_q2h_weight = mx.sym.Variable("%satt_q2h_weight" % prefix) # hidden to score self.att_h2s_weight = mx.sym.Variable("%satt_h2s_weight" % prefix) # dynamic source (coverage) weights and settings # input (coverage) to hidden self.att_c2h_weight = mx.sym.Variable( "%satt_c2h_weight" % prefix) if attention_coverage_type else None self.coverage = sockeye.coverage.get_coverage( attention_coverage_type, dynamic_source_num_hidden, layer_normalization) if attention_coverage_type else None if layer_normalization: self._ln = LayerNormalization(num_hidden=attention_num_hidden, prefix="att_norm") else: self._ln = None
def __init__(self, num_hidden: int, prefix: str = 'lngru_', params: Optional[mx.rnn.RNNParams] = None, norm_scale: float = 1.0, norm_shift: float = 0.0) -> None: super(LayerNormGRUCell, self).__init__(num_hidden, prefix, params) self._iN = LayerNormalization( num_hidden=num_hidden * 3, prefix="%si2h" % self._prefix, scale=self.params.get('i2h_scale', shape=(num_hidden * 3, ), init=mx.init.Constant(value=norm_scale)), shift=self.params.get('i2h_shift', shape=(num_hidden * 3, ), init=mx.init.Constant(value=norm_shift))) self._hN = LayerNormalization( num_hidden=num_hidden * 3, prefix="%sh2h" % self._prefix, scale=self.params.get('h2h_scale', shape=(num_hidden * 3, ), init=mx.init.Constant(value=norm_scale)), shift=self.params.get('h2h_shift', shape=(num_hidden * 3, ), init=mx.init.Constant(value=norm_shift))) self._shape_fix = None
def __init__(self, config: RecurrentDecoderConfig, attention: attentions.Attention, lexicon: Optional[lexicons.Lexicon] = None, prefix=C.DECODER_PREFIX) -> None: # TODO: implement variant without input feeding self.rnn_config = config.rnn_config self.target_vocab_size = config.vocab_size self.num_target_embed = config.num_embed self.attention = attention self.weight_tying = config.weight_tying self.context_gating = config.context_gating self.layer_norm = config.layer_normalization self.lexicon = lexicon self.prefix = prefix self.num_hidden = self.rnn_config.num_hidden if self.context_gating: self.gate_w = mx.sym.Variable("%sgate_weight" % prefix) self.gate_b = mx.sym.Variable("%sgate_bias" % prefix) self.mapped_rnn_output_w = mx.sym.Variable( "%smapped_rnn_output_weight" % prefix) self.mapped_rnn_output_b = mx.sym.Variable( "%smapped_rnn_output_bias" % prefix) self.mapped_context_w = mx.sym.Variable("%smapped_context_weight" % prefix) self.mapped_context_b = mx.sym.Variable("%smapped_context_bias" % prefix) # Stacked RNN self.rnn = rnn.get_stacked_rnn(self.rnn_config, self.prefix) # RNN init state parameters self._create_layer_parameters() # Hidden state parameters self.hidden_w = mx.sym.Variable("%shidden_weight" % prefix) self.hidden_b = mx.sym.Variable("%shidden_bias" % prefix) self.hidden_norm = LayerNormalization( self.num_hidden, prefix="%shidden_norm" % prefix) if self.layer_norm else None # Embedding & output parameters self.embedding = encoder.Embedding(self.num_target_embed, self.target_vocab_size, prefix=C.TARGET_EMBEDDING_PREFIX, dropout=0.) # TODO dropout? if self.weight_tying: check_condition( self.num_hidden == self.num_target_embed, "Weight tying requires target embedding size and rnn_num_hidden to be equal" ) self.cls_w = self.embedding.embed_weight else: self.cls_w = mx.sym.Variable("%scls_weight" % prefix) self.cls_b = mx.sym.Variable("%scls_bias" % prefix)
def __init__(self, num_hidden: int, prefix: str = 'lnlstm_', params: Optional[mx.rnn.RNNParams] = None, forget_bias: float = 1.0, norm_scale: float = 1.0, norm_shift: float = 0.0) -> None: super(LayerNormLSTMCell, self).__init__(num_hidden, prefix, params, forget_bias) self._iN = LayerNormalization(prefix="%si2h" % self._prefix, scale=self.params.get('i2h_scale', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_scale)), shift=self.params.get('i2h_shift', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_shift))) self._hN = LayerNormalization(prefix="%sh2h" % self._prefix, scale=self.params.get('h2h_scale', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_scale)), shift=self.params.get('h2h_shift', shape=(num_hidden * 4,), init=mx.init.Constant(value=norm_shift))) self._cN = LayerNormalization(prefix="%sc" % self._prefix, scale=self.params.get('c_scale', shape=(num_hidden,), init=mx.init.Constant(value=norm_scale)), shift=self.params.get('c_shift', shape=(num_hidden,), init=mx.init.Constant(value=norm_shift)))
def __init__(self, coverage_num_hidden: int, activation: str, layer_normalization: bool) -> None: super().__init__() self.activation = activation self.num_hidden = coverage_num_hidden # input (encoder) to hidden self.cov_e2h_weight = mx.sym.Variable("%se2h_weight" % self.prefix) # decoder to hidden self.cov_dec2h_weight = mx.sym.Variable("%si2h_weight" % self.prefix) # previous coverage to hidden self.cov_prev2h_weight = mx.sym.Variable("%sprev2h_weight" % self.prefix) # attention scores to hidden self.cov_a2h_weight = mx.sym.Variable("%sa2h_weight" % self.prefix) # optional layer normalization self.layer_norm = None if layer_normalization and not self.num_hidden != 1: self.layer_norm = LayerNormalization(self.num_hidden, prefix="%snorm" % self.prefix) if layer_normalization else None
def __init__(self, num_hidden: int, prefix: str = 'lnggru_', params: Optional[mx.rnn.RNNParams] = None, norm_scale: float = 1.0, norm_shift: float = 0.0) -> None: super(LayerNormPerGateGRUCell, self).__init__(num_hidden, prefix, params) self._norm_layers = list() # type: List[LayerNormalization] for name in ['r', 'z', 'o']: scale = self.params.get('%s_shift' % name, init=mx.init.Constant(value=norm_shift)) shift = self.params.get('%s_scale' % name, init=mx.init.Constant(value=norm_scale)) self._norm_layers.append(LayerNormalization(prefix="%s%s" % (self._prefix, name), scale=scale, shift=shift))
def __init__(self, num_hidden: int, prefix: str = 'lnglstm_', params: Optional[mx.rnn.RNNParams] = None, forget_bias: float = 1.0, norm_scale: float = 1.0, norm_shift: float = 0.0) -> None: super(LayerNormPerGateLSTMCell, self).__init__(num_hidden, prefix, params, forget_bias) self._norm_layers = list() # type: List[LayerNormalization] for name in ['i', 'f', 'c', 'o', 's']: scale = self.params.get('%s_shift' % name, init=mx.init.Constant(value=norm_shift)) shift = self.params.get('%s_scale' % name, init=mx.init.Constant(value=norm_scale if name != "f" else forget_bias)) self._norm_layers.append( LayerNormalization(prefix="%s%s" % (self._prefix, name), scale=scale, shift=shift))
def _create_layer_parameters(self): """ Creates parameters for encoder last state transformation into decoder layer initial states. """ self.init_ws, self.init_bs = [], [] self.init_norms = [] for state_idx, (_, init_num_hidden) in enumerate(self.rnn.state_shape): self.init_ws.append( mx.sym.Variable("%senc2decinit_%d_weight" % (self.prefix, state_idx))) self.init_bs.append( mx.sym.Variable("%senc2decinit_%d_bias" % (self.prefix, state_idx))) if self.layer_norm: self.init_norms.append( LayerNormalization(num_hidden=init_num_hidden, prefix="%senc2decinit_%d_norm" % (self.prefix, state_idx)))
class LayerNormGRUCell(mx.rnn.GRUCell): """ Gated Recurrent Unit (GRU) network cell with layer normalization across gates. Based on Jimmy Lei Ba et al: Layer Normalization (https://arxiv.org/pdf/1607.06450.pdf) :param num_hidden: number of RNN hidden units. Number of units in output symbol. :param prefix: prefix for name of layers (and name of weight if params is None). :param params: RNNParams or None. Container for weight sharing between cells. Created if None. :param norm_scale: scale/gain for layer normalization. :param norm_shift: shift/bias after layer normalization. """ def __init__(self, num_hidden: int, prefix: str = 'lngru_', params: Optional[mx.rnn.RNNParams] = None, norm_scale: float = 1.0, norm_shift: float = 0.0) -> None: super(LayerNormGRUCell, self).__init__(num_hidden, prefix, params) self._iN = LayerNormalization( num_hidden=num_hidden * 3, prefix="%si2h" % self._prefix, scale=self.params.get('i2h_scale', shape=(num_hidden * 3, ), init=mx.init.Constant(value=norm_scale)), shift=self.params.get('i2h_shift', shape=(num_hidden * 3, ), init=mx.init.Constant(value=norm_shift))) self._hN = LayerNormalization( num_hidden=num_hidden * 3, prefix="%sh2h" % self._prefix, scale=self.params.get('h2h_scale', shape=(num_hidden * 3, ), init=mx.init.Constant(value=norm_scale)), shift=self.params.get('h2h_shift', shape=(num_hidden * 3, ), init=mx.init.Constant(value=norm_shift))) self._shape_fix = None def __call__(self, inputs, states): self._counter += 1 seq_idx = self._counter name = '%st%d_' % (self._prefix, seq_idx) prev_state_h = states[0] i2h = mx.sym.FullyConnected(data=inputs, weight=self._iW, bias=self._iB, num_hidden=self._num_hidden * 3, name="%s_i2h" % name) h2h = mx.sym.FullyConnected(data=prev_state_h, weight=self._hW, bias=self._hB, num_hidden=self._num_hidden * 3, name="%s_h2h" % name) if self._counter == 0: self._shape_fix = mx.sym.zeros_like(i2h) else: assert self._shape_fix is not None i2h = self._iN.normalize(i2h) h2h = self._hN.normalize(self._shape_fix + h2h) i2h_r, i2h_z, i2h = mx.sym.split(i2h, num_outputs=3, name="%s_i2h_slice" % name) h2h_r, h2h_z, h2h = mx.sym.split(h2h, num_outputs=3, name="%s_h2h_slice" % name) reset_gate = mx.sym.Activation(i2h_r + h2h_r, act_type="sigmoid", name="%s_r_act" % name) update_gate = mx.sym.Activation(i2h_z + h2h_z, act_type="sigmoid", name="%s_z_act" % name) next_h_tmp = mx.sym.Activation(i2h + reset_gate * h2h, act_type="tanh", name="%s_h_act" % name) next_h = mx.sym._internal._plus((1. - update_gate) * next_h_tmp, update_gate * prev_state_h, name='%sout' % name) return next_h, [next_h]
class LayerNormLSTMCell(mx.rnn.LSTMCell): """ Long-Short Term Memory (LSTM) network cell with layer normalization across gates. Based on Jimmy Lei Ba et al: Layer Normalization (https://arxiv.org/pdf/1607.06450.pdf) :param num_hidden: number of RNN hidden units. Number of units in output symbol. :param prefix: prefix for name of layers (and name of weight if params is None). :param params: RNNParams or None. Container for weight sharing between cells. Created if None. :param forget_bias: bias added to forget gate, default 1.0. Jozefowicz et al. 2015 recommends setting this to 1.0. :param norm_scale: scale/gain for layer normalization. :param norm_shift: shift/bias after layer normalization. """ def __init__(self, num_hidden: int, prefix: str = 'lnlstm_', params: Optional[mx.rnn.RNNParams] = None, forget_bias: float = 1.0, norm_scale: float = 1.0, norm_shift: float = 0.0) -> None: super(LayerNormLSTMCell, self).__init__(num_hidden, prefix, params, forget_bias) self._iN = LayerNormalization( num_hidden=num_hidden * 4, prefix="%si2h" % self._prefix, scale=self.params.get('i2h_scale', shape=(num_hidden * 4, ), init=mx.init.Constant(value=norm_scale)), shift=self.params.get('i2h_shift', shape=(num_hidden * 4, ), init=mx.init.Constant(value=norm_shift))) self._hN = LayerNormalization( num_hidden=num_hidden * 4, prefix="%sh2h" % self._prefix, scale=self.params.get('h2h_scale', shape=(num_hidden * 4, ), init=mx.init.Constant(value=norm_scale)), shift=self.params.get('h2h_shift', shape=(num_hidden * 4, ), init=mx.init.Constant(value=norm_shift))) self._cN = LayerNormalization( num_hidden=num_hidden, prefix="%sc" % self._prefix, scale=self.params.get('c_scale', shape=(num_hidden, ), init=mx.init.Constant(value=norm_scale)), shift=self.params.get('c_shift', shape=(num_hidden, ), init=mx.init.Constant(value=norm_shift))) self._shape_fix = None def __call__(self, inputs, states): self._counter += 1 name = '%st%d_' % (self._prefix, self._counter) i2h = mx.sym.FullyConnected(data=inputs, weight=self._iW, bias=self._iB, num_hidden=self._num_hidden * 4, name='%si2h' % name) if self._counter == 0: self._shape_fix = mx.sym.zeros_like(i2h) else: assert self._shape_fix is not None h2h = mx.sym.FullyConnected(data=states[0], weight=self._hW, bias=self._hB, num_hidden=self._num_hidden * 4, name='%sh2h' % name) gates = self._iN.normalize(i2h) + self._hN.normalize(self._shape_fix + h2h) in_gate, forget_gate, in_transform, out_gate = mx.sym.split( gates, num_outputs=4, axis=1, name="%sslice" % name) in_gate = mx.sym.Activation(in_gate, act_type="sigmoid", name='%si' % name) forget_gate = mx.sym.Activation(forget_gate, act_type="sigmoid", name='%sf' % name) in_transform = mx.sym.Activation(in_transform, act_type="tanh", name='%sc' % name) out_gate = mx.sym.Activation(out_gate, act_type="sigmoid", name='%so' % name) next_c = mx.sym._internal._plus(forget_gate * states[1], in_gate * in_transform, name='%sstate' % name) next_h = mx.sym._internal._mul(out_gate, mx.sym.Activation( self._cN.normalize(next_c), act_type="tanh"), name='%sout' % name) return next_h, [next_h, next_c]
class RecurrentDecoder(Decoder): """ Class to generate the decoder part of the computation graph in sequence-to-sequence models. The architecture is based on Luong et al, 2015: Effective Approaches to Attention-based Neural Machine Translation. :param config: Configuration for recurrent decoder. :param attention: Attention model. :param lexicon: Optional Lexicon. :param prefix: Decoder symbol prefix. """ def __init__(self, config: RecurrentDecoderConfig, attention: attentions.Attention, lexicon: Optional[lexicons.Lexicon] = None, prefix=C.DECODER_PREFIX) -> None: # TODO: implement variant without input feeding self.rnn_config = config.rnn_config self.target_vocab_size = config.vocab_size self.num_target_embed = config.num_embed self.attention = attention self.weight_tying = config.weight_tying self.context_gating = config.context_gating self.layer_norm = config.layer_normalization self.lexicon = lexicon self.prefix = prefix self.num_hidden = self.rnn_config.num_hidden if self.context_gating: self.gate_w = mx.sym.Variable("%sgate_weight" % prefix) self.gate_b = mx.sym.Variable("%sgate_bias" % prefix) self.mapped_rnn_output_w = mx.sym.Variable( "%smapped_rnn_output_weight" % prefix) self.mapped_rnn_output_b = mx.sym.Variable( "%smapped_rnn_output_bias" % prefix) self.mapped_context_w = mx.sym.Variable("%smapped_context_weight" % prefix) self.mapped_context_b = mx.sym.Variable("%smapped_context_bias" % prefix) # Stacked RNN self.rnn = rnn.get_stacked_rnn(self.rnn_config, self.prefix) # RNN init state parameters self._create_layer_parameters() # Hidden state parameters self.hidden_w = mx.sym.Variable("%shidden_weight" % prefix) self.hidden_b = mx.sym.Variable("%shidden_bias" % prefix) self.hidden_norm = LayerNormalization( self.num_hidden, prefix="%shidden_norm" % prefix) if self.layer_norm else None # Embedding & output parameters self.embedding = encoder.Embedding(self.num_target_embed, self.target_vocab_size, prefix=C.TARGET_EMBEDDING_PREFIX, dropout=0.) # TODO dropout? if self.weight_tying: check_condition( self.num_hidden == self.num_target_embed, "Weight tying requires target embedding size and rnn_num_hidden to be equal" ) self.cls_w = self.embedding.embed_weight else: self.cls_w = mx.sym.Variable("%scls_weight" % prefix) self.cls_b = mx.sym.Variable("%scls_bias" % prefix) def get_num_hidden(self) -> int: """ Returns the representation size of this decoder. :return: Number of hidden units. """ return self.num_hidden def get_rnn_cells(self) -> List[mx.rnn.BaseRNNCell]: """ Returns a list of RNNCells used by this decoder. """ return [self.rnn] def _create_layer_parameters(self): """ Creates parameters for encoder last state transformation into decoder layer initial states. """ self.init_ws, self.init_bs = [], [] self.init_norms = [] for state_idx, (_, init_num_hidden) in enumerate(self.rnn.state_shape): self.init_ws.append( mx.sym.Variable("%senc2decinit_%d_weight" % (self.prefix, state_idx))) self.init_bs.append( mx.sym.Variable("%senc2decinit_%d_bias" % (self.prefix, state_idx))) if self.layer_norm: self.init_norms.append( LayerNormalization(num_hidden=init_num_hidden, prefix="%senc2decinit_%d_norm" % (self.prefix, state_idx))) def create_layer_input_variables(self, batch_size: int) \ -> Tuple[List[mx.sym.Symbol], List[mx.io.DataDesc], List[str]]: """ Creates RNN layer state variables. Used for inference. Returns nested list of layer_states variables, flat list of layer shapes (for module binding), and a flat list of layer names (for BucketingModule's data names) :param batch_size: Batch size. """ layer_states, layer_shapes, layer_names = [], [], [] for state_idx, (_, init_num_hidden) in enumerate(self.rnn.state_shape): name = "%senc2decinit_%d" % (self.prefix, state_idx) layer_states.append(mx.sym.Variable(name)) layer_shapes.append( mx.io.DataDesc(name=name, shape=(batch_size, init_num_hidden), layout=C.BATCH_MAJOR)) layer_names.append(name) return layer_states, layer_shapes, layer_names def compute_init_states(self, source_encoded: mx.sym.Symbol, source_length: mx.sym.Symbol) -> DecoderState: """ Computes initial states of the decoder, hidden state, and one for each RNN layer. Init states for RNN layers are computed using 1 non-linear FC with the last state of the encoder as input. :param source_encoded: Concatenated encoder states. Shape: (source_seq_len, batch_size, encoder_num_hidden). :param source_length: Lengths of source sequences. Shape: (batch_size,). :return: Decoder state. """ # initial decoder hidden state hidden = mx.sym.tile(data=mx.sym.expand_dims(data=source_length * 0, axis=1), reps=(1, self.num_hidden)) # initial states for each layer layer_states = [] for state_idx, (_, init_num_hidden) in enumerate(self.rnn.state_shape): init = mx.sym.FullyConnected( data=mx.sym.SequenceLast(data=source_encoded, sequence_length=source_length, use_sequence_length=True), num_hidden=init_num_hidden, weight=self.init_ws[state_idx], bias=self.init_bs[state_idx], name="%senc2decinit_%d" % (self.prefix, state_idx)) if self.layer_norm: init = self.init_norms[state_idx].normalize(init) init = mx.sym.Activation(data=init, act_type="tanh", name="%senc2dec_inittanh_%d" % (self.prefix, state_idx)) layer_states.append(init) return DecoderState(hidden, layer_states) def _step( self, word_vec_prev: mx.sym.Symbol, state: DecoderState, attention_func: Callable, attention_state: attentions.AttentionState, seq_idx: int = 0 ) -> Tuple[DecoderState, attentions.AttentionState]: """ Performs single-time step in the RNN, given previous word vector, previous hidden state, attention function, and RNN layer states. :param word_vec_prev: Embedding of previous target word. Shape: (batch_size, num_target_embed). :param state: Decoder state consisting of hidden and layer states. :param attention_func: Attention function to produce context vector. :param attention_state: Previous attention state. :param seq_idx: Decoder time step. :return: (new decoder state, updated attention state). """ # (1) RNN step # concat previous word embedding and previous hidden state rnn_input = mx.sym.concat(word_vec_prev, state.hidden, dim=1, name="%sconcat_target_context_t%d" % (self.prefix, seq_idx)) # rnn_output: (batch_size, rnn_num_hidden) # next_layer_states: num_layers * [batch_size, rnn_num_hidden] rnn_output, layer_states = self.rnn(rnn_input, state.layer_states) # (2) Attention step attention_input = self.attention.make_input(seq_idx, word_vec_prev, rnn_output) attention_state = attention_func(attention_input, attention_state) # (3) Combine context with hidden state if self.context_gating: # context: (batch_size, encoder_num_hidden) # gate: (batch_size, rnn_num_hidden) gate = mx.sym.FullyConnected(data=mx.sym.concat( word_vec_prev, rnn_output, attention_state.context, dim=1), num_hidden=self.num_hidden, weight=self.gate_w, bias=self.gate_b) gate = mx.sym.Activation(data=gate, act_type="sigmoid", name="%sgate_activation_t%d" % (self.prefix, seq_idx)) # mapped_rnn_output: (batch_size, rnn_num_hidden) mapped_rnn_output = mx.sym.FullyConnected( data=rnn_output, num_hidden=self.num_hidden, weight=self.mapped_rnn_output_w, bias=self.mapped_rnn_output_b, name="%smapped_rnn_output_fc_t%d" % (self.prefix, seq_idx)) # mapped_context: (batch_size, rnn_num_hidden) mapped_context = mx.sym.FullyConnected( data=attention_state.context, num_hidden=self.num_hidden, weight=self.mapped_context_w, bias=self.mapped_context_b, name="%smapped_context_fc_t%d" % (self.prefix, seq_idx)) # hidden: (batch_size, rnn_num_hidden) hidden = mx.sym.Activation( data=gate * mapped_rnn_output + (1 - gate) * mapped_context, act_type="tanh", name="%snext_hidden_t%d" % (self.prefix, seq_idx)) else: # hidden: (batch_size, rnn_num_hidden) hidden = mx.sym.FullyConnected( data=mx.sym.concat(rnn_output, attention_state.context, dim=1), # use same number of hidden states as RNN num_hidden=self.num_hidden, weight=self.hidden_w, bias=self.hidden_b) if self.layer_norm: hidden = self.hidden_norm.normalize(hidden) # hidden: (batch_size, rnn_num_hidden) hidden = mx.sym.Activation(data=hidden, act_type="tanh", name="%snext_hidden_t%d" % (self.prefix, seq_idx)) return DecoderState(hidden, layer_states), attention_state def decode( self, source_encoded: mx.sym.Symbol, source_seq_len: int, source_length: mx.sym.Symbol, target: mx.sym.Symbol, target_seq_len: int, source_lexicon: Optional[mx.sym.Symbol] = None) -> mx.sym.Symbol: """ Returns decoder logits with batch size and target sequence length collapsed into a single dimension. :param source_encoded: Concatenated encoder states. Shape: (source_seq_len, batch_size, encoder_num_hidden). :param source_seq_len: Maximum source sequence length. :param source_length: Lengths of source sequences. Shape: (batch_size,). :param target: Target sequence. Shape: (batch_size, target_seq_len). :param target_seq_len: Maximum target sequence length. :param source_lexicon: Lexical biases for current sentence. Shape: (batch_size, target_vocab_size, source_seq_len) :return: Logits of next-word predictions for target sequence. Shape: (batch_size * target_seq_len, target_vocab_size) """ # process encoder states source_encoded_batch_major = mx.sym.swapaxes( source_encoded, dim1=0, dim2=1, name='source_encoded_batch_major') # embed and slice target words # target_embed: (batch_size, target_seq_len, num_target_embed) target_embed, _, _ = self.embedding.encode(target, None, target_seq_len) # target_embed: target_seq_len * (batch_size, num_target_embed) target_embed = mx.sym.split(data=target_embed, num_outputs=target_seq_len, axis=1, squeeze_axis=True) # get recurrent attention function conditioned on source attention_func = self.attention.on(source_encoded_batch_major, source_length, source_seq_len) attention_state = self.attention.get_initial_state( source_length, source_seq_len) # initialize decoder states # hidden: (batch_size, rnn_num_hidden) # layer_states: List[(batch_size, state_num_hidden] state = self.compute_init_states(source_encoded, source_length) # hidden_all: target_seq_len * (batch_size, 1, rnn_num_hidden) hidden_all = [] # TODO: possible alternative: feed back the context vector instead of the hidden (see lamtram) lexical_biases = [] self.rnn.reset() # TODO remove this once mxnet.rnn.SequentialRNNCell.reset() invokes recursive calls on layer cells for cell in self.rnn._cells: cell.reset() for seq_idx in range(target_seq_len): # hidden: (batch_size, rnn_num_hidden) state, attention_state = self._step(target_embed[seq_idx], state, attention_func, attention_state, seq_idx) # hidden_expanded: (batch_size, 1, rnn_num_hidden) hidden_all.append(mx.sym.expand_dims(data=state.hidden, axis=1)) if source_lexicon is not None: assert self.lexicon is not None, "source_lexicon should not be None if no lexicon available" lexical_biases.append( self.lexicon.calculate_lex_bias(source_lexicon, attention_state.probs)) # concatenate along time axis # hidden_concat: (batch_size, target_seq_len, rnn_num_hidden) hidden_concat = mx.sym.concat(*hidden_all, dim=1, name="%shidden_concat" % self.prefix) # hidden_concat: (batch_size * target_seq_len, rnn_num_hidden) hidden_concat = mx.sym.reshape(data=hidden_concat, shape=(-1, self.num_hidden)) # logits: (batch_size * target_seq_len, target_vocab_size) logits = mx.sym.FullyConnected(data=hidden_concat, num_hidden=self.target_vocab_size, weight=self.cls_w, bias=self.cls_b, name=C.LOGITS_NAME) if source_lexicon is not None: # lexical_biases_concat: (batch_size, target_seq_len, target_vocab_size) lexical_biases_concat = mx.sym.concat(*lexical_biases, dim=1, name='lex_bias_concat') # lexical_biases_concat: (batch_size * target_seq_len, target_vocab_size) lexical_biases_concat = mx.sym.reshape( data=lexical_biases_concat, shape=(-1, self.target_vocab_size)) logits = mx.sym.broadcast_add(lhs=logits, rhs=lexical_biases_concat, name='%s_plus_lex_bias' % C.LOGITS_NAME) return logits def predict( self, word_id_prev: mx.sym.Symbol, state_prev: DecoderState, attention_func: Callable, attention_state_prev: attentions.AttentionState, source_lexicon: Optional[mx.sym.Symbol] = None, softmax_temperature: Optional[float] = None ) -> Tuple[mx.sym.Symbol, DecoderState, attentions.AttentionState]: """ Given previous word id, attention function, previous hidden state and RNN layer states, returns Softmax predictions (not a loss symbol), next hidden state, and next layer states. Used for inference. :param word_id_prev: Previous target word id. Shape: (1,). :param state_prev: Previous decoder state consisting of hidden and layer states. :param attention_func: Attention function to produce context vector. :param attention_state_prev: Previous attention state. :param source_lexicon: Lexical biases for current sentence. Shape: (batch_size, target_vocab_size, source_seq_len). :param softmax_temperature: Optional parameter to control steepness of softmax distribution. :return: (predicted next-word distribution, decoder state, attention state). """ # target side embedding word_vec_prev, _, _ = self.embedding.encode(word_id_prev, None, 1) # state.hidden: (batch_size, rnn_num_hidden) # attention_state.dynamic_source: (batch_size, source_seq_len, coverage_num_hidden) # attention_state.probs: (batch_size, source_seq_len) state, attention_state = self._step(word_vec_prev, state_prev, attention_func, attention_state_prev) # logits: (batch_size, target_vocab_size) logits = mx.sym.FullyConnected(data=state.hidden, num_hidden=self.target_vocab_size, weight=self.cls_w, bias=self.cls_b, name=C.LOGITS_NAME) if source_lexicon is not None: assert self.lexicon is not None # lex_bias: (batch_size, 1, target_vocab_size) lex_bias = self.lexicon.calculate_lex_bias(source_lexicon, attention_state.probs) # lex_bias: (batch_size, target_vocab_size) lex_bias = mx.sym.reshape(data=lex_bias, shape=(-1, self.target_vocab_size)) logits = mx.sym.broadcast_add(lhs=logits, rhs=lex_bias, name='%s_plus_lex_bias' % C.LOGITS_NAME) if softmax_temperature is not None: logits /= softmax_temperature softmax_out = mx.sym.softmax(data=logits, name=C.SOFTMAX_NAME) return softmax_out, state, attention_state
class MlpAttention(Attention): """ Attention computed through a one-layer MLP with num_hidden units [Luong et al, 2015]. :math:`score(h_t, h_s) = \\mathbf{W}_a tanh(\\mathbf{W}_c [h_t, h_s] + b)` :math:`a = softmax(score(*, h_s))` Optionally, if attention_coverage_type is not None, attention uses dynamic source encoding ('coverage' mechanism) as in Tu et al. (2016): Modeling Coverage for Neural Machine Translation. :math:`score(h_t, h_s) = \\mathbf{W}_a tanh(\\mathbf{W}_c [h_t, h_s, c_s] + b)` :math:`c_s` is the decoder time-step dependent source encoding which is updated using the current decoder state. :param input_previous_word: Feed the previous target embedding into the attention mechanism. :param attention_num_hidden: Number of hidden units. :param attention_coverage_type: The type of update for the dynamic source encoding. If None, no dynamic source encoding is done. :param attention_coverage_num_hidden: Number of hidden units for coverage attention. :param prefix: Layer name prefix. :param layer_normalization: If true, normalizes hidden layer outputs before tanh activation. """ def __init__(self, input_previous_word: bool, attention_num_hidden: int, attention_coverage_type: Optional[str] = None, attention_coverage_num_hidden: int = 1, prefix='', layer_normalization: bool = False) -> None: dynamic_source_num_hidden = 1 if attention_coverage_type is None else attention_coverage_num_hidden super().__init__(input_previous_word=input_previous_word, dynamic_source_num_hidden=dynamic_source_num_hidden) self.prefix = prefix self.attention_num_hidden = attention_num_hidden # input (encoder) to hidden self.att_e2h_weight = mx.sym.Variable("%satt_e2h_weight" % prefix) # input (query) to hidden self.att_q2h_weight = mx.sym.Variable("%satt_q2h_weight" % prefix) # hidden to score self.att_h2s_weight = mx.sym.Variable("%satt_h2s_weight" % prefix) # dynamic source (coverage) weights and settings # input (coverage) to hidden self.att_c2h_weight = mx.sym.Variable( "%satt_c2h_weight" % prefix) if attention_coverage_type else None self.coverage = sockeye.coverage.get_coverage( attention_coverage_type, dynamic_source_num_hidden, layer_normalization) if attention_coverage_type else None if layer_normalization: self._ln = LayerNormalization(num_hidden=attention_num_hidden, prefix="att_norm") else: self._ln = None def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: """ Returns callable to be used for recurrent attention in a sequence decoder. The callable is a recurrent function of the form: AttentionState = attend(AttentionInput, AttentionState). :param source: Shape: (batch_size, seq_len, encoder_num_hidden). :param source_length: Shape: (batch_size,). :param source_seq_len: Maximum length of source sequences. :return: Attention callable. """ coverage_func = self.coverage.on( source, source_length, source_seq_len) if self.coverage else None # (batch_size * seq_len, attention_num_hidden) source_hidden = mx.sym.FullyConnected( data=mx.sym.reshape(data=source, shape=(-3, -1), name="%satt_flat_source" % self.prefix), weight=self.att_e2h_weight, num_hidden=self.attention_num_hidden, no_bias=True, name="%satt_source_hidden_fc" % self.prefix) # (batch_size, seq_len, attention_num_hidden) source_hidden = mx.sym.reshape( source_hidden, shape=(-1, source_seq_len, self.attention_num_hidden), name="%satt_source_hidden" % self.prefix) def attend(att_input: AttentionInput, att_state: AttentionState) -> AttentionState: """ Returns updated attention state given attention input and current attention state. :param att_input: Attention input as returned by make_input(). :param att_state: Current attention state :return: Updated attention state. """ # (batch_size, attention_num_hidden) query_hidden = mx.sym.FullyConnected( data=att_input.query, weight=self.att_q2h_weight, num_hidden=self.attention_num_hidden, no_bias=True, name="%satt_query_hidden" % self.prefix) # (batch_size, 1, attention_num_hidden) query_hidden = mx.sym.expand_dims( data=query_hidden, axis=1, name="%satt_query_hidden_expanded" % self.prefix) attention_hidden_lhs = source_hidden if self.coverage: # (batch_size * seq_len, attention_num_hidden) dynamic_hidden = mx.sym.FullyConnected( data=mx.sym.reshape(data=att_state.dynamic_source, shape=(-3, -1), name="%satt_flat_dynamic_source" % self.prefix), weight=self.att_c2h_weight, num_hidden=self.attention_num_hidden, no_bias=True, name="%satt_dynamic_source_hidden_fc" % self.prefix) # (batch_size, seq_len, attention_num_hidden) dynamic_hidden = mx.sym.reshape( dynamic_hidden, shape=(-1, source_seq_len, self.attention_num_hidden), name="%satt_dynamic_source_hidden" % self.prefix) # (batch_size, seq_len, attention_num_hidden attention_hidden_lhs = dynamic_hidden + source_hidden # (batch_size, seq_len, attention_num_hidden) attention_hidden = mx.sym.broadcast_add( lhs=attention_hidden_lhs, rhs=query_hidden, name="%satt_query_plus_input" % self.prefix) # (batch_size * seq_len, attention_num_hidden) attention_hidden = mx.sym.reshape( data=attention_hidden, shape=(-3, -1), name="%satt_query_plus_input_before_fc" % self.prefix) if self._ln is not None: attention_hidden = self._ln.normalize(attention_hidden) # (batch_size * seq_len, attention_num_hidden) attention_hidden = mx.sym.Activation(attention_hidden, act_type="tanh", name="%satt_hidden" % self.prefix) # (batch_size * seq_len, 1) attention_scores = mx.sym.FullyConnected( data=attention_hidden, weight=self.att_h2s_weight, num_hidden=1, no_bias=True, name="%sraw_att_score_fc" % self.prefix) # (batch_size, seq_len, 1) attention_scores = mx.sym.reshape(attention_scores, shape=(-1, source_seq_len, 1), name="%sraw_att_score_fc" % self.prefix) context, attention_probs = get_context_and_attention_probs( source, source_length, attention_scores) dynamic_source = att_state.dynamic_source if self.coverage: # update dynamic source encoding # Note: this is a slight change to the Tu et al, 2016 paper: input to the coverage update # is the attention input query, not the previous decoder state. dynamic_source = coverage_func( prev_hidden=att_input.query, attention_prob_scores=attention_probs, prev_coverage=att_state.dynamic_source) return AttentionState(context=context, probs=attention_probs, dynamic_source=dynamic_source) return attend
class ActivationCoverage(Coverage): """ Implements a coverage mechanism whose updates are performed by a Perceptron with configurable activation function. :param coverage_num_hidden: Number of hidden units for coverage vectors. :param activation: Type of activation for Perceptron. :param layer_normalization: If true, applies layer normalization before non-linear activation. :param prefix: Layer name prefix. """ def __init__(self, coverage_num_hidden: int, activation: str, layer_normalization: bool, prefix="cov_") -> None: self.prefix = prefix self.activation = activation self.num_hidden = coverage_num_hidden # input (encoder) to hidden self.cov_e2h_weight = mx.sym.Variable("%se2h_weight" % self.prefix) # decoder to hidden self.cov_dec2h_weight = mx.sym.Variable("%si2h_weight" % self.prefix) # previous coverage to hidden self.cov_prev2h_weight = mx.sym.Variable("%sprev2h_weight" % self.prefix) # attention scores to hidden self.cov_a2h_weight = mx.sym.Variable("%sa2h_weight" % self.prefix) # optional layer normalization self.layer_norm = None if layer_normalization and not self.num_hidden != 1: self.layer_norm = LayerNormalization( self.num_hidden, prefix="%snorm" % prefix) if layer_normalization else None def on(self, source: mx.sym.Symbol, source_length: mx.sym.Symbol, source_seq_len: int) -> Callable: """ Returns callable to be used for updating coverage vectors in a sequence decoder. :param source: Shape: (batch_size, seq_len, encoder_num_hidden). :param source_length: Shape: (batch_size,). :param source_seq_len: Maximum length of source sequences. :return: Coverage callable. """ # (batch_size * seq_len, coverage_hidden_num) source_hidden = mx.sym.FullyConnected( data=mx.sym.reshape(data=source, shape=(-3, -1), name="%sflat_source" % self.prefix), weight=self.cov_e2h_weight, no_bias=True, num_hidden=self.num_hidden, name="%ssource_hidden_fc" % self.prefix) # (batch_size, seq_len, coverage_hidden_num) source_hidden = mx.sym.reshape(source_hidden, shape=(-1, source_seq_len, self.num_hidden), name="%ssource_hidden" % self.prefix) def update_coverage(prev_hidden: mx.sym.Symbol, attention_prob_scores: mx.sym.Symbol, prev_coverage: mx.sym.Symbol): """ :param prev_hidden: Previous hidden decoder state. Shape: (batch_size, decoder_num_hidden). :param attention_prob_scores: Current attention scores. Shape: (batch_size, source_seq_len, 1). :param prev_coverage: Shape: (batch_size, source_seq_len, coverage_num_hidden). :return: Updated coverage matrix . Shape: (batch_size, source_seq_len, coverage_num_hidden). """ # (batch_size * seq_len, coverage_hidden_num) coverage_hidden = mx.sym.FullyConnected( data=mx.sym.reshape(data=prev_coverage, shape=(-3, -1), name="%sflat_previous" % self.prefix), weight=self.cov_prev2h_weight, no_bias=True, num_hidden=self.num_hidden, name="%sprevious_hidden_fc" % self.prefix) # (batch_size, source_seq_len, coverage_hidden_num) coverage_hidden = mx.sym.reshape( coverage_hidden, shape=(-1, source_seq_len, self.num_hidden), name="%sprevious_hidden" % self.prefix) # (batch_size, source_seq_len, 1) attention_prob_score = mx.sym.expand_dims(attention_prob_scores, axis=2) # (batch_size * source_seq_len, coverage_num_hidden) attention_hidden = mx.sym.FullyConnected( data=mx.sym.reshape(attention_prob_score, shape=(-3, 0), name="%sreshape_att_probs" % self.prefix), weight=self.cov_a2h_weight, no_bias=True, num_hidden=self.num_hidden, name="%sattention_fc" % self.prefix) # (batch_size, source_seq_len, coverage_num_hidden) attention_hidden = mx.sym.reshape( attention_hidden, shape=(-1, source_seq_len, self.num_hidden), name="%sreshape_att" % self.prefix) # (batch_size, coverage_num_hidden) prev_hidden = mx.sym.FullyConnected(data=prev_hidden, weight=self.cov_dec2h_weight, no_bias=True, num_hidden=self.num_hidden, name="%sdecoder_hidden") # (batch_size, 1, coverage_num_hidden) prev_hidden = mx.sym.expand_dims( data=prev_hidden, axis=1, name="%sinput_decoder_hidden_expanded" % self.prefix) # (batch_size, source_seq_len, coverage_num_hidden) intermediate = mx.sym.broadcast_add(lhs=source_hidden, rhs=prev_hidden, name="%ssource_plus_hidden" % self.prefix) # (batch_size, source_seq_len, coverage_num_hidden) updated_coverage = intermediate + attention_hidden + coverage_hidden if self.layer_norm is not None: updated_coverage = self.layer_norm.normalize(updated_coverage) # (batch_size, seq_len, coverage_num_hidden) coverage = mx.sym.Activation(data=updated_coverage, act_type=self.activation, name="%sactivation" % self.prefix) return mask_coverage(coverage, source_length) return update_coverage
def __init__(self, num_hidden: int, attention: sockeye.attention.Attention, target_vocab_size: int, num_target_embed: int, num_layers=1, prefix=C.DECODER_PREFIX, weight_tying=False, dropout=0.0, cell_type: str = C.LSTM_TYPE, residual: bool = False, forget_bias: float = 0.0, lexicon: Optional[sockeye.lexicon.Lexicon] = None, context_gating: bool = False, layer_normalization: bool = False) -> None: # TODO: implement variant without input feeding self.num_layers = num_layers self.prefix = prefix self.dropout = dropout self.num_hidden = num_hidden self.attention = attention self.target_vocab_size = target_vocab_size self.num_target_embed = num_target_embed self.context_gating = context_gating if self.context_gating: self.gate_w = mx.sym.Variable("%sgate_weight" % prefix) self.gate_b = mx.sym.Variable("%sgate_bias" % prefix) self.mapped_rnn_output_w = mx.sym.Variable( "%smapped_rnn_output_weight" % prefix) self.mapped_rnn_output_b = mx.sym.Variable( "%smapped_rnn_output_bias" % prefix) self.mapped_context_w = mx.sym.Variable("%smapped_context_weight" % prefix) self.mapped_context_b = mx.sym.Variable("%smapped_context_bias" % prefix) self.layer_norm = layer_normalization # Decoder stacked RNN self.rnn = sockeye.rnn.get_stacked_rnn(cell_type, num_hidden, num_layers, dropout, prefix, residual, forget_bias) # Decoder parameters # RNN init state parameters self._create_layer_parameters() # Hidden state parameters self.hidden_w = mx.sym.Variable("%shidden_weight" % prefix) self.hidden_b = mx.sym.Variable("%shidden_bias" % prefix) self.hidden_norm = LayerNormalization( self.num_hidden, prefix="%shidden_norm" % prefix) if self.layer_norm else None # Embedding & output parameters self.embedding = sockeye.encoder.Embedding( self.num_target_embed, self.target_vocab_size, prefix=C.TARGET_EMBEDDING_PREFIX, dropout=0.) # TODO dropout? if weight_tying: check_condition( self.num_hidden == self.num_target_embed, "Weight tying requires target embedding size and rnn_num_hidden to be equal" ) self.cls_w = self.embedding.embed_weight else: self.cls_w = mx.sym.Variable("%scls_weight" % prefix) self.cls_b = mx.sym.Variable("%scls_bias" % prefix) self.lexicon = lexicon