Example #1
0
def residual_mlp_layer(x_flat,
                       intermediate_size,
                       initializer_range=0.02,
                       hidden_dropout_prob=0.1):
    """
    :param x: The attention output. It should be [batch_size*seq_length, dim]
    :param intermediate_size: the hidden projection. By default this is the input_dim * 4.
    in the original GPT we would return layer_norm(x_norm + h1) rather than layer_norm(x + h1)
    :return:
    """
    batch_size_seq_length, hidden_size = get_shape_list(x_flat,
                                                        expected_rank=2)
    x_norm = layer_norm(x_flat, name='mlp_ln0')

    intermediate_output = tf.layers.dense(
        x_norm,
        intermediate_size,
        activation=gelu,
        kernel_initializer=create_initializer(initializer_range),
        name='intermediate',
    )

    output_for_residual = tf.layers.dense(
        intermediate_output,
        hidden_size,
        name='output',
        kernel_initializer=create_initializer(initializer_range))
    output_for_residual = dropout(output_for_residual, hidden_dropout_prob)

    layer_output = layer_norm(x_flat + output_for_residual, name='mlp_ln1')
    return layer_output
Example #2
0
    def _step(self, x_h, x_z, x_r, x_m, h_tm1):
        '''
        x_h: input at time t
        x_z: update for x_t
        x_r: reset for x_t
        x_m: mask of x_t
        h_tm1: previous state
        '''

        if self.with_layernorm:
            z_t = layer_norm((x_z + T.dot(h_tm1, self.W_hz) + self.b_z), self.W_z_lnb, self.W_z_lns)
            z_t = T.nnet.sigmoid(z_t)                                                                     

            r_t = layer_norm((x_r + T.dot(h_tm1, self.W_hr) + self.b_r), self.W_r_lnb, self.W_r_lns)
            r_t = T.nnet.sigmoid(r_t)      

            can_h_t = layer_norm((x_h + r_t * T.dot(h_tm1, self.W_hh) + self.b_h), self.W_h_lnb, self.W_h_lns)
            can_h_t = T.tanh(can_h_t)
        else:
            z_t = T.nnet.sigmoid(x_z + T.dot(h_tm1, self.W_hz) + self.b_z)
            r_t = T.nnet.sigmoid(x_r + T.dot(h_tm1, self.W_hr) + self.b_r)
            can_h_t = T.tanh(x_h + r_t * T.dot(h_tm1, self.W_hh) + self.b_h)

        h_t = (1. - z_t) * h_tm1 + z_t * can_h_t

        h_t = x_m[:, None] * h_t + (1. - x_m[:, None]) * h_tm1

        return h_t
Example #3
0
    def _update_coverage(self, cov_tm1, probs, c, h_tm1, fertility=None):
        '''
        cov_tm1:    coverage at time (t-1)
        probs:      attention probabilities at time t
        c:          source annotations
        fertility:  fertility of individual source word
        '''
        if self.coverage_type is 'linguistic':
            assert fertility, 'ferility should be given for linguistic coverage'
            fertility_probs = probs/fertility
            cov = T.unbroadcast(fertility_probs.dimshuffle(0,1,'x'), 2)
            
            # accumulation
            cov = cov_tm1 + cov
        else:
            # we can precompute w*c in advance to minimize the computational cost
            extend_probs = probs.dimshuffle(0,1,'x')
            
            if self.with_layernorm:
                z = layer_norm((T.dot(cov_tm1, self.W_cov_z) + T.dot(extend_probs, self.W_cov_pz) + T.dot(c, self.W_cov_cz) + T.dot(h_tm1, self.W_cov_hz) + self.b_cov_z), self.W_cov_z_lnb, self.W_cov_z_lns)
                z = T.nnet.sigmoid(z)
                r = layer_norm((T.dot(cov_tm1, self.W_cov_r) + T.dot(extend_probs, self.W_cov_pr) + T.dot(c, self.W_cov_cr) + T.dot(h_tm1, self.W_cov_hr) + self.b_cov_r), self.W_cov_r_lnb, self.W_cov_r_lns)
                r = T.nnet.sigmoid(r)
                cov = layer_norm((r * T.dot(cov_tm1, self.W_cov_h) + T.dot(extend_probs, self.W_cov_ph) + T.dot(c, self.W_cov_ch) + T.dot(h_tm1, self.W_cov_hh) + self.b_cov_h), self.W_cov_h_lnb, self.W_cov_h_lns)
                cov = T.tanh(cov)
            else:
                z = T.nnet.sigmoid(T.dot(cov_tm1, self.W_cov_z) + T.dot(extend_probs, self.W_cov_pz) + T.dot(c, self.W_cov_cz) + T.dot(h_tm1, self.W_cov_hz) + self.b_cov_z)
                r = T.nnet.sigmoid(T.dot(cov_tm1, self.W_cov_r) + T.dot(extend_probs, self.W_cov_pr) + T.dot(c, self.W_cov_cr) + T.dot(h_tm1, self.W_cov_hr) + self.b_cov_r)
                cov = T.tanh(r * T.dot(cov_tm1, self.W_cov_h) + T.dot(extend_probs, self.W_cov_ph) + T.dot(c, self.W_cov_ch) + T.dot(h_tm1, self.W_cov_hh) + self.b_cov_h)

            cov = (1-z) * cov_tm1 + z * cov

        return cov
