示例#1
0
def validate_albumentation_transform(tf: dict):
    """ Validate a serialized albumentation transform. """
    if tf is not None:
        try:
            A.from_dict(tf)
        except Exception:
            raise ConfigError('The given serialization is invalid. Use '
                              'A.to_dict(transform) to serialize.')
    return tf
示例#2
0
def test_transform_pipeline_serialization_with_keypoints(seed, image, keypoints, keypoint_format, labels):
    aug = A.Compose([
        A.OneOrOther(
            A.Compose([
                A.RandomRotate90(),
                A.OneOf([
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.5),
                ])
            ]),
            A.Compose([
                A.Rotate(p=0.5),
                A.OneOf([
                    A.HueSaturationValue(p=0.5),
                    A.RGBShift(p=0.7)
                ], p=1),
            ])
        ),
        A.HorizontalFlip(p=1),
        A.RandomBrightnessContrast(p=0.5)
    ], keypoint_params={'format': keypoint_format, 'label_fields': ['labels']})
    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    random.seed(seed)
    aug_data = aug(image=image, keypoints=keypoints, labels=labels)
    random.seed(seed)
    deserialized_aug_data = deserialized_aug(image=image, keypoints=keypoints, labels=labels)
    assert np.array_equal(aug_data['image'], deserialized_aug_data['image'])
    assert np.array_equal(aug_data['keypoints'], deserialized_aug_data['keypoints'])
示例#3
0
def test_lambda_serialization(image, mask, bboxes, keypoints, seed, p):

    def vflip_image(image, **kwargs):
        return F.vflip(image)

    def vflip_mask(mask, **kwargs):
        return F.vflip(mask)

    def vflip_bbox(bbox, **kwargs):
        return F.bbox_vflip(bbox, **kwargs)

    def vflip_keypoint(keypoint, **kwargs):
        return F.keypoint_vflip(keypoint, **kwargs)

    aug = A.Lambda(name='vflip', image=vflip_image, mask=vflip_mask, bbox=vflip_bbox, keypoint=vflip_keypoint, p=p)

    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug, lambda_transforms={'vflip': aug})
    random.seed(seed)
    aug_data = aug(image=image, mask=mask, bboxes=bboxes, keypoints=keypoints)
    random.seed(seed)
    deserialized_aug_data = deserialized_aug(image=image, mask=mask, bboxes=bboxes, keypoints=keypoints)
    assert np.array_equal(aug_data['image'], deserialized_aug_data['image'])
    assert np.array_equal(aug_data['mask'], deserialized_aug_data['mask'])
    assert np.array_equal(aug_data['bboxes'], deserialized_aug_data['bboxes'])
    assert np.array_equal(aug_data['keypoints'], deserialized_aug_data['keypoints'])
示例#4
0
def _load_model(url, ):
    model_path = get_file_from_url(url, progress=True, unzip=False)
    model = torch.jit.load(model_path)

    hparams = json.loads(model.hparams)

    categories = hparams['categories']
    categories_values = hparams.get('categories_values', None)

    if "preprocess" in hparams:
        preprocess = A.from_dict(hparams["preprocess"])
    else:
        preprocess = get_default_preprocess()

    loaded = {
        'model': model,
        'categories': categories,
        'categories_values': categories_values,
        'preprocess': preprocess
    }

    if categories_values is None:
        loaded.pop('categories_values')

    return loaded
def test_from_float_serialization(float_image):
    aug = A.FromFloat(p=1, dtype="uint8")
    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    aug_data = aug(image=float_image)
    deserialized_aug_data = deserialized_aug(image=float_image)
    assert np.array_equal(aug_data["image"], deserialized_aug_data["image"])
 def _get_aug(self, arg):
     with open(arg) as f:
         augs = A.from_dict(json.load(f))
     target = {}
     for i in range(1, self.N):
         target['image' + str(i)] = 'image'
     return A.Compose(augs, p=1, additional_targets=target)
