def model_creator(in_shape, class_nb, weight_path): in_lay = Input(in_shape) base_pretrained_model = PTModel(input_shape=in_shape, include_top=False, weights=None) base_pretrained_model.trainable = False pt_depth = base_pretrained_model.get_output_shape_at(0)[-1] pt_features = base_pretrained_model(in_lay) bn_features = BatchNormalization()(pt_features) attn_layer = Conv2D(64, kernel_size=(1, 1), padding='same', activation='relu')(Dropout(0.5)(bn_features)) attn_layer = Conv2D(16, kernel_size=(1, 1), padding='same', activation='relu')(attn_layer) attn_layer = Conv2D(8, kernel_size=(1, 1), padding='same', activation='relu')(attn_layer) attn_layer = Conv2D(1, kernel_size=(1, 1), padding='valid', activation='sigmoid')(attn_layer) up_c2_w = np.ones((1, 1, 1, pt_depth)) up_c2 = Conv2D(pt_depth, kernel_size=(1, 1), padding='same', activation='linear', use_bias=False, weights=[up_c2_w], name='outcnn') up_c2.trainable = False attn_layer = up_c2(attn_layer) mask_features = multiply([attn_layer, bn_features]) gap_features = GlobalAveragePooling2D()(mask_features) gap_mask = GlobalAveragePooling2D()(attn_layer) gap = Lambda(lambda x: x[0] / x[1], name='RescaleGAP')([gap_features, gap_mask]) gap_dr = Dropout(0.25)(gap) dr_steps = Dropout(0.25)(Dense(128, activation='relu')(gap_dr)) out_layer = Dense(class_nb, activation='softmax')(dr_steps) retina_model = Model(inputs=[in_lay], outputs=[out_layer]) retina_model.load_weights(weight_path) return retina_model
def get_model(X, y): in_lay = Input(X.shape[1:]) base_pretrained_model = PTModel(input_shape=X.shape[1:], include_top=False, weights='imagenet') base_pretrained_model.trainable = False pt_depth = base_pretrained_model.get_output_shape_at(0)[-1] pt_features = base_pretrained_model(in_lay) bn_features = BatchNormalization()(pt_features) # here we do an attention mechanism to turn pixels in the GAP on an off attn_layer = Conv2D(64, kernel_size=(1, 1), padding='same', activation='relu')(Dropout(0.5)(bn_features)) attn_layer = Conv2D(16, kernel_size=(1, 1), padding='same', activation='relu')(attn_layer) attn_layer = Conv2D(8, kernel_size=(1, 1), padding='same', activation='relu')(attn_layer) attn_layer = Conv2D(1, kernel_size=(1, 1), padding='valid', activation='sigmoid')(attn_layer) # fan it out to all of the channels up_c2_w = np.ones((1, 1, 1, pt_depth)) up_c2 = Conv2D(pt_depth, kernel_size=(1, 1), padding='same', activation='linear', use_bias=False, weights=[up_c2_w]) up_c2.trainable = False attn_layer = up_c2(attn_layer) mask_features = multiply([attn_layer, bn_features]) gap_features = GlobalAveragePooling2D()(mask_features) gap_mask = GlobalAveragePooling2D()(attn_layer) # to account for missing values from the attention model gap = Lambda(lambda x: x[0] / x[1], name='RescaleGAP')([gap_features, gap_mask]) gap_dr = Dropout(0.25)(gap) dr_steps = Dropout(0.25)(Dense(128, activation='relu')(gap_dr)) out_layer = Dense(y.shape[-1], activation='softmax')(dr_steps) retina_model = Model(inputs=[in_lay], outputs=[out_layer]) def top_2_accuracy(in_gt, in_pred): return top_k_categorical_accuracy(in_gt, in_pred, k=2) retina_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['categorical_accuracy', top_2_accuracy]) return retina_model