Example #1
0
    def _inference3D(self, x, is_training):
        """ Inference part of the network.

            Per view we get a 46x46x128 encoding (and maybe a 46x46x3 eye map).
            We unproject into a hand centered volume of dimension 64, so input dim is:
                64x64x64x 8*128 = 64x64x64x 1024
        """

        with tf.variable_scope('PoseNet3D') as scope:
            num_chan = self.config.num_kp

            skips = list()
            scorevolumes = list()
            skips.append(None)  # this is needed for the final upsampling step

            # 3D encoder
            # chan_list = [64, 128, 128, 256]
            chan_list = [32, 64, 64, 64]
            for chan in chan_list:
                x = self._enc3D_step(x, chan,
                                     dim_red=True, is_training=is_training)  # voxel sizes: 32, 16, 8, 4
                skips.append(x)
            skips.pop()  # the last one is of no use

            # bottleneck in the middle
            x = slim.conv3d(x, 64, kernel_size=[1, 1, 1], trainable=is_training,activation_fn=tf.nn.relu)

            # make initial guess of the scorevolume
            scorevol = slim.conv3d_transpose(x, num_chan, kernel_size=[32, 32, 32], trainable=is_training, stride=16, activation_fn=None)
            scorevolumes.append(scorevol)

            # 3D decoder
            kernels = [16, 8, 4]
            # chan_list = [64, 64, 64]
            chan_list = [32, 32, 32]
            for chan, kernel in zip(chan_list, kernels):
                x, scorevol = self._dec3D_stop(x, skips.pop(), scorevol, chan, num_chan, kernel, is_training)
                scorevolumes.append(scorevol)

            # final decoder step
            x = slim.conv3d_transpose(x, 64, kernel_size=[4, 4, 4], trainable=is_training, stride=2, activation_fn=tf.nn.relu)
            scorevol_delta = slim.conv3d(x, num_chan, kernel_size=[1, 1, 1], trainable=is_training, activation_fn=None)
            scorevol = scorevol_delta
            scorevolumes.append(scorevol)

            variables = tf.contrib.framework.get_variables(scope)

            if self.net_config.use_softargmax:
                xyz_vox_list = [softargmax3D(svol, output_vox_space=True) for svol in scorevolumes]
                score_list = [tf.reduce_mean(svol, [1, 2, 3]) for svol in scorevolumes]
            else:
                xyz_vox_list = [argmax_3d(svol) for svol in scorevolumes]
                score_list = [tf.reduce_max(svol, [1, 2, 3]) for svol in scorevolumes]

            return scorevolumes, xyz_vox_list, score_list, variables
def em_branch(input, prefix='em_branch_'):
    # input should be of shape [batch_size, frame_count, height, width, 16]
    conv = slim.conv3d(input, 8, [3, 3, 3], rate=1, activation_fn=lrelu, scope=prefix + 'g_conv1', padding='SAME')

    padding_method = 'VALID'
    conv1 = slim.conv3d(conv, 16, [5, 5, 5], rate=1, activation_fn=lrelu, scope=prefix + 's_conv1', padding=padding_method)
    conv2 = slim.conv3d(conv1, 16, [5, 5, 5], rate=1, activation_fn=lrelu, scope=prefix + 's_conv2', padding=padding_method)
    conv3 = slim.conv3d(conv2, 16, [5, 5, 5], rate=1, activation_fn=lrelu, scope=prefix + 's_conv3', padding=padding_method)
    #
    # shape_image = tf.placeholder(tf.float32, [BATCH_SIZE, CROP_FRAME - 8, CROP_HEIGHT - 8, CROP_WIDTH - 8, 16])
    #
    # pool_size = 1
    # deconv_filter1 = tf.Variable(tf.truncated_normal([pool_size, pool_size, pool_size, 16, 16], stddev=0.02))
    # deconv1 = tf.nn.conv3d_transpose(conv3, deconv_filter1, tf.shape(shape_image), strides=[1, pool_size, pool_size, pool_size, 1])
    # deconv1 = lrelu(deconv1)
    #
    # # print deconv1.shape
    # # print 'conv1.shape[:-1] + (8,):', tuple(conv1.shape[:-1]) + (8,)
    #
    # shape_image = tf.placeholder(tf.float32, [BATCH_SIZE, CROP_FRAME - 4, CROP_HEIGHT - 4, CROP_WIDTH - 4, 8])
    # pool_size = 1
    # deconv_filter2 = tf.Variable(tf.truncated_normal([pool_size, pool_size, pool_size, 8, 16], stddev=0.02))
    # deconv2 = tf.nn.conv3d_transpose(deconv1, deconv_filter2, tf.shape(shape_image), strides=[1, pool_size, pool_size, pool_size, 1])
    # deconv2 = lrelu(deconv2)
    #
    # # print deconv2.shape
    # shape_image = tf.placeholder(tf.float32, [BATCH_SIZE, CROP_FRAME, CROP_HEIGHT, CROP_WIDTH, 3])
    # pool_size = 1
    # deconv_filter3 = tf.Variable(tf.truncated_normal([pool_size, pool_size, pool_size, 3, 8], stddev=0.02))
    # deconv3 = tf.nn.conv3d_transpose(deconv2, deconv_filter3, tf.shape(shape_image), strides=[1, pool_size, pool_size, pool_size, 1])
    # deconv3 = lrelu(deconv3)

    # print deconv3.shape
    deconv1 = slim.conv3d_transpose(conv3, 16, [5, 5, 5], activation_fn=lrelu, scope=prefix + 's_deconv1', padding=padding_method)
    deconv2 = slim.conv3d_transpose(deconv1, 8, [5, 5, 5], activation_fn=lrelu, scope=prefix + 's_deconv2', padding=padding_method)
    deconv3 = slim.conv3d_transpose(deconv2, 3, [5, 5, 5], activation_fn=lrelu, scope=prefix + 's_deconv3', padding=padding_method)


    if DEBUG == 1:
        print 'conv.shape:', conv.shape
        print 'conv1.shape:', conv1.shape
        print 'conv2.shape:', conv2.shape
        print 'conv3.shape:', conv3.shape
        print 'deconv1.shape:', deconv1.shape
        print 'deconv2.shape:', deconv2.shape
        print 'deconv3.shape:', deconv3.shape
    return deconv3