示例#7
0
def test_transform_pipeline_serialization(seed, image, mask):
    aug = A.Compose([
        A.OneOrOther(
            A.Compose([
                A.Resize(1024, 1024),
                A.RandomSizedCrop(min_max_height=(256, 1024), height=512, width=512, p=1),
                A.OneOf([
                    A.RandomSizedCrop(min_max_height=(256, 512), height=384, width=384, p=0.5),
                    A.RandomSizedCrop(min_max_height=(256, 512), height=512, width=512, p=0.5),
                ])
            ]),
            A.Compose([
                A.Resize(1024, 1024),
                A.RandomSizedCrop(min_max_height=(256, 1025), height=256, width=256, p=1),
                A.OneOf([
                    A.HueSaturationValue(p=0.5),
                    A.RGBShift(p=0.7)
                ], p=1),
            ])
        ),
        A.HorizontalFlip(p=1),
        A.RandomBrightnessContrast(p=0.5)
    ])
    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    random.seed(seed)
    aug_data = aug(image=image, mask=mask)
    random.seed(seed)
    deserialized_aug_data = deserialized_aug(image=image, mask=mask)
    assert np.array_equal(aug_data['image'], deserialized_aug_data['image'])
    assert np.array_equal(aug_data['mask'], deserialized_aug_data['mask'])
def test_lambda_serialization(image, mask, albumentations_bboxes, keypoints, seed, p):
    def vflip_image(image, **kwargs):
        return F.vflip(image)

    def vflip_mask(mask, **kwargs):
        return F.vflip(mask)

    def vflip_bbox(bbox, **kwargs):
        return F.bbox_vflip(bbox, **kwargs)

    def vflip_keypoint(keypoint, **kwargs):
        return F.keypoint_vflip(keypoint, **kwargs)

    aug = A.Lambda(name="vflip", image=vflip_image, mask=vflip_mask, bbox=vflip_bbox, keypoint=vflip_keypoint, p=p)

    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug, lambda_transforms={"vflip": aug})
    set_seed(seed)
    aug_data = aug(image=image, mask=mask, bboxes=albumentations_bboxes, keypoints=keypoints)
    set_seed(seed)
    deserialized_aug_data = deserialized_aug(image=image, mask=mask, bboxes=albumentations_bboxes, keypoints=keypoints)
    assert np.array_equal(aug_data["image"], deserialized_aug_data["image"])
    assert np.array_equal(aug_data["mask"], deserialized_aug_data["mask"])
    assert np.array_equal(aug_data["bboxes"], deserialized_aug_data["bboxes"])
    assert np.array_equal(aug_data["keypoints"], deserialized_aug_data["keypoints"])
示例#9
0
    def plot_batch(self, x: Tensor, y, output_path: str, z=None):
        """Plot a whole batch in a grid using plot_xyz.

        Args:
            x: batch of images
            y: ground truth labels
            output_path: local path where to save plot image
            z: optional predicted labels
        """
        batch_sz = x.shape[0]
        ncols = nrows = math.ceil(math.sqrt(batch_sz))
        fig = plt.figure(constrained_layout=True,
                         figsize=(3 * ncols, 3 * nrows))
        grid = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig)

        # (N, c, h, w) --> (N, h, w, c)
        x = x.permute(0, 2, 3, 1)

        # apply transform, if given
        if self.cfg.data.plot_options.transform is not None:
            tf = A.from_dict(self.cfg.data.plot_options.transform)
            x = tf(image=x.numpy())['image']
            x = torch.from_numpy(x)

        for i in range(batch_sz):
            ax = fig.add_subplot(grid[i])
            if z is None:
                self.plot_xyz(ax, x[i], y[i])
            else:
                self.plot_xyz(ax, x[i], y[i], z=z[i])

        make_dir(output_path, use_dirname=True)
        plt.savefig(output_path)
        plt.close()
示例#10
0
 def load_model(checkpoint):
     model_params = checkpoint['model_params']
     model_params.pop('weights', None)
     model, preprocessing = get_model(**model_params)
     model.load_state_dict(checkpoint['state_dict'])
     transforms = A.from_dict(checkpoint['test_transforms'])
     return model, preprocessing, transforms
