def plot_3d_no_gt(im_st, pred_st, res, thr = 0.5, plot_size_xy = 256):

	zoom_fac = [res[0]/res[1]/(im_st.shape[1]/plot_size_xy), plot_size_xy/im_st.shape[1], plot_size_xy/im_st.shape[1]]
	im_st = zoom(im_st.copy(), zoom = zoom_fac, order = 1)

	mask = np.zeros((im_st.shape[0],im_st.shape[1], im_st.shape[2]))

	#for i in range(im_st.shape[0]):
	#	mask[i, :, :] = lungmask_pro(im_st[i, :, :], morph = False)
	mask[:,:,:] = lungmask3D(im_st[:, :, :], morph = False)

	for i in range(im_st.shape[1]):
		mask[:, i, :] = cv2.Canny(np.uint8(maxminscale(mask[:, i, :])),50,100)
	
	pred_st[pred_st < thr] = 0 
	pred_st[pred_st >= thr] = 180

	pred_st = zoom(pred_st.copy(), zoom = zoom_fac, order = 0)

	
	mlab.figure(1, size = (1000,800))
	mlab.contour3d(pred_st, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 255)
	mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150)
	mlab.title('Prediction')

	mlab.show()
def plot_lungmask(im_st):
	mask = np.zeros((im_st.shape[0], im_st.shape[1], im_st.shape[2], 2))
	print('get mask')
	#for i in range(im_st.shape[0]):
	#	mask[i, :, :, 1] = lungmask(im_st[i, :, :, 0])
	mask[:,:,:,1] = lungmask3D(im_st[:, :, :, 0], morph = False)
	for i in range(im_st.shape[1]):
		mask[:, i, :, 1] = cv2.Canny(np.uint8(maxminscale(mask[:, i, :, 1])),50,100)

	print('plotting')

	mlab.contour3d(mask[:, :, :, 1], colormap = 'Blues', opacity = 0.1)
	mlab.show()
def plot_lung_and_tumor(im_st, pred_st, gt_st, res = [], thr = 0.5):
	mask = np.zeros((im_st.shape[0],im_st.shape[1], im_st.shape[2],2))

	#for i in range(im_st.shape[0]):
	#	mask[i, :, :, 1] = lungmask_pro(im_st[i, :, :, 0], morph = False)
	mask[:,:,:,1] = lungmask3D(im_st[:, :, :, 0], morph = False)

	for i in range(im_st.shape[1]):
		mask[:, i, :, 1] = cv2.Canny(np.uint8(maxminscale(mask[:, i, :, 1])),50,100)
	
	pred_st[pred_st < thr] = 0 
	pred_st[pred_st >= thr] = 180
	gt_st[gt_st >= thr] = 180	

	# TODO - make zoom_fac dependent on resolution in z dir (true values)
	if not res:
		zoom_fac = im_st.shape[1]/im_st.shape[0]
	else: 
		zoom_fac = res[0]/res[1]/2

	
	mask = zoom(mask[:,:,:,1], zoom = [zoom_fac,1,1])
	pred_st = zoom(pred_st[:, :, :, 1], zoom = [zoom_fac,1,1])
	gt_st = zoom(gt_st[:, :, :, 1], zoom = [zoom_fac,1,1])

	mlab.figure(1, size = (1000,800))
	mlab.contour3d(pred_st, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 255)
	mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150)
	mlab.title('Prediction')

	mlab.figure(2, size = (1000,800))
	mlab.contour3d(gt_st, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 255)
	mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150)
	mlab.title('Ground truth')
	
	mlab.show()
Ejemplo n.º 4
0
## fill regions surrounded by lung regions -> final ROI
# to fill holes -> without filling unwanted larger region in the middle
res = binary_fill_holes(mask).astype(int)

# need to filter out the main airways, because elsewise they will affect the resulting algorithm
res2 = remove_small_objects(label(res), min_size=1100)
res2[res2 > 0] = 1
#res2 = res.copy()

# first close boundaries to include juxta-vascular nodule candidates
mask = morphology.dilation(res2, disk(7))  # 7 (17 for more difficult ones)

# then fill inn nodules to include them, but not the much larger objects, ex. heart etc
mask = remove_small_objects(
    label(cv2.bitwise_not(maxminscale(mask))), min_size=300
)  # can change min_size larger if want to include larger nodule candidates
mask[mask != 0] = -1
mask = np.add(mask, np.ones(mask.shape))

