コード例 #1
0
ファイル: tf_helpers.py プロジェクト: afcarl/RnnLM
    def _attention(self, x, h_prev_summary, c_tape, h_tape):
        """
        :param x: batch_size * input_size
        :param h_prev_summary:batch_size * cell_size
        :param c_tape: batch_size * memory_size * cell_size
        :param h_tape: batch_size * memory_size * cell_size
        :return:
        """
        input_size = x.get_shape().with_rank(2)[1]

        with vs.variable_scope("Attention"):
            # mask out empty slots
            mask = tf.sign(tf.reduce_max(tf.abs(h_tape), reduction_indices=2))

            # construct query for attention
            concat_w = rnn_cell._get_concat_variable("query_w", [input_size.value+self._num_units, self._num_units], x.dtype, 1)
            b = vs.get_variable("query_bias", shape=[self._num_units], initializer=array_ops.zeros_initializer, dtype=x.dtype)
            query = tf.nn.bias_add(math_ops.matmul(array_ops.concat(1, [x, h_prev_summary]), concat_w), b)
            query = array_ops.reshape(query, [-1, 1, 1, self._num_units])

            # get the weights for attention
            k = vs.get_variable("AttnW", [1, 1, self._num_units, self._num_units])
            v = vs.get_variable("AttnV", [self._num_units])
            hidden = array_ops.reshape(h_tape, [-1, self._attn_length, 1, self._num_units])
            memory = array_ops.reshape(c_tape, [-1, self._attn_length, 1, self._num_units])

            hidden_features = tf.nn.conv2d(hidden, k, [1, 1, 1, 1], "SAME")
            s = tf.reduce_sum(v * tf.tanh(hidden_features + query), [2, 3])
            a = tf.nn.softmax(s) * mask
            a = a / (tf.reduce_sum(a, reduction_indices=1, keep_dims=True) + 1e-12)

            h_summary = tf.reduce_sum(array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
            c_summary = tf.reduce_sum(array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * memory, [1, 2])

            return a, c_summary, h_summary
コード例 #2
0
ファイル: tf_helpers.py プロジェクト: afcarl/RnnLM
    def _attention(self, x, h_prev_summary, h_tape):
        """
        :param x: batch_size * input_size
        :param h_prev_summary:batch_size * cell_size
        :param h_tape: batch_size * memory_size * cell_size
        :return: weighted sum h, and trucated h_tape
        """
        input_size = x.get_shape().with_rank(2)[1]

        with vs.variable_scope("Attention"):
            # construct query for attention
            concat_w = rnn_cell._get_concat_variable("query_w", [input_size.value+self._num_units, self._attn_size], x.dtype, 1)
            b = vs.get_variable("query_bias", shape=[self._attn_size], initializer=array_ops.zeros_initializer, dtype=x.dtype)
            query = tf.nn.bias_add(math_ops.matmul(array_ops.concat(1, [x, h_prev_summary]), concat_w), b)
            query = array_ops.reshape(query, [-1, 1, 1, self._attn_size])

            # get temporal feature
            mask = tf.sign(tf.reduce_max(tf.abs(h_tape), reduction_indices=2, keep_dims=True))
            t_ids = tf.ones_like(mask) * self._attn_indexes
            mask = tf.squeeze(mask)

            # get the weights for attention
            k = vs.get_variable("AttnW", [1, 1, self._num_units+1, self._attn_size])
            v = vs.get_variable("AttnV", [self._attn_size])
            hidden = tf.concat(2, [h_tape, t_ids])
            hidden = array_ops.reshape(hidden, [-1, self._attn_length, 1, self._num_units+1])

            hidden_features = tf.nn.conv2d(hidden, k, [1, 1, 1, 1], "SAME")

            # compute attention point
            s = tf.reduce_sum(v * tf.tanh(hidden_features + query), [2, 3])
            a = tf.nn.softmax(s)* mask
            a = a / (tf.reduce_sum(a, reduction_indices=1, keep_dims=True) + 1e-12)

            h_summary = tf.reduce_sum(array_ops.reshape(a, [-1, self._attn_length, 1]) * h_tape, reduction_indices=1)

            return h_summary
コード例 #3
0
ファイル: lstm_bn.py プロジェクト: xiaotret/DeepTriage
    def __call__(self, inputs, state, scope=None):
        """Run one step of LSTM.

        Args:
          inputs: input Tensor, 2D, batch x num_units.
          state: state Tensor, 2D, batch x state_size.
          scope: VariableScope for the created subgraph; defaults to "LSTMCell".

        Returns:
          A tuple containing:
          - A 2D, batch x output_dim, Tensor representing the output of the LSTM
            after reading "inputs" when previous state was "state".
            Here output_dim is:
               num_proj if num_proj was set,
               num_units otherwise.
          - A 2D, batch x state_size, Tensor representing the new state of LSTM
            after reading "inputs" when previous state was "state".
        Raises:
          ValueError: if an input_size was specified and the provided inputs have
            a different dimension.
        """
        num_proj = self._num_units if self._num_proj is None else self._num_proj

        c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
        m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])

        dtype = inputs.dtype
        actual_input_size = inputs.get_shape().as_list()[1]
        if self._input_size and self._input_size != actual_input_size:
            raise ValueError(
                "Actual input size not same as specified: %d vs %d." %
                (actual_input_size, self._input_size))

        scope_name = scope or type(self).__name__
        with vs.variable_scope(scope_name,
                               initializer=self._initializer):  # "LSTMCell"
            if not self._bn:
                concat_w = _get_concat_variable(
                    "W", [actual_input_size + num_proj, 4 * self._num_units],
                    dtype, self._num_unit_shards)
            else:
                concat_w_i = _get_concat_variable(
                    "W_i", [actual_input_size, 4 * self._num_units], dtype,
                    self._num_unit_shards)
                concat_w_r = _get_concat_variable(
                    "W_r", [num_proj, 4 * self._num_units], dtype,
                    self._num_unit_shards)

            b = vs.get_variable("B",
                                shape=[4 * self._num_units],
                                initializer=array_ops.zeros_initializer,
                                dtype=dtype)

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            if not self._bn:
                cell_inputs = array_ops.concat(1, [inputs, m_prev])
                lstm_matrix = nn_ops.bias_add(
                    math_ops.matmul(cell_inputs, concat_w), b)

            else:
                lstm_matrix_i = batch_norm(math_ops.matmul(inputs, concat_w_i),
                                           self._deterministic,
                                           shift=False,
                                           scope=scope_name + 'bn_i')
                if self._bn > 1:
                    lstm_matrix_r = batch_norm(math_ops.matmul(
                        m_prev, concat_w_r),
                                               self._deterministic,
                                               shift=False,
                                               scope=scope_name + 'bn_r')
                else:
                    lstm_matrix_r = math_ops.matmul(m_prev, concat_w_r)
                lstm_matrix = nn_ops.bias_add(
                    math_ops.add(lstm_matrix_i, lstm_matrix_r), b)

            i, j, f, o = array_ops.split(1, 4, lstm_matrix)

            # Diagonal connections
            if self._use_peepholes:
                w_f_diag = vs.get_variable("W_F_diag",
                                           shape=[self._num_units],
                                           dtype=dtype)
                w_i_diag = vs.get_variable("W_I_diag",
                                           shape=[self._num_units],
                                           dtype=dtype)
                w_o_diag = vs.get_variable("W_O_diag",
                                           shape=[self._num_units],
                                           dtype=dtype)

            if self._use_peepholes:
                c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) *
                     c_prev + sigmoid(i + w_i_diag * c_prev) * tanh(j))
            else:
                c = (sigmoid(f + self._forget_bias) * c_prev +
                     sigmoid(i) * tanh(j))

            if self._cell_clip is not None:
                c = clip_ops.clip_by_value(c, -self._cell_clip,
                                           self._cell_clip)

            if self._use_peepholes:
                if self._bn > 2:
                    m = sigmoid(o + w_o_diag * c) * tanh(
                        batch_norm(
                            c, self._deterministic, scope=scope_name + 'bn_m'))
                else:
                    m = sigmoid(o + w_o_diag * c) * tanh(c)
            else:
                if self._bn > 2:
                    m = sigmoid(o) * tanh(
                        batch_norm(
                            c, self._deterministic, scope=scope_name + 'bn_m'))
                else:
                    m = sigmoid(o) * tanh(c)

            if self._num_proj is not None:
                concat_w_proj = _get_concat_variable(
                    "W_P", [self._num_units, self._num_proj], dtype,
                    self._num_proj_shards)

                m = math_ops.matmul(m, concat_w_proj)

        if not self._return_gate:
            return m, array_ops.concat(1, [c, m])
        else:
            return m, array_ops.concat(1, [c, m]), (i, j, f, o)
