def __sample(self, iter): target_img_batch, target_lab_batch, atlas_img_batch, atlas_lab_batch = self.validate_sampler.next_sample( ) trainFeed = { self.ph_target_image: target_img_batch, self.ph_warp_label: atlas_lab_batch, self.ph_gt_label: target_lab_batch } target_img, warp_label, gt, pred = self.sess.run([ self.ph_target_image, self.ph_warp_label, self.ph_gt_label, self.predit ], feed_dict=trainFeed) sitk_write_images(target_img, dir=self.args.sample_dir, name=str(iter) + "target_img") sitk_write_images(warp_label, dir=self.args.sample_dir, name=str(iter) + "warp_label") sitk_write_labs(np.argmax(gt, axis=-1), dir=self.args.sample_dir, name=str(iter) + "gt") sitk_write_labs(pred, dir=self.args.sample_dir, name=str(iter) + "pred") acc = calculate_binary_dice(np.argmax(gt, axis=-1), pred) self.logger.debug("acc:" + str(acc)) return acc
def sample(self, iter): fix_imgs, fix_labs, mv_imgs, mv_labs = self.validate_sampler.next_sample( ) trainFeed = { self.ph_MV_image: mv_imgs, self.ph_FIX_image: fix_imgs, self.ph_MV_label: mv_labs, self.ph_FIX_label: fix_labs, self.ph_moving_affine: util.initial_transform_generator(self.args.batch_size), self.ph_fixed_affine: util.initial_transform_generator(self.args.batch_size), self.ph_random_ddf: util.init_ddf_generator(self.args.batch_size, self.image_size) } input_mv_label, input_fix_label = self.sess.run( [self.input_MV_label, self.input_FIX_label], feed_dict=trainFeed) sitk_write_labs(input_fix_label, None, self.args.sample_dir, '%d_input_fix_lab' % (iter)) sitk_write_labs(input_mv_label, None, self.args.sample_dir, '%d_input_mv_lab' % (iter)) warp_mv_labs = self.sess.run(self.warped_MV_label, feed_dict=trainFeed) # warp_mv_labs=np.where(warp_mv_labs>0.5,1,0) sitk_write_labs(warp_mv_labs, None, self.args.sample_dir, '%d_warp_mv_lab' % (iter)) warped_mv_imgs = self.sess.run(self.warped_MV_image, feed_dict=trainFeed) sitk_write_images(warped_mv_imgs, None, self.args.sample_dir, '%d_warp_mv_img' % (iter)) sim_warp_mv_fix = calculate_binary_dice(warp_mv_labs, fix_labs) ddf_mv_fix = self.sess.run(self.ddf_MV_FIX, feed_dict=trainFeed) _, _, neg2 = neg_jac(ddf_mv_fix[0, ...]) self.logger.debug("global_step %d: dice : warp_mv_fix=%f" % (iter, sim_warp_mv_fix)) self.logger.debug("neg_jac %d" % (neg2)) return sim_warp_mv_fix
def __sample(self, itr): target_img_batch, target_lab_batch, atlas_img_batch, atlas_lab_batch, sim_batch = self.validate_sampler.next_sample( ) trainFeed = { self.ph_target_image: target_img_batch, self.ph_target_label: target_lab_batch, self.ph_atlas_image: atlas_img_batch, self.ph_atlas_label: atlas_lab_batch, self.ph_sim: sim_batch } target_img, atlas_img, gt_sim, pred = self.sess.run( [ self.ph_target_image, self.ph_atlas_image, self.ph_sim, self.prob ], feed_dict=trainFeed) sitk_write_labs(target_lab_batch, dir=self.args.sample_dir, name=str(itr) + "target_lab") sitk_write_labs(atlas_lab_batch, dir=self.args.sample_dir, name=str(itr) + "atlas_lab") sitk_write_images(target_img, dir=self.args.sample_dir, name=str(itr) + "target_img") sitk_write_images(atlas_img, dir=self.args.sample_dir, name=str(itr) + "atlas_img") sitk_write_images(gt_sim, dir=self.args.sample_dir, name=str(itr) + "gt_sim") sitk_write_images(pred, dir=self.args.sample_dir, name=str(itr) + "pred_sim")
def __sample(self, itr): target_img_batch, target_lab_batch, atlas_img_batch, atlas_lab_batch, sim_batch = self.validate_sampler.next_sample( ) trainFeed = { self.ph_target_image: target_img_batch, self.ph_target_label: target_lab_batch, self.ph_atlas_label: atlas_lab_batch, self.ph_gt_dicesim: sim_batch } target_img, atlas_label, gt, target_lab, pred_lab, pred_sim = self.sess.run( [ self.ph_target_image, self.ph_atlas_label, self.ph_gt_dicesim, self.ph_target_label, self.predict_label, self.predict_sim ], feed_dict=trainFeed) sitk_write_images(target_img, dir=self.args.sample_dir, name=str(itr) + "target_img") sitk_write_labs(atlas_label, dir=self.args.sample_dir, name=str(itr) + "atlas_label") sitk_write_images(gt, dir=self.args.sample_dir, name=str(itr) + "gt_sim") sitk_write_images(pred_sim, dir=self.args.sample_dir, name=str(itr) + "pred_sim") sitk_write_labs(target_lab, dir=self.args.sample_dir, name=str(itr) + "target_label") sitk_write_labs(pred_lab, dir=self.args.sample_dir, name=str(itr) + "pred_label") pre_acc = calculate_binary_dice(target_lab, atlas_label) acc = calculate_binary_dice(target_lab, pred_lab, 0.2) self.logger.debug("pre_acc %f -> acc %f" % (pre_acc, acc)) return acc
def sample(self, iter): fix_imgs, fix_labs, mv_imgs, mv_labs = self.validate_sampler.next_sample( ) # p_img_mvs, p_img_fixs, p_lab_mvs, p_lab_fixs = self.validate_sampler.get_batch_file() # fix_imgs, fix_labs, mv_imgs, mv_labs = self.validate_sampler.get_batch_data_V2(p_img_mvs, p_img_fixs, p_lab_mvs, # p_lab_fixs) trainFeed = self.create_feed_dict(fix_imgs, fix_labs, mv_imgs, mv_labs, is_aug=False) input_mv_label, input_fix_label = self.sess.run( [self.input_MV_label, self.input_FIX_label], feed_dict=trainFeed) np.where(input_mv_label > 0.5, 1, 0) np.where(input_fix_label > 0.5, 1, 0) sitk_write_images(input_fix_label.astype(np.uint16), None, self.args.sample_dir, '%d_input_fix_lab' % (iter)) sitk_write_images(input_mv_label.astype(np.uint16), None, self.args.sample_dir, '%d_input_mv_lab' % (iter)) warp_mv_labs = self.sess.run(self.warped_MV_label, feed_dict=trainFeed) warp_mv_labs = np.where(warp_mv_labs > 0.5, 1, 0) sitk_write_images(warp_mv_labs.astype(np.uint16), None, self.args.sample_dir, '%d_warp_mv_lab' % (iter)) fix_labs, mv_labs = self.sess.run( [self.ph_FIX_label, self.ph_MV_label], feed_dict=trainFeed) fix_labs = np.where(fix_labs > 0.5, 1, 0) mv_labs = np.where(mv_labs > 0.5, 1, 0) # sitk_write_images(fix_labs.astype(np.uint16), None, self.args.sample_dir, '%d_fix_lab' % (iter)) # sitk_write_images(mv_labs.astype(np.uint16),None,self.args.sample_dir,'%d_mv_lab'%(iter)) input_fix_imgs, input_mv_imgs = self.sess.run( [self.input_FIX_image, self.input_MV_image], feed_dict=trainFeed) sitk_write_images(input_fix_imgs, None, self.args.sample_dir, '%d_input_fix_img' % (iter)) sitk_write_images(input_mv_imgs, None, self.args.sample_dir, '%d_input_mv_img' % (iter)) warped_mv_imgs = self.sess.run(self.warped_MV_image, feed_dict=trainFeed) sitk_write_images(warped_mv_imgs, None, self.args.sample_dir, '%d_warp_mv_img' % (iter)) fix_imgs, mv_imgs = self.sess.run( [self.ph_FIX_image, self.ph_MV_image], feed_dict=trainFeed) sitk_write_images(mv_imgs, None, self.args.sample_dir, '%d_mv_img' % (iter)) sitk_write_images(fix_imgs, None, self.args.sample_dir, '%d_fix_img' % (iter)) sim_warp_mv_fix = calculate_binary_dice(warp_mv_labs, fix_labs) # para = sitk.ReadImage(p_lab_fixs[0]) # hd_warp_mv_fix = hd(np.squeeze(warp_mv_labs[0, ...]), np.squeeze(fix_labs[0, ...]),voxelspacing=para.GetSpacing()) hd_warp_mv_fix = 999 return sim_warp_mv_fix, hd_warp_mv_fix