def plot_3d_no_gt(im_st, pred_st, res, thr = 0.5, plot_size_xy = 256): zoom_fac = [res[0]/res[1]/(im_st.shape[1]/plot_size_xy), plot_size_xy/im_st.shape[1], plot_size_xy/im_st.shape[1]] im_st = zoom(im_st.copy(), zoom = zoom_fac, order = 1) mask = np.zeros((im_st.shape[0],im_st.shape[1], im_st.shape[2])) #for i in range(im_st.shape[0]): # mask[i, :, :] = lungmask_pro(im_st[i, :, :], morph = False) mask[:,:,:] = lungmask3D(im_st[:, :, :], morph = False) for i in range(im_st.shape[1]): mask[:, i, :] = cv2.Canny(np.uint8(maxminscale(mask[:, i, :])),50,100) pred_st[pred_st < thr] = 0 pred_st[pred_st >= thr] = 180 pred_st = zoom(pred_st.copy(), zoom = zoom_fac, order = 0) mlab.figure(1, size = (1000,800)) mlab.contour3d(pred_st, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 255) mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) mlab.title('Prediction') mlab.show()
def plot_lungmask(im_st): mask = np.zeros((im_st.shape[0], im_st.shape[1], im_st.shape[2], 2)) print('get mask') #for i in range(im_st.shape[0]): # mask[i, :, :, 1] = lungmask(im_st[i, :, :, 0]) mask[:,:,:,1] = lungmask3D(im_st[:, :, :, 0], morph = False) for i in range(im_st.shape[1]): mask[:, i, :, 1] = cv2.Canny(np.uint8(maxminscale(mask[:, i, :, 1])),50,100) print('plotting') mlab.contour3d(mask[:, :, :, 1], colormap = 'Blues', opacity = 0.1) mlab.show()
def plot_lung_and_tumor(im_st, pred_st, gt_st, res = [], thr = 0.5): mask = np.zeros((im_st.shape[0],im_st.shape[1], im_st.shape[2],2)) #for i in range(im_st.shape[0]): # mask[i, :, :, 1] = lungmask_pro(im_st[i, :, :, 0], morph = False) mask[:,:,:,1] = lungmask3D(im_st[:, :, :, 0], morph = False) for i in range(im_st.shape[1]): mask[:, i, :, 1] = cv2.Canny(np.uint8(maxminscale(mask[:, i, :, 1])),50,100) pred_st[pred_st < thr] = 0 pred_st[pred_st >= thr] = 180 gt_st[gt_st >= thr] = 180 # TODO - make zoom_fac dependent on resolution in z dir (true values) if not res: zoom_fac = im_st.shape[1]/im_st.shape[0] else: zoom_fac = res[0]/res[1]/2 mask = zoom(mask[:,:,:,1], zoom = [zoom_fac,1,1]) pred_st = zoom(pred_st[:, :, :, 1], zoom = [zoom_fac,1,1]) gt_st = zoom(gt_st[:, :, :, 1], zoom = [zoom_fac,1,1]) mlab.figure(1, size = (1000,800)) mlab.contour3d(pred_st, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 255) mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) mlab.title('Prediction') mlab.figure(2, size = (1000,800)) mlab.contour3d(gt_st, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 255) mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) mlab.title('Ground truth') mlab.show()
## fill regions surrounded by lung regions -> final ROI # to fill holes -> without filling unwanted larger region in the middle res = binary_fill_holes(mask).astype(int) # need to filter out the main airways, because elsewise they will affect the resulting algorithm res2 = remove_small_objects(label(res), min_size=1100) res2[res2 > 0] = 1 #res2 = res.copy() # first close boundaries to include juxta-vascular nodule candidates mask = morphology.dilation(res2, disk(7)) # 7 (17 for more difficult ones) # then fill inn nodules to include them, but not the much larger objects, ex. heart etc mask = remove_small_objects( label(cv2.bitwise_not(maxminscale(mask))), min_size=300 ) # can change min_size larger if want to include larger nodule candidates mask[mask != 0] = -1 mask = np.add(mask, np.ones(mask.shape)) # last erosion to complete the closing+, larger disk size because need to remove some of the lung boundaries filled_tmp = morphology.erosion(mask, disk(9)) # 9 (19 for more difficult ones) imgclean(filled_tmp, Figure=True) #image = filled_tmp*img_orig image = img_orig.copy() image[filled_tmp == 0] = np.amin(img_orig) #image = filled_tmp*img_orig
def predict_gen(model_path, start_path, pred_only_gt=False, maxmin_input=False, one_zero_scale=False, only=[], exclude=[], predict=True): model = load_model(model_path, compile=False) model_dim = len( model.get_config()['layers'][0]['config']['batch_input_shape']) model_config = model.get_config( )['layers'][0]['config']['batch_input_shape'] print('model:', model_path.split('/')[-1], ', model dimensions:', model_dim - 2) s = '/' dcm_file_path = [] dcm_files = {} xml_files = {} print() for dir1 in os.listdir(start_path): if dir1.startswith('.DS'): continue path1 = start_path + s + dir1 for dir2 in os.listdir(path1): if dir2.startswith('.DS'): continue pat_nr = int(dir2.split('-')[2]) path2 = path1 + s + dir2 if len(only) > 20: sys.stdout.write("\033[F") print('reading through files, at patient_nr:', pat_nr) if int(pat_nr) in exclude: continue if only: if int(pat_nr) not in only: continue for dir3 in os.listdir(path2): if dir3.startswith('.DS'): continue path3 = path2 + s + dir3 for dir4 in os.listdir(path3): if dir4.startswith('.DS'): continue path4 = path3 + s + dir4 for file in os.listdir(path4): if file.endswith('dcm') and len( os.listdir(path4)) > 10: dcm_file_path.append(path4 + s + file) if file.endswith('.xml') and len( os.listdir(path4)) > 10: xml_file_path = path4 + s + file dcm_files[pat_nr] = dcm_file_path xml_files[pat_nr] = xml_file_path dcm_file_path = [] for pat in dcm_files.keys(): orig = [] for image in dcm_files[pat]: itkimage = sitk.ReadImage(image) origin = np.array(list(reversed(itkimage.GetOrigin()))) orig.append([origin[0], image]) #resolution z_res = np.abs(sorted(orig)[0][0] - sorted(orig)[1][0]) tree = et.parse(xml_files[pat]) root = tree.getroot() nodDicSmall = {} # nodule <= 3mm dict nodDicLarge = {} #large nodule dict nonNodDic = {} #non-nodule dict nod_list = [] non_nod_list = [] nod_small_list = [] # Read all unblinded read sessions readSes = root.findall('{http://www.nih.gov}readingSession') for doctor in range(len(readSes)): #get real nodules(tumors) nodule = readSes[doctor].findall( '{http://www.nih.gov}unblindedReadNodule') for i in range(len(nodule)): for roi in nodule[i].findall('{http://www.nih.gov}roi'): # single pixel nodules if len(roi) <= 5: zValue = roi.find('{http://www.nih.gov}imageZposition') for edgeMap in roi.findall( '{http://www.nih.gov}edgeMap'): xValue = edgeMap.find('{http://www.nih.gov}xCoord') yValue = edgeMap.find('{http://www.nih.gov}yCoord') nodDicSmall.setdefault(float(zValue.text), []).append([ int(xValue.text), int(yValue.text) ]) else: zValue = roi.find('{http://www.nih.gov}imageZposition') for edgeMap in roi.findall( '{http://www.nih.gov}edgeMap'): xValue = edgeMap.find('{http://www.nih.gov}xCoord') yValue = edgeMap.find('{http://www.nih.gov}yCoord') nodDicLarge.setdefault(float(zValue.text), []).append([ int(xValue.text), int(yValue.text) ]) nod_list.append(nodDicLarge) nod_small_list.append(nodDicSmall) nodDicSmall = {} nodDicLarge = {} #get non-nodules nonNodule = readSes[doctor].findall( '{http://www.nih.gov}nonNodule') for i in range(len(nonNodule)): zValue = nonNodule[i].find( '{http://www.nih.gov}imageZposition') for locus in nonNodule[i].findall('{http://www.nih.gov}locus'): xValue = locus.find('{http://www.nih.gov}xCoord') yValue = locus.find('{http://www.nih.gov}yCoord') nonNodDic.setdefault(float(zValue.text), []).append( [int(xValue.text), int(yValue.text)]) non_nod_list.append(nonNodDic) nonNodDic = {} output_im = np.zeros((4, len(dcm_files[pat]), 512, 512, 2), dtype=np.uint8) input_im = np.zeros((len(dcm_files[pat]), 512, 512, 1)) cnt = 0 for orig, path in sorted(orig): itkimage = sitk.ReadImage(path) ct_scan = sitk.GetArrayFromImage(itkimage) im = ct_scan[0, :, :] input_im[cnt, :, :, 0] = im #resolution spacing = np.array(list(reversed(itkimage.GetSpacing()))) xy_res = spacing[1] # large nodules for list_ in range(len(nod_list)): if orig in nod_list[list_].keys(): for j in range(len(nod_list[list_][orig])): output_im[list_, cnt, nod_list[list_][orig][j][1], nod_list[list_][orig][j][0], 1] = 255 output_im[list_, cnt, :, :, 1] = object_filler( output_im[list_, cnt, :, :, 1], (0, 0)) #output_im[list_, cnt, :, :, 0] = cv2.medianBlur(np.uint8(output_im[list_, cnt, :, :,0]), 7) cnt += 1 out = np.zeros((len(dcm_files[pat]), 512, 512, 2)) out[:, :, :, 1] = np.add( np.add(output_im[0, :, :, :, 1].astype(int), output_im[1, :, :, :, 1].astype(int)), np.add(output_im[2, :, :, :, 1].astype(int), output_im[3, :, :, :, 1].astype(int))) out[:, :, :, 1][out[:, :, :, 1] < 300] = 0 out[:, :, :, 1][out[:, :, :, 1] >= 300] = 1 #one-hot out[:, :, :, 0][out[:, :, :, 1] == 1] = 0 out[:, :, :, 0][out[:, :, :, 1] == 0] = 1 #resolution res = [z_res, xy_res] if predict: # ----- predicting 2d Unet ------ if model_dim == 4: im = np.zeros((1, input_im.shape[1], input_im.shape[2], input_im.shape[3])) pred_output = np.zeros( (out.shape[0], out.shape[1], out.shape[2], out.shape[3])) for i in tqdm(range(input_im.shape[0])): im[0, :, :, :] = input_im[i, :, :, :] if pred_only_gt: if np.count_nonzero(out[i, :, :, 1]) == 0: continue pred_output[i, :, :, :] = model.predict(im) # ----- predicting 3d Unet ----- if model_dim == 5: inp = np.zeros((input_im.shape[0], 256, 256, 1)) output = np.zeros((input_im.shape[0], 256, 256, 2)) for i in range(input_im.shape[0]): if maxmin_input: inp[i, :, :, 0] = cv2.resize(maxminscale(input_im[i, :, :, 0]), (256, 256)) else: inp[i, :, :, 0] = cv2.resize(input_im[i, :, :, 0], (256, 256)) output[i, :, :, 0] = cv2.resize(out[i, :, :, 0], (256, 256), interpolation=cv2.INTER_NEAREST) output[i, :, :, 1] = cv2.resize(out[i, :, :, 1], (256, 256), interpolation=cv2.INTER_NEAREST) input_im = inp.copy() out = output.copy() chunks = int(np.ceil(input_im.shape[0] / model_config[1])) pred_output = np.zeros((model_config[1] * chunks, model_config[2], model_config[3], 2)) pred_output_5d = np.zeros( (chunks, model_config[1], model_config[2], model_config[3], 2)) im = np.zeros((chunks, model_config[1], model_config[2], model_config[3], model_config[4])) for i in range(chunks): for j in range(model_config[1]): if model_config[1] * i + j >= input_im.shape[0]: continue im[i, j, :, :, :] = input_im[model_config[1] * i + j, :, :, :] for i in tqdm(range(chunks)): if pred_only_gt: if np.count_nonzero(out[model_config[1] * i:(i + 1) * model_config[1], :, :, 1]) == 0: continue pred_output_5d[i, :, :, :, :] = model.predict( np.expand_dims(im[i, :, :, :, :], axis=0)) for i in range(chunks): for j in range(pred_output_5d.shape[1]): pred_output[i * pred_output_5d.shape[1] + j, :, :, :] = pred_output_5d[i, j, :, :, :] yield input_im, out, pred_output, pat, res else: yield input_im, out, pat, res
gt_vals.append(i) t.toc() print(gt_vals) # get some image num = int(input("select slice containing tumor: ")) img = data[num, :, :] #num = 224 # slice number #img = data[num,:,:] # 84 is a problem! lung segmentation not robust enough img[img <= -1024] = -1024 # better visual contrast, BUT AFFECTS PERFORMANCE OF FCM!!! img[img >= 1024] = 1024 img = maxminscale(img) # OBS #imgclean(img, Figure=True) # keep original maxminscaled image -> becuase img is going to be altered img_orig = img.copy() # blur img img = np.uint8(img) img = cv2.medianBlur(img, 5) # get dimensions of image row_size, col_size = img.shape # specify window for k-means to work on, such that you get the lung, and not the rest middle = img[int(col_size / 5):int(col_size / 5 * 4),
# get ground truth data and display at which slice there is a nodule as a list gt_data = output.value[:,:,:,0] gt_vals = [] for i in range(gt_data.shape[0]): if len(np.unique(gt_data[i,:,:])) > 1: gt_vals.append(i) print(gt_vals) # choose slice num = input("choose patient: ") img = data[int(num),:,:] img = np.asarray(img) org_img = maxminscale(img) org_img = np.uint8(org_img) org_img = cv2.medianBlur(org_img, 11) ## 2) get 8-bit image #img = maxminscale(img) #img = img.astype(np.uint8) img_orig = 1*img img_orig = maxminscale(img_orig) row_size= img.shape[0] col_size = img.shape[1] mean = np.mean(img) std = np.std(img)
def test2d3d(im_st, pred_st, res, thr = 0.1, plot_size_xy = 150, VGG_model_path = ''): t = TicToc() im_ct = im_st.copy() zoom_fac = [res[0]/res[1]/(im_st.shape[1]/plot_size_xy), plot_size_xy/im_st.shape[1], plot_size_xy/im_st.shape[1]] im_st = zoom(im_st.copy(), zoom = zoom_fac, order = 1) mask = np.zeros((im_st.shape[0],im_st.shape[1], im_st.shape[2])) #for i in range(im_st.shape[0]): # mask[i, :, :] = lungmask_pro(im_st[i, :, :], morph = False) mask[:,:,:] = lungmask3D(im_st[:, :, :], morph = False) for i in range(im_st.shape[1]): mask[:, i, :] = cv2.Canny(np.uint8(maxminscale(mask[:, i, :])),50,100) pred_st_tmp = pred_st.copy() pred_st_tmp = zoom(pred_st, zoom = zoom_fac, order = 0) bin_pred_3d = pred_st_tmp.copy() bin_pred_3d[bin_pred_3d < thr] = 0 bin_pred_3d[bin_pred_3d >= thr] = 1 bin_pred_2d = pred_st.copy() bin_pred_2d[bin_pred_2d < thr] = 0 bin_pred_2d[bin_pred_2d >= thr] = 1 #labels label_3d = label(bin_pred_3d) label_2d = label(bin_pred_2d) print('Number of nodules found: ' + str(len(np.unique(label_2d))-1)) mlab.figure(1, size = (1000,800)) mlab.contour3d(pred_st_tmp, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1) def new_thr(event): tmp_thr = pred_st[int(slider2.val)].copy() tmp_3d = pred_st_tmp.copy() threshold = event tmp_thr[tmp_thr < threshold] = 0 tmp_thr[tmp_thr >= threshold] = 1 tmp_3d[tmp_3d < threshold] = 0 tmp_3d[tmp_3d >= threshold] = 1 mlab.clf() #mlab.figure(1, size = (1000,800)) mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1) if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) #mlab.title('Prediction') ax.clear() ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray') #ax.imshow(tmp_thr, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3) ax.imshow(tmp_thr, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3) #ax.set_title('predicted') ax.set_axis_off() f.suptitle('slice '+str(slider2.val)) f.canvas.draw_idle() def images(event): threshold = slider1.val tmp = pred_st[int(event)].copy() tmp[tmp < threshold] = 0 tmp[tmp >= threshold] = 1 ax.clear() ax.imshow(im_ct[int(event), :, :], cmap = 'gray') #ax.imshow(tmp, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3) ax.imshow(tmp, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3) ax.set_title('predicted') ax.set_axis_off() f.suptitle('slice '+str(int(slider2.val))) f.canvas.draw_idle() def up_scroll_alt(event): if event.key == "up": if (slider2.val + 2 > im_ct.shape[0]): 1 #print("Whoops, end of stack", print(slider2.val)) else: slider2.set_val(slider2.val + 1) def down_scroll_alt(event): if event.key == "down": if (slider2.val - 1 < 0): 1 #print("Whoops, end of stack", print(slider2.val)) else: slider2.set_val(slider2.val - 1) def up_scroll(event): if event.button == 'up': if (slider2.val + 2 > im_ct.shape[0]): 1 #print("Whoops, end of stack", print(slider2.val)) else: slider2.set_val(slider2.val + 1) def down_scroll(event): if event.button == 'down': if (slider2.val - 1 < 0): 1 #print("Whoops, end of stack", print(slider2.val)) else: slider2.set_val(slider2.val - 1) def show_lung(event): if not button4.get_status()[0]: tmp_3d = pred_st_tmp.copy() threshold = slider1.val tmp_3d[tmp_3d < threshold] = 0 tmp_3d[tmp_3d >= threshold] = 1 mlab.clf() mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1) if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) elif button4.get_status()[0]: pred_class = pred_VGG(im_ct, pred_st, res, VGG_model_path) pred_class = zoom(pred_class.copy(), zoom = zoom_fac, order = 0) mlab.clf() print(np.unique(pred_class)[1:]) for i in np.unique(pred_class): if i == 0: continue tmp = pred_class.copy() tmp[pred_class != i] = 0 mlab.contour3d(tmp, colormap = 'OrRd', color = tuple(colors[int(round((9/5)*i)),:]), vmin = 1, vmax = 5) mlab.scalarbar(orientation = 'vertical', nb_labels = 9, label_fmt='%.1f') if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) mlab.orientation_axes(xlabel = 'z', ylabel = 'y', zlabel = 'x') def remove(event): # only add when seed point selected is at an axis if (event.ydata != None) or (event.xdata != None): ix, iy = int(event.ydata), int(event.xdata) # if already in add mode -> switch to remove mode # if button3.get_status()[0]: # button3.set_active(0) #print(event) if (str(event.inaxes).split('(')[0] == 'AxesSubplot') and button2.get_status()[0]: # remove coords_zoom = (np.array([slider2.val, ix, iy])*np.array(zoom_fac)).astype(int) coords_orig = (int(slider2.val), int(ix), int(iy)) print(coords_zoom,coords_orig) val_3d = label_3d[tuple(coords_zoom)] pred_st_tmp[label_3d == val_3d] = 0 val_2d = label_2d[coords_orig] pred_st[label_2d == val_2d] = 0 # re-plot tmp_thr = pred_st[int(slider2.val)].copy() tmp_3d = pred_st_tmp.copy() threshold = slider1.val tmp_thr[tmp_thr < threshold] = 0 tmp_thr[tmp_thr >= threshold] = 1 tmp_3d[tmp_3d < threshold] = 0 tmp_3d[tmp_3d >= threshold] = 1 mlab.clf() mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1) if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) ax.clear() ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray') #ax.imshow(tmp_thr, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3) ax.imshow(tmp_thr, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3) cursor = Cursor(ax, useblit=True, color='orange', linewidth=0.5) #ax.set_title('predicted') ax.set_axis_off() f.suptitle('slice '+str(slider2.val)) f.canvas.draw_idle() # def add_window(event): # # only add when seed point selected is at an axis # if (event.ydata != None) or (event.xdata != None): # ix, iy = int(event.ydata), int(event.xdata) # if (str(event.inaxes).split('(')[0] == 'AxesSubplot') and button3.get_status()[0]: # figs.show() def add(event): def adv_exit(event): print(1) plt.close(figs) #plt.close() def adv_start(event): print(2) plt.close(figs) def adv_window(): figs, axx = plt.subplots(num = 'Advanced settings') figs.canvas.mpl_connect('key_press_event', adv_exit) axx.axis('off') #bx1_as = plt.axes([0.05, 0.3, 0.15, 0.11]) #bx1_as.set_axis_off() #button_as1 = CheckButtons(bx1_as, ['lambda1'], [1]) axx.axis('off') ax_as1 = plt.axes([0.15, 0.02, 0.5, 0.05]) slider_as1 = Slider(ax_as1, 'lambda1', 0.1, 4, dragging = True, valstep = 0.1) ax_as2 = plt.axes([0.15, 0.10, 0.5, 0.05]) slider_as2 = Slider(ax_as2, 'lambda2', 0.1, 4, dragging = True, valstep = 0.1) ax_as3 = plt.axes([0.15, 0.18, 0.5, 0.05]) slider_as3 = Slider(ax_as3, 'smoothing', 0, 4, dragging = True, valstep = 1) ax_as4 = plt.axes([0.15, 0.26, 0.5, 0.05]) slider_as4 = Slider(ax_as4, 'iterations', 1, 1000, dragging = True, valstep = 1) ax_as5 = plt.axes([0.15, 0.34, 0.5, 0.05]) slider_as5 = Slider(ax_as5, 'radius', 0.5, 5, dragging = True, valstep = 0.1) ax_b1 = plt.axes([0.85, 0.15, 0.07, 0.08]) ax_b2 = plt.axes([0.85, 0.05, 0.07, 0.08]) but_as1 = Button(ax_b1, 'exit', color = 'beige', hovercolor = 'beige') but_as2 = Button(ax_b2, 'start', color = 'beige', hovercolor = 'beige') #ax_textbox = plt.axes([0, 0.4, 0.5, 0.4]) #axx.axis('off') textstr = "Press ENTER in terminal to start segmentation. \n Shouldn't be necessairy to change settings below, but can be tuned if \n resulting segmentation is not ideal. \n Especially if small nodule: try setting lambda1 <= lambda2 \n Or if very nonhomogeneous nodule: try setting lambda1 > lambda2." props = dict(boxstyle='round', facecolor='wheat') axx.text(-0.18, 0.25, textstr, transform=ax.transAxes, fontsize=12, verticalalignment='top', bbox=props) slider_as1.set_val(lambda1) slider_as2.set_val(lambda2) slider_as3.set_val(smoothing) slider_as4.set_val(iterations) slider_as5.set_val(rad) but_as1.on_clicked(adv_exit) but_as2.on_clicked(adv_start) figs.canvas.draw_idle() return figs, axx, slider_as1, slider_as2, slider_as3, slider_as4, slider_as5 # only add when seed point selected is at an axis if (event.ydata != None) or (event.xdata != None): ix, iy = int(event.ydata), int(event.xdata) if (str(event.inaxes).split('(')[0] == 'AxesSubplot') and button3.get_status()[0]: # default settings for levet set lambda1=1; lambda2=4; smoothing = 1; iterations = 100; rad = 3 # advanced settings window pop-up figs, axx, slider_as1, slider_as2, slider_as3, slider_as4, slider_as5 = adv_window() figs.show() # to start segmentation input('Press enter to start segmentation: ') plt.close(figs) # add coords_zoom = (np.array([slider2.val, ix, iy])*np.array(zoom_fac)).astype(int) coords_orig = (int(slider2.val), int(ix), int(iy)) print(coords_zoom,coords_orig) # apply level set to grow in 3D from single seed point seg_tmp = level_set3D(im_ct, coords_orig, list(reversed(res)), lambda1=slider_as1.val, lambda2=slider_as2.val, smoothing = int(slider_as3.val), iterations = int(slider_as4.val), rad = slider_as5.val) #seg_tmp = level_set3D(im_ct, coords_orig, list(reversed(res)), smoothing = int(slider_as3.val), iterations = int(slider_as4.val), rad = slider_as5.val, method = 'GAC', alpha = 150, sigma = 5, balloon = 1) # if no nodule was segmented, break; else continue if (seg_tmp is None): print('No nodule was segmented. Try changing parameters...') return None else: # because of interpolation, went from {0,1} -> [0,1]. Need to threshold to get binary segment seg_tmp[seg_tmp < 0.5] = 0 seg_tmp[seg_tmp >= 0.5] = 1 pred_st[seg_tmp == 1] = 1 label_2d[seg_tmp == 1] = len(np.unique(label_2d)) # OBS: need to add new label for new segment, in order to remove it properly! seg_tmp_3d = zoom(seg_tmp.copy(), zoom = zoom_fac, order = 1) seg_tmp_3d[seg_tmp_3d < 0.5] = 0 seg_tmp_3d[seg_tmp_3d >= 0.5] = 1 pred_st_tmp[seg_tmp_3d == 1] = 1 label_3d[seg_tmp_3d == 1] = len(np.unique(label_3d)) # re-plot tmp_thr = pred_st[int(slider2.val)].copy() tmp_3d = pred_st_tmp.copy() threshold = slider1.val tmp_thr[tmp_thr < threshold] = 0 tmp_thr[tmp_thr >= threshold] = 1 tmp_3d[tmp_3d < threshold] = 0 tmp_3d[tmp_3d >= threshold] = 1 mlab.clf() mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1) if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) ax.clear() ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray') #ax.imshow(tmp_thr, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3) ax.imshow(tmp_thr, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3) cursor = Cursor(ax, useblit=True, color='orange', linewidth=0.5) #ax.set_title('predicted') ax.set_axis_off() f.suptitle('slice '+str(slider2.val)) f.canvas.draw_idle() def remove_mode(event): # if already in add mode -> switch to remove mode if button3.get_status()[0]: button3.set_active(0) f.canvas.mpl_connect('button_press_event', remove) def add_mode(event): # if already in remove mode -> switch to add mode if button2.get_status()[0]: button2.set_active(0) f.canvas.mpl_connect('button_press_event', add) def classify(event): if button4.get_status()[0]: pred_class = pred_VGG(im_ct, pred_st, res, VGG_model_path) pred_class = zoom(pred_class.copy(), zoom = zoom_fac, order = 0) mlab.clf() print(np.unique(pred_class)[1:]) for i in np.unique(pred_class): if i == 0: continue tmp = pred_class.copy() tmp[pred_class != i] = 0 mlab.contour3d(tmp, colormap = 'OrRd', color = tuple(colors[int(round((9/5)*i)),:]), vmin = 1, vmax = 5) mlab.scalarbar(orientation = 'vertical', nb_labels = 9, label_fmt='%.1f') if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) mlab.orientation_axes(xlabel = 'z', ylabel = 'y', zlabel = 'x') elif not button4.get_status()[0]: # re-plot without classify tmp_thr = pred_st[int(slider2.val)].copy() tmp_3d = pred_st_tmp.copy() threshold = slider1.val tmp_thr[tmp_thr < threshold] = 0 tmp_thr[tmp_thr >= threshold] = 1 tmp_3d[tmp_3d < threshold] = 0 tmp_3d[tmp_3d >= threshold] = 1 mlab.clf() mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1) if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) ax.clear() ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray') #ax.imshow(tmp_thr, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3) ax.imshow(tmp_thr, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3) ax.set_axis_off() f.suptitle('slice '+str(slider2.val)) f.canvas.draw_idle() # Exit app when ESC is pressed def quit(event): if event.key == "escape": plt.close() mlab.close() # default settings for levet set for add-event lambda1=1; lambda2=4; smoothing = 1; iterations = 100; rad = 3 # colormap for 3D-malignancy-plot for classify-event dz = list(range(1,10)) norm = plt.Normalize() colors = plt.cm.OrRd(norm(dz)) colors = np.array(colors)[:,:3] # plot to make simulator gator f, ax = plt.subplots(1,1, figsize = (12, 12)) f.canvas.mpl_connect('key_press_event', up_scroll_alt) f.canvas.mpl_connect('key_press_event', down_scroll_alt) f.canvas.mpl_connect('scroll_event', up_scroll) f.canvas.mpl_connect('scroll_event', down_scroll) f.canvas.mpl_connect('key_press_event', quit) #f.canvas.mpl_connect('button_press_event', remove) #f.canvas.mpl_connect('button_press_event', add) b1ax = plt.axes([0.05, 0.2, 0.15, 0.11]) b1ax.set_axis_off() b2ax = plt.axes([0.05, 0.35, 0.15, 0.11]) b2ax.set_axis_off() b3ax = plt.axes([0.05, 0.5, 0.15, 0.11]) b3ax.set_axis_off() b4ax = plt.axes([0.05, 0.65, 0.15, 0.11]) b4ax.set_axis_off() s1ax = plt.axes([0.25, 0.08, 0.5, 0.03]) s2ax = plt.axes([0.25, 0.02, 0.5, 0.03]) slider1 = Slider(s1ax, 'threshold', 0.1, 1.0, dragging = True, valstep = 0.05) slider2 = Slider(s2ax, 'slice', 0.0, im_ct.shape[0]-1, valstep = 1) button1 = CheckButtons(b1ax, ['lung'], [False]) button2 = CheckButtons(b2ax, ['Remove'], [False]) button3 = CheckButtons(b3ax, ['Add'], [False]) button4 = CheckButtons(b4ax, ['Classify'], [False]) slider1.set_val(0.5) slider2.set_val(0) f.subplots_adjust(bottom = 0.15) #ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray') #ax.imshow(pred_st[int(slider2.val), :, :], cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3) ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray') ax.imshow(pred_st[int(slider2.val), :, :], cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3) cursor = Cursor(ax, useblit=True, color='orange', linewidth=0.5) ax.set_title('predicted') ax.set_axis_off() f.suptitle('slice '+str(slider2.val)) button1.on_clicked(show_lung) button2.on_clicked(remove_mode) button3.on_clicked(add_mode) button4.on_clicked(classify) slider1.on_changed(new_thr) slider2.on_changed(images) plt.show()
from scipy.misc import imread import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import scipy.ndimage.filters as filters from skimage import measure import drlse_algo as drlse import numpy as np from image_functions import maxminscale, imgclean, img2pixels import cv2 #img = plt.imread('test_end.png') img = plt.imread('Figure_1_org.png') img = img[:, :, 0] img = maxminscale(img) org_img = 1 * img # need this later! # initial seed point -> to grow from seed_point_full = (int(366.7933774834437), int(315.6754966887417)) reg = 25 img = img[seed_point_full[0] - reg:seed_point_full[0] + reg, seed_point_full[1] - reg:seed_point_full[1] + reg] seed_point = ( reg, reg ) # seed point for new smaller image 50x50 -> need to rename later.. np # parameters timestep = 1 # time step (1) mu = 0.2 / timestep # coefficient of the distance regularization term R(phi) iter_inner = 4 # (4)