def lstm_layer(inputs, num_units, lengths=None, stack_size=1, use_cudnn=False, rnn_dropout_drop_amt=0, bidirectional=True): """Create a LSTM layer using the specified backend.""" if use_cudnn: tf.logging.warning( 'cuDNN LSTM no longer supported. Using regular LSTM.') if not bidirectional: raise ValueError('Only bidirectional LSTMs are supported.') assert rnn_dropout_drop_amt == 0 cells_fw = [ contrib_cudnn_rnn.CudnnCompatibleLSTMCell(num_units) for _ in range(stack_size) ] cells_bw = [ contrib_cudnn_rnn.CudnnCompatibleLSTMCell(num_units) for _ in range(stack_size) ] with tf.variable_scope('cudnn_lstm'): (outputs, unused_state_f, unused_state_b) = contrib_rnn.stack_bidirectional_dynamic_rnn( cells_fw, cells_bw, inputs, dtype=tf.float32, sequence_length=lengths, parallel_iterations=1) return outputs
def _createStackBidirectionalDynamicRNN(self, use_gpu, use_shape, use_state_tuple, initial_states_fw=None, initial_states_bw=None, scope=None): self.layers = [2, 3] input_size = 5 batch_size = 2 max_length = 8 initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) sequence_length = tf.placeholder(tf.int64) self.cells_fw = [ rnn_cell.LSTMCell( # pylint:disable=g-complex-comprehension num_units, input_size, initializer=initializer, state_is_tuple=False) for num_units in self.layers ] self.cells_bw = [ rnn_cell.LSTMCell( # pylint:disable=g-complex-comprehension num_units, input_size, initializer=initializer, state_is_tuple=False) for num_units in self.layers ] inputs = max_length * [ tf.placeholder(tf.float32, shape=(batch_size, input_size) if use_shape else (None, input_size)) ] inputs_c = tf.stack(inputs) inputs_c = tf.transpose(inputs_c, [1, 0, 2]) outputs, st_fw, st_bw = contrib_rnn.stack_bidirectional_dynamic_rnn( self.cells_fw, self.cells_bw, inputs_c, initial_states_fw=initial_states_fw, initial_states_bw=initial_states_bw, dtype=tf.float32, sequence_length=sequence_length, scope=scope) # Outputs has shape (batch_size, max_length, 2* layer[-1]. output_shape = [None, max_length, 2 * self.layers[-1]] if use_shape: output_shape[0] = batch_size self.assertAllEqual(outputs.get_shape().as_list(), output_shape) input_value = np.random.randn(batch_size, input_size) return input_value, inputs, outputs, st_fw, st_bw, sequence_length
def _preprocess_controls(self, c_input, length): cells_fw, cells_bw = self._control_preprocessing_cells outputs, _, _ = contrib_rnn.stack_bidirectional_dynamic_rnn( cells_fw, cells_bw, c_input, sequence_length=length, time_major=False, dtype=tf.float32, scope='control_preprocessing') return outputs
def encode(self, sequence, sequence_length): cells_fw, cells_bw = self._cells _, states_fw, states_bw = contrib_rnn.stack_bidirectional_dynamic_rnn( cells_fw, cells_bw, sequence, sequence_length=sequence_length, time_major=False, dtype=tf.float32, scope=self._name_or_scope) # Note we access the outputs (h) from the states since the backward # ouputs are reversed to the input order in the returned outputs. last_h_fw = states_fw[-1][-1].h last_h_bw = states_bw[-1][-1].h return tf.concat([last_h_fw, last_h_bw], 1)
def lstm_layer(inputs, num_units, bidirectional, is_training, lengths=None, stack_size=1, dropout_keep_prob=1): """Create a LSTM layer using the specified backend.""" cells_fw = [] for i in range(stack_size): del i cell = tf.nn.rnn_cell.BasicLSTMCell(num_units) cell = tf.nn.rnn_cell.DropoutWrapper( cell, output_keep_prob=dropout_keep_prob if is_training else 1.0) cells_fw.append(cell) if bidirectional: cells_bw = [] for i in range(stack_size): del i cell = tf.nn.rnn_cell.BasicLSTMCell(num_units) cell = tf.nn.rnn_cell.DropoutWrapper( cell, output_keep_prob=dropout_keep_prob if is_training else 1.0) cells_bw.append(cell) with tf.variable_scope('lstm'): (outputs, unused_state_f, unused_state_b) = contrib_rnn.stack_bidirectional_dynamic_rnn( cells_fw, cells_bw, inputs, dtype=tf.float32, sequence_length=lengths, parallel_iterations=1) return outputs else: with tf.variable_scope('lstm'): outputs, unused_state = tf.nn.dynamic_rnn( cell=tf.nn.rnn_cell.MultiRNNCell(cells_fw), inputs=inputs, dtype=tf.float32, sequence_length=lengths, parallel_iterations=1) return outputs