def create_model(self, img_shape, num_class): concat_axis = 3 inputs = layers.Input(shape = img_shape) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1_1')(inputs) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4) conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool4) conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv5) up_conv5 = layers.UpSampling2D(size=(2, 2))(conv5) ch, cw = self.get_crop_shape(conv4, up_conv5) crop_conv4 = layers.Cropping2D(cropping=(ch,cw))(conv4) up6 = layers.concatenate([up_conv5, crop_conv4], axis=concat_axis) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(up6) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) up_conv6 = layers.UpSampling2D(size=(2, 2))(conv6) ch, cw = self.get_crop_shape(conv3, up_conv6) crop_conv3 = layers.Cropping2D(cropping=(ch,cw))(conv3) up7 = layers.concatenate([up_conv6, crop_conv3], axis=concat_axis) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up7) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv7) up_conv7 = layers.UpSampling2D(size=(2, 2))(conv7) ch, cw = self.get_crop_shape(conv2, up_conv7) crop_conv2 = layers.Cropping2D(cropping=(ch,cw))(conv2) up8 = layers.concatenate([up_conv7, crop_conv2], axis=concat_axis) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up8) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv8) up_conv8 = layers.UpSampling2D(size=(2, 2))(conv8) ch, cw = self.get_crop_shape(conv1, up_conv8) crop_conv1 = layers.Cropping2D(cropping=(ch,cw))(conv1) up9 = layers.concatenate([up_conv8, crop_conv1], axis=concat_axis) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(up9) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv9) ch, cw = self.get_crop_shape(inputs, conv9) conv9 = layers.ZeroPadding2D(padding=((ch[0], ch[1]), (cw[0], cw[1])))(conv9) conv10 = layers.Conv2D(num_class, (1, 1))(conv9) model = models.Model(inputs=inputs, outputs=conv10) return model
def build_unet(self): conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1_1')(self.model_input) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4) conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool4) conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv5) up_conv5 = layers.UpSampling2D(size=(2, 2))(conv5) ch, cw = self.get_crop_shape(conv4, up_conv5) crop_conv4 = layers.Cropping2D(cropping=(ch,cw))(conv4) up6 = layers.concatenate([up_conv5, crop_conv4], axis=3) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(up6) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) up_conv6 = layers.UpSampling2D(size=(2, 2))(conv6) ch, cw = self.get_crop_shape(conv3, up_conv6) crop_conv3 = layers.Cropping2D(cropping=(ch,cw))(conv3) up7 = layers.concatenate([up_conv6, crop_conv3], axis=3) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up7) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv7) up_conv7 = layers.UpSampling2D(size=(2, 2))(conv7) ch, cw = self.get_crop_shape(conv2, up_conv7) crop_conv2 = layers.Cropping2D(cropping=(ch,cw))(conv2) up8 = layers.concatenate([up_conv7, crop_conv2], axis=3) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up8) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv8) up_conv8 = layers.UpSampling2D(size=(2, 2))(conv8) ch, cw = self.get_crop_shape(conv1, up_conv8) crop_conv1 = layers.Cropping2D(cropping=(ch,cw))(conv1) up9 = layers.concatenate([up_conv8, crop_conv1], axis=3) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(up9) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv9) ch, cw = self.get_crop_shape(self.model_input, conv9) conv9 = layers.ZeroPadding2D(padding=((ch[0], ch[1]), (cw[0], cw[1])))(conv9) conv10 = layers.Conv2D(2, (3, 3),activation='sigmoid', padding='same')(conv9) self.img_pred=conv10
def VanillaUnet(num_class, img_shape): concat_axis = 3 # input inputs = layers.Input(shape=img_shape) # Unet convolution block 1 conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1_1')(inputs) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1) # Unet convolution block 2 conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2) # Unet convolution block 3 conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3) # Unet convolution block 4 conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4) # Unet convolution block 5 conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool4) conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv5) # Unet up-sampling block 1; Concatenation with crop_conv4 up_conv5 = layers.UpSampling2D(size=(2, 2))(conv5) ch, cw = get_crop_shape(conv4, up_conv5) crop_conv4 = layers.Cropping2D(cropping=(ch, cw))(conv4) up6 = layers.concatenate([up_conv5, crop_conv4], axis=concat_axis) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(up6) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) # Unet up-sampling block 2; Concatenation with crop_conv3 up_conv6 = layers.UpSampling2D(size=(2, 2))(conv6) ch, cw = get_crop_shape(conv3, up_conv6) crop_conv3 = layers.Cropping2D(cropping=(ch, cw))(conv3) up7 = layers.concatenate([up_conv6, crop_conv3], axis=concat_axis) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up7) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv7) # Unet up-sampling block 3; Concatenation with crop_conv2 up_conv7 = layers.UpSampling2D(size=(2, 2))(conv7) ch, cw = get_crop_shape(conv2, up_conv7) crop_conv2 = layers.Cropping2D(cropping=(ch, cw))(conv2) up8 = layers.concatenate([up_conv7, crop_conv2], axis=concat_axis) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up8) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv8) # Unet up-sampling block 4; Concatenation with crop_conv1 up_conv8 = layers.UpSampling2D(size=(2, 2))(conv8) ch, cw = get_crop_shape(conv1, up_conv8) crop_conv1 = layers.Cropping2D(cropping=(ch, cw))(conv1) up9 = layers.concatenate([up_conv8, crop_conv1], axis=concat_axis) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(up9) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv9) ch, cw = get_crop_shape(inputs, conv9) conv9 = layers.ZeroPadding2D(padding=((ch[0], ch[1]), (cw[0], cw[1])))(conv9) conv10 = layers.Conv2D(num_class, (1, 1))(conv9) model = models.Model(inputs=inputs, outputs=conv10) return model
def create_model(self, img_shape, num_class): concat_axis = 3 inputs = layers.Input(shape=img_shape) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same', name='conv1_1')(inputs) conv1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv1) pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(pool1) conv2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv2) pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(pool2) conv3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv3) pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(pool3) conv4 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv4) pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4) ## Use dilated convolution x = pool4 depth = 3 #3 #6 dilated_layers = [] mode = 'cascade' if mode == 'cascade': for i in range(depth): x = layers.Conv2D(512, (3, 3), activation='relu', padding='same', dilation_rate=2**i)(x) dilated_layers.append(x) conv5 = layers.add(dilated_layers) elif mode == 'parallel': #"Atrous Spatial Pyramid Pooling" for i in range(depth): dilated_layers.append( layers.Conv2D(512, (3, 3), activation='relu', padding='same', dilation_rate=2**i)(x)) conv5 = layers.add(dilated_layers) #conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(pool4) #conv5 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(conv5) up_conv5 = layers.UpSampling2D(size=(2, 2))(conv5) ch, cw = self.get_crop_shape(conv4, up_conv5) crop_conv4 = layers.Cropping2D(cropping=(ch, cw))(conv4) up6 = layers.concatenate([up_conv5, crop_conv4], axis=concat_axis) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(up6) conv6 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(conv6) up_conv6 = layers.UpSampling2D(size=(2, 2))(conv6) ch, cw = self.get_crop_shape(conv3, up_conv6) crop_conv3 = layers.Cropping2D(cropping=(ch, cw))(conv3) up7 = layers.concatenate([up_conv6, crop_conv3], axis=concat_axis) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(up7) conv7 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(conv7) up_conv7 = layers.UpSampling2D(size=(2, 2))(conv7) ch, cw = self.get_crop_shape(conv2, up_conv7) crop_conv2 = layers.Cropping2D(cropping=(ch, cw))(conv2) up8 = layers.concatenate([up_conv7, crop_conv2], axis=concat_axis) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(up8) conv8 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(conv8) up_conv8 = layers.UpSampling2D(size=(2, 2))(conv8) ch, cw = self.get_crop_shape(conv1, up_conv8) crop_conv1 = layers.Cropping2D(cropping=(ch, cw))(conv1) up9 = layers.concatenate([up_conv8, crop_conv1], axis=concat_axis) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(up9) conv9 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv9) ch, cw = self.get_crop_shape(inputs, conv9) conv9 = layers.ZeroPadding2D(padding=((ch[0], ch[1]), (cw[0], cw[1])))(conv9) conv10 = layers.Conv2D(num_class, (1, 1))(conv9) model = models.Model(inputs=inputs, outputs=conv10) return model