def dilated_residual_2d(inputs, filters, kernel_size=3, rate_mult=2, dropout=0, repeat=1, symmetric=True, **kwargs): """Construct a residual dilated convolution block. """ # flow through variable current current = inputs # initialize dilation rate dilation_rate = 1.0 for ri in range(repeat): rep_input = current # dilate current = conv_block_2d(current, filters=filters, kernel_size=kernel_size, dilation_rate=int(np.round(dilation_rate)), bn_gamma='ones', **kwargs) # return current = conv_block_2d(current, filters=rep_input.shape[-1], dropout=dropout, bn_gamma='zeros', **kwargs) # residual add current = tf.keras.layers.Add()([rep_input, current]) # enforce symmetry if symmetric: current = layers.Symmetrize2D()(current) # update dilation rate dilation_rate *= rate_mult return current
def symmetrize_2d(inputs, **kwargs): return layers.Symmetrize2D()(inputs)
def conv_block_2d(inputs, filters=128, activation='relu', conv_type='standard', kernel_size=1, strides=1, dilation_rate=1, l2_scale=0, dropout=0, pool_size=1, batch_norm=False, bn_momentum=0.99, bn_gamma='ones', bn_type='standard', symmetric=False): """Construct a single 2D convolution block. """ # flow through variable current current = inputs # activation current = layers.activate(current, activation) # choose convolution type if conv_type == 'separable': conv_layer = tf.keras.layers.SeparableConv2D else: conv_layer = tf.keras.layers.Conv2D # convolution current = conv_layer( filters=filters, kernel_size=kernel_size, strides=strides, padding='same', use_bias=False, dilation_rate=dilation_rate, kernel_initializer='he_normal', kernel_regularizer=tf.keras.regularizers.l2(l2_scale))(current) # batch norm if batch_norm: if bn_type == 'sync': bn_layer = tf.keras.layers.experimental.SyncBatchNormalization else: bn_layer = tf.keras.layers.BatchNormalization current = bn_layer(momentum=bn_momentum, gamma_initializer=bn_gamma)(current) # dropout if dropout > 0: current = tf.keras.layers.Dropout(rate=dropout)(current) # pool if pool_size > 1: current = tf.keras.layers.MaxPool2D(pool_size=pool_size, padding='same')(current) # symmetric if symmetric: current = layers.Symmetrize2D()(current) return current