def dump_class_masks(self, required_class_names: List[str],
                         highlight_classname: str, condition,
                         folder_prefix: str):
        """ Get for all splits, at once.

			Args:
			-	required_class_names
			-	highlight_classname: class to highlight in pink
			-	condition
			-	folder_prefix

			Returns:
			-	None
		"""
        for split in ['train', 'val']:
            ade20k_split_nickname = self.ade20k_split_nickname_dict[split]
            rgb_fpaths = glob.glob(
                f'{self.img_dir}/{ade20k_split_nickname}/*.jpg')

            for i, rgb_fpath in enumerate(rgb_fpaths):
                print(f'On {i}/{len(rgb_fpaths)-1}')
                fname_stem = Path(rgb_fpath).stem
                assert rgb_fpath == self.fname_to_rgbfpath_dict[fname_stem]

                rgb_img, label_img = self.get_img_pair(fname_stem)

                present_class_idxs = np.unique(label_img)
                present_classnames = [
                    self.id_to_classname_map[idx] for idx in present_class_idxs
                ]

                if not all([
                        req_name in present_classnames
                        for req_name in required_class_names
                ]):
                    continue

                seg_rgb_fpath = self.fname_to_segrgbfpath_dict[fname_stem]
                instance_masks, instance_ids = get_ade20k_instance_label_masks(
                    seg_rgb_fpath, rgb_img)
                for (instance_mask,
                     instance_id) in zip(instance_masks, instance_ids):

                    # test the instance's class
                    label_votes, majority_vote = get_instance_mask_class_votes(
                        instance_mask, label_img)
                    if label_votes.size < MIN_REQ_PX:
                        continue

                    instance_classname = self.id_to_classname_map[
                        majority_vote]
                    if instance_classname != highlight_classname:  # not in required_class_names:
                        continue

                    save_fname = f'{fname_stem}_{instance_id}.png'
                    save_fpath = f'temp_files/{folder_prefix}_{split}_2019_12_16/{save_fname}'
                    save_binary_mask_double(rgb_img,
                                            instance_mask,
                                            save_fpath,
                                            save_to_disk=True)
    def get_segment_mask(self, seq_id: None, segmentid: int, fname_stem: str,
                         split: str) -> Optional[np.ndarray]:
        """
			Args:
			-	segmentid: integer representing segment unique ID
			-	fname_stem
			-	split: dataset split, i.e. 'train' or 'val'

			Returns:
			-	segment_mask
		"""
        segment_mask = None
        rgb_img, label_img = self.get_img_pair(fname_stem)

        seg_rgb_fpath = self.fname_to_segrgbfpath_dict[fname_stem]
        instance_masks, instance_ids = get_ade20k_instance_label_masks(
            seg_rgb_fpath, rgb_img)
        for (instance_mask, instance_id) in zip(instance_masks, instance_ids):

            if instance_id == segmentid:
                segment_mask = instance_mask
                break

        # test the instance's class
        label_votes, majority_vote = get_instance_mask_class_votes(
            copy.deepcopy(segment_mask), label_img)

        if label_votes.size < MIN_REQ_PX:
            print('Big problem here! quitting...')
            quit()

        if segment_mask is None:
            print('Specified segment ID does not exist.')
            return None
        return segment_mask
Exemple #3
0
def test_get_instance_mask_class_votes1():
    """
	Given instance masks covering entire image, ensure get most likely
	category vote.
	"""
    instance_mask = np.array([[1, 1], [1, 1]])
    label_img = np.array([[0, 0], [0, 0]])
    label_votes, majority_vote = get_instance_mask_class_votes(
        instance_mask, label_img)
    gt_majority_vote = 0
    gt_label_votes = np.array([0, 0, 0, 0])
    assert gt_majority_vote == majority_vote
    assert np.allclose(gt_label_votes, label_votes)
Exemple #4
0
def test_get_instance_mask_class_votes2():
    """
	Given instance masks covering partial image, ensure get most likely
	category vote. Here, our instance mask is noisy.
	"""
    instance_mask = np.array([[0, 0, 1, 0], [0, 1, 1, 0], [0, 1, 1, 0],
                              [0, 1, 0, 0]])
    label_img = np.array([[8, 8, 8, 8], [9, 9, 9, 9], [8, 9, 9, 9],
                          [7, 7, 7, 7]])
    label_votes, majority_vote = get_instance_mask_class_votes(
        instance_mask, label_img)
    gt_majority_vote = 9
    gt_label_votes = np.array([8, 9, 9, 9, 9, 7])
    assert gt_majority_vote == majority_vote
    assert np.allclose(gt_label_votes, label_votes)