Ejemplo n.º 1
0
    def __call__(self, results):
        """Call function.

        Args:
            results (dict): A dict containing the necessary information and
                data for augmentation.

        Returns:
            dict: A dict containing the processed data and information.
        """

        if self.mask_mode == 'bbox':
            mask_bbox = random_bbox(**self.mask_config)
            mask = bbox2mask(self.mask_config['img_shape'], mask_bbox)
            results['mask_bbox'] = mask_bbox
        elif self.mask_mode == 'irregular':
            mask = get_irregular_mask(**self.mask_config)
        elif self.mask_mode == 'set':
            mask = self._get_random_mask_from_set()
        elif self.mask_mode == 'ff':
            mask = brush_stroke_mask(**self.mask_config)
        elif self.mask_mode == 'file':
            mask = self._get_mask_from_file(results['mask_path'])
        else:
            raise NotImplementedError(
                f'Mask mode {self.mask_mode} has not been implemented.')
        results['mask'] = mask
        return results
Ejemplo n.º 2
0
def test_irregular_mask():
    img_shape = (256, 256)
    for _ in range(10):
        mask = get_irregular_mask(img_shape)
        assert mask.shape == (256, 256, 1)
        assert 0.15 < (np.sum(mask) / (img_shape[0] * img_shape[1])) < 0.50
        zero_area = np.sum((mask == 0).astype(np.uint8))
        ones_area = np.sum((mask == 1).astype(np.uint8))
        assert zero_area + ones_area == 256 * 256
        assert mask.dtype == np.uint8

    with pytest.raises(TypeError):
        mask = get_irregular_mask(img_shape, brush_width=dict())

    with pytest.raises(TypeError):
        mask = get_irregular_mask(img_shape, length_range=dict())

    with pytest.raises(TypeError):
        mask = get_irregular_mask(img_shape, num_vertexes=dict())

    mask = get_irregular_mask(img_shape, brush_width=10)
    assert mask.shape == (256, 256, 1)

    mask = get_irregular_mask(img_shape, length_range=10)
    assert mask.shape == (256, 256, 1)

    mask = get_irregular_mask(img_shape, num_vertexes=10)
    assert mask.shape == (256, 256, 1)