def EBlock(self, last, channels, resblocks=1, bottleneck=True, kernel=[1, 3], stride=[1, 2], format=DATA_FORMAT, activation=ACTIVATION, normalizer=None, regularizer=None, collections=None): initializer = tf.initializers.variance_scaling(1.0, 'fan_in', 'normal', self.random_seed, self.dtype) skip = last in_channels = skip.get_shape()[-3 if format == 'NCHW' else -1] if activation: last = activation(last) if bottleneck and in_channels > channels: last = slim.conv2d(last, channels, [1, 1], [1, 1], 'SAME', format, 1, activation, None, weights_initializer=initializer, weights_regularizer=regularizer, variables_collections=collections) last = slim.conv2d(last, channels, kernel, stride, 'SAME', format, 1, None, None, weights_initializer=initializer, weights_regularizer=regularizer, variables_collections=collections) for i in range(resblocks): with tf.variable_scope('ResBlock_{}'.format(i)): last = self.ResBlock(last, channels, format=format, activation=activation, normalizer=normalizer, regularizer=regularizer, collections=collections) with tf.variable_scope('DenseConnection'): last = layers.SEUnit(last, channels, format, collections) if stride != 1 or stride != [1, 1]: pool_stride = [1, 1] + stride if format == 'NCHW' else [ 1 ] + stride + [1] skip = tf.nn.avg_pool(skip, pool_stride, pool_stride, 'SAME', format) last = tf.concat([skip, last], -3 if format == 'NCHW' else -1) return last
def ResBlock(self, last, channels, kernel=[1, 3], stride=[1, 1], biases=True, format=DATA_FORMAT, dilate=1, activation=ACTIVATION, normalizer=None, regularizer=None, collections=None): biases = tf.initializers.zeros(self.dtype) if biases else None initializer = tf.initializers.variance_scaling(1.0, 'fan_in', 'normal', self.random_seed, self.dtype) skip = last # pre-activation if normalizer: last = normalizer(last) if activation: last = activation(last) # convolution last = slim.conv2d(last, channels, kernel, stride, 'SAME', format, dilate, activation, normalizer, None, initializer, regularizer, biases, variables_collections=collections) last = slim.conv2d(last, channels, kernel, stride, 'SAME', format, dilate, None, None, None, initializer, regularizer, biases, variables_collections=collections) # residual connection last = layers.SEUnit(last, channels, format, collections) last += skip return last
def EBlock(self, last, channels, resblocks=1, kernel=[4, 4], stride=[2, 2], format=DATA_FORMAT, activation=ACTIVATION, normalizer=None, regularizer=None, collections=None): initializer = tf.initializers.variance_scaling(1.0, 'fan_in', 'normal', self.random_seed, self.dtype) skip = last # pre-activation if activation: last = activation(last) # convolution last = slim.conv2d(last, channels, kernel, stride, 'SAME', format, 1, None, None, weights_initializer=initializer, weights_regularizer=regularizer, variables_collections=collections) # residual blocks for i in range(resblocks): with tf.variable_scope('ResBlock_{}'.format(i)): last = self.ResBlock(last, channels, format=format, activation=activation, normalizer=normalizer, regularizer=regularizer, collections=collections) # dense connection with tf.variable_scope('DenseConnection'): last = layers.SEUnit(last, channels, format, collections) if stride != 1 or stride != [1, 1]: pool_stride = [1, 1] + stride if format == 'NCHW' else [ 1 ] + stride + [1] skip = tf.nn.avg_pool(skip, pool_stride, pool_stride, 'SAME', format) last = tf.concat([skip, last], -3 if format == 'NCHW' else -1) return last