Example #3
0
def deconv3d(input_, output_dim, ks=4, s=2, stddev=0.02, name="deconv3d"):
    with tf.compat.v1.variable_scope(name):
        return slim.conv3d_transpose(input_,
                                     output_dim,
                                     ks,
                                     s,
                                     padding='SAME',
                                     activation_fn=None,
                                     biases_initializer=None)
Example #4
0
def deconv3d(input_, output_dim, ks=4, s=2, stddev=0.02, name="deconv3d"):
    with tf.variable_scope(name):
        return slim.conv3d_transpose(
            input_,
            output_dim,
            ks,
            s,
            padding='SAME',
            activation_fn=None,
            weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
            biases_initializer=None)
Example #5
0
    def _dec3D_stop(self, x, skip, scorevol, chan, num_pred_chan, kernel2fs, is_training):
        # upsample features one step
        x_up = slim.conv3d_transpose(x, chan,
                                     kernel_size=[4, 4, 4], trainable=is_training, stride=2, activation_fn=tf.nn.relu)
        if skip is not None:
            # dim. reduction of the skip features
            x_skip = slim.conv3d(skip, chan, kernel_size=[1, 1, 1], trainable=is_training, activation_fn=tf.nn.relu)
            x = tf.concat([x_up, x_skip], -1)

        else:
            # if there is no skip left
            x = x_up

        # process features on the current resolution
        x = slim.conv3d(x, chan, kernel_size=[3, 3, 3], trainable=is_training, activation_fn=tf.nn.relu)

        # upsample all the way to full scale to make a prediction based on the current features
        scorevol_delta = slim.conv3d_transpose(x, num_pred_chan, kernel_size=[kernel2fs, kernel2fs, kernel2fs],
                                               stride=kernel2fs//2, trainable=is_training, activation_fn=None)

        scorevol = scorevol_delta
        return x, scorevol
Example #6
0
    def _create_decoder(self,
                        z_rgb,
                        trainable=True,
                        if_bn=False,
                        reuse=False,
                        scope_name='ae_decoder'):

        with tf.variable_scope(scope_name) as scope:
            if reuse:
                scope.reuse_variables()

            if if_bn:
                batch_normalizer_gen = slim.batch_norm
                batch_norm_params_gen = {
                    'is_training': self.is_training,
                    'decay': self.FLAGS.bn_decay
                }
            else:
                #self._print_arch('=== NOT Using BN for GENERATOR!')
                batch_normalizer_gen = None
                batch_norm_params_gen = None

            if self.FLAGS.if_l2Reg:
                weights_regularizer = slim.l2_regularizer(1e-5)
            else:
                weights_regularizer = None

            with slim.arg_scope([slim.fully_connected],
                                activation_fn=self.activation_fn,
                                trainable=trainable,
                                normalizer_fn=batch_normalizer_gen,
                                normalizer_params=batch_norm_params_gen,
                                weights_regularizer=weights_regularizer):

                net_up5 = slim.conv3d_transpose(z_rgb, 512, kernel_size=[4,4,4], stride=[2,2,2], padding='SAME', \
                    scope='ae_deconv6')
                net_up4 = slim.conv3d_transpose(net_up5, 512, kernel_size=[4,4,4], stride=[2,2,2], padding='SAME', \
                    scope='ae_deconv5')
                net_up3 = slim.conv3d_transpose(net_up4, 512, kernel_size=[4,4,4], stride=[2,2,2], padding='SAME', \
                    scope='ae_deconv4')
                net_up2 = slim.conv3d_transpose(net_up3, 256, kernel_size=[3,3,3], stride=[1,1,1], padding='SAME', \
                    scope='ae_deconv3')
                net_up1 = slim.conv3d_transpose(net_up2, 128, kernel_size=[4,4,4], stride=[2,2,2], padding='SAME', \
                    scope='ae_deconv2')
                net_up0 = slim.conv3d_transpose(net_up1, 64, kernel_size=[3,3,3], stride=[1,1,1], padding='SAME', \
                    scope='ae_deconv1')
                net_out_ = slim.conv3d_transpose(net_up0, 1, kernel_size=[3,3,3], stride=[1,1,1], padding='SAME', \
                    activation_fn=None, normalizer_fn=None, normalizer_params=None, scope='ae_out')

        return tf.nn.sigmoid(net_out_), net_out_
Example #7
0
    def hourglass3d(net, n, scope=None, reuse=None):
        num = int(net.shape[-1].value)
        sc_current = 'hourglass3d_{}'.format(n)
        with tf.variable_scope(scope, sc_current, [net], reuse=reuse):
            upper0 = inresnet3d.resnet_k(net)

            lower0 = slim.max_pool3d(net, 3, stride=2)
            lower0 = inresnet3d.resnet_k(lower0)

            lower0 = slim.conv3d(lower0, num * 2, 1, stride=1)

            if 1 < n:
                lower1 = inresnet3d.hourglass3d(lower0, n - 1)
            else:
                lower1 = lower0

            lower1 = slim.conv3d(lower1, num, 1, stride=1)

            lower2 = inresnet3d.resnet_k(lower1)
            upper1 = slim.conv3d_transpose(
                lower2, num, 3, stride=2)
            return upper0 + upper1
Example #8
0
def unet3d(inputs):
    """
    unet3D model without softmax.
    """
    print(inputs.shape)
    conv1 = slim.repeat(inputs=inputs,
                        repetitions=2,
                        layer=slim.layers.conv3d,
                        num_outputs=64,
                        kernel_size=3,
                        activation_fn=tf.nn.relu,
                        normalizer_fn=slim.batch_norm)
    print(conv1.shape)
    pool1 = slim.max_pool3d(inputs=conv1, kernel_size=2)
    print(pool1.shape)

    conv2 = slim.repeat(inputs=pool1,
                        repetitions=2,
                        layer=slim.conv3d,
                        num_outputs=128,
                        kernel_size=3,
                        activation_fn=tf.nn.relu,
                        normalizer_fn=slim.batch_norm)
    print(conv2.shape)
    pool2 = slim.max_pool3d(inputs=conv2, kernel_size=2)
    print(pool2.shape)

    conv3 = slim.repeat(inputs=pool2,
                        repetitions=2,
                        layer=slim.conv3d,
                        num_outputs=256,
                        activation_fn=tf.nn.relu,
                        kernel_size=3,
                        normalizer_fn=slim.batch_norm)
    print(conv3.shape)
    pool3 = slim.max_pool3d(inputs=conv3, kernel_size=2)
    print(pool3.shape)

    conv4 = slim.repeat(inputs=pool3,
                        repetitions=2,
                        layer=slim.conv3d,
                        num_outputs=512,
                        activation_fn=tf.nn.relu,
                        kernel_size=3,
                        normalizer_fn=slim.batch_norm)
    print(conv4.shape)
    # pool4 = slim.max_pool3d(inputs=conv4, kernel_size=2)
    # print(pool4.shape)

    # conv5 = slim.repeat(inputs=pool4,
    #                     repetitions=2,
    #                     layer=slim.conv3d,
    #                     num_outputs=1024,
    #                     activation_fn=tf.nn.relu,
    #                     kernel_size=3,
    #                     normalizer_fn=slim.batch_norm)
    # print(conv5.shape)

    # upsampling1 = slim.conv3d_transpose(inputs=conv5,
    #                                     kernel_size=3,
    #                                     num_outputs=1024,
    #                                     stride=2,
    #                                     activation_fn=tf.nn.relu,
    #                                     normalizer_fn=slim.batch_norm)
    # print(upsampling1.shape)
    # upconv1 = slim.conv3d(inputs=upsampling1,
    #                       kernel_size=2,
    #                       num_outputs=512,
    #                       activation_fn=tf.nn.relu,
    #                       normalizer_fn=slim.batch_norm)
    # print(upconv1.shape)
    # concat1 = tf.concat([conv4, upconv1], 3)
    # print(concat1.shape)

    # conv4 = slim.repeat(inputs=concat1,
    #                     repetitions=2,
    #                     layer=slim.conv3d,
    #                     num_outputs=512,
    #                     activation_fn=tf.nn.relu,
    #                     kernel_size=3,
    #                     normalizer_fn=slim.batch_norm)
    # print(conv4.shape)

    upsampling2 = slim.conv3d_transpose(inputs=conv4,
                                        kernel_size=3,
                                        num_outputs=512,
                                        stride=2,
                                        activation_fn=tf.nn.relu,
                                        normalizer_fn=slim.batch_norm)
    print(upsampling2.shape)
    upconv2 = slim.conv3d(inputs=upsampling2,
                          kernel_size=2,
                          num_outputs=256,
                          activation_fn=tf.nn.relu,
                          normalizer_fn=slim.batch_norm)
    print(upconv2.shape)
    concat2 = tf.concat([conv3, upconv2], 4)
    print(concat2.shape)
    conv3 = slim.repeat(inputs=concat2,
                        repetitions=2,
                        layer=slim.conv3d,
                        num_outputs=256,
                        activation_fn=tf.nn.relu,
                        kernel_size=3,
                        normalizer_fn=slim.batch_norm)
    print(conv3.shape)

    upsampling3 = slim.conv3d_transpose(inputs=conv3,
                                        kernel_size=3,
                                        num_outputs=256,
                                        stride=2,
                                        activation_fn=tf.nn.relu,
                                        normalizer_fn=slim.batch_norm)
    print(upsampling3.shape)
    upconv3 = slim.conv3d(inputs=upsampling3,
                          kernel_size=2,
                          num_outputs=128,
                          activation_fn=tf.nn.relu,
                          normalizer_fn=slim.batch_norm)
    print(upconv3.shape)
    concat3 = tf.concat([conv2, upconv3], 4)
    print(concat3.shape)
    conv2 = slim.repeat(inputs=concat3,
                        repetitions=2,
                        layer=slim.conv3d,
                        num_outputs=128,
                        activation_fn=tf.nn.relu,
                        kernel_size=3,
                        normalizer_fn=slim.batch_norm)
    print(conv2.shape)

    upsampling4 = slim.conv3d_transpose(inputs=conv2,
                                        kernel_size=3,
                                        num_outputs=128,
                                        stride=2,
                                        activation_fn=tf.nn.relu,
                                        normalizer_fn=slim.batch_norm)
    print(upsampling4.shape)
    upconv4 = slim.conv3d(inputs=upsampling4,
                          kernel_size=2,
                          num_outputs=64,
                          activation_fn=tf.nn.relu,
                          normalizer_fn=slim.batch_norm)
    print(upconv4.shape)
    concat4 = tf.concat([conv1, upconv4], 4)
    print(concat4.shape)
    conv1 = slim.repeat(inputs=concat4,
                        repetitions=2,
                        layer=slim.conv3d,
                        num_outputs=64,
                        activation_fn=tf.nn.relu,
                        kernel_size=3,
                        normalizer_fn=slim.batch_norm)
    print(conv1.shape)

    output = slim.repeat(inputs=conv1,
                         repetitions=1,
                         layer=slim.conv3d,
                         num_outputs=4,
                         activation_fn=tf.identity,
                         kernel_size=1,
                         normalizer_fn=slim.batch_norm)
    print(output.shape)

    return output
Example #9
0
def create_3D_UNet(x, features_root=16, n_classes=2):

    net = OrderedDict()
    with slim.arg_scope(
        [slim.conv3d, slim.conv3d_transpose],
            weights_initializer=initializers.variance_scaling_initializer(
                factor=2.0, mode='FAN_IN', uniform=False),
            activation_fn=leaky_relu):

        net['encode/conv1_1'] = instance_norm(
            slim.conv3d(x, features_root, [3, 3, 3]))
        net['encode/conv1_2'] = instance_norm(
            slim.conv3d(net['encode/conv1_1'], features_root, [3, 3, 3]))
        net['encode/pool1'] = slim.max_pool3d(net['encode/conv1_2'],
                                              kernel_size=[1, 2, 2],
                                              stride=[1, 2, 2])

        net['encode/conv2_1'] = instance_norm(
            slim.conv3d(net['encode/pool1'], features_root * 2, [3, 3, 3]))
        net['encode/conv2_2'] = instance_norm(
            slim.conv3d(net['encode/conv2_1'], features_root * 2, [3, 3, 3]))
        net['encode/pool2'] = slim.max_pool3d(net['encode/conv2_2'],
                                              kernel_size=[2, 2, 2],
                                              stride=[2, 2, 2])

        net['encode/conv3_1'] = instance_norm(
            slim.conv3d(net['encode/pool2'], features_root * 4, [3, 3, 3]))
        net['encode/conv3_2'] = instance_norm(
            slim.conv3d(net['encode/conv3_1'], features_root * 4, [3, 3, 3]))
        net['encode/pool3'] = slim.max_pool3d(net['encode/conv3_2'], [2, 2, 2])

        net['encode/conv4_1'] = instance_norm(
            slim.conv3d(net['encode/pool3'], features_root * 8, [3, 3, 3]))
        net['encode/conv4_2'] = instance_norm(
            slim.conv3d(net['encode/conv4_1'], features_root * 8, [3, 3, 3]))
        net['encode/pool4'] = slim.max_pool3d(net['encode/conv4_2'], [2, 2, 2])

        net['encode/conv5_1'] = instance_norm(
            slim.conv3d(net['encode/pool4'], features_root * 16, [3, 3, 3]))
        net['encode/conv5_2'] = instance_norm(
            slim.conv3d(net['encode/conv5_1'], features_root * 16, [3, 3, 3]))

        net['decode/up_conv1'] = slim.conv3d_transpose(net['encode/conv5_2'],
                                                       features_root * 8,
                                                       [2, 2, 2],
                                                       stride=2,
                                                       activation_fn=None,
                                                       padding='VALID',
                                                       biases_initializer=None)
        net['decode/concat_c4_u1'] = tf.concat(
            [net['encode/conv4_2'], net['decode/up_conv1']], 4)
        net['decode/conv1_1'] = instance_norm(
            slim.conv3d(net['decode/concat_c4_u1'], features_root * 8,
                        [3, 3, 3]))
        net['decode/conv1_2'] = instance_norm(
            slim.conv3d(net['decode/conv1_1'], features_root * 8, [3, 3, 3]))

        net['decode/up_conv2'] = slim.conv3d_transpose(net['decode/conv1_2'],
                                                       features_root * 4,
                                                       [2, 2, 2],
                                                       stride=2,
                                                       activation_fn=None,
                                                       padding='VALID',
                                                       biases_initializer=None)

        net['decode/concat_c3_u2'] = tf.concat(
            [net['encode/conv3_2'], net['decode/up_conv2']], 4)
        net['decode/conv2_1'] = instance_norm(
            slim.conv3d(net['decode/concat_c3_u2'], features_root * 4,
                        [3, 3, 3]))
        net['decode/conv2_2'] = instance_norm(
            slim.conv3d(net['decode/conv2_1'], features_root * 4, [3, 3, 3]))

        net['decode/up_conv3'] = slim.conv3d_transpose(net['decode/conv2_2'],
                                                       features_root * 2,
                                                       kernel_size=[2, 2, 2],
                                                       stride=[2, 2, 2],
                                                       activation_fn=None,
                                                       padding='VALID',
                                                       biases_initializer=None)
        net['decode/concat_c2_u3'] = tf.concat(
            [net['encode/conv2_2'], net['decode/up_conv3']], 4)
        net['decode/conv3_1'] = instance_norm(
            slim.conv3d(net['decode/concat_c2_u3'], features_root * 2,
                        [3, 3, 3]))
        net['decode/conv3_2'] = instance_norm(
            slim.conv3d(net['decode/conv3_1'], features_root * 2, [3, 3, 3]))

        net['decode/up_conv4'] = slim.conv3d_transpose(net['decode/conv3_2'],
                                                       features_root,
                                                       [1, 2, 2],
                                                       stride=[1, 2, 2],
                                                       activation_fn=None,
                                                       padding='VALID',
                                                       biases_initializer=None)

        net['decode/concat_c1_u4'] = tf.concat(
            [net['encode/conv1_2'], net['decode/up_conv4']], 4)
        net['decode/conv4_1'] = instance_norm(
            slim.conv3d(net['decode/concat_c1_u4'], features_root, [3, 3, 3]))
        net['decode/conv4_2'] = instance_norm(
            slim.conv3d(net['decode/conv4_1'], features_root, [3, 3, 3]))

        net['out_map'] = instance_norm(
            slim.conv3d(net['decode/conv4_2'],
                        n_classes, [1, 1, 1],
                        activation_fn=None))

    return net
Example #10
0
def unet_valid_sparese(vox_feat,
                       mask,
                       channels,
                       FLAGS,
                       trainable=True,
                       if_bn=False,
                       reuse=False,
                       is_training=True,
                       activation_fn=tf.nn.relu,
                       scope_name='unet_3d'):

    with tf.variable_scope(scope_name) as scope:
        if reuse:
            scope.reuse_variables()

        if if_bn:
            batch_normalizer_gen = slim.batch_norm
            batch_norm_params_gen = {
                'is_training': is_training,
                'decay': FLAGS.bn_decay
            }
        else:
            batch_normalizer_gen = None
            batch_norm_params_gen = None

        if FLAGS.if_l2Reg:
            weights_regularizer = slim.l2_regularizer(1e-5)
        else:
            weights_regularizer = None

        with slim.arg_scope(
            [slim.fully_connected, slim.conv3d, slim.conv3d_transpose],
                activation_fn=activation_fn,
                trainable=trainable,
                normalizer_fn=batch_normalizer_gen,
                normalizer_params=batch_norm_params_gen,
                weights_regularizer=weights_regularizer):

            mask_down1 = tf.stop_gradient(mask)
            net_down1 = slim.conv3d(vox_feat *
                                    tf.tile(mask_down1, [1, 1, 1, 1, 16]),
                                    16,
                                    kernel_size=4,
                                    stride=2,
                                    padding='SAME',
                                    scope='unet_conv1')
            mask_down2 = tf.stop_gradient(
                slim.max_pool3d(mask_down1,
                                kernel_size=4,
                                stride=2,
                                padding='SAME'))
            net_down2 = slim.conv3d(net_down1 *
                                    tf.tile(mask_down2, [1, 1, 1, 1, 16]),
                                    32,
                                    kernel_size=4,
                                    stride=2,
                                    padding='SAME',
                                    scope='unet_conv2')
            #net_down2 = slim.conv3d(net_down1 , 32, kernel_size=4, stride=2, padding='SAME', scope='unet_conv2')
            mask_down3 = tf.stop_gradient(
                slim.max_pool3d(mask_down2,
                                kernel_size=4,
                                stride=2,
                                padding='SAME'))
            net_down3 = slim.conv3d(net_down2 *
                                    tf.tile(mask_down3, [1, 1, 1, 1, 32]),
                                    64,
                                    kernel_size=4,
                                    stride=2,
                                    padding='SAME',
                                    scope='unet_conv3')
            #net_down3 = slim.conv3d(net_down2, 64, kernel_size=4, stride=2, padding='SAME', scope='unet_conv3')
            mask_down4 = tf.stop_gradient(
                slim.max_pool3d(mask_down3,
                                kernel_size=4,
                                stride=2,
                                padding='SAME'))
            net_down4 = slim.conv3d(net_down3 *
                                    tf.tile(mask_down4, [1, 1, 1, 1, 64]),
                                    128,
                                    kernel_size=4,
                                    stride=2,
                                    padding='SAME',
                                    scope='unet_conv4')
            #net_down4 = slim.conv3d(net_down3, 128, kernel_size=4, stride=2, padding='SAME', scope='unet_conv4')
            mask_down5 = tf.stop_gradient(
                slim.max_pool3d(mask_down4,
                                kernel_size=4,
                                stride=2,
                                padding='SAME'))
            net_down5 = slim.conv3d(net_down4 *
                                    tf.tile(mask_down5, [1, 1, 1, 1, 128]),
                                    256,
                                    kernel_size=4,
                                    stride=2,
                                    padding='SAME',
                                    scope='unet_conv5')
            #net_down5 = slim.conv3d(net_down4, 256, kernel_size=4, stride=2, padding='SAME', scope='unet_conv5')
            mask_down6 = tf.stop_gradient(
                slim.max_pool3d(mask_down5,
                                kernel_size=4,
                                stride=2,
                                padding='SAME'))

            net_up4 = slim.conv3d_transpose(net_down5*tf.tile(mask_down6, [1,1,1,1,256]), 128, kernel_size=4, stride=2, padding='SAME', \
                scope='unet_deconv4')
            #net_up4 = slim.conv3d_transpose(net_down5, 128, kernel_size=4, stride=2, padding='SAME', \
            #    scope='unet_deconv4')
            net_up4_ = tf.concat([net_up4, net_down4], axis=-1)
            net_up3 = slim.conv3d_transpose(net_up4_*tf.tile(mask_down5, [1,1,1,1,256]), 64, kernel_size=4, stride=2, padding='SAME', \
                scope='unet_deconv3')
            #net_up3 = slim.conv3d_transpose(net_up4_, 64, kernel_size=4, stride=2, padding='SAME', \
            #    scope='unet_deconv3')
            net_up3_ = tf.concat([net_up3, net_down3], axis=-1)
            net_up2 = slim.conv3d_transpose(net_up3_*tf.tile(mask_down4, [1,1,1,1,128]), 32, kernel_size=4, stride=2, padding='SAME', \
                scope='unet_deconv2')
            #net_up2 = slim.conv3d_transpose(net_up3_, 32, kernel_size=4, stride=2, padding='SAME', \
            #    scope='unet_deconv2')
            net_up2_ = tf.concat([net_up2, net_down2], axis=-1)
            net_up1 = slim.conv3d_transpose(net_up2_*tf.tile(mask_down3, [1,1,1,1,64]), 16, kernel_size=4, stride=2, padding='SAME', \
                scope='unet_deconv1')
            #net_up1 = slim.conv3d_transpose(net_up2_, 16, kernel_size=4, stride=2, padding='SAME', \
            #    scope='unet_deconv1')
            net_up1_ = tf.concat([net_up1, net_down1], axis=-1)
            #net_out_ = slim.conv3d(net_up1_, 1, kernel_size=4, stride=2, padding='SAME', \
            #    activation_fn=None, normalizer_fn=None, normalizer_params=None, scope='unet_deconv_out')
            ## heavy load
            net_up0 = slim.conv3d_transpose(net_up1_*tf.tile(mask_down2, [1,1,1,1,32]), channels, kernel_size=4, stride=2, padding='SAME', \
                scope='unet_deconv0')
            #net_up0 = slim.conv3d_transpose(net_up1_, channels, kernel_size=4, stride=2, padding='SAME', \
            #    scope='unet_deconv0')
            net_up0_ = tf.concat([net_up0, vox_feat], axis=-1)
            net_out_ = slim.conv3d(net_up0_, 1, kernel_size=3, stride=1, padding='SAME', \
                activation_fn=None, normalizer_fn=None, normalizer_params=None, scope='unet_deconv_out')
            ## heavy load
            #net_up2_ = tf.add(net_up2, net_down2)
            #net_up1 = slim.conv3d_transpose(net_up2_, 64, kernel_size=[4,4], stride=[2,2], padding='SAME', \
            #    scope='unet_deconv1')
            #net_up1_ = tf.concat([net_up1, net_down1], axis=-1)
            #net_out_ = slim.conv3d_transpose(net_up1_, out_channel, kernel_size=[4,4], stride=[2,2], padding='SAME', \
            #    activation_fn=None, normalizer_fn=None, normalizer_params=None, scope='unet_out')

    return tf.nn.sigmoid(net_out_), net_out_
Example #11
0
def voxel_net_3d_v2(inputs,
                    aux=None,
                    bn=True,
                    bn_trainmode='train',
                    freeze_decoder=False,
                    d0=16,
                    return_logits=False,
                    return_feats=False,
                    debug=False):

    decoder_trainable = not freeze_decoder
    input_size = list(inputs.get_shape())[1]
    if input_size == 128:
        arch = 'marr128'
    elif input_size == 64:
        arch = 'marr64'
    elif input_size == 32:
        arch = 'marr32'
    else:
        raise Exception, 'input size not supported'

    if aux is not None:
        assert tfutil.rank(aux) == 2
        aux_dim = int(tuple(aux.get_shape())[1])

    normalizer_params = {
        'is_training': bn_trainmode,
        'decay': 0.9,
        'epsilon': 1e-5,
        'scale': True,
        'updates_collections': None,
        'trainable': decoder_trainable
    }

    with slim.arg_scope([slim.conv3d, slim.conv3d_transpose],
                        activation_fn=tf.nn.relu,
                        normalizer_fn=slim.batch_norm if bn else None,
                        normalizer_params=normalizer_params):

        if arch == 'marr128':
            # 128 -> 64 -> 32 -> 16 -> 8 -> 1
            dims = [d0, 2 * d0, 4 * d0, 8 * d0, 16 * d0]
            ksizes = [4, 4, 4, 4, 8]
            strides = [2, 2, 2, 2, 1]
            paddings = ['SAME'] * 4 + ['VALID']
        elif arch == 'marr64':
            # 64 -> 32 -> 16 -> 8 -> 4 -> 1
            dims = [d0, 2 * d0, 4 * d0, 8 * d0, 16 * d0]
            ksizes = [4, 4, 4, 4, 4]
            strides = [2, 2, 2, 2, 1]
            paddings = ['SAME'] * 4 + ['VALID']
        elif arch == 'marr32':
            # 32 -> 16 -> 8 -> 4 -> 1
            dims = [d0, 2 * d0, 4 * d0, 8 * d0]
            ksizes = [4, 4, 4, 4]
            strides = [2, 2, 2, 1]
            paddings = ['SAME'] * 3 + ['VALID']
        #inputs = slim.batch_norm(inputs, decay=0.9, scale=True, epsilon=1e-5,
        #    updates_collections=None, is_training=bn_trainmode, trainable=decoder_trainable)

        net = inputs

        if debug:
            summ.histogram('voxel_net_3d_input', net)

        skipcons = [net]
        for i, (dim, ksize, stride,
                padding) in enumerate(zip(dims, ksizes, strides, paddings)):
            net = slim.conv3d(net, dim, ksize, stride=stride, padding=padding)
            skipcons.append(net)
            if debug:
                summ.histogram('voxel_net_3d_enc_%d' % i, net)

        if aux is not None:
            aux = tf.reshape(aux, (-1, 1, 1, 1, aux_dim))
            net = tf.concat([aux, net], axis=4)

        if arch == 'marr128':
            chans = [8 * d0, 4 * d0, 2 * d0, d0, 1]
            strides = [1, 2, 2, 2, 2]
            ksizes = [8, 4, 4, 4, 4]
            paddings = ['VALID'] + ['SAME'] * 4
        elif arch == 'marr64':
            chans = [8 * d0, 4 * d0, 2 * d0, d0, 1]
            strides = [1, 2, 2, 2, 2]
            ksizes = [4, 4, 4, 4, 4]
            paddings = ['VALID'] + ['SAME'] * 4
        elif arch == 'marr32':
            chans = [4 * d0, 2 * d0, d0, 1]
            strides = [1, 2, 2, 2]
            ksizes = [4, 4, 4, 4]
            paddings = ['VALID'] + ['SAME'] * 3

        skipcons.pop()  #we don't want the innermost layer as skipcon

        for i, (chan, stride, ksize,
                padding) in enumerate(zip(chans, strides, ksizes, paddings)):

            net = slim.conv3d_transpose(net,
                                        chan,
                                        ksize,
                                        stride=stride,
                                        padding=padding,
                                        trainable=decoder_trainable)
            #now concatenate on the skip-connection
            net = tf.concat([net, skipcons.pop()], axis=4)

            if net.shape[1] == 32:
                feats = net

            if debug:
                summ.histogram('voxel_net_3d_dec_%d' % i, net)

        #one last 1x1 conv to get the right number of output channels
        net = slim.conv3d(net,
                          1,
                          1,
                          1,
                          padding='SAME',
                          activation_fn=None,
                          normalizer_fn=None,
                          trainable=decoder_trainable)
        if debug:
            summ.histogram('voxel_net_3d_logits', net)

    net_ = tf.nn.sigmoid(net)
    if debug:
        summ.histogram('voxel_net_3d_output', net_)

    rvals = [net_]
    if return_logits:
        rvals.append(net)
    if return_feats:
        rvals.append(feats)
    if len(rvals) == 1:
        return rvals[0]
    return tuple(rvals)
Example #12
0
def voxel_net_3d(inputs, aux=None, bn=True, outsize=128, d0=16):

    # B x S x S x S x 25
    ###########################

    if aux is not None:
        assert tfutil.rank(aux) == 2
        aux_dim = int(tuple(aux.get_shape())[1])

    #aux is used for the category input
    bn_trainmode = ((const.mode != 'test') and (not const.rpvx_unsup))
    if const.force_batchnorm_trainmode:
        bn_trainmode = True
    if const.force_batchnorm_testmode:
        bn_trainmode = False

    normalizer_params = {
        'is_training': bn_trainmode,
        'decay': 0.9,
        'epsilon': 1e-5,
        'scale': True,
        'updates_collections': None
    }

    with slim.arg_scope([slim.conv3d, slim.conv3d_transpose],
                        activation_fn=tf.nn.relu,
                        normalizer_fn=slim.batch_norm if bn else None,
                        normalizer_params=normalizer_params):

        #the encoder part
        if const.NET3DARCH == '3x3':
            dims = [d0, 2 * d0, 4 * d0, 8 * d0, 16 * d0]
            ksizes = [3, 3, 3, 3, 3]
            strides = [2, 2, 2, 2, 2]
            paddings = ['SAME'] * 5
        elif const.NET3DARCH == 'marr':
            dims = [d0, 2 * d0, 4 * d0, 8 * d0, 16 * d0]
            ksizes = [4, 4, 4, 4, 8]
            strides = [2, 2, 2, 2, 1]
            paddings = ['SAME'] * 4 + ['VALID']
        elif const.NET3DARCH == 'marr_small':
            # 32 -> 16 -> 8 -> 4 -> 1
            dims = [d0, 2 * d0, 4 * d0, 8 * d0]
            ksizes = [4, 4, 4, 4]
            strides = [2, 2, 2, 1]
            paddings = ['SAME'] * 3 + ['VALID']
        elif const.NET3DARCH == 'marr_64':
            # 64 -> 32 -> 16 -> 8 -> 4 -> 1
            dims = [d0, 2 * d0, 4 * d0, 8 * d0, 16 * d0]
            ksizes = [4, 4, 4, 4, 4]
            strides = [2, 2, 2, 2, 1]
            paddings = ['SAME'] * 4 + ['VALID']
        else:
            raise Exception, 'unsupported network architecture'

        net = inputs
        skipcons = [net]
        for i, (dim, ksize, stride,
                padding) in enumerate(zip(dims, ksizes, strides, paddings)):
            net = slim.conv3d(net, dim, ksize, stride=stride, padding=padding)

            skipcons.append(net)

        #BS x 4 x 4 x 4 x 256

        if aux is not None:
            aux = tf.reshape(aux, (-1, 1, 1, 1, aux_dim))

            if const.NET3DARCH == '3x3':
                aux = tf.tile(aux,
                              (1, 4, 4, 4, 1))  #!!!!!!!!!!!!!! hardcoded value
            elif const.NET3DARCH == 'marr':
                pass  #really do nothing ;)
            else:
                raise Exception, 'unsupported networka rchitecture'

            net = tf.concat([aux, net], axis=4)

        #fix from here..
        if const.NET3DARCH == '3x3':
            chans = [128, 64, 32, 16, 1]
            strides = [2, 2, 2, 2, 2]
            ksizes = [3, 3, 3, 3, 3]
            paddings = ['SAME'] * 5
            activation_fns = [tf.nn.relu] * 4 + [
                None
            ]  #important to have the last be none
        elif const.NET3DARCH == 'marr':
            chans = [8 * d0, 4 * d0, 2 * d0, d0, 1]
            strides = [1, 2, 2, 2, 2]
            ksizes = [8, 4, 4, 4, 4]
            paddings = ['VALID'] + ['SAME'] * 4
            activation_fns = [tf.nn.relu] * 4 + [
                None
            ]  #important to have the last be none
        elif const.NET3DARCH == 'marr_small':
            chans = [4 * d0, 2 * d0, d0, 1]
            strides = [1, 2, 2, 2]
            ksizes = [4, 4, 4, 4]
            paddings = ['VALID'] + ['SAME'] * 3
            activation_fns = [tf.nn.relu] * 3 + [
                None
            ]  #important to have the last be none
        elif const.NET3DARCH == 'marr_64':
            chans = [8 * d0, 4 * d0, 2 * d0, d0, 1]
            strides = [1, 2, 2, 2, 2]
            ksizes = [4, 4, 4, 4, 4]
            paddings = ['VALID'] + ['SAME'] * 4
            activation_fns = [tf.nn.relu] * 4 + [
                None
            ]  #important to have the last be none
        else:
            raise Exception, 'unsupported network architecture'

        decoder_trainable = not const.rpvx_unsup  #don't ruin the decoder by FTing

        skipcons.pop()  #we don't want the innermost layer as skipcon

        for i, (chan, stride, ksize, padding, activation_fn) \
            in enumerate(zip(chans, strides, ksizes, paddings, activation_fns)):

            if i == len(chans) - 1:
                norm_fn = None
            else:
                norm_fn = slim.batch_norm

            net = slim.conv3d_transpose(net,
                                        chan,
                                        ksize,
                                        stride=stride,
                                        padding=padding,
                                        activation_fn=activation_fn,
                                        normalizer_fn=norm_fn,
                                        trainable=decoder_trainable)

            #now concatenate on the skip-connection
            net = tf.concat([net, skipcons.pop()], axis=4)

        #one last 1x1 conv to get the right number of output channels
        net = slim.conv3d(net,
                          1,
                          1,
                          1,
                          padding='SAME',
                          activation_fn=None,
                          normalizer_fn=slim.batch_norm,
                          trainable=decoder_trainable)

    net = tf.nn.sigmoid(net)

    return net
Example #13
0
def voxel_net(inputs,
              aux=None,
              bn=True,
              outsize=128,
              built_in_transform=False):
    bn_trainmode = ((const.mode != 'test') and (not const.rpvx_unsup))
    if const.force_batchnorm_trainmode:
        bn_trainmode = True
    if const.force_batchnorm_testmode:
        bn_trainmode = False

    normalizer_params = {
        'is_training': bn_trainmode,
        'decay': 0.9,
        'epsilon': 1e-5,
        'scale': True,
        'updates_collections': None
    }

    with slim.arg_scope(
        [slim.conv2d, slim.conv3d_transpose, slim.fully_connected],
            activation_fn=tf.nn.relu,
            normalizer_fn=slim.batch_norm if bn else None,
            normalizer_params=normalizer_params):

        #the encoder part
        dims = [64, 128, 256, 512, const.VOXNET_LATENT]
        ksizes = [11, 5, 5, 5, 8]
        strides = [4, 2, 2, 2, 1]
        paddings = ['SAME'] * 4 + ['VALID']

        net = inputs
        for i, (dim, ksize, stride,
                padding) in enumerate(zip(dims, ksizes, strides, paddings)):
            if const.DEBUG_HISTS:
                tf.summary.histogram('encoder_%d' % i, net)
            net = slim.conv2d(net, dim, ksize, stride=stride, padding=padding)

        if aux is not None:
            aux = tf.reshape(aux, (const.BS, 1, 1, -1))
            net = tf.concat([aux, net], axis=3)

        #two FC layers, as prescribed
        for i in range(2):
            net = slim.fully_connected(net, const.VOXNET_LATENT)

        if built_in_transform:
            tnet = net
            tnet = slim.fully_connected(tnet, 128)
            tnet = slim.fully_connected(tnet, 128)
            pose_logits = slim.fully_connected(tnet, 20, normalizer_fn=None)
            pose_ = tf.nn.softmax(pose_logits)

            angles = [20.0 * i for i in range(const.V)]
            rot_mats = map(utils.voxel.get_transform_matrix, angles)
            rot_mats = map(lambda x: tf.constant(x, dtype=tf.float32),
                           rot_mats)
            rot_mats = tf.expand_dims(tf.stack(rot_mats, axis=0), axis=0)
            pose_ = tf.reshape(pose_, (const.BS, const.V, 1, 1))
            rot_mat = tf.reduce_sum(rot_mats * pose_, axis=1)

            print rot_mat

            #do some things here.. predict weights for each rotmat

        #net is 1 x 1 x 1 x ?
        net = tf.reshape(net, (const.BS, 1, 1, 1, -1))

        if outsize == 128:
            chans = [256, 128, 64, 32, 1]
            strides = [1] + [2] * 4
            ksizes = [8] + [4] * 5
            paddings = ['VALID'] + ['SAME'] * 4
            activation_fns = [tf.nn.relu] * 4 + [None]

        elif outsize == 32:
            chans = [256, 128, 64, 1]
            strides = [1, 2, 2, 2]
            ksizes = [4, 2, 2, 2]
            paddings = ['VALID'] + ['SAME'] * 3
            activation_fns = [tf.nn.relu] * 3 + [None]

        else:
            raise Exception, 'unsupported outsize %d' % outsize

        decoder_trainable = not const.rpvx_unsup  #don't ruin the decoder by FTing
        #normalizer_params_ = dict(normalizer_params.items())
        #if not decoder_trainable:
        #    normalizer_params_['is_training'] = False

        for i, (chan, stride, ksize, padding, activation_fn) \
            in enumerate(zip(chans, strides, ksizes, paddings, activation_fns)):

            if const.DEBUG_HISTS:
                tf.summary.histogram('decoder_%d' % i, net)

            #before
            if i == -1:
                net = tfpy.summarize_tensor(net, 'layer %d' % i)

            if i == len(chans) - 1:
                norm_fn = None
            else:
                norm_fn = slim.batch_norm

            net = slim.conv3d_transpose(net,
                                        chan,
                                        ksize,
                                        stride=stride,
                                        padding=padding,
                                        activation_fn=activation_fn,
                                        normalizer_fn=norm_fn,
                                        trainable=decoder_trainable)

        if const.DEBUG_HISTS:
            tf.summary.histogram('before_sigmoid', net)

    net = tf.nn.sigmoid(net)
    if const.DEBUG_HISTS:
        tf.summary.histogram('after_sigmoid', net)

    if built_in_transform:
        net = voxel.rotate_voxel(net, rot_mat)

    return net