Example #4
0
def multihead_attn(queries,
                   keys,
                   num_units,
                   num_heads,
                   dropout_rate,
                   is_training,
                   restricted=False):
    """
    Args:
      queries: A 3d tensor with shape of [N, T_q, C_q]
      keys: A 3d tensor with shape of [N, T_k, C_k]
    """
    if num_units is None:
        num_units = queries.get_shape().as_list[-1]
    T_q = queries.get_shape().as_list()[1]  # max time length of query
    T_k = keys.get_shape().as_list()[1]  # max time length of key

    Q = tf.layers.dense(queries, num_units)  # (N, T_q, C)
    K = tf.layers.dense(keys, num_units)  # (N, T_k, C)
    V = tf.layers.dense(keys, num_units)  # (N, T_k, C)

    Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0)  # (h*N, T_q, C/h)
    K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0)  # (h*N, T_k, C/h)
    V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0)  # (h*N, T_k, C/h)

    align = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1]))  # (h*N, T_q, T_k)
    align = align / (K_.get_shape().as_list()[-1]**0.5)  # scale

    if restricted:
        paddings = tf.fill(tf.shape(align), float('-inf'))  # exp(-large) -> 0
        lower_tri = tf.ones([T_q, T_k])  # (T_q, T_k)
        lower_tri = tf.linalg.LinearOperatorLowerTriangular(
            lower_tri).to_dense()  # (T_q, T_k)
        masks = tf.tile(tf.expand_dims(lower_tri, 0),
                        [tf.shape(align)[0], 1, 1])  # (h*N, T_q, T_k)
        align = tf.where(tf.equal(masks, 0), paddings,
                         align)  # (h*N, T_q, T_k)

    align = tf.nn.softmax(align)  # (h*N, T_q, T_k)

    align = tf.layers.dropout(align, dropout_rate,
                              training=is_training)  # (h*N, T_q, T_k)

    # Weighted sum
    outputs = tf.matmul(align, V_)  # (h*N, T_q, C/h)
    # Restore shape
    outputs = tf.concat(tf.split(outputs, num_heads, axis=0),
                        axis=2)  # (N, T_q, C)
    # Residual connection
    outputs += queries  # (N, T_q, C)
    # Normalize
    outputs = layer_norm(outputs)  # (N, T_q, C)
    return outputs
