예제 #1
0
파일: lstm_bn.py 프로젝트: TPLink32/nlp
    def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__, reuse=self._reuse):
            c, h = state
            input_size = inputs.get_shape().as_list()[1]

            W_xh = tf.get_variable('W_xh',
                                   [input_size, 4 * self._num_units],
                                   initializer=orthogonal_initializer())
            W_hh = tf.get_variable('W_hh',
                                   [self._num_units, 4 * self._num_units],
                                   initializer=bn_lstm_identity_initializer(0.95))
            bias = tf.get_variable('bias', [4 * self._num_units])

            xh = tf.matmul(inputs, W_xh)
            hh = tf.matmul(h, W_hh)

            bn_xh = batch_norm(xh, 'xh', self._is_training)
            bn_hh = batch_norm(hh, 'hh', self._is_training)

            hidden = bn_xh + bn_hh + bias

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            i, j, f, o = array_ops.split(value=hidden, num_or_size_splits=4, axis=1)

            new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
            bn_new_c = batch_norm(new_c, 'c', self._is_training)
            new_h = self._activation(bn_new_c) * sigmoid(o)
            new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)

            return new_h, new_state
예제 #2
0
    def __call__(self, inputs, state, scope=None):
        """LSTM cell with layer normalization and recurrent dropout."""

        with vs.variable_scope(scope or "layer_norm_basic_lstm_cell"):
            c, h = state
            args = array_ops.concat([inputs, h], 1)
            concat = self._linear(args)

            i, j, f, o = array_ops.split(value=concat,
                                         num_or_size_splits=4,
                                         axis=1)
            if self._layer_norm:
                i = self._norm(i, "input")
                j = self._norm(j, "transform")
                f = self._norm(f, "forget")
                o = self._norm(o, "output")

            g = self._activation(j)
            if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
                g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)

            new_c = (c * math_ops.sigmoid(f + self._forget_bias) +
                     math_ops.sigmoid(i) * g)
            if self._layer_norm:
                new_c = self._norm(new_c, "state")
            new_h = self._activation(new_c) * math_ops.sigmoid(o)

            new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
            return new_h, new_state
예제 #3
0
    def __call__(self, x, states_prev, scope=None):
        """Long short-term memory cell (LSTM)."""
        with vs.variable_scope(scope or self._names["scope"]):
            x_shape = x.get_shape().with_rank(2)
            if not x_shape[1].value:
                raise ValueError("Expecting x_shape[1] to be set: %s" %
                                 str(x_shape))
            if len(states_prev) != 2:
                raise ValueError(
                    "Expecting states_prev to be a tuple with length 2.")
            input_size = x_shape[1].value
            w = vs.get_variable(
                self._names["W"],
                [input_size + self._num_units, self._num_units * 4])
            b = vs.get_variable(self._names["b"],
                                [w.get_shape().with_rank(2)[1].value],
                                initializer=init_ops.constant_initializer(0.0))
            if self._use_peephole:
                wci = vs.get_variable(self._names["wci"], [self._num_units])
                wco = vs.get_variable(self._names["wco"], [self._num_units])
                wcf = vs.get_variable(self._names["wcf"], [self._num_units])
            else:
                wci = wco = wcf = array_ops.zeros([self._num_units])
            (cs_prev, h_prev) = states_prev
            (_, cs, _, _, _, _,
             h) = _lstm_block_cell(x,
                                   cs_prev,
                                   h_prev,
                                   w,
                                   b,
                                   wci=wci,
                                   wco=wco,
                                   wcf=wcf,
                                   forget_bias=self._forget_bias,
                                   use_peephole=self._use_peephole)

            new_state = core_rnn_cell.LSTMStateTuple(cs, h)
            return h, new_state
