def bn_feature_net_2D(receptive_field=61, input_shape=(256, 256, 1), inputs=None, n_features=3, n_channels=1, reg=1e-5, n_conv_filters=64, n_dense_filters=200, VGG_mode=False, init='he_normal', norm_method='std', location=False, dilated=False, padding=False, padding_mode='reflect', multires=False, include_top=True): """Creates a 2D featurenet. Args: receptive_field (int): the receptive field of the neural network. input_shape (tuple): If no input tensor, create one with this shape. inputs (tensor): optional input tensor n_features (int): Number of output features n_channels (int): number of input channels reg (int): regularization value n_conv_filters (int): number of convolutional filters n_dense_filters (int): number of dense filters VGG_mode (bool): If multires, uses VGG_mode for multiresolution init (str): Method for initalizing weights. norm_method (str): ImageNormalization mode to use location (bool): Whether to include location data dilated (bool): Whether to use dilated pooling. padding (bool): Whether to use padding. padding_mode (str): Type of padding, one of 'reflect' or 'zero' multires (bool): Enables multi-resolution mode include_top (bool): Whether to include the final layer of the model Returns: tensorflow.keras.Model: 2D FeatureNet """ # Create layers list (x) to store all of the layers. # We need to use the functional API to enable the multiresolution mode x = [] win = (receptive_field - 1) // 2 if dilated: padding = True if K.image_data_format() == 'channels_first': channel_axis = 1 row_axis = 2 col_axis = 3 if not dilated: input_shape = (n_channels, receptive_field, receptive_field) else: row_axis = 1 col_axis = 2 channel_axis = -1 if not dilated: input_shape = (receptive_field, receptive_field, n_channels) if inputs is not None: if not K.is_keras_tensor(inputs): img_input = Input(tensor=inputs, shape=input_shape) else: img_input = inputs x.append(img_input) else: x.append(Input(shape=input_shape)) x.append( ImageNormalization2D(norm_method=norm_method, filter_size=receptive_field)(x[-1])) if padding: if padding_mode == 'reflect': x.append(ReflectionPadding2D(padding=(win, win))(x[-1])) elif padding_mode == 'zero': x.append(ZeroPadding2D(padding=(win, win))(x[-1])) if location: x.append(Location2D(in_shape=tuple(x[-1].shape.as_list()[1:]))(x[-1])) x.append(Concatenate(axis=channel_axis)([x[-2], x[-1]])) layers_to_concat = [] rf_counter = receptive_field block_counter = 0 d = 1 while rf_counter > 4: filter_size = 3 if rf_counter % 2 == 0 else 4 x.append( Conv2D(n_conv_filters, filter_size, dilation_rate=d, kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1])) x.append(BatchNormalization(axis=channel_axis)(x[-1])) x.append(Activation('relu')(x[-1])) block_counter += 1 rf_counter -= filter_size - 1 if block_counter % 2 == 0: if dilated: x.append( DilatedMaxPool2D(dilation_rate=d, pool_size=(2, 2))(x[-1])) d *= 2 else: x.append(MaxPool2D(pool_size=(2, 2))(x[-1])) if VGG_mode: n_conv_filters *= 2 rf_counter = rf_counter // 2 if multires: layers_to_concat.append(len(x) - 1) if multires: c = [] for l in layers_to_concat: output_shape = x[l].get_shape().as_list() target_shape = x[-1].get_shape().as_list() row_crop = int(output_shape[row_axis] - target_shape[row_axis]) if row_crop % 2 == 0: row_crop = (row_crop // 2, row_crop // 2) else: row_crop = (row_crop // 2, row_crop // 2 + 1) col_crop = int(output_shape[col_axis] - target_shape[col_axis]) if col_crop % 2 == 0: col_crop = (col_crop // 2, col_crop // 2) else: col_crop = (col_crop // 2, col_crop // 2 + 1) cropping = (row_crop, col_crop) c.append(Cropping2D(cropping=cropping)(x[l])) if multires: x.append(Concatenate(axis=channel_axis)(c)) x.append( Conv2D(n_dense_filters, (rf_counter, rf_counter), dilation_rate=d, kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1])) x.append(BatchNormalization(axis=channel_axis)(x[-1])) x.append(Activation('relu')(x[-1])) if include_top: x.append( TensorProduct(n_dense_filters, kernel_initializer=init, kernel_regularizer=l2(reg))(x[-1])) x.append(BatchNormalization(axis=channel_axis)(x[-1])) x.append(Activation('relu')(x[-1])) x.append( TensorProduct(n_features, kernel_initializer=init, kernel_regularizer=l2(reg))(x[-1])) if not dilated: x.append(Flatten()(x[-1])) x.append(Softmax(axis=channel_axis)(x[-1])) if inputs is not None: real_inputs = keras_utils.get_source_inputs(x[0]) else: real_inputs = x[0] model = Model(inputs=real_inputs, outputs=x[-1]) return model
def bn_feature_net_2D(receptive_field=61, input_shape=(256, 256, 1), n_features=3, n_channels=1, reg=1e-5, n_conv_filters=64, n_dense_filters=200, VGG_mode=False, init='he_normal', norm_method='std', location=False, dilated=False, padding=False, padding_mode='reflect', multires=False, include_top=True): # Create layers list (x) to store all of the layers. # We need to use the functional API to enable the multiresolution mode x = [] win = (receptive_field - 1) // 2 if dilated: padding = True if K.image_data_format() == 'channels_first': channel_axis = 1 row_axis = 2 col_axis = 3 if not dilated: input_shape = (n_channels, receptive_field, receptive_field) else: row_axis = 1 col_axis = 2 channel_axis = -1 if not dilated: input_shape = (receptive_field, receptive_field, n_channels) x.append(Input(shape=input_shape)) x.append(ImageNormalization2D(norm_method=norm_method, filter_size=receptive_field)(x[-1])) if padding: if padding_mode == 'reflect': x.append(ReflectionPadding2D(padding=(win, win))(x[-1])) elif padding_mode == 'zero': x.append(ZeroPadding2D(padding=(win, win))(x[-1])) if location: x.append(Location(in_shape=tuple(x[-1].shape.as_list()[1:]))(x[-1])) x.append(Concatenate(axis=channel_axis)([x[-2], x[-1]])) if multires: layers_to_concat = [] rf_counter = receptive_field block_counter = 0 d = 1 while rf_counter > 4: filter_size = 3 if rf_counter % 2 == 0 else 4 x.append(Conv2D(n_conv_filters, (filter_size, filter_size), dilation_rate=d, kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1])) x.append(BatchNormalization(axis=channel_axis)(x[-1])) x.append(Activation('relu')(x[-1])) block_counter += 1 rf_counter -= filter_size - 1 if block_counter % 2 == 0: if dilated: x.append(DilatedMaxPool2D(dilation_rate=d, pool_size=(2, 2))(x[-1])) d *= 2 else: x.append(MaxPool2D(pool_size=(2, 2))(x[-1])) if VGG_mode: n_conv_filters *= 2 rf_counter = rf_counter // 2 if multires: layers_to_concat.append(len(x) - 1) if multires: c = [] for l in layers_to_concat: output_shape = x[l].get_shape().as_list() target_shape = x[-1].get_shape().as_list() row_crop = int(output_shape[row_axis] - target_shape[row_axis]) if row_crop % 2 == 0: row_crop = (row_crop // 2, row_crop // 2) else: row_crop = (row_crop // 2, row_crop // 2 + 1) col_crop = int(output_shape[col_axis] - target_shape[col_axis]) if col_crop % 2 == 0: col_crop = (col_crop // 2, col_crop // 2) else: col_crop = (col_crop // 2, col_crop // 2 + 1) cropping = (row_crop, col_crop) c.append(Cropping2D(cropping=cropping)(x[l])) x.append(Concatenate(axis=channel_axis)(c)) x.append(Conv2D(n_dense_filters, (rf_counter, rf_counter), dilation_rate=d, kernel_initializer=init, padding='valid', kernel_regularizer=l2(reg))(x[-1])) x.append(BatchNormalization(axis=channel_axis)(x[-1])) x.append(Activation('relu')(x[-1])) x.append(TensorProd2D(n_dense_filters, n_dense_filters, kernel_initializer=init, kernel_regularizer=l2(reg))(x[-1])) x.append(BatchNormalization(axis=channel_axis)(x[-1])) x.append(Activation('relu')(x[-1])) x.append(TensorProd2D(n_dense_filters, n_features, kernel_initializer=init, kernel_regularizer=l2(reg))(x[-1])) if not dilated: x.append(Flatten()(x[-1])) if include_top: x.append(Softmax(axis=channel_axis)(x[-1])) model = Model(inputs=x[0], outputs=x[-1]) return model