示例#1
0
    def __get_unet(self):
        inputs = Input((self.img_rows, self.img_cols, 4))  # (b, 224, 224, 1)
        filters = 16  # 64

        stem = inputs
        conv224, p112 = self.__pool_layer(stem, filters=filters, block_num=1)  # (b, 112, 112, 64)
        conv112, p56 = self.__pool_layer(p112, filters=filters * 2, block_num=2)  # (b, 56, 56, 128)
        conv56, p28 = self.__pool_layer(p56, filters=filters * 4, block_num=3)  # (b, 28, 28, 256)
        conv28, p14 = self.__pool_layer(p28, filters=filters * 8, block_num=4)  # (b, 14, 14, 512)
        conv14, _ = self.__pool_layer(p14, filters=filters * 16, block_num=5)  # (b, 14, 14, 512)

        # conv14 = Dropout(0.5)(conv14)

        u28 = self.__unpool_block(pooled=conv14, prepooled=conv28, block_num=1)
        u56 = self.__unpool_block(pooled=u28, prepooled=conv56, block_num=2)
        u112 = self.__unpool_block(pooled=u56, prepooled=conv112, block_num=3)
        u224 = self.__unpool_block(pooled=u112, prepooled=conv224, block_num=4)

        predictions = Conv2D(filters=4, kernel_size=1, activation='softmax', name='predictions')(u224)
        mask = Lambda(metrics.compute_mask)(inputs)
        masked_predictions = predictions  # Lambda(lambda mask_n_preds: mask_n_preds[0] * mask_n_preds[1])(
           # [mask, predictions])  # multiply([mask, predictions])

        model = Model(inputs=inputs, outputs=masked_predictions)
        model.compile(optimizer=Adam(lr=1e-3), loss=metrics.keras_dice_coef_loss(),
                      metrics=[metrics.wt_dice,
                               metrics.tc_dice,
                               metrics.et_dice])

        model.summary()
        plot_model(model, self.config.results_path + '/model.png', show_shapes=True, show_layer_names=True)
        return model
示例#2
0
    def __get_unet(self):
        inputs = Input((self.img_rows, self.img_cols, self.img_depth, 4))  # (b, 224, 224, 1)
        filters = 8 # 64

        conv224, p112 = self.__pool_layer(inputs, filters=filters, block_num=1)  # (b, 112, 112, 16)
        conv112, p56 = self.__pool_layer(p112, filters=filters*2, block_num=2)   # (b, 56, 56, 32)
        conv56, p28 = self.__pool_layer(p56, filters=filters*4, block_num=3)    # (b, 28, 28, 64)
        conv28, p14 = self.__pool_layer(p28, filters=filters*8, block_num=4)    # (b, 14, 14, 128)
        conv14, _ = self.__pool_layer(p14, filters=filters*16, block_num=5)    # (b, 14, 14, 256)

       # conv14 = Dropout(0.5)(conv14)

        u28 = self.__unpool_block(pooled=conv14, prepooled=conv28, block_num=1)
        u56 = self.__unpool_block(pooled=u28, prepooled=conv56, block_num=2)
        u112 = self.__unpool_block(pooled=u56, prepooled=conv112, block_num=3)
        u224 = self.__unpool_block(pooled=u112, prepooled=conv224, block_num=4)

        predictions = Conv3D(filters=4, kernel_size=1, activation='softmax', name='predictions')(u224)
        mask = Lambda(metrics.compute_mask)(inputs)
        masked_predictions = Lambda(lambda mask_n_preds: mask_n_preds[0] * mask_n_preds[1])([mask, predictions])  # multiply([mask, predictions])

        model = Model(inputs=inputs, outputs=masked_predictions)
        model.compile(optimizer=Adam(lr=1e-3), loss=metrics.keras_dice_coef_loss(), metrics=[metrics.category_dice_score(1), metrics.category_dice_score(2), metrics.category_dice_score(3)])

        return model
示例#3
0
              (scores_crf[0], scores_crf[1], scores_crf[2]))


# For testing
if __name__ == '__main__':
    import config as configuration
    from read_data import BRATSReader
    from evaluation_generator import EvalGenerator
    from keras.models import load_model
    import metrics
    import keras

    config = configuration.Config()

    # Super hacky way to load weights and architecture. Absolutely not ok to run training or keras metrics on this.
    keras.losses.keras_dice_coef_loss_fn = metrics.keras_dice_coef_loss()
    keras.metrics.hard_dice = metrics.wt_dice
    keras.metrics.wt_dice = metrics.wt_dice
    keras.metrics.et_dice = metrics.et_dice
    keras.metrics.tc_dice = metrics.tc_dice
    axial_modals = [load_model("unet.hdf5")]
    multi_modals = []

    brats = BRATSReader(use_hgg=True, use_lgg=True)
    # print(brats.get_mean_dev(.15, 't1ce'))
    train_ids, val_ids, test_ids = brats.get_case_ids(config.brats_val_split)

    height, width, slices = brats.get_dims()
    #train_datagen = EvalGenerator(brats, train_ids, dim=(height, width, 4))
    val_datagen = EvalGenerator(brats, val_ids, dim=(height, width, 4))