def build_model(inputs, lstm_init_state): with tf.device("/gpu:0"): with tf.variable_scope("cnn_unrolled", reuse=tf.AUTO_REUSE): cnn_outputs = cnn_over_timesteps(inputs) cnn_outputs = tf.reshape(cnn_outputs, [ cnn_outputs.shape[0], cnn_outputs.shape[1], cnn_outputs.shape[2] * cnn_outputs.shape[3] * cnn_outputs.shape[4] ]) with tf.device("/gpu:0"): # RNN Block with tf.variable_scope("rnn_unrolled", reuse=tf.AUTO_REUSE): lstm_init_state = tuple(tf.unstack(lstm_init_state)) lstm_outputs, lstm_states = cudnn_lstm_unrolled( cnn_outputs, lstm_init_state) with tf.device("/gpu:0"): with tf.variable_scope("fc_unrolled", reuse=tf.AUTO_REUSE): fc_outputs = tools.static_map_fn(fc_model, lstm_outputs, axis=0) with tf.device("/gpu:0"): with tf.variable_scope("se3_unrolled", reuse=tf.AUTO_REUSE): # at this point the outputs from the fully connected layer are [x, y, z, yaw, pitch, roll, 6 x covars] se3_outputs = tools.static_map_fn(se3_comp_over_timesteps, fc_outputs, axis=1) return fc_outputs, se3_outputs, lstm_states
def fc_layer(inputs, fc_model_fn=fc_model): with tf.variable_scope("fc_layer", reuse=tf.AUTO_REUSE): fc_outputs = tools.static_map_fn(fc_model_fn, inputs, axis=0) return fc_outputs