Exemplo n.º 1
0
    def __sample(self, iter):

        target_img_batch, target_lab_batch, atlas_img_batch, atlas_lab_batch = self.validate_sampler.next_sample(
        )

        trainFeed = {
            self.ph_target_image: target_img_batch,
            self.ph_warp_label: atlas_lab_batch,
            self.ph_gt_label: target_lab_batch
        }

        target_img, warp_label, gt, pred = self.sess.run([
            self.ph_target_image, self.ph_warp_label, self.ph_gt_label,
            self.predit
        ],
                                                         feed_dict=trainFeed)

        sitk_write_images(target_img,
                          dir=self.args.sample_dir,
                          name=str(iter) + "target_img")
        sitk_write_images(warp_label,
                          dir=self.args.sample_dir,
                          name=str(iter) + "warp_label")
        sitk_write_labs(np.argmax(gt, axis=-1),
                        dir=self.args.sample_dir,
                        name=str(iter) + "gt")
        sitk_write_labs(pred,
                        dir=self.args.sample_dir,
                        name=str(iter) + "pred")
        acc = calculate_binary_dice(np.argmax(gt, axis=-1), pred)
        self.logger.debug("acc:" + str(acc))
        return acc
Exemplo n.º 2
0
    def sample(self, iter):
        fix_imgs, fix_labs, mv_imgs, mv_labs = self.validate_sampler.next_sample(
        )
        trainFeed = {
            self.ph_MV_image:
            mv_imgs,
            self.ph_FIX_image:
            fix_imgs,
            self.ph_MV_label:
            mv_labs,
            self.ph_FIX_label:
            fix_labs,
            self.ph_moving_affine:
            util.initial_transform_generator(self.args.batch_size),
            self.ph_fixed_affine:
            util.initial_transform_generator(self.args.batch_size),
            self.ph_random_ddf:
            util.init_ddf_generator(self.args.batch_size, self.image_size)
        }

        input_mv_label, input_fix_label = self.sess.run(
            [self.input_MV_label, self.input_FIX_label], feed_dict=trainFeed)
        sitk_write_labs(input_fix_label, None, self.args.sample_dir,
                        '%d_input_fix_lab' % (iter))
        sitk_write_labs(input_mv_label, None, self.args.sample_dir,
                        '%d_input_mv_lab' % (iter))

        warp_mv_labs = self.sess.run(self.warped_MV_label, feed_dict=trainFeed)
        # warp_mv_labs=np.where(warp_mv_labs>0.5,1,0)
        sitk_write_labs(warp_mv_labs, None, self.args.sample_dir,
                        '%d_warp_mv_lab' % (iter))

        warped_mv_imgs = self.sess.run(self.warped_MV_image,
                                       feed_dict=trainFeed)
        sitk_write_images(warped_mv_imgs, None, self.args.sample_dir,
                          '%d_warp_mv_img' % (iter))

        sim_warp_mv_fix = calculate_binary_dice(warp_mv_labs, fix_labs)
        ddf_mv_fix = self.sess.run(self.ddf_MV_FIX, feed_dict=trainFeed)
        _, _, neg2 = neg_jac(ddf_mv_fix[0, ...])
        self.logger.debug("global_step %d: dice :  warp_mv_fix=%f" %
                          (iter, sim_warp_mv_fix))
        self.logger.debug("neg_jac %d" % (neg2))
        return sim_warp_mv_fix
Exemplo n.º 3
0
    def __sample(self, itr):

        target_img_batch, target_lab_batch, atlas_img_batch, atlas_lab_batch, sim_batch = self.validate_sampler.next_sample(
        )

        trainFeed = {
            self.ph_target_image: target_img_batch,
            self.ph_target_label: target_lab_batch,
            self.ph_atlas_image: atlas_img_batch,
            self.ph_atlas_label: atlas_lab_batch,
            self.ph_sim: sim_batch
        }

        target_img, atlas_img, gt_sim, pred = self.sess.run(
            [
                self.ph_target_image, self.ph_atlas_image, self.ph_sim,
                self.prob
            ],
            feed_dict=trainFeed)

        sitk_write_labs(target_lab_batch,
                        dir=self.args.sample_dir,
                        name=str(itr) + "target_lab")
        sitk_write_labs(atlas_lab_batch,
                        dir=self.args.sample_dir,
                        name=str(itr) + "atlas_lab")
        sitk_write_images(target_img,
                          dir=self.args.sample_dir,
                          name=str(itr) + "target_img")
        sitk_write_images(atlas_img,
                          dir=self.args.sample_dir,
                          name=str(itr) + "atlas_img")
        sitk_write_images(gt_sim,
                          dir=self.args.sample_dir,
                          name=str(itr) + "gt_sim")
        sitk_write_images(pred,
                          dir=self.args.sample_dir,
                          name=str(itr) + "pred_sim")