# last erosion to complete the closing+, larger disk size because need to remove some of the lung boundaries
filled_tmp = morphology.erosion(mask,
                                disk(9))  # 9 (19 for more difficult ones)

imgclean(filled_tmp, Figure=True)

#image = filled_tmp*img_orig
image = img_orig.copy()
image[filled_tmp == 0] = np.amin(img_orig)

#image = filled_tmp*img_orig
def predict_gen(model_path,
                start_path,
                pred_only_gt=False,
                maxmin_input=False,
                one_zero_scale=False,
                only=[],
                exclude=[],
                predict=True):

    model = load_model(model_path, compile=False)
    model_dim = len(
        model.get_config()['layers'][0]['config']['batch_input_shape'])
    model_config = model.get_config(
    )['layers'][0]['config']['batch_input_shape']
    print('model:',
          model_path.split('/')[-1], ', model dimensions:', model_dim - 2)

    s = '/'
    dcm_file_path = []
    dcm_files = {}
    xml_files = {}
    print()

    for dir1 in os.listdir(start_path):
        if dir1.startswith('.DS'):
            continue
        path1 = start_path + s + dir1

        for dir2 in os.listdir(path1):
            if dir2.startswith('.DS'):
                continue
            pat_nr = int(dir2.split('-')[2])
            path2 = path1 + s + dir2

            if len(only) > 20:
                sys.stdout.write("\033[F")
                print('reading through files, at patient_nr:', pat_nr)

            if int(pat_nr) in exclude:
                continue

            if only:
                if int(pat_nr) not in only:
                    continue

            for dir3 in os.listdir(path2):
                if dir3.startswith('.DS'):
                    continue
                path3 = path2 + s + dir3
                for dir4 in os.listdir(path3):
                    if dir4.startswith('.DS'):
                        continue
                    path4 = path3 + s + dir4

                    for file in os.listdir(path4):
                        if file.endswith('dcm') and len(
                                os.listdir(path4)) > 10:
                            dcm_file_path.append(path4 + s + file)
                        if file.endswith('.xml') and len(
                                os.listdir(path4)) > 10:
                            xml_file_path = path4 + s + file

            dcm_files[pat_nr] = dcm_file_path
            xml_files[pat_nr] = xml_file_path
            dcm_file_path = []

    for pat in dcm_files.keys():
        orig = []

        for image in dcm_files[pat]:
            itkimage = sitk.ReadImage(image)
            origin = np.array(list(reversed(itkimage.GetOrigin())))
            orig.append([origin[0], image])

        #resolution
        z_res = np.abs(sorted(orig)[0][0] - sorted(orig)[1][0])

        tree = et.parse(xml_files[pat])
        root = tree.getroot()

        nodDicSmall = {}  # nodule <= 3mm dict
        nodDicLarge = {}  #large nodule dict
        nonNodDic = {}  #non-nodule dict
        nod_list = []
        non_nod_list = []
        nod_small_list = []

        # Read all unblinded read sessions
        readSes = root.findall('{http://www.nih.gov}readingSession')
        for doctor in range(len(readSes)):

            #get real nodules(tumors)
            nodule = readSes[doctor].findall(
                '{http://www.nih.gov}unblindedReadNodule')

            for i in range(len(nodule)):
                for roi in nodule[i].findall('{http://www.nih.gov}roi'):
                    # single pixel nodules
                    if len(roi) <= 5:
                        zValue = roi.find('{http://www.nih.gov}imageZposition')
                        for edgeMap in roi.findall(
                                '{http://www.nih.gov}edgeMap'):
                            xValue = edgeMap.find('{http://www.nih.gov}xCoord')
                            yValue = edgeMap.find('{http://www.nih.gov}yCoord')

                            nodDicSmall.setdefault(float(zValue.text),
                                                   []).append([
                                                       int(xValue.text),
                                                       int(yValue.text)
                                                   ])

                    else:
                        zValue = roi.find('{http://www.nih.gov}imageZposition')
                        for edgeMap in roi.findall(
                                '{http://www.nih.gov}edgeMap'):
                            xValue = edgeMap.find('{http://www.nih.gov}xCoord')
                            yValue = edgeMap.find('{http://www.nih.gov}yCoord')

                            nodDicLarge.setdefault(float(zValue.text),
                                                   []).append([
                                                       int(xValue.text),
                                                       int(yValue.text)
                                                   ])
            nod_list.append(nodDicLarge)
            nod_small_list.append(nodDicSmall)
            nodDicSmall = {}
            nodDicLarge = {}

            #get non-nodules
            nonNodule = readSes[doctor].findall(
                '{http://www.nih.gov}nonNodule')

            for i in range(len(nonNodule)):
                zValue = nonNodule[i].find(
                    '{http://www.nih.gov}imageZposition')
                for locus in nonNodule[i].findall('{http://www.nih.gov}locus'):
                    xValue = locus.find('{http://www.nih.gov}xCoord')
                    yValue = locus.find('{http://www.nih.gov}yCoord')
                    nonNodDic.setdefault(float(zValue.text), []).append(
                        [int(xValue.text), int(yValue.text)])
            non_nod_list.append(nonNodDic)
            nonNodDic = {}

        output_im = np.zeros((4, len(dcm_files[pat]), 512, 512, 2),
                             dtype=np.uint8)

        input_im = np.zeros((len(dcm_files[pat]), 512, 512, 1))
        cnt = 0
        for orig, path in sorted(orig):
            itkimage = sitk.ReadImage(path)
            ct_scan = sitk.GetArrayFromImage(itkimage)
            im = ct_scan[0, :, :]
            input_im[cnt, :, :, 0] = im
            #resolution
            spacing = np.array(list(reversed(itkimage.GetSpacing())))
            xy_res = spacing[1]

            # large nodules
            for list_ in range(len(nod_list)):
                if orig in nod_list[list_].keys():

                    for j in range(len(nod_list[list_][orig])):
                        output_im[list_, cnt, nod_list[list_][orig][j][1],
                                  nod_list[list_][orig][j][0], 1] = 255

                    output_im[list_, cnt, :, :, 1] = object_filler(
                        output_im[list_, cnt, :, :, 1], (0, 0))
                    #output_im[list_, cnt, :, :, 0] = cv2.medianBlur(np.uint8(output_im[list_, cnt, :, :,0]), 7)

            cnt += 1

        out = np.zeros((len(dcm_files[pat]), 512, 512, 2))

        out[:, :, :, 1] = np.add(
            np.add(output_im[0, :, :, :, 1].astype(int),
                   output_im[1, :, :, :, 1].astype(int)),
            np.add(output_im[2, :, :, :, 1].astype(int),
                   output_im[3, :, :, :, 1].astype(int)))
        out[:, :, :, 1][out[:, :, :, 1] < 300] = 0
        out[:, :, :, 1][out[:, :, :, 1] >= 300] = 1
        #one-hot

        out[:, :, :, 0][out[:, :, :, 1] == 1] = 0
        out[:, :, :, 0][out[:, :, :, 1] == 0] = 1

        #resolution
        res = [z_res, xy_res]

        if predict:
            # ----- predicting 2d Unet ------
            if model_dim == 4:
                im = np.zeros((1, input_im.shape[1], input_im.shape[2],
                               input_im.shape[3]))
                pred_output = np.zeros(
                    (out.shape[0], out.shape[1], out.shape[2], out.shape[3]))
                for i in tqdm(range(input_im.shape[0])):
                    im[0, :, :, :] = input_im[i, :, :, :]

                    if pred_only_gt:
                        if np.count_nonzero(out[i, :, :, 1]) == 0:
                            continue
                    pred_output[i, :, :, :] = model.predict(im)

            # ----- predicting 3d Unet -----
            if model_dim == 5:
                inp = np.zeros((input_im.shape[0], 256, 256, 1))
                output = np.zeros((input_im.shape[0], 256, 256, 2))
                for i in range(input_im.shape[0]):

                    if maxmin_input:
                        inp[i, :, :,
                            0] = cv2.resize(maxminscale(input_im[i, :, :, 0]),
                                            (256, 256))
                    else:
                        inp[i, :, :, 0] = cv2.resize(input_im[i, :, :, 0],
                                                     (256, 256))

                    output[i, :, :,
                           0] = cv2.resize(out[i, :, :, 0], (256, 256),
                                           interpolation=cv2.INTER_NEAREST)
                    output[i, :, :,
                           1] = cv2.resize(out[i, :, :, 1], (256, 256),
                                           interpolation=cv2.INTER_NEAREST)

                input_im = inp.copy()
                out = output.copy()

                chunks = int(np.ceil(input_im.shape[0] / model_config[1]))

                pred_output = np.zeros((model_config[1] * chunks,
                                        model_config[2], model_config[3], 2))
                pred_output_5d = np.zeros(
                    (chunks, model_config[1], model_config[2], model_config[3],
                     2))
                im = np.zeros((chunks, model_config[1], model_config[2],
                               model_config[3], model_config[4]))

                for i in range(chunks):
                    for j in range(model_config[1]):
                        if model_config[1] * i + j >= input_im.shape[0]:
                            continue

                        im[i, j, :, :, :] = input_im[model_config[1] * i +
                                                     j, :, :, :]

                for i in tqdm(range(chunks)):

                    if pred_only_gt:
                        if np.count_nonzero(out[model_config[1] * i:(i + 1) *
                                                model_config[1], :, :,
                                                1]) == 0:
                            continue

                    pred_output_5d[i, :, :, :, :] = model.predict(
                        np.expand_dims(im[i, :, :, :, :], axis=0))

                for i in range(chunks):
                    for j in range(pred_output_5d.shape[1]):
                        pred_output[i * pred_output_5d.shape[1] +
                                    j, :, :, :] = pred_output_5d[i, j, :, :, :]

            yield input_im, out, pred_output, pat, res
        else:
            yield input_im, out, pat, res
