-
Notifications
You must be signed in to change notification settings - Fork 2
/
predict_all_frames.py
executable file
·97 lines (77 loc) · 3.21 KB
/
predict_all_frames.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/python
import sys
import irtk
import heartdetector
import numpy as np
from glob import glob
import os
import argparse
from time import time
def get_ED_ES(patient_id):
f = "Images/"+patient_id+"/"+patient_id+"_ED_ES_time.txt"
f = open(f,"rb")
res = []
for line in f:
line = line.rstrip() # chomp
res.append( int(line.split(' ')[-1]) - 1 )
return res
parser = argparse.ArgumentParser(
description="""Left ventricule segmentation using Autocontext Random Forests.
Code for the MICCAI 2014 Segmentation challenge.""" )
parser.add_argument( 'patient_id', type=str )
parser.add_argument( '--forest', type=str, required=True )
parser.add_argument( '--frame', type=int, default=None )
parser.add_argument( '--all', action="store_true", default=False )
parser.add_argument( '--time', action="store_true", default=False )
parser.add_argument( '--nb_autocontext', type=int, default=4 )
parser.add_argument( '--debug', action="store_true", default=False )
args = parser.parse_args()
if not args.time:
print args
start = time()
detector = heartdetector.HeartDetector( name=args.forest )
detector.load()
if not args.time:
print detector
if not os.path.exists("predictions/"+args.patient_id):
os.makedirs("predictions/"+args.patient_id)
all_frames = sorted(glob("denoised/"+args.patient_id+"_frame*.nii.gz"))
tmp = irtk.imread(all_frames[0])
mask = irtk.zeros(tmp.get_header(),dtype='float32')
for f in all_frames:
if "_seg" in f:
continue
mask += irtk.imread(f)
mask = (mask > 0).astype('uint8')
# if args.frame is not None:
# all_frames = [all_frames[args.frame]]
# elif not args.all:
# ED,ES = get_ED_ES(args.patient_id)
# all_frames = [all_frames[ED],all_frames[ES]]
for f in all_frames:
if "_seg" in f:
continue
if not args.time:
print f
all_proba = heartdetector.predict( detector,
f,
ga=0.0,
nb_autocontext=args.nb_autocontext,
mask=mask,
debug=args.debug,
return_all=not args.time )
if isinstance(all_proba,irtk.Image):
irtk.imwrite("predictions/"+args.patient_id+"/iter"+str(args.nb_autocontext)+"_"+os.path.basename(f),all_proba)
irtk.imwrite("predictions/"+args.patient_id+"/iter"+str(args.nb_autocontext)+"_"+os.path.basename(f)[:-len('.nii.gz')]
+"_hard.nii.gz",
detector.groups[0].hard_thresholding( all_proba[1].resample(0.0005),
smoothing=4.0*0.001 ) )
else:
for i,proba in enumerate(all_proba,start=1):
irtk.imwrite("predictions/"+args.patient_id+"/iter"+str(i)+"_"+os.path.basename(f),proba)
irtk.imwrite("predictions/"+args.patient_id+"/iter"+str(i)+"_"+os.path.basename(f)[:-len('.nii.gz')]
+"_hard.nii.gz",
detector.groups[0].hard_thresholding( proba[1].resample(0.0005),
smoothing=4.0*0.001 ) )
stop = time()
print stop - start