def model_creator(in_shape, class_nb, weight_path):
        in_lay = Input(in_shape)
        base_pretrained_model = PTModel(input_shape=in_shape, include_top=False,
                                        weights=None)
        base_pretrained_model.trainable = False
        pt_depth = base_pretrained_model.get_output_shape_at(0)[-1]
        pt_features = base_pretrained_model(in_lay)
        bn_features = BatchNormalization()(pt_features)

        attn_layer = Conv2D(64, kernel_size=(1, 1), padding='same', activation='relu')(Dropout(0.5)(bn_features))
        attn_layer = Conv2D(16, kernel_size=(1, 1), padding='same', activation='relu')(attn_layer)
        attn_layer = Conv2D(8, kernel_size=(1, 1), padding='same', activation='relu')(attn_layer)
        attn_layer = Conv2D(1,
                            kernel_size=(1, 1),
                            padding='valid',
                            activation='sigmoid')(attn_layer)

        up_c2_w = np.ones((1, 1, 1, pt_depth))
        up_c2 = Conv2D(pt_depth, kernel_size=(1, 1), padding='same',
                       activation='linear', use_bias=False, weights=[up_c2_w], name='outcnn')
        up_c2.trainable = False
        attn_layer = up_c2(attn_layer)

        mask_features = multiply([attn_layer, bn_features])
        gap_features = GlobalAveragePooling2D()(mask_features)
        gap_mask = GlobalAveragePooling2D()(attn_layer)

        gap = Lambda(lambda x: x[0] / x[1], name='RescaleGAP')([gap_features, gap_mask])
        gap_dr = Dropout(0.25)(gap)
        dr_steps = Dropout(0.25)(Dense(128, activation='relu')(gap_dr))
        out_layer = Dense(class_nb, activation='softmax')(dr_steps)
        retina_model = Model(inputs=[in_lay], outputs=[out_layer])

        retina_model.load_weights(weight_path)
        return retina_model
Esempio n. 2
0
def get_model(X, y):

    in_lay = Input(X.shape[1:])
    base_pretrained_model = PTModel(input_shape=X.shape[1:],
                                    include_top=False,
                                    weights='imagenet')
    base_pretrained_model.trainable = False
    pt_depth = base_pretrained_model.get_output_shape_at(0)[-1]
    pt_features = base_pretrained_model(in_lay)
    bn_features = BatchNormalization()(pt_features)

    # here we do an attention mechanism to turn pixels in the GAP on an off

    attn_layer = Conv2D(64,
                        kernel_size=(1, 1),
                        padding='same',
                        activation='relu')(Dropout(0.5)(bn_features))
    attn_layer = Conv2D(16,
                        kernel_size=(1, 1),
                        padding='same',
                        activation='relu')(attn_layer)
    attn_layer = Conv2D(8,
                        kernel_size=(1, 1),
                        padding='same',
                        activation='relu')(attn_layer)
    attn_layer = Conv2D(1,
                        kernel_size=(1, 1),
                        padding='valid',
                        activation='sigmoid')(attn_layer)
    # fan it out to all of the channels
    up_c2_w = np.ones((1, 1, 1, pt_depth))
    up_c2 = Conv2D(pt_depth,
                   kernel_size=(1, 1),
                   padding='same',
                   activation='linear',
                   use_bias=False,
                   weights=[up_c2_w])
    up_c2.trainable = False
    attn_layer = up_c2(attn_layer)

    mask_features = multiply([attn_layer, bn_features])
    gap_features = GlobalAveragePooling2D()(mask_features)
    gap_mask = GlobalAveragePooling2D()(attn_layer)
    # to account for missing values from the attention model
    gap = Lambda(lambda x: x[0] / x[1],
                 name='RescaleGAP')([gap_features, gap_mask])
    gap_dr = Dropout(0.25)(gap)
    dr_steps = Dropout(0.25)(Dense(128, activation='relu')(gap_dr))
    out_layer = Dense(y.shape[-1], activation='softmax')(dr_steps)
    retina_model = Model(inputs=[in_lay], outputs=[out_layer])

    def top_2_accuracy(in_gt, in_pred):
        return top_k_categorical_accuracy(in_gt, in_pred, k=2)

    retina_model.compile(optimizer='adam',
                         loss='categorical_crossentropy',
                         metrics=['categorical_accuracy', top_2_accuracy])

    return retina_model