示例#1
0
    def validate(self):

        self.is_train = False
        init_op = tf.global_variables_initializer()
        self.saver = tf.train.Saver()
        self.sess.run(init_op)
        if self.load(self.args.checkpoint_dir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        genSample = AbdomenSampler(self.args, 'validate')
        # self.generator_sample_targetwise(genSample, )
        outputdir = self.args.sample_dir + "/atlas_target/"
        ds_all = []
        bf_ds_all = []
        jt_all = []
        for img_fix, lab_fix in zip(genSample.img_fix, genSample.lab_fix):
            output_dir = outputdir + get_name_wo_suffix(img_fix)
            mk_or_cleardir(output_dir)
            print(output_dir)
            for img_mv, lab_mv in zip(genSample.img_mv, genSample.lab_mv):
                bfds, ds = self.gen_one_batch(genSample, img_fix, lab_fix,
                                              img_mv, lab_mv, output_dir)
                # print("sim= %f"%ds)
                ds_all.append(ds)
                bf_ds_all.append(bfds)
            # print(ds_all)
        self.logger.debug("%s total_sim %s->%s " % (Get_Name_By_Index(
            self.args.component), self.args.Tatlas, self.args.Ttarget))
        print_mean_and_std(bf_ds_all)
        print_mean_and_std(ds_all)
示例#2
0
def prepare_unsupervised_reg_data(args):
    if not os.path.exists(args.dataset_dir):
        mk_or_cleardir(args.dataset_dir)
        types=parse_arg_list(args.mode)
        for t in types:
            imgs=sort_glob("../../dataset/%s/%s-image/*.nii.gz" % (args.task,t))
            labs=sort_glob("../../dataset/%s/%s-label/*.nii.gz" % (args.task,t))
            crop_ROI_data_by_label(imgs, labs, args.dataset_dir + "/%s/"%(t), args.component)
示例#3
0
 def generator_sample_atlaswise(self, genSample, outputdir):
     for img_mv, lab_mv in zip(genSample.img_mv, genSample.lab_mv):
         output_dir = outputdir + get_name_wo_suffix(img_mv)
         mk_or_cleardir(output_dir)
         print(output_dir)
         for img_fix, lab_fix in zip(genSample.img_fix, genSample.lab_fix):
             self.gen_test(genSample, img_fix, lab_fix, img_mv, lab_mv,
                           output_dir)
示例#4
0
def prepare_crossvalidation_reg_data(args):
    if not os.path.exists(args.dataset_dir):
        mk_or_cleardir(args.dataset_dir)
        imgs=sort_glob("../../dataset/%s/%s-image/*.nii.gz" % (args.task,args.Tatlas))
        labs=sort_glob("../../dataset/%s/%s-label/*.nii.gz" % (args.task,args.Tatlas))
        crop_ROI_data_by_label(imgs, labs, args.dataset_dir + "/atlas/", args.component)
        #####
        imgs=sort_glob("../../dataset/%s/%s-image/*.nii.gz" % (args.task,args.Ttarget))
        labs=sort_glob("../../dataset/%s/%s-label/*.nii.gz" % (args.task,args.Ttarget))
        crop_ROI_data_by_label(imgs, labs, args.dataset_dir + "/target/", args.component)
示例#5
0
def split(test_imgs_mv, test_lab_mv, train_imgs_mv, train_lab_mv):
    mk_or_cleardir(test_imgs_mv)
    mk_or_cleardir(test_lab_mv)
    file_img_mv = glob.glob(train_imgs_mv+"/*.nii.gz")
    file_img_mv.sort()
    file_lab_mv = glob.glob(train_lab_mv+"/*.nii.gz")
    file_lab_mv.sort()
    # L = random.sample(range(0, len(file_img_mv)), 8)
    L=range(0,len(file_img_mv))
    for i in L[:8]:
        shutil.move(file_img_mv[i], test_imgs_mv)
        shutil.move(file_lab_mv[i], test_lab_mv)
示例#6
0
def sitk_write_lab(input_, parameter_img=None, dir=None, name=''):
    if dir is not None:
        if not os.path.exists(dir):
            mk_or_cleardir(dir)
        if not isinstance(input_, sitk.Image):
            input_ = np.where(input_ > 0.5, 1, 0)
            input_ = input_.astype(np.uint16)
            img = sitk.GetImageFromArray(input_)
        else:
            img = input_
        if parameter_img is not None:
            img.CopyInformation(parameter_img)
        sitk.WriteImage(img, os.path.join(dir, name + '.nii.gz'))
示例#7
0
 def generator_sample_targetwise(self, genSample, outputdir):
     all_ds = []
     all_hd = []
     for img_fix, lab_fix in zip(genSample.img_fix, genSample.lab_fix):
         output_dir = outputdir + get_name_wo_suffix(img_fix)
         mk_or_cleardir(output_dir)
         print(output_dir)
         for img_mv, lab_mv in zip(genSample.img_mv, genSample.lab_mv):
             ds, hd = self.gen_one_batch(genSample, img_fix, lab_fix,
                                         img_mv, lab_mv, output_dir)
             all_ds.append(ds)
             all_hd.append(hd)
     outpu2excel(self.args.res_excel, self.args.MOLD_ID + "_DS", all_ds)
     outpu2excel(self.args.res_excel, self.args.MOLD_ID + "_HD", all_hd)
示例#8
0
def sitk_write_images(input_, parameter_img=None, dir=None, name=''):

    if dir is not None:
        if not os.path.exists(dir):
            mk_or_cleardir(dir)
        batch_size = input_.shape[0]
        for idx in range(batch_size):
            if not isinstance(input_, sitk.Image):
                img = sitk.GetImageFromArray(input_[idx, ...])
            else:
                img = input_[idx, ...]
            if parameter_img is not None:
                img.CopyInformation(parameter_img)

            sitk.WriteImage(img, os.path.join(dir, name + '_%s.nii.gz' % idx))
示例#9
0
def main(_):

    globel_setup()
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    with tf.Session(config=tfconfig) as sess:
        if args.phase == 'train':
            mkdir_if_not_exist(args.log_dir)
            # mk_or_cleardir(args.checkpoint_dir)
            mk_or_cleardir(args.log_dir)
            if args.task == 'CHAOS':
                prepare_chaos_reg_working_data(args)
            else:
                prepare_mmwhs_reg_working_data(args)
            model = LabAttentionReg(sess, args)
            model.train()
        elif args.phase == 'validate':
            # mk_or_cleardir(args.sample_dir)
            model = LabAttentionReg(sess, args)
            # model.show_gated_info()
            model.validate()
            # cross_validate(args)
        # elif args.phase=='test':
        #     mk_or_cleardir(args.test_dir)
        #     # prepare_data()
        #     model =LabAttentionReg (sess, args)
        #     model.test()
        #     post_process(args)
        # elif args.phase=='post':
        #     post_process(args)
        elif args.phase == 'gen':
            #在../datasets/sim_ct_mr_**中生成数据
            # mk_or_cleardir(args.fusion_dataset_dir)
            # mk_or_cleardir(args.gen_dir)
            mk_or_cleardir(args.sample_dir)
            model = LabAttentionReg(sess, args)
            # model.generate()
            # model.validate()
            #同时生成验证和训练的数据
            model.generate_4_fusion()
        # elif args.phase=='trainSim':
        #     mk_or_cleardir(args.sim_checkpoint_dir)
        #     mk_or_cleardir(args.sim_sample_dir)
        #     mk_or_cleardir(args.sim_log_dir)
        #     model=PatchEmbbeding(sess,args)
        #     model.train()
        # elif args.phase=='testSim':
        #     mk_or_cleardir(args.sim_sample_dir)
        #     mk_or_cleardir(args.sim_test_dir)
        #     model = PatchEmbbeding(sess, args)
        #     model.test()
        # elif args.phase == 'fusion':
        #     # 调用生成label
        #     # 进行融合
        #     ngf = NGFFusion(args)
        #     ngf.run()
        #     # mvfusion=MVFusion(args)
        #     # mvfusion.run()
        #     mvfusion = SitkSTAPLEFusion(args)
        #     mvfusion.run()
        #     mvfusion = SitkMVFusion(args)
        #     mvfusion.run()
        elif args.phase == 'summary':
            cross_validate(args)
        else:
            print("undefined phase")
示例#10
0
def prepare_3dUnet_ROI_data(args):
    if not os.path.exists(args.dataset_dir):
        mk_or_cleardir(args.dataset_dir)
        generator_ROI_data_for_3DUnet(args)
示例#11
0
def prepare_mmwhs_reg_working_data(args):
    if not os.path.exists(args.dataset_dir):
        mk_or_cleardir(args.dataset_dir)
        generator_ROI_data(args)
示例#12
0
 def gen_warp_atlas(self, dir, genSample, i, img_fix, lab_fix, is_aug=True):
     output_dir = self.args.gen_dir + "/" + dir + "/target_" + str(
         i) + "_" + get_name_wo_suffix(img_fix)
     mk_or_cleardir(output_dir)
     params = []
     input_fix_imgs = []
     input_fix_labels = []
     warp_mv_imgs = []
     warp_mv_labels = []
     losses = []
     sims = []
     for img_mv, lab_mv in zip(genSample.img_mv, genSample.lab_mv):
         fix_imgs, fix_labs, mv_imgs, mv_labs = genSample.get_batch_data_V2(
             [img_mv], [img_fix], [lab_mv], [lab_fix])
         trainFeed = self.create_feed_dict(fix_imgs, fix_labs, mv_imgs,
                                           mv_labs, is_aug)
         input_mv_label, \
         input_fix_label, \
         warp_mv_label, \
         input_mv_img, \
         input_fix_img, \
         warp_mv_img = self.sess.run([self.input_MV_label,
                                      self.input_FIX_label,
                                      self.warped_MV_label,
                                      self.input_MV_image,
                                      self.input_FIX_image,
                                      self.warped_MV_image], feed_dict=trainFeed)
         input_fix_label = np.where(input_fix_label > 0.5, 1, 0)
         warp_mv_label = np.where(warp_mv_label > 0.5, 1, 0)
         sims.append(
             calculate_binary_dice(input_fix_label[0, ...],
                                   warp_mv_label[0, ...]))
         param = sitk.ReadImage(img_fix)
         params.append(param)
         input_fix_imgs.append(input_fix_img[0, ...])
         input_fix_labels.append(input_fix_label[0, ...])
         warp_mv_imgs.append(warp_mv_img[0, ...])
         warp_mv_labels.append(warp_mv_label[0, ...])
         losses.append(
             conditional_entropy_label_over_image(
                 np.squeeze(input_fix_img[0, ...]),
                 np.squeeze(warp_mv_label[0, ...])))
     indexs = np.argsort(losses)
     for ind in indexs:
         sitk_write_image(warp_mv_imgs[ind],
                          params[ind],
                          dir=output_dir,
                          name=str(ind) + "_mv_img")
         sitk_write_lab(warp_mv_labels[ind],
                        params[ind],
                        dir=output_dir,
                        name=str(ind) + "_mv_lab")
     sitk_write_image(input_fix_imgs[0],
                      params[0],
                      dir=output_dir,
                      name=str(0) + "_fix_img")
     sitk_write_lab(input_fix_labels[0],
                    params[0],
                    dir=output_dir,
                    name=str(0) + "_fix_lab")
     self.logger.debug(sims)
     self.logger.debug("%s %f -> %f" %
                       (output_dir, np.mean(sims),
                        np.mean([sims[ind] for ind in indexs[:5]])))