def get_3d_rsunet():

    inputs = Input((cm.slices_3d, cm.img_rows_3d, cm.img_cols_3d, 1))

    conv1 = residual_module_3d(inputs, 32, 'prev')
    pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conv1)

    conv2 = residual_module_3d(pool1, 64, 'prev')
    pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conv2)

    conv3 = residual_module_3d(pool2, 128, 'prev')
    pool3 = MaxPooling3D(pool_size=(2, 2, 2))(conv3)

    conv4 = residual_module_3d(pool3, 256, 'prev')
    pool4 = MaxPooling3D(pool_size=(2, 2, 2))(conv4)

    conv5 = residual_module_3d(pool4, 256, 'prev')

    up6 = merge([UpSampling3D(size=(2, 2, 2))(conv5), conv4],
                mode='concat',
                concat_axis=-1)
    conv6 = residual_module_3d(up6, 128, 'prev')

    up7 = merge([UpSampling3D(size=(2, 2, 2))(conv6), conv3],
                mode='concat',
                concat_axis=-1)
    conv7 = residual_module_3d(up7, 64, 'prev')

    up8 = merge([UpSampling3D(size=(2, 2, 2))(conv7), conv2],
                mode='concat',
                concat_axis=-1)
    conv8 = residual_module_3d(up8, 32, 'prev')

    up9 = merge([UpSampling3D(size=(2, 2, 2))(conv8), conv1],
                mode='concat',
                concat_axis=-1)
    conv9 = residual_module_3d(up9, 16, 'last')

    conv10 = Conv3D(filters=3,
                    kernel_size=(1, 1, 1),
                    strides=(1, 1, 1),
                    activation='sigmoid')(conv9)

    model = Model(input=inputs, output=conv10)

    weights = np.array([1, 1, 1])
    loss = lf.weighted_categorical_crossentropy_loss(weights)
    # model.compile(optimizer=Adam(lr=1.0e-5), loss="categorical_crossentropy", metrics=["categorical_accuracy"])
    model.compile(optimizer=Adam(lr=1.0e-6),
                  loss=loss,
                  metrics=["categorical_accuracy"])

    return model
