def __call__(self, inputs, state, scope=None): with tf.variable_scope(scope or type(self).__name__, reuse=self._reuse): c, h = state input_size = inputs.get_shape().as_list()[1] W_xh = tf.get_variable('W_xh', [input_size, 4 * self._num_units], initializer=orthogonal_initializer()) W_hh = tf.get_variable('W_hh', [self._num_units, 4 * self._num_units], initializer=bn_lstm_identity_initializer(0.95)) bias = tf.get_variable('bias', [4 * self._num_units]) xh = tf.matmul(inputs, W_xh) hh = tf.matmul(h, W_hh) bn_xh = batch_norm(xh, 'xh', self._is_training) bn_hh = batch_norm(hh, 'hh', self._is_training) hidden = bn_xh + bn_hh + bias # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=hidden, num_or_size_splits=4, axis=1) new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) bn_new_c = batch_norm(new_c, 'c', self._is_training) new_h = self._activation(bn_new_c) * sigmoid(o) new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) return new_h, new_state
def __call__(self, inputs, state, scope=None): """LSTM cell with layer normalization and recurrent dropout.""" with vs.variable_scope(scope or "layer_norm_basic_lstm_cell"): c, h = state args = array_ops.concat([inputs, h], 1) concat = self._linear(args) i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) if self._layer_norm: i = self._norm(i, "input") j = self._norm(j, "transform") f = self._norm(f, "forget") o = self._norm(o, "output") g = self._activation(j) if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) new_c = (c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g) if self._layer_norm: new_c = self._norm(new_c, "state") new_h = self._activation(new_c) * math_ops.sigmoid(o) new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) return new_h, new_state
def __call__(self, x, states_prev, scope=None): """Long short-term memory cell (LSTM).""" with vs.variable_scope(scope or self._names["scope"]): x_shape = x.get_shape().with_rank(2) if not x_shape[1].value: raise ValueError("Expecting x_shape[1] to be set: %s" % str(x_shape)) if len(states_prev) != 2: raise ValueError( "Expecting states_prev to be a tuple with length 2.") input_size = x_shape[1].value w = vs.get_variable( self._names["W"], [input_size + self._num_units, self._num_units * 4]) b = vs.get_variable(self._names["b"], [w.get_shape().with_rank(2)[1].value], initializer=init_ops.constant_initializer(0.0)) if self._use_peephole: wci = vs.get_variable(self._names["wci"], [self._num_units]) wco = vs.get_variable(self._names["wco"], [self._num_units]) wcf = vs.get_variable(self._names["wcf"], [self._num_units]) else: wci = wco = wcf = array_ops.zeros([self._num_units]) (cs_prev, h_prev) = states_prev (_, cs, _, _, _, _, h) = _lstm_block_cell(x, cs_prev, h_prev, w, b, wci=wci, wco=wco, wcf=wcf, forget_bias=self._forget_bias, use_peephole=self._use_peephole) new_state = core_rnn_cell.LSTMStateTuple(cs, h) return h, new_state
def __call__(self, inputs, initial_state=None, dtype=None, sequence_length=None, scope=None): """Run this LSTM on inputs, starting from the given state. Args: inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]` or a list of `time_len` tensors of shape `[batch_size, input_size]`. initial_state: a tuple `(initial_cell_state, initial_output)` with tensors of shape `[batch_size, self._num_units]`. If this is not provided, the cell is expected to create a zero initial state of type `dtype`. dtype: The data type for the initial state and expected output. Required if `initial_state` is not provided or RNN state has a heterogeneous dtype. sequence_length: Specifies the length of each sequence in inputs. An `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0, time_len).` Defaults to `time_len` for each element. scope: `VariableScope` for the created subgraph; defaults to class name. Returns: A pair containing: - Output: A `3-D` tensor of shape `[time_len, batch_size, output_size]` or a list of time_len tensors of shape `[batch_size, output_size]`, to match the type of the `inputs`. - Final state: a tuple `(cell_state, output)` matching `initial_state`. Raises: ValueError: in case of shape mismatches """ with vs.variable_scope(scope or "lstm_block_wrapper"): is_list = isinstance(inputs, list) if is_list: inputs = array_ops.stack(inputs) inputs_shape = inputs.get_shape().with_rank(3) if not inputs_shape[2]: raise ValueError("Expecting inputs_shape[2] to be set: %s" % inputs_shape) batch_size = inputs_shape[1].value if batch_size is None: batch_size = array_ops.shape(inputs)[1] time_len = inputs_shape[0].value if time_len is None: time_len = array_ops.shape(inputs)[0] # Provide default values for initial_state and dtype if initial_state is None: if dtype is None: raise ValueError( "Either initial_state or dtype needs to be specified") z = array_ops.zeros(array_ops.stack( [batch_size, self.num_units]), dtype=dtype) initial_state = z, z else: if len(initial_state) != 2: raise ValueError( "Expecting initial_state to be a tuple with length 2 or None" ) if dtype is None: dtype = initial_state[0].dtype # create the actual cell if sequence_length is not None: sequence_length = ops.convert_to_tensor(sequence_length) initial_cell_state, initial_output = initial_state # pylint: disable=unpacking-non-sequence cell_states, outputs = self._call_cell(inputs, initial_cell_state, initial_output, dtype, sequence_length) if sequence_length is not None: # Mask out the part beyond sequence_length mask = array_ops.transpose( array_ops.sequence_mask(sequence_length, time_len, dtype=dtype), [1, 0]) mask = array_ops.tile(array_ops.expand_dims(mask, [-1]), [1, 1, self.num_units]) outputs *= mask # Prepend initial states to cell_states and outputs for indexing to work # correctly,since we want to access the last valid state at # sequence_length - 1, which can even be -1, corresponding to the # initial state. mod_cell_states = array_ops.concat([ array_ops.expand_dims(initial_cell_state, [0]), cell_states ], 0) mod_outputs = array_ops.concat( [array_ops.expand_dims(initial_output, [0]), outputs], 0) final_cell_state = self._gather_states(mod_cell_states, sequence_length, batch_size) final_output = self._gather_states(mod_outputs, sequence_length, batch_size) else: # No sequence_lengths used: final state is the last state final_cell_state = cell_states[-1] final_output = outputs[-1] if is_list: # Input was a list, so return a list outputs = array_ops.unstack(outputs) final_state = core_rnn_cell.LSTMStateTuple(final_cell_state, final_output) return outputs, final_state
def state_size(self): return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
def __call__(self, inputs, state, scope=None): """Long short-term memory cell (LSTM).""" with tf.variable_scope(scope or type(self).__name__): # "BasicLSTMCell" # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: c, h = state else: c, h = array_ops.split(1, 2, state) concat = _linear([inputs, h], 4 * self._num_units, True, 0., self.weights_init, self.trainable, self.restore, self.reuse) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) # apply batch normalization to inner state and gates if self.batch_norm == True: i = batch_normalization(i, gamma=0.1, trainable=self.trainable, restore=self.restore, reuse=self.reuse) j = batch_normalization(j, gamma=0.1, trainable=self.trainable, restore=self.restore, reuse=self.reuse) f = batch_normalization(f, gamma=0.1, trainable=self.trainable, restore=self.restore, reuse=self.reuse) o = batch_normalization(o, gamma=0.1, trainable=self.trainable, restore=self.restore, reuse=self.reuse) new_c = (c * self._inner_activation(f + self._forget_bias) + self._inner_activation(i) * self._activation(j)) # hidden-to-hidden batch normalizaiton if self.batch_norm == True: batch_norm_new_c = batch_normalization( new_c, gamma=0.1, trainable=self.trainable, restore=self.restore, reuse=self.reuse) new_h = self._activation( batch_norm_new_c) * self._inner_activation(o) else: new_h = self._activation(new_c) * self._inner_activation(o) if self._state_is_tuple: new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) else: new_state = array_ops.concat([new_c, new_h], 1) # Retrieve RNN Variables with tf.variable_scope('Linear', reuse=True): self.W = tf.get_variable('Matrix') self.b = tf.get_variable('Bias') return new_h, new_state
def state_size(self): return (core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units)