Example #5
0
    def _step_context(self, x_t, x_m, h_tm1, cz, cr, ch):
        '''
        x_t: input at time t
        x_m: mask of x_t
        h_tm1: previous state
        '''

        if self.with_layernorm:
            z_t = layer_norm((T.dot(x_t, self.W_xz) +
                                 T.dot(h_tm1, self.W_hz) +
                                 T.dot(cz, self.W_cz) + self.b_z), self.W_z_lnb, self.W_z_lns)
            z_t = T.nnet.sigmoid(z_t)

            r_t = layer_norm((T.dot(x_t, self.W_xr) +
                                 T.dot(h_tm1, self.W_hr) +
                                 T.dot(cr, self.W_cr) + self.b_r), self.W_r_lnb, self.W_r_lns)
            r_t = T.nnet.sigmoid(r_t)

            can_h_t = layer_norm((T.dot(x_t, self.W_xh) +
                             r_t * T.dot(h_tm1, self.W_hh) +
                             T.dot(ch, self.W_ch) + self.b_h), self.W_h_lnb, self.W_h_lns)
            can_h_t = T.tanh(can_h_t)
        else:
            z_t = T.nnet.sigmoid(T.dot(x_t, self.W_xz) +
                                 T.dot(h_tm1, self.W_hz) +
                                 T.dot(cz, self.W_cz) + self.b_z)

            r_t = T.nnet.sigmoid(T.dot(x_t, self.W_xr) +
                                 T.dot(h_tm1, self.W_hr) +
                                 T.dot(cr, self.W_cr) + self.b_r)

            can_h_t = T.tanh(T.dot(x_t, self.W_xh) +
                             r_t * T.dot(h_tm1, self.W_hh) +
                             T.dot(ch, self.W_ch) + self.b_h)

        h_t = (1 - z_t) * h_tm1 + z_t * can_h_t

        h_t = x_m[:, None] * h_t + (1. - x_m[:, None])*h_tm1

        return h_t
Example #6
0
    def apply(self, state_below, mask_below=None, init_state=None, context=None):

        n_steps = state_below.shape[0]
        if state_below.ndim == 3:
            batch_size = state_below.shape[1]
        else:
            batch_size = 1
            state_below = state_below.reshape((n_steps, batch_size, state_below.shape[1]))

        if mask_below is None:
            mask_below = T.alloc(numpy.float32(1.), n_steps, 1)

        if self.with_context:
            assert context

            if init_state is None:
                init_state = T.tanh(T.dot(context, self.W_c_init))

            c_z = T.dot(context, self.W_cz)
            c_r = T.dot(context, self.W_cr)
            c_h = T.dot(context, self.W_ch)
            if self.with_layernorm:
                c_h = layer_norm(c_h, self.W_ch_lnb, self.W_ch_lns)
                c_z = layer_norm(c_z, self.W_cz_lnb, self.W_cz_lns)
                c_r = layer_norm(c_r, self.W_cr_lnb, self.W_cr_lns)

            non_sequences = [c_z, c_r, c_h]

            rval, updates = theano.scan(self._step_context,
                                        sequences=[state_below, mask_below],
                                        non_sequences=non_sequences,
                                        outputs_info=[init_state],
                                        name=_p(self.pname, 'layers'),
                                        n_steps=n_steps)
        else:
            if init_state is None:
                init_state = T.alloc(numpy.float32(0.), batch_size, self.n_hids)

            state_below_xh = T.dot(state_below, self.W_xh)
            state_below_xz = T.dot(state_below, self.W_xz)
            state_below_xr = T.dot(state_below, self.W_xr)

            if self.with_layernorm:
                state_below_xh = layer_norm(state_below_xh, self.W_xh_lnb, self.W_xh_lns)
                state_below_xz = layer_norm(state_below_xz, self.W_xz_lnb, self.W_xz_lns)
                state_below_xr = layer_norm(state_below_xr, self.W_xr_lnb, self.W_xr_lns)

            sequences = [state_below_xh, state_below_xz, state_below_xr, mask_below]

            rval, updates = theano.scan(self._step,
                                        sequences=sequences,
                                        outputs_info=[init_state],
                                        name=_p(self.pname, 'layers'),
                                        n_steps=n_steps)
        self.output = rval

        return self.output
