Beispiel #1
0
    def train(self):
        self.is_train = True
        init_op = tf.global_variables_initializer()
        self.sess.run(init_op)
        self.writer = tf.summary.FileWriter(self.args.log_dir, self.sess.graph)
        self.saver = tf.train.Saver()
        for itr in range(self.args.iteration):
            target_img_batch, target_lab_batch, _, atlas_lab_batch, sim_batch = self.train_sampler.next_sample(
            )

            trainFeed = {
                self.ph_target_image: target_img_batch,
                self.ph_target_label: target_lab_batch,
                self.ph_atlas_label: atlas_lab_batch,
                self.ph_gt_dicesim: sim_batch
            }

            _, summary, pred_label, gt_label = self.sess.run(
                [
                    self.train_op, self.summary, self.predict_label,
                    self.ph_target_label
                ],
                feed_dict=trainFeed)
            self.writer.add_summary(summary, global_step=itr)
            self.writer.add_summary(summary, itr)
            self.logger.debug(
                "step %d : dice=%f" %
                (itr, calculate_binary_dice(pred_label, gt_label)))
            if np.mod(itr, self.args.print_freq) == 1:
                self.logger.debug(self.sess.run(self.learning_rate))
                print(self.sess.run(self.global_step))
                self.__sample(itr)
            if np.mod(itr, self.args.save_freq) == 1:
                self.save(self.args.checkpoint_dir, itr)
Beispiel #2
0
    def __sample(self, iter):

        target_img_batch, target_lab_batch, atlas_img_batch, atlas_lab_batch = self.validate_sampler.next_sample(
        )

        trainFeed = {
            self.ph_target_image: target_img_batch,
            self.ph_warp_label: atlas_lab_batch,
            self.ph_gt_label: target_lab_batch
        }

        target_img, warp_label, gt, pred = self.sess.run([
            self.ph_target_image, self.ph_warp_label, self.ph_gt_label,
            self.predit
        ],
                                                         feed_dict=trainFeed)

        sitk_write_images(target_img,
                          dir=self.args.sample_dir,
                          name=str(iter) + "target_img")
        sitk_write_images(warp_label,
                          dir=self.args.sample_dir,
                          name=str(iter) + "warp_label")
        sitk_write_labs(np.argmax(gt, axis=-1),
                        dir=self.args.sample_dir,
                        name=str(iter) + "gt")
        sitk_write_labs(pred,
                        dir=self.args.sample_dir,
                        name=str(iter) + "pred")
        acc = calculate_binary_dice(np.argmax(gt, axis=-1), pred)
        self.logger.debug("acc:" + str(acc))
        return acc
Beispiel #3
0
    def train(self):
        self.is_train = True

        init_op = tf.global_variables_initializer()
        self.sess.run(init_op)
        self.writer = tf.summary.FileWriter(self.args.log_dir, self.sess.graph)
        self.saver = tf.train.Saver()
        for itr in range(self.args.iteration):
            target_img_batch, target_lab_batch, atlas_img_batch, atlas_lab_batch = self.train_sampler.next_sample(
            )

            trainFeed = {
                self.ph_target_image: target_img_batch,
                self.ph_warp_label: atlas_lab_batch,
                self.ph_gt_label: target_lab_batch
            }

            _, summary, gt, predict = self.sess.run(
                [self.train_op, self.summary, self.ph_gt_label, self.predit],
                feed_dict=trainFeed)
            self.writer.add_summary(summary, global_step=itr)
            print(
                "working  itr: %d dice = %f" %
                (itr, calculate_binary_dice(np.argmax(gt, axis=-1), predict)))
            self.writer.add_summary(summary, itr)
            if np.mod(itr, self.args.print_freq) == 1:
                self.__sample(itr)
            if np.mod(itr, self.args.save_freq) == 1:
                self.save(self.args.checkpoint_dir, itr)
