Beispiel #1
0
    def _build_network(self):
        import config
        if config.model_type == MODEL_TYPE_vgg16:
            from nets import vgg
            with slim.arg_scope(
                [slim.conv2d],
                    activation_fn=tf.nn.relu,
                    weights_regularizer=slim.l2_regularizer(
                        config.weight_decay),
                    weights_initializer=tf.contrib.layers.xavier_initializer(),
                    biases_initializer=tf.zeros_initializer()):
                with slim.arg_scope([slim.conv2d, slim.max_pool2d],
                                    padding='SAME') as sc:
                    self.arg_scope = sc
                    self.net, self.end_points = vgg.basenet(inputs=self.inputs,
                                                            pooling='MAX')

        elif config.model_type == MODEL_TYPE_vgg16_no_dilation:
            from nets import vgg
            with slim.arg_scope(
                [slim.conv2d],
                    activation_fn=tf.nn.relu,
                    weights_regularizer=slim.l2_regularizer(
                        config.weight_decay),
                    weights_initializer=tf.contrib.layers.xavier_initializer(),
                    biases_initializer=tf.zeros_initializer()):
                with slim.arg_scope([slim.conv2d, slim.max_pool2d],
                                    padding='SAME') as sc:
                    self.arg_scope = sc
                    self.net, self.end_points = vgg.basenet(inputs=self.inputs,
                                                            dilation=False,
                                                            pooling='MAX')
        else:
            raise ValueError('model_type not supported:%s' %
                             (config.model_type))
def model_vgg(images, weight_decay=1e-5, is_training=True):
    '''
    define the model, we use slim's implemention of resnet
    '''
    images = mean_image_subtraction(images)

    with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)):
        logits, end_points = vgg.basenet(images, scope='vgg_16')

    with tf.variable_scope('feature_fusion', values=[end_points.values]):
        batch_norm_params = {
            'decay': 0.997,
            'epsilon': 1e-5,
            'scale': True,
            'is_training': is_training
        }
        with slim.arg_scope(
            [slim.conv2d],
                activation_fn=tf.nn.relu,
                normalizer_fn=slim.batch_norm,
                normalizer_params=batch_norm_params,
                weights_regularizer=slim.l2_regularizer(weight_decay)):

            feature_maps = [
                end_points['fc7'], end_points['conv5_3'],
                end_points['conv4_3'], end_points['conv3_3'],
                end_points['conv2_2']
            ]

            pixel_1 = slim.conv2d(end_points['fc7'], 2, 1) + slim.conv2d(
                end_points['conv5_3'], 2, 1)
            pixel_2 = unpool(pixel_1) + slim.conv2d(end_points['conv4_3'], 2,
                                                    1)
            pixel_3 = unpool(pixel_2) + slim.conv2d(end_points['conv3_3'], 2,
                                                    1)
            pixel_cls = slim.conv2d(pixel_3, 2, 1)

            print('pixel_shape:{}'.format(pixel_cls.shape))

            link_1 = slim.conv2d(end_points['fc7'], 16, 1) + slim.conv2d(
                end_points['conv5_3'], 16, 1)
            link_2 = unpool(link_1) + slim.conv2d(end_points['conv4_3'], 16, 1)
            link_3 = unpool(link_2) + slim.conv2d(end_points['conv3_3'], 16, 1)
            link_cls = slim.conv2d(link_3, 16, 1)

            print('link_shape:{}'.format(link_cls.shape))

    return pixel_cls, link_cls