def test_transform_pipeline_serialization_with_bboxes(seed, image, bboxes,
                                                      bbox_format, labels):
    aug = A.Compose(
        [
            A.OneOrOther(
                A.Compose([
                    A.RandomRotate90(),
                    A.OneOf([A.HorizontalFlip(p=0.5),
                             A.VerticalFlip(p=0.5)])
                ]),
                A.Compose([
                    A.Rotate(p=0.5),
                    A.OneOf([A.HueSaturationValue(p=0.5),
                             A.RGBShift(p=0.7)],
                            p=1)
                ]),
            ),
            A.HorizontalFlip(p=1),
            A.RandomBrightnessContrast(p=0.5),
        ],
        bbox_params={
            "format": bbox_format,
            "label_fields": ["labels"]
        },
    )
    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    set_seed(seed)
    aug_data = aug(image=image, bboxes=bboxes, labels=labels)
    set_seed(seed)
    deserialized_aug_data = deserialized_aug(image=image,
                                             bboxes=bboxes,
                                             labels=labels)
    assert np.array_equal(aug_data["image"], deserialized_aug_data["image"])
    assert np.array_equal(aug_data["bboxes"], deserialized_aug_data["bboxes"])
def test_subtract_is_serializable(image):
    aug = Subtract(subtract_value=10)
    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    set_seed(42)
    aug_data = aug.apply(image)
    set_seed(42)
    deserialized_aug_data = deserialized_aug.apply(image)
    assert np.array_equal(aug_data, deserialized_aug_data)
示例#13
0
    def plot_batch(self,
                   x: torch.Tensor,
                   y: Union[torch.Tensor, np.ndarray],
                   output_path: str,
                   z: Optional[torch.Tensor] = None,
                   batch_limit: Optional[int] = None) -> None:
        """Plot a whole batch in a grid using plot_xyz.

        Args:
            x: batch of images
            y: ground truth labels
            output_path: local path where to save plot image
            z: optional predicted labels
            batch_limit: optional limit on (rendered) batch size
        """
        batch_sz, c, h, w = x.shape
        batch_sz = min(batch_sz,
                       batch_limit) if batch_limit is not None else batch_sz
        if batch_sz == 0:
            return

        channel_groups = self.cfg.data.channel_display_groups

        nrows = batch_sz
        # one col for each group + 1 for labels + 1 for predictions
        ncols = len(channel_groups) + 1
        if z is not None:
            ncols += 1

        fig, axes = plt.subplots(
            nrows=nrows,
            ncols=ncols,
            squeeze=False,
            constrained_layout=True,
            figsize=(3 * ncols, 3 * nrows))

        assert axes.shape == (nrows, ncols)

        # (N, c, h, w) --> (N, h, w, c)
        x = x.permute(0, 2, 3, 1)

        # apply transform, if given
        if self.cfg.data.plot_options.transform is not None:
            tf = A.from_dict(self.cfg.data.plot_options.transform)
            imgs = [tf(image=img)['image'] for img in x.numpy()]
            x = torch.from_numpy(np.stack(imgs))

        for i in range(batch_sz):
            ax = (fig, axes[i])
            if z is None:
                self.plot_xyz(ax, x[i], y[i])
            else:
                self.plot_xyz(ax, x[i], y[i], z=z[i])

        make_dir(output_path, use_dirname=True)
        plt.savefig(output_path, bbox_inches='tight')
        plt.close()
示例#14
0
def test_augmentations_for_keypoints_serialization(augmentation_cls, params, p, seed, image, keypoints, always_apply):
    aug = augmentation_cls(p=p, always_apply=always_apply, **params)
    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    random.seed(seed)
    aug_data = aug(image=image, keypoints=keypoints)
    random.seed(seed)
    deserialized_aug_data = deserialized_aug(image=image, keypoints=keypoints)
    assert np.array_equal(aug_data['image'], deserialized_aug_data['image'])
    assert np.array_equal(aug_data['keypoints'], deserialized_aug_data['keypoints'])