コード例 #4
0
ファイル: tf_lstm.py プロジェクト: ScartleRoy/TF_LSTM_seq_bn
    def __call__(self, inputs, state, scope=None):
        """Run one step of LSTM.

        Args:
          inputs: input Tensor, 2D, batch x num_units.
          state: if `state_is_tuple` is False, this must be a state Tensor,
            `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
            tuple of state Tensors, both `2-D`, with column sizes `c_state` and
            `m_state`.
          scope: VariableScope for the created subgraph; defaults to "LSTMCell".

        Returns:
          A tuple containing:
          - A `2-D, [batch x output_dim]`, Tensor representing the output of the
            LSTM after reading `inputs` when previous state was `state`.
            Here output_dim is:
               num_proj if num_proj was set,
               num_units otherwise.
          - Tensor(s) representing the new state of LSTM after reading `inputs` when
            the previous state was `state`.  Same type and shape(s) as `state`.

        Raises:
          ValueError: If input size cannot be inferred from inputs via
            static shape inference.
        """
        num_proj = self._num_units if self._num_proj is None else self._num_proj

        (c_prev, m_prev) = state

        dtype = inputs.dtype
        input_size = inputs.get_shape().with_rank(2)[1]
        if input_size.value is None:
            raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
        
        scope_name = scope or type(self).__name__
        with vs.variable_scope(scope_name, 
                               initializer=self._initializer):  # "LSTMCell"
            if self._bn:
                concat_w_i = _get_concat_variable(
                    "W_i", [input_size.value, 4 * self._num_units],
                    dtype, self._num_unit_shards)
                
                concat_w_r = _get_concat_variable(
                    "W_r", [num_proj, 4 * self._num_units],
                    dtype, self._num_unit_shards)

                b = vs.get_variable(
                    "B", shape=[4 * self._num_units],
                    initializer=array_ops.zeros_initializer, dtype=dtype)
            else:
                concat_w = _get_concat_variable(
                    "W", [input_size.value + num_proj, 4 * self._num_units],
                    dtype, self._num_unit_shards)

                b = vs.get_variable(
                    "B", shape=[4 * self._num_units],
                    initializer=array_ops.zeros_initializer, dtype=dtype)

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            if self._bn:
                lstm_matrix_i = batch_norm(math_ops.matmul(inputs, concat_w_i), self._deterministic,
                                           shift=False, scope=scope_name+'bn_i')
                if self._bn > 1:
                    lstm_matrix_r = batch_norm(math_ops.matmul(m_prev, concat_w_r), self._deterministic,
                                               shift=False, scope=scope_name+'bn_r')
                else:
                    lstm_matrix_r = math_ops.matmul(m_prev, concat_w_r)
                
                lstm_matrix = nn_ops.bias_add(math_ops.add(lstm_matrix_i, lstm_matrix_r), b)

            else:
                cell_inputs = array_ops.concat(1, [inputs, m_prev])
                lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
            
            i, j, f, o = array_ops.split(1, 4, lstm_matrix)

            # Diagonal connections
            if self._use_peepholes:
                w_f_diag = vs.get_variable(
                    "W_F_diag", shape=[self._num_units], dtype=dtype)
                w_i_diag = vs.get_variable(
                    "W_I_diag", shape=[self._num_units], dtype=dtype)
                w_o_diag = vs.get_variable(
                    "W_O_diag", shape=[self._num_units], dtype=dtype)

            if self._use_peepholes:
                c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
                 sigmoid(i + w_i_diag * c_prev) * self._activation(j))
            else:
                c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
                 self._activation(j))

            if self._cell_clip is not None:
                # pylint: disable=invalid-unary-operand-type
                c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
                # pylint: enable=invalid-unary-operand-type

            if self._use_peepholes:
                if self._bn > 2:
                    m = sigmoid(o + w_o_diag * c) * self._activation(batch_norm(c, self._deterministic,
                                                                                scope=scope_name+'bn_m'))
                else:
                    m = sigmoid(o + w_o_diag * c) * self._activation(c)
            else:
                if self._bn > 2:
                    m = sigmoid(o) * self._activation(batch_norm(c, self._deterministic,
                                                                 scope=scope_name+'bn_m'))
                else:
                    m = sigmoid(o) * self._activation(c)

            if self._num_proj is not None:
                concat_w_proj = _get_concat_variable(
                    "W_P", [self._num_units, self._num_proj],
                    dtype, self._num_proj_shards)

                m = math_ops.matmul(m, concat_w_proj)

        new_state = LSTMStateTuple(c, m)
        
        if not self._return_gate:
            return m, new_state
        else:
            return m, new_state, (i, j, f, o)
