Example #1
0
def get_model(imgs, is_training, weight_decay=0.0, bn_decay=None):
  batch_size = imgs.get_shape()[0].value
  im_dim = imgs.get_shape()[1].value

  ########
  with tf.variable_scope('Encoding'):
    # (batch_size, 64, 64, 64)
    net = tf_utils.conv2d(imgs, 64, [7,7],
                         padding='SAME', stride=[2,2],
                         bn=True, is_training=is_training,
                         scope='conv1', bn_decay=bn_decay, 
                         weight_decay=weight_decay, activation_fn=tf.nn.elu)
    # (batch_size, 32, 32, 64)
    net = tf_utils.conv2d(net, 64, [5,5],
                         padding='SAME', stride=[2,2],
                         bn=True, is_training=is_training,
                         scope='conv2', bn_decay=bn_decay, 
                         weight_decay=weight_decay, activation_fn=tf.nn.elu)
    # (batch_size, 16, 16, 128)
    net = tf_utils.conv2d(net, 128, [5,5],
                         padding='SAME', stride=[2,2],
                         bn=True, is_training=is_training,
                         scope='conv3', bn_decay=bn_decay, 
                         weight_decay=weight_decay, activation_fn=tf.nn.elu)
    # (batch_size, 8, 8, 128)
    net = tf_utils.conv2d(net, 128, [3,3],
                         padding='SAME', stride=[2,2],
                         bn=True, is_training=is_training,
                         scope='conv4', bn_decay=bn_decay, 
                         weight_decay=weight_decay, activation_fn=tf.nn.elu)
    # (batch_size, 4, 4, 256)
    net = tf_utils.conv2d(net, 256, [3,3],
                         padding='SAME', stride=[2,2],
                         bn=True, is_training=is_training,
                         scope='conv5', bn_decay=bn_decay, 
                         weight_decay=weight_decay, activation_fn=tf.nn.elu)
    # (batch_size, 1, 1, 512)
    net = tf_utils.conv2d(net, 512, [4,4],
                         padding='VALID', stride=[1,1],
                         bn=True, is_training=is_training,
                         scope='conv6', bn_decay=bn_decay, 
                         weight_decay=weight_decay, activation_fn=tf.nn.elu)

  ########
  with tf.variable_scope('Latent_variable'):
    net = tf.reshape(net, [batch_size, 512])
    net = tf_utils.fully_connected(net, 512, scope="fc1", 
                    weight_decay=weight_decay, activation_fn=tf.nn.elu, 
                    bn=True, bn_decay=bn_decay, is_training=is_training)
    net = tf_utils.fully_connected(net, 128*4*4*4, scope="fc2", 
                    weight_decay=weight_decay, activation_fn=tf.nn.elu, 
                    bn=True, bn_decay=bn_decay, is_training=is_training)
    net = tf.reshape(net, [batch_size, 4, 4, 4, 128])

  ########
  with tf.variable_scope('Decoding'):
    # (batch_size, 8, 8, 8, 64)
    net = tf_utils.conv3d_transpose(net, 64, [3, 3, 3], scope="deconv1",
                     stride=[2, 2, 2], padding='SAME',
                     weight_decay=weight_decay, activation_fn=tf.nn.elu,
                     bn=True, bn_decay=bn_decay, is_training=is_training)
    
    # (batch_size, 16, 16, 16, 32)
    net = tf_utils.conv3d_transpose(net, 32, [3, 3, 3], scope="deconv2",
                     stride=[2, 2, 2], padding='SAME',
                     weight_decay=weight_decay, activation_fn=tf.nn.elu,
                     bn=True, bn_decay=bn_decay, is_training=is_training)
    
    # (batch_size, 32, 32, 32, 32)
    net = tf_utils.conv3d_transpose(net, 32, [3, 3, 3], scope="deconv3",
                     stride=[2, 2, 2], padding='SAME',
                     weight_decay=weight_decay, activation_fn=tf.nn.elu,
                     bn=True, bn_decay=bn_decay, is_training=is_training)

    ##################
    ## regressed color
    ##################
    # (batch_size, 64, 64, 64, 24)
    net_reg_clr = tf_utils.conv3d_transpose(net, 16, [3, 3, 3], scope="deconv_reg_clr1",
                     stride=[2, 2, 2], padding='SAME',
                     weight_decay=weight_decay, activation_fn=tf.nn.elu,
                     bn=True, bn_decay=bn_decay, is_training=is_training)
    # (batch_size, 64, 64, 64, 3)
    net_reg_clr = tf_utils.conv3d(net_reg_clr, 3, [3, 3, 3], scope="deconv_reg_clr2",
                     stride=[1, 1, 1], padding='SAME',
                     weight_decay=weight_decay, activation_fn=tf.sigmoid,
                     bn=True, bn_decay=bn_decay, is_training=is_training)

    ##############
    ### confidence
    ############## 
    net_conf = tf_utils.conv3d_transpose(net, 16, [3, 3, 3], scope="deconv_conf1",
                     stride=[2, 2, 2], padding='SAME',
                     weight_decay=weight_decay, activation_fn=tf.nn.elu,
                     bn=True, bn_decay=bn_decay, is_training=is_training)
     # (batch_size, 64, 64, 64, 1)
    net_conf = tf_utils.conv3d(net_conf, 1, [3, 3, 3], scope="conv_conf2",
                     stride=[1, 1, 1], padding='SAME',
                     weight_decay=weight_decay, activation_fn=tf.sigmoid,
                     bn=True, bn_decay=bn_decay, is_training=is_training)

    ##############
    ### flow color
    ##############
    net_flow = tf_utils.conv3d_transpose(net, 2, [3, 3, 3], scope="deconv_flow",
                     stride=[2, 2, 2], padding='SAME',
                     weight_decay=weight_decay, activation_fn=tf.sigmoid,
                     bn=True, bn_decay=bn_decay, is_training=is_training)
    # (batch_size, 64, 64, 64, 3)
    net_flow_clr = tf_utils.Sampler(net_flow, imgs)

    #################
    ### blended color
    #################
    net_blended_clr = net_reg_clr * net_conf + net_flow_clr * (1.0 - net_conf)


  return net_reg_clr, net_conf, net_flow, net_blended_clr
