def C3D(input_data, num_classes, keep_pro=0.5, non_local=False):
    with tf.variable_scope('C3D'):
        with slim.arg_scope([slim.conv3d],
                            padding='SAME',
                            weights_regularizer=slim.l2_regularizer(0.0005),
                            activation_fn=tf.nn.relu,
                            kernel_size=[3, 3, 3],
                            stride=[1, 1, 1]
                            ):
            # Batch * 16 * 112 * 112 * 3
            net = slim.conv3d(input_data, 64, scope='conv1')
            net = slim.max_pool3d(net, kernel_size=[1, 2, 2], stride=[1, 2, 2], padding='SAME', scope='max_pool1')
            # net = NonLocalBlock(net, 64, scope='nonlocal_block_1')

            # Batch * 16 * 56 * 56 * 64
            net = 
            net = 
            if non_local:
                net = NonLocalBlock(net, 128, scope='nonlocal_block_2')

            # Batch * 8 * 28 * 28 * 128 
            net = 
            net = 
            if non_local:
                net = NonLocalBlock(net, 256, scope='nonlocal_block_3')

            # Batch * 4 * 14 * 14 * 256
            net = 
            net = 
            if non_local:
                net = NonLocalBlock(net, 512, scope='nonlocal_block_4')
            # Batch * 2 * 7 * 7 * 512
            net = slim.repeat(net, 2, slim.conv3d, 512, scope='conv5')
            net = slim.max_pool3d(net, kernel_size=[2, 2, 2], stride=[2, 2, 2], padding='SAME', scope='max_pool5')

            # Batch * 1 * 4 * 4 * 512
            net = tf.reshape(net, [-1, 512 * 4 * 4])
            net = slim.fully_connected(net, 4096, weights_regularizer=slim.l2_regularizer(0.0005), scope='fc6')
            net = slim.dropout(net, keep_pro, scope='dropout1')
            net = slim.fully_connected(net, 4096, weights_regularizer=slim.l2_regularizer(0.0005), scope='fc7')
            net = slim.dropout(net, keep_pro, scope='dropout2')
            out = slim.fully_connected(net, num_classes, weights_regularizer=slim.l2_regularizer(0.0005), \
                                       activation_fn=None, scope='out')

            return out