コード例 #5
0
ファイル: tf_lstm.py プロジェクト: zhoukangg/TF_LSTM_seq_bn
    def __call__(self, inputs, state, scope=None):
        """Run one step of LSTM.

        Args:
          inputs: input Tensor, 2D, batch x num_units.
          state: if `state_is_tuple` is False, this must be a state Tensor,
            `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
            tuple of state Tensors, both `2-D`, with column sizes `c_state` and
            `m_state`.
          scope: VariableScope for the created subgraph; defaults to "LSTMCell".

        Returns:
          A tuple containing:
          - A `2-D, [batch x output_dim]`, Tensor representing the output of the
            LSTM after reading `inputs` when previous state was `state`.
            Here output_dim is:
               num_proj if num_proj was set,
               num_units otherwise.
          - Tensor(s) representing the new state of LSTM after reading `inputs` when
            the previous state was `state`.  Same type and shape(s) as `state`.

        Raises:
          ValueError: If input size cannot be inferred from inputs via
            static shape inference.
        """
        num_proj = self._num_units if self._num_proj is None else self._num_proj

        (c_prev, m_prev) = state

        dtype = inputs.dtype
        input_size = inputs.get_shape().with_rank(2)[1]
        if input_size.value is None:
            raise ValueError(
                "Could not infer input size from inputs.get_shape()[-1]")

        scope_name = scope or type(self).__name__
        with vs.variable_scope(scope_name,
                               initializer=self._initializer):  # "LSTMCell"
            if self._bn:
                concat_w_i = _get_concat_variable(
                    "W_i", [input_size.value, 4 * self._num_units], dtype,
                    self._num_unit_shards)

                concat_w_r = _get_concat_variable(
                    "W_r", [num_proj, 4 * self._num_units], dtype,
                    self._num_unit_shards)

                b = vs.get_variable("B",
                                    shape=[4 * self._num_units],
                                    initializer=array_ops.zeros_initializer,
                                    dtype=dtype)
            else:
                concat_w = _get_concat_variable(
                    "W", [input_size.value + num_proj, 4 * self._num_units],
                    dtype, self._num_unit_shards)

                b = vs.get_variable("B",
                                    shape=[4 * self._num_units],
                                    initializer=array_ops.zeros_initializer,
                                    dtype=dtype)

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            if self._bn:
                lstm_matrix_i = batch_norm(math_ops.matmul(inputs, concat_w_i),
                                           self._deterministic,
                                           shift=False,
                                           scope=scope_name + 'bn_i')
                if self._bn > 1:
                    lstm_matrix_r = batch_norm(math_ops.matmul(
                        m_prev, concat_w_r),
                                               self._deterministic,
                                               shift=False,
                                               scope=scope_name + 'bn_r')
                else:
                    lstm_matrix_r = math_ops.matmul(m_prev, concat_w_r)

                lstm_matrix = nn_ops.bias_add(
                    math_ops.add(lstm_matrix_i, lstm_matrix_r), b)

            else:
                cell_inputs = array_ops.concat(1, [inputs, m_prev])
                lstm_matrix = nn_ops.bias_add(
                    math_ops.matmul(cell_inputs, concat_w), b)

            i, j, f, o = array_ops.split(1, 4, lstm_matrix)

            # Diagonal connections
            if self._use_peepholes:
                w_f_diag = vs.get_variable("W_F_diag",
                                           shape=[self._num_units],
                                           dtype=dtype)
                w_i_diag = vs.get_variable("W_I_diag",
                                           shape=[self._num_units],
                                           dtype=dtype)
                w_o_diag = vs.get_variable("W_O_diag",
                                           shape=[self._num_units],
                                           dtype=dtype)

            if self._use_peepholes:
                c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) *
                     c_prev +
                     sigmoid(i + w_i_diag * c_prev) * self._activation(j))
            else:
                c = (sigmoid(f + self._forget_bias) * c_prev +
                     sigmoid(i) * self._activation(j))

            if self._cell_clip is not None:
                # pylint: disable=invalid-unary-operand-type
                c = clip_ops.clip_by_value(c, -self._cell_clip,
                                           self._cell_clip)
                # pylint: enable=invalid-unary-operand-type

            if self._use_peepholes:
                if self._bn > 2:
                    m = sigmoid(o + w_o_diag * c) * self._activation(
                        batch_norm(
                            c, self._deterministic, scope=scope_name + 'bn_m'))
                else:
                    m = sigmoid(o + w_o_diag * c) * self._activation(c)
            else:
                if self._bn > 2:
                    m = sigmoid(o) * self._activation(
                        batch_norm(
                            c, self._deterministic, scope=scope_name + 'bn_m'))
                else:
                    m = sigmoid(o) * self._activation(c)

            if self._num_proj is not None:
                concat_w_proj = _get_concat_variable(
                    "W_P", [self._num_units, self._num_proj], dtype,
                    self._num_proj_shards)

                m = math_ops.matmul(m, concat_w_proj)

        new_state = LSTMStateTuple(c, m)

        if not self._return_gate:
            return m, new_state
        else:
            return m, new_state, (i, j, f, o)