Beispiel #4
0
    def gen_one_batch(self, genSample, img_fix, lab_fix, img_mv, lab_mv,
                      output_dir):
        fix_imgs, fix_labs, mv_imgs, mv_labs = genSample.get_batch_data_V2(
            [img_mv], [img_fix], [lab_mv], [lab_fix])
        trainFeed = self.create_feed_dict(fix_imgs, fix_labs, mv_imgs, mv_labs)
        input_mv_label, \
        input_fix_label, \
        warp_mv_label, \
        input_mv_img, \
        input_fix_img, \
        warp_mv_img = self.sess.run([self.input_MV_label,
                                     self.input_FIX_label,
                                     self.warped_MV_label,
                                     self.input_MV_image,
                                     self.input_FIX_image,
                                     self.warped_MV_image], feed_dict=trainFeed)
        param = sitk.ReadImage(img_fix)
        sitk_write_image(input_fix_img[0, ...], param, output_dir,
                         get_name_wo_suffix(img_fix))
        sitk_write_lab(input_fix_label[0, ...], param, output_dir,
                       get_name_wo_suffix(img_fix).replace('image', 'label'))

        sitk_write_image(warp_mv_img[0, ...], param, output_dir,
                         get_name_wo_suffix(img_mv))
        sitk_write_lab(warp_mv_label[0, ...], param, output_dir,
                       get_name_wo_suffix(img_mv).replace('image', 'label'))

        ddf_mv_fix = self.sess.run(self.ddf_MV_FIX, feed_dict=trainFeed)
        _, _, neg2 = neg_jac(ddf_mv_fix[0, ...])
        self.logger.debug("neg_jac  %d" % (neg2))
        ds = calculate_binary_dice(warp_mv_label, input_fix_label)
        hd = calculate_binary_hd(warp_mv_label,
                                 input_fix_label,
                                 spacing=param.GetSpacing())
        return ds, hd
Beispiel #5
0
    def sample(self, iter):
        fix_imgs, fix_labs, mv_imgs, mv_labs = self.validate_sampler.next_sample(
        )
        # p_img_mvs, p_img_fixs, p_lab_mvs, p_lab_fixs = self.validate_sampler.get_batch_file()
        # fix_imgs, fix_labs, mv_imgs, mv_labs = self.validate_sampler.get_batch_data_V2(p_img_mvs, p_img_fixs, p_lab_mvs,
        #                                                                                p_lab_fixs)
        trainFeed = self.create_feed_dict(fix_imgs,
                                          fix_labs,
                                          mv_imgs,
                                          mv_labs,
                                          is_aug=False)

        input_mv_label, input_fix_label = self.sess.run(
            [self.input_MV_label, self.input_FIX_label], feed_dict=trainFeed)
        np.where(input_mv_label > 0.5, 1, 0)
        np.where(input_fix_label > 0.5, 1, 0)
        sitk_write_images(input_fix_label.astype(np.uint16), None,
                          self.args.sample_dir, '%d_input_fix_lab' % (iter))
        sitk_write_images(input_mv_label.astype(np.uint16), None,
                          self.args.sample_dir, '%d_input_mv_lab' % (iter))

        warp_mv_labs = self.sess.run(self.warped_MV_label, feed_dict=trainFeed)
        warp_mv_labs = np.where(warp_mv_labs > 0.5, 1, 0)
        sitk_write_images(warp_mv_labs.astype(np.uint16), None,
                          self.args.sample_dir, '%d_warp_mv_lab' % (iter))

        fix_labs, mv_labs = self.sess.run(
            [self.ph_FIX_label, self.ph_MV_label], feed_dict=trainFeed)
        fix_labs = np.where(fix_labs > 0.5, 1, 0)
        mv_labs = np.where(mv_labs > 0.5, 1, 0)
        # sitk_write_images(fix_labs.astype(np.uint16), None, self.args.sample_dir, '%d_fix_lab' % (iter))
        # sitk_write_images(mv_labs.astype(np.uint16),None,self.args.sample_dir,'%d_mv_lab'%(iter))
        input_fix_imgs, input_mv_imgs = self.sess.run(
            [self.input_FIX_image, self.input_MV_image], feed_dict=trainFeed)
        sitk_write_images(input_fix_imgs, None, self.args.sample_dir,
                          '%d_input_fix_img' % (iter))
        sitk_write_images(input_mv_imgs, None, self.args.sample_dir,
                          '%d_input_mv_img' % (iter))

        warped_mv_imgs = self.sess.run(self.warped_MV_image,
                                       feed_dict=trainFeed)
        sitk_write_images(warped_mv_imgs, None, self.args.sample_dir,
                          '%d_warp_mv_img' % (iter))

        fix_imgs, mv_imgs = self.sess.run(
            [self.ph_FIX_image, self.ph_MV_image], feed_dict=trainFeed)
        sitk_write_images(mv_imgs, None, self.args.sample_dir,
                          '%d_mv_img' % (iter))
        sitk_write_images(fix_imgs, None, self.args.sample_dir,
                          '%d_fix_img' % (iter))

        sim_warp_mv_fix = calculate_binary_dice(warp_mv_labs, fix_labs)
        # para = sitk.ReadImage(p_lab_fixs[0])
        # hd_warp_mv_fix = hd(np.squeeze(warp_mv_labs[0, ...]), np.squeeze(fix_labs[0, ...]),voxelspacing=para.GetSpacing())
        hd_warp_mv_fix = 999
        return sim_warp_mv_fix, hd_warp_mv_fix