示例#15
0
    def get_data_transforms(self) -> Tuple[A.BasicTransform, A.BasicTransform]:
        """Get albumentations transform objects for data augmentation.

        Returns:
           1st tuple arg: a transform that doesn't do any data augmentation
           2nd tuple arg: a transform with data augmentation
        """
        cfg = self.cfg
        bbox_params = self.get_bbox_params()
        base_tfs = [A.Resize(cfg.data.img_sz, cfg.data.img_sz)]
        if cfg.data.base_transform is not None:
            base_tfs.append(A.from_dict(cfg.data.base_transform))
        base_transform = A.Compose(base_tfs, bbox_params=bbox_params)

        if cfg.data.aug_transform is not None:
            aug_transform = A.from_dict(cfg.data.aug_transform)
            aug_transform = A.Compose([aug_transform, base_transform],
                                      bbox_params=bbox_params)
            return base_transform, aug_transform

        augmentors_dict = {
            'Blur': A.Blur(),
            'RandomRotate90': A.RandomRotate90(),
            'HorizontalFlip': A.HorizontalFlip(),
            'VerticalFlip': A.VerticalFlip(),
            'GaussianBlur': A.GaussianBlur(),
            'GaussNoise': A.GaussNoise(),
            'RGBShift': A.RGBShift(),
            'ToGray': A.ToGray()
        }
        aug_transforms = []
        for augmentor in cfg.data.augmentors:
            try:
                aug_transforms.append(augmentors_dict[augmentor])
            except KeyError as e:
                log.warning(
                    '{0} is an unknown augmentor. Continuing without {0}. \
                    Known augmentors are: {1}'.format(
                        e, list(augmentors_dict.keys())))
        aug_transforms.append(base_transform)
        aug_transform = A.Compose(aug_transforms, bbox_params=bbox_params)

        return base_transform, aug_transform
def test_augmentations_serialization(augmentation_cls, params, p, seed, image, mask, always_apply):
    aug = augmentation_cls(p=p, always_apply=always_apply, **params)
    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    set_seed(seed)
    aug_data = aug(image=image, mask=mask)
    set_seed(seed)
    deserialized_aug_data = deserialized_aug(image=image, mask=mask)
    assert np.array_equal(aug_data["image"], deserialized_aug_data["image"])
    assert np.array_equal(aug_data["mask"], deserialized_aug_data["mask"])
示例#17
0
def test_image_only_crop_around_bbox_augmentation_serialization(p, seed, image, always_apply):
    aug = A.RandomCropNearBBox(p=p, always_apply=always_apply, max_part_shift=0.15)
    annotations = {'image': image, 'cropping_bbox': [-59, 77, 177, 231]}
    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    random.seed(seed)
    aug_data = aug(**annotations)
    random.seed(seed)
    deserialized_aug_data = deserialized_aug(**annotations)
    assert np.array_equal(aug_data['image'], deserialized_aug_data['image'])
def test_additional_targets_for_image_only_serialization(augmentation_cls, params, image, seed):
    aug = A.Compose([augmentation_cls(always_apply=True, **params)], additional_targets={"image2": "image"})
    image2 = image.copy()

    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    set_seed(seed)
    aug_data = aug(image=image, image2=image2)
    set_seed(seed)
    deserialized_aug_data = deserialized_aug(image=image, image2=image2)
    assert np.array_equal(aug_data["image"], deserialized_aug_data["image"])
    assert np.array_equal(aug_data["image2"], deserialized_aug_data["image2"])
def test_augmentations_for_bboxes_serialization(
    augmentation_cls, params, p, seed, image, albumentations_bboxes, always_apply
):
    aug = augmentation_cls(p=p, always_apply=always_apply, **params)
    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    set_seed(seed)
    aug_data = aug(image=image, bboxes=albumentations_bboxes)
    set_seed(seed)
    deserialized_aug_data = deserialized_aug(image=image, bboxes=albumentations_bboxes)
    assert np.array_equal(aug_data["image"], deserialized_aug_data["image"])
    assert np.array_equal(aug_data["bboxes"], deserialized_aug_data["bboxes"])