コード例 #6
0
ファイル: tf_helpers.py プロジェクト: afcarl/RnnLM
    def __call__(self, inputs, state, scope=None):
        """Run one step of MemoryLSTM.
        Args:
          inputs: input Tensor, 2D, batch x num_units.
          state: if `state_is_tuple` is False, this must be a state Tensor,
            `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
            tuple of state Tensors, both `2-D`, with column sizes `c_state` and
            `m_state`.
          scope: VariableScope for the created subgraph; defaults to "LSTMCell".

        Returns:
          A tuple containing:
          - A `2-D, [batch x output_dim]`, Tensor representing the output of the
            LSTM after reading `inputs` when previous state was `state`.
            Here output_dim is:
               num_proj if num_proj was set,
               num_units otherwise.
          - Tensor(s) representing the new state of LSTM after reading `inputs` when
            the previous state was `state`.  Same type and shape(s) as `state`.

        Raises:
          ValueError: If input size cannot be inferred from inputs via
            static shape inference.
        """
        (a, h_prev_summary, c_tape_prev, h_tape_prev) = state

        dtype = inputs.dtype
        input_size = inputs.get_shape().with_rank(2)[1]
        if input_size.value is None:
            raise ValueError("Could not infer input size from inputs.get_shape()[-1]")

        with vs.variable_scope(scope or type(self).__name__, initializer=self._initializer):  # "LSTMCell"
            concat_w = rnn_cell._get_concat_variable(
                "W", [input_size.value + self._num_units, 4 * self._num_units], dtype, 1)

            b = vs.get_variable("Bias", shape=[4 * self._num_units], initializer=array_ops.zeros_initializer, dtype=dtype)

            # reshape tape to 3D
            c_tape_prev = array_ops.reshape(c_tape_prev, [-1, self._attn_length, self._num_units])
            h_tape_prev = array_ops.reshape(h_tape_prev, [-1, self._attn_length, self._num_units])

            a, new_c_summary, new_h_summary = self._attention(inputs, h_prev_summary, c_tape_prev, h_tape_prev)

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            cell_inputs = array_ops.concat(1, [inputs, new_h_summary])
            lstm_matrix = tf.nn.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
            i, j, f, o = array_ops.split(1, 4, lstm_matrix)

            # Diagonal connections
            if self._use_peepholes:
                w_f_diag = vs.get_variable(
                    "W_F_diag", shape=[self._num_units], dtype=dtype)
                w_i_diag = vs.get_variable(
                    "W_I_diag", shape=[self._num_units], dtype=dtype)
                w_o_diag = vs.get_variable(
                    "W_O_diag", shape=[self._num_units], dtype=dtype)

            if self._use_peepholes:
                c = (sigmoid(f + self._forget_bias + w_f_diag * new_c_summary) * new_c_summary +
                     sigmoid(i + w_i_diag * new_c_summary) * self._activation(j))
            else:
                c = (sigmoid(f + self._forget_bias) * new_c_summary + sigmoid(i) *
                     self._activation(j))

            if self._cell_clip is not None:
                c = tf.clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)

            if self._use_peepholes:
                h = sigmoid(o + w_o_diag * c) * self._activation(c)
            else:
                h = sigmoid(o) * self._activation(c)

            # remove old value
            new_h_tape = array_ops.slice(h_tape_prev, [0, 1, 0], [-1, -1, -1])
            new_c_tape = array_ops.slice(c_tape_prev, [0, 1, 0], [-1, -1, -1])

            # append the new c and h to the tape
            new_c_tape = array_ops.concat(1, [new_c_tape, array_ops.expand_dims(c, 1)])
            new_h_tape = array_ops.concat(1, [new_h_tape, array_ops.expand_dims(h, 1)])

            # flatten the tape to 2D
            new_c_tape = array_ops.reshape(new_c_tape, [-1, self._attn_length * self._num_units])
            new_h_tape = array_ops.reshape(new_h_tape, [-1, self._attn_length * self._num_units])

            new_state = (a, new_h_summary, new_c_tape, new_h_tape)

            return h, new_state
