def fun(inputs): return ops.InstanceNormalization()(inputs)
def unet2d_fullIN(lossfunc, lr, input_dim, feature_num, metric='mse', multigpu=False, include_top=True): """ Initialize 2D U-net Uses the Keras functional API to construct a U-Net. The net is fully convolutional, so it can be trained and tested on variable size input (thus the x-y input dimensions are undefined) inputs-- lossfunc: loss function lr: float; learning rate input_dim: int; number of feature channels in input feature_num: int; number of output features outputs-- model: Keras model object """ inputs = Input((None, None, input_dim)) conv1 = Conv2D(32, (3, 3), padding='same')(inputs) conv1 = Activation('relu')(ops.InstanceNormalization()(conv1)) conv1 = Conv2D(32, (3, 3), padding='same')(conv1) conv1 = Activation('relu')(ops.InstanceNormalization()(conv1)) pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = Conv2D(64, (3, 3), padding='same')(pool1) conv2 = Activation('relu')(ops.InstanceNormalization()(conv2)) conv2 = Conv2D(64, (3, 3), padding='same')(conv2) conv2 = Activation('relu')(ops.InstanceNormalization()(conv2)) pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) conv3 = Conv2D(128, (3, 3), padding='same')(pool2) conv3 = Activation('relu')(ops.InstanceNormalization()(conv3)) conv3 = Conv2D(128, (3, 3), padding='same')(conv3) conv3 = Activation('relu')(ops.InstanceNormalization()(conv3)) pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) conv4 = Conv2D(256, (3, 3), padding='same')(pool3) conv4 = Activation('relu')(ops.InstanceNormalization()(conv4)) conv4 = Conv2D(256, (3, 3), padding='same')(conv4) conv4 = Activation('relu')(ops.InstanceNormalization()(conv4)) pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) conv5 = Conv2D(512, (3, 3), padding='same')(pool4) conv5 = Activation('relu')(ops.InstanceNormalization()(conv5)) conv5 = Conv2D(512, (3, 3), padding='same')(conv5) conv5 = Activation('relu')(ops.InstanceNormalization()(conv5)) up6 = concatenate([ Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4 ], axis=3) conv6 = Conv2D(256, (3, 3), padding='same')(up6) conv6 = Activation('relu')(ops.InstanceNormalization()(conv6)) conv6 = Conv2D(256, (3, 3), padding='same')(conv6) conv6 = Activation('relu')(ops.InstanceNormalization()(conv6)) up7 = concatenate([ Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3 ], axis=3) conv7 = Conv2D(128, (3, 3), padding='same')(up7) conv7 = Activation('relu')(ops.InstanceNormalization()(conv7)) conv7 = Conv2D(128, (3, 3), padding='same')(conv7) conv7 = Activation('relu')(ops.InstanceNormalization()(conv7)) up8 = concatenate([ Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2 ], axis=3) conv8 = Conv2D(64, (3, 3), padding='same')(up8) conv8 = Activation('relu')(ops.InstanceNormalization()(conv8)) conv8 = Conv2D(64, (3, 3), padding='same')(conv8) conv8 = Activation('relu')(ops.InstanceNormalization()(conv8)) up9 = concatenate([ Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1 ], axis=3) conv9 = Conv2D(32, (3, 3), padding='same')(up9) conv9 = Activation('relu')(ops.InstanceNormalization()(conv9)) conv9 = Conv2D(32, (3, 3), padding='same')(conv9) conv9 = Activation('relu')(ops.InstanceNormalization()(conv9)) conv10 = Conv2D(feature_num, (1, 1), activation='sigmoid')(conv9) if include_top: model = Model(inputs=[inputs], outputs=[conv10]) else: model = Model(inputs=[inputs], outputs=[conv9]) if multigpu: model = multi_gpu_model(model, gpus=2) model.compile(optimizer=Adam(lr=lr), loss=lossfunc, metrics=[metric]) return model