def slice_by_label(): ''' :return: ''' files = glob.glob("../datasets/myo_data/train25_myops_gd_convert/*.nii.gz") output_lab_dir = "../datasets/myo_data/train25_myops_gd_crop/" output_img_dir = "../datasets/myo_data/train25_myops_crop/" mkdir_if_not_exist(output_lab_dir) for i in files: lab = sitk.ReadImage(i) #先转化成统一space,保证crop的大小一致 lab = sitkResample3DV2(lab, sitk.sitkNearestNeighbor, [1, 1, 1]) bbox = get_bounding_boxV2(sitk.GetArrayFromImage(lab), padding=10) ##extend bbox crop_lab = crop_by_bbox(lab, bbox) crop_lab = sitkResize3DV2(crop_lab, [256, 256, crop_lab.GetSize()[-1]], sitk.sitkNearestNeighbor) sitk_write_image(crop_lab[:, :, crop_lab.GetSize()[-1] // 2], dir=output_lab_dir, name=os.path.basename(i)) img_file = glob.glob( "../datasets/myo_data/train25_convert/*%s*.nii.gz" % (os.path.basename(i).split("_")[2])) for j in img_file: img = sitk.ReadImage(j) img = sitkResample3DV2(img, sitk.sitkLinear, [1, 1, 1]) crop_img = crop_by_bbox(img, bbox) crop_img = sitkResize3DV2( crop_img, [256, 256, crop_img.GetSize()[-1]], sitk.sitkLinear) sitk_write_image(crop_img[:, :, crop_lab.GetSize()[-1] // 2], dir=output_img_dir, name=os.path.basename(j))
def rotate_img_to_same_direction(imgs, labs, output_imgs_dir, output_labs_dir): for p_img,p_lab in zip(imgs,labs): mv_img = sitk.ReadImage(p_img) mv_lab = sitk.ReadImage(p_lab) # ref_lab=recast_pixel_val(mv_lab,ref_lab) ref_lab = get_rotate_ref_img(mv_lab) ref_img = get_rotate_ref_img(mv_img) # multi_slice_viewer(mv_img) initial_transform = sitk.CenteredTransformInitializer(ref_lab, mv_lab, sitk.Euler3DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY) # mv_label_resampled=mv_lab # mv_img_resampled=mv_img # uncomment the code below if u wanna preregistration mv_label_resampled = sitk.Resample(mv_lab, ref_lab, initial_transform, sitk.sitkNearestNeighbor, 0, mv_lab.GetPixelID()) initial_transform = sitk.CenteredTransformInitializer(ref_img, mv_img, sitk.Euler3DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY) mv_img_resampled = sitk.Resample(mv_img, ref_img, initial_transform, sitk.sitkLinear, 0, mv_img.GetPixelID()) sitk_write_image(mv_img_resampled,dir=output_imgs_dir,name=os.path.basename(p_img)) sitk_write_image(mv_label_resampled, dir=output_labs_dir, name=os.path.basename(p_lab))
def gen_one_batch(self, genSample, img_fix, lab_fix, img_mv, lab_mv, output_dir): fix_imgs, fix_labs, mv_imgs, mv_labs = genSample.get_batch_data_V2( [img_mv], [img_fix], [lab_mv], [lab_fix]) trainFeed = self.create_feed_dict(fix_imgs, fix_labs, mv_imgs, mv_labs) input_mv_label, \ input_fix_label, \ warp_mv_label, \ input_mv_img, \ input_fix_img, \ warp_mv_img = self.sess.run([self.input_MV_label, self.input_FIX_label, self.warped_MV_label, self.input_MV_image, self.input_FIX_image, self.warped_MV_image], feed_dict=trainFeed) param = sitk.ReadImage(img_fix) sitk_write_image(input_fix_img[0, ...], param, output_dir, get_name_wo_suffix(img_fix)) sitk_write_lab(input_fix_label[0, ...], param, output_dir, get_name_wo_suffix(img_fix).replace('image', 'label')) sitk_write_image(warp_mv_img[0, ...], param, output_dir, get_name_wo_suffix(img_mv)) sitk_write_lab(warp_mv_label[0, ...], param, output_dir, get_name_wo_suffix(img_mv).replace('image', 'label')) ddf_mv_fix = self.sess.run(self.ddf_MV_FIX, feed_dict=trainFeed) _, _, neg2 = neg_jac(ddf_mv_fix[0, ...]) self.logger.debug("neg_jac %d" % (neg2)) ds = calculate_binary_dice(warp_mv_label, input_fix_label) hd = calculate_binary_hd(warp_mv_label, input_fix_label, spacing=param.GetSpacing()) return ds, hd
def generate_3dCT(image_dirs, lab_dirs, type='ct'): for img_dir, lab_dir in zip(image_dirs, lab_dirs): # files=sort_glob(img_dir+"\\*") img = sitk_read_dico_series(img_dir) lab = read_png_seriesV2(lab_dir) sitk_write_image(img, dir=output + "//%s-image//" % (type), name=img_dir.split('\\')[-2] + "_%s_image" % (type)) sitk_write_lab(lab, parameter_img=img, dir=output + "//%s-label//" % (type), name=lab_dir.split('\\')[-2] + "_%s_label" % (type))
def resize_img_lab(imgs,labs,output_img,output_lab,zscore): for p_img,p_lab in zip(imgs,labs): img_obj=sitk.ReadImage(p_img) lab_obj=sitk.ReadImage(p_lab) resize_lab = sitkResize3DV2(lab_obj, [96, 96, 96], sitk.sitkNearestNeighbor) sitk_write_image(resize_lab, dir=output_lab, name=os.path.basename(p_lab)) resize_img = sitkResize3DV2(img_obj, [96, 96, 96], sitk.sitkLinear) if zscore: resize_img=sitk.RescaleIntensity(resize_img) resize_img=sitk.Normalize(resize_img) sitk_write_image(resize_img, dir=output_img, name=os.path.basename(p_img))
def generate_3dMR(image_dirs, lab_dirs, type='mr'): for img_dir, lab_dir in zip(image_dirs, lab_dirs): # files=sort_glob(img_dir+"\\*") img = sitk_read_dico_series(img_dir) lab = read_png_series(lab_dir) sitk_write_image(img, dir=output + "//%s-image//" % (type), name=img_dir.split('\\')[-4] + "_%s_image" % (type)) lab_low = np.where(lab >= 55, 1, 0) lab_up = np.where(lab <= 70, 1, 0) lab = lab_low * lab_up sitk_write_lab(lab, parameter_img=img, dir=output + "//%s-label//" % (type), name=lab_dir.split('\\')[-3] + "_%s_label" % (type))
def de_rotate(refs, tgts, output_dir): for r, t in zip(refs, tgts): ref = sitk.ReadImage(r) lab = sitk.ReadImage(t) ref = recast_pixel_val(lab, ref) initial_transform = sitk.CenteredTransformInitializer( ref, lab, sitk.Euler3DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY) lab_resampled = sitk.Resample(lab, ref, initial_transform, sitk.sitkNearestNeighbor, 0, lab.GetPixelID()) sitk_write_image(lab_resampled, dir=output_dir, name=os.path.basename(t))
def de_crop(refs, tgts, output_dir, structure): for r, t in zip(refs, tgts): img = sitk.ReadImage(r) label = sitk.ReadImage(t) label = resample_segmentations(img, label) blank_img = sitk.Image(img.GetSize(), label.GetPixelIDValue()) blank_img.CopyInformation(img) label_in_orignial_img = paste_roi_image(blank_img, label) # 标签重新转换成205或者其他对应的值 convert = sitk.GetArrayFromImage(label_in_orignial_img).astype( np.uint16) convert = np.where(convert == 1, structure, convert) convert_img = sitk.GetImageFromArray(convert) convert_img.CopyInformation(label_in_orignial_img) sitk_write_image(convert_img, dir=output_dir, name=os.path.basename(t))
def crop_img_lab(input_imgs, input_labs, crop_imgs_dir, crop_lab_fix, id): for p_img, p_lab in zip(input_imgs, input_labs): lab = sitk.ReadImage(p_lab) # 先转化成统一space,保证crop的大小原始物理尺寸一致 lab = sitkResample3DV2(lab, sitk.sitkNearestNeighbor, [1, 1, 1]) bbox = get_bounding_box_by_id(sitk.GetArrayFromImage(lab), padding=10,id=id) ##extend bbox crop_lab = crop_by_bbox(lab, bbox) # crop_lab = sitkResize3DV2(crop_lab, [96, 96, 96], sitk.sitkNearestNeighbor) sitk_write_image(crop_lab, dir=crop_lab_fix, name=os.path.basename(p_lab)) # img = sitk.ReadImage(p_img) img = sitkResample3DV2(img, sitk.sitkLinear, [1, 1, 1]) crop_img = crop_by_bbox(img, bbox) # crop_img = sitkResize3DV2(crop_img, [96, 96, 96], sitk.sitkLinear) sitk_write_image(crop_img, dir=crop_imgs_dir, name=os.path.basename(p_img))
def sample_network(self, itr): # p_target_img, p_target_lab, p_atlas_imgses, p_atlas_labses=self.valid_sampler.get_file() # target_img, target_lab, atlas_imgses, atlas_labses=self.valid_sampler.get_data(p_target_img, p_target_lab, p_atlas_imgses, p_atlas_labses) target_img, target_lab, atlas_imgses, atlas_labses = self.valid_sampler.next_sample( ) feed_dict = {self.ph_atlas: atlas_labses, self.ph_gt: target_lab} summary, out, gt = self.sess.run( [self.summary, self.output, self.ph_gt], feed_dict=feed_dict) out = np.argmax(out, axis=-1) sitk_write_lab(out[0, ...], dir=self.args.sample_dir, name=str(itr) + "pred") gt = np.argmax(gt, axis=-1) sitk_write_lab(gt[0, ...], dir=self.args.sample_dir, name=str(itr) + "gt") sitk_write_image(np.squeeze(target_img[0, ...]), dir=self.args.sample_dir, name=str(itr) + "img") dice = dc(out[0, ...], gt[0, ...]) print("dc:%f" % (dice)) return dice
def gen_one_batch(self, genSample, img_fix, lab_fix, img_mv, lab_mv, output_dir): fix_imgs, fix_labs, mv_imgs, mv_labs = genSample.get_batch_data_V2( [img_mv], [img_fix], [lab_mv], [lab_fix]) feed = self.create_feed_dict(fix_imgs, fix_labs, mv_imgs, mv_labs) input_mv_label, \ input_fix_label, \ warp_mv_label, \ input_mv_img, \ input_fix_img, \ warp_mv_img,\ fw,\ bw= self.sess.run([self.input_MV_label, self.input_FIX_label, self.warped_MV_label, self.input_MV_image, self.input_FIX_image, self.warped_MV_image, self.theta_bw, self.theta_fw], feed_dict=feed) # param=sitk.ReadImage(img_fix) param = None sitk_write_image( input_fix_img[0, ...], param, output_dir, get_name_wo_suffix(img_fix).replace('image', 'target_image')) sitk_write_lab( input_fix_label[0, ...], param, output_dir, get_name_wo_suffix(lab_fix).replace('label', 'target_label')) sitk_write_image( warp_mv_img[0, ...], param, output_dir, get_name_wo_suffix(img_mv).replace('image', 'atlas_image')) sitk_write_lab( warp_mv_label[0, ...], param, output_dir, get_name_wo_suffix(lab_mv).replace('label', 'atlas_label')) resotre_mv, restore_fix = self.sess.run( [self.restore_MV_image, self.restore_FIX_image], feed_dict=feed) sitk_write_image(input_mv_img[0, ...], param, output_dir, get_name_wo_suffix(img_mv)) sitk_write_image( resotre_mv[0, ...], param, output_dir, get_name_wo_suffix(img_mv).replace('image', 'restore')) sitk_write_image( restore_fix[0, ...], param, output_dir, get_name_wo_suffix(img_fix).replace('image', 'restore')) print(fw) # contour= np.where(warp_mv_label> 0.5, 1, 0) # contour= contour.astype(np.uint16) # contour=sitk.GetImageFromArray(np.squeeze(contour)) # contour=sitk.LabelContour(contour,True) # sitk_write_lab(sitk.GetArrayFromImage(contour),param , output_dir,get_name_wo_suffix(img_mv).replace('image','contour')) # ddf_fix_mv, ddf_mv_fix = self.sess.run([self.ddf_FIX_MV, self.ddf_MV_FIX], feed_dict=feed) # _, _, neg1 = neg_jac(ddf_fix_mv[0, ...]) # _, _, neg2 = neg_jac(ddf_mv_fix[0, ...]) ds = calculate_binary_dice(warp_mv_label, input_fix_label) bf_ds = calculate_binary_dice(input_mv_label, input_fix_label) return bf_ds, ds
def fusion_one_target(self, itr): target_img_batch, target_lab_batch, atlas_img_batch, atlas_lab_batch, sim_batch, p_fix_img, p_fix_lab = self.validate_sampler.next_sample_4_fusion( ) sims = [] param = sitk.ReadImage(p_fix_img[0]) sitk_write_lab(np.squeeze(target_lab_batch.astype(np.uint8)), param, dir=self.args.validate_dir, name=get_name_wo_suffix(p_fix_lab[0])) sitk_write_image(np.squeeze(target_img_batch), param, dir=self.args.validate_dir, name=get_name_wo_suffix(p_fix_img[0])) for i in range(atlas_lab_batch.shape[-2]): trainFeed = { self.ph_target_image: target_img_batch, self.ph_target_label: target_lab_batch, self.ph_atlas_label: atlas_lab_batch[..., i, :], self.ph_gt_dicesim: sim_batch[..., i, :] } target_img, atlas_label, gt, target_lab, pred_lab, pred_sim = self.sess.run( [ self.ph_target_image, self.ph_atlas_label, self.ph_gt_dicesim, self.ph_target_label, self.predict_label, self.predict_sim ], feed_dict=trainFeed) sims.append(pred_sim) sitk_write_lab(np.squeeze(atlas_lab_batch[..., i, :]), param, dir=self.args.validate_dir, name=get_name_wo_suffix(p_fix_lab[0].replace( 'label', 'label_' + str(i)))) sitk_write_image(np.squeeze(pred_sim), param, dir=self.args.validate_dir, name=get_name_wo_suffix(p_fix_lab[0]).replace( 'label', 'sim_' + str(i))) sims = np.stack(sims, -1) u_lab = np.unique(target_lab.astype(np.uint8)) LabelStats = np.zeros((len(u_lab), ) + np.squeeze(target_lab).shape) for i, lab in enumerate(u_lab): LabelStats[i] = np.sum( (np.squeeze(atlas_lab_batch) == lab).astype(np.int16) * np.squeeze(sims), axis=-1) fusion_label = u_lab[np.argmax(LabelStats, axis=0)] ds = calculate_binary_dice(fusion_label, target_lab_batch) hd = calculate_binary_hd(fusion_label, target_lab_batch, spacing=param.GetSpacing()) # sitk_write_image(np.squeeze(target_img_batch),param,dir=os.path.dirname(p_fix_lab[0]),name=get_name_wo_suffix(p_fix_img[0])) sitk_write_lab(np.squeeze(fusion_label).astype(np.uint8), param, dir=os.path.dirname(p_fix_lab[0]), name=get_name_wo_suffix(p_fix_lab[0]).replace( 'label', 'net_fusion_label')) return ds, hd
def gen_warp_atlas(self, dir, genSample, i, img_fix, lab_fix, is_aug=True): output_dir = self.args.gen_dir + "/" + dir + "/target_" + str( i) + "_" + get_name_wo_suffix(img_fix) mk_or_cleardir(output_dir) params = [] input_fix_imgs = [] input_fix_labels = [] warp_mv_imgs = [] warp_mv_labels = [] losses = [] sims = [] for img_mv, lab_mv in zip(genSample.img_mv, genSample.lab_mv): fix_imgs, fix_labs, mv_imgs, mv_labs = genSample.get_batch_data_V2( [img_mv], [img_fix], [lab_mv], [lab_fix]) trainFeed = self.create_feed_dict(fix_imgs, fix_labs, mv_imgs, mv_labs, is_aug) input_mv_label, \ input_fix_label, \ warp_mv_label, \ input_mv_img, \ input_fix_img, \ warp_mv_img = self.sess.run([self.input_MV_label, self.input_FIX_label, self.warped_MV_label, self.input_MV_image, self.input_FIX_image, self.warped_MV_image], feed_dict=trainFeed) input_fix_label = np.where(input_fix_label > 0.5, 1, 0) warp_mv_label = np.where(warp_mv_label > 0.5, 1, 0) sims.append( calculate_binary_dice(input_fix_label[0, ...], warp_mv_label[0, ...])) param = sitk.ReadImage(img_fix) params.append(param) input_fix_imgs.append(input_fix_img[0, ...]) input_fix_labels.append(input_fix_label[0, ...]) warp_mv_imgs.append(warp_mv_img[0, ...]) warp_mv_labels.append(warp_mv_label[0, ...]) losses.append( conditional_entropy_label_over_image( np.squeeze(input_fix_img[0, ...]), np.squeeze(warp_mv_label[0, ...]))) indexs = np.argsort(losses) for ind in indexs: sitk_write_image(warp_mv_imgs[ind], params[ind], dir=output_dir, name=str(ind) + "_mv_img") sitk_write_lab(warp_mv_labels[ind], params[ind], dir=output_dir, name=str(ind) + "_mv_lab") sitk_write_image(input_fix_imgs[0], params[0], dir=output_dir, name=str(0) + "_fix_img") sitk_write_lab(input_fix_labels[0], params[0], dir=output_dir, name=str(0) + "_fix_lab") self.logger.debug(sims) self.logger.debug("%s %f -> %f" % (output_dir, np.mean(sims), np.mean([sims[ind] for ind in indexs[:5]])))