コード例 #7
0
    def __call__(self, inputs, state, scope=None):
        """Run one step of LSTM.

        Args:
          inputs: input Tensor, 2D, batch x num_units.
          state: state Tensor, 2D, batch x state_size.
          scope: VariableScope for the created subgraph; defaults to "LSTMCell".

        Returns:
          A tuple containing:
          - A 2D, batch x output_dim, Tensor representing the output of the LSTM
            after reading "inputs" when previous state was "state".
            Here output_dim is:
               num_proj if num_proj was set,
               num_units otherwise.
          - A 2D, batch x state_size, Tensor representing the new state of LSTM
            after reading "inputs" when previous state was "state".
        Raises:
          ValueError: if an input_size was specified and the provided inputs have
            a different dimension.
        """
        num_proj = self._num_units if self._num_proj is None else self._num_proj

        c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
        m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])

        dtype = inputs.dtype
        actual_input_size = inputs.get_shape().as_list()[1]
        if self._input_size and self._input_size != actual_input_size:
            raise ValueError("Actual input size not same as specified: %d vs %d." %
                             (actual_input_size, self._input_size))

        scope_name = scope or type(self).__name__
        with vs.variable_scope(scope_name,
                               initializer=self._initializer):  # "LSTMCell"
            if not self._bn:
                concat_w = _get_concat_variable(
                    "W", [actual_input_size + num_proj, 4 * self._num_units],
                    dtype, self._num_unit_shards)
            else:
                concat_w_i = _get_concat_variable(
                    "W_i", [actual_input_size, 4 * self._num_units],
                    dtype, self._num_unit_shards)
                concat_w_r = _get_concat_variable(
                    "W_r", [num_proj, 4 * self._num_units],
                    dtype, self._num_unit_shards)

            b = vs.get_variable(
                "B", shape=[4 * self._num_units],
                initializer=array_ops.zeros_initializer, dtype=dtype)

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            if not self._bn:
                cell_inputs = array_ops.concat(1, [inputs, m_prev])
                lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)

            else:
                lstm_matrix_i = batch_norm(math_ops.matmul(inputs, concat_w_i), self._deterministic,
                                           shift=False, scope=scope_name+'bn_i')
                if self._bn > 1:
                    lstm_matrix_r = batch_norm(math_ops.matmul(m_prev, concat_w_r), self._deterministic,
                                               shift=False, scope=scope_name+'bn_r')
                else:
                    lstm_matrix_r = math_ops.matmul(m_prev, concat_w_r)
                lstm_matrix = nn_ops.bias_add(math_ops.add(lstm_matrix_i, lstm_matrix_r), b)

            i, j, f, o = array_ops.split(1, 4, lstm_matrix)

            # Diagonal connections
            if self._use_peepholes:
                w_f_diag = vs.get_variable(
                    "W_F_diag", shape=[self._num_units], dtype=dtype)
                w_i_diag = vs.get_variable(
                    "W_I_diag", shape=[self._num_units], dtype=dtype)
                w_o_diag = vs.get_variable(
                    "W_O_diag", shape=[self._num_units], dtype=dtype)

            if self._use_peepholes:
                c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
                     sigmoid(i + w_i_diag * c_prev) * tanh(j))
            else:
                c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))

            if self._cell_clip is not None:
                c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)

            if self._use_peepholes:
                if self._bn > 2:
                    m = sigmoid(o + w_o_diag * c) * tanh(batch_norm(c, self._deterministic,
                                                                    scope=scope_name+'bn_m'))
                else:
                    m = sigmoid(o + w_o_diag * c) *tanh(c)
            else:
                if self._bn > 2:
                    m = sigmoid(o) * tanh(batch_norm(c, self._deterministic,
                                                     scope=scope_name+'bn_m'))
                else:
                    m = sigmoid(o) * tanh(c)

            if self._num_proj is not None:
                concat_w_proj = _get_concat_variable(
                    "W_P", [self._num_units, self._num_proj],
                    dtype, self._num_proj_shards)

                m = math_ops.matmul(m, concat_w_proj)

        if not self._return_gate:
            return m, array_ops.concat(1, [c, m])
        else:
            return m, array_ops.concat(1, [c, m]), (i, j, f, o)