def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
                         label_ids, label_weights, prev_bplayers=None):
    """Get loss and log probs for the masked LM."""
    input_tensor = gather_indexes(input_tensor, positions)

    with tf.variable_scope("cls/predictions") as prediction_scope:
        # We apply one more non-linear transformation before the output layer.
        # This matrix is not used after pre-training.
        with tf.variable_scope("transform"):
            input_tensor = tf.layers.dense(
                input_tensor,
                units=bert_config.hidden_size,
                activation=utils.get_activation(bert_config.hidden_act),
                kernel_initializer=utils.create_initializer(
                    bert_config.initializer_range))
            input_tensor = utils.layer_norm(input_tensor)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        output_bias = tf.get_variable(
            "output_bias",
            shape=[bert_config.vocab_size],
            initializer=tf.zeros_initializer())
        logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        log_probs = tf.nn.log_softmax(logits, axis=-1)

        label_ids = tf.reshape(label_ids, [-1])
        label_weights = tf.reshape(label_weights, [-1])

        one_hot_labels = tf.one_hot(
            label_ids, depth=bert_config.vocab_size, dtype=tf.float32)

        # The `positions` tensor might be zero-padded (if the sequence is too
        # short to have the maximum number of predictions). The `label_weights`
        # tensor has a value of 1.0 for every real prediction and 0.0 for the
        # padding predictions.
        per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
        numerator = tf.reduce_sum(label_weights * per_example_loss)
        denominator = tf.reduce_sum(label_weights) + 1e-5
        loss = numerator / denominator
        loss_bplayer = BPLayer(loss, prediction_scope, backward_layers=prev_bplayers)

    return (loss, per_example_loss, log_probs, loss_bplayer)
Example #8
0
def attention_func(input_tensor, attention_mask, hidden_size,
                   hidden_dropout_prob, num_attention_heads,
                   attention_head_size, attention_probs_dropout_prob,
                   initializer_range, batch_size, seq_length):
    attention_heads = []
    with tf.variable_scope("attention") as scope:
        with tf.variable_scope("self"):
            attention_head = attention_layer(
                from_tensor=input_tensor,
                to_tensor=input_tensor,
                attention_mask=attention_mask,
                num_attention_heads=num_attention_heads,
                size_per_head=attention_head_size,
                attention_probs_dropout_prob=attention_probs_dropout_prob,
                initializer_range=initializer_range,
                do_return_2d_tensor=True,
                batch_size=batch_size,
                from_seq_length=seq_length,
                to_seq_length=seq_length)
            attention_heads.append(attention_head)
        if len(attention_heads) == 1:
            attention_output = attention_heads[0]
        else:
            # In the case where we have other sequences, we just concatenate
            # them to the self-attention head before the projection.
            attention_output = tf.concat(attention_heads, axis=-1)
        # Run a linear projection of `hidden_size` then add a residual
        # with `layer_input`.
        with tf.variable_scope("output"):
            attention_output = tf.layers.dense(
                attention_output,
                hidden_size,
                kernel_initializer=utils.create_initializer(initializer_range))
            attention_output = utils.dropout(attention_output,
                                             hidden_dropout_prob)
            attention_output = utils.layer_norm(attention_output +
                                                input_tensor)
    return attention_output, scope
Example #9
0
def multihead_attn(queries, keys, num_units, num_heads, dropout_rate,
                   is_training):
    """
    Args:
      queries: A 3d tensor with shape of [N, T_q, C_q]
      keys: A 3d tensor with shape of [N, T_k, C_k]
    """
    if num_units is None:
        num_units = queries.get_shape().as_list[-1]
    T_q = queries.get_shape().as_list()[1]  # max time length of query
    T_k = keys.get_shape().as_list()[1]  # max time length of key

    Q = tf.layers.dense(queries, num_units)  # (N, T_q, C)
    K = tf.layers.dense(keys, num_units)  # (N, T_k, C)
    V = tf.layers.dense(keys, num_units)  # (N, T_k, C)

    Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0)  # (h*N, T_q, C/h)
    K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0)  # (h*N, T_k, C/h)
    V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0)  # (h*N, T_k, C/h)

    outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1]))  # (h*N, T_q, T_k)
    outputs = outputs / (K_.get_shape().as_list()[-1]**0.5)  # scale

    outputs = tf.nn.softmax(outputs)  # (h*N, T_q, T_k)

    outputs = tf.layers.dropout(outputs, dropout_rate,
                                training=is_training)  # (h*N, T_q, T_k)

    # Weighted sum
    outputs = tf.matmul(outputs, V_)  # (h*N, T_q, C/h)
    # Restore shape
    outputs = tf.concat(tf.split(outputs, num_heads, axis=0),
                        axis=2)  # (N, T_q, C)
    # Residual connection
    outputs += queries  # (N, T_q, C)
    # Normalize
    outputs = layer_norm(outputs)  # (N, T_q, C)
    return outputs
