Exemplo n.º 1
0
def test_bbox_mask():
    # default config for random bbox mask
    cfg = dict(img_shape=(256, 256),
               max_bbox_shape=100,
               max_bbox_delta=10,
               min_margin=10)

    bbox = random_bbox(**cfg)
    mask_bbox = bbox2mask(cfg['img_shape'], bbox)
    assert mask_bbox.shape == (256, 256, 1)
    zero_area = np.sum((mask_bbox == 0).astype(np.uint8))
    ones_area = np.sum((mask_bbox == 1).astype(np.uint8))
    assert zero_area + ones_area == 256 * 256
    assert mask_bbox.dtype == np.uint8

    with pytest.raises(ValueError):
        cfg_ = cfg.copy()
        cfg_['max_bbox_shape'] = 300
        bbox = random_bbox(**cfg_)

    with pytest.raises(ValueError):
        cfg_ = cfg.copy()
        cfg_['max_bbox_delta'] = 300
        bbox = random_bbox(**cfg_)

    with pytest.raises(ValueError):
        cfg_ = cfg.copy()
        cfg_['max_bbox_shape'] = 254
        bbox = random_bbox(**cfg_)

    cfg_ = cfg.copy()
    cfg_['max_bbox_delta'] = 1
    bbox = random_bbox(**cfg_)
    mask_bbox = bbox2mask(cfg['img_shape'], bbox)
    assert mask_bbox.shape == (256, 256, 1)
Exemplo n.º 2
0
    def __call__(self, results):
        """Call function.

        Args:
            results (dict): A dict containing the necessary information and
                data for augmentation.

        Returns:
            dict: A dict containing the processed data and information.
        """

        if self.mask_mode == 'bbox':
            mask_bbox = random_bbox(**self.mask_config)
            mask = bbox2mask(self.mask_config['img_shape'], mask_bbox)
            results['mask_bbox'] = mask_bbox
        elif self.mask_mode == 'irregular':
            mask = get_irregular_mask(**self.mask_config)
        elif self.mask_mode == 'set':
            mask = self._get_random_mask_from_set()
        elif self.mask_mode == 'ff':
            mask = brush_stroke_mask(**self.mask_config)
        elif self.mask_mode == 'file':
            mask = self._get_mask_from_file(results['mask_path'])
        else:
            raise NotImplementedError(
                f'Mask mode {self.mask_mode} has not been implemented.')
        results['mask'] = mask
        return results