def dense_blk(name, l, ch, ksize, count, split=1, padding='valid'): with tf.variable_scope(name): for i in range(0, count): with tf.variable_scope('blk/' + str(i)): x = BNReLU('preact_bna', l) x = Conv2D('conv1', x, ch[0], ksize[0], padding=padding, activation=BNReLU) x = Conv2D('conv2', x, ch[1], ksize[1], padding=padding, split=split) ## if padding == 'valid': x_shape = x.get_shape().as_list() l_shape = l.get_shape().as_list() l = crop_op( l, (l_shape[2] - x_shape[2], l_shape[3] - x_shape[3])) l = tf.concat([l, x], axis=1) l = BNReLU('blk_bna', l) return l
def dense_blk(name, l, ch, ksize, count, split=1, padding='valid'): with tf.variable_scope(name): for i in range(0, count): with tf.variable_scope('blk/' + str(i)): x = BNReLU('preact_bna', l) x = Conv2D('conv1', x, ch[0], ksize[0], padding=padding, activation=BNReLU) x = Conv2D('conv2', x, ch[1], ksize[1], padding=padding, split=split) ## if padding == 'valid': x_shape = x.get_shape().as_list() l_shape = l.get_shape().as_list() l = crop_op( l, (l_shape[2] - x_shape[2], l_shape[3] - x_shape[3])) l = Conv2D('conv2forl', l, ch[1], ksize[0], padding='same', split=split) with tf.variable_scope('CBAM', reuse=tf.AUTO_REUSE): residual = x ratio = 8 ###CBAM module kernel_initializer = tf.contrib.layers.variance_scaling_initializer( ) bias_initializer = tf.constant_initializer(value=0.0) channel = x.shape[1] print(channel) # plt.pause(3.0) avg_pool = tf.reduce_mean(x, axis=[2, 3], keepdims=True) # avg_pool = GlobalAvgPooling('globalAveragepooling',x,data_format='channels_first') # avg_pool = tf.reshape(avg_pool,[-1,channel,1,1],name='reshape_avgpool') print(avg_pool) # plt.pause(3.0) assert avg_pool.shape[1:] == (channel, 1, 1) # avg_pool = tf.compat.v1.layers.Dense( # units=channel//ratio, # activation=tf.nn.relu, # kernel_initializer=kernel_initializer, # bias_initializer=bias_initializer, # name='mlp_0')(avg_pool) avg_pool = FullyConnected( 'fullyconnected1', avg_pool, units=channel // ratio, activation=tf.nn.relu, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer) print(avg_pool) # plt.pause(3.0) avg_pool = tf.reshape(avg_pool, [-1, channel // ratio, 1, 1], name='reshape_avgpoolB') assert avg_pool.shape[1:] == (channel // ratio, 1, 1) print(avg_pool) # plt.pause(3.0) avg_pool = FullyConnected( 'fullyconnectedB', avg_pool, units=channel, activation=None, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer) avg_pool = tf.reshape(avg_pool, [-1, channel, 1, 1], name='reshape_avgpoolC') assert avg_pool.get_shape()[1:] == (channel, 1, 1) print(avg_pool) # plt.pause(3.0) max_pool = tf.reduce_max(x, axis=[2, 3], keepdims=True) assert max_pool.get_shape()[1:] == (channel, 1, 1) max_pool = FullyConnected('fullyconnected1', max_pool, units=channel // ratio, activation=tf.nn.relu) max_pool = tf.reshape(max_pool, [-1, channel // ratio, 1, 1], name='reshape_maxpoolA') assert max_pool.get_shape()[1:] == (channel // ratio, 1, 1) max_pool = FullyConnected('fullyconnectedB', max_pool, units=channel, activation=None) max_pool = tf.reshape(max_pool, [-1, channel, 1, 1], name='reshape_maxpoolB') assert max_pool.get_shape()[1:] == (channel, 1, 1) scale = tf.sigmoid(avg_pool + max_pool, 'sigmoid') channel_map = tf.math.multiply(residual, scale, name='merge_Channel') #spatial_attention kernel_size = 7 avg_pool_spatial = tf.reduce_mean(channel_map, axis=[1], keepdims=True) print("Spatial averge pooling : ", avg_pool_spatial) # plt.pause(3.0) assert avg_pool_spatial.get_shape()[1] == 1 max_pool_spatial = tf.reduce_max(channel_map, axis=[1], keepdims=True) assert max_pool_spatial.get_shape()[1] == 1 print("Spatial max pooling : ", max_pool_spatial) # plt.pause(3.0) concat = tf.concat([avg_pool_spatial, max_pool_spatial], 1) assert concat.get_shape()[1] == 2 print(concat.get_shape()) concat = tp.models.Conv2D( 'spatialconvolution', concat, 1, kernel_size, kernel_initializer=kernel_initializer, use_bias=False) assert concat.get_shape()[1] == 1 print("concat convolution shape : ", concat.get_shape()) concat = tf.sigmoid(concat, 'sigmoid') concat = concat * channel_map l = l + concat # l = tf.concat([l, concat], axis=1) l = BNReLU('blk_bna', l) return l