Example #2
0
def get_model(imgs, is_training, weight_decay=0.0, bn_decay=None):
    """
      Args: 
        imgs: (batch_size, im_dim, im_dim, 3)
        is_training: a boolean placeholder.

      Return:
        shape: (batch_size, vol_dim, vol_dim, vol_dim, 1)
  """
    batch_size = imgs.get_shape()[0].value
    im_dim = imgs.get_shape()[1].value

    ########
    with tf.variable_scope('Encoding'):
        # (batch_size, 64, 64, 64)
        net = tf_utils.conv2d(imgs,
                              64, [7, 7],
                              padding='SAME',
                              stride=[2, 2],
                              bn=True,
                              is_training=is_training,
                              scope='conv1',
                              bn_decay=bn_decay,
                              weight_decay=weight_decay,
                              activation_fn=tf.nn.elu)
        # (batch_size, 32, 32, 64)
        net = tf_utils.conv2d(net,
                              64, [5, 5],
                              padding='SAME',
                              stride=[2, 2],
                              bn=True,
                              is_training=is_training,
                              scope='conv2',
                              bn_decay=bn_decay,
                              weight_decay=weight_decay,
                              activation_fn=tf.nn.elu)
        # (batch_size, 16, 16, 128)
        net = tf_utils.conv2d(net,
                              128, [5, 5],
                              padding='SAME',
                              stride=[2, 2],
                              bn=True,
                              is_training=is_training,
                              scope='conv3',
                              bn_decay=bn_decay,
                              weight_decay=weight_decay,
                              activation_fn=tf.nn.elu)
        # (batch_size, 8, 8, 128)
        net = tf_utils.conv2d(net,
                              128, [3, 3],
                              padding='SAME',
                              stride=[2, 2],
                              bn=True,
                              is_training=is_training,
                              scope='conv4',
                              bn_decay=bn_decay,
                              weight_decay=weight_decay,
                              activation_fn=tf.nn.elu)
        # (batch_size, 4, 4, 256)
        net = tf_utils.conv2d(net,
                              256, [3, 3],
                              padding='SAME',
                              stride=[2, 2],
                              bn=True,
                              is_training=is_training,
                              scope='conv5',
                              bn_decay=bn_decay,
                              weight_decay=weight_decay,
                              activation_fn=tf.nn.elu)
        # (batch_size, 1, 1, 512)
        net = tf_utils.conv2d(net,
                              512, [4, 4],
                              padding='VALID',
                              stride=[1, 1],
                              bn=True,
                              is_training=is_training,
                              scope='conv6',
                              bn_decay=bn_decay,
                              weight_decay=weight_decay,
                              activation_fn=tf.nn.elu)

    ########
    with tf.variable_scope('Latent_variable'):
        net = tf.reshape(net, [batch_size, 512])
        net = tf_utils.fully_connected(net,
                                       512,
                                       scope="fc1",
                                       weight_decay=weight_decay,
                                       activation_fn=tf.nn.elu,
                                       bn=True,
                                       bn_decay=bn_decay,
                                       is_training=is_training)
        net = tf_utils.fully_connected(net,
                                       128 * 4 * 4 * 4,
                                       scope="fc2",
                                       weight_decay=weight_decay,
                                       activation_fn=tf.nn.elu,
                                       bn=True,
                                       bn_decay=bn_decay,
                                       is_training=is_training)
        net = tf.reshape(net, [batch_size, 4, 4, 4, 128])

    ########
    with tf.variable_scope('Decoding'):
        # (batch_size, 8, 8, 8, 64)
        net = tf_utils.conv3d_transpose(net,
                                        64, [3, 3, 3],
                                        scope="deconv1",
                                        stride=[2, 2, 2],
                                        padding='SAME',
                                        weight_decay=weight_decay,
                                        activation_fn=tf.nn.elu,
                                        bn=True,
                                        bn_decay=bn_decay,
                                        is_training=is_training)
        # (batch_size, 16, 16, 16, 32)
        net = tf_utils.conv3d_transpose(net,
                                        32, [3, 3, 3],
                                        scope="deconv2",
                                        stride=[2, 2, 2],
                                        padding='SAME',
                                        weight_decay=weight_decay,
                                        activation_fn=tf.nn.elu,
                                        bn=True,
                                        bn_decay=bn_decay,
                                        is_training=is_training)
        # (batch_size, 32, 32, 32, 32)
        net = tf_utils.conv3d_transpose(net,
                                        32, [3, 3, 3],
                                        scope="deconv3",
                                        stride=[2, 2, 2],
                                        padding='SAME',
                                        weight_decay=weight_decay,
                                        activation_fn=tf.nn.elu,
                                        bn=True,
                                        bn_decay=bn_decay,
                                        is_training=is_training)
        # (batch_size, 64, 64, 64, 16)
        net = tf_utils.conv3d_transpose(net,
                                        24, [3, 3, 3],
                                        scope="deconv4",
                                        stride=[2, 2, 2],
                                        padding='SAME',
                                        weight_decay=weight_decay,
                                        activation_fn=tf.nn.elu,
                                        bn=True,
                                        bn_decay=bn_decay,
                                        is_training=is_training)
        # (batch_size, 64, 64, 64, 1)
        net = tf_utils.conv3d(net,
                              1, [3, 3, 3],
                              scope="deconv5",
                              stride=[1, 1, 1],
                              padding='SAME',
                              weight_decay=weight_decay,
                              activation_fn=None,
                              bn=True,
                              bn_decay=bn_decay,
                              is_training=is_training)

    return net