def _build_network(self): import config if config.model_type == MODEL_TYPE_vgg16: from nets import vgg with slim.arg_scope( [slim.conv2d], activation_fn=tf.nn.relu, weights_regularizer=slim.l2_regularizer( config.weight_decay), weights_initializer=tf.contrib.layers.xavier_initializer(), biases_initializer=tf.zeros_initializer()): with slim.arg_scope([slim.conv2d, slim.max_pool2d], padding='SAME') as sc: self.arg_scope = sc self.net, self.end_points = vgg.basenet(inputs=self.inputs, pooling='MAX') elif config.model_type == MODEL_TYPE_vgg16_no_dilation: from nets import vgg with slim.arg_scope( [slim.conv2d], activation_fn=tf.nn.relu, weights_regularizer=slim.l2_regularizer( config.weight_decay), weights_initializer=tf.contrib.layers.xavier_initializer(), biases_initializer=tf.zeros_initializer()): with slim.arg_scope([slim.conv2d, slim.max_pool2d], padding='SAME') as sc: self.arg_scope = sc self.net, self.end_points = vgg.basenet(inputs=self.inputs, dilation=False, pooling='MAX') else: raise ValueError('model_type not supported:%s' % (config.model_type))
def model_vgg(images, weight_decay=1e-5, is_training=True): ''' define the model, we use slim's implemention of resnet ''' images = mean_image_subtraction(images) with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)): logits, end_points = vgg.basenet(images, scope='vgg_16') with tf.variable_scope('feature_fusion', values=[end_points.values]): batch_norm_params = { 'decay': 0.997, 'epsilon': 1e-5, 'scale': True, 'is_training': is_training } with slim.arg_scope( [slim.conv2d], activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params, weights_regularizer=slim.l2_regularizer(weight_decay)): feature_maps = [ end_points['fc7'], end_points['conv5_3'], end_points['conv4_3'], end_points['conv3_3'], end_points['conv2_2'] ] pixel_1 = slim.conv2d(end_points['fc7'], 2, 1) + slim.conv2d( end_points['conv5_3'], 2, 1) pixel_2 = unpool(pixel_1) + slim.conv2d(end_points['conv4_3'], 2, 1) pixel_3 = unpool(pixel_2) + slim.conv2d(end_points['conv3_3'], 2, 1) pixel_cls = slim.conv2d(pixel_3, 2, 1) print('pixel_shape:{}'.format(pixel_cls.shape)) link_1 = slim.conv2d(end_points['fc7'], 16, 1) + slim.conv2d( end_points['conv5_3'], 16, 1) link_2 = unpool(link_1) + slim.conv2d(end_points['conv4_3'], 16, 1) link_3 = unpool(link_2) + slim.conv2d(end_points['conv3_3'], 16, 1) link_cls = slim.conv2d(link_3, 16, 1) print('link_shape:{}'.format(link_cls.shape)) return pixel_cls, link_cls