def get_model(imgs, is_training, weight_decay=0.0, bn_decay=None): batch_size = imgs.get_shape()[0].value im_dim = imgs.get_shape()[1].value ######## with tf.variable_scope('Encoding'): # (batch_size, 64, 64, 64) net = tf_utils.conv2d(imgs, 64, [7,7], padding='SAME', stride=[2,2], bn=True, is_training=is_training, scope='conv1', bn_decay=bn_decay, weight_decay=weight_decay, activation_fn=tf.nn.elu) # (batch_size, 32, 32, 64) net = tf_utils.conv2d(net, 64, [5,5], padding='SAME', stride=[2,2], bn=True, is_training=is_training, scope='conv2', bn_decay=bn_decay, weight_decay=weight_decay, activation_fn=tf.nn.elu) # (batch_size, 16, 16, 128) net = tf_utils.conv2d(net, 128, [5,5], padding='SAME', stride=[2,2], bn=True, is_training=is_training, scope='conv3', bn_decay=bn_decay, weight_decay=weight_decay, activation_fn=tf.nn.elu) # (batch_size, 8, 8, 128) net = tf_utils.conv2d(net, 128, [3,3], padding='SAME', stride=[2,2], bn=True, is_training=is_training, scope='conv4', bn_decay=bn_decay, weight_decay=weight_decay, activation_fn=tf.nn.elu) # (batch_size, 4, 4, 256) net = tf_utils.conv2d(net, 256, [3,3], padding='SAME', stride=[2,2], bn=True, is_training=is_training, scope='conv5', bn_decay=bn_decay, weight_decay=weight_decay, activation_fn=tf.nn.elu) # (batch_size, 1, 1, 512) net = tf_utils.conv2d(net, 512, [4,4], padding='VALID', stride=[1,1], bn=True, is_training=is_training, scope='conv6', bn_decay=bn_decay, weight_decay=weight_decay, activation_fn=tf.nn.elu) ######## with tf.variable_scope('Latent_variable'): net = tf.reshape(net, [batch_size, 512]) net = tf_utils.fully_connected(net, 512, scope="fc1", weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) net = tf_utils.fully_connected(net, 128*4*4*4, scope="fc2", weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) net = tf.reshape(net, [batch_size, 4, 4, 4, 128]) ######## with tf.variable_scope('Decoding'): # (batch_size, 8, 8, 8, 64) net = tf_utils.conv3d_transpose(net, 64, [3, 3, 3], scope="deconv1", stride=[2, 2, 2], padding='SAME', weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) # (batch_size, 16, 16, 16, 32) net = tf_utils.conv3d_transpose(net, 32, [3, 3, 3], scope="deconv2", stride=[2, 2, 2], padding='SAME', weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) # (batch_size, 32, 32, 32, 32) net = tf_utils.conv3d_transpose(net, 32, [3, 3, 3], scope="deconv3", stride=[2, 2, 2], padding='SAME', weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) ################## ## regressed color ################## # (batch_size, 64, 64, 64, 24) net_reg_clr = tf_utils.conv3d_transpose(net, 16, [3, 3, 3], scope="deconv_reg_clr1", stride=[2, 2, 2], padding='SAME', weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) # (batch_size, 64, 64, 64, 3) net_reg_clr = tf_utils.conv3d(net_reg_clr, 3, [3, 3, 3], scope="deconv_reg_clr2", stride=[1, 1, 1], padding='SAME', weight_decay=weight_decay, activation_fn=tf.sigmoid, bn=True, bn_decay=bn_decay, is_training=is_training) ############## ### confidence ############## net_conf = tf_utils.conv3d_transpose(net, 16, [3, 3, 3], scope="deconv_conf1", stride=[2, 2, 2], padding='SAME', weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) # (batch_size, 64, 64, 64, 1) net_conf = tf_utils.conv3d(net_conf, 1, [3, 3, 3], scope="conv_conf2", stride=[1, 1, 1], padding='SAME', weight_decay=weight_decay, activation_fn=tf.sigmoid, bn=True, bn_decay=bn_decay, is_training=is_training) ############## ### flow color ############## net_flow = tf_utils.conv3d_transpose(net, 2, [3, 3, 3], scope="deconv_flow", stride=[2, 2, 2], padding='SAME', weight_decay=weight_decay, activation_fn=tf.sigmoid, bn=True, bn_decay=bn_decay, is_training=is_training) # (batch_size, 64, 64, 64, 3) net_flow_clr = tf_utils.Sampler(net_flow, imgs) ################# ### blended color ################# net_blended_clr = net_reg_clr * net_conf + net_flow_clr * (1.0 - net_conf) return net_reg_clr, net_conf, net_flow, net_blended_clr
def get_model(imgs, is_training, weight_decay=0.0, bn_decay=None): """ Args: imgs: (batch_size, im_dim, im_dim, 3) is_training: a boolean placeholder. Return: shape: (batch_size, vol_dim, vol_dim, vol_dim, 1) """ batch_size = imgs.get_shape()[0].value im_dim = imgs.get_shape()[1].value ######## with tf.variable_scope('Encoding'): # (batch_size, 64, 64, 64) net = tf_utils.conv2d(imgs, 64, [7, 7], padding='SAME', stride=[2, 2], bn=True, is_training=is_training, scope='conv1', bn_decay=bn_decay, weight_decay=weight_decay, activation_fn=tf.nn.elu) # (batch_size, 32, 32, 64) net = tf_utils.conv2d(net, 64, [5, 5], padding='SAME', stride=[2, 2], bn=True, is_training=is_training, scope='conv2', bn_decay=bn_decay, weight_decay=weight_decay, activation_fn=tf.nn.elu) # (batch_size, 16, 16, 128) net = tf_utils.conv2d(net, 128, [5, 5], padding='SAME', stride=[2, 2], bn=True, is_training=is_training, scope='conv3', bn_decay=bn_decay, weight_decay=weight_decay, activation_fn=tf.nn.elu) # (batch_size, 8, 8, 128) net = tf_utils.conv2d(net, 128, [3, 3], padding='SAME', stride=[2, 2], bn=True, is_training=is_training, scope='conv4', bn_decay=bn_decay, weight_decay=weight_decay, activation_fn=tf.nn.elu) # (batch_size, 4, 4, 256) net = tf_utils.conv2d(net, 256, [3, 3], padding='SAME', stride=[2, 2], bn=True, is_training=is_training, scope='conv5', bn_decay=bn_decay, weight_decay=weight_decay, activation_fn=tf.nn.elu) # (batch_size, 1, 1, 512) net = tf_utils.conv2d(net, 512, [4, 4], padding='VALID', stride=[1, 1], bn=True, is_training=is_training, scope='conv6', bn_decay=bn_decay, weight_decay=weight_decay, activation_fn=tf.nn.elu) ######## with tf.variable_scope('Latent_variable'): net = tf.reshape(net, [batch_size, 512]) net = tf_utils.fully_connected(net, 512, scope="fc1", weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) net = tf_utils.fully_connected(net, 128 * 4 * 4 * 4, scope="fc2", weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) net = tf.reshape(net, [batch_size, 4, 4, 4, 128]) ######## with tf.variable_scope('Decoding'): # (batch_size, 8, 8, 8, 64) net = tf_utils.conv3d_transpose(net, 64, [3, 3, 3], scope="deconv1", stride=[2, 2, 2], padding='SAME', weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) # (batch_size, 16, 16, 16, 32) net = tf_utils.conv3d_transpose(net, 32, [3, 3, 3], scope="deconv2", stride=[2, 2, 2], padding='SAME', weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) # (batch_size, 32, 32, 32, 32) net = tf_utils.conv3d_transpose(net, 32, [3, 3, 3], scope="deconv3", stride=[2, 2, 2], padding='SAME', weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) # (batch_size, 64, 64, 64, 16) net = tf_utils.conv3d_transpose(net, 24, [3, 3, 3], scope="deconv4", stride=[2, 2, 2], padding='SAME', weight_decay=weight_decay, activation_fn=tf.nn.elu, bn=True, bn_decay=bn_decay, is_training=is_training) # (batch_size, 64, 64, 64, 1) net = tf_utils.conv3d(net, 1, [3, 3, 3], scope="deconv5", stride=[1, 1, 1], padding='SAME', weight_decay=weight_decay, activation_fn=None, bn=True, bn_decay=bn_decay, is_training=is_training) return net