Beispiel #6
0
    def __sample(self, itr):

        target_img_batch, target_lab_batch, atlas_img_batch, atlas_lab_batch, sim_batch = self.validate_sampler.next_sample(
        )

        trainFeed = {
            self.ph_target_image: target_img_batch,
            self.ph_target_label: target_lab_batch,
            self.ph_atlas_label: atlas_lab_batch,
            self.ph_gt_dicesim: sim_batch
        }
        target_img, atlas_label, gt, target_lab, pred_lab, pred_sim = self.sess.run(
            [
                self.ph_target_image, self.ph_atlas_label, self.ph_gt_dicesim,
                self.ph_target_label, self.predict_label, self.predict_sim
            ],
            feed_dict=trainFeed)

        sitk_write_images(target_img,
                          dir=self.args.sample_dir,
                          name=str(itr) + "target_img")
        sitk_write_labs(atlas_label,
                        dir=self.args.sample_dir,
                        name=str(itr) + "atlas_label")
        sitk_write_images(gt,
                          dir=self.args.sample_dir,
                          name=str(itr) + "gt_sim")
        sitk_write_images(pred_sim,
                          dir=self.args.sample_dir,
                          name=str(itr) + "pred_sim")
        sitk_write_labs(target_lab,
                        dir=self.args.sample_dir,
                        name=str(itr) + "target_label")
        sitk_write_labs(pred_lab,
                        dir=self.args.sample_dir,
                        name=str(itr) + "pred_label")
        pre_acc = calculate_binary_dice(target_lab, atlas_label)
        acc = calculate_binary_dice(target_lab, pred_lab, 0.2)
        self.logger.debug("pre_acc %f -> acc %f" % (pre_acc, acc))
        return acc
Beispiel #7
0
    def sample(self, iter):
        fix_imgs, fix_labs, mv_imgs, mv_labs = self.validate_sampler.next_sample(
        )
        trainFeed = {
            self.ph_MV_image:
            mv_imgs,
            self.ph_FIX_image:
            fix_imgs,
            self.ph_MV_label:
            mv_labs,
            self.ph_FIX_label:
            fix_labs,
            self.ph_moving_affine:
            util.initial_transform_generator(self.args.batch_size),
            self.ph_fixed_affine:
            util.initial_transform_generator(self.args.batch_size),
            self.ph_random_ddf:
            util.init_ddf_generator(self.args.batch_size, self.image_size)
        }

        input_mv_label, input_fix_label = self.sess.run(
            [self.input_MV_label, self.input_FIX_label], feed_dict=trainFeed)
        sitk_write_labs(input_fix_label, None, self.args.sample_dir,
                        '%d_input_fix_lab' % (iter))
        sitk_write_labs(input_mv_label, None, self.args.sample_dir,
                        '%d_input_mv_lab' % (iter))

        warp_mv_labs = self.sess.run(self.warped_MV_label, feed_dict=trainFeed)
        # warp_mv_labs=np.where(warp_mv_labs>0.5,1,0)
        sitk_write_labs(warp_mv_labs, None, self.args.sample_dir,
                        '%d_warp_mv_lab' % (iter))

        warped_mv_imgs = self.sess.run(self.warped_MV_image,
                                       feed_dict=trainFeed)
        sitk_write_images(warped_mv_imgs, None, self.args.sample_dir,
                          '%d_warp_mv_img' % (iter))

        sim_warp_mv_fix = calculate_binary_dice(warp_mv_labs, fix_labs)
        ddf_mv_fix = self.sess.run(self.ddf_MV_FIX, feed_dict=trainFeed)
        _, _, neg2 = neg_jac(ddf_mv_fix[0, ...])
        self.logger.debug("global_step %d: dice :  warp_mv_fix=%f" %
                          (iter, sim_warp_mv_fix))
        self.logger.debug("neg_jac %d" % (neg2))
        return sim_warp_mv_fix