예제 #2
0
def inference_t3d(frames, feature_size, _dropout,
                  block_config=(6, 12, 24, 16)):
    global IS_TRAIN
    IS_TRAIN = True

    #dbg_op = tf.Print(IS_TRAIN, [IS_TRAIN], message="IS_TRAIN:")
    with slim.arg_scope([slim.conv3d],
                        weights_initializer=weights_initializer,
                        weights_regularizer=slim.l2_regularizer(1e-7)):
        #print('input_shape:', frames.shape.as_list())
        out = slim.conv3d(frames,
                          START_CHANNEL, [3, 7, 7],
                          stride=[1, 2, 2],
                          padding='SAME',
                          biases_initializer=None)
        #print('conv1_shape:', out.shape.as_list())
        out = tf.layers.batch_normalization(out, training=IS_TRAIN)
        out = tf.nn.relu(out)
        out = slim.max_pool3d(out,
                              kernel_size=[3, 3, 3],
                              stride=2,
                              padding='SAME')
        #print('max3d_shape:', out.shape.as_list())
        in_channels = START_CHANNEL

        for i, num_layers in enumerate(block_config):

            out = build_block(out, num_layers, in_channels, _dropout)
            #print('block_shape:', out.shape.as_list())
            in_channels = in_channels + GROWTH_RATE * num_layers
            if i != len(block_config) - 1:
                if i == 0:
                    out = TTL(out, (1, 3, 6))
                else:
                    out = TTL(out)
                #print('ttl3d_shape:', out.shape.as_list())
                in_channels = 128 * 3
                out = Transition(out, in_channels // 2)
                #print('trans_shape:', out.shape.as_list())
                in_channels = in_channels // 2

        out = tf.layers.batch_normalization(out, training=IS_TRAIN)
        out = tf.nn.relu(out)
        #Standard input shape=[BATCH,NUM_CLIP=16,HEIGHT=160,WIDTH=160,RGB=3],makes that kernel_size of AVG_POOL equals '5'
        #If you are about to change size of input,changing the kernel size of 'avg_pool3d' simultaneously.
        out = slim.avg_pool3d(out, kernel_size=[1, 5, 5])
        #print('avg3d_shape:', out.shape.as_list())
        out = tf.reshape(out, [out.get_shape().as_list()[0], -1])
        #print('fc_re_shape:', out.shape.as_list())
        out = slim.fully_connected(out, feature_size)
        return out
예제 #3
0
 def reduce_net(net, num_out=None, scope=None, reuse=None):
     """ reduce scale by one-half, while double feature size """
     # return incept_resnet.conv_maxpool(net, scope, reuse)
     num = int(net.shape[-1].value)
     if num_out is None:
         num_out = num
     num2 = (num_out >> 1)
     num4 = (num2 >> 1)
     # num8 = (num4 >> 1)
     sc_current = 'reduce_net_{}'.format(num2)
     with tf.variable_scope(scope, sc_current, [net], reuse=reuse):
         with tf.variable_scope('branch0'):
             tower0 = slim.max_pool3d(net, 3, stride=2)
         with tf.variable_scope('branch1'):
             tower1 = slim.conv3d(net, num4, 1, stride=1)
             tower1 = slim.conv3d(tower1, num2, 3, stride=2)
         with tf.variable_scope('branch2'):
             tower2 = slim.conv3d(net, num4, 1, stride=1)
             tower2 = slim.conv3d(tower2, num4, 3, stride=1)
             tower2 = slim.conv3d(tower2, num2, 3, stride=2)
         net = tf.concat(axis=-1, values=[tower0, tower1, tower2])
     return net
예제 #4
0
    def hourglass3d(net, n, scope=None, reuse=None):
        num = int(net.shape[-1].value)
        sc_current = 'hourglass3d_{}'.format(n)
        with tf.variable_scope(scope, sc_current, [net], reuse=reuse):
            upper0 = inresnet3d.resnet_k(net)

            lower0 = slim.max_pool3d(net, 3, stride=2)
            lower0 = inresnet3d.resnet_k(lower0)

            lower0 = slim.conv3d(lower0, num * 2, 1, stride=1)

            if 1 < n:
                lower1 = inresnet3d.hourglass3d(lower0, n - 1)
            else:
                lower1 = lower0

            lower1 = slim.conv3d(lower1, num, 1, stride=1)

            lower2 = inresnet3d.resnet_k(lower1)
            upper1 = slim.conv3d_transpose(
                lower2, num, 3, stride=2)
            return upper0 + upper1
예제 #5
0
    def get_model(self,
                  input_tensor,
                  is_training,
                  bn_decay,
                  regu_scale,
                  scope=None,
                  final_endpoint='stage_out'):
        """ input_tensor: BxHxWxC
            out_dim: Bx(Jx3), where J is number of joints
        """
        # batch_size = frames_tf.get_shape()[0].value
        end_points = {}
        self.end_point_list = []

        def add_and_check_final(name, net):
            end_points[name] = net
            return name == final_endpoint

        with tf.variable_scope(scope, self.name_desc, [input_tensor]):
            bn_epsilon = 0.001
            with \
                slim.arg_scope(
                    [slim.batch_norm],
                    is_training=is_training,
                    epsilon=bn_epsilon,
                    # # Make sure updates happen automatically
                    # updates_collections=None,
                    # Try zero_debias_moving_mean=True for improved stability.
                    # zero_debias_moving_mean=True,
                    decay=bn_decay), \
                slim.arg_scope(
                    [slim.dropout],
                    is_training=is_training), \
                slim.arg_scope(
                    [slim.fully_connected],
                    weights_regularizer=slim.l2_regularizer(regu_scale),
                    biases_regularizer=slim.l2_regularizer(regu_scale),
                    activation_fn=tf.nn.relu,
                    normalizer_fn=slim.batch_norm), \
                slim.arg_scope(
                    [slim.max_pool3d, slim.avg_pool3d],
                    stride=2, padding='SAME'), \
                slim.arg_scope(
                    [slim.conv3d_transpose],
                    stride=2, padding='SAME',
                    weights_regularizer=slim.l2_regularizer(regu_scale),
                    biases_regularizer=slim.l2_regularizer(regu_scale),
                    activation_fn=tf.nn.relu,
                    normalizer_fn=slim.batch_norm), \
                slim.arg_scope(
                    [slim.conv3d],
                    stride=1, padding='SAME',
                    weights_regularizer=slim.l2_regularizer(regu_scale),
                    biases_regularizer=slim.l2_regularizer(regu_scale),
                    activation_fn=tf.nn.relu,
                    normalizer_fn=slim.batch_norm):
                with tf.variable_scope('stage0'):
                    sc = 'stage0'
                    net = slim.conv3d(input_tensor, 16, 3, scope='conv0_3x3_1')
                    net = slim.conv3d(net,
                                      32,
                                      3,
                                      stride=2,
                                      scope='conv0_3x3_2')
                    net = slim.max_pool3d(net, 3, scope='maxpool0_3x3_1')
                    self.end_point_list.append(sc)
                    if add_and_check_final(sc, net):
                        return net, end_points
                with tf.variable_scope('stage1'):
                    sc = 'stage1'
                    net = slim.conv3d(net, 64, 3, scope='conv1_3x3_1')
                    net = slim.max_pool3d(net,
                                          3,
                                          stride=2,
                                          scope='maxpool1_3x3_2')
                    self.end_point_list.append(sc)
                    if add_and_check_final(sc, net):
                        return net, end_points
                # with tf.variable_scope('stage2'):
                #     sc = 'stage2'
                #     net = slim.conv3d(net, 64, 3, scope='conv2_3x3_1')
                #     net = slim.max_pool3d(
                #         net, 3, stride=2, scope='maxpool2_3x3_2')
                #     self.end_point_list.append(sc)
                #     if add_and_check_final(sc, net):
                #         return net, end_points
                with tf.variable_scope('stage8'):
                    sc = 'stage_out'
                    net = slim.max_pool3d(net,
                                          5,
                                          stride=3,
                                          padding='VALID',
                                          scope='maxpool8_5x5_3')
                    net = slim.conv3d(net, 128, 1, scope='reduce8')
                    net = slim.conv3d(net,
                                      256,
                                      net.get_shape()[1:4],
                                      padding='VALID',
                                      scope='fullconn8')
                    net = slim.flatten(net)
                    net = slim.dropout(net,
                                       0.5,
                                       is_training=is_training,
                                       scope='dropout8')
                    net = slim.fully_connected(net,
                                               self.out_dim,
                                               activation_fn=None,
                                               scope='output8')
                    # self.end_point_list.append(sc)
                    if add_and_check_final(sc, net):
                        return net, end_points

        raise ValueError('final_endpoint (%s) not recognized', final_endpoint)
예제 #6
0
    def model(self, video, mode, only_endpoints=False, final_endpoint=''):
        """Create the model graph.

    Args:
      video: a BxTxHxWxC video tensor
      mode: string,  train or eval
      only_endpoints: Whether to return only the endpoints.
      final_endpoint: Specifies the endpoint to construct the network up to.
          If not specified, the entire network is constructed and returned.
          Only used if only_endpoints is True.

    Returns:
      loss, accuracy and logits, or endpoints
    """
        self.is_training = (mode == 'train')
        is_training = self.is_training
        data_format = self.data_format

        endpoints = {}

        def add_and_check_endpoint(net, endpoint):
            endpoints[endpoint] = net
            return only_endpoints and final_endpoint == endpoint

        with slim.arg_scope([slim.conv2d], padding='SAME'):
            with tf.variable_scope('VidIncRes', 'VidIncRes', [video]):
                with slim.arg_scope([slim.batch_norm, slim.dropout],
                                    is_training=is_training):
                    net = video

                    conv_op = self.get_layer_type(self.spec.convop1)
                    net = conv_op(net,
                                  64, [self.spec.time1, 7, 7],
                                  strides=[2, 2, 2],
                                  scope='Conv2d_1a_7x7',
                                  dilation=self.spec.dilation)
                    net = batch_norm_relu(net,
                                          is_training,
                                          relu=True,
                                          data_format=data_format)
                    if add_and_check_endpoint(net, 'Conv2d_1a_7x7'):
                        return endpoints

                    net = slim.max_pool3d(net,
                                          [self.spec.max_pool1_time, 3, 3],
                                          stride=[2, 2, 2],
                                          scope='maxpool1',
                                          padding='SAME')
                    if add_and_check_endpoint(net, 'maxpool1'):
                        return endpoints

                    net = self.residual_block(
                        net=net,
                        filters=4 * 64,
                        layers=self.spec.blocks[0].layers,
                        scope='res_block_2',
                        data_format=data_format,
                        block=self.spec.blocks[0])
                    if add_and_check_endpoint(net, 'res_block_2'):
                        return endpoints
                    net = slim.max_pool3d(net,
                                          [self.spec.max_pool1_time, 2, 2],
                                          stride=[1, 2, 2],
                                          scope='maxpool2',
                                          padding='SAME')
                    if add_and_check_endpoint(net, 'maxpool2'):
                        return endpoints

                    net = self.residual_block(net,
                                              4 * 128,
                                              self.spec.blocks[1].layers,
                                              scope='res_block_3',
                                              data_format=data_format,
                                              block=self.spec.blocks[1])
                    if add_and_check_endpoint(net, 'res_block_3'):
                        return endpoints
                    net = slim.max_pool3d(net,
                                          [self.spec.max_pool3_time, 2, 2],
                                          stride=[1, 2, 2],
                                          scope='maxpool3',
                                          padding='SAME')
                    if add_and_check_endpoint(net, 'maxpool3'):
                        return endpoints

                    net = self.residual_block(
                        net,
                        filters=4 * 256,
                        layers=self.spec.blocks[2].layers,
                        scope='res_block_4',
                        data_format=data_format,
                        block=self.spec.blocks[2])
                    if add_and_check_endpoint(net, 'res_block_4'):
                        return endpoints
                    net = slim.max_pool3d(net,
                                          [self.spec.max_pool4_time, 2, 2],
                                          stride=[1, 2, 2],
                                          scope='maxpool4',
                                          padding='SAME')
                    if add_and_check_endpoint(net, 'maxpool4'):
                        return endpoints

                    net = self.residual_block(net,
                                              4 * 512,
                                              self.spec.blocks[3].layers,
                                              scope='res_block_5',
                                              data_format=data_format,
                                              block=self.spec.blocks[3])
                    if add_and_check_endpoint(net, 'res_block_5'):
                        return endpoints
                    # Adds one more endpoint denoting the last cell before logits.
                    if add_and_check_endpoint(net, 'LastCell'):
                        return endpoints

                    with tf.variable_scope('Logits'):
                        shape = net.get_shape().as_list()
                        s = shape[3]
                        pool_size = (min(
                            shape[1] if data_format == 'channels_last' else
                            shape[2], 2), s, s)
                        net = slim.avg_pool3d(inputs=net,
                                              kernel_size=pool_size,
                                              stride=1,
                                              padding='VALID')
                        net = slim.dropout(net,
                                           self.dropout_keep_prob,
                                           scope='Dropout_0b',
                                           is_training=is_training)
                        net = slim.conv3d(net,
                                          self.num_classes,
                                          kernel_size=1,
                                          stride=1,
                                          activation_fn=None,
                                          normalizer_fn=None,
                                          weights_initializer=initializers.
                                          variance_scaling_initializer(
                                              factor=2.0,
                                              mode='FAN_IN',
                                              uniform=False))
                        # spatial-temporal pooling
                        logits = tf.reduce_mean(
                            net,
                            axis=([1, 2, 3] if data_format == 'channels_last'
                                  else [2, 3, 4]))
                        if add_and_check_endpoint(logits, 'Logits'):
                            return endpoints

        pred = tf.argmax(slim.softmax(logits), axis=1)
        if add_and_check_endpoint(pred, 'Predictions'):
            return endpoints

        if only_endpoints:
            return endpoints

        return logits
예제 #7
0
    def block(self,
              net,
              total_filters,
              layers,
              scope,
              data_format,
              use_projection=False,
              use_time=True):
        """Block function."""

        axis = 4
        is_training = self.is_training
        filters = total_filters // len(layers)
        total_filters = filters * len(layers)

        branches = []
        with tf.variable_scope(scope):
            shortcut = net
            if use_projection:
                shortcut = conv3d(net,
                                  total_filters, [1, 1, 1],
                                  scope='shortcut')
                shortcut = batch_norm_relu(shortcut,
                                           is_training,
                                           relu=False,
                                           data_format=data_format)

            for i, layer in enumerate(layers):
                with tf.variable_scope('Branch_' + str(i)):
                    if layer.layer_type == search_proto.Layer.SINGLE_DEFAULT:
                        branch = conv3d(net,
                                        filters, [1, 1, 1],
                                        scope='Conv2d_0a_1x1',
                                        data_format=data_format)
                        branch = batch_norm_relu(branch,
                                                 is_training,
                                                 relu=False,
                                                 data_format=data_format)
                    elif layer.layer_type == search_proto.Layer.CONV:
                        branch = conv3d(net,
                                        filters // 2, [1, 1, 1],
                                        scope='Conv2d_0a_1x1',
                                        data_format=data_format)
                        branch = batch_norm_relu(branch,
                                                 is_training,
                                                 relu=True,
                                                 data_format=data_format)
                        conv_fn = self.get_layer_type(layer.conv_type)
                        branch = conv_fn(
                            branch,
                            filters, [layer.time if use_time else 1, 3, 3],
                            scope='Conv2d_0b_3x3',
                            data_format=data_format,
                            dilation=layer.dilation if use_time else 1)
                        branch = batch_norm_relu(branch,
                                                 is_training,
                                                 relu=False,
                                                 data_format=data_format)
                    elif layer.layer_type == search_proto.Layer.CONV2:
                        branch = conv3d(net,
                                        filters // 4, [1, 1, 1],
                                        scope='Conv2d_0a_1x1',
                                        data_format=data_format)
                        branch = batch_norm_relu(branch,
                                                 is_training,
                                                 relu=True,
                                                 data_format=data_format)
                        conv_fn = self.get_layer_type(layer.conv_type)
                        branch = conv_fn(
                            branch,
                            filters // 2,
                            [layer.time if use_time else 1, 3, 3],
                            scope='Conv2d_0b_3x3',
                            data_format=data_format,
                            dilation=layer.dilation if use_time else 1)
                        branch = batch_norm_relu(branch,
                                                 is_training,
                                                 relu=True,
                                                 data_format=data_format)
                        conv_fn = self.get_layer_type(layer.conv_type2)
                        branch = conv_fn(
                            branch,
                            filters, [layer.time2 if use_time else 1, 3, 3],
                            scope='Conv2d_0c_3x3',
                            data_format=data_format,
                            dilation=layer.dilation2 if use_time else 1)
                        branch = batch_norm_relu(branch,
                                                 is_training,
                                                 relu=False,
                                                 data_format=data_format)
                    elif layer.layer_type == search_proto.Layer.MAXPOOLCONV:
                        branch = slim.max_pool3d(net, [layer.time, 3, 3],
                                                 scope='MaxPool_0a_3x3',
                                                 stride=1,
                                                 padding='SAME')
                        branch = conv3d(branch,
                                        filters, [1, 1, 1],
                                        scope='Conv2d_0b_1x1',
                                        data_format=data_format)
                        branch = batch_norm_relu(branch,
                                                 is_training,
                                                 relu=False,
                                                 data_format=data_format)
                branches.append(branch)
            net = tf.concat(branches, axis=axis)

        return tf.nn.relu(net + shortcut)
예제 #8
0
def maxpool3d(batch_input, depth, height, width, scope='maxpool3d'):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        return slim.max_pool3d(batch_input, [depth, height, width], 2, 'SAME')
예제 #9
0
    def get_model(self,
                  input_tensor,
                  is_training,
                  bn_decay,
                  regu_scale,
                  hg_repeat=2,
                  scope=None):
        """ input_tensor: BxHxWxDxC
            out_dim: BxHxWxDx(J*5), where J is number of joints
        """
        end_points = {}
        self.end_point_list = []
        final_endpoint = 'stage_out'
        num_joint = self.join_num
        num_feature = 128

        def add_and_check_final(name, net):
            end_points[name] = net
            return name == final_endpoint

        from tensorflow.contrib import slim
        from inresnet3d import inresnet3d
        # ~/anaconda2/lib/python2.7/site-packages/tensorflow/contrib/layers/
        with tf.variable_scope(scope, self.name_desc, [input_tensor]):
            bn_epsilon = 0.001
            with \
                slim.arg_scope(
                    [slim.batch_norm],
                    is_training=is_training,
                    epsilon=bn_epsilon,
                    # # Make sure updates happen automatically
                    # updates_collections=None,
                    # Try zero_debias_moving_mean=True for improved stability.
                    # zero_debias_moving_mean=True,
                    decay=bn_decay), \
                slim.arg_scope(
                    [slim.dropout],
                    is_training=is_training), \
                slim.arg_scope(
                    [slim.fully_connected],
                    weights_regularizer=slim.l2_regularizer(regu_scale),
                    biases_regularizer=slim.l2_regularizer(regu_scale),
                    activation_fn=tf.nn.relu,
                    normalizer_fn=slim.batch_norm), \
                slim.arg_scope(
                    [slim.max_pool3d, slim.avg_pool3d],
                    stride=2, padding='SAME'), \
                slim.arg_scope(
                    [slim.conv3d_transpose],
                    stride=2, padding='SAME',
                    weights_regularizer=slim.l2_regularizer(regu_scale),
                    biases_regularizer=slim.l2_regularizer(regu_scale),
                    activation_fn=tf.nn.relu,
                    normalizer_fn=slim.batch_norm), \
                slim.arg_scope(
                    [slim.conv3d],
                    stride=1, padding='SAME',
                    weights_regularizer=slim.l2_regularizer(regu_scale),
                    biases_regularizer=slim.l2_regularizer(regu_scale),
                    activation_fn=tf.nn.relu,
                    normalizer_fn=slim.batch_norm):
                with tf.variable_scope('stage64'):
                    sc = 'stage64'
                    net = slim.conv3d(input_tensor, 8, 3)
                    net = inresnet3d.conv_maxpool(net, scope=sc)
                    self.end_point_list.append(sc)
                    if add_and_check_final(sc, net):
                        return net, end_points
                    sc = 'stage32'
                    # net = inresnet3d.resnet_k(
                    #     net, scope='stage32_res')
                    net = inresnet3d.conv_maxpool(net, scope=sc)
                    net = slim.conv3d(net, num_feature, 1, scope='stage32_out')
                    self.end_point_list.append(sc)
                    if add_and_check_final(sc, net):
                        return net, end_points
                for hg in range(hg_repeat):
                    sc = 'hourglass_{}'.format(hg)
                    with tf.variable_scope(sc):
                        # branch0 = inresnet3d.hourglass3d(
                        #     net, 2, scope=sc + '_hg')
                        branch0 = inresnet3d.resnet_k(net, scope='_res')
                        branch_olm = slim.conv3d(
                            branch0,
                            num_joint,
                            1,
                            # normalizer_fn=None, activation_fn=tf.nn.relu)
                            normalizer_fn=None,
                            activation_fn=None)
                        branch_uom = slim.conv3d(
                            branch0,
                            num_joint * 3,
                            1,
                            # normalizer_fn=None, activation_fn=tf.nn.sigmoid)
                            normalizer_fn=None,
                            activation_fn=None)
                        net_maps = tf.concat([branch_olm, branch_uom], axis=-1)
                        self.end_point_list.append(sc)
                        if add_and_check_final(sc, net_maps):
                            return net_maps, end_points
                        branch1 = slim.conv3d(net_maps, num_feature, 1)
                        net = net + branch0 + branch1
                with tf.variable_scope('stage16'):
                    sc = 'stage16'
                    net = slim.max_pool3d(net, 3, scope=sc)
                    self.end_point_list.append(sc)
                    if add_and_check_final(sc, net):
                        return net, end_points
                with tf.variable_scope('stage8'):
                    sc = 'stage_out'
                    net = inresnet3d.pullout8(net,
                                              self.out_dim,
                                              is_training,
                                              scope=sc)
                    self.end_point_list.append(sc)
                    if add_and_check_final(sc, net):
                        return net, end_points
        raise ValueError('final_endpoint (%s) not recognized', final_endpoint)
예제 #10
0
def inception_v1_3d(inputs, keep_prob, num_classes):
    with tf.variable_scope('InceptionV1_3d'):
        with slim.arg_scope(
            [slim.conv3d, slim.fully_connected],
                weights_initializer=tf.truncated_normal_initializer(
                    stddev=0.001)):
            with slim.arg_scope([slim.conv3d, slim.max_pool3d],
                                stride=1,
                                padding='SAME'):
                with slim.arg_scope([slim.conv3d],
                                    normalizer_fn=slim.batch_norm):
                    with slim.arg_scope([slim.batch_norm, slim.dropout],
                                        is_training=True):

                        net = slim.conv3d(inputs,
                                          64, [7, 7, 7],
                                          stride=2,
                                          scope='Conv2d_1a_7x7')
                        net = slim.max_pool3d(net, [1, 3, 3],
                                              stride=[1, 2, 2],
                                              scope='MaxPool_2a_3x3')
                        net = slim.conv3d(net,
                                          64, [1, 1, 1],
                                          scope='Conv2d_2b_1x1')
                        net = slim.conv3d(net,
                                          192, [3, 3, 3],
                                          scope='Conv2d_2c_3x3')
                        net = slim.max_pool3d(net, [1, 3, 3],
                                              stride=[1, 2, 2],
                                              scope='MaxPool_3a_3x3')

                        with tf.variable_scope('Mixed_3b'):
                            with tf.variable_scope('Branch_0'):
                                branch_0 = slim.conv3d(net,
                                                       64, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                            with tf.variable_scope('Branch_1'):
                                branch_1 = slim.conv3d(net,
                                                       96, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_1 = slim.conv3d(branch_1,
                                                       128, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_2'):
                                branch_2 = slim.conv3d(net,
                                                       16, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_2 = slim.conv3d(branch_2,
                                                       32, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_3'):
                                branch_3 = slim.max_pool3d(
                                    net, [3, 3, 3], scope='MaxPool_0a_3x3')
                                branch_3 = slim.conv3d(branch_3,
                                                       32, [1, 1, 1],
                                                       scope='Conv2d_0b_1x1')
                            net = tf.concat(axis=4,
                                            values=[
                                                branch_0, branch_1, branch_2,
                                                branch_3
                                            ])

                        with tf.variable_scope('Mixed_3c'):
                            with tf.variable_scope('Branch_0'):
                                branch_0 = slim.conv3d(net,
                                                       128, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                            with tf.variable_scope('Branch_1'):
                                branch_1 = slim.conv3d(net,
                                                       128, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_1 = slim.conv3d(branch_1,
                                                       192, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_2'):
                                branch_2 = slim.conv3d(net,
                                                       32, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_2 = slim.conv3d(branch_2,
                                                       96, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_3'):
                                branch_3 = slim.max_pool3d(
                                    net, [3, 3, 3], scope='MaxPool_0a_3x3')
                                branch_3 = slim.conv3d(branch_3,
                                                       64, [1, 1, 1],
                                                       scope='Conv2d_0b_1x1')
                            net = tf.concat(axis=4,
                                            values=[
                                                branch_0, branch_1, branch_2,
                                                branch_3
                                            ])

                        net = slim.max_pool3d(net, [3, 3, 3],
                                              stride=2,
                                              scope='MaxPool_4a_3x3')

                        with tf.variable_scope('Mixed_4b'):
                            with tf.variable_scope('Branch_0'):
                                branch_0 = slim.conv3d(net,
                                                       192, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                            with tf.variable_scope('Branch_1'):
                                branch_1 = slim.conv3d(net,
                                                       96, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_1 = slim.conv3d(branch_1,
                                                       208, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_2'):
                                branch_2 = slim.conv3d(net,
                                                       16, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_2 = slim.conv3d(branch_2,
                                                       48, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_3'):
                                branch_3 = slim.max_pool3d(
                                    net, [3, 3, 3], scope='MaxPool_0a_3x3')
                                branch_3 = slim.conv3d(branch_3,
                                                       64, [1, 1, 1],
                                                       scope='Conv2d_0b_1x1')
                            net = tf.concat(axis=4,
                                            values=[
                                                branch_0, branch_1, branch_2,
                                                branch_3
                                            ])

                        with tf.variable_scope('Mixed_4c'):
                            with tf.variable_scope('Branch_0'):
                                branch_0 = slim.conv3d(net,
                                                       160, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                            with tf.variable_scope('Branch_1'):
                                branch_1 = slim.conv3d(net,
                                                       112, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_1 = slim.conv3d(branch_1,
                                                       224, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_2'):
                                branch_2 = slim.conv3d(net,
                                                       24, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_2 = slim.conv3d(branch_2,
                                                       64, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_3'):
                                branch_3 = slim.max_pool3d(
                                    net, [3, 3, 3], scope='MaxPool_0a_3x3')
                                branch_3 = slim.conv3d(branch_3,
                                                       64, [1, 1, 1],
                                                       scope='Conv2d_0b_1x1')
                            net = tf.concat(axis=4,
                                            values=[
                                                branch_0, branch_1, branch_2,
                                                branch_3
                                            ])

                        with tf.variable_scope('Mixed_4d'):
                            with tf.variable_scope('Branch_0'):
                                branch_0 = slim.conv3d(net,
                                                       128, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                            with tf.variable_scope('Branch_1'):
                                branch_1 = slim.conv3d(net,
                                                       128, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_1 = slim.conv3d(branch_1,
                                                       256, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_2'):
                                branch_2 = slim.conv3d(net,
                                                       24, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_2 = slim.conv3d(branch_2,
                                                       64, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_3'):
                                branch_3 = slim.max_pool3d(
                                    net, [3, 3, 3], scope='MaxPool_0a_3x3')
                                branch_3 = slim.conv3d(branch_3,
                                                       64, [1, 1, 1],
                                                       scope='Conv2d_0b_1x1')
                            net = tf.concat(axis=4,
                                            values=[
                                                branch_0, branch_1, branch_2,
                                                branch_3
                                            ])

                        with tf.variable_scope('Mixed_4e'):
                            with tf.variable_scope('Branch_0'):
                                branch_0 = slim.conv3d(net,
                                                       112, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                            with tf.variable_scope('Branch_1'):
                                branch_1 = slim.conv3d(net,
                                                       144, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_1 = slim.conv3d(branch_1,
                                                       288, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_2'):
                                branch_2 = slim.conv3d(net,
                                                       32, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_2 = slim.conv3d(branch_2,
                                                       64, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_3'):
                                branch_3 = slim.max_pool3d(
                                    net, [3, 3, 3], scope='MaxPool_0a_3x3')
                                branch_3 = slim.conv3d(branch_3,
                                                       64, [1, 1, 1],
                                                       scope='Conv2d_0b_1x1')
                            net = tf.concat(axis=4,
                                            values=[
                                                branch_0, branch_1, branch_2,
                                                branch_3
                                            ])

                        with tf.variable_scope('Mixed_4f'):
                            with tf.variable_scope('Branch_0'):
                                branch_0 = slim.conv3d(net,
                                                       256, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                            with tf.variable_scope('Branch_1'):
                                branch_1 = slim.conv3d(net,
                                                       160, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_1 = slim.conv3d(branch_1,
                                                       320, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_2'):
                                branch_2 = slim.conv3d(net,
                                                       32, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_2 = slim.conv3d(branch_2,
                                                       128, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_3'):
                                branch_3 = slim.max_pool3d(
                                    net, [3, 3, 3], scope='MaxPool_0a_3x3')
                                branch_3 = slim.conv3d(branch_3,
                                                       128, [1, 1, 1],
                                                       scope='Conv2d_0b_1x1')
                            net = tf.concat(axis=4,
                                            values=[
                                                branch_0, branch_1, branch_2,
                                                branch_3
                                            ])

                        net = slim.max_pool3d(net, [2, 2, 2],
                                              stride=2,
                                              scope='MaxPool_5a_2x2x2')

                        with tf.variable_scope('Mixed_5b'):
                            with tf.variable_scope('Branch_0'):
                                branch_0 = slim.conv3d(net,
                                                       256, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                            with tf.variable_scope('Branch_1'):
                                branch_1 = slim.conv3d(net,
                                                       160, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_1 = slim.conv3d(branch_1,
                                                       320, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_2'):
                                branch_2 = slim.conv3d(net,
                                                       32, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_2 = slim.conv3d(branch_2,
                                                       128, [3, 3, 3],
                                                       scope='Conv2d_0a_3x3')
                            with tf.variable_scope('Branch_3'):
                                branch_3 = slim.max_pool3d(
                                    net, [3, 3, 3], scope='MaxPool_0a_3x3')
                                branch_3 = slim.conv3d(branch_3,
                                                       128, [1, 1, 1],
                                                       scope='Conv2d_0b_1x1')
                            net = tf.concat(axis=4,
                                            values=[
                                                branch_0, branch_1, branch_2,
                                                branch_3
                                            ])

                        with tf.variable_scope('Mixed_5c'):
                            with tf.variable_scope('Branch_0'):
                                branch_0 = slim.conv3d(net,
                                                       384, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                            with tf.variable_scope('Branch_1'):
                                branch_1 = slim.conv3d(net,
                                                       192, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_1 = slim.conv3d(branch_1,
                                                       384, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_2'):
                                branch_2 = slim.conv3d(net,
                                                       48, [1, 1, 1],
                                                       scope='Conv2d_0a_1x1')
                                branch_2 = slim.conv3d(branch_2,
                                                       128, [3, 3, 3],
                                                       scope='Conv2d_0b_3x3')
                            with tf.variable_scope('Branch_3'):
                                branch_3 = slim.max_pool3d(
                                    net, [3, 3, 3], scope='MaxPool_0a_3x3')
                                branch_3 = slim.conv3d(branch_3,
                                                       128, [1, 1, 1],
                                                       scope='Conv2d_0b_1x1')
                            net = tf.concat(axis=4,
                                            values=[
                                                branch_0, branch_1, branch_2,
                                                branch_3
                                            ])

                        with tf.variable_scope('Logits'):
                            net = slim.avg_pool3d(net, [2, 7, 7],
                                                  stride=1,
                                                  scope='AvgPool_0a_7x7')
                            net = slim.dropout(net,
                                               keep_prob,
                                               scope='Dropout_0b')
                            logits = slim.conv3d(net,
                                                 num_classes, [1, 1, 1],
                                                 activation_fn=None,
                                                 normalizer_fn=None,
                                                 scope='Conv2d_0c_1x1')

                            logits = tf.squeeze(logits, [2, 3],
                                                name='SpatialSqueeze')

                            averaged_logits = tf.reduce_mean(logits, axis=1)

                            return averaged_logits
예제 #11
0
    def _find_logits(self):
        # [Batch, Height, Width, Depth, Channels]
        x = self._placeholder_inputs
        with tf.variable_scope(self.get_variable_scope()):
            with slim.arg_scope(self.arg_scope()):
                # [Batch, Depth, Height, Width, Channels]
                # x = tf.transpose(x, perm=[0,3,1,2,4])

                x = slim.conv3d(x,
                                50, [1, 1, 1],
                                activation_fn=tf.nn.relu,
                                scope='p1_c1')
                x = slim.conv3d(x,
                                50, [1, 1, 1],
                                activation_fn=tf.nn.relu,
                                scope='p1_c2')
                x = slim.conv3d(x,
                                32, [1, 1, 1],
                                activation_fn=tf.nn.relu,
                                scope='p1_c3')
                # x = slim.batch_norm(x)
                x = slim.conv3d(x,
                                16, [3, 3, 5],
                                activation_fn=tf.nn.relu,
                                scope='p2_c1')
                x = slim.conv3d(x,
                                16, [3, 3, 5],
                                activation_fn=tf.nn.relu,
                                scope='p2_c2')
                x = slim.conv3d(x,
                                16, [3, 3, 5],
                                activation_fn=tf.nn.relu,
                                scope='p2_c3')
                x = slim.max_pool3d(x, [2, 2, 2], scope='p1')
                x = slim.conv3d(x,
                                32, [3, 3, 5],
                                activation_fn=tf.nn.relu,
                                scope='p2_c4')
                x = slim.conv3d(x,
                                32, [3, 3, 5],
                                activation_fn=tf.nn.relu,
                                scope='p2_c5')
                x = slim.max_pool3d(x, [2, 2, 2], scope='p2')
                x = slim.conv3d(x,
                                32, [3, 3, 5],
                                activation_fn=tf.nn.relu,
                                scope='p2_c6')
                x = slim.conv3d(x,
                                64, [3, 3, 5],
                                activation_fn=tf.nn.relu,
                                scope='p3_c1')
                x = slim.max_pool3d(x, [2, 2, 2], scope='p3')
                x = slim.conv3d(x,
                                64, [3, 3, 5],
                                activation_fn=tf.nn.relu,
                                scope='p3_c2')
                x = slim.conv3d(x,
                                64, [3, 3, 5],
                                activation_fn=tf.nn.relu,
                                scope='p3_c3')

                # x should be [B, H, W, D, 8]

                flattened_features = tf.reshape(x, (self.batch_size, -1))

                fc_1 = tf.layers.dense(
                    flattened_features,
                    units=1024,
                    activation=tf.nn.relu,
                    kernel_initializer=tf.random_normal_initializer(
                        mean=0., stddev=self.weight_init_width),
                    bias_initializer=tf.random_normal_initializer(
                        mean=0., stddev=self.weight_init_width))
                fc_2 = tf.layers.dense(
                    fc_1,
                    units=1024,
                    activation=tf.nn.relu,
                    kernel_initializer=tf.random_normal_initializer(
                        mean=0., stddev=self.weight_init_width),
                    bias_initializer=tf.random_normal_initializer(
                        mean=0., stddev=self.weight_init_width))
                fc_3 = tf.layers.dense(
                    fc_2,
                    units=self.num_classes,
                    activation=None,
                    kernel_initializer=tf.random_normal_initializer(
                        mean=0., stddev=self.weight_init_width),
                    bias_initializer=tf.random_normal_initializer(
                        mean=0., stddev=self.weight_init_width))

                return fc_3
예제 #12
0
    def _module_fn():
        """
        Function building the module
        """

        feature_layer = tf.placeholder(
            tf.float32,
            shape=[None, None, None, None, nchannels],
            name='input')
        obs_layer = tf.placeholder(tf.float32,
                                   shape=[None, None, None, None, n_y],
                                   name='observations')

        # Builds the neural network

        if pad == 0:
            d00 = slim.conv3d(feature_layer,
                              fsize,
                              5,
                              activation_fn=tf.nn.leaky_relu,
                              padding='same')
        elif pad == 2:
            d00 = slim.conv3d(feature_layer,
                              fsize,
                              5,
                              activation_fn=tf.nn.leaky_relu,
                              padding='valid')
        if pad == 4:
            d00 = slim.conv3d(feature_layer,
                              fsize,
                              5,
                              activation_fn=tf.nn.leaky_relu,
                              padding='valid')
            d00 = slim.conv3d(d00,
                              fsize * 2,
                              5,
                              activation_fn=tf.nn.leaky_relu,
                              padding='valid')
##        #downsample
##        dd = [[d00]]
##        cfsize = fsize
##        for i in range(nsub):
##            d0 = dd[-1][-1]
##            d1 = wide_resnet(d0, cfsize, activation_fn=tf.nn.leaky_relu)
##            d2 = wide_resnet(d1, cfsize, activation_fn=tf.nn.leaky_relu)
##            dsub = slim.max_pool3d(d2, kernel_size=3, stride=2, padding='SAME')
##            dd.append([d1, d2, dsub])
##            cfsize  *= 2
##
##        #lower layer
##        d0 = dd[-1][-1]
##        d1 = wide_resnet(d0, cfsize, activation_fn=tf.nn.leaky_relu)
##        d2 = wide_resnet(d1, cfsize, activation_fn=tf.nn.leaky_relu)
##
##        up = [[d1, d2]]
##        #upsample
##        for i in range(nsub):
##            cfsize = cfsize // 2
##            usub = up[-1][-1]
##            dup = dd.pop()
##            u0 = dynamic_deconv3d('up%d'%i, usub, shape=[3,3,3,cfsize], activation=tf.nn.leaky_relu)
##            #u0 = slim.conv3d_transpose(usub, fsize, kernel_size=3, stride=2)
##            uc = tf.concat([u0, dup[1]], axis=-1)
##            u1 = wide_resnet(uc, cfsize, activation_fn=tf.nn.leaky_relu)
##            u2 = wide_resnet(u1, cfsize, activation_fn=tf.nn.leaky_relu)
##            up.append([u0, u1, u1c, u2])
##
##        u0 = up[-1][-1]
##        net = slim.conv3d(u0, 1, 3, activation_fn=tf.nn.tanh)
##
#downsample #restructure code while doubling filter size
        cfsize = fsize
        d1 = wide_resnet(d00, cfsize, activation_fn=tf.nn.leaky_relu)
        d2 = wide_resnet(d1, cfsize, activation_fn=tf.nn.leaky_relu)
        dd = [d2]
        for i in range(nsub):
            cfsize *= 2
            print(i, cfsize)
            dsub = slim.max_pool3d(dd[-1],
                                   kernel_size=3,
                                   stride=2,
                                   padding='SAME')
            d1 = wide_resnet(dsub, cfsize, activation_fn=tf.nn.leaky_relu)
            d2 = wide_resnet(d1, cfsize, activation_fn=tf.nn.leaky_relu)
            dd.append(d2)

        print(len(dd))
        #upsample
        usub = dd.pop()
        for i in range(nsub):
            u0 = dynamic_deconv3d('up%d' % i,
                                  usub,
                                  shape=[3, 3, 3, cfsize],
                                  activation=tf.identity)
            cfsize = cfsize // 2
            print(i, cfsize)
            u0 = slim.conv3d(u0,
                             cfsize,
                             1,
                             activation_fn=tf.identity,
                             padding='same')
            #u0 = slim.conv3d_transpose(usub, fsize, kernel_size=3, stride=2)
            uc = tf.concat([u0, dd.pop()], axis=-1)
            u1 = wide_resnet(uc, cfsize, activation_fn=tf.nn.leaky_relu)
            u2 = wide_resnet(u1, cfsize, activation_fn=tf.nn.leaky_relu)
            usub = u2

        print(len(dd))
        net = slim.conv3d(usub, 1, 3, activation_fn=tf.nn.tanh)

        # Define the probabilistic layer
        net = slim.conv3d(net, n_mixture * 3 * n_y, 1, activation_fn=None)
        cube_size = tf.shape(obs_layer)[1]
        net = tf.reshape(
            net, [-1, cube_size, cube_size, cube_size, n_y, n_mixture * 3])
        #         net = tf.reshape(net, [None, None, None, None, n_y, n_mixture*3])
        loc, unconstrained_scale, logits = tf.split(net,
                                                    num_or_size_splits=3,
                                                    axis=-1)
        scale = tf.nn.softplus(unconstrained_scale) + 1e-3

        # Form mixture of discretized logistic distributions. Note we shift the
        # logistic distribution by -0.5. This lets the quantization capture "rounding"
        # intervals, `(x-0.5, x+0.5]`, and not "ceiling" intervals, `(x-1, x]`.
        if distribution == 'logistic':
            discretized_logistic_dist = tfd.QuantizedDistribution(
                distribution=tfd.TransformedDistribution(
                    distribution=tfd.Logistic(loc=loc, scale=scale),
                    bijector=tfb.AffineScalar(shift=-0.5)),
                low=0.,
                high=2.**3 - 1)

            mixture_dist = tfd.MixtureSameFamily(
                mixture_distribution=tfd.Categorical(logits=logits),
                components_distribution=discretized_logistic_dist)

        elif distribution == 'normal':

            mixture_dist = tfd.MixtureSameFamily(
                mixture_distribution=tfd.Categorical(logits=logits),
                components_distribution=tfd.Normal(loc=loc, scale=scale))

        # Define a function for sampling, and a function for estimating the log likelihood
        #sample = tf.squeeze(mixture_dist.sample())
        sample = mixture_dist.sample()
        loglik = mixture_dist.log_prob(obs_layer)
        hub.add_signature(inputs={
            'features': feature_layer,
            'labels': obs_layer
        },
                          outputs={
                              'sample': sample,
                              'loglikelihood': loglik,
                              'loc': loc,
                              'scale': scale,
                              'logits': logits
                          })
예제 #13
0
def C3D(input, dimensions, dropout=False, regularizer=True):

    if regularizer:
        regularizer = tf.contrib.layers.l2_regularizer(0.005)  # 0.005

    with tf.variable_scope('C3D'):
        with slim.arg_scope(
            [slim.conv3d],
                padding='SAME',
                weights_initializer=tf.random_normal_initializer(stddev=0.01)
                #weights_regularizer=slim.l2_regularizer(0.0005)
        ):
            conv_1 = slim.conv3d(input, 64, 3, 1, scope='conv_1')
            maxpool_1 = slim.max_pool3d(conv_1,
                                        kernel_size=[1, 2, 2],
                                        stride=[1, 2, 2],
                                        padding='SAME',
                                        scope='maxpool_1')
            conv_2 = slim.conv3d(maxpool_1, 128, 3, 1, scope='conv_2')
            maxpool_2 = slim.max_pool3d(conv_2, [2, 2, 2], [2, 2, 2],
                                        padding='SAME',
                                        scope='maxpool_2')
            conv_3 = slim.conv3d(maxpool_2, 256, 3, 1, scope='conv_3')
            maxpool_3 = slim.max_pool3d(conv_3, [2, 2, 2], [2, 2, 2],
                                        padding='SAME',
                                        scope='maxpool_3')
            conv_4 = slim.conv3d(maxpool_3, 256, 3, 1, scope='conv_4')
            maxpool_4 = slim.max_pool3d(conv_4, [2, 2, 2], [2, 2, 2],
                                        padding='SAME',
                                        scope='maxpool_4')
            conv_5 = slim.conv3d(maxpool_4, 256, 3, 1, scope='conv_5')
            maxpool_5 = slim.max_pool3d(conv_5, [2, 2, 2], [2, 2, 2],
                                        padding='SAME',
                                        scope='maxpool_5')

        pool_shape = maxpool_5.get_shape().as_list()
        nodes = pool_shape[1] * pool_shape[2] * pool_shape[3] * pool_shape[
            4]  # 1 x 4 x 4 x 256 = 4096
        reshaped = tf.reshape(
            maxpool_5,
            [pool_shape[0], nodes])  # pool_shape[0] is N, batch_size

        with tf.variable_scope('fc6'):
            fc6_weight = tf.get_variable(
                'weight', [nodes, 2048],
                initializer=tf.random_normal_initializer(stddev=0.005))
            if regularizer:
                tf.add_to_collection("weight_decay_loss",
                                     regularizer(fc6_weight))
            fc6_bias = tf.get_variable(
                'bias', [2048], initializer=tf.constant_initializer(1.0))
            fc6 = tf.nn.relu(tf.matmul(reshaped, fc6_weight) + fc6_bias)
            if dropout:
                fc6 = tf.nn.dropout(fc6, 0.5)

        with tf.variable_scope('fc7'):
            fc7_weight = tf.get_variable(
                'weight', [2048, 2048],
                initializer=tf.random_normal_initializer(stddev=0.005))
            if regularizer:
                tf.add_to_collection("weight_decay_loss",
                                     regularizer(fc7_weight))
            fc7_bias = tf.get_variable(
                'bias', [2048], initializer=tf.constant_initializer(1.0))
            fc7 = tf.nn.relu(tf.matmul(fc6, fc7_weight) + fc7_bias)
            if dropout:
                fc7 = tf.nn.dropout(fc7, 0.5)

        with tf.variable_scope('fc8'):  # fc8
            out_weight = tf.get_variable(
                'weight', [2048, dimensions],
                initializer=tf.random_normal_initializer(stddev=0.01))
            if regularizer:
                tf.add_to_collection("weight_decay_loss",
                                     regularizer(out_weight))
            out_bias = tf.get_variable(
                'bias', [dimensions], initializer=tf.constant_initializer(0.0))
            out = tf.matmul(fc7, out_weight) + out_bias  # DO NOT ADD RELU!!!

        return out
예제 #14
0
def create_3D_UNet(x, features_root=16, n_classes=2):

    net = OrderedDict()
    with slim.arg_scope(
        [slim.conv3d, slim.conv3d_transpose],
            weights_initializer=initializers.variance_scaling_initializer(
                factor=2.0, mode='FAN_IN', uniform=False),
            activation_fn=leaky_relu):

        net['encode/conv1_1'] = instance_norm(
            slim.conv3d(x, features_root, [3, 3, 3]))
        net['encode/conv1_2'] = instance_norm(
            slim.conv3d(net['encode/conv1_1'], features_root, [3, 3, 3]))
        net['encode/pool1'] = slim.max_pool3d(net['encode/conv1_2'],
                                              kernel_size=[1, 2, 2],
                                              stride=[1, 2, 2])

        net['encode/conv2_1'] = instance_norm(
            slim.conv3d(net['encode/pool1'], features_root * 2, [3, 3, 3]))
        net['encode/conv2_2'] = instance_norm(
            slim.conv3d(net['encode/conv2_1'], features_root * 2, [3, 3, 3]))
        net['encode/pool2'] = slim.max_pool3d(net['encode/conv2_2'],
                                              kernel_size=[2, 2, 2],
                                              stride=[2, 2, 2])

        net['encode/conv3_1'] = instance_norm(
            slim.conv3d(net['encode/pool2'], features_root * 4, [3, 3, 3]))
        net['encode/conv3_2'] = instance_norm(
            slim.conv3d(net['encode/conv3_1'], features_root * 4, [3, 3, 3]))
        net['encode/pool3'] = slim.max_pool3d(net['encode/conv3_2'], [2, 2, 2])

        net['encode/conv4_1'] = instance_norm(
            slim.conv3d(net['encode/pool3'], features_root * 8, [3, 3, 3]))
        net['encode/conv4_2'] = instance_norm(
            slim.conv3d(net['encode/conv4_1'], features_root * 8, [3, 3, 3]))
        net['encode/pool4'] = slim.max_pool3d(net['encode/conv4_2'], [2, 2, 2])

        net['encode/conv5_1'] = instance_norm(
            slim.conv3d(net['encode/pool4'], features_root * 16, [3, 3, 3]))
        net['encode/conv5_2'] = instance_norm(
            slim.conv3d(net['encode/conv5_1'], features_root * 16, [3, 3, 3]))

        net['decode/up_conv1'] = slim.conv3d_transpose(net['encode/conv5_2'],
                                                       features_root * 8,
                                                       [2, 2, 2],
                                                       stride=2,
                                                       activation_fn=None,
                                                       padding='VALID',
                                                       biases_initializer=None)
        net['decode/concat_c4_u1'] = tf.concat(
            [net['encode/conv4_2'], net['decode/up_conv1']], 4)
        net['decode/conv1_1'] = instance_norm(
            slim.conv3d(net['decode/concat_c4_u1'], features_root * 8,
                        [3, 3, 3]))
        net['decode/conv1_2'] = instance_norm(
            slim.conv3d(net['decode/conv1_1'], features_root * 8, [3, 3, 3]))

        net['decode/up_conv2'] = slim.conv3d_transpose(net['decode/conv1_2'],
                                                       features_root * 4,
                                                       [2, 2, 2],
                                                       stride=2,
                                                       activation_fn=None,
                                                       padding='VALID',
                                                       biases_initializer=None)

        net['decode/concat_c3_u2'] = tf.concat(
            [net['encode/conv3_2'], net['decode/up_conv2']], 4)
        net['decode/conv2_1'] = instance_norm(
            slim.conv3d(net['decode/concat_c3_u2'], features_root * 4,
                        [3, 3, 3]))
        net['decode/conv2_2'] = instance_norm(
            slim.conv3d(net['decode/conv2_1'], features_root * 4, [3, 3, 3]))

        net['decode/up_conv3'] = slim.conv3d_transpose(net['decode/conv2_2'],
                                                       features_root * 2,
                                                       kernel_size=[2, 2, 2],
                                                       stride=[2, 2, 2],
                                                       activation_fn=None,
                                                       padding='VALID',
                                                       biases_initializer=None)
        net['decode/concat_c2_u3'] = tf.concat(
            [net['encode/conv2_2'], net['decode/up_conv3']], 4)
        net['decode/conv3_1'] = instance_norm(
            slim.conv3d(net['decode/concat_c2_u3'], features_root * 2,
                        [3, 3, 3]))
        net['decode/conv3_2'] = instance_norm(
            slim.conv3d(net['decode/conv3_1'], features_root * 2, [3, 3, 3]))

        net['decode/up_conv4'] = slim.conv3d_transpose(net['decode/conv3_2'],
                                                       features_root,
                                                       [1, 2, 2],
                                                       stride=[1, 2, 2],
                                                       activation_fn=None,
                                                       padding='VALID',
                                                       biases_initializer=None)

        net['decode/concat_c1_u4'] = tf.concat(
            [net['encode/conv1_2'], net['decode/up_conv4']], 4)
        net['decode/conv4_1'] = instance_norm(
            slim.conv3d(net['decode/concat_c1_u4'], features_root, [3, 3, 3]))
        net['decode/conv4_2'] = instance_norm(
            slim.conv3d(net['decode/conv4_1'], features_root, [3, 3, 3]))

        net['out_map'] = instance_norm(
            slim.conv3d(net['decode/conv4_2'],
                        n_classes, [1, 1, 1],
                        activation_fn=None))

    return net
def NonLocalBlock(input_x,
                  out_channels,
                  sub_sample=True,
                  is_bn=True,
                  scope='NonLocalBlock'):
    batchsize, clips, height, width, in_channels = input_x.get_shape().as_list(
    )
    with tf.variable_scope(scope) as sc:
        with tf.variable_scope('g') as scope:
            g = slim.conv3d(input_x,
                            out_channels,
                            kernel_size=1,
                            stride=1,
                            scope='g')
            if sub_sample:
                g = slim.max_pool3d(g, [1, 2, 2],
                                    stride=[1, 2, 2],
                                    scope='g_max_pool')

        with tf.variable_scope('phi') as scope:
            phi = slim.conv3d(input_x,
                              out_channels,
                              kernel_size=1,
                              stride=1,
                              scope='phi')
            if sub_sample:
                phi = slim.max_pool3d(phi, [1, 2, 2],
                                      stride=[1, 2, 2],
                                      scope='phi_max_pool')

        with tf.variable_scope('theta') as scope:
            theta = slim.conv3d(input_x,
                                out_channels,
                                kernel_size=1,
                                stride=1,
                                scope='theta')
        '''
        g_x = tf.reshape(g, [batchsize,clips*height*width,out_channels])
        '''
        g_x = tf.reshape(g, [batchsize, -1, out_channels])
        # g_x = tf.transpose(g_x, [0,2,3,1])
        '''
        theta_x = tf.reshape(theta, [batchsize,clips*height*width,out_channels])
        '''
        theta_x = tf.reshape(theta, [batchsize, -1, out_channels])
        # theta_x = tf.transpose(theta_x, [0,2,3,1])
        '''
        phi_x = tf.reshape(phi, [batchsize, clips*height*width,out_channels])
        '''
        phi_x = tf.reshape(phi, [batchsize, -1, out_channels])
        phi_x = tf.transpose(phi_x, [0, 2, 1])

        f = tf.matmul(theta_x, phi_x)
        f_softmax = tf.nn.softmax(f, -1)
        y = tf.matmul(f_softmax, g_x)
        y = tf.reshape(y, [batchsize, clips, height, width, out_channels])

        with tf.variable_scope('w') as scope:
            w_y = slim.conv3d(y,
                              in_channels,
                              kernel_size=1,
                              stride=1,
                              scope='w')
            if is_bn:
                w_y = slim.batch_norm(w_y)

        z = input_x + w_y

    return z
예제 #16
0
def model(input_tensor, weight_decay=1e-5, is_training=True):
    batch_norm_params = {
        'decay': 0.997,
        'epsilon': 1e-5,
        'scale': True,
        'is_training': is_training
    }
    with slim.arg_scope([slim.conv3d, slim.conv3d_transpose],
                        # normalizer_fn=slim.batch_norm,
                        activation_fn=tf.nn.relu,
                        # normalizer_params=batch_norm_params,
                        # weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
                        # weights_regularizer=slim.l2_regularizer(0.005),
                        weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
                        weights_regularizer=slim.l2_regularizer(0.005)):
        with slim.arg_scope([slim.batch_norm], is_training=is_training):
            # `[batch_size] + input_spatial_shape + [in_channels]` if data_format does
            scale1_1 = slim.conv3d(input_tensor, kernel_size=[9, 9, 7], num_outputs=8)
            scale1_1 = slim.batch_norm(scale1_1, activation_fn=tf.nn.relu)
            scale1_2 = slim.conv3d(scale1_1, kernel_size=[9, 9, 7], num_outputs=8)
            scale1_2 = slim.batch_norm(scale1_2, activation_fn=tf.nn.relu)
            down_pooling1 = slim.max_pool3d(scale1_2, kernel_size=[2, 2, 2], stride=2, padding='SAME')

            scale2_1 = slim.conv3d(down_pooling1, kernel_size=[7, 7, 5], num_outputs=16)
            scale2_1 = slim.batch_norm(scale2_1, activation_fn=tf.nn.relu)
            scale2_2 = slim.conv3d(scale2_1, kernel_size=[7, 7, 5], num_outputs=32)
            scale2_2 = slim.batch_norm(scale2_2, activation_fn=tf.nn.relu)
            down_pooling2 = slim.max_pool3d(scale2_2, kernel_size=[2, 2, 2], stride=2, padding='SAME')

            scale3_1 = slim.conv3d(down_pooling2, kernel_size=[5, 5, 3], num_outputs=32)
            scale3_1 = slim.batch_norm(scale3_1, activation_fn=tf.nn.relu)
            scale3_2 = slim.conv3d(scale3_1, kernel_size=[1, 1, 1], num_outputs=32)
            scale3_2 = slim.batch_norm(scale3_2, activation_fn=tf.nn.relu)


            print(scale3_2)
            with tf.variable_scope('predition_last'):
                # up_pooling1_1 = slim.conv3d_transpose(scale3_2, num_outputs=32, kernel_size=[3, 3, 3], stride=[2, 2, 2])
                # up_pooling1_2 = slim.conv3d(up_pooling1_1, num_outputs=2, kernel_size=[3, 3, 3], stride=[2, 2, 2])
                up_pooling1_1 = unpool(scale3_2)
                up_pooling1_2 = unpool(up_pooling1_1)
                pred_last = slim.conv3d(up_pooling1_2, kernel_size=[3, 3, 3], num_outputs=32)
                pred_last = slim.batch_norm(pred_last)
                pred_last = slim.conv3d(pred_last, kernel_size=[1, 1, 1], num_outputs=2, activation_fn=None)

            with tf.variable_scope('prediction_layer6'):
                # up_pooling2_1 = slim.conv3d_transpose(down_pooling2, num_outputs=32, kernel_size=[3, 3, 3],
                #                                       stride=[2, 2, 2])
                # up_pooling2_2 = slim.conv3d_transpose(up_pooling2_1, num_outputs=2, kernel_size=[3, 3, 3], stride=[2, 2, 2])
                up_pooling2_1 = unpool(down_pooling2)
                up_pooling2_2 = unpool(up_pooling2_1)
                pred_6 = slim.conv3d(up_pooling2_2, kernel_size=[3, 3, 3], num_outputs=32)
                pred_6 = slim.batch_norm(pred_6)
                pred_6 = slim.conv3d(pred_6, kernel_size=[1, 1, 1], num_outputs=2, activation_fn=None)

            with tf.variable_scope('prediction_layer3'):
                # up_pooling3_1 = slim.conv3d_transpose(down_pooling1, num_outputs=2, kernel_size=[3, 3, 3], stride=[2, 2, 2])
                up_pooling3_1 = unpool(down_pooling1)
                pred_3 = slim.conv3d(up_pooling3_1, kernel_size=[3, 3, 3], num_outputs=32)
                pred_3 = slim.batch_norm(pred_3)
                pred_3 = slim.conv3d(pred_3, kernel_size=[1, 1, 1], num_outputs=2, activation_fn=None)
            print(pred_last, pred_6, pred_3)

            return pred_last, pred_6, pred_3
예제 #17
0
def unet3d(inputs):
    """
    unet3D model without softmax.
    """
    print(inputs.shape)
    conv1 = slim.repeat(inputs=inputs,
                        repetitions=2,
                        layer=slim.layers.conv3d,
                        num_outputs=64,
                        kernel_size=3,
                        activation_fn=tf.nn.relu,
                        normalizer_fn=slim.batch_norm)
    print(conv1.shape)
    pool1 = slim.max_pool3d(inputs=conv1, kernel_size=2)
    print(pool1.shape)

    conv2 = slim.repeat(inputs=pool1,
                        repetitions=2,
                        layer=slim.conv3d,
                        num_outputs=128,
                        kernel_size=3,
                        activation_fn=tf.nn.relu,
                        normalizer_fn=slim.batch_norm)
    print(conv2.shape)
    pool2 = slim.max_pool3d(inputs=conv2, kernel_size=2)
    print(pool2.shape)

    conv3 = slim.repeat(inputs=pool2,
                        repetitions=2,
                        layer=slim.conv3d,
                        num_outputs=256,
                        activation_fn=tf.nn.relu,
                        kernel_size=3,
                        normalizer_fn=slim.batch_norm)
    print(conv3.shape)
    pool3 = slim.max_pool3d(inputs=conv3, kernel_size=2)
    print(pool3.shape)

    conv4 = slim.repeat(inputs=pool3,
                        repetitions=2,
                        layer=slim.conv3d,
                        num_outputs=512,
                        activation_fn=tf.nn.relu,
                        kernel_size=3,
                        normalizer_fn=slim.batch_norm)
    print(conv4.shape)
    # pool4 = slim.max_pool3d(inputs=conv4, kernel_size=2)
    # print(pool4.shape)

    # conv5 = slim.repeat(inputs=pool4,
    #                     repetitions=2,
    #                     layer=slim.conv3d,
    #                     num_outputs=1024,
    #                     activation_fn=tf.nn.relu,
    #                     kernel_size=3,
    #                     normalizer_fn=slim.batch_norm)
    # print(conv5.shape)

    # upsampling1 = slim.conv3d_transpose(inputs=conv5,
    #                                     kernel_size=3,
    #                                     num_outputs=1024,
    #                                     stride=2,
    #                                     activation_fn=tf.nn.relu,
    #                                     normalizer_fn=slim.batch_norm)
    # print(upsampling1.shape)
    # upconv1 = slim.conv3d(inputs=upsampling1,
    #                       kernel_size=2,
    #                       num_outputs=512,
    #                       activation_fn=tf.nn.relu,
    #                       normalizer_fn=slim.batch_norm)
    # print(upconv1.shape)
    # concat1 = tf.concat([conv4, upconv1], 3)
    # print(concat1.shape)

    # conv4 = slim.repeat(inputs=concat1,
    #                     repetitions=2,
    #                     layer=slim.conv3d,
    #                     num_outputs=512,
    #                     activation_fn=tf.nn.relu,
    #                     kernel_size=3,
    #                     normalizer_fn=slim.batch_norm)
    # print(conv4.shape)

    upsampling2 = slim.conv3d_transpose(inputs=conv4,
                                        kernel_size=3,
                                        num_outputs=512,
                                        stride=2,
                                        activation_fn=tf.nn.relu,
                                        normalizer_fn=slim.batch_norm)
    print(upsampling2.shape)
    upconv2 = slim.conv3d(inputs=upsampling2,
                          kernel_size=2,
                          num_outputs=256,
                          activation_fn=tf.nn.relu,
                          normalizer_fn=slim.batch_norm)
    print(upconv2.shape)
    concat2 = tf.concat([conv3, upconv2], 4)
    print(concat2.shape)
    conv3 = slim.repeat(inputs=concat2,
                        repetitions=2,
                        layer=slim.conv3d,
                        num_outputs=256,
                        activation_fn=tf.nn.relu,
                        kernel_size=3,
                        normalizer_fn=slim.batch_norm)
    print(conv3.shape)

    upsampling3 = slim.conv3d_transpose(inputs=conv3,
                                        kernel_size=3,
                                        num_outputs=256,
                                        stride=2,
                                        activation_fn=tf.nn.relu,
                                        normalizer_fn=slim.batch_norm)
    print(upsampling3.shape)
    upconv3 = slim.conv3d(inputs=upsampling3,
                          kernel_size=2,
                          num_outputs=128,
                          activation_fn=tf.nn.relu,
                          normalizer_fn=slim.batch_norm)
    print(upconv3.shape)
    concat3 = tf.concat([conv2, upconv3], 4)
    print(concat3.shape)
    conv2 = slim.repeat(inputs=concat3,
                        repetitions=2,
                        layer=slim.conv3d,
                        num_outputs=128,
                        activation_fn=tf.nn.relu,
                        kernel_size=3,
                        normalizer_fn=slim.batch_norm)
    print(conv2.shape)

    upsampling4 = slim.conv3d_transpose(inputs=conv2,
                                        kernel_size=3,
                                        num_outputs=128,
                                        stride=2,
                                        activation_fn=tf.nn.relu,
                                        normalizer_fn=slim.batch_norm)
    print(upsampling4.shape)
    upconv4 = slim.conv3d(inputs=upsampling4,
                          kernel_size=2,
                          num_outputs=64,
                          activation_fn=tf.nn.relu,
                          normalizer_fn=slim.batch_norm)
    print(upconv4.shape)
    concat4 = tf.concat([conv1, upconv4], 4)
    print(concat4.shape)
    conv1 = slim.repeat(inputs=concat4,
                        repetitions=2,
                        layer=slim.conv3d,
                        num_outputs=64,
                        activation_fn=tf.nn.relu,
                        kernel_size=3,
                        normalizer_fn=slim.batch_norm)
    print(conv1.shape)

    output = slim.repeat(inputs=conv1,
                         repetitions=1,
                         layer=slim.conv3d,
                         num_outputs=4,
                         activation_fn=tf.identity,
                         kernel_size=1,
                         normalizer_fn=slim.batch_norm)
    print(output.shape)

    return output
예제 #18
0
def unet_valid_sparese(vox_feat,
                       mask,
                       channels,
                       FLAGS,
                       trainable=True,
                       if_bn=False,
                       reuse=False,
                       is_training=True,
                       activation_fn=tf.nn.relu,
                       scope_name='unet_3d'):

    with tf.variable_scope(scope_name) as scope:
        if reuse:
            scope.reuse_variables()

        if if_bn:
            batch_normalizer_gen = slim.batch_norm
            batch_norm_params_gen = {
                'is_training': is_training,
                'decay': FLAGS.bn_decay
            }
        else:
            batch_normalizer_gen = None
            batch_norm_params_gen = None

        if FLAGS.if_l2Reg:
            weights_regularizer = slim.l2_regularizer(1e-5)
        else:
            weights_regularizer = None

        with slim.arg_scope(
            [slim.fully_connected, slim.conv3d, slim.conv3d_transpose],
                activation_fn=activation_fn,
                trainable=trainable,
                normalizer_fn=batch_normalizer_gen,
                normalizer_params=batch_norm_params_gen,
                weights_regularizer=weights_regularizer):

            mask_down1 = tf.stop_gradient(mask)
            net_down1 = slim.conv3d(vox_feat *
                                    tf.tile(mask_down1, [1, 1, 1, 1, 16]),
                                    16,
                                    kernel_size=4,
                                    stride=2,
                                    padding='SAME',
                                    scope='unet_conv1')
            mask_down2 = tf.stop_gradient(
                slim.max_pool3d(mask_down1,
                                kernel_size=4,
                                stride=2,
                                padding='SAME'))
            net_down2 = slim.conv3d(net_down1 *
                                    tf.tile(mask_down2, [1, 1, 1, 1, 16]),
                                    32,
                                    kernel_size=4,
                                    stride=2,
                                    padding='SAME',
                                    scope='unet_conv2')
            #net_down2 = slim.conv3d(net_down1 , 32, kernel_size=4, stride=2, padding='SAME', scope='unet_conv2')
            mask_down3 = tf.stop_gradient(
                slim.max_pool3d(mask_down2,
                                kernel_size=4,
                                stride=2,
                                padding='SAME'))
            net_down3 = slim.conv3d(net_down2 *
                                    tf.tile(mask_down3, [1, 1, 1, 1, 32]),
                                    64,
                                    kernel_size=4,
                                    stride=2,
                                    padding='SAME',
                                    scope='unet_conv3')
            #net_down3 = slim.conv3d(net_down2, 64, kernel_size=4, stride=2, padding='SAME', scope='unet_conv3')
            mask_down4 = tf.stop_gradient(
                slim.max_pool3d(mask_down3,
                                kernel_size=4,
                                stride=2,
                                padding='SAME'))
            net_down4 = slim.conv3d(net_down3 *
                                    tf.tile(mask_down4, [1, 1, 1, 1, 64]),
                                    128,
                                    kernel_size=4,
                                    stride=2,
                                    padding='SAME',
                                    scope='unet_conv4')
            #net_down4 = slim.conv3d(net_down3, 128, kernel_size=4, stride=2, padding='SAME', scope='unet_conv4')
            mask_down5 = tf.stop_gradient(
                slim.max_pool3d(mask_down4,
                                kernel_size=4,
                                stride=2,
                                padding='SAME'))
            net_down5 = slim.conv3d(net_down4 *
                                    tf.tile(mask_down5, [1, 1, 1, 1, 128]),
                                    256,
                                    kernel_size=4,
                                    stride=2,
                                    padding='SAME',
                                    scope='unet_conv5')
            #net_down5 = slim.conv3d(net_down4, 256, kernel_size=4, stride=2, padding='SAME', scope='unet_conv5')
            mask_down6 = tf.stop_gradient(
                slim.max_pool3d(mask_down5,
                                kernel_size=4,
                                stride=2,
                                padding='SAME'))

            net_up4 = slim.conv3d_transpose(net_down5*tf.tile(mask_down6, [1,1,1,1,256]), 128, kernel_size=4, stride=2, padding='SAME', \
                scope='unet_deconv4')
            #net_up4 = slim.conv3d_transpose(net_down5, 128, kernel_size=4, stride=2, padding='SAME', \
            #    scope='unet_deconv4')
            net_up4_ = tf.concat([net_up4, net_down4], axis=-1)
            net_up3 = slim.conv3d_transpose(net_up4_*tf.tile(mask_down5, [1,1,1,1,256]), 64, kernel_size=4, stride=2, padding='SAME', \
                scope='unet_deconv3')
            #net_up3 = slim.conv3d_transpose(net_up4_, 64, kernel_size=4, stride=2, padding='SAME', \
            #    scope='unet_deconv3')
            net_up3_ = tf.concat([net_up3, net_down3], axis=-1)
            net_up2 = slim.conv3d_transpose(net_up3_*tf.tile(mask_down4, [1,1,1,1,128]), 32, kernel_size=4, stride=2, padding='SAME', \
                scope='unet_deconv2')
            #net_up2 = slim.conv3d_transpose(net_up3_, 32, kernel_size=4, stride=2, padding='SAME', \
            #    scope='unet_deconv2')
            net_up2_ = tf.concat([net_up2, net_down2], axis=-1)
            net_up1 = slim.conv3d_transpose(net_up2_*tf.tile(mask_down3, [1,1,1,1,64]), 16, kernel_size=4, stride=2, padding='SAME', \
                scope='unet_deconv1')
            #net_up1 = slim.conv3d_transpose(net_up2_, 16, kernel_size=4, stride=2, padding='SAME', \
            #    scope='unet_deconv1')
            net_up1_ = tf.concat([net_up1, net_down1], axis=-1)
            #net_out_ = slim.conv3d(net_up1_, 1, kernel_size=4, stride=2, padding='SAME', \
            #    activation_fn=None, normalizer_fn=None, normalizer_params=None, scope='unet_deconv_out')
            ## heavy load
            net_up0 = slim.conv3d_transpose(net_up1_*tf.tile(mask_down2, [1,1,1,1,32]), channels, kernel_size=4, stride=2, padding='SAME', \
                scope='unet_deconv0')
            #net_up0 = slim.conv3d_transpose(net_up1_, channels, kernel_size=4, stride=2, padding='SAME', \
            #    scope='unet_deconv0')
            net_up0_ = tf.concat([net_up0, vox_feat], axis=-1)
            net_out_ = slim.conv3d(net_up0_, 1, kernel_size=3, stride=1, padding='SAME', \
                activation_fn=None, normalizer_fn=None, normalizer_params=None, scope='unet_deconv_out')
            ## heavy load
            #net_up2_ = tf.add(net_up2, net_down2)
            #net_up1 = slim.conv3d_transpose(net_up2_, 64, kernel_size=[4,4], stride=[2,2], padding='SAME', \
            #    scope='unet_deconv1')
            #net_up1_ = tf.concat([net_up1, net_down1], axis=-1)
            #net_out_ = slim.conv3d_transpose(net_up1_, out_channel, kernel_size=[4,4], stride=[2,2], padding='SAME', \
            #    activation_fn=None, normalizer_fn=None, normalizer_params=None, scope='unet_out')

    return tf.nn.sigmoid(net_out_), net_out_