Example #10
0
def feedforward_func(input_tensor,
                     intermediate_size,
                     initializer_range,
                     hidden_size,
                     hidden_dropout_prob,
                     intermediate_act_fn=utils.gelu):
    # The activation is only applied to the "intermediate" hidden layer.
    with tf.variable_scope("feedforward") as scope:
        intermediate_output = tf.layers.dense(
            input_tensor,
            intermediate_size,
            activation=intermediate_act_fn,
            kernel_initializer=utils.create_initializer(initializer_range),
            name="intermediate_dense")
        # Down-project back to `hidden_size` then add the residual.
        layer_output = tf.layers.dense(
            intermediate_output,
            hidden_size,
            kernel_initializer=utils.create_initializer(initializer_range),
            name="intermediate_output")
        layer_output = utils.dropout(layer_output, hidden_dropout_prob)
        layer_output = utils.layer_norm(layer_output + input_tensor)
    return layer_output, scope
Example #11
0
def embed(input_ids,
          vocab_size,
          embedding_size,
          position_offset=0,
          initializer_range=0.02,
          max_position_embeddings=512,
          use_one_hot_embeddings=True):
    """reur and position embeddings
    :param input_ids: int Tensor of shape [batch_size, seq_length].
    :param vocab_size: number of words in vocab
    :param embedding_size: dimensionality of the embedding
    :param position_offset: aka number of cached tokens.
    :param initializer_range: float. Range of the weight initialization.
    :param max_position_embeddings: int. Maximum sequence length.
    :param use_one_hot_embeddings: probably want this to be true
    :return: [batch_size, seq_length, embedding_size] embedded tensor
    """
    (batch_size, seq_length) = get_shape_list(input_ids, expected_rank=2)

    embedding_table = tf.compat.v1.get_variable(
        name='word_embed',
        shape=[vocab_size, embedding_size],
        initializer=create_initializer(initializer_range),
    )

    assert_op = tf.compat.v1.assert_less_equal(tf.reduce_max(input_ids),
                                               vocab_size - 1)
    with tf.control_dependencies([assert_op]):
        if use_one_hot_embeddings:
            flat_input_ids = tf.reshape(input_ids, [-1])
            one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
            output_flat = tf.matmul(one_hot_input_ids, embedding_table)
        else:
            output_flat = tf.nn.embedding_lookup(embedding_table, input_ids)

        embedded_input = tf.reshape(output_flat,
                                    [batch_size, seq_length, embedding_size])

    assert_op = tf.compat.v1.assert_less_equal(seq_length,
                                               max_position_embeddings)

    with tf.control_dependencies([assert_op]):
        full_position_embeddings = tf.compat.v1.get_variable(
            name='pos_embed',
            shape=[max_position_embeddings, embedding_size],
            initializer=create_initializer(initializer_range),
        )
        # Since the position embedding table is a learned variable, we create it
        # using a (long) sequence length `max_position_embeddings`. The actual
        # sequence length might be shorter than this, for faster training of
        # tasks that do not have long sequences.
        #
        # So `full_position_embeddings` is effectively an embedding table
        # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
        # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
        # perform a slice.
        if position_offset == 0:
            embedded_input += tf.slice(full_position_embeddings, [0, 0],
                                       [seq_length, -1])[None]
        else:
            # Tensorflow is too stupid to allow slicing
            flat_pos_ids = (tf.range(seq_length, dtype=tf.int32) +
                            position_offset)
            one_hot_pos_ids = tf.one_hot(flat_pos_ids,
                                         depth=max_position_embeddings)

            # [seq_length, full_position_embeddings], [full_position_embeddings, dim]
            seq_embeds = tf.matmul(one_hot_pos_ids, full_position_embeddings)
            embedded_input += seq_embeds[None]

            # embedded_input += tf.slice(full_position_embeddings[position_offset:], [0, 0], [seq_length, -1])[None]

    return layer_norm(embedded_input, name='embed_norm'), embedding_table