예제 #4
0
    def __call__(self,
                 inputs,
                 initial_state=None,
                 dtype=None,
                 sequence_length=None,
                 scope=None):
        """Run this LSTM on inputs, starting from the given state.

    Args:
      inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`
        or a list of `time_len` tensors of shape `[batch_size, input_size]`.
      initial_state: a tuple `(initial_cell_state, initial_output)` with tensors
        of shape `[batch_size, self._num_units]`. If this is not provided, the
        cell is expected to create a zero initial state of type `dtype`.
      dtype: The data type for the initial state and expected output. Required
        if `initial_state` is not provided or RNN state has a heterogeneous
        dtype.
      sequence_length: Specifies the length of each sequence in inputs. An
        `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
        time_len).`
        Defaults to `time_len` for each element.
      scope: `VariableScope` for the created subgraph; defaults to class name.

    Returns:
      A pair containing:

      - Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]`
        or a list of time_len tensors of shape `[batch_size, output_size]`,
        to match the type of the `inputs`.
      - Final state: a tuple `(cell_state, output)` matching `initial_state`.

    Raises:
      ValueError: in case of shape mismatches
    """
        with vs.variable_scope(scope or "lstm_block_wrapper"):
            is_list = isinstance(inputs, list)
            if is_list:
                inputs = array_ops.stack(inputs)
            inputs_shape = inputs.get_shape().with_rank(3)
            if not inputs_shape[2]:
                raise ValueError("Expecting inputs_shape[2] to be set: %s" %
                                 inputs_shape)
            batch_size = inputs_shape[1].value
            if batch_size is None:
                batch_size = array_ops.shape(inputs)[1]
            time_len = inputs_shape[0].value
            if time_len is None:
                time_len = array_ops.shape(inputs)[0]

            # Provide default values for initial_state and dtype
            if initial_state is None:
                if dtype is None:
                    raise ValueError(
                        "Either initial_state or dtype needs to be specified")
                z = array_ops.zeros(array_ops.stack(
                    [batch_size, self.num_units]),
                                    dtype=dtype)
                initial_state = z, z
            else:
                if len(initial_state) != 2:
                    raise ValueError(
                        "Expecting initial_state to be a tuple with length 2 or None"
                    )
                if dtype is None:
                    dtype = initial_state[0].dtype

            # create the actual cell
            if sequence_length is not None:
                sequence_length = ops.convert_to_tensor(sequence_length)
            initial_cell_state, initial_output = initial_state  # pylint: disable=unpacking-non-sequence
            cell_states, outputs = self._call_cell(inputs, initial_cell_state,
                                                   initial_output, dtype,
                                                   sequence_length)

            if sequence_length is not None:
                # Mask out the part beyond sequence_length
                mask = array_ops.transpose(
                    array_ops.sequence_mask(sequence_length,
                                            time_len,
                                            dtype=dtype), [1, 0])
                mask = array_ops.tile(array_ops.expand_dims(mask, [-1]),
                                      [1, 1, self.num_units])
                outputs *= mask
                # Prepend initial states to cell_states and outputs for indexing to work
                # correctly,since we want to access the last valid state at
                # sequence_length - 1, which can even be -1, corresponding to the
                # initial state.
                mod_cell_states = array_ops.concat([
                    array_ops.expand_dims(initial_cell_state, [0]), cell_states
                ], 0)
                mod_outputs = array_ops.concat(
                    [array_ops.expand_dims(initial_output, [0]), outputs], 0)
                final_cell_state = self._gather_states(mod_cell_states,
                                                       sequence_length,
                                                       batch_size)
                final_output = self._gather_states(mod_outputs,
                                                   sequence_length, batch_size)
            else:
                # No sequence_lengths used: final state is the last state
                final_cell_state = cell_states[-1]
                final_output = outputs[-1]

            if is_list:
                # Input was a list, so return a list
                outputs = array_ops.unstack(outputs)

            final_state = core_rnn_cell.LSTMStateTuple(final_cell_state,
                                                       final_output)
            return outputs, final_state
예제 #5
0
 def state_size(self):
     return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
예제 #6
0
    def __call__(self, inputs, state, scope=None):
        """Long short-term memory cell (LSTM)."""
        with tf.variable_scope(scope
                               or type(self).__name__):  # "BasicLSTMCell"
            # Parameters of gates are concatenated into one multiply for efficiency.
            if self._state_is_tuple:
                c, h = state
            else:
                c, h = array_ops.split(1, 2, state)
            concat = _linear([inputs, h], 4 * self._num_units, True, 0.,
                             self.weights_init, self.trainable, self.restore,
                             self.reuse)

            # i = input_gate, j = new_input, f = forget_gate, o = output_gate
            i, j, f, o = array_ops.split(value=concat,
                                         num_or_size_splits=4,
                                         axis=1)

            # apply batch normalization to inner state and gates
            if self.batch_norm == True:
                i = batch_normalization(i,
                                        gamma=0.1,
                                        trainable=self.trainable,
                                        restore=self.restore,
                                        reuse=self.reuse)
                j = batch_normalization(j,
                                        gamma=0.1,
                                        trainable=self.trainable,
                                        restore=self.restore,
                                        reuse=self.reuse)
                f = batch_normalization(f,
                                        gamma=0.1,
                                        trainable=self.trainable,
                                        restore=self.restore,
                                        reuse=self.reuse)
                o = batch_normalization(o,
                                        gamma=0.1,
                                        trainable=self.trainable,
                                        restore=self.restore,
                                        reuse=self.reuse)

            new_c = (c * self._inner_activation(f + self._forget_bias) +
                     self._inner_activation(i) * self._activation(j))

            # hidden-to-hidden batch normalizaiton
            if self.batch_norm == True:
                batch_norm_new_c = batch_normalization(
                    new_c,
                    gamma=0.1,
                    trainable=self.trainable,
                    restore=self.restore,
                    reuse=self.reuse)
                new_h = self._activation(
                    batch_norm_new_c) * self._inner_activation(o)
            else:
                new_h = self._activation(new_c) * self._inner_activation(o)

            if self._state_is_tuple:
                new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
            else:
                new_state = array_ops.concat([new_c, new_h], 1)

            # Retrieve RNN Variables
            with tf.variable_scope('Linear', reuse=True):
                self.W = tf.get_variable('Matrix')
                self.b = tf.get_variable('Bias')

            return new_h, new_state
예제 #7
0
 def state_size(self):
     return (core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
             if self._state_is_tuple else 2 * self._num_units)