def create_model(self, img_shape, num_class): concat_axis = 3 inputs = layers.Input(shape = img_shape) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1_1')(inputs) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4) conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool4) conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv5) up_conv5 = layers.UpSampling2D(size=(2, 2))(conv5) ch, cw = self.get_crop_shape(conv4, up_conv5) crop_conv4 = layers.Cropping2D(cropping=(ch,cw))(conv4) up6 = layers.concatenate([up_conv5, crop_conv4], axis=concat_axis) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(up6) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) up_conv6 = layers.UpSampling2D(size=(2, 2))(conv6) ch, cw = self.get_crop_shape(conv3, up_conv6) crop_conv3 = layers.Cropping2D(cropping=(ch,cw))(conv3) up7 = layers.concatenate([up_conv6, crop_conv3], axis=concat_axis) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up7) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv7) up_conv7 = layers.UpSampling2D(size=(2, 2))(conv7) ch, cw = self.get_crop_shape(conv2, up_conv7) crop_conv2 = layers.Cropping2D(cropping=(ch,cw))(conv2) up8 = layers.concatenate([up_conv7, crop_conv2], axis=concat_axis) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up8) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv8) up_conv8 = layers.UpSampling2D(size=(2, 2))(conv8) ch, cw = self.get_crop_shape(conv1, up_conv8) crop_conv1 = layers.Cropping2D(cropping=(ch,cw))(conv1) up9 = layers.concatenate([up_conv8, crop_conv1], axis=concat_axis) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(up9) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv9) ch, cw = self.get_crop_shape(inputs, conv9) conv9 = layers.ZeroPadding2D(padding=((ch[0], ch[1]), (cw[0], cw[1])))(conv9) conv10 = layers.Conv2D(num_class, (1, 1))(conv9) model = models.Model(inputs=inputs, outputs=conv10) return model
def build_unet(self): conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1_1')(self.model_input) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4) conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool4) conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv5) up_conv5 = layers.UpSampling2D(size=(2, 2))(conv5) ch, cw = self.get_crop_shape(conv4, up_conv5) crop_conv4 = layers.Cropping2D(cropping=(ch,cw))(conv4) up6 = layers.concatenate([up_conv5, crop_conv4], axis=3) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(up6) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) up_conv6 = layers.UpSampling2D(size=(2, 2))(conv6) ch, cw = self.get_crop_shape(conv3, up_conv6) crop_conv3 = layers.Cropping2D(cropping=(ch,cw))(conv3) up7 = layers.concatenate([up_conv6, crop_conv3], axis=3) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up7) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv7) up_conv7 = layers.UpSampling2D(size=(2, 2))(conv7) ch, cw = self.get_crop_shape(conv2, up_conv7) crop_conv2 = layers.Cropping2D(cropping=(ch,cw))(conv2) up8 = layers.concatenate([up_conv7, crop_conv2], axis=3) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up8) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv8) up_conv8 = layers.UpSampling2D(size=(2, 2))(conv8) ch, cw = self.get_crop_shape(conv1, up_conv8) crop_conv1 = layers.Cropping2D(cropping=(ch,cw))(conv1) up9 = layers.concatenate([up_conv8, crop_conv1], axis=3) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(up9) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv9) ch, cw = self.get_crop_shape(self.model_input, conv9) conv9 = layers.ZeroPadding2D(padding=((ch[0], ch[1]), (cw[0], cw[1])))(conv9) conv10 = layers.Conv2D(2, (3, 3),activation='sigmoid', padding='same')(conv9) self.img_pred=conv10
def feature_fusion_module(self, input, name): input_big = input[0] input_small = input[1] up_sampled_input = keras_ly.UpSampling2D(size=(2, 2), name=name + '_upsample')(input_small) concat_1 = tf.concat(axis=3, values=[input_big, up_sampled_input], name=name + '_concat') conv_1 = keras_ly.Conv2D(1024, [3, 3], padding='SAME', name=name + '_conv1')(concat_1) global_pool = tf.reduce_mean(conv_1, [1, 2], keep_dims=True) conv_2 = keras_ly.Conv2D(1024, [1, 1], padding='SAME', name=name + '_conv2')(global_pool) conv_3 = keras_ly.Conv2D(1024, [1, 1], padding='SAME', name=name + '_conv3')(conv_2) sigmoid = tf.sigmoid(conv_3, name=name + '_sigmoid') mul = tf.multiply(sigmoid, conv_1, name=name + '_multiply') add_out = tf.add(conv_1, mul, name=name + '_add_out') return add_out
def attention_block(x, gating, inter_shape): shape_x = K.int_shape(x) shape_g = K.int_shape(gating) theta_x = layers.Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same')(x) # 16 shape_theta_x = K.int_shape(theta_x) phi_g = layers.Conv2D(inter_shape, (1, 1), padding='same')(gating) upsample_g = layers.Conv2DTranspose( inter_shape, (3, 3), strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]), padding='same')(phi_g) # 16 concat_xg = layers.add([upsample_g, theta_x]) act_xg = layers.Activation('relu')(concat_xg) psi = layers.Conv2D(1, (1, 1), padding='same')(act_xg) sigmoid_xg = layers.Activation('sigmoid')(psi) shape_sigmoid = K.int_shape(sigmoid_xg) upsample_psi = layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))( sigmoid_xg) # 32 upsample_psi = expend_as(upsample_psi, shape_x[3]) y = layers.multiply([upsample_psi, x]) result = layers.Conv2D(shape_x[3], (1, 1), padding='same')(y) result_bn = layers.BatchNormalization()(result) return result_bn
def feature_fusion_module_new(self, input, name, num_features): input_big = input[0] input_small = input[1] b_shape = input_big.get_shape() s_shape = input_small.get_shape() if(b_shape[1].value > s_shape[1].value): up_sampled_input = keras_ly.UpSampling2D(size=(2, 2), name=name+'_upsample')(input_small) else: up_sampled_input = input_small concat_1 = tf.concat(axis=3, values=[input_big, up_sampled_input], name=name+'_concat') conv_1 = keras_ly.Conv2D(num_features, [3, 3], padding='SAME', name=name+'_conv1')(concat_1) conv_1_bn_relu = tf.nn.relu(slim.batch_norm(conv_1, fused=True)) global_pool = tf.reduce_mean(conv_1_bn_relu, [1, 2], keep_dims=True) conv_2 = keras_ly.Conv2D(num_features, [1, 1], padding='SAME', name=name+'_conv2')(global_pool) conv_3 = keras_ly.Conv2D(num_features, [1, 1], padding='SAME', name=name+'_conv3')(conv_2) sigmoid = tf.sigmoid(conv_3, name=name+'_sigmoid') mul = tf.multiply(sigmoid, conv_1_bn_relu, name=name+'_multiply') #sigmoid * conv_1 add_out = tf.add(conv_1_bn_relu, mul, name=name+'_add_out') # conv_1 + mul return add_out
def layer_upsampling_2D(x, filter_size = (2,2), name = 'upsampling_2D'): """ wrapper for Keras' upsampling2D layer with import from tn.nn.contrib.keras """ with tf.variable_scope(name): y = layers.UpSampling2D(filter_size)(x) return y
def up_block(self, act, bn, f, name): x = layers.UpSampling2D( size=(2,2), name='upsample_{}'.format(name))(act) temp = layers.concatenate([bn, x], axis=1) temp = self.conv_bn_relu(temp, (3, 3), (1, 1), 2*f, 'layer2_{}'.format(name)) temp = layers.BatchNormalization(self.conv(temp, (3, 3), (1, 1), f, 'layer3_{}'.format( name)), momentum=0.99, name='layer3_bn_{}'.format(name)) #bn = layers.add([bn,x]) bn = self.shortcut(x, bn) act = layers.Activation('relu')(bn) return act
def attention_block(x, gating, inter_shape, name): """ self gated attention, attention mechanism on spatial dimension :param x: input feature map :param gating: gate signal, feature map from the lower layer :param inter_shape: intermedium channle numer :param name: name of attention layer, for output :return: attention weighted on spatial dimension feature map """ shape_x = K.int_shape(x) shape_g = K.int_shape(gating) theta_x = layers.Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same')(x) # 16 shape_theta_x = K.int_shape(theta_x) phi_g = layers.Conv2D(inter_shape, (1, 1), padding='same')(gating) upsample_g = layers.Conv2DTranspose( inter_shape, (3, 3), strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]), padding='same')(phi_g) # 16 # upsample_g = layers.UpSampling2D(size=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]), # data_format="channels_last")(phi_g) concat_xg = layers.add([upsample_g, theta_x]) act_xg = layers.Activation('relu')(concat_xg) psi = layers.Conv2D(1, (1, 1), padding='same')(act_xg) sigmoid_xg = layers.Activation('sigmoid')(psi) shape_sigmoid = K.int_shape(sigmoid_xg) upsample_psi = layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]), name=name + '_weight')(sigmoid_xg) # 32 upsample_psi = expend_as(upsample_psi, shape_x[3]) y = layers.multiply([upsample_psi, x]) result = layers.Conv2D(shape_x[3], (1, 1), padding='same')(y) result_bn = layers.BatchNormalization()(result) return result_bn
import tensorflow as tf from tensorflow.contrib.keras import layers import numpy as np # from tensorflow.contrib import layers if __name__ == '__main__': print(tf.VERSION) print(tf.keras.__version__) model = tf.keras.Sequential() # Adds a densely-connected layer with 64 units to the model: model.add(layers.UpSampling2D(size=(2, 2))) data = np.random.random((1, 3, 3, 1)) # print(data) print(data.reshape((3,3))) # labels = np.random.random((1000, 10)) print('--------\n') print('--------\n') print('--------\n') result = model.predict(data, batch_size=1) print(result.shape) print(result.reshape((6,6)))
def Attention_ResUNet_PA(dropout_rate=0.0, batch_norm=True): ''' Rsidual UNet construction, with attention gate convolution: 3*3 SAME padding pooling: 2*2 VALID padding upsampling: 3*3 VALID padding final convolution: 1*1 :param dropout_rate: FLAG & RATE of dropout. if < 0 dropout cancelled, if > 0 set as the rate :param batch_norm: flag of if batch_norm used, if True batch normalization :return: model ''' # input data # dimension of the image depth inputs = layers.Input((INPUT_SIZE, INPUT_SIZE, INPUT_CHANNEL), dtype=tf.float32) axis = 3 # Downsampling layers # DownRes 1, double residual convolution + pooling conv_128 = double_conv_layer(inputs, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm) pool_64 = layers.MaxPooling2D(pool_size=(2, 2))(conv_128) # DownRes 2 conv_64 = double_conv_layer(pool_64, FILTER_SIZE, 2 * FILTER_NUM, dropout_rate, batch_norm) pool_32 = layers.MaxPooling2D(pool_size=(2, 2))(conv_64) # DownRes 3 conv_32 = double_conv_layer(pool_32, FILTER_SIZE, 4 * FILTER_NUM, dropout_rate, batch_norm) pool_16 = layers.MaxPooling2D(pool_size=(2, 2))(conv_32) # DownRes 4 conv_16 = double_conv_layer(pool_16, FILTER_SIZE, 8 * FILTER_NUM, dropout_rate, batch_norm) pool_8 = layers.MaxPooling2D(pool_size=(2, 2))(conv_16) # DownRes 5, convolution only conv_8 = double_conv_layer(pool_8, FILTER_SIZE, 16 * FILTER_NUM, dropout_rate, batch_norm) # Upsampling layers # UpRes 6, attention gated concatenation + upsampling + double residual convolution # channel attention block se_conv_16 = SE_block(conv_16, out_dim=8 * FILTER_NUM, ratio=SE_RATIO, name='att_16') # spatial attention block gating_16 = gating_signal(conv_8, 8 * FILTER_NUM, batch_norm) att_16 = attention_block(se_conv_16, gating_16, 8 * FILTER_NUM, name='att_16') # attention re-weight & concatenate up_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_8) up_16 = layers.concatenate([up_16, att_16], axis=axis) up_conv_16 = double_conv_layer(up_16, FILTER_SIZE, 8 * FILTER_NUM, dropout_rate, batch_norm) # UpRes 7 # channel attention block se_conv_32 = SE_block(conv_32, out_dim=4 * FILTER_NUM, ratio=SE_RATIO, name='att_32') # spatial attention block gating_32 = gating_signal(up_conv_16, 4 * FILTER_NUM, batch_norm) att_32 = attention_block(se_conv_32, gating_32, 4 * FILTER_NUM, name='att_32') # attention re-weight & concatenate up_32 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_16) up_32 = layers.concatenate([up_32, att_32], axis=axis) up_conv_32 = double_conv_layer(up_32, FILTER_SIZE, 4 * FILTER_NUM, dropout_rate, batch_norm) # UpRes 8 # channel attention block se_conv_64 = SE_block(conv_64, out_dim=2 * FILTER_NUM, ratio=SE_RATIO, name='att_64') # spatial attention block gating_64 = gating_signal(up_conv_32, 2 * FILTER_NUM, batch_norm) att_64 = attention_block(se_conv_64, gating_64, 2 * FILTER_NUM, name='att_64') # attention re-weight & concatenate up_64 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_32) up_64 = layers.concatenate([up_64, att_64], axis=axis) up_conv_64 = double_conv_layer(up_64, FILTER_SIZE, 2 * FILTER_NUM, dropout_rate, batch_norm) # UpRes 9 # channel attention block se_conv_128 = SE_block(conv_128, out_dim=FILTER_NUM, ratio=SE_RATIO, name='att_128') # spatial attention block gating_128 = gating_signal(up_conv_64, FILTER_NUM, batch_norm) # attention re-weight & concatenate att_128 = attention_block(se_conv_128, gating_128, FILTER_NUM, name='att_128') up_128 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_64) up_128 = layers.concatenate([up_128, att_128], axis=axis) up_conv_128 = double_conv_layer(up_128, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm) # 1*1 convolutional layers # valid padding # batch normalization # sigmoid nonlinear activation conv_final = layers.Conv2D(OUTPUT_MASK_CHANNEL, kernel_size=(1, 1))(up_conv_128) conv_final = layers.BatchNormalization(axis=axis)(conv_final) conv_final = layers.Activation('relu')(conv_final) # Model integration model = models.Model(inputs, conv_final, name="AttentionSEResUNet") return model
def VanillaUnet(num_class, img_shape): concat_axis = 3 # input inputs = layers.Input(shape=img_shape) # Unet convolution block 1 conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1_1')(inputs) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1) # Unet convolution block 2 conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2) # Unet convolution block 3 conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3) # Unet convolution block 4 conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4) # Unet convolution block 5 conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool4) conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv5) # Unet up-sampling block 1; Concatenation with crop_conv4 up_conv5 = layers.UpSampling2D(size=(2, 2))(conv5) ch, cw = get_crop_shape(conv4, up_conv5) crop_conv4 = layers.Cropping2D(cropping=(ch, cw))(conv4) up6 = layers.concatenate([up_conv5, crop_conv4], axis=concat_axis) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(up6) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) # Unet up-sampling block 2; Concatenation with crop_conv3 up_conv6 = layers.UpSampling2D(size=(2, 2))(conv6) ch, cw = get_crop_shape(conv3, up_conv6) crop_conv3 = layers.Cropping2D(cropping=(ch, cw))(conv3) up7 = layers.concatenate([up_conv6, crop_conv3], axis=concat_axis) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up7) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv7) # Unet up-sampling block 3; Concatenation with crop_conv2 up_conv7 = layers.UpSampling2D(size=(2, 2))(conv7) ch, cw = get_crop_shape(conv2, up_conv7) crop_conv2 = layers.Cropping2D(cropping=(ch, cw))(conv2) up8 = layers.concatenate([up_conv7, crop_conv2], axis=concat_axis) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up8) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv8) # Unet up-sampling block 4; Concatenation with crop_conv1 up_conv8 = layers.UpSampling2D(size=(2, 2))(conv8) ch, cw = get_crop_shape(conv1, up_conv8) crop_conv1 = layers.Cropping2D(cropping=(ch, cw))(conv1) up9 = layers.concatenate([up_conv8, crop_conv1], axis=concat_axis) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(up9) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv9) ch, cw = get_crop_shape(inputs, conv9) conv9 = layers.ZeroPadding2D(padding=((ch[0], ch[1]), (cw[0], cw[1])))(conv9) conv10 = layers.Conv2D(num_class, (1, 1))(conv9) model = models.Model(inputs=inputs, outputs=conv10) return model
def UNet_PA(dropout_rate=0.0, batch_norm=True): ''' UNet construction convolution: 3*3 SAME padding pooling: 2*2 VALID padding upsampling: 3*3 VALID padding final convolution: 1*1 :param dropout_rate: FLAG & RATE of dropout. if < 0 dropout cancelled, if > 0 set as the rate :param batch_norm: flag of if batch_norm used, if True batch normalization :return: UNet model for PACT recons ''' # input data # dimension of the image depth inputs = layers.Input((INPUT_SIZE, INPUT_SIZE, INPUT_CHANNEL)) axis = 3 # Subsampling layers # double layer 1, convolution + pooling conv_128 = double_conv_layer(inputs, FILTER_SIZE, INPUT_SIZE, dropout_rate, batch_norm) pool_64 = layers.MaxPooling2D(pool_size=(2, 2))(conv_128) # double layer 2 conv_64 = double_conv_layer(pool_64, 2 * FILTER_SIZE, INPUT_SIZE, dropout_rate, batch_norm) pool_32 = layers.MaxPooling2D(pool_size=(2, 2))(conv_64) # double layer 3 conv_32 = double_conv_layer(pool_32, 4 * FILTER_SIZE, INPUT_SIZE, dropout_rate, batch_norm) pool_16 = layers.MaxPooling2D(pool_size=(2, 2))(conv_32) # double layer 4 conv_16 = double_conv_layer(pool_16, 8 * FILTER_SIZE, INPUT_SIZE, dropout_rate, batch_norm) pool_8 = layers.MaxPooling2D(pool_size=(2, 2))(conv_16) # double layer 5, convolution only conv_8 = double_conv_layer(pool_8, 16 * FILTER_SIZE, INPUT_SIZE, dropout_rate, batch_norm) # Upsampling layers # double layer 6, upsampling + concatenation + convolution up_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(conv_8) up_16 = layers.concatenate([up_16, conv_16], axis=axis) up_conv_16 = double_conv_layer(up_16, 8 * FILTER_SIZE, INPUT_SIZE, dropout_rate, batch_norm) # double layer 7 up_32 = layers.concatenate([ layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_16), conv_32 ], axis=axis) up_conv_32 = double_conv_layer(up_32, 4 * FILTER_SIZE, INPUT_SIZE, dropout_rate, batch_norm) # double layer 8 up_64 = layers.concatenate([ layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_32), conv_64 ], axis=axis) up_conv_64 = double_conv_layer(up_64, 2 * FILTER_SIZE, INPUT_SIZE, dropout_rate, batch_norm) # double layer 9 up_128 = layers.concatenate([ layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE), data_format="channels_last")(up_conv_64), conv_128 ], axis=axis) up_conv_128 = double_conv_layer(up_128, FILTER_SIZE, INPUT_SIZE, dropout_rate, batch_norm) # 1*1 convolutional layers # valid padding # batch normalization # sigmoid nonlinear activation conv_final = layers.Conv2D(OUTPUT_MASK_CHANNEL, kernel_size=(1, 1))(up_conv_128) conv_final = layers.BatchNormalization(axis=axis)(conv_final) conv_final = layers.Activation('sigmoid')(conv_final) # Model integration model = models.Model(inputs, conv_final, name="UNet") return model
def create_model(self, x_shape, y_shape): # Specify inputs (size is given in opts.py file) x_in = layers.Input(shape=x_shape, name='x_in') y_rfp = layers.Input(shape=y_shape, name='y_rfp') # First two conv layers of source cell encoder conv1 = layers.Conv2D(96, (3, 3), activation='relu', padding='same', name='conv1_1')(x_in) conv1 = layers.BatchNormalization()(conv1) pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = layers.Conv2D(256, (3, 3), activation='relu', padding='same', name='conv2_1')(pool1) conv2 = layers.BatchNormalization()(conv2) pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2) # First two conv layers of target marker encoder rfpconv1 = layers.Conv2D(16, (3, 3), activation='relu', padding='same', name='rfpconv1_1')(y_rfp) rfpconv1 = layers.BatchNormalization()(rfpconv1) rfppool1 = layers.MaxPooling2D(pool_size=(2, 2))(rfpconv1) rfpconv2 = layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='rfpconv2_1')(rfppool1) rfpconv2 = layers.BatchNormalization()(rfpconv2) rfppool2 = layers.MaxPooling2D(pool_size=(2, 2))(rfpconv2) # Last three conv layers of source cell encoder conv3 = layers.Conv2D(384, (3, 3), activation='relu', padding='same', name='conv3_1')(pool2) conv3 = layers.BatchNormalization()(conv3) conv4 = layers.Conv2D(384, (3, 3), activation='relu', padding='same', name='conv4_1')(conv3) conv4 = layers.BatchNormalization()(conv4) conv5 = layers.Conv2D(256, (3, 3), activation='relu', padding='same', name='conv5_1')(conv4) conv5 = layers.BatchNormalization()(conv5) # Last conv layer of target marker encoder rfpconv3 = layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='rfpconv3_1')(rfppool2) rfpconv3 = layers.BatchNormalization()(rfpconv3) # Concatencation later conv5 = layers.Concatenate(axis=-1)([conv5, rfpconv3]) # Decoder layers conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same', name='conv6_1')(conv5) conv7 = layers.Conv2D(384, (3, 3), activation='relu', padding='same', name='conv7_1')(conv6) conv8 = layers.Conv2D(384, (3, 3), activation='relu', padding='same', name='conv8_1')(conv7) up_conv9 = layers.UpSampling2D(size=(2, 2))(conv8) conv9 = layers.Conv2D(256, (3, 3), activation='relu', padding='same', name='conv9_1')(up_conv9) up_conv10 = layers.UpSampling2D(size=(2, 2))(conv9) conv10 = layers.Conv2D(96, (3, 3), activation='relu', padding='same', name='conv10_1')(up_conv10) conv10 = layers.Conv2D(1, (1, 1), activation=None, name='y_gfp')(conv10) # Paired cell inpainting output model = models.Model(inputs=[x_in, y_rfp], outputs=conv10) return model
def create_model(self, img_shape, num_class): concat_axis = 3 inputs = layers.Input(shape=img_shape) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1_1')(inputs) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4) ## Use dilated convolution x = pool4 depth = 3 #3 #6 dilated_layers = [] mode = 'cascade' if mode == 'cascade': for i in range(depth): x = layers.Conv2D(512, (3, 3), activation='relu', padding='same', dilation_rate=2**i)(x) dilated_layers.append(x) conv5 = layers.add(dilated_layers) elif mode == 'parallel': #"Atrous Spatial Pyramid Pooling" for i in range(depth): dilated_layers.append( layers.Conv2D(512, (3, 3), activation='relu', padding='same', dilation_rate=2**i)(x)) conv5 = layers.add(dilated_layers) #conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool4) #conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv5) up_conv5 = layers.UpSampling2D(size=(2, 2))(conv5) ch, cw = self.get_crop_shape(conv4, up_conv5) crop_conv4 = layers.Cropping2D(cropping=(ch, cw))(conv4) up6 = layers.concatenate([up_conv5, crop_conv4], axis=concat_axis) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(up6) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) up_conv6 = layers.UpSampling2D(size=(2, 2))(conv6) ch, cw = self.get_crop_shape(conv3, up_conv6) crop_conv3 = layers.Cropping2D(cropping=(ch, cw))(conv3) up7 = layers.concatenate([up_conv6, crop_conv3], axis=concat_axis) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up7) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv7) up_conv7 = layers.UpSampling2D(size=(2, 2))(conv7) ch, cw = self.get_crop_shape(conv2, up_conv7) crop_conv2 = layers.Cropping2D(cropping=(ch, cw))(conv2) up8 = layers.concatenate([up_conv7, crop_conv2], axis=concat_axis) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up8) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv8) up_conv8 = layers.UpSampling2D(size=(2, 2))(conv8) ch, cw = self.get_crop_shape(conv1, up_conv8) crop_conv1 = layers.Cropping2D(cropping=(ch, cw))(conv1) up9 = layers.concatenate([up_conv8, crop_conv1], axis=concat_axis) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(up9) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv9) ch, cw = self.get_crop_shape(inputs, conv9) conv9 = layers.ZeroPadding2D(padding=((ch[0], ch[1]), (cw[0], cw[1])))(conv9) conv10 = layers.Conv2D(num_class, (1, 1))(conv9) model = models.Model(inputs=inputs, outputs=conv10) return model