Example #12
0
    def apply(self, state_below, mask_below=None, init_state=None,
              init_context=None, c=None, c_mask=None, one_step=False,
              # added by Zhaopeng Tu, 2016-04-29
              cov_before=None, fertility=None):

        # assert c, 'Context must be provided'
        # assert c.ndim == 3, 'Context must be 3-d: n_seq * batch_size * dim'

        # state_below: n_steps * batch_size/1 * embedding
        if state_below.ndim == 3:
            n_steps = state_below.shape[0]
            batch_size = state_below.shape[1]
        else:
            batch_size = 1

        # mask
        if mask_below is None: #sampling or beamsearch
            mask_below = T.alloc(numpy.float32(1.), state_below.shape[0], 1)

        if one_step:
            assert init_state, 'previous state mush be provided'

        if init_state is None:
            init_state = self.create_init_state(init_context)
        
        state_below_xh = T.dot(state_below, self.W_xh)
        state_below_xz = T.dot(state_below, self.W_xz)
        state_below_xr = T.dot(state_below, self.W_xr)

        # for attention model
        p_from_c = T.dot(c, self.A_cp).reshape((c.shape[0], c.shape[1], self.n_hids))
        if self.with_layernorm:
            p_from_c = layer_norm(p_from_c, self.c_lnb, self.c_lns)

        if one_step:
            return self._step_attention(state_below_xh, state_below_xz, state_below_xr, \
                                        mask_below, init_state, c, c_mask, p_from_c, \
                                        # added by Zhaopeng Tu, 2016-06-08
                                        cov_tm1=cov_before, fertility=fertility)
        else:
            sequences = [state_below_xh, state_below_xz, state_below_xr, mask_below]
            # decoder hidden state
            outputs_info = [init_state]
            non_sequences = [c, c_mask, p_from_c]
            # added by Zhaopeng Tu, 2016-04-29
            # ctx, probs
            outputs_info += [None, None]
            if self.with_coverage:
                # initialization for coverage
                init_cov = T.unbroadcast(T.zeros((c.shape[0], c.shape[1], self.coverage_dim), dtype='float32'), 2)
                outputs_info.append(init_cov)
                
                # fertility is not constructed outside when training
                if self.coverage_type is 'linguistic':
                    fertility = self._get_fertility(c)
                else:
                    fertility = T.zeros((c.shape[0], c.shape[1]), dtype='float32')
                non_sequences.append(fertility)

            # modified by Zhaopeng Tu, 2016-05-02
            # rval, updates = theano.scan(self._step_attention,
            if not self.with_coverage:
                             # seqs              |  out    |   non_seqs
                fn = lambda  x_h, x_z, x_r, x_m,    h_tm1,     c, c_m, p_from_c :  self._step_attention(x_h, x_z, x_r, x_m, h_tm1, c, c_m, p_from_c)
            else:
                             # seqs              |  out              |   non_seqs
                fn = lambda  x_h, x_z, x_r, x_m,    h_tm1, cov_tm1,      c, c_m, p_from_c, fertility :  self._step_attention(x_h, x_z, x_r, x_m, h_tm1, c, c_m, p_from_c, cov_tm1=cov_tm1, fertility=fertility)

            rval, updates = theano.scan(fn,
                                    sequences=sequences,
                                    non_sequences=non_sequences,
                                    # outputs_info=[init_state, None],
                                    outputs_info=outputs_info,
                                    name=_p(self.pname, 'layers'),
                                    n_steps=n_steps)

            self.output = rval

            return self.output