Exemplo n.º 4
0
    def __sample(self, itr):

        target_img_batch, target_lab_batch, atlas_img_batch, atlas_lab_batch, sim_batch = self.validate_sampler.next_sample(
        )

        trainFeed = {
            self.ph_target_image: target_img_batch,
            self.ph_target_label: target_lab_batch,
            self.ph_atlas_label: atlas_lab_batch,
            self.ph_gt_dicesim: sim_batch
        }
        target_img, atlas_label, gt, target_lab, pred_lab, pred_sim = self.sess.run(
            [
                self.ph_target_image, self.ph_atlas_label, self.ph_gt_dicesim,
                self.ph_target_label, self.predict_label, self.predict_sim
            ],
            feed_dict=trainFeed)

        sitk_write_images(target_img,
                          dir=self.args.sample_dir,
                          name=str(itr) + "target_img")
        sitk_write_labs(atlas_label,
                        dir=self.args.sample_dir,
                        name=str(itr) + "atlas_label")
        sitk_write_images(gt,
                          dir=self.args.sample_dir,
                          name=str(itr) + "gt_sim")
        sitk_write_images(pred_sim,
                          dir=self.args.sample_dir,
                          name=str(itr) + "pred_sim")
        sitk_write_labs(target_lab,
                        dir=self.args.sample_dir,
                        name=str(itr) + "target_label")
        sitk_write_labs(pred_lab,
                        dir=self.args.sample_dir,
                        name=str(itr) + "pred_label")
        pre_acc = calculate_binary_dice(target_lab, atlas_label)
        acc = calculate_binary_dice(target_lab, pred_lab, 0.2)
        self.logger.debug("pre_acc %f -> acc %f" % (pre_acc, acc))
        return acc
Exemplo n.º 5
0
    def sample(self, iter):
        fix_imgs, fix_labs, mv_imgs, mv_labs = self.validate_sampler.next_sample(
        )
        # p_img_mvs, p_img_fixs, p_lab_mvs, p_lab_fixs = self.validate_sampler.get_batch_file()
        # fix_imgs, fix_labs, mv_imgs, mv_labs = self.validate_sampler.get_batch_data_V2(p_img_mvs, p_img_fixs, p_lab_mvs,
        #                                                                                p_lab_fixs)
        trainFeed = self.create_feed_dict(fix_imgs,
                                          fix_labs,
                                          mv_imgs,
                                          mv_labs,
                                          is_aug=False)

        input_mv_label, input_fix_label = self.sess.run(
            [self.input_MV_label, self.input_FIX_label], feed_dict=trainFeed)
        np.where(input_mv_label > 0.5, 1, 0)
        np.where(input_fix_label > 0.5, 1, 0)
        sitk_write_images(input_fix_label.astype(np.uint16), None,
                          self.args.sample_dir, '%d_input_fix_lab' % (iter))
        sitk_write_images(input_mv_label.astype(np.uint16), None,
                          self.args.sample_dir, '%d_input_mv_lab' % (iter))

        warp_mv_labs = self.sess.run(self.warped_MV_label, feed_dict=trainFeed)
        warp_mv_labs = np.where(warp_mv_labs > 0.5, 1, 0)
        sitk_write_images(warp_mv_labs.astype(np.uint16), None,
                          self.args.sample_dir, '%d_warp_mv_lab' % (iter))

        fix_labs, mv_labs = self.sess.run(
            [self.ph_FIX_label, self.ph_MV_label], feed_dict=trainFeed)
        fix_labs = np.where(fix_labs > 0.5, 1, 0)
        mv_labs = np.where(mv_labs > 0.5, 1, 0)
        # sitk_write_images(fix_labs.astype(np.uint16), None, self.args.sample_dir, '%d_fix_lab' % (iter))
        # sitk_write_images(mv_labs.astype(np.uint16),None,self.args.sample_dir,'%d_mv_lab'%(iter))
        input_fix_imgs, input_mv_imgs = self.sess.run(
            [self.input_FIX_image, self.input_MV_image], feed_dict=trainFeed)
        sitk_write_images(input_fix_imgs, None, self.args.sample_dir,
                          '%d_input_fix_img' % (iter))
        sitk_write_images(input_mv_imgs, None, self.args.sample_dir,
                          '%d_input_mv_img' % (iter))

        warped_mv_imgs = self.sess.run(self.warped_MV_image,
                                       feed_dict=trainFeed)
        sitk_write_images(warped_mv_imgs, None, self.args.sample_dir,
                          '%d_warp_mv_img' % (iter))

        fix_imgs, mv_imgs = self.sess.run(
            [self.ph_FIX_image, self.ph_MV_image], feed_dict=trainFeed)
        sitk_write_images(mv_imgs, None, self.args.sample_dir,
                          '%d_mv_img' % (iter))
        sitk_write_images(fix_imgs, None, self.args.sample_dir,
                          '%d_fix_img' % (iter))

        sim_warp_mv_fix = calculate_binary_dice(warp_mv_labs, fix_labs)
        # para = sitk.ReadImage(p_lab_fixs[0])
        # hd_warp_mv_fix = hd(np.squeeze(warp_mv_labs[0, ...]), np.squeeze(fix_labs[0, ...]),voxelspacing=para.GetSpacing())
        hd_warp_mv_fix = 999
        return sim_warp_mv_fix, hd_warp_mv_fix