def plot_3d_no_gt(im_st, pred_st, res, thr = 0.5, plot_size_xy = 256): zoom_fac = [res[0]/res[1]/(im_st.shape[1]/plot_size_xy), plot_size_xy/im_st.shape[1], plot_size_xy/im_st.shape[1]] im_st = zoom(im_st.copy(), zoom = zoom_fac, order = 1) mask = np.zeros((im_st.shape[0],im_st.shape[1], im_st.shape[2])) #for i in range(im_st.shape[0]): # mask[i, :, :] = lungmask_pro(im_st[i, :, :], morph = False) mask[:,:,:] = lungmask3D(im_st[:, :, :], morph = False) for i in range(im_st.shape[1]): mask[:, i, :] = cv2.Canny(np.uint8(maxminscale(mask[:, i, :])),50,100) pred_st[pred_st < thr] = 0 pred_st[pred_st >= thr] = 180 pred_st = zoom(pred_st.copy(), zoom = zoom_fac, order = 0) mlab.figure(1, size = (1000,800)) mlab.contour3d(pred_st, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 255) mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) mlab.title('Prediction') mlab.show()
def plot_lungmask(im_st): mask = np.zeros((im_st.shape[0], im_st.shape[1], im_st.shape[2], 2)) print('get mask') #for i in range(im_st.shape[0]): # mask[i, :, :, 1] = lungmask(im_st[i, :, :, 0]) mask[:,:,:,1] = lungmask3D(im_st[:, :, :, 0], morph = False) for i in range(im_st.shape[1]): mask[:, i, :, 1] = cv2.Canny(np.uint8(maxminscale(mask[:, i, :, 1])),50,100) print('plotting') mlab.contour3d(mask[:, :, :, 1], colormap = 'Blues', opacity = 0.1) mlab.show()
def plot_lung_and_tumor(im_st, pred_st, gt_st, res = [], thr = 0.5): mask = np.zeros((im_st.shape[0],im_st.shape[1], im_st.shape[2],2)) #for i in range(im_st.shape[0]): # mask[i, :, :, 1] = lungmask_pro(im_st[i, :, :, 0], morph = False) mask[:,:,:,1] = lungmask3D(im_st[:, :, :, 0], morph = False) for i in range(im_st.shape[1]): mask[:, i, :, 1] = cv2.Canny(np.uint8(maxminscale(mask[:, i, :, 1])),50,100) pred_st[pred_st < thr] = 0 pred_st[pred_st >= thr] = 180 gt_st[gt_st >= thr] = 180 # TODO - make zoom_fac dependent on resolution in z dir (true values) if not res: zoom_fac = im_st.shape[1]/im_st.shape[0] else: zoom_fac = res[0]/res[1]/2 mask = zoom(mask[:,:,:,1], zoom = [zoom_fac,1,1]) pred_st = zoom(pred_st[:, :, :, 1], zoom = [zoom_fac,1,1]) gt_st = zoom(gt_st[:, :, :, 1], zoom = [zoom_fac,1,1]) mlab.figure(1, size = (1000,800)) mlab.contour3d(pred_st, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 255) mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) mlab.title('Prediction') mlab.figure(2, size = (1000,800)) mlab.contour3d(gt_st, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 255) mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) mlab.title('Ground truth') mlab.show()
#name of the created files name = 'pat_12' if not os.path.exists(save_path): os.mkdir(save_path) # read dicom images, res, offset = import_images(image_path) #predict/segment with unet unet_pred = pred_unet(images, unet_model_path) #classifiy with VGG VGG_pred = pred_VGG(images, unet_pred, res, VGG_model_path) #---- To custusX ---- #save segmentation in stl format pred_to_stl(unet_pred, res, save_path, name) #save raw and mhd pred_to_raw(unet_pred, save_path, name, res, offset) #save lungmask as raw with mhd mask = lungmask3D(images, morph=False) images[images < -1024] = -1024 images[images >= 400] = 400 lung = images.copy() lung[mask == 0] = np.amin(lung) lung = lung - np.amin(lung) lung = lung / np.amax(lung) pred_to_raw(lung, save_path, name + '_lung', res, offset)
def test2d3d(im_st, pred_st, res, thr = 0.1, plot_size_xy = 150, VGG_model_path = ''): t = TicToc() im_ct = im_st.copy() zoom_fac = [res[0]/res[1]/(im_st.shape[1]/plot_size_xy), plot_size_xy/im_st.shape[1], plot_size_xy/im_st.shape[1]] im_st = zoom(im_st.copy(), zoom = zoom_fac, order = 1) mask = np.zeros((im_st.shape[0],im_st.shape[1], im_st.shape[2])) #for i in range(im_st.shape[0]): # mask[i, :, :] = lungmask_pro(im_st[i, :, :], morph = False) mask[:,:,:] = lungmask3D(im_st[:, :, :], morph = False) for i in range(im_st.shape[1]): mask[:, i, :] = cv2.Canny(np.uint8(maxminscale(mask[:, i, :])),50,100) pred_st_tmp = pred_st.copy() pred_st_tmp = zoom(pred_st, zoom = zoom_fac, order = 0) bin_pred_3d = pred_st_tmp.copy() bin_pred_3d[bin_pred_3d < thr] = 0 bin_pred_3d[bin_pred_3d >= thr] = 1 bin_pred_2d = pred_st.copy() bin_pred_2d[bin_pred_2d < thr] = 0 bin_pred_2d[bin_pred_2d >= thr] = 1 #labels label_3d = label(bin_pred_3d) label_2d = label(bin_pred_2d) print('Number of nodules found: ' + str(len(np.unique(label_2d))-1)) mlab.figure(1, size = (1000,800)) mlab.contour3d(pred_st_tmp, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1) def new_thr(event): tmp_thr = pred_st[int(slider2.val)].copy() tmp_3d = pred_st_tmp.copy() threshold = event tmp_thr[tmp_thr < threshold] = 0 tmp_thr[tmp_thr >= threshold] = 1 tmp_3d[tmp_3d < threshold] = 0 tmp_3d[tmp_3d >= threshold] = 1 mlab.clf() #mlab.figure(1, size = (1000,800)) mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1) if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) #mlab.title('Prediction') ax.clear() ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray') #ax.imshow(tmp_thr, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3) ax.imshow(tmp_thr, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3) #ax.set_title('predicted') ax.set_axis_off() f.suptitle('slice '+str(slider2.val)) f.canvas.draw_idle() def images(event): threshold = slider1.val tmp = pred_st[int(event)].copy() tmp[tmp < threshold] = 0 tmp[tmp >= threshold] = 1 ax.clear() ax.imshow(im_ct[int(event), :, :], cmap = 'gray') #ax.imshow(tmp, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3) ax.imshow(tmp, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3) ax.set_title('predicted') ax.set_axis_off() f.suptitle('slice '+str(int(slider2.val))) f.canvas.draw_idle() def up_scroll_alt(event): if event.key == "up": if (slider2.val + 2 > im_ct.shape[0]): 1 #print("Whoops, end of stack", print(slider2.val)) else: slider2.set_val(slider2.val + 1) def down_scroll_alt(event): if event.key == "down": if (slider2.val - 1 < 0): 1 #print("Whoops, end of stack", print(slider2.val)) else: slider2.set_val(slider2.val - 1) def up_scroll(event): if event.button == 'up': if (slider2.val + 2 > im_ct.shape[0]): 1 #print("Whoops, end of stack", print(slider2.val)) else: slider2.set_val(slider2.val + 1) def down_scroll(event): if event.button == 'down': if (slider2.val - 1 < 0): 1 #print("Whoops, end of stack", print(slider2.val)) else: slider2.set_val(slider2.val - 1) def show_lung(event): if not button4.get_status()[0]: tmp_3d = pred_st_tmp.copy() threshold = slider1.val tmp_3d[tmp_3d < threshold] = 0 tmp_3d[tmp_3d >= threshold] = 1 mlab.clf() mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1) if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) elif button4.get_status()[0]: pred_class = pred_VGG(im_ct, pred_st, res, VGG_model_path) pred_class = zoom(pred_class.copy(), zoom = zoom_fac, order = 0) mlab.clf() print(np.unique(pred_class)[1:]) for i in np.unique(pred_class): if i == 0: continue tmp = pred_class.copy() tmp[pred_class != i] = 0 mlab.contour3d(tmp, colormap = 'OrRd', color = tuple(colors[int(round((9/5)*i)),:]), vmin = 1, vmax = 5) mlab.scalarbar(orientation = 'vertical', nb_labels = 9, label_fmt='%.1f') if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) mlab.orientation_axes(xlabel = 'z', ylabel = 'y', zlabel = 'x') def remove(event): # only add when seed point selected is at an axis if (event.ydata != None) or (event.xdata != None): ix, iy = int(event.ydata), int(event.xdata) # if already in add mode -> switch to remove mode # if button3.get_status()[0]: # button3.set_active(0) #print(event) if (str(event.inaxes).split('(')[0] == 'AxesSubplot') and button2.get_status()[0]: # remove coords_zoom = (np.array([slider2.val, ix, iy])*np.array(zoom_fac)).astype(int) coords_orig = (int(slider2.val), int(ix), int(iy)) print(coords_zoom,coords_orig) val_3d = label_3d[tuple(coords_zoom)] pred_st_tmp[label_3d == val_3d] = 0 val_2d = label_2d[coords_orig] pred_st[label_2d == val_2d] = 0 # re-plot tmp_thr = pred_st[int(slider2.val)].copy() tmp_3d = pred_st_tmp.copy() threshold = slider1.val tmp_thr[tmp_thr < threshold] = 0 tmp_thr[tmp_thr >= threshold] = 1 tmp_3d[tmp_3d < threshold] = 0 tmp_3d[tmp_3d >= threshold] = 1 mlab.clf() mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1) if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) ax.clear() ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray') #ax.imshow(tmp_thr, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3) ax.imshow(tmp_thr, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3) cursor = Cursor(ax, useblit=True, color='orange', linewidth=0.5) #ax.set_title('predicted') ax.set_axis_off() f.suptitle('slice '+str(slider2.val)) f.canvas.draw_idle() # def add_window(event): # # only add when seed point selected is at an axis # if (event.ydata != None) or (event.xdata != None): # ix, iy = int(event.ydata), int(event.xdata) # if (str(event.inaxes).split('(')[0] == 'AxesSubplot') and button3.get_status()[0]: # figs.show() def add(event): def adv_exit(event): print(1) plt.close(figs) #plt.close() def adv_start(event): print(2) plt.close(figs) def adv_window(): figs, axx = plt.subplots(num = 'Advanced settings') figs.canvas.mpl_connect('key_press_event', adv_exit) axx.axis('off') #bx1_as = plt.axes([0.05, 0.3, 0.15, 0.11]) #bx1_as.set_axis_off() #button_as1 = CheckButtons(bx1_as, ['lambda1'], [1]) axx.axis('off') ax_as1 = plt.axes([0.15, 0.02, 0.5, 0.05]) slider_as1 = Slider(ax_as1, 'lambda1', 0.1, 4, dragging = True, valstep = 0.1) ax_as2 = plt.axes([0.15, 0.10, 0.5, 0.05]) slider_as2 = Slider(ax_as2, 'lambda2', 0.1, 4, dragging = True, valstep = 0.1) ax_as3 = plt.axes([0.15, 0.18, 0.5, 0.05]) slider_as3 = Slider(ax_as3, 'smoothing', 0, 4, dragging = True, valstep = 1) ax_as4 = plt.axes([0.15, 0.26, 0.5, 0.05]) slider_as4 = Slider(ax_as4, 'iterations', 1, 1000, dragging = True, valstep = 1) ax_as5 = plt.axes([0.15, 0.34, 0.5, 0.05]) slider_as5 = Slider(ax_as5, 'radius', 0.5, 5, dragging = True, valstep = 0.1) ax_b1 = plt.axes([0.85, 0.15, 0.07, 0.08]) ax_b2 = plt.axes([0.85, 0.05, 0.07, 0.08]) but_as1 = Button(ax_b1, 'exit', color = 'beige', hovercolor = 'beige') but_as2 = Button(ax_b2, 'start', color = 'beige', hovercolor = 'beige') #ax_textbox = plt.axes([0, 0.4, 0.5, 0.4]) #axx.axis('off') textstr = "Press ENTER in terminal to start segmentation. \n Shouldn't be necessairy to change settings below, but can be tuned if \n resulting segmentation is not ideal. \n Especially if small nodule: try setting lambda1 <= lambda2 \n Or if very nonhomogeneous nodule: try setting lambda1 > lambda2." props = dict(boxstyle='round', facecolor='wheat') axx.text(-0.18, 0.25, textstr, transform=ax.transAxes, fontsize=12, verticalalignment='top', bbox=props) slider_as1.set_val(lambda1) slider_as2.set_val(lambda2) slider_as3.set_val(smoothing) slider_as4.set_val(iterations) slider_as5.set_val(rad) but_as1.on_clicked(adv_exit) but_as2.on_clicked(adv_start) figs.canvas.draw_idle() return figs, axx, slider_as1, slider_as2, slider_as3, slider_as4, slider_as5 # only add when seed point selected is at an axis if (event.ydata != None) or (event.xdata != None): ix, iy = int(event.ydata), int(event.xdata) if (str(event.inaxes).split('(')[0] == 'AxesSubplot') and button3.get_status()[0]: # default settings for levet set lambda1=1; lambda2=4; smoothing = 1; iterations = 100; rad = 3 # advanced settings window pop-up figs, axx, slider_as1, slider_as2, slider_as3, slider_as4, slider_as5 = adv_window() figs.show() # to start segmentation input('Press enter to start segmentation: ') plt.close(figs) # add coords_zoom = (np.array([slider2.val, ix, iy])*np.array(zoom_fac)).astype(int) coords_orig = (int(slider2.val), int(ix), int(iy)) print(coords_zoom,coords_orig) # apply level set to grow in 3D from single seed point seg_tmp = level_set3D(im_ct, coords_orig, list(reversed(res)), lambda1=slider_as1.val, lambda2=slider_as2.val, smoothing = int(slider_as3.val), iterations = int(slider_as4.val), rad = slider_as5.val) #seg_tmp = level_set3D(im_ct, coords_orig, list(reversed(res)), smoothing = int(slider_as3.val), iterations = int(slider_as4.val), rad = slider_as5.val, method = 'GAC', alpha = 150, sigma = 5, balloon = 1) # if no nodule was segmented, break; else continue if (seg_tmp is None): print('No nodule was segmented. Try changing parameters...') return None else: # because of interpolation, went from {0,1} -> [0,1]. Need to threshold to get binary segment seg_tmp[seg_tmp < 0.5] = 0 seg_tmp[seg_tmp >= 0.5] = 1 pred_st[seg_tmp == 1] = 1 label_2d[seg_tmp == 1] = len(np.unique(label_2d)) # OBS: need to add new label for new segment, in order to remove it properly! seg_tmp_3d = zoom(seg_tmp.copy(), zoom = zoom_fac, order = 1) seg_tmp_3d[seg_tmp_3d < 0.5] = 0 seg_tmp_3d[seg_tmp_3d >= 0.5] = 1 pred_st_tmp[seg_tmp_3d == 1] = 1 label_3d[seg_tmp_3d == 1] = len(np.unique(label_3d)) # re-plot tmp_thr = pred_st[int(slider2.val)].copy() tmp_3d = pred_st_tmp.copy() threshold = slider1.val tmp_thr[tmp_thr < threshold] = 0 tmp_thr[tmp_thr >= threshold] = 1 tmp_3d[tmp_3d < threshold] = 0 tmp_3d[tmp_3d >= threshold] = 1 mlab.clf() mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1) if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) ax.clear() ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray') #ax.imshow(tmp_thr, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3) ax.imshow(tmp_thr, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3) cursor = Cursor(ax, useblit=True, color='orange', linewidth=0.5) #ax.set_title('predicted') ax.set_axis_off() f.suptitle('slice '+str(slider2.val)) f.canvas.draw_idle() def remove_mode(event): # if already in add mode -> switch to remove mode if button3.get_status()[0]: button3.set_active(0) f.canvas.mpl_connect('button_press_event', remove) def add_mode(event): # if already in remove mode -> switch to add mode if button2.get_status()[0]: button2.set_active(0) f.canvas.mpl_connect('button_press_event', add) def classify(event): if button4.get_status()[0]: pred_class = pred_VGG(im_ct, pred_st, res, VGG_model_path) pred_class = zoom(pred_class.copy(), zoom = zoom_fac, order = 0) mlab.clf() print(np.unique(pred_class)[1:]) for i in np.unique(pred_class): if i == 0: continue tmp = pred_class.copy() tmp[pred_class != i] = 0 mlab.contour3d(tmp, colormap = 'OrRd', color = tuple(colors[int(round((9/5)*i)),:]), vmin = 1, vmax = 5) mlab.scalarbar(orientation = 'vertical', nb_labels = 9, label_fmt='%.1f') if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) mlab.orientation_axes(xlabel = 'z', ylabel = 'y', zlabel = 'x') elif not button4.get_status()[0]: # re-plot without classify tmp_thr = pred_st[int(slider2.val)].copy() tmp_3d = pred_st_tmp.copy() threshold = slider1.val tmp_thr[tmp_thr < threshold] = 0 tmp_thr[tmp_thr >= threshold] = 1 tmp_3d[tmp_3d < threshold] = 0 tmp_3d[tmp_3d >= threshold] = 1 mlab.clf() mlab.contour3d(tmp_3d, colormap = 'hot', opacity = 1.0, vmin = 0, vmax = 1) if button1.get_status()[0]: mlab.contour3d(mask, colormap = 'Greys', opacity = 0.1, vmin = 0, vmax = 150) ax.clear() ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray') #ax.imshow(tmp_thr, cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3) ax.imshow(tmp_thr, cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3) ax.set_axis_off() f.suptitle('slice '+str(slider2.val)) f.canvas.draw_idle() # Exit app when ESC is pressed def quit(event): if event.key == "escape": plt.close() mlab.close() # default settings for levet set for add-event lambda1=1; lambda2=4; smoothing = 1; iterations = 100; rad = 3 # colormap for 3D-malignancy-plot for classify-event dz = list(range(1,10)) norm = plt.Normalize() colors = plt.cm.OrRd(norm(dz)) colors = np.array(colors)[:,:3] # plot to make simulator gator f, ax = plt.subplots(1,1, figsize = (12, 12)) f.canvas.mpl_connect('key_press_event', up_scroll_alt) f.canvas.mpl_connect('key_press_event', down_scroll_alt) f.canvas.mpl_connect('scroll_event', up_scroll) f.canvas.mpl_connect('scroll_event', down_scroll) f.canvas.mpl_connect('key_press_event', quit) #f.canvas.mpl_connect('button_press_event', remove) #f.canvas.mpl_connect('button_press_event', add) b1ax = plt.axes([0.05, 0.2, 0.15, 0.11]) b1ax.set_axis_off() b2ax = plt.axes([0.05, 0.35, 0.15, 0.11]) b2ax.set_axis_off() b3ax = plt.axes([0.05, 0.5, 0.15, 0.11]) b3ax.set_axis_off() b4ax = plt.axes([0.05, 0.65, 0.15, 0.11]) b4ax.set_axis_off() s1ax = plt.axes([0.25, 0.08, 0.5, 0.03]) s2ax = plt.axes([0.25, 0.02, 0.5, 0.03]) slider1 = Slider(s1ax, 'threshold', 0.1, 1.0, dragging = True, valstep = 0.05) slider2 = Slider(s2ax, 'slice', 0.0, im_ct.shape[0]-1, valstep = 1) button1 = CheckButtons(b1ax, ['lung'], [False]) button2 = CheckButtons(b2ax, ['Remove'], [False]) button3 = CheckButtons(b3ax, ['Add'], [False]) button4 = CheckButtons(b4ax, ['Classify'], [False]) slider1.set_val(0.5) slider2.set_val(0) f.subplots_adjust(bottom = 0.15) #ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray') #ax.imshow(pred_st[int(slider2.val), :, :], cmap = 'inferno', vmin = 0, vmax = 1, alpha = 0.3) ax.imshow(im_ct[int(slider2.val), :, :], cmap = 'gray') ax.imshow(pred_st[int(slider2.val), :, :], cmap = 'gnuplot', vmin = 0, vmax = 2, alpha = 0.3) cursor = Cursor(ax, useblit=True, color='orange', linewidth=0.5) ax.set_title('predicted') ax.set_axis_off() f.suptitle('slice '+str(slider2.val)) button1.on_clicked(show_lung) button2.on_clicked(remove_mode) button3.on_clicked(add_mode) button4.on_clicked(classify) slider1.on_changed(new_thr) slider2.on_changed(images) plt.show()