Example #13
0
    def _step_attention(self, x_h, x_z, x_r, x_m, h_tm1, c, c_m, p_from_c, cov_tm1=None, fertility=None):
        '''
        x_h: input at time t
        x_z: update of input
        x_r: reset of input
        x_m: mask of x_t
        h_tm1: previous state
        # added by Zhaopeng Tu, 2016-04-29
        cov_tm1:  coverage at time (t-1)
        fertility:  fertility of individual source word
        '''

        # for attention model
        source_len = c.shape[0]
        target_num = h_tm1.shape[0]

        # commented by Zhaopeng Tu, 2016-04-29
        # here h1 combines previous hidden state and lastly generated word with GRU
        # note that this is different from the paper
        if self.with_layernorm:
            z1 = layer_norm((T.dot(h_tm1, self.W_n1_z) + x_z + self.b_n1_z), self.W_n1_z_lnb, self.W_n1_z_lns)
            z1 = T.nnet.sigmoid(z1)
            r1 = layer_norm((T.dot(h_tm1, self.W_n1_r) + x_r + self.b_n1_r), self.W_n1_r_lnb, self.W_n1_r_lns)
            r1 = T.nnet.sigmoid(r1)
            h1 = layer_norm((r1 * T.dot(h_tm1, self.W_n1_h) + x_h + self.b_n1_h), self.W_n1_h_lnb, self.W_n1_h_lns)
            h1 = T.tanh(h1)
        else:
            z1 = T.nnet.sigmoid(T.dot(h_tm1, self.W_n1_z) + x_z + self.b_n1_z)
            r1 = T.nnet.sigmoid(T.dot(h_tm1, self.W_n1_r) + x_r + self.b_n1_r)
            h1 = T.tanh(r1 * T.dot(h_tm1, self.W_n1_h) + x_h + self.b_n1_h)

        h1 = z1 * h_tm1 + (1. - z1) * h1
        h1 = x_m[:, None] * h1 + (1. - x_m)[:, None] * h_tm1

        p_from_h = ReplicateLayer(T.dot(h1, self.B_hp), source_len)
        p = p_from_h + p_from_c + self.b_tt

        # added by Zhaopeng Tu, 2016-04-29
        if self.with_coverage:
            p_from_cov = T.dot(cov_tm1, self.C_covp)
            p += p_from_cov

        energy = T.exp(T.dot(T.tanh(p), self.D_pe) + self.c_tt).reshape((source_len, target_num))
        if c_m:
            energy *= c_m

        normalizer = energy.sum(axis=0, keepdims=True)
        probs = energy / normalizer

        ctx = (c * probs.dimshuffle(0, 1, 'x')).sum(axis=0)

        # added by Zhaopeng Tu, 2016-04-29
        # update coverage after producing attention probabilities at time t
        if self.with_coverage:
            cov = self._update_coverage(cov_tm1, probs, c, h_tm1, fertility)

        # commented by Zhaopeng Tu, 2016-04-29
        # this is even more consistent with our context gate
        # h1 corresponds to target context, while ctx corresponds to source context
        # added by Zhaopeng Tu, 2016-05-30
        if self.with_context_gate:
            gate = T.nnet.sigmoid(T.dot(h1, self.W_ctx_h) +
                                  T.dot(ctx, self.W_ctx_c) + self.b_ctx)
            
            # we directly scale h1, since it used in computing both can_h_t and h_t
            h1 = h1 * (1.-gate)
        else:
            gate = 1.

        # modified by Zhaopeng Tu, 2017-11-28
        if self.with_layernorm:
            z_t = layer_norm((T.dot(h1, self.W_hz) + T.dot(ctx, self.W_cz) + self.b_z), self.W_hz_lnb, self.W_hz_lns)
            z_t = T.nnet.sigmoid(z_t)
            r_t = layer_norm((T.dot(h1, self.W_hr) + T.dot(ctx, self.W_cr) + self.b_r), self.W_hr_lnb, self.W_hr_lns)
            r_t = T.nnet.sigmoid(r_t)
            h_t = layer_norm((r_t * T.dot(h1, self.W_hh) + T.dot(ctx, self.W_ch) + self.b_h), self.W_hh_lnb, self.W_hh_lns)
            h_t = T.tanh(h_t)
        else:
            z_t = T.nnet.sigmoid(T.dot(h1, self.W_hz) + gate * T.dot(ctx, self.W_cz) + self.b_z)
            r_t = T.nnet.sigmoid(T.dot(h1, self.W_hr) + gate * T.dot(ctx, self.W_cr) + self.b_r)
            h_t = T.tanh(r_t * T.dot(h1, self.W_hh) + gate * T.dot(ctx, self.W_ch) + self.b_h)

        h_t = z_t * h1 + (1. - z_t) * h_t
        h_t = x_m[:, None] * h_t + (1. - x_m[:, None]) * h1
     
        results = [h_t, ctx, probs]

        if self.with_coverage:
            results += [cov]

        return results