Ejemplo n.º 1
0
 def test_sensitivity_empty(self):
     """sensitivity: should work on emtpy BasePairs"""
     # both empty
     self.assertFloatEqual(sensitivity(BasePairs([]), BasePairs([])), 1)
     pred = BasePairs([(6, 1), (10, 11), (3, 12), (13, 20), (14, 19),
                       (15, 18)])
     # prediction emtpy
     self.assertFloatEqual(sensitivity(BasePairs([]), pred), 0)
     # reference empty
     self.assertFloatEqual(sensitivity(pred, BasePairs([])), 0)
Ejemplo n.º 2
0
    def test_sensitivity_dupl(self):
        """sensitivity: should handle duplicates, pseudo, None"""
        ref = BasePairs([(1, 6), (2, 5), (3, 10), (7, None), (None, None),
                         (5, 2), (4, 9)])
        pred = BasePairs([(6, 1), (10, 11), (3, 12)])
        self.assertFloatEqual(sensitivity(ref, pred), 0.25)

        pred = BasePairs([(6, 1), (10, 11), (3, 12), (20, None), (None, None),
                          (1, 6)])
        self.assertFloatEqual(sensitivity(ref, pred), 0.25)
Ejemplo n.º 3
0
 def test_sensitivity_general(self):
     """sensitivity: should work in general"""
     ref = BasePairs([(1, 6), (2, 5), (3, 10)])
     pred = BasePairs([(6, 1), (10, 11), (3, 12)])
     # one good prediction
     self.assertFloatEqual(sensitivity(ref, pred), 1 / 3)
     # over-prediction not penalized
     pred = BasePairs([(6, 1), (10, 11), (3, 12), (13, 20), (14, 19),
                       (15, 18)])
     self.assertFloatEqual(sensitivity(ref, pred), 1 / 3)
Ejemplo n.º 4
0
 def test_sensitivity(self):
     """sensitivity: check against compare_ct.pm"""
     sen = sensitivity(self.true, self.predicted)
     self.assertEqual(sen, 0.4)
