Example #1
0
def lstm_network(x,
                 lstm_sequence_length,
                 lstm_class=rnn.BasicLSTMCell,
                 lstm_layers=(256, ),
                 name='lstm',
                 reuse=False,
                 **kwargs):
    """
    Stage2 network: from features to flattened LSTM output.
    Defines [multi-layered] dynamic [possibly shared] LSTM network.

    Returns:
         batch-wise flattened output tensor;
         lstm initial state tensor;
         lstm state output tensor;
         lstm flattened feed placeholders as tuple.
    """
    with tf.variable_scope(name, reuse=reuse):
        # Flatten, add action/reward and expand with fake [time] batch? dim to feed LSTM bank:
        #x = tf.concat([x, a_r] ,axis=-1)
        #x = tf.concat([batch_flatten(x), a_r], axis=-1)
        #x = tf.expand_dims(x, [0])

        # Define LSTM layers:
        lstm = []
        for size in lstm_layers:
            lstm += [lstm_class(size, state_is_tuple=True)]

        lstm = rnn.MultiRNNCell(lstm, state_is_tuple=True)
        # Get time_dimension as [1]-shaped tensor:
        step_size = tf.expand_dims(tf.shape(x)[1], [0])

        lstm_init_state = lstm.zero_state(1, dtype=tf.float32)

        lstm_state_pl = rnn_placeholders(lstm.zero_state(1, dtype=tf.float32))
        lstm_state_pl_flatten = flatten_nested(lstm_state_pl)

        lstm_outputs, lstm_state_out = tf.nn.dynamic_rnn(
            lstm,
            x,
            initial_state=lstm_state_pl,
            sequence_length=lstm_sequence_length,
            time_major=False)
        #x_out = tf.reshape(lstm_outputs, [-1, lstm_layers[-1]])
        x_out = lstm_outputs
    return x_out, lstm_init_state, lstm_state_out, lstm_state_pl_flatten
Example #2
0
def lstm_network(
        x,
        lstm_sequence_length,
        lstm_class=rnn.BasicLSTMCell,
        lstm_layers=(256,),
        static=False,
        name='lstm',
        reuse=False,
        **kwargs
    ):
    """
    Stage2 network: from features to flattened LSTM output.
    Defines [multi-layered] dynamic [possibly shared] LSTM network.

    Returns:
         batch-wise flattened output tensor;
         lstm initial state tensor;
         lstm state output tensor;
         lstm flattened feed placeholders as tuple.
    """
    with tf.variable_scope(name, reuse=reuse):
        # Prepare rnn type:
        if static:
            rnn_net = tf.nn.static_rnn
            # Remove time dimension (suppose always get one) and wrap to list:
            x = [x[:, 0, :]]

        else:
            rnn_net = tf.nn.dynamic_rnn
        # Define LSTM layers:
        lstm = []
        for size in lstm_layers:
            lstm += [lstm_class(size)] #, state_is_tuple=True)]

        lstm = rnn.MultiRNNCell(lstm, state_is_tuple=True)
        # Get time_dimension as [1]-shaped tensor:
        step_size = tf.expand_dims(tf.shape(x)[1], [0])

        lstm_init_state = lstm.zero_state(1, dtype=tf.float32)

        lstm_state_pl = rnn_placeholders(lstm.zero_state(1, dtype=tf.float32))
        lstm_state_pl_flatten = flatten_nested(lstm_state_pl)

        # print('rnn_net: ', rnn_net)

        lstm_outputs, lstm_state_out = rnn_net(
            cell=lstm,
            inputs=x,
            initial_state=lstm_state_pl,
            sequence_length=lstm_sequence_length,
        )

        # print('\nlstm_outputs: ', lstm_outputs)
        # print('\nlstm_state_out:', lstm_state_out)

        # Unwrap and expand:
        if static:
            x_out = lstm_outputs[0][:, None, :]
        else:
            x_out = lstm_outputs
        state_out = lstm_state_out
    return x_out, lstm_init_state, state_out, lstm_state_pl_flatten