-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
107 lines (93 loc) · 3.98 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import numpy as np
import vtk
from vtk.util.numpy_support import vtk_to_numpy, numpy_to_vtk
import SimpleITK as sitk
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "vtk_utils"))
from vtk_utils import *
import csv
def extract_surface(poly):
connectivity = vtk.vtkPolyDataConnectivityFilter()
connectivity.SetInputData(poly)
connectivity.ColorRegionsOn()
connectivity.SetExtractionModeToAllRegions()
connectivity.Update()
poly = connectivity.GetOutput()
return poly
def surface_distance(p_surf, g_surf):
dist_fltr = vtk.vtkDistancePolyDataFilter()
dist_fltr.SetInputData(1, p_surf)
dist_fltr.SetInputData(0, g_surf)
dist_fltr.SignedDistanceOff()
dist_fltr.Update()
distance = vtk_to_numpy(dist_fltr.GetOutput().GetPointData().GetArray('Distance'))
return distance, dist_fltr.GetOutput()
def evaluate_poly_distances(poly, gt, NUM):
# compute assd and hausdorff distances
assd_list, haus_list, poly_list = [], [], []
poly =extract_surface(poly)
for i in range(NUM):
poly_i = thresholdPolyData(poly, 'Scalars_', (i+1, i+1),'cell')
if poly_i.GetNumberOfPoints() == 0:
print("Mesh based methods.")
poly_i = thresholdPolyData(poly, 'RegionId', (i, i), 'point')
gt_i = thresholdPolyData(gt, 'Scalars_', (i+1, i+1),'cell')
print("DEBUG: ", poly_i.GetNumberOfPoints(), gt_i.GetNumberOfPoints())
pred2gt_dist, pred2gt = surface_distance(gt_i, poly_i)
gt2pred_dist, gt2pred = surface_distance(poly_i, gt_i)
assd = (np.mean(pred2gt_dist)+np.mean(gt2pred_dist))/2
haus = max(np.max(pred2gt_dist), np.max(gt2pred_dist))
assd_list.append(assd)
haus_list.append(haus)
poly_list.append(pred2gt)
poly_dist = appendPolyData(poly_list)
# whole heart
pred2gt_dist, pred2gt = surface_distance(gt, poly)
gt2pred_dist, gt2pred = surface_distance(poly, gt)
assd = (np.mean(pred2gt_dist)+np.mean(gt2pred_dist))/2
haus = max(np.max(pred2gt_dist), np.max(gt2pred_dist))
assd_list.insert(0, assd)
haus_list.insert(0, haus)
print(assd_list)
print(haus_list)
return assd_list, haus_list, poly_dist
def dice_score(pred, true):
pred = pred.astype(np.int)
true = true.astype(np.int)
num_class = np.unique(true)
#change to one hot
dice_out = [None]*len(num_class)
for i in range(1, len(num_class)):
pred_c = pred == num_class[i]
true_c = true == num_class[i]
dice_out[i] = np.sum(pred_c*true_c)*2.0 / (np.sum(pred_c) + np.sum(true_c))
mask =( pred > 0 )+ (true > 0)
dice_out[0] = np.sum((pred==true)[mask]) * 2. / (np.sum(pred>0) + np.sum(true>0))
return dice_out
def jaccard_score(pred, true):
pred = pred.astype(np.int)
true = true.astype(np.int)
num_class = np.unique(true)
#change to one hot
jac_out = [None]*len(num_class)
for i in range(1, len(num_class)):
pred_c = pred == num_class[i]
true_c = true == num_class[i]
jac_out[i] = np.sum(pred_c*true_c) / (np.sum(pred_c) + np.sum(true_c)-np.sum(pred_c*true_c))
mask =( pred > 0 )+ (true > 0)
jac_out[0] = np.sum((pred==true)[mask]) / (np.sum(pred>0) + np.sum(true>0)-np.sum((pred==true)[mask]))
return jac_out
def evaluate_segmentation_accuracy(pred_vtk, gt_vtk):
pred_py = vtk_to_numpy(pred_vtk.GetPointData().GetScalars())
gt_py = vtk_to_numpy(gt_vtk.GetPointData().GetScalars())
for i, v in enumerate(np.unique(pred_py)):
pred_py[pred_py==v] = i
for i, v in enumerate(np.unique(gt_py)):
gt_py[gt_py==v] = i
print("Pred, gt seg ids: ", np.unique(pred_py), np.unique(gt_py))
dice_list = dice_score(pred_py, gt_py)
jac_list = jaccard_score(pred_py, gt_py)
print("dice: ", dice_list)
print("jaccard: ", jac_list)
return dice_list, jac_list