def dump_class_masks(self, required_class_names: List[str], highlight_classname: str, condition, folder_prefix: str): """ Get for all splits, at once. Args: - required_class_names - highlight_classname: class to highlight in pink - condition - folder_prefix Returns: - None """ for split in ['train', 'val']: ade20k_split_nickname = self.ade20k_split_nickname_dict[split] rgb_fpaths = glob.glob( f'{self.img_dir}/{ade20k_split_nickname}/*.jpg') for i, rgb_fpath in enumerate(rgb_fpaths): print(f'On {i}/{len(rgb_fpaths)-1}') fname_stem = Path(rgb_fpath).stem assert rgb_fpath == self.fname_to_rgbfpath_dict[fname_stem] rgb_img, label_img = self.get_img_pair(fname_stem) present_class_idxs = np.unique(label_img) present_classnames = [ self.id_to_classname_map[idx] for idx in present_class_idxs ] if not all([ req_name in present_classnames for req_name in required_class_names ]): continue seg_rgb_fpath = self.fname_to_segrgbfpath_dict[fname_stem] instance_masks, instance_ids = get_ade20k_instance_label_masks( seg_rgb_fpath, rgb_img) for (instance_mask, instance_id) in zip(instance_masks, instance_ids): # test the instance's class label_votes, majority_vote = get_instance_mask_class_votes( instance_mask, label_img) if label_votes.size < MIN_REQ_PX: continue instance_classname = self.id_to_classname_map[ majority_vote] if instance_classname != highlight_classname: # not in required_class_names: continue save_fname = f'{fname_stem}_{instance_id}.png' save_fpath = f'temp_files/{folder_prefix}_{split}_2019_12_16/{save_fname}' save_binary_mask_double(rgb_img, instance_mask, save_fpath, save_to_disk=True)
def get_segment_mask(self, seq_id: None, segmentid: int, fname_stem: str, split: str) -> Optional[np.ndarray]: """ Args: - segmentid: integer representing segment unique ID - fname_stem - split: dataset split, i.e. 'train' or 'val' Returns: - segment_mask """ segment_mask = None rgb_img, label_img = self.get_img_pair(fname_stem) seg_rgb_fpath = self.fname_to_segrgbfpath_dict[fname_stem] instance_masks, instance_ids = get_ade20k_instance_label_masks( seg_rgb_fpath, rgb_img) for (instance_mask, instance_id) in zip(instance_masks, instance_ids): if instance_id == segmentid: segment_mask = instance_mask break # test the instance's class label_votes, majority_vote = get_instance_mask_class_votes( copy.deepcopy(segment_mask), label_img) if label_votes.size < MIN_REQ_PX: print('Big problem here! quitting...') quit() if segment_mask is None: print('Specified segment ID does not exist.') return None return segment_mask
def test_get_instance_mask_class_votes1(): """ Given instance masks covering entire image, ensure get most likely category vote. """ instance_mask = np.array([[1, 1], [1, 1]]) label_img = np.array([[0, 0], [0, 0]]) label_votes, majority_vote = get_instance_mask_class_votes( instance_mask, label_img) gt_majority_vote = 0 gt_label_votes = np.array([0, 0, 0, 0]) assert gt_majority_vote == majority_vote assert np.allclose(gt_label_votes, label_votes)
def test_get_instance_mask_class_votes2(): """ Given instance masks covering partial image, ensure get most likely category vote. Here, our instance mask is noisy. """ instance_mask = np.array([[0, 0, 1, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 1, 0, 0]]) label_img = np.array([[8, 8, 8, 8], [9, 9, 9, 9], [8, 9, 9, 9], [7, 7, 7, 7]]) label_votes, majority_vote = get_instance_mask_class_votes( instance_mask, label_img) gt_majority_vote = 9 gt_label_votes = np.array([8, 9, 9, 9, 9, 7]) assert gt_majority_vote == majority_vote assert np.allclose(gt_label_votes, label_votes)