def load_file(self, data_dir, file_name, num_channels=3): try: self.data_dir = data_dir self.file_name = file_name self.image_arr = imgutil.get_image_as_array(os.path.join(self.data_dir, self.file_name), channels=num_channels) except Exception: try: self.image_arr = imgutil.get_image_as_array(os.path.join(self.data_dir, self.file_name), channels=3 if num_channels == 1 else 1) except Exception as e1: print('Error Loading file: ' + self.file_name) print(str(e1))
def load_mask(self, mask_dir=None, fget_mask=None, erode=False, channels=1): try: mask_file = fget_mask(self.file_name) self.mask = imgutil.get_image_as_array(os.path.join(mask_dir, mask_file), channels) if erode: self.mask = cv2.erode(self.mask, kernel=fu.get_chosen_mask_erode_kernel(), iterations=5) except Exception as e: print('Fail to load mask: ' + str(e))
def _get_image_obj(self, img_file=None): img_obj = Image() img_obj.load_file(data_dir=self.image_dir, file_name=img_file) if self.mask_getter is not None: img_obj.load_mask(mask_dir=self.mask_dir, fget_mask=self.mask_getter) if self.truth_getter is not None: img_obj.load_ground_truth(gt_dir=self.truth_dir, fget_ground_truth=self.truth_getter) img_obj.working_arr = img_obj.image_arr[:, :, 1] img_obj.apply_clahe() img_obj.apply_mask() sup, res = 20, 235 img_obj.extra['unet'] = iu.get_image_as_array( self.unet_dir + sep + img_obj.file_name.split('.')[0] + self.input_image_ext, 1) img_obj.extra['indices'] = list( zip(*np.where((img_obj.extra['unet'] >= sup) & (img_obj.extra['unet'] <= res)))) img_obj.extra['fill_in'] = np.zeros_like(img_obj.working_arr) img_obj.extra['fill_in'][img_obj.extra['unet'] > res] = 1 img_obj.extra['mid_pix'] = img_obj.extra['unet'].copy() img_obj.extra['mid_pix'][img_obj.extra['mid_pix'] < sup] = 0 img_obj.extra['mid_pix'][img_obj.extra['mid_pix'] > res] = 0 img_obj.extra['gt_mid'] = img_obj.ground_truth.copy() img_obj.extra['gt_mid'][img_obj.extra['unet'] > res] = 0 img_obj.extra['gt_mid'][img_obj.extra['unet'] < sup] = 0 # <PREP1> Segment with a low threshold and get a raw segmented image raw_estimate = img_obj.extra['unet'].copy() raw_estimate[raw_estimate > sup] = 255 raw_estimate[raw_estimate <= sup] = 0 # <PREP2> Clear up small components(components less that 10px) raw_estimate = iu.remove_connected_comp(raw_estimate.squeeze(), 10) # <PREP3> Skeletonize binary image seed = raw_estimate.copy() seed[seed == 255] = 1 seed = skeletonize(seed).astype(np.uint8) # <PREP4> Come up with a grid mask to select few possible pixels to reconstruct the vessels from sk_mask = np.zeros_like(seed) sk_mask[::int(0.6 * self.patch_shape[0])] = 1 sk_mask[:, ::int(0.6 * self.patch_shape[0])] = 1 # <PREP5> Apply mask and save seed img_obj.extra['seed'] = seed * sk_mask * 255 return img_obj
def run_all(self, Dirs=None, fget_mask=None, fget_gt=None, params_combination=[], save_images=False): if os.path.isdir(self.out_dir) is False: os.makedirs(self.out_dir) self.writer = open(self.out_dir + os.sep + "segmentation_result.csv", 'w') self.writer.write('ITR,FILE_NAME,F1,PRECISION,RECALL,ACCURACY,' 'SK_THRESHOLD,' 'ALPHA,' 'ORIG_CONTRIB,' 'SEG_THRESHOLD\n') for file_name in os.listdir(Dirs['images']): print('File: ' + file_name) img_obj = SegmentedImage() img_obj.load_file(data_dir=Dirs['images'], file_name=file_name) img_obj.working_arr = img_obj.image_arr[:, :, 1] img_obj.apply_clahe() img_obj.res['orig'] = img_obj.working_arr img_obj.working_arr = imgutils.get_image_as_array( Dirs['segmented'] + sep + file_name + '.png', channels=1) img_obj.load_mask(mask_dir=Dirs['mask'], fget_mask=fget_mask, erode=True) img_obj.load_ground_truth(gt_dir=Dirs['truth'], fget_ground_truth=fget_gt) img_obj.apply_mask() img_obj.generate_lattice_graph() arr = img_obj.working_arr.copy() for i in range(arr.shape[0]): for j in range(arr.shape[1]): if img_obj.mask[i, j] == 0: arr[i, j] = 255 img_obj.working_arr = arr for params in params_combination: img_obj.generate_skeleton( threshold=int(params['sk_threshold'])) self._run(img_obj=img_obj, params=params, save_images=save_images) self.writer.close()
def load_ground_truth(self, gt_dir=None, fget_ground_truth=None, channels=1): try: gt_file = fget_ground_truth(self.file_name) self.ground_truth = imgutil.get_image_as_array(os.path.join(gt_dir, gt_file), channels) except Exception as e: print('Fail to load ground truth: ' + str(e))