Exemple #1
0
def slice_by_label():
    '''
    :return:
    '''
    files = glob.glob("../datasets/myo_data/train25_myops_gd_convert/*.nii.gz")
    output_lab_dir = "../datasets/myo_data/train25_myops_gd_crop/"
    output_img_dir = "../datasets/myo_data/train25_myops_crop/"
    mkdir_if_not_exist(output_lab_dir)
    for i in files:
        lab = sitk.ReadImage(i)
        #先转化成统一space,保证crop的大小一致
        lab = sitkResample3DV2(lab, sitk.sitkNearestNeighbor, [1, 1, 1])
        bbox = get_bounding_boxV2(sitk.GetArrayFromImage(lab), padding=10)
        ##extend bbox
        crop_lab = crop_by_bbox(lab, bbox)
        crop_lab = sitkResize3DV2(crop_lab,
                                  [256, 256, crop_lab.GetSize()[-1]],
                                  sitk.sitkNearestNeighbor)
        sitk_write_image(crop_lab[:, :, crop_lab.GetSize()[-1] // 2],
                         dir=output_lab_dir,
                         name=os.path.basename(i))
        img_file = glob.glob(
            "../datasets/myo_data/train25_convert/*%s*.nii.gz" %
            (os.path.basename(i).split("_")[2]))
        for j in img_file:
            img = sitk.ReadImage(j)
            img = sitkResample3DV2(img, sitk.sitkLinear, [1, 1, 1])
            crop_img = crop_by_bbox(img, bbox)
            crop_img = sitkResize3DV2(
                crop_img, [256, 256, crop_img.GetSize()[-1]], sitk.sitkLinear)
            sitk_write_image(crop_img[:, :, crop_lab.GetSize()[-1] // 2],
                             dir=output_img_dir,
                             name=os.path.basename(j))
Exemple #2
0
def rotate_img_to_same_direction(imgs, labs, output_imgs_dir, output_labs_dir):
    for p_img,p_lab in zip(imgs,labs):
        mv_img = sitk.ReadImage(p_img)
        mv_lab = sitk.ReadImage(p_lab)
        # ref_lab=recast_pixel_val(mv_lab,ref_lab)
        ref_lab = get_rotate_ref_img(mv_lab)
        ref_img = get_rotate_ref_img(mv_img)

        # multi_slice_viewer(mv_img)

        initial_transform = sitk.CenteredTransformInitializer(ref_lab,
                                                              mv_lab,
                                                              sitk.Euler3DTransform(),
                                                              sitk.CenteredTransformInitializerFilter.GEOMETRY)
        # mv_label_resampled=mv_lab
        # mv_img_resampled=mv_img
        # uncomment the code below if u wanna preregistration
        mv_label_resampled = sitk.Resample(mv_lab, ref_lab, initial_transform, sitk.sitkNearestNeighbor, 0,
                                           mv_lab.GetPixelID())

        initial_transform = sitk.CenteredTransformInitializer(ref_img,
                                                              mv_img,
                                                              sitk.Euler3DTransform(),
                                                              sitk.CenteredTransformInitializerFilter.GEOMETRY)

        mv_img_resampled = sitk.Resample(mv_img, ref_img, initial_transform, sitk.sitkLinear, 0,
                                         mv_img.GetPixelID())
        sitk_write_image(mv_img_resampled,dir=output_imgs_dir,name=os.path.basename(p_img))
        sitk_write_image(mv_label_resampled, dir=output_labs_dir, name=os.path.basename(p_lab))
    def gen_one_batch(self, genSample, img_fix, lab_fix, img_mv, lab_mv,
                      output_dir):
        fix_imgs, fix_labs, mv_imgs, mv_labs = genSample.get_batch_data_V2(
            [img_mv], [img_fix], [lab_mv], [lab_fix])
        trainFeed = self.create_feed_dict(fix_imgs, fix_labs, mv_imgs, mv_labs)
        input_mv_label, \
        input_fix_label, \
        warp_mv_label, \
        input_mv_img, \
        input_fix_img, \
        warp_mv_img = self.sess.run([self.input_MV_label,
                                     self.input_FIX_label,
                                     self.warped_MV_label,
                                     self.input_MV_image,
                                     self.input_FIX_image,
                                     self.warped_MV_image], feed_dict=trainFeed)
        param = sitk.ReadImage(img_fix)
        sitk_write_image(input_fix_img[0, ...], param, output_dir,
                         get_name_wo_suffix(img_fix))
        sitk_write_lab(input_fix_label[0, ...], param, output_dir,
                       get_name_wo_suffix(img_fix).replace('image', 'label'))

        sitk_write_image(warp_mv_img[0, ...], param, output_dir,
                         get_name_wo_suffix(img_mv))
        sitk_write_lab(warp_mv_label[0, ...], param, output_dir,
                       get_name_wo_suffix(img_mv).replace('image', 'label'))

        ddf_mv_fix = self.sess.run(self.ddf_MV_FIX, feed_dict=trainFeed)
        _, _, neg2 = neg_jac(ddf_mv_fix[0, ...])
        self.logger.debug("neg_jac  %d" % (neg2))
        ds = calculate_binary_dice(warp_mv_label, input_fix_label)
        hd = calculate_binary_hd(warp_mv_label,
                                 input_fix_label,
                                 spacing=param.GetSpacing())
        return ds, hd
Exemple #4
0
def generate_3dCT(image_dirs, lab_dirs, type='ct'):
    for img_dir, lab_dir in zip(image_dirs, lab_dirs):
        # files=sort_glob(img_dir+"\\*")
        img = sitk_read_dico_series(img_dir)
        lab = read_png_seriesV2(lab_dir)
        sitk_write_image(img,
                         dir=output + "//%s-image//" % (type),
                         name=img_dir.split('\\')[-2] + "_%s_image" % (type))
        sitk_write_lab(lab,
                       parameter_img=img,
                       dir=output + "//%s-label//" % (type),
                       name=lab_dir.split('\\')[-2] + "_%s_label" % (type))
Exemple #5
0
def resize_img_lab(imgs,labs,output_img,output_lab,zscore):

    for p_img,p_lab in zip(imgs,labs):
        img_obj=sitk.ReadImage(p_img)
        lab_obj=sitk.ReadImage(p_lab)
        resize_lab = sitkResize3DV2(lab_obj, [96, 96, 96], sitk.sitkNearestNeighbor)
        sitk_write_image(resize_lab, dir=output_lab, name=os.path.basename(p_lab))
        resize_img = sitkResize3DV2(img_obj, [96, 96, 96], sitk.sitkLinear)
        if zscore:
            resize_img=sitk.RescaleIntensity(resize_img)
            resize_img=sitk.Normalize(resize_img)
        sitk_write_image(resize_img, dir=output_img, name=os.path.basename(p_img))
Exemple #6
0
def generate_3dMR(image_dirs, lab_dirs, type='mr'):
    for img_dir, lab_dir in zip(image_dirs, lab_dirs):
        # files=sort_glob(img_dir+"\\*")
        img = sitk_read_dico_series(img_dir)
        lab = read_png_series(lab_dir)
        sitk_write_image(img,
                         dir=output + "//%s-image//" % (type),
                         name=img_dir.split('\\')[-4] + "_%s_image" % (type))

        lab_low = np.where(lab >= 55, 1, 0)
        lab_up = np.where(lab <= 70, 1, 0)
        lab = lab_low * lab_up
        sitk_write_lab(lab,
                       parameter_img=img,
                       dir=output + "//%s-label//" % (type),
                       name=lab_dir.split('\\')[-3] + "_%s_label" % (type))
Exemple #7
0
def de_rotate(refs, tgts, output_dir):
    for r, t in zip(refs, tgts):
        ref = sitk.ReadImage(r)
        lab = sitk.ReadImage(t)
        ref = recast_pixel_val(lab, ref)
        initial_transform = sitk.CenteredTransformInitializer(
            ref, lab, sitk.Euler3DTransform(),
            sitk.CenteredTransformInitializerFilter.GEOMETRY)

        lab_resampled = sitk.Resample(lab, ref, initial_transform,
                                      sitk.sitkNearestNeighbor, 0,
                                      lab.GetPixelID())

        sitk_write_image(lab_resampled,
                         dir=output_dir,
                         name=os.path.basename(t))
Exemple #8
0
def de_crop(refs, tgts, output_dir, structure):

    for r, t in zip(refs, tgts):
        img = sitk.ReadImage(r)
        label = sitk.ReadImage(t)
        label = resample_segmentations(img, label)
        blank_img = sitk.Image(img.GetSize(), label.GetPixelIDValue())
        blank_img.CopyInformation(img)
        label_in_orignial_img = paste_roi_image(blank_img, label)
        # 标签重新转换成205或者其他对应的值
        convert = sitk.GetArrayFromImage(label_in_orignial_img).astype(
            np.uint16)
        convert = np.where(convert == 1, structure, convert)
        convert_img = sitk.GetImageFromArray(convert)
        convert_img.CopyInformation(label_in_orignial_img)
        sitk_write_image(convert_img, dir=output_dir, name=os.path.basename(t))
Exemple #9
0
def crop_img_lab(input_imgs, input_labs, crop_imgs_dir, crop_lab_fix, id):
    for p_img, p_lab in zip(input_imgs, input_labs):
        lab = sitk.ReadImage(p_lab)

        # 先转化成统一space,保证crop的大小原始物理尺寸一致
        lab = sitkResample3DV2(lab, sitk.sitkNearestNeighbor, [1, 1, 1])
        bbox = get_bounding_box_by_id(sitk.GetArrayFromImage(lab), padding=10,id=id)
        ##extend bbox
        crop_lab = crop_by_bbox(lab, bbox)
        # crop_lab = sitkResize3DV2(crop_lab, [96, 96, 96], sitk.sitkNearestNeighbor)
        sitk_write_image(crop_lab, dir=crop_lab_fix, name=os.path.basename(p_lab))
        #
        img = sitk.ReadImage(p_img)
        img = sitkResample3DV2(img, sitk.sitkLinear, [1, 1, 1])
        crop_img = crop_by_bbox(img, bbox)
        # crop_img = sitkResize3DV2(crop_img, [96, 96, 96], sitk.sitkLinear)
        sitk_write_image(crop_img, dir=crop_imgs_dir, name=os.path.basename(p_img))
Exemple #10
0
 def sample_network(self, itr):
     # p_target_img, p_target_lab, p_atlas_imgses, p_atlas_labses=self.valid_sampler.get_file()
     # target_img, target_lab, atlas_imgses, atlas_labses=self.valid_sampler.get_data(p_target_img, p_target_lab, p_atlas_imgses, p_atlas_labses)
     target_img, target_lab, atlas_imgses, atlas_labses = self.valid_sampler.next_sample(
     )
     feed_dict = {self.ph_atlas: atlas_labses, self.ph_gt: target_lab}
     summary, out, gt = self.sess.run(
         [self.summary, self.output, self.ph_gt], feed_dict=feed_dict)
     out = np.argmax(out, axis=-1)
     sitk_write_lab(out[0, ...],
                    dir=self.args.sample_dir,
                    name=str(itr) + "pred")
     gt = np.argmax(gt, axis=-1)
     sitk_write_lab(gt[0, ...],
                    dir=self.args.sample_dir,
                    name=str(itr) + "gt")
     sitk_write_image(np.squeeze(target_img[0, ...]),
                      dir=self.args.sample_dir,
                      name=str(itr) + "img")
     dice = dc(out[0, ...], gt[0, ...])
     print("dc:%f" % (dice))
     return dice
Exemple #11
0
    def gen_one_batch(self, genSample, img_fix, lab_fix, img_mv, lab_mv,
                      output_dir):
        fix_imgs, fix_labs, mv_imgs, mv_labs = genSample.get_batch_data_V2(
            [img_mv], [img_fix], [lab_mv], [lab_fix])
        feed = self.create_feed_dict(fix_imgs, fix_labs, mv_imgs, mv_labs)
        input_mv_label, \
        input_fix_label, \
        warp_mv_label, \
        input_mv_img, \
        input_fix_img, \
        warp_mv_img,\
        fw,\
        bw= self.sess.run([self.input_MV_label,
                                     self.input_FIX_label,
                                     self.warped_MV_label,
                                     self.input_MV_image,
                                     self.input_FIX_image,
                                     self.warped_MV_image,
                                     self.theta_bw,
                                     self.theta_fw], feed_dict=feed)
        # param=sitk.ReadImage(img_fix)
        param = None
        sitk_write_image(
            input_fix_img[0, ...], param, output_dir,
            get_name_wo_suffix(img_fix).replace('image', 'target_image'))
        sitk_write_lab(
            input_fix_label[0, ...], param, output_dir,
            get_name_wo_suffix(lab_fix).replace('label', 'target_label'))

        sitk_write_image(
            warp_mv_img[0, ...], param, output_dir,
            get_name_wo_suffix(img_mv).replace('image', 'atlas_image'))
        sitk_write_lab(
            warp_mv_label[0, ...], param, output_dir,
            get_name_wo_suffix(lab_mv).replace('label', 'atlas_label'))

        resotre_mv, restore_fix = self.sess.run(
            [self.restore_MV_image, self.restore_FIX_image], feed_dict=feed)
        sitk_write_image(input_mv_img[0, ...], param, output_dir,
                         get_name_wo_suffix(img_mv))
        sitk_write_image(
            resotre_mv[0, ...], param, output_dir,
            get_name_wo_suffix(img_mv).replace('image', 'restore'))
        sitk_write_image(
            restore_fix[0, ...], param, output_dir,
            get_name_wo_suffix(img_fix).replace('image', 'restore'))

        print(fw)
        # contour= np.where(warp_mv_label> 0.5, 1, 0)
        # contour= contour.astype(np.uint16)
        # contour=sitk.GetImageFromArray(np.squeeze(contour))

        # contour=sitk.LabelContour(contour,True)
        # sitk_write_lab(sitk.GetArrayFromImage(contour),param , output_dir,get_name_wo_suffix(img_mv).replace('image','contour'))

        # ddf_fix_mv, ddf_mv_fix = self.sess.run([self.ddf_FIX_MV, self.ddf_MV_FIX], feed_dict=feed)
        # _, _, neg1 = neg_jac(ddf_fix_mv[0, ...])
        # _, _, neg2 = neg_jac(ddf_mv_fix[0, ...])
        ds = calculate_binary_dice(warp_mv_label, input_fix_label)
        bf_ds = calculate_binary_dice(input_mv_label, input_fix_label)
        return bf_ds, ds
Exemple #12
0
    def fusion_one_target(self, itr):
        target_img_batch, target_lab_batch, atlas_img_batch, atlas_lab_batch, sim_batch, p_fix_img, p_fix_lab = self.validate_sampler.next_sample_4_fusion(
        )
        sims = []

        param = sitk.ReadImage(p_fix_img[0])
        sitk_write_lab(np.squeeze(target_lab_batch.astype(np.uint8)),
                       param,
                       dir=self.args.validate_dir,
                       name=get_name_wo_suffix(p_fix_lab[0]))
        sitk_write_image(np.squeeze(target_img_batch),
                         param,
                         dir=self.args.validate_dir,
                         name=get_name_wo_suffix(p_fix_img[0]))

        for i in range(atlas_lab_batch.shape[-2]):
            trainFeed = {
                self.ph_target_image: target_img_batch,
                self.ph_target_label: target_lab_batch,
                self.ph_atlas_label: atlas_lab_batch[..., i, :],
                self.ph_gt_dicesim: sim_batch[..., i, :]
            }
            target_img, atlas_label, gt, target_lab, pred_lab, pred_sim = self.sess.run(
                [
                    self.ph_target_image, self.ph_atlas_label,
                    self.ph_gt_dicesim, self.ph_target_label,
                    self.predict_label, self.predict_sim
                ],
                feed_dict=trainFeed)
            sims.append(pred_sim)
            sitk_write_lab(np.squeeze(atlas_lab_batch[..., i, :]),
                           param,
                           dir=self.args.validate_dir,
                           name=get_name_wo_suffix(p_fix_lab[0].replace(
                               'label', 'label_' + str(i))))
            sitk_write_image(np.squeeze(pred_sim),
                             param,
                             dir=self.args.validate_dir,
                             name=get_name_wo_suffix(p_fix_lab[0]).replace(
                                 'label', 'sim_' + str(i)))

        sims = np.stack(sims, -1)
        u_lab = np.unique(target_lab.astype(np.uint8))
        LabelStats = np.zeros((len(u_lab), ) + np.squeeze(target_lab).shape)
        for i, lab in enumerate(u_lab):
            LabelStats[i] = np.sum(
                (np.squeeze(atlas_lab_batch) == lab).astype(np.int16) *
                np.squeeze(sims),
                axis=-1)
        fusion_label = u_lab[np.argmax(LabelStats, axis=0)]
        ds = calculate_binary_dice(fusion_label, target_lab_batch)
        hd = calculate_binary_hd(fusion_label,
                                 target_lab_batch,
                                 spacing=param.GetSpacing())

        # sitk_write_image(np.squeeze(target_img_batch),param,dir=os.path.dirname(p_fix_lab[0]),name=get_name_wo_suffix(p_fix_img[0]))
        sitk_write_lab(np.squeeze(fusion_label).astype(np.uint8),
                       param,
                       dir=os.path.dirname(p_fix_lab[0]),
                       name=get_name_wo_suffix(p_fix_lab[0]).replace(
                           'label', 'net_fusion_label'))
        return ds, hd
 def gen_warp_atlas(self, dir, genSample, i, img_fix, lab_fix, is_aug=True):
     output_dir = self.args.gen_dir + "/" + dir + "/target_" + str(
         i) + "_" + get_name_wo_suffix(img_fix)
     mk_or_cleardir(output_dir)
     params = []
     input_fix_imgs = []
     input_fix_labels = []
     warp_mv_imgs = []
     warp_mv_labels = []
     losses = []
     sims = []
     for img_mv, lab_mv in zip(genSample.img_mv, genSample.lab_mv):
         fix_imgs, fix_labs, mv_imgs, mv_labs = genSample.get_batch_data_V2(
             [img_mv], [img_fix], [lab_mv], [lab_fix])
         trainFeed = self.create_feed_dict(fix_imgs, fix_labs, mv_imgs,
                                           mv_labs, is_aug)
         input_mv_label, \
         input_fix_label, \
         warp_mv_label, \
         input_mv_img, \
         input_fix_img, \
         warp_mv_img = self.sess.run([self.input_MV_label,
                                      self.input_FIX_label,
                                      self.warped_MV_label,
                                      self.input_MV_image,
                                      self.input_FIX_image,
                                      self.warped_MV_image], feed_dict=trainFeed)
         input_fix_label = np.where(input_fix_label > 0.5, 1, 0)
         warp_mv_label = np.where(warp_mv_label > 0.5, 1, 0)
         sims.append(
             calculate_binary_dice(input_fix_label[0, ...],
                                   warp_mv_label[0, ...]))
         param = sitk.ReadImage(img_fix)
         params.append(param)
         input_fix_imgs.append(input_fix_img[0, ...])
         input_fix_labels.append(input_fix_label[0, ...])
         warp_mv_imgs.append(warp_mv_img[0, ...])
         warp_mv_labels.append(warp_mv_label[0, ...])
         losses.append(
             conditional_entropy_label_over_image(
                 np.squeeze(input_fix_img[0, ...]),
                 np.squeeze(warp_mv_label[0, ...])))
     indexs = np.argsort(losses)
     for ind in indexs:
         sitk_write_image(warp_mv_imgs[ind],
                          params[ind],
                          dir=output_dir,
                          name=str(ind) + "_mv_img")
         sitk_write_lab(warp_mv_labels[ind],
                        params[ind],
                        dir=output_dir,
                        name=str(ind) + "_mv_lab")
     sitk_write_image(input_fix_imgs[0],
                      params[0],
                      dir=output_dir,
                      name=str(0) + "_fix_img")
     sitk_write_lab(input_fix_labels[0],
                    params[0],
                    dir=output_dir,
                    name=str(0) + "_fix_lab")
     self.logger.debug(sims)
     self.logger.debug("%s %f -> %f" %
                       (output_dir, np.mean(sims),
                        np.mean([sims[ind] for ind in indexs[:5]])))