示例#20
0
def test_imgaug_augmentations_for_bboxes_serialization(augmentation_cls, params, p, seed, image, bboxes, always_apply):
    aug = augmentation_cls(p=p, always_apply=always_apply, **params)
    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    random.seed(seed)
    ia.seed(seed)
    aug_data = aug(image=image, bboxes=bboxes)
    random.seed(seed)
    ia.seed(seed)
    deserialized_aug_data = deserialized_aug(image=image, bboxes=bboxes)
    assert np.array_equal(aug_data['image'], deserialized_aug_data['image'])
    assert np.array_equal(aug_data['bboxes'], deserialized_aug_data['bboxes'])
def test_augmentations_serialization_with_call_params(
    augmentation_cls, params, call_params, p, seed, image, always_apply
):
    aug = augmentation_cls(p=p, always_apply=always_apply, **params)
    annotations = {"image": image, **call_params}
    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    set_seed(seed)
    aug_data = aug(**annotations)
    set_seed(seed)
    deserialized_aug_data = deserialized_aug(**annotations)
    assert np.array_equal(aug_data["image"], deserialized_aug_data["image"])
示例#22
0
    def __init__(
        self,
        data_root: AnyPath,
        initial_cropping_rectangle: Optional[CroppingRectangle] = None,
        random_cropping_size: Optional[Tuple[int, int]] = None,
        num_slices_per_axis: Optional[int] = 1,
        batch_size: int = 1,
        train_subset: Optional[str] = None,
        val_subset: Optional[str] = None,
        test_subset: Optional[str] = None,
        user_albumentation_train: Optional[Union[dict, DictConfig, Any]] = None,
        class_selector: Optional[List[str]] = None,
        num_samples_limit_train: Optional[int] = None,
    ) -> None:

        super().__init__()

        self.data_root = Path(data_root)
        self.initial_cropping_rectangle = initial_cropping_rectangle
        self.random_cropping_size = random_cropping_size
        self.num_slices_per_axis = num_slices_per_axis
        self.batch_size = batch_size

        self.train_subset = train_subset
        self.val_subset = val_subset
        self.test_subset = test_subset

        self.num_samples_limit_train = num_samples_limit_train

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

        if isinstance(user_albumentation_train, dict) or isinstance(
            user_albumentation_train, DictConfig
        ):
            # Try to parse user_albumentation_train into an albumentation.
            self.user_albumentation_train = albumentations.from_dict(user_albumentation_train)
        else:
            self.user_albumentation_train = user_albumentation_train

        self.train_transforms = self.get_transforms(train=True)
        self.val_transforms = self.get_transforms(train=False)
        self.test_transforms = self.get_transforms(train=False)

        self.map_label_to_class_name = None
        self.num_classes = None

        self.class_selector = class_selector
示例#23
0
    def __init__(self, config):
        super().__init__(config)
        arch = config["architecture"]

        model_path = get_file_from_url(model_urls[arch],
                                       progress=True,
                                       unzip=False)
        self.model = torch.jit.load(model_path)
        self.device = torch.device(self.config["device"])
        self.model.to(self.device)

        hparams = json.loads(self.model.hparams)
        self.preprocess = A.from_dict(hparams["preprocess"])

        self.threshold = self.config["decision_threshold"]
        self.categories = hparams["categories"]
示例#24
0
def test_template_transform_serialization(image, template, seed, p):
    template_transform = A.TemplateTransform(name="template",
                                             templates=template,
                                             p=p)

    aug = A.Compose([A.Flip(), template_transform, A.Blur()])

    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(
        serialized_aug, lambda_transforms={"template": template_transform})

    set_seed(seed)
    aug_data = aug(image=image)
    set_seed(seed)
    deserialized_aug_data = deserialized_aug(image=image)

    assert np.array_equal(aug_data["image"], deserialized_aug_data["image"])