Beispiel #8
0
    def gen_one_batch(self, genSample, img_fix, lab_fix, img_mv, lab_mv,
                      output_dir):
        fix_imgs, fix_labs, mv_imgs, mv_labs = genSample.get_batch_data_V2(
            [img_mv], [img_fix], [lab_mv], [lab_fix])
        feed = self.create_feed_dict(fix_imgs, fix_labs, mv_imgs, mv_labs)
        input_mv_label, \
        input_fix_label, \
        warp_mv_label, \
        input_mv_img, \
        input_fix_img, \
        warp_mv_img,\
        fw,\
        bw= self.sess.run([self.input_MV_label,
                                     self.input_FIX_label,
                                     self.warped_MV_label,
                                     self.input_MV_image,
                                     self.input_FIX_image,
                                     self.warped_MV_image,
                                     self.theta_bw,
                                     self.theta_fw], feed_dict=feed)
        # param=sitk.ReadImage(img_fix)
        param = None
        sitk_write_image(
            input_fix_img[0, ...], param, output_dir,
            get_name_wo_suffix(img_fix).replace('image', 'target_image'))
        sitk_write_lab(
            input_fix_label[0, ...], param, output_dir,
            get_name_wo_suffix(lab_fix).replace('label', 'target_label'))

        sitk_write_image(
            warp_mv_img[0, ...], param, output_dir,
            get_name_wo_suffix(img_mv).replace('image', 'atlas_image'))
        sitk_write_lab(
            warp_mv_label[0, ...], param, output_dir,
            get_name_wo_suffix(lab_mv).replace('label', 'atlas_label'))

        resotre_mv, restore_fix = self.sess.run(
            [self.restore_MV_image, self.restore_FIX_image], feed_dict=feed)
        sitk_write_image(input_mv_img[0, ...], param, output_dir,
                         get_name_wo_suffix(img_mv))
        sitk_write_image(
            resotre_mv[0, ...], param, output_dir,
            get_name_wo_suffix(img_mv).replace('image', 'restore'))
        sitk_write_image(
            restore_fix[0, ...], param, output_dir,
            get_name_wo_suffix(img_fix).replace('image', 'restore'))

        print(fw)
        # contour= np.where(warp_mv_label> 0.5, 1, 0)
        # contour= contour.astype(np.uint16)
        # contour=sitk.GetImageFromArray(np.squeeze(contour))

        # contour=sitk.LabelContour(contour,True)
        # sitk_write_lab(sitk.GetArrayFromImage(contour),param , output_dir,get_name_wo_suffix(img_mv).replace('image','contour'))

        # ddf_fix_mv, ddf_mv_fix = self.sess.run([self.ddf_FIX_MV, self.ddf_MV_FIX], feed_dict=feed)
        # _, _, neg1 = neg_jac(ddf_fix_mv[0, ...])
        # _, _, neg2 = neg_jac(ddf_mv_fix[0, ...])
        ds = calculate_binary_dice(warp_mv_label, input_fix_label)
        bf_ds = calculate_binary_dice(input_mv_label, input_fix_label)
        return bf_ds, ds
