def lstm_network(x, lstm_sequence_length, lstm_class=rnn.BasicLSTMCell, lstm_layers=(256, ), name='lstm', reuse=False, **kwargs): """ Stage2 network: from features to flattened LSTM output. Defines [multi-layered] dynamic [possibly shared] LSTM network. Returns: batch-wise flattened output tensor; lstm initial state tensor; lstm state output tensor; lstm flattened feed placeholders as tuple. """ with tf.variable_scope(name, reuse=reuse): # Flatten, add action/reward and expand with fake [time] batch? dim to feed LSTM bank: #x = tf.concat([x, a_r] ,axis=-1) #x = tf.concat([batch_flatten(x), a_r], axis=-1) #x = tf.expand_dims(x, [0]) # Define LSTM layers: lstm = [] for size in lstm_layers: lstm += [lstm_class(size, state_is_tuple=True)] lstm = rnn.MultiRNNCell(lstm, state_is_tuple=True) # Get time_dimension as [1]-shaped tensor: step_size = tf.expand_dims(tf.shape(x)[1], [0]) lstm_init_state = lstm.zero_state(1, dtype=tf.float32) lstm_state_pl = rnn_placeholders(lstm.zero_state(1, dtype=tf.float32)) lstm_state_pl_flatten = flatten_nested(lstm_state_pl) lstm_outputs, lstm_state_out = tf.nn.dynamic_rnn( lstm, x, initial_state=lstm_state_pl, sequence_length=lstm_sequence_length, time_major=False) #x_out = tf.reshape(lstm_outputs, [-1, lstm_layers[-1]]) x_out = lstm_outputs return x_out, lstm_init_state, lstm_state_out, lstm_state_pl_flatten
def lstm_network( x, lstm_sequence_length, lstm_class=rnn.BasicLSTMCell, lstm_layers=(256,), static=False, name='lstm', reuse=False, **kwargs ): """ Stage2 network: from features to flattened LSTM output. Defines [multi-layered] dynamic [possibly shared] LSTM network. Returns: batch-wise flattened output tensor; lstm initial state tensor; lstm state output tensor; lstm flattened feed placeholders as tuple. """ with tf.variable_scope(name, reuse=reuse): # Prepare rnn type: if static: rnn_net = tf.nn.static_rnn # Remove time dimension (suppose always get one) and wrap to list: x = [x[:, 0, :]] else: rnn_net = tf.nn.dynamic_rnn # Define LSTM layers: lstm = [] for size in lstm_layers: lstm += [lstm_class(size)] #, state_is_tuple=True)] lstm = rnn.MultiRNNCell(lstm, state_is_tuple=True) # Get time_dimension as [1]-shaped tensor: step_size = tf.expand_dims(tf.shape(x)[1], [0]) lstm_init_state = lstm.zero_state(1, dtype=tf.float32) lstm_state_pl = rnn_placeholders(lstm.zero_state(1, dtype=tf.float32)) lstm_state_pl_flatten = flatten_nested(lstm_state_pl) # print('rnn_net: ', rnn_net) lstm_outputs, lstm_state_out = rnn_net( cell=lstm, inputs=x, initial_state=lstm_state_pl, sequence_length=lstm_sequence_length, ) # print('\nlstm_outputs: ', lstm_outputs) # print('\nlstm_state_out:', lstm_state_out) # Unwrap and expand: if static: x_out = lstm_outputs[0][:, None, :] else: x_out = lstm_outputs state_out = lstm_state_out return x_out, lstm_init_state, state_out, lstm_state_pl_flatten