示例#25
0
 def get_preprocessing_transforms(self):
     preprocessing_config = self.cfg.data.preprocessing
     preprocessing_transforms = []
     if preprocessing_config:
         for preprocessing_transform in preprocessing_config:
             for transform_name, transform_args in preprocessing_transform.items(
             ):
                 transform = A.from_dict({
                     "transform": {
                         "__class_fullname__":
                         "albumentations.augmentations.transforms." +
                         transform_name,
                         **transform_args,
                     }
                 })
                 preprocessing_transforms.append(transform)
     return preprocessing_transforms
示例#26
0
def test_augmentations_serialization_with_custom_parameters(
    augmentation_cls,
    params,
    p,
    seed,
    image,
    mask,
    always_apply,
):
    aug = augmentation_cls(p=p, always_apply=always_apply, **params)
    serialized_aug = A.to_dict(aug)
    deserialized_aug = A.from_dict(serialized_aug)
    random.seed(seed)
    aug_data = aug(image=image, mask=mask)
    random.seed(seed)
    deserialized_aug_data = deserialized_aug(image=image, mask=mask)
    assert np.array_equal(aug_data['image'], deserialized_aug_data['image'])
    assert np.array_equal(aug_data['mask'], deserialized_aug_data['mask'])
示例#27
0
 def get_preprocessing_transforms(self):
     preprocessing_config = self.data_cfg.preprocessing
     if not preprocessing_config:
         return []
     preprocessing_config = OmegaConf.to_container(preprocessing_config,
                                                   resolve=True)
     preprocessing_transforms = []
     for preprocessing_transform in preprocessing_config:
         for transform_name, transform_args in preprocessing_transform.items(
         ):
             transform = A.from_dict({
                 "transform": {
                     "__class_fullname__":
                     "albumentations.augmentations.transforms." +
                     transform_name,
                     **transform_args,
                 }
             })
             preprocessing_transforms.append(transform)
     return preprocessing_transforms
示例#28
0
    def __init__(
        self,
        root_folder,
        image_normalization,
        ground_truth_normalization,
        augmenters,
        mode,
    ):
        super().__init__(
            root_folder,
            image_normalization,
            ground_truth_normalization,
            augmenters,
            mode,
        )

        self.image_normalization = getattr(normalizer,
                                           self.image_normalization)
        self.ground_truth_normalization = getattr(
            normalizer, self.ground_truth_normalization)

        if mode is "train" and len(self.augmenters) > 0:
            self.augmenters = from_dict(self.augmenters)