Beispiel #9
0
    def fusion_one_target(self, itr):
        target_img_batch, target_lab_batch, atlas_img_batch, atlas_lab_batch, sim_batch, p_fix_img, p_fix_lab = self.validate_sampler.next_sample_4_fusion(
        )
        sims = []

        param = sitk.ReadImage(p_fix_img[0])
        sitk_write_lab(np.squeeze(target_lab_batch.astype(np.uint8)),
                       param,
                       dir=self.args.validate_dir,
                       name=get_name_wo_suffix(p_fix_lab[0]))
        sitk_write_image(np.squeeze(target_img_batch),
                         param,
                         dir=self.args.validate_dir,
                         name=get_name_wo_suffix(p_fix_img[0]))

        for i in range(atlas_lab_batch.shape[-2]):
            trainFeed = {
                self.ph_target_image: target_img_batch,
                self.ph_target_label: target_lab_batch,
                self.ph_atlas_label: atlas_lab_batch[..., i, :],
                self.ph_gt_dicesim: sim_batch[..., i, :]
            }
            target_img, atlas_label, gt, target_lab, pred_lab, pred_sim = self.sess.run(
                [
                    self.ph_target_image, self.ph_atlas_label,
                    self.ph_gt_dicesim, self.ph_target_label,
                    self.predict_label, self.predict_sim
                ],
                feed_dict=trainFeed)
            sims.append(pred_sim)
            sitk_write_lab(np.squeeze(atlas_lab_batch[..., i, :]),
                           param,
                           dir=self.args.validate_dir,
                           name=get_name_wo_suffix(p_fix_lab[0].replace(
                               'label', 'label_' + str(i))))
            sitk_write_image(np.squeeze(pred_sim),
                             param,
                             dir=self.args.validate_dir,
                             name=get_name_wo_suffix(p_fix_lab[0]).replace(
                                 'label', 'sim_' + str(i)))

        sims = np.stack(sims, -1)
        u_lab = np.unique(target_lab.astype(np.uint8))
        LabelStats = np.zeros((len(u_lab), ) + np.squeeze(target_lab).shape)
        for i, lab in enumerate(u_lab):
            LabelStats[i] = np.sum(
                (np.squeeze(atlas_lab_batch) == lab).astype(np.int16) *
                np.squeeze(sims),
                axis=-1)
        fusion_label = u_lab[np.argmax(LabelStats, axis=0)]
        ds = calculate_binary_dice(fusion_label, target_lab_batch)
        hd = calculate_binary_hd(fusion_label,
                                 target_lab_batch,
                                 spacing=param.GetSpacing())

        # sitk_write_image(np.squeeze(target_img_batch),param,dir=os.path.dirname(p_fix_lab[0]),name=get_name_wo_suffix(p_fix_img[0]))
        sitk_write_lab(np.squeeze(fusion_label).astype(np.uint8),
                       param,
                       dir=os.path.dirname(p_fix_lab[0]),
                       name=get_name_wo_suffix(p_fix_lab[0]).replace(
                           'label', 'net_fusion_label'))
        return ds, hd
Beispiel #10
0
 def gen_warp_atlas(self, dir, genSample, i, img_fix, lab_fix, is_aug=True):
     output_dir = self.args.gen_dir + "/" + dir + "/target_" + str(
         i) + "_" + get_name_wo_suffix(img_fix)
     mk_or_cleardir(output_dir)
     params = []
     input_fix_imgs = []
     input_fix_labels = []
     warp_mv_imgs = []
     warp_mv_labels = []
     losses = []
     sims = []
     for img_mv, lab_mv in zip(genSample.img_mv, genSample.lab_mv):
         fix_imgs, fix_labs, mv_imgs, mv_labs = genSample.get_batch_data_V2(
             [img_mv], [img_fix], [lab_mv], [lab_fix])
         trainFeed = self.create_feed_dict(fix_imgs, fix_labs, mv_imgs,
                                           mv_labs, is_aug)
         input_mv_label, \
         input_fix_label, \
         warp_mv_label, \
         input_mv_img, \
         input_fix_img, \
         warp_mv_img = self.sess.run([self.input_MV_label,
                                      self.input_FIX_label,
                                      self.warped_MV_label,
                                      self.input_MV_image,
                                      self.input_FIX_image,
                                      self.warped_MV_image], feed_dict=trainFeed)
         input_fix_label = np.where(input_fix_label > 0.5, 1, 0)
         warp_mv_label = np.where(warp_mv_label > 0.5, 1, 0)
         sims.append(
             calculate_binary_dice(input_fix_label[0, ...],
                                   warp_mv_label[0, ...]))
         param = sitk.ReadImage(img_fix)
         params.append(param)
         input_fix_imgs.append(input_fix_img[0, ...])
         input_fix_labels.append(input_fix_label[0, ...])
         warp_mv_imgs.append(warp_mv_img[0, ...])
         warp_mv_labels.append(warp_mv_label[0, ...])
         losses.append(
             conditional_entropy_label_over_image(
                 np.squeeze(input_fix_img[0, ...]),
                 np.squeeze(warp_mv_label[0, ...])))
     indexs = np.argsort(losses)
     for ind in indexs:
         sitk_write_image(warp_mv_imgs[ind],
                          params[ind],
                          dir=output_dir,
                          name=str(ind) + "_mv_img")
         sitk_write_lab(warp_mv_labels[ind],
                        params[ind],
                        dir=output_dir,
                        name=str(ind) + "_mv_lab")
     sitk_write_image(input_fix_imgs[0],
                      params[0],
                      dir=output_dir,
                      name=str(0) + "_fix_img")
     sitk_write_lab(input_fix_labels[0],
                    params[0],
                    dir=output_dir,
                    name=str(0) + "_fix_lab")
     self.logger.debug(sims)
     self.logger.debug("%s %f -> %f" %
                       (output_dir, np.mean(sims),
                        np.mean([sims[ind] for ind in indexs[:5]])))