Ejemplo n.º 2
0
def get_3d_crfrnn_model_def():
    channels, slices, height, weight = 1, cm.slices_3d, cm.img_rows_3d, cm.img_cols_3d

    # Input
    inputs = Input((cm.slices_3d, cm.img_rows_3d, cm.img_cols_3d, 1),
                   name='layer_no_0_input')

    conv1 = Conv3D(filters=32,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_1_conv')(inputs)
    conv1 = Conv3D(filters=32,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_2_conv')(conv1)
    pool1 = MaxPooling3D(pool_size=(2, 2, 2), name='layer_no_3')(conv1)

    conv2 = Conv3D(filters=64,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_4_conv')(pool1)
    conv2 = Conv3D(filters=64,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_5_conv')(conv2)
    pool2 = MaxPooling3D(pool_size=(2, 2, 2), name='layer_no_6')(conv2)

    conv3 = Conv3D(filters=128,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_7_conv')(pool2)
    conv3 = Conv3D(filters=128,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_8_conv')(conv3)
    pool3 = MaxPooling3D(pool_size=(2, 2, 2), name='layer_no_9')(conv3)

    conv4 = Conv3D(filters=256,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_10_conv')(pool3)
    conv4 = Conv3D(filters=256,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_11_conv')(conv4)
    pool4 = MaxPooling3D(pool_size=(2, 2, 2), name='layer_no_12')(conv4)

    conv5 = Conv3D(filters=512,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_13_conv')(pool4)
    conv5 = Conv3D(filters=512,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_14_conv')(conv5)

    up6 = merge(
        [UpSampling3D(size=(2, 2, 2), name='layer_no_15')(conv5), conv4],
        mode='concat',
        concat_axis=-1,
        name='layer_no_16')
    conv6 = Conv3D(filters=256,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_17_conv')(up6)
    conv6 = Conv3D(filters=256,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_18_conv')(conv6)

    up7 = merge(
        [UpSampling3D(size=(2, 2, 2), name='layer_no_19')(conv6), conv3],
        mode='concat',
        concat_axis=-1,
        name='layer_no_20')
    conv7 = Conv3D(filters=128,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_21_conv')(up7)
    conv7 = Conv3D(filters=128,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_22_conv')(conv7)

    up8 = merge(
        [UpSampling3D(size=(2, 2, 2), name='layer_no_23')(conv7), conv2],
        mode='concat',
        concat_axis=-1,
        name='layer_no_24')
    conv8 = Conv3D(filters=64,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_25_conv')(up8)
    conv8 = Conv3D(filters=64,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_26_conv')(conv8)

    up9 = merge(
        [UpSampling3D(size=(2, 2, 2), name='layer_no_27')(conv8), conv1],
        mode='concat',
        concat_axis=-1,
        name='layer_no_28')
    conv9 = Conv3D(filters=32,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_29_conv')(up9)
    conv9 = Conv3D(filters=32,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same',
                   name='layer_no_30_last')(conv9)

    conv10 = Conv3D(filters=3,
                    kernel_size=(1, 1, 1),
                    strides=(1, 1, 1),
                    activation='sigmoid',
                    name='layer_no_31_output')(conv9)

    output = CrfRnnLayer_3d(image_dims=(slices, height, weight),
                            num_classes=3,
                            theta_alpha=160.,
                            theta_beta=3.,
                            theta_gamma=3.,
                            num_iterations=1,
                            name='crfrnn')([conv10, inputs])

    weights = np.array([1.0, 100.0, 100.0])
    loss = lf.weighted_categorical_crossentropy_loss(weights)

    # Build the model
    model = Model(inputs, conv10, name='crfrnn_UNet')

    # model.compile(optimizer=Adam(lr=1.0e-5), loss="categorical_crossentropy", metrics=["categorical_accuracy"])
    model.compile(optimizer=Adam(lr=1.0e-5),
                  loss=loss,
                  metrics=["categorical_accuracy"])

    return model
Ejemplo n.º 3
0
def get_3d_rsunet_Gerda(opti):

    inputs = Input((cm.slices_3d, cm.img_rows_3d, cm.img_cols_3d, 1))

    layer1 = Conv3D(filters=16,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    padding='same')(inputs)
    layer2 = BN_ReLU(layer1)
    layer3 = conv_block(layer2, 32, 2, 2, 2, 'same')

    block1 = residual_module_3d(layer3, 32, 3, 3, 'same')

    layer4 = BN_ReLU(block1)
    layer5 = conv_block(layer4, 64, 2, 2, 2, 'same')

    block2 = residual_module_3d(layer5, 64, 3, 1, 'same')

    block3 = residual_module_3d(block2, 64, 3, 1, 'same')

    layer6 = BN_ReLU(block3)
    layer7 = conv_block(layer6, 128, 2, 2, 2, 'same')

    block4 = residual_module_3d(layer7, 128, 3, 1, 'same')

    block5 = residual_module_3d(block4, 128, 3, 1, 'same')

    layer8 = BN_ReLU(block5)

    concat1 = concat_block(layer8, layer6, 32, 2, 2, 'same')

    block6 = residual_module_3d(concat1, 64, 3, 1, 'same')

    block7 = residual_module_3d(block6, 64, 3, 1, 'same')

    layer9 = BN_ReLU(block7)

    concat2 = concat_block(layer9, layer4, 16, 2, 2, 'same')

    block8 = residual_module_3d(concat2, 32, 3, 1, 'same')

    layer10 = BN_ReLU(block8)

    concat3 = concat_block(layer10, layer2, 8, 2, 2, 'same')

    layer11 = conv_block(concat3, 16, 3, 1, 1, 'same')

    layer12 = BN_ReLU(layer11)

    conv10 = Conv3D(filters=3,
                    kernel_size=(1, 1, 1),
                    strides=(1, 1, 1),
                    activation='sigmoid')(layer12)

    model = Model(input=inputs, output=conv10)

    weights = np.array([0.1, 10, 10])
    loss = lf.weighted_categorical_crossentropy_loss(weights)
    # model.compile(optimizer=Adam(lr=1.0e-5), loss="categorical_crossentropy", metrics=["categorical_accuracy"])
    model.compile(optimizer=opti, loss=loss, metrics=["categorical_accuracy"])

    return model
def nifty_evaluation(vol_list):

    for i in range(len(vol_list)):

        aorta_GT_nifty_file = vol_list[
            i] + '/NIFTY/' + 'masksAortaGroundTruth.nii'
        pul_GT_nifty_file = vol_list[i] + '/NIFTY/' + 'masksPulGroundTruth.nii'

        aorta_Pred_nifty_file = vol_list[
            i] + '/NIFTY/' + 'masksAortaPredicted.nii'
        pul_Pred_nifty_file = vol_list[i] + '/NIFTY/' + 'masksPulPredicted.nii'

        # Show runtime:
        starttime = datetime.datetime.now()

        current_file = vol_list[i].split('/')[-2]
        current_dir = vol_list[i]

        stdout_backup = sys.stdout
        log_file = open(current_dir + "/logs_post.txt", "w")
        sys.stdout = log_file

        print('-' * 30)
        print('Start post-evaluating test data %04d/%04d...' %
              (i + 1, len(vol_list)))

        originVol, originVol_num, originVolwidth, originVolheight = dp.loadFile(
            originFile_list[i])
        maskAortaVol, maskAortaVol_num, maskAortaVolwidth, maskAortaVolheight = dp.loadFile(
            maskAortaFile_list[i])
        maskPulVol, maskPulVol_num, maskPulVolwidth, maskPulVolheight = dp.loadFile(
            maskPulFile_list[i])
        maskVol = maskAortaVol

        ds = dicom.read_file(originFile_list[i])
        image_pixel_space = ds.ImagerPixelSpacing
        pixel_space = ds.PixelSpacing
        ds = None

        for j in range(len(maskAortaVol)):
            maskAortaVol[j] = np.where(maskAortaVol[j] != 0, 1, 0)
        for j in range(len(maskPulVol)):
            maskPulVol[j] = np.where(maskPulVol[j] != 0, 2, 0)

        maskVol = maskVol + maskPulVol

        for j in range(len(maskVol)):
            maskVol[j] = np.where(maskVol[j] > 2, 0, maskVol[j])
            # maskVol[j] = np.where(maskVol[j] != 0, 0, maskVol[j])

        for i in range(originVol.shape[0]):
            img = originVol[i, :, :]

            out_test_images.append(img)
        for i in range(maskVol.shape[0]):
            img = maskVol[i, :, :]

            out_test_masks.append(img)

        vol_slices.append(originVol.shape[0])

        maskAortaVol = None
        maskPulVol = None
        maskVol = None
        originVol = None

        nb_class = 3
        outmasks_onehot = to_categorical(out_test_masks, num_classes=nb_class)
        final_test_images = np.ndarray([sum(vol_slices), 512, 512, 1],
                                       dtype=np.int16)
        final_test_masks = np.ndarray([sum(vol_slices), 512, 512, nb_class],
                                      dtype=np.int8)

        for i in range(len(out_test_images)):
            final_test_images[i, :, :, 0] = out_test_images[i]
            final_test_masks[i, :, :, :] = outmasks_onehot[i]

        outmasks_onehot = None
        out_test_masks = None
        out_test_images = None

        row = cm.img_rows_3d
        col = cm.img_cols_3d
        row_1 = int((512 - row) / 2)
        row_2 = int(512 - (512 - row) / 2)
        col_1 = int((512 - col) / 2)
        col_2 = int(512 - (512 - col) / 2)
        slices = cm.slices_3d
        gaps = cm.gaps_3d

        final_images_crop = final_test_images[:, row_1:row_2, col_1:col_2, :]
        final_masks_crop = final_test_masks[:, row_1:row_2, col_1:col_2, :]

        sitk.WriteImage(
            sitk.GetImageFromArray(np.uint16(final_test_masks[:, :, :, 1])),
            current_dir + '/DICOM/masksAortaGroundTruth.dcm')
        sitk.WriteImage(
            sitk.GetImageFromArray(np.uint16(final_test_masks[:, :, :, 2])),
            current_dir + '/DICOM/masksPulGroundTruth.dcm')

        dicom_temp = dicom.read_file(current_dir +
                                     '/DICOM/masksAortaGroundTruth.dcm')
        dicom_temp.ImagerPixelSpacing = image_pixel_space
        dicom_temp.PixelSpacing = pixel_space
        dicom_temp.save_as(current_dir + '/DICOM/masksAortaGroundTruth.dcm')

        dicom_temp = dicom.read_file(current_dir +
                                     '/DICOM/masksPulGroundTruth.dcm')
        dicom_temp.ImagerPixelSpacing = image_pixel_space
        dicom_temp.PixelSpacing = pixel_space
        dicom_temp.save_as(current_dir + '/DICOM/masksPulGroundTruth.dcm')

        nii_space = np.eye(4)
        nii_space[0, 0] = image_pixel_space[0]
        nii_space[1, 1] = image_pixel_space[1]

        nii_temp = nib.Nifti1Image(
            np.swapaxes(np.uint16(final_test_masks[:, ::-1, ::-1, 1]), 0, 2),
            nii_space)
        nib.save(nii_temp, current_dir + '/NIFTY/masksAortaGroundTruth.nii')

        nii_temp = nib.Nifti1Image(
            np.swapaxes(np.uint16(final_test_masks[:, ::-1, ::-1, 2]), 0, 2),
            nii_space)
        nib.save(nii_temp, current_dir + '/NIFTY/masksPulGroundTruth.nii')

        sitk.WriteImage(
            sitk.GetImageFromArray(np.uint16(final_test_masks[:, :, :, 1])),
            current_dir + '/mhd/masksAortaGroundTruth.mhd')
        sitk.WriteImage(
            sitk.GetImageFromArray(np.uint16(final_test_masks[:, :, :, 2])),
            current_dir + '/mhd/masksPulGroundTruth.mhd')

        # clear the masks for the final step:
        final_test_masks = np.where(final_test_masks == 0, 0, 0)

        num_patches = int((sum(vol_slices) - slices) / gaps)

        test_image = np.ndarray([1, slices, row, col, 1], dtype=np.int16)

        predicted_mask_volume = np.ndarray(
            [sum(vol_slices), row, col, nb_class], dtype=np.float32)

        # model = DenseUNet_3D.get_3d_denseunet()
        # model = CRFRNN.get_3d_crfrnn_model_def()
        # model = UNet_3D.get_3d_unet_bn()
        # model = RSUNet_3D.get_3d_rsunet()
        # model = UNet_3D.get_3d_wnet1()
        model = UNet_3D.get_3d_unet()
        # model = RSUNet_3D_Gerda.get_3d_rsunet_Gerdafeature(opti)

        using_start_end = 1
        start_slice = cm.start_slice
        end_slice = -1

        if use_existing:
            model.load_weights(modelname)

        for i in range(num_patches):
            count1 = i * gaps
            count2 = i * gaps + slices
            test_image[0] = final_images_crop[count1:count2]

            predicted_mask = model.predict(test_image)

            # if i == int(num_patches*0.63):
            #   vs.visualize_activation_in_layer_one_plot_add_weights(model, test_image, current_dir)
            # else:
            #   pass

            predicted_mask_volume[count1:count2] += predicted_mask[
                0, :, :, :, :]

        t = len(predicted_mask_volume)
        for i in range(0, slices, gaps):
            predicted_mask_volume[i:(
                i +
                gaps)] = predicted_mask_volume[i:(i + gaps)] / (i / gaps + 1)

        for i in range(0, slices, gaps):
            predicted_mask_volume[(t - i -
                                   gaps):(t - i)] = predicted_mask_volume[
                                       (t - i - gaps):(t - i)] / (i / gaps + 1)

        for i in range(slices, (len(predicted_mask_volume) - slices)):
            predicted_mask_volume[i] = predicted_mask_volume[i] / (slices /
                                                                   gaps)

        np.save(cm.workingPath.testingNPY_path + 'testImages.npy',
                final_images_crop)
        np.save(cm.workingPath.testingNPY_path + 'testMasks.npy',
                final_masks_crop)
        np.save(cm.workingPath.testingNPY_path + 'masksTestPredicted.npy',
                predicted_mask_volume)

        final_images_crop = None
        final_masks_crop = None
        predicted_mask_volume = None

        imgs_origin = np.load(cm.workingPath.testingNPY_path +
                              'testImages.npy').astype(np.int16)
        imgs_true = np.load(cm.workingPath.testingNPY_path +
                            'testMasks.npy').astype(np.int8)
        imgs_predict = np.load(cm.workingPath.testingNPY_path +
                               'masksTestPredicted.npy').astype(np.float32)
        imgs_predict_threshold = np.load(cm.workingPath.testingNPY_path +
                                         'masksTestPredicted.npy').astype(
                                             np.float32)

        # ########## ROC curve aorta
        #
        # actual = imgs_true[:, :, :, 1].reshape(-1)
        # predictions = imgs_predict[:, :, :, 1].reshape(-1)
        # # predictions = np.where(predictions < (0.7), 0, 1)
        #
        # false_positive_rate_aorta, true_positive_rate_aorta, thresholds_aorta = roc_curve(actual, predictions, pos_label=1)
        # roc_auc_aorta = auc(false_positive_rate_aorta, true_positive_rate_aorta)
        # plt.figure(1, figsize=(6, 6))
        # plt.figure(1)
        # plt.title('ROC of Aorta')
        # plt.plot(false_positive_rate_aorta, true_positive_rate_aorta, 'b')
        # label = 'AUC = %0.2f' % roc_auc_aorta
        # plt.legend(loc='lower right')
        # plt.plot([0, 1], [0, 1], 'r--')
        # plt.xlim([-0.0, 1.0])
        # plt.ylim([-0.0, 1.0])
        # plt.xlabel('False Positive Rate')
        # plt.ylabel('True Positive Rate')
        # # plt.show()
        # saveName = '/Plots/ROC_Aorta_curve.png'
        # plt.savefig(current_dir + saveName)
        # plt.close()
        # ########## ROC curve pul
        #
        # actual = imgs_true[:, :, :, 2].reshape(-1)
        # predictions = imgs_predict[:, :, :, 2].reshape(-1)
        #
        # false_positive_rate_pul, true_positive_rate_pul, thresholds_pul = roc_curve(actual, predictions, pos_label=1)
        # roc_auc_pul = auc(false_positive_rate_pul, true_positive_rate_pul)
        # plt.figure(2, figsize=(6, 6))
        # plt.figure(2)
        # plt.title('ROC of pul')
        # plt.plot(false_positive_rate_pul, true_positive_rate_pul, 'b')
        # label = 'AUC = %0.2f' % roc_auc_pul
        # plt.legend(loc='lower right')
        # plt.plot([0, 1], [0, 1], 'r--')
        # plt.xlim([-0.0, 1.0])
        # plt.ylim([-0.0, 1.0])
        # plt.xlabel('False Positive Rate')
        # plt.ylabel('True Positive Rate')
        # # plt.show()
        # saveName = '/Plots/ROC_Pul_curve.png'
        # plt.savefig(current_dir + saveName)
        # plt.close()

        false_positive_rate_aorta = None
        true_positive_rate_aorta = None
        false_positive_rate_pul = None
        true_positive_rate_pul = None

        imgs_predict_threshold = np.where(imgs_predict_threshold < (0.5), 0, 1)

        if using_start_end == 1:
            aortaMean = lf.dice_coef_np(
                imgs_predict_threshold[start_slice:end_slice, :, :, 1],
                imgs_true[start_slice:end_slice, :, :, 1])
            pulMean = lf.dice_coef_np(
                imgs_predict_threshold[start_slice:end_slice, :, :, 2],
                imgs_true[start_slice:end_slice, :, :, 2])
        else:
            aortaMean = lf.dice_coef_np(imgs_predict_threshold[:, :, :, 1],
                                        imgs_true[:, :, :, 1])
            pulMean = lf.dice_coef_np(imgs_predict_threshold[:, :, :, 2],
                                      imgs_true[:, :, :, 2])

        np.savetxt(current_dir + '/Plots/Aorta_Dice_mean.txt',
                   np.array(aortaMean).reshape(1, ),
                   fmt='%.5f')
        np.savetxt(current_dir + '/Plots/Pul_Dice_mean.txt',
                   np.array(pulMean).reshape(1, ),
                   fmt='%.5f')

        print('Model file:', modelname)
        print('-' * 30)
        print('Aorta Dice Coeff', aortaMean)
        print('Pul Dice Coeff', pulMean)
        print('-' * 30)

        # Draw the subplots of figures:

        color1 = 'gray'  # ***
        color2 = 'viridis'  # ******
        # color = 'plasma'  # **
        # color = 'magma'  # ***
        # color2 = 'RdPu'  # ***
        # color = 'gray'  # ***
        # color = 'gray'  # ***

        transparent1 = 1.0
        transparent2 = 0.5

        # Slice parameters:

        #################################### Aorta
        # Automatically:

        steps = 40
        slice = range(0, len(imgs_origin), steps)
        plt_row = 3
        plt_col = int(len(imgs_origin) / steps)

        plt.figure(3, figsize=(25, 12))
        plt.figure(3)

        for i in slice:
            if i == 0:
                plt_num = int(i / steps) + 1
            else:
                plt_num = int(i / steps)

            if plt_num <= plt_col:

                ax1 = plt.subplot(plt_row, plt_col, plt_num)
                title = 'slice=' + str(i)
                plt.title(title)
                ax1.imshow(imgs_origin[i, :, :, 0],
                           cmap=color1,
                           alpha=transparent1)
                ax1.imshow(imgs_true[i, :, :, 1],
                           cmap=color2,
                           alpha=transparent2)

                ax2 = plt.subplot(plt_row, plt_col, plt_num + plt_col)
                title = 'slice=' + str(i)
                plt.title(title)
                ax2.imshow(imgs_origin[i, :, :, 0],
                           cmap=color1,
                           alpha=transparent1)
                ax2.imshow(imgs_predict[i, :, :, 1],
                           cmap=color2,
                           alpha=transparent2)

                ax3 = plt.subplot(plt_row, plt_col, plt_num + 2 * plt_col)
                title = 'slice=' + str(i)
                plt.title(title)
                ax3.imshow(imgs_origin[i, :, :, 0],
                           cmap=color1,
                           alpha=transparent1)
                ax3.imshow(imgs_predict_threshold[i, :, :, 1],
                           cmap=color2,
                           alpha=transparent2)
            else:
                pass

        modelname = cm.modellist[0]

        imageName = re.findall(r'\d+\.?\d*', modelname)
        epoch_num = int(imageName[0]) + 1
        accuracy = float(
            np.loadtxt(current_dir + '/Plots/Aorta_Dice_mean.txt', float))

        # saveName = 'epoch_' + str(epoch_num) + '_dice_' +str(accuracy) + '.png'
        saveName = '/Plots/epoch_Aorta_%02d_dice_%.3f.png' % (epoch_num - 1,
                                                              accuracy)

        plt.subplots_adjust(left=0.0,
                            bottom=0.05,
                            right=1.0,
                            top=0.95,
                            hspace=0.3,
                            wspace=0.3)
        plt.savefig(current_dir + saveName)
        plt.close()
        # plt.show()

        ################################ Pulmonary
        steps = 40
        slice = range(0, len(imgs_origin), steps)
        plt_row = 3
        plt_col = int(len(imgs_origin) / steps)

        plt.figure(4, figsize=(25, 12))
        plt.figure(4)
        for i in slice:
            if i == 0:
                plt_num = int(i / steps) + 1
            else:
                plt_num = int(i / steps)

            if plt_num <= plt_col:

                ax1 = plt.subplot(plt_row, plt_col, plt_num)
                title = 'slice=' + str(i)
                plt.title(title)
                ax1.imshow(imgs_origin[i, :, :, 0],
                           cmap=color1,
                           alpha=transparent1)
                ax1.imshow(imgs_true[i, :, :, 2],
                           cmap=color2,
                           alpha=transparent2)

                ax2 = plt.subplot(plt_row, plt_col, plt_num + plt_col)
                title = 'slice=' + str(i)
                plt.title(title)
                ax2.imshow(imgs_origin[i, :, :, 0],
                           cmap=color1,
                           alpha=transparent1)
                ax2.imshow(imgs_predict[i, :, :, 2],
                           cmap=color2,
                           alpha=transparent2)

                ax3 = plt.subplot(plt_row, plt_col, plt_num + 2 * plt_col)
                title = 'slice=' + str(i)
                plt.title(title)
                ax3.imshow(imgs_origin[i, :, :, 0],
                           cmap=color1,
                           alpha=transparent1)
                ax3.imshow(imgs_predict_threshold[i, :, :, 2],
                           cmap=color2,
                           alpha=transparent2)
            else:
                pass

        modelname = cm.modellist[0]

        imageName = re.findall(r'\d+\.?\d*', modelname)
        epoch_num = int(imageName[0]) + 1
        accuracy = float(
            np.loadtxt(current_dir + '/Plots/Pul_Dice_mean.txt', float))

        # saveName = 'epoch_' + str(epoch_num) + '_dice_' +str(accuracy) + '.png'
        saveName = '/Plots/epoch_Pul_%02d_dice_%.3f.png' % (epoch_num - 1,
                                                            accuracy)

        plt.subplots_adjust(left=0.0,
                            bottom=0.05,
                            right=1.0,
                            top=0.95,
                            hspace=0.3,
                            wspace=0.3)
        plt.savefig(current_dir + saveName)
        plt.close()
        # plt.show()

        print('Images saved')
        # Save npy as dcm files:

        final_test_aorta_predicted_threshold = final_test_masks[:, :, :, 1]
        final_test_pul_predicted_threshold = final_test_masks[:, :, :, 2]

        final_test_aorta_predicted_threshold[:, row_1:row_2, col_1:
                                             col_2] = imgs_predict_threshold[:, :, :,
                                                                             1]
        final_test_pul_predicted_threshold[:, row_1:row_2, col_1:
                                           col_2] = imgs_predict_threshold[:, :, :,
                                                                           2]

        new_imgs_dcm = sitk.GetImageFromArray(
            np.uint16(final_test_images + 4000))
        new_imgs_aorta_predict_dcm = sitk.GetImageFromArray(
            np.uint16(final_test_aorta_predicted_threshold))
        new_imgs_pul_predict_dcm = sitk.GetImageFromArray(
            np.uint16(final_test_pul_predicted_threshold))

        sitk.WriteImage(new_imgs_dcm,
                        current_dir + '/DICOM/OriginalImages.dcm')
        sitk.WriteImage(new_imgs_aorta_predict_dcm,
                        current_dir + '/DICOM/masksAortaPredicted.dcm')
        sitk.WriteImage(new_imgs_pul_predict_dcm,
                        current_dir + '/DICOM/masksPulPredicted.dcm')

        dicom_temp = dicom.read_file(current_dir + '/DICOM/OriginalImages.dcm')
        dicom_temp.ImagerPixelSpacing = image_pixel_space
        dicom_temp.PixelSpacing = pixel_space
        dicom_temp.save_as(current_dir + '/DICOM/OriginalImages.dcm')

        nii_temp = nib.Nifti1Image(
            np.swapaxes(np.uint16(final_test_images[:, ::-1, ::-1]), 0, 2),
            nii_space)
        nib.save(nii_temp, current_dir + '/NIFTY/OriginalImages.nii')

        dicom_temp = dicom.read_file(current_dir +
                                     '/DICOM/masksAortaPredicted.dcm')
        dicom_temp.ImagerPixelSpacing = image_pixel_space
        dicom_temp.PixelSpacing = pixel_space
        dicom_temp.save_as(current_dir + '/DICOM/masksAortaPredicted.dcm')

        dicom_temp = dicom.read_file(current_dir +
                                     '/DICOM/masksPulPredicted.dcm')
        dicom_temp.ImagerPixelSpacing = image_pixel_space
        dicom_temp.PixelSpacing = pixel_space
        dicom_temp.save_as(current_dir + '/DICOM/masksPulPredicted.dcm')

        dicom_temp = None

        final_test_aorta_predicted = final_test_masks[:, :, :, 1]
        final_test_pul_predicted = final_test_masks[:, :, :, 2]

        final_test_aorta_predicted[:, row_1:row_2,
                                   col_1:col_2] = imgs_predict[:, :, :, 1]
        final_test_pul_predicted[:, row_1:row_2,
                                 col_1:col_2] = imgs_predict[:, :, :, 2]

        nii_temp = nib.Nifti1Image(
            np.swapaxes(np.uint16(final_test_aorta_predicted[:, ::-1, ::-1]),
                        0, 2), nii_space)
        nib.save(nii_temp, current_dir + '/NIFTY/masksAortaPredicted.nii')

        nii_temp = nib.Nifti1Image(
            np.swapaxes(np.uint16(final_test_pul_predicted[:, ::-1, ::-1]), 0,
                        2), nii_space)
        nib.save(nii_temp, current_dir + '/NIFTY/masksPulPredicted.nii')

        nii_temp = None

        sitk.WriteImage(new_imgs_dcm, current_dir + '/mhd/imagesPredicted.mhd')
        sitk.WriteImage(new_imgs_aorta_predict_dcm,
                        current_dir + '/mhd/masksAortaPredicted.mhd')
        sitk.WriteImage(new_imgs_pul_predict_dcm,
                        current_dir + '/mhd/masksPulPredicted.mhd')

        # mt.SegmentDist(current_dir + '/mhd/masksAortaPredicted.mhd',current_dir + '/mhd/masksAortaGroundTruth.mhd', current_dir + '/Surface_Distance/Aorta', 'Aorta')
        # mt.SegmentDist(current_dir + '/mhd/masksPulPredicted.mhd',current_dir + '/mhd/masksPulGroundTruth.mhd', current_dir + '/Surface_Distance/Pul', 'Pul')

        print('DICOM saved')

        # Clear memory for the next testing sample:

        final_test_aorta_predicted_threshold = None
        final_test_pul_predicted_threshold = None
        final_test_aorta_predicted = None
        final_test_pul_predicted = None
        imgs_predict_threshold = None
        new_imgs_dcm = None
        new_imgs_aorta_predict_dcm = None
        new_imgs_pul_predict_dcm = None
        final_test_images = None
        final_test_masks = None
        imgs_origin = None
        imgs_predict = None
        imgs_true = None
        predicted_mask = None
        predictions = None

        endtime = datetime.datetime.now()
        print('-' * 30)
        print('running time:', endtime - starttime)

        log_file.close()
        sys.stdout = stdout_backup
def model_test(use_existing):
    print('-' * 30)
    print('Loading test data...')

    # Loading test data:
    filename = cm.filename
    modelname = cm.modellist[0]
    originFile_list = sorted(
        glob(cm.workingPath.originTestingSet_path + filename))
    maskFile_list = sorted(glob(cm.workingPath.aortaTestingSet_path +
                                filename))

    out_test_images = []
    out_test_masks = []

    for i in range(len(originFile_list)):
        # originTestVolInfo = loadFileInformation(originFile_list[i])
        # maskTestVolInfo = loadFileInformation(maskFile_list[i])

        originTestVol, originTestVol_num, originTestVolwidth, originTestVolheight = loadFile(
            originFile_list[i])
        maskTestVol, maskTestVol_num, maskTestVolwidth, maskTestVolheight = loadFile(
            maskFile_list[i])

        for j in range(len(maskTestVol)):
            maskTestVol[j] = np.where(maskTestVol[j] != 0, 1, 0)
        for img in originTestVol:
            out_test_images.append(img)
        for img in maskTestVol:
            out_test_masks.append(img)

    num_test_images = len(out_test_images)

    final_test_images = np.ndarray([num_test_images, 512, 512], dtype=np.int16)
    final_test_masks = np.ndarray([num_test_images, 512, 512], dtype=np.int8)

    for i in range(num_test_images):
        final_test_images[i] = out_test_images[i]
        final_test_masks[i] = out_test_masks[i]
    final_test_images = np.expand_dims(final_test_images, axis=-1)
    final_test_masks = np.expand_dims(final_test_masks, axis=-1)

    row = cm.img_rows_3d
    col = cm.img_cols_3d
    num_rowes = 1
    num_coles = 1
    row_1 = int((512 - row) / 2)
    row_2 = int(512 - (512 - row) / 2)
    col_1 = int((512 - col) / 2)
    col_2 = int(512 - (512 - col) / 2)
    slices = cm.slices_3d
    gaps = cm.gaps_3d

    learning_rate = 0.00001

    adam = Adam(lr=learning_rate)

    opti = adam

    final_images_crop = final_test_images[:, row_1:row_2, col_1:col_2, :]
    final_masks_crop = final_test_masks[:, row_1:row_2, col_1:col_2, :]

    num_patches = int((num_test_images - slices) / gaps)
    num_patches1 = int(final_images_crop.shape[0] / slices)

    test_image = np.ndarray([1, slices, row, col, 1], dtype=np.int16)
    test_mask = np.ndarray([1, slices, row, col, 1], dtype=np.int8)

    predicted_mask_volume = np.ndarray([num_test_images, row, col],
                                       dtype=np.float32)

    # model = nw.get_3D_unet()
    model = UNet_3D.get_3d_unet()
    # model = nw.get_3D_unet_drop_1()
    # model = nw.get_3D_unet_BN()

    using_start_end = 1
    start_slice = cm.start_slice
    end_slice = -1

    if use_existing:
        model.load_weights(modelname)

    for i in range(num_patches):
        count1 = i * gaps
        count2 = i * gaps + slices
        test_image[0] = final_images_crop[count1:count2]
        test_mask[0] = final_masks_crop[count1:count2]

        predicted_mask = model.predict(test_image)

        if i == 88:
            vs.visualize_activation_in_layer(model, test_image)
        else:
            pass

        predicted_mask_volume[count1:count2] += predicted_mask[0, :, :, :, 0]

    predicted_mask_volume = np.expand_dims(predicted_mask_volume, axis=-1)
    np.save(cm.workingPath.testingSet_path + 'testImages.npy',
            final_images_crop)
    np.save(cm.workingPath.testingSet_path + 'testMasks.npy', final_masks_crop)
    np.save(cm.workingPath.testingSet_path + 'masksTestPredicted.npy',
            predicted_mask_volume)

    imgs_origin = np.load(cm.workingPath.testingSet_path +
                          'testImages.npy').astype(np.int16)
    imgs_true = np.load(cm.workingPath.testingSet_path +
                        'testMasks.npy').astype(np.int8)
    imgs_predict = np.load(cm.workingPath.testingSet_path +
                           'masksTestPredicted.npy').astype(np.float32)
    imgs_predict_threshold = np.load(cm.workingPath.testingSet_path +
                                     'masksTestPredicted.npy').astype(
                                         np.float32)

    imgs_origin = np.squeeze(imgs_origin, axis=-1)
    imgs_true = np.squeeze(imgs_true, axis=-1)
    imgs_predict = np.squeeze(imgs_predict, axis=-1)
    imgs_predict_threshold = np.squeeze(imgs_predict_threshold, axis=-1)

    imgs_predict_threshold = np.where(imgs_predict_threshold < (10), 0, 1)

    if using_start_end == 1:
        mean = lf.dice_coef_np(imgs_predict_threshold[start_slice:end_slice],
                               imgs_true[start_slice:end_slice])
    else:
        mean = lf.dice_coef_np(imgs_predict_threshold, imgs_true)

    np.savetxt(cm.workingPath.testingSet_path + 'dicemean.txt',
               np.array(mean).reshape(1, ),
               fmt='%.5f')

    print('Model file:', modelname)
    print('Total Dice Coeff', mean)
    print('-' * 30)

    # Draw the subplots of figures:

    color1 = 'gray'  # ***
    color2 = 'viridis'  # ******
    # color = 'plasma'  # **
    # color = 'magma'  # ***
    # color2 = 'RdPu'  # ***
    # color = 'gray'  # ***
    # color = 'gray'  # ***

    transparent1 = 1.0
    transparent2 = 0.5

    # Slice parameters:

    # #############################################
    # # Automatically:
    #
    steps = 40
    slice = range(0, len(imgs_origin), steps)
    plt_row = 3
    plt_col = int(len(imgs_origin) / steps)

    plt.figure(1, figsize=(25, 12))

    for i in slice:
        if i == 0:
            plt_num = int(i / steps) + 1
        else:
            plt_num = int(i / steps)

        if plt_num <= plt_col:

            plt.figure(1)

            ax1 = plt.subplot(plt_row, plt_col, plt_num)
            title = 'slice=' + str(i)
            plt.title(title)
            ax1.imshow(imgs_origin[i, :, :], cmap=color1, alpha=transparent1)
            ax1.imshow(imgs_true[i, :, :], cmap=color2, alpha=transparent2)

            ax2 = plt.subplot(plt_row, plt_col, plt_num + plt_col)
            title = 'slice=' + str(i)
            plt.title(title)
            ax2.imshow(imgs_origin[i, :, :], cmap=color1, alpha=transparent1)
            ax2.imshow(imgs_predict[i, :, :], cmap=color2, alpha=transparent2)

            ax3 = plt.subplot(plt_row, plt_col, plt_num + 2 * plt_col)
            title = 'slice=' + str(i)
            plt.title(title)
            ax3.imshow(imgs_origin[i, :, :], cmap=color1, alpha=transparent1)
            ax3.imshow(imgs_predict_threshold[i, :, :],
                       cmap=color2,
                       alpha=transparent2)
        else:
            pass

    modelname = cm.modellist[0]

    imageName = re.findall(r'\d+\.?\d*', modelname)
    epoch_num = int(imageName[0]) + 1
    accuracy = float(
        np.loadtxt(cm.workingPath.testingSet_path + 'dicemean.txt', float))

    # saveName = 'epoch_' + str(epoch_num) + '_dice_' +str(accuracy) + '.png'
    saveName = 'epoch_%02d_dice_%.3f.png' % (epoch_num - 1, accuracy)

    plt.subplots_adjust(left=0.0,
                        bottom=0.05,
                        right=1.0,
                        top=0.95,
                        hspace=0.3,
                        wspace=0.3)
    plt.savefig(cm.workingPath.testingSet_path + saveName)
    # plt.show()

    print('Images saved')

    # Save npy as dcm files:

    # final_test_predicted_threshold = np.ndarray([num_test_images, 512, 512], dtype=np.int8)

    # final_test_images = np.squeeze(final_test_images + 4000, axis=-1)
    final_test_masks = np.squeeze(final_test_masks, axis=-1)

    # final_test_images[0:num_patches1 * slices, row_1:row_2, col_1:col_2,] = imgs_origin[:, :, :]
    # final_test_masks[0:num_patches1 * slices:, row_1:row_2, col_1:col_2,] = imgs_true[:, :, :]
    final_test_predicted_threshold = final_test_masks
    final_test_predicted_threshold[:, row_1:row_2, col_1:
                                   col_2] = imgs_predict_threshold[:, :, :]

    final_test_predicted_threshold = np.uint16(final_test_predicted_threshold)

    new_imgs_predict_dcm = sitk.GetImageFromArray(
        final_test_predicted_threshold)

    sitk.WriteImage(new_imgs_predict_dcm,
                    cm.workingPath.testingSet_path + 'masksTestPredicted.dcm')

    ds1 = dicom.read_file(maskFile_list[0])
    ds2 = dicom.read_file(cm.workingPath.testingSet_path +
                          'masksTestPredicted.dcm')
    ds1.PixelData = ds2.PixelData
    ds1.save_as(cm.workingPath.testingSet_path + 'masksTestPredicted.dcm')

    print('DICOM saved')
Ejemplo n.º 6
0
def get_3d_wnet1():

    inputs = Input((cm.slices_3d, cm.img_rows_3d, cm.img_cols_3d, 1))

    conv1 = Conv3D(filters=16,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(inputs)
    conv1 = Conv3D(filters=16,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(conv1)
    pool1 = MaxPooling3D(pool_size=(2, 2, 2))(conv1)

    conv1p = Conv3D(filters=16,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(inputs)
    conv1p = Conv3D(filters=16,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(conv1p)
    pool1p = MaxPooling3D(pool_size=(2, 2, 2))(conv1p)

    conv2 = Conv3D(filters=32,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(pool1)
    conv2 = Conv3D(filters=32,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(conv2)
    pool2 = MaxPooling3D(pool_size=(2, 2, 2))(conv2)

    conv2p = Conv3D(filters=32,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(pool1p)
    conv2p = Conv3D(filters=32,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(conv2p)
    pool2p = MaxPooling3D(pool_size=(2, 2, 2))(conv2p)

    conv3 = Conv3D(filters=64,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(pool2)
    conv3 = Conv3D(filters=64,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(conv3)
    pool3 = MaxPooling3D(pool_size=(2, 2, 2))(conv3)

    conv3p = Conv3D(filters=64,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(pool2p)
    conv3p = Conv3D(filters=64,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(conv3p)
    pool3p = MaxPooling3D(pool_size=(2, 2, 2))(conv3p)

    conv4 = Conv3D(filters=128,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(pool3)
    conv4 = Conv3D(filters=128,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(conv4)
    pool4 = MaxPooling3D(pool_size=(2, 2, 2))(conv4)

    conv4p = Conv3D(filters=128,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(pool3p)
    conv4p = Conv3D(filters=128,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(conv4p)
    pool4p = MaxPooling3D(pool_size=(2, 2, 2))(conv4p)

    conv5 = Conv3D(filters=256,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(pool4)
    conv5 = Conv3D(filters=256,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(conv5)

    conv5p = Conv3D(filters=256,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(pool4p)
    conv5p = Conv3D(filters=256,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(conv5p)

    up6 = merge([
        UpSampling3D(size=(2, 2, 2))(conv5), conv4,
        UpSampling3D(size=(2, 2, 2))(conv5p), conv4p
    ],
                mode='concat',
                concat_axis=-1)
    conv6 = Conv3D(filters=256,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(up6)
    conv6 = Conv3D(filters=128,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(conv6)

    conv6p = Conv3D(filters=256,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(up6)
    conv6p = Conv3D(filters=128,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(conv6p)

    up7 = merge([
        UpSampling3D(size=(2, 2, 2))(conv6), conv3,
        UpSampling3D(size=(2, 2, 2))(conv6p), conv3p
    ],
                mode='concat',
                concat_axis=-1)
    conv7 = Conv3D(filters=128,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(up7)
    conv7 = Conv3D(filters=64,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(conv7)

    conv7p = Conv3D(filters=128,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(up7)
    conv7p = Conv3D(filters=64,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(conv7p)

    up8 = merge([
        UpSampling3D(size=(2, 2, 2))(conv7), conv2,
        UpSampling3D(size=(2, 2, 2))(conv7p), conv2p
    ],
                mode='concat',
                concat_axis=-1)
    conv8 = Conv3D(filters=64,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(up8)
    conv8 = Conv3D(filters=32,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(conv8)

    conv8p = Conv3D(filters=64,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(up8)
    conv8p = Conv3D(filters=32,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(conv8p)

    up9 = merge([
        UpSampling3D(size=(2, 2, 2))(conv8), conv1,
        UpSampling3D(size=(2, 2, 2))(conv8p), conv1p
    ],
                mode='concat',
                concat_axis=-1)
    conv9 = Conv3D(filters=32,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(up9)
    conv9 = Conv3D(filters=16,
                   kernel_size=(3, 3, 3),
                   strides=(1, 1, 1),
                   activation='relu',
                   border_mode='same')(conv9)

    conv9p = Conv3D(filters=32,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(up9)
    conv9p = Conv3D(filters=16,
                    kernel_size=(3, 3, 3),
                    strides=(1, 1, 1),
                    activation='relu',
                    border_mode='same')(conv9p)

    conv10a = merge([conv9, conv9p], mode='concat', concat_axis=-1)
    conv10 = Conv3D(filters=3,
                    kernel_size=(1, 1, 1),
                    strides=(1, 1, 1),
                    activation='sigmoid')(conv10a)

    model = Model(input=inputs, output=conv10)

    weights = np.array([1, 1, 1])
    loss = lf.weighted_categorical_crossentropy_loss(weights)
    # model.compile(optimizer=Adam(lr=1.0e-5), loss="categorical_crossentropy", metrics=["categorical_accuracy"])
    model.compile(optimizer=Adam(lr=1.0e-5),
                  loss=loss,
                  metrics=["categorical_accuracy"])
    # model.compile(optimizer=opti, loss="categorical_crossentropy", metrics=["categorical_accuracy"])

    return model
Ejemplo n.º 7
0
def model_test(use_existing):

    cm.mkdir(cm.workingPath.testingResults_path)
    cm.mkdir(cm.workingPath.testingNPY_path)

    # Loading test data:

    filename = cm.filename
    modelname = cm.modellist[0]

    # Single CT:
    originFile_list = sorted(
        glob(cm.workingPath.originTestingSet_path + filename))
    maskAortaFile_list = sorted(
        glob(cm.workingPath.aortaTestingSet_path + filename))
    maskPulFile_list = sorted(
        glob(cm.workingPath.pulTestingSet_path + filename))

    # Zahra CTs:
    # originFile_list = sorted(glob(cm.workingPath.originTestingSet_path + "vol*.dcm"))
    # maskAortaFile_list = sorted(glob(cm.workingPath.aortaTestingSet_path + "vol*.dcm"))
    # maskPulFile_list = sorted(glob(cm.workingPath.pulTestingSet_path + "vol*.dcm"))

    # Lidia CTs:
    # originFile_list = sorted(glob(cm.workingPath.originLidiaTestingSet_path + "vol*.dcm"))[61:]
    # maskAortaFile_list = sorted(glob(cm.workingPath.originLidiaTestingSet_path + "vol*.dcm"))[61:]
    # maskPulFile_list = sorted(glob(cm.workingPath.originLidiaTestingSet_path + "vol*.dcm"))[61:]

    # Abnormal CTs:
    # originFile_list = sorted(glob(cm.workingPath.originAbnormalTestingSet_path + "vol126*.dcm"))
    # maskAortaFile_list = sorted(glob(cm.workingPath.originAbnormalTestingSet_path + "vol126*.dcm"))
    # maskPulFile_list = sorted(glob(cm.workingPath.originAbnormalTestingSet_path + "vol126*.dcm"))

    for i in range(len(originFile_list)):

        # Show runtime:
        starttime = datetime.datetime.now()

        vol_slices = []
        out_test_images = []
        out_test_masks = []

        current_file = originFile_list[i].split('/')[-1]
        current_dir = cm.workingPath.testingResults_path + str(
            current_file[:-17])
        cm.mkdir(current_dir)
        cm.mkdir(current_dir + '/Plots/')
        cm.mkdir(current_dir + '/Surface_Distance/Aorta/')
        cm.mkdir(current_dir + '/Surface_Distance/Pul/')
        cm.mkdir(current_dir + '/DICOM/')
        cm.mkdir(current_dir + '/mhd/')

        stdout_backup = sys.stdout
        log_file = open(current_dir + "/logs.txt", "w")
        sys.stdout = log_file

        print('-' * 30)
        print('Loading test data %04d/%04d...' % (i + 1, len(originFile_list)))

        originVol, originVol_num, originVolwidth, originVolheight = dp.loadFile(
            originFile_list[i])
        maskAortaVol, maskAortaVol_num, maskAortaVolwidth, maskAortaVolheight = dp.loadFile(
            maskAortaFile_list[i])
        maskPulVol, maskPulVol_num, maskPulVolwidth, maskPulVolheight = dp.loadFile(
            maskPulFile_list[i])
        maskVol = maskAortaVol

        for j in range(len(maskAortaVol)):
            maskAortaVol[j] = np.where(maskAortaVol[j] != 0, 1, 0)
        for j in range(len(maskPulVol)):
            maskPulVol[j] = np.where(maskPulVol[j] != 0, 2, 0)

        maskVol = maskVol + maskPulVol

        for j in range(len(maskVol)):
            maskVol[j] = np.where(maskVol[j] > 2, 0, maskVol[j])
            # maskVol[j] = np.where(maskVol[j] != 0, 0, maskVol[j])

        # Make the Vessel class
        for j in range(len(maskVol)):
            maskVol[j] = np.where(maskVol[j] != 0, 1, 0)

        for i in range(originVol.shape[0]):
            img = originVol[i, :, :]

            out_test_images.append(img)
        for i in range(maskVol.shape[0]):
            img = maskVol[i, :, :]

            out_test_masks.append(img)

        vol_slices.append(originVol.shape[0])

        maskAortaVol = None
        maskPulVol = None
        maskVol = None
        originVol = None

        nb_class = 2
        outmasks_onehot = to_categorical(out_test_masks, num_classes=nb_class)
        final_test_images = np.ndarray([sum(vol_slices), 512, 512, 1],
                                       dtype=np.int16)
        final_test_masks = np.ndarray([sum(vol_slices), 512, 512, nb_class],
                                      dtype=np.int8)

        for i in range(len(out_test_images)):
            final_test_images[i, :, :, 0] = out_test_images[i]
            final_test_masks[i, :, :, :] = outmasks_onehot[i]

        outmasks_onehot = None
        out_test_masks = None
        out_test_images = None

        row = cm.img_rows_3d
        col = cm.img_cols_3d
        row_1 = int((512 - row) / 2)
        row_2 = int(512 - (512 - row) / 2)
        col_1 = int((512 - col) / 2)
        col_2 = int(512 - (512 - col) / 2)
        slices = cm.slices_3d
        gaps = cm.gaps_3d

        final_images_crop = final_test_images[:, row_1:row_2, col_1:col_2, :]
        final_masks_crop = final_test_masks[:, row_1:row_2, col_1:col_2, :]

        sitk.WriteImage(
            sitk.GetImageFromArray(np.uint16(final_test_masks[:, :, :, 1])),
            current_dir + '/DICOM/masksAortaGroundTruth.dcm')

        sitk.WriteImage(
            sitk.GetImageFromArray(np.uint16(final_test_masks[:, :, :, 1])),
            current_dir + '/mhd/masksAortaGroundTruth.mhd')

        # clear the masks for the final step:
        final_test_masks = np.where(final_test_masks == 0, 0, 0)

        num_patches = int((sum(vol_slices) - slices) / gaps)

        test_image = np.ndarray([1, slices, row, col, 1], dtype=np.int16)

        predicted_mask_volume = np.ndarray(
            [sum(vol_slices), row, col, nb_class], dtype=np.float32)

        # model = DenseUNet_3D.get_3d_denseunet()
        # model = UNet_3D.get_3d_unet_bn()
        # model = RSUNet_3D.get_3d_rsunet(opti)
        # model = UNet_3D.get_3d_wnet(opti)
        # model = UNet_3D.get_3d_unet()
        # model = UNet_3D.get_3d_unet()
        model = CNN_3D.get_3d_cnn()
        # model = RSUNet_3D_Gerda.get_3d_rsunet_Gerdafeature(opti)

        using_start_end = 1
        start_slice = cm.start_slice
        end_slice = -1

        if use_existing:
            model.load_weights(modelname)

        for i in range(num_patches):
            count1 = i * gaps
            count2 = i * gaps + slices
            test_image[0] = final_images_crop[count1:count2]

            predicted_mask = model.predict(test_image)

            if i == int(num_patches * 0.63):
                vs.visualize_activation_in_layer_one_plot_add_weights(
                    model, test_image, current_dir)
            else:
                pass

            predicted_mask_volume[count1:count2] += predicted_mask[
                0, :, :, :, :]

        t = len(predicted_mask_volume)
        for i in range(0, slices, gaps):
            predicted_mask_volume[i:(
                i +
                gaps)] = predicted_mask_volume[i:(i + gaps)] / (i / gaps + 1)

        for i in range(0, slices, gaps):
            predicted_mask_volume[(t - i -
                                   gaps):(t - i)] = predicted_mask_volume[
                                       (t - i - gaps):(t - i)] / (i / gaps + 1)

        for i in range(slices, (len(predicted_mask_volume) - slices)):
            predicted_mask_volume[i] = predicted_mask_volume[i] / (slices /
                                                                   gaps)

        np.save(cm.workingPath.testingNPY_path + 'testImages.npy',
                final_images_crop)
        np.save(cm.workingPath.testingNPY_path + 'testMasks.npy',
                final_masks_crop)
        np.save(cm.workingPath.testingNPY_path + 'masksTestPredicted.npy',
                predicted_mask_volume)

        final_images_crop = None
        final_masks_crop = None
        predicted_mask_volume = None

        imgs_origin = np.load(cm.workingPath.testingNPY_path +
                              'testImages.npy').astype(np.int16)
        imgs_true = np.load(cm.workingPath.testingNPY_path +
                            'testMasks.npy').astype(np.int8)
        imgs_predict = np.load(cm.workingPath.testingNPY_path +
                               'masksTestPredicted.npy').astype(np.float32)
        imgs_predict_threshold = np.load(cm.workingPath.testingNPY_path +
                                         'masksTestPredicted.npy').astype(
                                             np.float32)

        ########## ROC curve aorta

        actual = imgs_true[:, :, :, 1].reshape(-1)
        predictions = imgs_predict[:, :, :, 1].reshape(-1)
        # predictions = np.where(predictions < (0.7), 0, 1)

        false_positive_rate_aorta, true_positive_rate_aorta, thresholds_aorta = roc_curve(
            actual, predictions, pos_label=1)
        roc_auc_aorta = auc(false_positive_rate_aorta,
                            true_positive_rate_aorta)
        plt.figure(1, figsize=(6, 6))
        plt.figure(1)
        plt.title('ROC of Aorta')
        plt.plot(false_positive_rate_aorta, true_positive_rate_aorta, 'b')
        label = 'AUC = %0.2f' % roc_auc_aorta
        plt.legend(loc='lower right')
        plt.plot([0, 1], [0, 1], 'r--')
        plt.xlim([-0.0, 1.0])
        plt.ylim([-0.0, 1.0])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        # plt.show()
        saveName = '/Plots/ROC_Aorta_curve.png'
        plt.savefig(current_dir + saveName)
        plt.close()

        false_positive_rate_aorta = None
        true_positive_rate_aorta = None

        imgs_predict_threshold = np.where(imgs_predict_threshold < 0.3, 0, 1)

        if using_start_end == 1:
            aortaMean = lf.dice_coef_np(
                imgs_predict_threshold[start_slice:end_slice, :, :, 1],
                imgs_true[start_slice:end_slice, :, :, 1])
        else:
            aortaMean = lf.dice_coef_np(imgs_predict_threshold[:, :, :, 1],
                                        imgs_true[:, :, :, 1])

        np.savetxt(current_dir + '/Plots/AortaDicemean.txt',
                   np.array(aortaMean).reshape(1, ),
                   fmt='%.5f')

        print('Model file:', modelname)
        print('-' * 30)
        print('Aorta Dice Coeff', aortaMean)
        print('-' * 30)

        # Draw the subplots of figures:

        color1 = 'gray'  # ***
        color2 = 'viridis'  # ******

        transparent1 = 1.0
        transparent2 = 0.5

        # Slice parameters:

        #################################### Aorta
        # Automatically:

        steps = 40
        slice = range(0, len(imgs_origin), steps)
        plt_row = 3
        plt_col = int(len(imgs_origin) / steps)

        plt.figure(3, figsize=(25, 12))
        plt.figure(3)

        for i in slice:
            if i == 0:
                plt_num = int(i / steps) + 1
            else:
                plt_num = int(i / steps)

            if plt_num <= plt_col:

                ax1 = plt.subplot(plt_row, plt_col, plt_num)
                title = 'slice=' + str(i)
                plt.title(title)
                ax1.imshow(imgs_origin[i, :, :, 0],
                           cmap=color1,
                           alpha=transparent1)
                ax1.imshow(imgs_true[i, :, :, 1],
                           cmap=color2,
                           alpha=transparent2)

                ax2 = plt.subplot(plt_row, plt_col, plt_num + plt_col)
                title = 'slice=' + str(i)
                plt.title(title)
                ax2.imshow(imgs_origin[i, :, :, 0],
                           cmap=color1,
                           alpha=transparent1)
                ax2.imshow(imgs_predict[i, :, :, 1],
                           cmap=color2,
                           alpha=transparent2)

                ax3 = plt.subplot(plt_row, plt_col, plt_num + 2 * plt_col)
                title = 'slice=' + str(i)
                plt.title(title)
                ax3.imshow(imgs_origin[i, :, :, 0],
                           cmap=color1,
                           alpha=transparent1)
                ax3.imshow(imgs_predict_threshold[i, :, :, 1],
                           cmap=color2,
                           alpha=transparent2)
            else:
                pass

        modelname = cm.modellist[0]

        imageName = re.findall(r'\d+\.?\d*', modelname)
        epoch_num = int(imageName[0]) + 1
        accuracy = float(
            np.loadtxt(current_dir + '/Plots/AortaDicemean.txt', float))

        # saveName = 'epoch_' + str(epoch_num) + '_dice_' +str(accuracy) + '.png'
        saveName = '/Plots/epoch_Aorta_%02d_dice_%.3f.png' % (epoch_num - 1,
                                                              accuracy)

        plt.subplots_adjust(left=0.0,
                            bottom=0.05,
                            right=1.0,
                            top=0.95,
                            hspace=0.3,
                            wspace=0.3)
        plt.savefig(current_dir + saveName)
        plt.close()
        # plt.show()

        print('Images saved')
        # Save npy as dcm files:

        final_test_aorta_predicted_threshold = final_test_masks[:, :, :, 1]

        final_test_aorta_predicted_threshold[:, row_1:row_2, col_1:
                                             col_2] = imgs_predict_threshold[:, :, :,
                                                                             1]

        new_imgs_dcm = sitk.GetImageFromArray(
            np.uint16(final_test_images + 4000))
        new_imgs_aorta_predict_dcm = sitk.GetImageFromArray(
            np.uint16(final_test_aorta_predicted_threshold))

        sitk.WriteImage(new_imgs_dcm,
                        current_dir + '/DICOM/imagesPredicted.dcm')
        sitk.WriteImage(new_imgs_aorta_predict_dcm,
                        current_dir + '/DICOM/masksAortaPredicted.dcm')

        sitk.WriteImage(new_imgs_dcm, current_dir + '/mhd/imagesPredicted.mhd')
        sitk.WriteImage(new_imgs_aorta_predict_dcm,
                        current_dir + '/mhd/masksAortaPredicted.mhd')

        # mt.SegmentDist(current_dir + '/mhd/masksAortaPredicted.mhd',current_dir + '/mhd/masksAortaGroundTruth.mhd', current_dir + '/Surface_Distance/Aorta')
        # mt.SegmentDist(current_dir + '/mhd/masksPulPredicted.mhd',current_dir + '/mhd/masksPulGroundTruth.mhd', current_dir + '/Surface_Distance/Pul')

        print('DICOM saved')

        # Clear memory for the next testing sample:

        final_test_aorta_predicted_threshold = None
        final_test_pul_predicted_threshold = None
        imgs_predict_threshold = None
        new_imgs_dcm = None
        new_imgs_aorta_predict_dcm = None
        new_imgs_pul_predict_dcm = None
        final_test_images = None
        final_test_masks = None
        imgs_origin = None
        imgs_predict = None
        imgs_true = None
        predicted_mask = None
        predictions = None

        endtime = datetime.datetime.now()
        print('-' * 30)
        print('running time:', endtime - starttime)

        log_file.close()
        sys.stdout = stdout_backup