示例#29
0
def main():
    local_css(
        "/home/pasonatech/workspace/albumentations_forked/albumentations-demo/src/custom_css.css"
    )
    # logo_img = "/home/pasonatech/workspace/albumentations_forked/albumentations-demo/images/p.png"
    # html_sticky = f"""
    #     <div class="sticky pt-2">
    #         <img class="img-fluid" src="data:image/png;base64,{base64.b64encode(open(logo_img, "rb").read()).decode()}">
    #     </div>
    # """
    # st.markdown(html_sticky ,unsafe_allow_html = True)

    # get CLI params: the path to images and image width
    path_to_images, width_original = get_arguments()

    if not os.path.isdir(path_to_images):
        st.title("There is no directory: " + path_to_images)
    else:
        # select interface type
        interface_type = st.sidebar.radio(
            "Select the interface mode",
            ["Simple", "Professional", "Custom", "LoadMyFile"])

        #pick css
        if interface_type == "LoadMyFile":
            local_css(
                "/home/pasonatech/workspace/albumentations_forked/albumentations-demo/src/custom_loadmy_css.css"
            )

        if interface_type == "Custom":
            json_file_name = st.sidebar.text_input(
                "Insert Json File Name", "aug_file")  #text_area same format
            json_file_name = os.path.join("./my_json_files",
                                          f"{json_file_name}" + '.json')

        # select image
        status, image = select_image(path_to_images, interface_type)
        if status == 1:
            st.title("Can't load image")
        if status == 2:
            st.title("Please, upload the image")
        else:
            # image was loaded successfully
            placeholder_params = get_placeholder_params(image)

            # load the config
            augmentations = load_augmentations_config(
                placeholder_params, "configs/augmentations.json")

            if interface_type is not "LoadMyFile":
                # get the list of transformations names
                transform_names = select_transformations(
                    augmentations, interface_type)

            if interface_type is "Custom":
                transforms = get_transormations_params_custom(
                    transform_names, augmentations, json_file_name)

            elif interface_type is "LoadMyFile":

                f_name = st.sidebar.file_uploader("Select your json file",
                                                  type="json")

                view_times = 0
                if f_name:
                    j_text = StringIO.read(f_name)
                    j_data = json.loads(j_text)

                    image_replace = st.empty()
                    st.image(image,
                             caption="Original image",
                             width=width_original)
                    if st.sidebar.button("Play Preview"):
                        view_times = 1
                    stop_btn = st.sidebar.button("STOP Preview")
                    if stop_btn:
                        view_times = 0
                    # for seconds in range(view_times):
                    # data =j_data
                    try:
                        transform = A.from_dict(j_data)
                        display_value = True
                    except KeyError:
                        st.error(
                            "Please, confirm your augmentations structure.")
                        st.error(
                            "Supports only albumentations augmentation generated 'A.to_dict()'."
                        )
                        # view_times = 0
                        display_value = False

                    while (view_times == 1):

                        try:
                            # data = json.load(open(file_name, 'r'))
                            # transform = A.from_dict(data)
                            aug_img_obj = transform(image=image)
                            # print(aug_img_obj.keys())
                            aug_img = aug_img_obj['image']

                            image_replace.image(
                                aug_img,
                                caption="Transformed image",
                                width=width_original,
                            )
                        except IOError:
                            st.error("Confirm your json file path.")
                            view_times = 0
                        except UnboundLocalError:
                            st.error(
                                "Your json file seems incompatible to run this task. "
                            )
                            view_times = 0
                        except ValueError as e:
                            image_replace.error(
                                e)  #replaces error log in same field
                            pass

                        time.sleep(1)
                    if stop_btn is True:
                        st.info(
                            "Preview Stopped. Press Play Preview button to resume previewing."
                        )
                    if display_value:
                        if st.sidebar.checkbox(
                                "Display Augmentation Parameters"):
                            onetine_data_loader(j_data)

                    transforms = []
                else:
                    st.header("WELCOME")
                    st.header("Please upload a JSON File")

            else:
                # get parameters for each transform
                transforms = get_transormations_params(transform_names,
                                                       augmentations)

            if interface_type is not "LoadMyFile":
                try:
                    # apply the transformation to the image
                    data = A.ReplayCompose(transforms)(image=image)
                    error = 0
                except ValueError:
                    error = 1
                    st.title(
                        "The error has occurred. Most probably you have passed wrong set of parameters. \
                    Check transforms that change the shape of image.")

                # proceed only if everything is ok
                if error == 0:
                    augmented_image = data["image"]
                    # show title
                    st.title("Demo of Albumentations")

                    # show the images
                    width_transformed = int(width_original / image.shape[1] *
                                            augmented_image.shape[1])

                    st.image(image,
                             caption="Original image",
                             width=width_original)
                    st.image(
                        augmented_image,
                        caption="Transformed image",
                        width=width_transformed,
                    )

                    # comment about refreshing
                    st.write("*Press 'R' to refresh*")

                    #custom preview of aug list
                    # random values used to get transformations
                    show_random_params(data, interface_type)

                    for transform in transforms:
                        show_docstring(transform)
                        st.code(str(transform))
                    show_credentials()

                # adding google analytics pixel
                # only when deployed online. don't collect statistics of local usage
                if "GA" in os.environ:
                    st.image(os.environ["GA"])
                    st.markdown(
                        ("[Privacy policy]" +
                         ("(https://htmlpreview.github.io/?" +
                          "https://github.com/IliaLarchenko/" +
                          "albumentations-demo/blob/deploy/docs/privacy.html)")
                         ))
示例#30
0
 def _get_aug(self, arg):
     with open(arg) as f:
         return A.from_dict(json.load(f))