Ejemplo n.º 5
0
def evaluate_test(model, test_ds, num_test_examples, cspace, epochs, save_model_path=None, type_train='',write_images=True, it=0):
    if (save_model_path != None):
        model = models.load_model(
            save_model_path,
            custom_objects={
                'bce_dice_loss': bce_dice_loss,
                'dice_loss': dice_loss
            })
    # Let's visualize some of the outputs
    mjccard = 0
    score = 0
    v_jaccard = np.zeros(num_test_examples)
    v_sensitivity = np.zeros(num_test_examples)
    v_specificity = np.zeros(num_test_examples)
    v_accuracy = np.zeros(num_test_examples)
    v_dice = np.zeros(num_test_examples)

    crf_jaccard = np.zeros(num_test_examples)
    crf_sensitivity = np.zeros(num_test_examples)
    crf_specificity = np.zeros(num_test_examples)
    crf_accuracy = np.zeros(num_test_examples)
    crf_dice = np.zeros(num_test_examples)

    data_aug_iter = test_ds.make_one_shot_iterator()
    next_element = data_aug_iter.get_next()
    if(not os.path.exists('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/predict/')):
            os.makedirs('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/predict/')
    for j in range(num_test_examples):
        # Running next element in our graph will produce a batch of images
        batch_of_imgs, label = tf.keras.backend.get_session().run(next_element)
        img = batch_of_imgs[0]

        predicted_label = model.predict(batch_of_imgs)[0]
        mpimg.imsave('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/predict/' + str(j) + '.png', predicted_label[:,:,0])
        mask_pred = (predicted_label[:, :, 0] > 0.55).astype(int)
        label = label.astype(int)

        v_jaccard[j] = fjaccard(label[0, :, :, 0], mask_pred)
        v_sensitivity[j] = utils.sensitivity(label[0,:,:,0], mask_pred)
        v_specificity[j] = utils.specificity(label[0,:,:,0], mask_pred)
        v_accuracy[j] = utils.accuracy(label[0,:,:,0], mask_pred)
        v_dice[j] = utils.dice_coeff(label[0,:,:,0], mask_pred)
        score += v_jaccard[j] if v_jaccard[j] >= 0.65 else 0
        print(score)
        mjccard += v_jaccard[j]

        img_rgb = img[:, :, :3]

        if(cspace == 'HSV'):
            img_rgb = tf.keras.backend.get_session().run(tf.image.hsv_to_rgb(img_rgb))
        elif(cspace == 'LAB'):
            img_rgb = tf.keras.backend.get_session().run(Conv_img.lab_to_rgb(img_rgb))

        crf_mask = utils.dense_crf(np.array(img_rgb*255).astype(np.uint8), np.array(predicted_label[:, :, 0]).astype(np.float32))

        crf_jaccard[j] = fjaccard(label[0, :, :, 0], crf_mask)
        crf_sensitivity[j] = utils.sensitivity(label[0,:,:,0], crf_mask)
        crf_specificity[j] = utils.specificity(label[0,:,:,0], crf_mask)
        crf_accuracy[j] = utils.accuracy(label[0,:,:,0], crf_mask)
        crf_dice[j] = utils.dice_coeff(label[0,:,:,0], crf_mask)

        if(write_images):
            fig = plt.figure(figsize=(25, 25))

            plt.subplot(1, 4, 1)
            plt.imshow(img[:, :, :3])
            plt.title("Input image")
            
            plt.subplot(1, 4, 2)
            plt.imshow(label[0, :, :, 0])
            plt.title("Actual Mask")
            
            plt.subplot(1, 4, 3)
            plt.imshow(predicted_label[:, :, 0] > 0.55)
            plt.title("Predicted Mask\n" +
                        "Jaccard = " + str(v_jaccard[j]) +
                        '\nSensitivity = ' + str(v_sensitivity[j]) +
                        '\nSpecificity = ' + str(v_specificity[j]) +
                        '\nAccuracy = ' + str(v_accuracy[j]) +
                        '\nDice = ' + str(v_dice[j]))
            
            plt.subplot(1, 4, 4)
            plt.imshow(crf_mask)
            plt.title("CRF Mask\n" +
                        "Jaccard = " + str(crf_jaccard[j]) +
                        '\nSensitivity = ' + str(crf_sensitivity[j]) +
                        '\nSpecificity = ' + str(crf_specificity[j]) +
                        '\nAccuracy = ' + str(crf_accuracy[j]) +
                        '\nDice = ' + str(crf_dice[j]))
            
            fig.savefig(
                'pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/' + str(j) + '.png',
                bbox_inches='tight')
            plt.close(fig)
            mpimg.imsave('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/predict/' + str(j) + '.png', predicted_label[:,:,0])
            plt.close()

    mjccard /= num_test_examples
    score /= num_test_examples
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/jaccard', v_jaccard)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/sensitivity', v_sensitivity)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/specificity', v_specificity)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/accuracy', v_accuracy)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/dice', v_dice)
    with open('pos_results/' + type_train + cspace + '/' + str(epochs)  + '/' + str(it) + '/score','w') as f:
        f.write('Score = ' + str(score) +
        '\nSensitivity = ' + str(np.mean(v_sensitivity)) +
        '\nSpecificity = ' + str(np.mean(v_specificity)) +
        '\nAccuracy = ' + str(np.mean(v_accuracy)) +
        '\nDice = ' + str(np.mean(v_dice)) +
        '\nJaccars = ' + str(np.mean(v_jaccard)))

    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/crf_jaccard', crf_jaccard)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/crf_sensitivity', crf_sensitivity)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/crf_crf_specificity', crf_specificity)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/crf_accuracy', crf_accuracy)
    np.savetxt('pos_results/' + type_train + cspace + '/' + str(epochs) + '/' + str(it) + '/crf_dice', crf_dice)
    with open('pos_results/' + type_train + cspace + '/' + str(epochs)  + '/' + str(it) + '/crf_score','w') as f:
        f.write('Sensitivity = ' + str(np.mean(crf_sensitivity)) +
        '\nSpecificity = ' + str(np.mean(crf_specificity)) +
        '\nAccuracy = ' + str(np.mean(crf_accuracy)) +
        '\nDice = ' + str(np.mean(crf_dice)) +
        '\nJaccars = ' + str(np.mean(crf_jaccard)))

    print('Jccard = ' + str(mjccard))
    print('Score = ' + str(score))
    return mjccard, score