def __call__(self, results): """Call function. Args: results (dict): A dict containing the necessary information and data for augmentation. Returns: dict: A dict containing the processed data and information. """ if self.mask_mode == 'bbox': mask_bbox = random_bbox(**self.mask_config) mask = bbox2mask(self.mask_config['img_shape'], mask_bbox) results['mask_bbox'] = mask_bbox elif self.mask_mode == 'irregular': mask = get_irregular_mask(**self.mask_config) elif self.mask_mode == 'set': mask = self._get_random_mask_from_set() elif self.mask_mode == 'ff': mask = brush_stroke_mask(**self.mask_config) elif self.mask_mode == 'file': mask = self._get_mask_from_file(results['mask_path']) else: raise NotImplementedError( f'Mask mode {self.mask_mode} has not been implemented.') results['mask'] = mask return results
def test_irregular_mask(): img_shape = (256, 256) for _ in range(10): mask = get_irregular_mask(img_shape) assert mask.shape == (256, 256, 1) assert 0.15 < (np.sum(mask) / (img_shape[0] * img_shape[1])) < 0.50 zero_area = np.sum((mask == 0).astype(np.uint8)) ones_area = np.sum((mask == 1).astype(np.uint8)) assert zero_area + ones_area == 256 * 256 assert mask.dtype == np.uint8 with pytest.raises(TypeError): mask = get_irregular_mask(img_shape, brush_width=dict()) with pytest.raises(TypeError): mask = get_irregular_mask(img_shape, length_range=dict()) with pytest.raises(TypeError): mask = get_irregular_mask(img_shape, num_vertexes=dict()) mask = get_irregular_mask(img_shape, brush_width=10) assert mask.shape == (256, 256, 1) mask = get_irregular_mask(img_shape, length_range=10) assert mask.shape == (256, 256, 1) mask = get_irregular_mask(img_shape, num_vertexes=10) assert mask.shape == (256, 256, 1)