Ejemplo n.º 6
0
        gt_vals.append(i)
t.toc()

print(gt_vals)

# get some image
num = int(input("select slice containing tumor: "))
img = data[num, :, :]

#num = 224 # slice number
#img = data[num,:,:] # 84 is a problem! lung segmentation not robust enough

img[img <=
    -1024] = -1024  # better visual contrast, BUT AFFECTS PERFORMANCE OF FCM!!!
img[img >= 1024] = 1024
img = maxminscale(img)  # OBS

#imgclean(img, Figure=True)

# keep original maxminscaled image -> becuase img is going to be altered
img_orig = img.copy()

# blur img
img = np.uint8(img)
img = cv2.medianBlur(img, 5)

# get dimensions of image
row_size, col_size = img.shape

# specify window for k-means to work on, such that you get the lung, and not the rest
middle = img[int(col_size / 5):int(col_size / 5 * 4),
Ejemplo n.º 7
0
# get ground truth data and display at which slice there is a nodule as a list
gt_data = output.value[:,:,:,0]
gt_vals = []
for i in range(gt_data.shape[0]):
	if len(np.unique(gt_data[i,:,:])) > 1:
		gt_vals.append(i)
print(gt_vals)


# choose slice
num = input("choose patient: ")
img = data[int(num),:,:]


img = np.asarray(img)
org_img = maxminscale(img)
org_img = np.uint8(org_img)
org_img = cv2.medianBlur(org_img, 11)


## 2) get 8-bit image
#img = maxminscale(img)
#img = img.astype(np.uint8)
img_orig = 1*img
img_orig = maxminscale(img_orig)

row_size= img.shape[0]
col_size = img.shape[1]

mean = np.mean(img)
std = np.std(img)
def test2d3d(im_st, pred_st, res, thr = 0.1, plot_size_xy = 150, VGG_model_path = ''):
	
	t = TicToc()
	im_ct = im_st.copy()
	zoom_fac = [res[0]/res[1]/(im_st.shape[1]/plot_size_xy), plot_size_xy/im_st.shape[1], plot_size_xy/im_st.shape[1]]

	im_st = zoom(im_st.copy(), zoom = zoom_fac, order = 1)
	mask = np.zeros((im_st.shape[0],im_st.shape[1], im_st.shape[2]))

	#for i in range(im_st.shape[0]):
	#	mask[i, :, :] = lungmask_pro(im_st[i, :, :], morph = False)
	mask[:,:,:] = lungmask3D(im_st[:, :, :], morph = False)

	for i in range(im_st.shape[1]):
		mask[:, i, :] = cv2.Canny(np.uint8(maxminscale(mask[:, i, :])),50,100)
	
	pred_st_tmp = pred_st.copy()

	pred_st_tmp = zoom(pred_st, zoom = zoom_fac, order = 0)
	bin_pred_3d = pred_st_tmp.copy()
	bin_pred_3d[bin_pred_3d < thr] = 0
	bin_pred_3d[bin_pred_3d >= thr] = 1

	bin_pred_2d = pred_st.copy()
	bin_pred_2d[bin_pred_2d < thr] = 0
	bin_pred_2d[bin_pred_2d >= thr] = 1


	#labels 
	label_3d = label(bin_pred_3d)
	label_2d = label(bin_pred_2d)
	print('Number of nodules found: ' + str(len(np.unique(label_2d))-1))

	mlab.figure(1, size = (1000,800))
	mlab.contour3d(pred_st_tmp, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1)

	
	def new_thr(event):
		 
		tmp_thr = pred_st[int(slider2.val)].copy()
		tmp_3d = pred_st_tmp.copy()
		threshold = event

		tmp_thr[tmp_thr < threshold] = 0
		tmp_thr[tmp_thr >= threshold] = 1

		tmp_3d[tmp_3d < threshold] = 0
		tmp_3d[tmp_3d >= threshold] = 1

		mlab.clf()
		#mlab.figure(1, size = (1000,800))
		mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1)
		if button1.get_status()[0]:
			mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150)
		#mlab.title('Prediction')

		ax.clear()
		ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray')
		#ax.imshow(tmp_thr, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3)
		ax.imshow(tmp_thr, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3)
		#ax.set_title('predicted')
		ax.set_axis_off()

		f.suptitle('slice '+str(slider2.val))
		f.canvas.draw_idle()


	def images(event):
		
		threshold = slider1.val
		tmp = pred_st[int(event)].copy()
		
		tmp[tmp < threshold] = 0
		tmp[tmp >= threshold] = 1

		ax.clear()
		ax.imshow(im_ct[int(event), :, :], cmap = 'gray')
		#ax.imshow(tmp, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3)
		ax.imshow(tmp, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3)
		ax.set_title('predicted')
		ax.set_axis_off()

		
		f.suptitle('slice '+str(int(slider2.val)))
		f.canvas.draw_idle()

	def up_scroll_alt(event):
		if event.key == "up":
			if (slider2.val + 2 > im_ct.shape[0]):
				1
				#print("Whoops, end of stack", print(slider2.val))
			else:
				slider2.set_val(slider2.val + 1)
		

	def down_scroll_alt(event):
		if event.key == "down":
			if (slider2.val - 1 < 0):
				1
				#print("Whoops, end of stack", print(slider2.val))
			else:
				slider2.set_val(slider2.val - 1)


	def up_scroll(event):
		if event.button == 'up':
			if (slider2.val + 2 > im_ct.shape[0]):
				1
				#print("Whoops, end of stack", print(slider2.val))
			else:
				slider2.set_val(slider2.val + 1)


	def down_scroll(event):
		if event.button == 'down':
			if (slider2.val - 1 < 0):
				1
				#print("Whoops, end of stack", print(slider2.val))
			else:
				slider2.set_val(slider2.val - 1)

		

	def show_lung(event):
		if not button4.get_status()[0]:
			tmp_3d = pred_st_tmp.copy()
			threshold = slider1.val

			tmp_3d[tmp_3d < threshold] = 0
			tmp_3d[tmp_3d >= threshold] = 1

			mlab.clf()
			mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1)
			if button1.get_status()[0]:
				mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150)

		elif button4.get_status()[0]:
			pred_class = pred_VGG(im_ct, pred_st, res, VGG_model_path)	
			pred_class = zoom(pred_class.copy(), zoom = zoom_fac, order = 0)
			mlab.clf()
			print(np.unique(pred_class)[1:])

			for i in np.unique(pred_class):
				if i == 0:
					continue
				tmp = pred_class.copy()
				tmp[pred_class != i] = 0
				mlab.contour3d(tmp, colormap = 'OrRd', color = tuple(colors[int(round((9/5)*i)),:]), vmin = 1, vmax = 5)
				mlab.scalarbar(orientation = 'vertical', nb_labels = 9, label_fmt='%.1f')

			if button1.get_status()[0]:
				mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150)

		mlab.orientation_axes(xlabel = 'z', ylabel = 'y', zlabel = 'x')


	def remove(event):
		# only add when seed point selected is at an axis
		if (event.ydata != None) or (event.xdata != None):
			ix, iy = int(event.ydata), int(event.xdata)

			# if already in add mode -> switch to remove mode
			# if button3.get_status()[0]:
			# 	button3.set_active(0)

		#print(event)

		if (str(event.inaxes).split('(')[0] == 'AxesSubplot') and button2.get_status()[0]:
			# remove
			coords_zoom = (np.array([slider2.val, ix, iy])*np.array(zoom_fac)).astype(int)
			coords_orig = (int(slider2.val), int(ix), int(iy))
			print(coords_zoom,coords_orig)

			val_3d = label_3d[tuple(coords_zoom)]
			pred_st_tmp[label_3d == val_3d] = 0

			val_2d = label_2d[coords_orig]
			pred_st[label_2d == val_2d] = 0

			# re-plot
			tmp_thr = pred_st[int(slider2.val)].copy()
			tmp_3d = pred_st_tmp.copy()
			threshold = slider1.val

			tmp_thr[tmp_thr < threshold] = 0
			tmp_thr[tmp_thr >= threshold] = 1

			tmp_3d[tmp_3d < threshold] = 0
			tmp_3d[tmp_3d >= threshold] = 1

			mlab.clf()
			mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1)
			if button1.get_status()[0]:
				mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150)

			ax.clear()
			ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray')
			#ax.imshow(tmp_thr, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3)
			ax.imshow(tmp_thr, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3)
			cursor = Cursor(ax, useblit=True, color='orange', linewidth=0.5)
			#ax.set_title('predicted')
			ax.set_axis_off()

			f.suptitle('slice '+str(slider2.val))
			f.canvas.draw_idle()



	# def add_window(event):
	# 	# only add when seed point selected is at an axis
	# 	if (event.ydata != None) or (event.xdata != None):
	# 		ix, iy = int(event.ydata), int(event.xdata)

	# 	if (str(event.inaxes).split('(')[0] == 'AxesSubplot') and button3.get_status()[0]:
	# 		figs.show()

	def add(event):

		def adv_exit(event):
			print(1)
			plt.close(figs)
			#plt.close()

		def adv_start(event):
			print(2)
			plt.close(figs)

		def adv_window():
			figs, axx = plt.subplots(num = 'Advanced settings')
			figs.canvas.mpl_connect('key_press_event', adv_exit)
			axx.axis('off')
			#bx1_as = plt.axes([0.05, 0.3, 0.15, 0.11])
			#bx1_as.set_axis_off()

			#button_as1 = CheckButtons(bx1_as, ['lambda1'], [1])
			axx.axis('off')
			ax_as1 = plt.axes([0.15, 0.02, 0.5, 0.05])
			slider_as1 = Slider(ax_as1, 'lambda1', 0.1, 4, dragging = True, valstep = 0.1)

			ax_as2 = plt.axes([0.15, 0.10, 0.5, 0.05])
			slider_as2 = Slider(ax_as2, 'lambda2', 0.1, 4, dragging = True, valstep = 0.1)

			ax_as3 = plt.axes([0.15, 0.18, 0.5, 0.05])
			slider_as3 = Slider(ax_as3, 'smoothing', 0, 4, dragging = True, valstep = 1)

			ax_as4 = plt.axes([0.15, 0.26, 0.5, 0.05])
			slider_as4 = Slider(ax_as4, 'iterations', 1, 1000, dragging = True, valstep = 1)

			ax_as5 = plt.axes([0.15, 0.34, 0.5, 0.05])
			slider_as5 = Slider(ax_as5, 'radius', 0.5, 5, dragging = True, valstep = 0.1)

			ax_b1 = plt.axes([0.85, 0.15, 0.07, 0.08])
			ax_b2 = plt.axes([0.85, 0.05, 0.07, 0.08])

			but_as1 = Button(ax_b1, 'exit', color = 'beige', hovercolor = 'beige')
			but_as2 = Button(ax_b2, 'start', color = 'beige', hovercolor = 'beige')

			#ax_textbox = plt.axes([0, 0.4, 0.5, 0.4])
			#axx.axis('off')
			textstr = "Press ENTER in terminal to start segmentation. \n Shouldn't be necessairy to change settings below, but can be tuned if \n resulting segmentation is not ideal. \n Especially if small nodule: try setting lambda1 <= lambda2 \n Or if very nonhomogeneous nodule: try setting lambda1 > lambda2."

			props = dict(boxstyle='round', facecolor='wheat')
			axx.text(-0.18, 0.25, textstr, transform=ax.transAxes, fontsize=12,
        verticalalignment='top', bbox=props)


			slider_as1.set_val(lambda1)
			slider_as2.set_val(lambda2)
			slider_as3.set_val(smoothing)
			slider_as4.set_val(iterations)
			slider_as5.set_val(rad)
			but_as1.on_clicked(adv_exit)
			but_as2.on_clicked(adv_start)

			figs.canvas.draw_idle()

			return figs, axx, slider_as1, slider_as2, slider_as3, slider_as4, slider_as5


		# only add when seed point selected is at an axis
		if (event.ydata != None) or (event.xdata != None):
			ix, iy = int(event.ydata), int(event.xdata)

		if (str(event.inaxes).split('(')[0] == 'AxesSubplot') and button3.get_status()[0]:

			# default settings for levet set
			lambda1=1; lambda2=4; smoothing = 1; iterations = 100; rad = 3

			# advanced settings window pop-up
			figs, axx, slider_as1, slider_as2, slider_as3, slider_as4, slider_as5 = adv_window()
			figs.show()

			# to start segmentation
			input('Press enter to start segmentation: ')
			plt.close(figs)

			# add
			coords_zoom = (np.array([slider2.val, ix, iy])*np.array(zoom_fac)).astype(int)
			coords_orig = (int(slider2.val), int(ix), int(iy))
			print(coords_zoom,coords_orig)

			# apply level set to grow in 3D from single seed point
			seg_tmp = level_set3D(im_ct, coords_orig, list(reversed(res)), lambda1=slider_as1.val, lambda2=slider_as2.val, smoothing = int(slider_as3.val), iterations = int(slider_as4.val), rad = slider_as5.val)
			#seg_tmp = level_set3D(im_ct, coords_orig, list(reversed(res)), smoothing = int(slider_as3.val), iterations = int(slider_as4.val), rad = slider_as5.val, method = 'GAC', alpha = 150, sigma = 5, balloon = 1)

			# if no nodule was segmented, break; else continue
			if (seg_tmp is None):
				print('No nodule was segmented. Try changing parameters...')
				return None
			else:
				# because of interpolation, went from {0,1} -> [0,1]. Need to threshold to get binary segment
				seg_tmp[seg_tmp < 0.5] = 0
				seg_tmp[seg_tmp >= 0.5] = 1
				pred_st[seg_tmp == 1] = 1
				label_2d[seg_tmp == 1] = len(np.unique(label_2d)) # OBS: need to add new label for new segment, in order to remove it properly!

				seg_tmp_3d = zoom(seg_tmp.copy(), zoom = zoom_fac, order = 1)
				seg_tmp_3d[seg_tmp_3d < 0.5] = 0
				seg_tmp_3d[seg_tmp_3d >= 0.5] = 1
				pred_st_tmp[seg_tmp_3d == 1] = 1
				label_3d[seg_tmp_3d == 1] = len(np.unique(label_3d))

				# re-plot
				tmp_thr = pred_st[int(slider2.val)].copy()
				tmp_3d = pred_st_tmp.copy()
				threshold = slider1.val

				tmp_thr[tmp_thr < threshold] = 0
				tmp_thr[tmp_thr >= threshold] = 1

				tmp_3d[tmp_3d < threshold] = 0
				tmp_3d[tmp_3d >= threshold] = 1

				mlab.clf()
				mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1)
				if button1.get_status()[0]:
					mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150)

				ax.clear()
				ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray')
				#ax.imshow(tmp_thr, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3)
				ax.imshow(tmp_thr, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3)
				cursor = Cursor(ax, useblit=True, color='orange', linewidth=0.5)
				#ax.set_title('predicted')
				ax.set_axis_off()

				f.suptitle('slice '+str(slider2.val))
				f.canvas.draw_idle()


	def remove_mode(event):
		# if already in add mode -> switch to remove mode
		if button3.get_status()[0]:
			button3.set_active(0)

		f.canvas.mpl_connect('button_press_event', remove)


	def add_mode(event):
		# if already in remove mode -> switch to add mode
		if button2.get_status()[0]:
			button2.set_active(0)

		f.canvas.mpl_connect('button_press_event', add)


	def classify(event):
		if button4.get_status()[0]:
			pred_class = pred_VGG(im_ct, pred_st, res, VGG_model_path)	
			pred_class = zoom(pred_class.copy(), zoom = zoom_fac, order = 0)
			mlab.clf()
			print(np.unique(pred_class)[1:])

			for i in np.unique(pred_class):
				if i == 0:
					continue
				tmp = pred_class.copy()
				tmp[pred_class != i] = 0
				mlab.contour3d(tmp, colormap = 'OrRd', color = tuple(colors[int(round((9/5)*i)),:]), vmin = 1, vmax = 5)
				mlab.scalarbar(orientation = 'vertical', nb_labels = 9, label_fmt='%.1f')

			if button1.get_status()[0]:
				mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150)

			mlab.orientation_axes(xlabel = 'z', ylabel = 'y', zlabel = 'x')

		elif not button4.get_status()[0]:
			# re-plot without classify
			tmp_thr = pred_st[int(slider2.val)].copy()
			tmp_3d = pred_st_tmp.copy()
			threshold = slider1.val

			tmp_thr[tmp_thr < threshold] = 0
			tmp_thr[tmp_thr >= threshold] = 1

			tmp_3d[tmp_3d < threshold] = 0
			tmp_3d[tmp_3d >= threshold] = 1

			mlab.clf()
			mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1)
			if button1.get_status()[0]:
				mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150)

			ax.clear()
			ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray')
			#ax.imshow(tmp_thr, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3)
			ax.imshow(tmp_thr, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3)
			ax.set_axis_off()

			f.suptitle('slice '+str(slider2.val))
			f.canvas.draw_idle()




	# Exit app when ESC is pressed
	def quit(event):
		if event.key == "escape":
			plt.close()
			mlab.close()



	# default settings for levet set for add-event
	lambda1=1; lambda2=4; smoothing = 1; iterations = 100; rad = 3


	# colormap for 3D-malignancy-plot for classify-event
	dz = list(range(1,10))
	norm = plt.Normalize()
	colors = plt.cm.OrRd(norm(dz))
	colors = np.array(colors)[:,:3]


	# plot to make simulator gator
	f, ax = plt.subplots(1,1, figsize = (12, 12))

	f.canvas.mpl_connect('key_press_event', up_scroll_alt)
	f.canvas.mpl_connect('key_press_event', down_scroll_alt)
	f.canvas.mpl_connect('scroll_event', up_scroll)
	f.canvas.mpl_connect('scroll_event', down_scroll)
	f.canvas.mpl_connect('key_press_event', quit)
	#f.canvas.mpl_connect('button_press_event', remove)
	#f.canvas.mpl_connect('button_press_event', add)

	b1ax = plt.axes([0.05, 0.2, 0.15, 0.11])
	b1ax.set_axis_off()
	b2ax = plt.axes([0.05, 0.35, 0.15, 0.11])
	b2ax.set_axis_off()
	b3ax = plt.axes([0.05, 0.5, 0.15, 0.11])
	b3ax.set_axis_off()
	b4ax = plt.axes([0.05, 0.65, 0.15, 0.11])
	b4ax.set_axis_off()


	s1ax = plt.axes([0.25, 0.08, 0.5, 0.03])
	s2ax = plt.axes([0.25, 0.02, 0.5, 0.03])

	slider1 = Slider(s1ax, 'threshold', 0.1, 1.0, dragging = True, valstep = 0.05)
	slider2 = Slider(s2ax, 'slice', 0.0, im_ct.shape[0]-1, valstep = 1)

	button1 = CheckButtons(b1ax, ['lung'], [False])
	button2 = CheckButtons(b2ax, ['Remove'], [False])
	button3 = CheckButtons(b3ax, ['Add'], [False])
	button4 = CheckButtons(b4ax, ['Classify'], [False])

	slider1.set_val(0.5)
	slider2.set_val(0)
	f.subplots_adjust(bottom = 0.15)

	#ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray')
	#ax.imshow(pred_st[int(slider2.val), :, :], cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3)
	ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray')
	ax.imshow(pred_st[int(slider2.val), :, :], cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3)
	cursor = Cursor(ax, useblit=True, color='orange', linewidth=0.5)
	ax.set_title('predicted')
	ax.set_axis_off()

	f.suptitle('slice '+str(slider2.val))

	button1.on_clicked(show_lung)
	button2.on_clicked(remove_mode)
	button3.on_clicked(add_mode)
	button4.on_clicked(classify)
	slider1.on_changed(new_thr)
	slider2.on_changed(images)

	plt.show()
Ejemplo n.º 9
0
from scipy.misc import imread
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import scipy.ndimage.filters as filters
from skimage import measure
import drlse_algo as drlse
import numpy as np
from image_functions import maxminscale, imgclean, img2pixels
import cv2

#img = plt.imread('test_end.png')
img = plt.imread('Figure_1_org.png')

img = img[:, :, 0]
img = maxminscale(img)
org_img = 1 * img  # need this later!

# initial seed point -> to grow from
seed_point_full = (int(366.7933774834437), int(315.6754966887417))
reg = 25
img = img[seed_point_full[0] - reg:seed_point_full[0] + reg,
          seed_point_full[1] - reg:seed_point_full[1] + reg]

seed_point = (
    reg, reg
)  # seed point for new smaller image 50x50 -> need to rename later.. np

# parameters
timestep = 1  # time step (1)
mu = 0.2 / timestep  # coefficient of the distance regularization term R(phi)
iter_inner = 4  # (4)