def get_grad_img(polar_patch): OFFX = 2 polar_patch_gaussian = gaussian_filter(polar_patch, sigma=5) sy = ndimage.sobel(polar_patch_gaussian, axis=0, mode='constant') sx = ndimage.sobel(polar_patch_gaussian, axis=1, mode='constant') polar_grad = np.hypot(sx, sy) polar_grad = croppatch(croppatch(polar_grad, 128, 128, 127, 127), 127, 127 - OFFX, 128, 128) polar_grad[0] = polar_grad[1] polar_grad[-1] = polar_grad[-2] gradimg = polar_grad #/ np.max(polar_grad) return gradimg
def multi_min_dist_map(min_dist_cnn, dicomslice, octy, octx, patchheight, patchwidth): search_patch_rz_batch = [] stride = 10 rg = 2 DEBUG = 0 for ofi in np.arange(-rg, rg + 1): for ofj in np.arange(-rg, rg + 1): octyof = octy + ofi * stride octxof = octx + ofj * stride searchpatch = croppatch(dicomslice, octyof, octxof, patchheight, patchwidth) searchpatchrz = cv2.resize(searchpatch, (0, 0), fx=4, fy=4) searchpatchrz = searchpatchrz / np.max(searchpatchrz) # searchpatchrz = searchpatch search_patch_rz_batch.append(searchpatchrz) '''plt.imshow(searchpatchrz) plt.title('Original Patch'+str(octyof)+str(octxof)) plt.show()''' search_patch_rz_batch = np.array(search_patch_rz_batch) min_dist_search_patch_batch = min_dist_cnn.predict( search_patch_rz_batch[:, :, :, None])[:, :, :, 0] min_dist_search_patch = [] for i in range(min_dist_search_patch_batch.shape[0]): ofi = i // (2 * rg + 1) - rg ofj = i % (2 * rg + 1) - rg # print(ofi,ofj) offpred = croppatch(min_dist_search_patch_batch[i], 256 - ofi * stride * 4, 256 - ofj * stride * 4, 256, 256) # offpred = croppatch(min_dist_search_patch_batch[i], 256 - ofi * stride, 256 - ofj * stride, 256, 256) if DEBUG: plt.suptitle('Original Patch' + str(ofi) + str(ofj)) plt.subplot(1, 3, 1) plt.imshow(search_patch_rz_batch[i]) plt.subplot(1, 3, 2) plt.imshow(min_dist_search_patch_batch[i]) plt.subplot(1, 3, 3) plt.imshow(offpred) #plt.colorbar() plt.show() print(np.max(offpred)) min_dist_search_patch.append(offpred / np.max(offpred)) min_dist_search_patch = np.array(min_dist_search_patch) min_dist_search_patch = np.max(min_dist_search_patch, axis=0) # plt.imshow(min_dist_search_patch) return min_dist_search_patch
def multi_min_dist_pred_component(min_dist_cnn, dicomslice, bb, patchheight, patchwidth, kernelmask=None): #min_dist_search_patch = multi_min_dist_map(min_dist_cnn, dicomslice, octy, octx, patchheight, patchwidth) cts, nms_dist_map_p = multi_min_dist_pred_withinbb(min_dist_cnn, dicomslice, bb, 64, 64) min_dist_search_patch = nms_dist_map_p[:, :, 0] if kernelmask is None: kernelmask = gen_2d_gaussion(min_dist_search_patch.shape[0], 0.5) component_bbs, nms_dist_map_p = find_con_region(min_dist_search_patch * kernelmask, fig=1) #prepare output img search_patch_center = croppatch(dicomslice, bb.y, bb.x, patchheight, patchwidth) search_patch_center_rz = cv2.resize(search_patch_center, (0, 0), fx=4, fy=4) # search_patch_center_rz = search_patch_center # print('map max',np.max(search_patch_center_rz)) nms_dist_map_p[:, :, 2] = search_patch_center_rz * 255 return component_bbs, nms_dist_map_p
def multi_min_dist_pred(min_dist_cnn, dicomslice, octy, octx, patchheight, patchwidth, kernelmask=None): min_dist_search_patch = multi_min_dist_map(min_dist_cnn, dicomslice, octy, octx, patchheight, patchwidth) if kernelmask is None: kernelmask = gen_2d_gaussion(min_dist_search_patch.shape[0], 0.5) dist_pred_ct, nms_dist_map_p = find_nms_center(min_dist_search_patch * kernelmask, fig=1) search_patch_center = croppatch(dicomslice, octy, octx, patchheight, patchwidth) search_patch_center_rz = cv2.resize(search_patch_center, (0, 0), fx=4, fy=4) # search_patch_center_rz = search_patch_center #print('map max',np.max(search_patch_center_rz)) nms_dist_map_p[:, :, 2] = search_patch_center_rz * 255 cts = [[int(i.split('-')[0]), int(i.split('-')[1])] for i in list(dist_pred_ct.keys())] if len(cts) == 0: print('no ct') elif len(cts) != 1: pass #print('cts', len(cts)) return cts, nms_dist_map_p
def load_aug_patch(self, slicei, type): if not self.valid_slicei(slicei): return None if type == 'polar_patch': aug_filename = self.caselist['slices'][slicei]['augpatch'][ 'polarpatchbatchname'] elif type == 'cart_patch': aug_filename = self.caselist['slices'][slicei]['augpatch'][ 'cartpatchbatchname'] elif type == 'polar_cont': aug_filename = self.caselist['slices'][slicei]['augpatch'][ 'polarcoutbatchname'] elif type == 'cart_label': aug_offs = self.load_aug_off(slicei) cart_label_vw = self.load_cart_vw(slicei) cart_label_vw_aug_batch = [] for augi in range(len(aug_offs)): ctx = aug_offs[augi][0] cty = aug_offs[augi][1] cart_label_vw_aug = croppatch(cart_label_vw, 256 + cty, 256 + ctx, 256, 256) cart_label_vw_aug_batch.append(cart_label_vw_aug) return cart_label_vw_aug_batch elif type == 'polar_label': aug_offs = self.load_aug_off(slicei) cart_label_vw = self.load_cart_vw(slicei) polar_label_vw_aug_batch = [] for augi in range(len(aug_offs)): ctx = aug_offs[augi][0] cty = aug_offs[augi][1] cart_label_vw_aug = croppatch(cart_label_vw, 256 + cty, 256 + ctx, 256, 256) polar_label_vw_aug = topolar(cart_label_vw_aug, 256, 256) polar_label_vw_aug_batch.append(polar_label_vw_aug) return polar_label_vw_aug_batch else: print('unknown type') return aug_filename = self.targetprefix + aug_filename if not os.path.exists(aug_filename): print('no aug polar patch file', aug_filename) raise FileNotFoundError aug_batch = np.load(aug_filename).astype(np.float) return aug_batch
def intpath(pos1, pos2, dicomstack): DEBUG = 0 pos1 = np.array(pos1) pos2 = np.array(pos2) pos1int = [int(round(posi)) for posi in pos1] pos2int = [int(round(posi)) for posi in pos2] direction = pos2 - pos1 dist = np.linalg.norm(direction) dirnorm = direction / dist intp = [] intp.append(dicomstack[pos1int[2]][pos1int[1]][pos1int[0]]) if DEBUG: pdsp = croppatch(dicomstack[pos1int[2]], pos1int[1], pos1int[0]) pdsp[pdsp.shape[0] // 2] = np.max(pdsp) pdsp[:, pdsp.shape[1] // 2] = np.max(pdsp) plt.title('start') plt.imshow(pdsp) plt.show() for stepi in range(1, int(np.floor(dist))): cpos = pos1 + dirnorm * stepi cposint = np.array([int(np.round(cposi)) for cposi in cpos]) if DEBUG: pdsp = croppatch(dicomstack[cposint[2]], cpos[1], cpos[0]) pdsp[pdsp.shape[0] // 2] = np.max(pdsp) pdsp[:, pdsp.shape[1] // 2] = np.max(pdsp) plt.imshow(pdsp) plt.title(str(cposint[2])) plt.show() valceil = dicomstack[int(np.ceil(cpos[2]))][cposint[1]][cposint[0]] valfloor = dicomstack[int(np.floor(cpos[2]))][cposint[1]][cposint[0]] w1 = int(np.ceil(cpos[2])) - cpos[2] w2 = cpos[2] - int(np.floor(cpos[2])) intp.append(valceil * w2 / (w1 + w2) + valfloor * w1 / (w1 + w2)) intp.append(dicomstack[pos2int[2]][pos2int[1]][pos2int[0]]) if DEBUG: pdsp = croppatch(dicomstack[pos2int[2]], pos2int[1], pos2int[0]) pdsp[pdsp.shape[0] // 2] = np.max(pdsp) pdsp[:, pdsp.shape[1] // 2] = np.max(pdsp) plt.title('end') plt.imshow(pdsp) plt.show() return intp
def crop_dcm_stack(dicomstack, sid, cty, ctx, hps, RESCALE, depth=3): cartstack = croppatch(dicomstack[sid], cty, ctx, hps, hps) cartstackrz = cv2.resize(cartstack, (0, 0), fx=RESCALE, fy=RESCALE) cartstackrz = np.repeat(cartstackrz[:, :, None], depth, axis=2) nei = depth // 2 for ni in range(1, nei + 1): imgslicep = sid - ni if imgslicep >= 0: cartstack = croppatch(dicomstack[imgslicep], cty, ctx, 64, 64) cartstackrz[..., nei - ni] = cv2.resize(cartstack, (0, 0), fx=RESCALE, fy=RESCALE) imgslicen = sid + ni if imgslicen < dicomstack.shape[0]: cartstack = croppatch(dicomstack[imgslicen], cty, ctx, 64, 64) cartstackrz[..., nei + ni] = cv2.resize(cartstack, (0, 0), fx=RESCALE, fy=RESCALE) cartstackrz = cartstackrz / np.max(cartstackrz) return cartstackrz
def cont_ct_within_bb(min_dist_cnn, dicomstack, slicei, bb, seq, posstart): DEBUG = 0 if slicei == seq[-1][0]: cts, nms_dist_map_p = multi_min_dist_pred_withinbb( min_dist_cnn, dicomstack[slicei], bb, 64, 64) else: cts, nms_dist_map_p = multi_min_dist_pred(min_dist_cnn, dicomstack[slicei], bb.y, bb.x, 64, 64) # merge multiple ct in the same connected region merge_cts = mergects(nms_dist_map_p[:, :, 0], cts) # convert ct to dicom coordinate cts_dcm = to_dcm_cord(merge_cts, bb) if DEBUG: plt.imshow(nms_dist_map_p) plt.show() if len(cts_dcm) == 1: bbx = cts_dcm[0][0] bby = cts_dcm[0][1] elif len(cts_dcm) > 1: # last slice choose lower artery if slicei == seq[-1][0]: mapint = [ nms_dist_map_p[int(round(merge_cts[ica][1])), int(round(merge_cts[ica][0])), 0] for ica in range(len(merge_cts)) ] print('mapint', mapint) #ypos = [cts_dcm[ica][1] for ica in range(len(cts_dcm))] #ica = np.argmax(ypos) ica = np.argmax(mapint) bbx = cts_dcm[ica][0] bby = cts_dcm[ica][1] # other slices find minimum int change else: #print('multiple', cts_dcm,bbprev) ctm = ct_with_min_change(posstart, slicei, cts_dcm, dicomstack) #print('ctm', ctm) bbx = ctm[0] bby = ctm[1] if DEBUG: plt.imshow(croppatch(dicomstack[slicei], bby, bbx)) plt.show() else: print('no cts') return bbx, bby
def displayvw(self,figfilename=None): plt.figure(figsize=(18, 5)) for slicei in range(len(self.seg['segct'])): if self.seg['segct'][slicei] is None or self.seg['cartcont']['lumen'][slicei] is None: continue ct = np.mean(self.seg['segct'][slicei],axis=0) dcmpatch = croppatch(self.plotvw(slicei), ct[1], ct[0], 64, 64) maxval = np.max(dcmpatch) dcmpatch[:,dcmpatch.shape[1]//2] = maxval dcmpatch[dcmpatch.shape[1] // 2] = maxval plt.subplot(2, int(np.ceil(len(self.dicomstack) / 2)), slicei + 1) plt.imshow(dcmpatch) plt.title(slicei) if figfilename is not None: plt.savefig(figfilename) else: plt.show() plt.close()
def load_dcm_stack(dicompath, norm=0, seq='101', refscale=1, shift_x=0, shift_y=0): piname = dicompath.split('/')[-2] einame = dicompath.split('/')[-1] imgnamepattern = os.path.join(dicompath, einame + 'S' + seq + 'I*.dcm') targetimgs = glob.glob(imgnamepattern) if len(targetimgs) == 0: print('No dcm', imgnamepattern) return slices = [int(i.split('I')[-1][:-4]) for i in targetimgs] print(slices, np.min(slices), np.max(slices)) cartimgstack = np.zeros((np.max(slices), 512, 512)) imgnamepattern = os.path.join(dicompath, einame + 'S' + seq + 'I%d.dcm') if np.min(slices) != 1: print('min slice not 1') for slicei in range(np.min(slices), np.max(slices) + 1): imgfilename = imgnamepattern % slicei if os.path.exists(imgfilename): RefDs = pydicom.read_file(imgfilename).pixel_array if RefDs.shape != (512, 512): # print('Not regular size',RefDs.shape) if RefDs.shape[0] == RefDs.shape[1]: RefDs = cv2.resize(RefDs, (512, 512)) else: padarr = np.zeros((max(RefDs.shape), max(RefDs.shape))) padarr[0:RefDs.shape[0], 0:RefDs.shape[1]] = RefDs RefDs = padarr RefDs = cv2.resize(RefDs, (512, 512)) if refscale != 1: RefDs = cv2.resize(RefDs, (0, 0), fx=refscale, fy=refscale) if shift_y != 0 or shift_x != 0: RefDs = croppatch(RefDs, (shift_y + 256), (shift_x + 256), 256, 256) dcmimg = RefDs / np.max(RefDs) cartimgstack[slicei - 1] = dcmimg else: print('no slice img for', imgfilename) return cartimgstack
def display_tracklet_dcm_patch(tracklet, dicomstack, figfilename=None): plt.figure(figsize=(18, 5)) for slicei in range(len(tracklet)): for bb in tracklet[slicei]: dcmslice = croppatch(dicomstack[slicei], bb.y, bb.x, bb.h * 2, bb.w * 2) w = int(round(bb.w / 2)) h = int(round(bb.h / 2)) dcmslice[h, w:w * 3] = np.max(dcmslice) dcmslice[h:h * 3, w] = np.max(dcmslice) dcmslice[h * 3, w:w * 3] = np.max(dcmslice) dcmslice[h:h * 3, w * 3] = np.max(dcmslice) plt.subplot(2, int(np.ceil(len(dicomstack) / 2)), slicei + 1) plt.imshow(dcmslice) plt.title(slicei) if figfilename is not None: plt.savefig(figfilename) else: plt.show() plt.close()
def gen_aug_patch(caseloader, slicei, dcmstack=None): DEBUG = 0 ctx, cty = caseloader.loadct(slicei) if dcmstack is None: dcmstack = caseloader.loadstack(slicei, 'dcm') cartcont = caseloader.load_cart_cont(slicei) #cart_patch = croppatch(dcmstack,cty,ctx,256,256) #plt.imshow(cart_patch) #plt.show() mindist = caseloader.load_cart_min_dist(slicei) dist_pred_ct_gt, nms_dist_map_p_gt = find_nms_center(mindist, fig=1) validy, validx = np.where(nms_dist_map_p_gt[:, :, 0] > 128) ''' for rd in range(5): rnd = random.randint(0,validx.shape[0]-1) ckey = '%d-%d'%(validx[rnd],validy[rnd]) dist_pred_ct_gt[ckey] = nms_dist_map_p_gt[validy[rnd],validx[rnd],0]/256 ''' xypos = [[validx[i], validy[i]] for i in range(len(validx))] kmeans = KMeans(n_clusters=5, random_state=0).fit(xypos) for kmcti in kmeans.cluster_centers_: kmcx = int(round(kmcti[0])) kmcy = int(round(kmcti[1])) ckey = '%d-%d' % (kmcx, kmcy) if nms_dist_map_p_gt[kmcy, kmcx, 0] < 127: print('kmean cluster center not large enough', kmcti, nms_dist_map_p_gt[kmcy, kmcx, 0]) dist_pred_ct_gt[ckey] = nms_dist_map_p_gt[kmcy, kmcx, 0] / 256 #print(dist_pred_ct_gt) aug_patch_obj = {} targetdir = DATADESKTOPdir + '/DVWIMAGES/' aug_cart_patch_batch = [] new_cart_patch_batch_filename = '/casepatch/'+caseloader.art+'/'+caseloader.pjname+'/augcart/'+caseloader.pi+os.path.basename(caseloader.dcmpath(slicei))[:-4]+\ caseloader.side+'.npy' aug_patch_obj['cartpatchbatchname'] = new_cart_patch_batch_filename if not os.path.exists(targetdir + '/casepatch/' + caseloader.art + '/' + caseloader.pjname + '/augcart/'): os.mkdir(targetdir + '/casepatch/' + caseloader.art + '/' + caseloader.pjname + '/augcart/') aug_polar_patch_batch = [] new_polar_patch_batch_filename = '/casepatch/'+caseloader.art+'/'+caseloader.pjname+'/augpolar/'+caseloader.pi+os.path.basename(caseloader.dcmpath(slicei))[:-4]+\ caseloader.side+'.npy' aug_patch_obj['polarpatchbatchname'] = new_polar_patch_batch_filename if not os.path.exists(targetdir + '/casepatch/' + caseloader.art + '/' + caseloader.pjname + '/augpolar/'): os.mkdir(targetdir + '/casepatch/' + caseloader.art + '/' + caseloader.pjname + '/augpolar/') aug_polar_cont_batch = [] new_polar_cont_batch_filename = '/casepatch/'+caseloader.art+'/'+caseloader.pjname+'/augpolarcont/'+caseloader.pi+os.path.basename(caseloader.dcmpath(slicei))[:-4]+\ caseloader.side+'.npy' aug_patch_obj['polarcoutbatchname'] = new_polar_cont_batch_filename if not os.path.exists(targetdir + '/casepatch/' + caseloader.art + '/' + caseloader.pjname + '/augpolarcont/'): os.mkdir(targetdir + '/casepatch/' + caseloader.art + '/' + caseloader.pjname + '/augpolarcont/') aug_patch_obj['auginfo'] = [] #plt.imshow(dcmstack) #plt.title('dcmstack') #plt.show() for cti in dist_pred_ct_gt.keys(): cctx = int(round(int(cti.split('-')[0]) + ctx - 256)) ccty = int(round(int(cti.split('-')[1]) + cty - 256)) #print(cctx,ccty) tx = cctx - ctx ty = ccty - cty new_cart_patch = croppatch(dcmstack, ccty, cctx, 256, 256) #plt.imshow(new_cart_patch[:,:,1]) #plt.title('new_cart_patch') #plt.show() new_cart_patch = new_cart_patch / np.max(new_cart_patch) aug_cart_patch_batch.append(new_cart_patch) new_polar_patch = topolar(new_cart_patch, 256, 256) new_polar_patch = new_polar_patch / np.max(new_polar_patch) #plt.imshow(new_polar_patch) #plt.show() aug_polar_patch_batch.append(new_polar_patch) #rebase polar cont new_polar_cont_lumen = tocordpolar(cartcont[0], cctx, ccty) new_polar_cont_wall = tocordpolar(cartcont[1], cctx, ccty) aug_polar_cont = np.zeros((256, 2)) aug_polar_cont[:, 0] = new_polar_cont_lumen aug_polar_cont[:, 1] = new_polar_cont_wall #exportpolarcontour(new_polar_cont_filename,[new_polar_cont_lumen,new_polar_cont_wall]) aug_polar_cont_batch.append(aug_polar_cont) if DEBUG: #check polar ct matches with cart patch cart_patch_dsp = np.zeros( (new_cart_patch.shape[0], new_cart_patch.shape[1], 3)) cart_patch_dsp[:, :, 0] = new_cart_patch[:, :, new_cart_patch.shape[2] // 2] polarbd = np.concatenate( [new_polar_cont_lumen[:, None], new_polar_cont_wall[:, None]], axis=1) * 256 contourin, contourout = toctbd(polarbd, 256, 256) cart_vw_seg = plotct(512, contourin, contourout) plt.imshow(cart_vw_seg) cart_patch_dsp[:, :, 1] = cart_vw_seg plt.imshow(cart_patch_dsp) plt.show() augobj = {} augobj['ctofx'] = cctx augobj['ctofy'] = ccty augobj['transx'] = tx augobj['transy'] = ty augobj['augid'] = len(aug_cart_patch_batch) - 1 aug_patch_obj['auginfo'].append(copy.copy(augobj)) #print(augobj) np.save(targetdir + new_cart_patch_batch_filename, np.array(aug_cart_patch_batch, dtype=np.float16)) np.save(targetdir + new_polar_patch_batch_filename, np.array(aug_polar_patch_batch, dtype=np.float16)) np.save(targetdir + new_polar_cont_batch_filename, np.array(aug_polar_cont_batch, dtype=np.float16)) caseloader.caselist['slices'][slicei]['augpatch'] = aug_patch_obj with open(caseloader.caselistname, 'wb') as fp: pickle.dump(caseloader.caselist, fp)
def multi_min_dist_pred_withinbb(min_dist_cnn, dicomslice, bb, patchheight, patchwidth): SCALE = 4 octy = bb.y octx = bb.x ENL = 1.2 min_dist_search_patch = multi_min_dist_map(min_dist_cnn, dicomslice, octy, octx, patchheight, patchwidth) min_dist_search_patch_rz = croppatch(min_dist_search_patch, min_dist_search_patch.shape[1] / 2, min_dist_search_patch.shape[0] / 2, bb.h / 2 * SCALE * ENL, bb.w / 2 * SCALE * ENL) min_dist_search_patch_rz = fillpatch(np.zeros(min_dist_search_patch.shape), min_dist_search_patch_rz) #plt.imshow(min_dist_search_patch_rz) #plt.show() #all cts dist_pred_ct, nms_dist_map_p = find_nms_center(min_dist_search_patch_rz, fig=1) #cts = [[int(i.split('-')[0]), int(i.split('-')[1])] for i in list(dist_pred_ct.keys())] nms_dist_map = nms_dist_map_p[:, :, 0] #remove cts in unconnected region nms_dist_map_int = (nms_dist_map / np.max(nms_dist_map) * 255).astype( np.uint8) ret, thresh = cv2.threshold(nms_dist_map_int, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) connectivity = 4 output = cv2.connectedComponentsWithStats(thresh, connectivity, cv2.CV_32S) # The second cell is the label matrix component_label = output[1] #plt.imshow(component_label) #plt.show() center_component_label = component_label[256, 256] cts = [] for i in dist_pred_ct: if component_label[int(i.split('-')[1]), int(i.split('-')[0])] != center_component_label: #print('ignore unconnected nms ct') continue cts.append([int(i.split('-')[0]), int(i.split('-')[1])]) search_patch_center = croppatch(dicomslice, octy, octx, patchheight, patchwidth) search_patch_center_rz = cv2.resize(search_patch_center, (0, 0), fx=4, fy=4) # search_patch_center_rz = search_patch_center nms_dist_map_p[:, :, 2] = search_patch_center_rz / np.max( search_patch_center_rz) * 255 if len(cts) == 0: #print('no ct') pass elif len(cts) != 1: pass '''print('cts', len(cts)) plt.subplot(1,2,1) plt.title('component_label') plt.imshow(component_label) plt.subplot(1, 2, 2) plt.title('nms_dist_map_p') plt.imshow(nms_dist_map_p) plt.show()''' return cts, nms_dist_map_p
def polar_seg_slice(vwcnn, vwcfg, dicomstack, slicei, cts, SCALE=4, usegrad=False, ITERPRED=True, cartreg=False): DEBUG = 0 contourin = None contourout = None polarconsistency = None recenters = None if DEBUG: print(slicei, 'cts', cts) if len(cts) == 1: ct = cts[0] ctx = ct[0] cty = ct[1] if ITERPRED == False: cartstack = crop_dcm_stack(dicomstack, slicei, cty, ctx, 64, 1) cartstack = cartstack / np.max(cartstack) polarimg = topolar(cartstack, 64, 64) polar_img_rz = cv2.resize(polarimg, (0, 0), fx=SCALE, fy=SCALE) if cartreg: cart_stack_rz = cv2.resize(cartstack, (0, 0), fx=2, fy=2) contbd = vwcnn.predict(cart_stack_rz[None, :, :, :, None])[0] * 256 contour_in_roi = contbd[:, :2] + 256 contour_out_roi = contbd[:, 2:] + 256 contourin = to_dcm_cord(contour_in_roi, BB(ctx, cty, 0, 0)) contourout = to_dcm_cord(contour_out_roi, BB(ctx, cty, 0, 0)) polarconsistency = None else: polarbd, polarsd = polar_pred_cont_cst(polar_img_rz[:, :, :, None], vwcfg, vwcnn, usegrad=usegrad) polarconsistency = 1 - np.mean(polarsd, axis=0) # contour positions [x, y] in original dicom space contourin, contourout = toctbd(polarbd / SCALE, ctx, cty) if DEBUG: contourin, contourout = toctbd(polarbd, 256, 256) cartvw = plotct(512, contourin, contourout) plt.subplot(1, 2, 1) cartstackdsp = copy.copy(cartstack) cartstackdsp[:, cartstackdsp.shape[1] // 2] = np.max(cartstackdsp) cartstackdsp[cartstackdsp.shape[0] // 2] = np.max(cartstackdsp) plt.imshow(cartstackdsp) plt.subplot(1, 2, 2) cartvwdsp = copy.copy(cartvw) cartvwdsp[:, cartvwdsp.shape[1] // 2] = np.max(cartvwdsp) cartvwdsp[cartvwdsp.shape[0] // 2] = np.max(cartvwdsp) plt.imshow(cartvwdsp) plt.show() # offset in 512 dcm cordinate else: # iterative pred and refine center REPS = 20 MOVEFACTOR = 0.5 MAX_OFF_STEP = 1 mincty = cty minctx = ctx mindiff = np.inf lastoffset = [0, 0] if DEBUG: print('init', ctx, cty) for repi in range(REPS): if repi == REPS - 1 or MAX_OFF_STEP < 0.1: if DEBUG: print('use min diff cts', minctx, mincty, 'with mindiff', mindiff) cty = mincty ctx = minctx cartstack = crop_dcm_stack(dicomstack, slicei, cty, ctx, 64, 1) cartstack = cartstack / np.max(cartstack) polarimg = topolar(cartstack, 64, 64) polar_img_rz = cv2.resize(polarimg, (0, 0), fx=SCALE, fy=SCALE) polarbd, polarsd = polar_pred_cont_cst(polar_img_rz[:, :, :, None], vwcfg, vwcnn, usegrad=usegrad) if DEBUG: contourin, contourout = toctbd(polarbd, 256, 256) cartvw = plotct(512, contourin, contourout) plt.subplot(1, 2, 1) cartstackdsp = copy.copy(cartstack) cartstackdsp[:, cartstackdsp.shape[1] // 2] = np.max(cartstackdsp) cartstackdsp[cartstackdsp.shape[0] // 2] = np.max(cartstackdsp) plt.imshow(cartstackdsp) plt.subplot(1, 2, 2) cartvwdsp = copy.copy(cartvw) cartvwdsp[:, cartvwdsp.shape[1] // 2] = np.max(cartvwdsp) cartvwdsp[cartvwdsp.shape[0] // 2] = np.max(cartvwdsp) plt.imshow(cartvwdsp) plt.show() #offset in 512 dcm cordinate polar_cont_offset = cal_polar_offset(polarbd) if repi > 0 and MAX_OFF_STEP > 0.1 and lastoffset == polar_cont_offset: #MAX_OFF_STEP += 0.5/SCALE MAX_OFF_STEP = 0.09 if DEBUG: print('move ct no change', MAX_OFF_STEP, lastoffset, polar_cont_offset) continue cofftype = [polar_cont_offset[0] > 0, polar_cont_offset[1] > 0] if repi == 0: offtype = cofftype if repi > 2 and cofftype != offtype: MAX_OFF_STEP /= 2 if DEBUG: print('reduce max move to', MAX_OFF_STEP) offtype = cofftype polarconsistency = 1 - np.mean(polarsd, axis=0) #ccstl = polarconsistency[0] #ccstw = polarconsistency[1] # contour positions [x, y] in original dicom space contourin, contourout = toctbd(polarbd / SCALE, ctx, cty) cdif = np.max(abs(np.array(polar_cont_offset))) if cdif < 1 or MAX_OFF_STEP < 0.1: if cdif < 1: print(polar_cont_offset) print('==', slicei, 'Break tracklet ref', polar_cont_offset) break if cdif < mindiff: mindiff = cdif mincty = cty minctx = ctx cofx = polar_cont_offset[0] cofy = polar_cont_offset[1] if abs(polar_cont_offset[0]) < 1: cofx = 0 if abs(polar_cont_offset[1]) < 1: cofy = 0 if repi < 2: ctx += cofx * MOVEFACTOR cty += cofy * MOVEFACTOR else: ctx = ctx + max(-MAX_OFF_STEP, min(MAX_OFF_STEP, cofx * MOVEFACTOR)) cty = cty + max(-MAX_OFF_STEP, min(MAX_OFF_STEP, cofy * MOVEFACTOR)) print('repeat', repi, 'offset', polar_cont_offset, ctx, cty) lastoffset = polar_cont_offset recenters = [[ctx, cty]] else: print('multiple cts', cts) all_contour_in = [] all_contour_out = [] all_polarconsistency = [] for ct in cts: ctx = ct[0] cty = ct[1] cartstack = crop_dcm_stack(dicomstack, slicei, cty, ctx, 64, 1) cartstack = cartstack / np.max(cartstack) polarimg = topolar(cartstack, 64, 64) polar_img_rz = cv2.resize(polarimg, (0, 0), fx=SCALE, fy=SCALE) polarbd, polarsd = polar_pred_cont_cst(polar_img_rz[:, :, :, None], vwcfg, vwcnn) polarconsistency_c = 1 - np.mean(polarsd, axis=0) contour_in_c, contour_out_c = toctbd(polarbd / SCALE, ctx, cty) if DEBUG: plt.subplot(1, 2, 1) plt.imshow(polarimg) plt.subplot(1, 2, 2) polarseg = plotpolar(256, polarbd) plt.imshow(polarseg) plt.show() all_contour_in.append(contour_in_c) all_contour_out.append(contour_out_c) all_polarconsistency.append(polarconsistency_c) contourin = mergecont(all_contour_in, cts) contourout = mergecont(all_contour_out, cts) ctx, cty = np.mean(cts, axis=0) if DEBUG: seg_vw_merge = plotct(512 * SCALE, contourin * SCALE, contourout * SCALE) predseg = croppatch(seg_vw_merge, cty * SCALE, ctx * SCALE, 64 * SCALE, 64 * SCALE) plt.imshow(predseg) plt.show() polarconsistency = np.mean(np.array(all_polarconsistency), axis=0) #no center adjustment for multi ct prediction recenters = [[ctx, cty]] return [contourin, contourout], polarconsistency, recenters
def data_generator(config, exams, aug): xarray = np.zeros([ config['batchsize'], config['height'], config['width'], config['depth'], config['channel'] ]) yarray = np.zeros([config['batchsize'], config['patchheight'], 2]) bi = 0 while 1: for ei in exams: caseloader = CaseLoader(ei) # print(caseloader) for slicei in caseloader.slices: if caseloader.valid_slicei(slicei) == False: continue cartcont = caseloader.load_cart_cont(slicei) min_dist_map = caseloader.load_cart_min_dist(slicei) dcmstack = caseloader.loadstack(slicei, 'dcm') ctx, cty = caseloader.loadct(slicei) aug_polar_patch_batch = np.zeros((10, 256, 256, 3)) aug_polar_cont_batch = np.zeros((10, 256, 2)) offi = 0 mi = 0 while offi < 10 and mi < 50: mi += 1 if offi == 0: ofx = 0 ofy = 0 else: ofx = random.randint(-20, 20) ofy = random.randint(-20, 20) cctx = ofx + ctx ccty = ofy + cty if min_dist_map[256 + ofx, 256 + ofy] < 0.5: if mi > 40: print('mi', mi) # print('mind',min_dist_map[256+ofx,256+ofy]) continue # print(ofx,ofy) new_polar_cont_lumen = tocordpolar(cartcont[0], cctx, ccty) new_polar_cont_wall = tocordpolar(cartcont[1], cctx, ccty) aug_polar_cont_batch[offi, :, 0] = new_polar_cont_lumen aug_polar_cont_batch[offi, :, 1] = new_polar_cont_wall new_cart_patch = croppatch(dcmstack, ccty, cctx, 256, 256) new_polar_patch = topolar(new_cart_patch, 256, 256) aug_polar_patch_batch[offi] = new_polar_patch / np.max( new_polar_patch) offi += 1 # aug_polar_patch_batch = caseloader.load_aug_patch(slicei,'polar_patch') # aug_polar_cont_batch = caseloader.load_aug_patch(slicei,'polar_cont') for augi in range(len(aug_polar_patch_batch)): xarray[bi] = aug_polar_patch_batch[augi, :, :, :, None] yarray[bi] = aug_polar_cont_batch[augi] bi += 1 if bi == config['batchsize']: bi = 0 if aug == True: for offi in random.sample(range(config['height']), config['rottimes']): xarray_off = batch_polar_rot(xarray, offi) yarray_off = batch_polar_rot(yarray, offi) yield (xarray_off, yarray_off) else: yield (xarray, yarray)
def bb_match_score(featuremodel, imgstack, pos1, pos2): patch1 = croppatch(imgstack[pos1[2]], pos1[1], pos1[0], 64, 64) patch2 = croppatch(imgstack[pos2[2]], pos2[1], pos2[0], 64, 64) return bb_match_score_patch(featuremodel, patch1, patch2)