Exemple #1
0
    def apply_trans(  # type: ignore
            cls, input: torch.Tensor, label: Optional[torch.Tensor],
            module: nn.Module,
            param: ParamItem) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Apply a transformation with respect to the parameters.

        Args:
            input: the input tensor.
            label: the optional label tensor.
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
        """
        if isinstance(module, (MixAugmentationBase, )):
            input, label = module(input, label=label, params=param.data)
        elif isinstance(module, (_AugmentationBase, )):
            input = module(input, params=param.data)
        elif isinstance(module, kornia.augmentation.ImageSequential):
            temp = module.apply_inverse_func
            temp2 = module.return_label
            module.apply_inverse_func = InputApplyInverse
            module.return_label = True
            if isinstance(module, kornia.augmentation.AugmentationSequential):
                input, label = module(input,
                                      label=label,
                                      params=param.data,
                                      data_keys=[cls.data_key])
            else:
                input, label = module(input, label=label, params=param.data)
            module.apply_inverse_func = temp
            module.return_label = temp2
        else:
            if param.data is not None:
                raise AssertionError(
                    f"Non-augmentaion operation {param.name} require empty parameters. Got {param}."
                )
            # In case of return_transform = True
            if isinstance(input, (tuple, list)):
                input = (module(input[0]), input[1])
            else:
                input = module(input)
        return input, label
Exemple #2
0
    def apply_trans(
        cls,
        input: torch.Tensor,
        label: Optional[torch.Tensor],
        module: nn.Module,
        param: Optional[ParamItem] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Apply a transformation with respect to the parameters.

        Args:
            input: the input tensor.
            label: the optional label tensor.
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
        """
        if param is not None:
            _param = param.data
        else:
            _param = None  # type: ignore

        if isinstance(module, (GeometricAugmentationBase2D, RandomErasing)):
            _param = cast(Dict[str, torch.Tensor], _param).copy()
            # TODO: Parametrize value to pad with across the board for different keys
            if 'values' in _param:
                _param['values'] = torch.zeros_like(
                    _param['values'])  # Always pad with zeros
            input = module(input, params=_param)
        elif isinstance(module, kornia.augmentation.ImageSequential
                        ) and not module.is_intensity_only():
            _param = cast(List[ParamItem], _param)
            temp = module.apply_inverse_func
            module.apply_inverse_func = MaskApplyInverse
            geo_param: List[ParamItem] = _get_geometric_only_param(
                module, _param)
            input = cls.make_input_only_sequential(module)(input,
                                                           label=None,
                                                           params=geo_param)
            module.apply_inverse_func = temp
        else:
            pass  # No need to update anything
        return input, label
Exemple #3
0
    def inverse(cls,
                input: torch.Tensor,
                module: nn.Module,
                param: Optional[ParamItem] = None) -> torch.Tensor:
        """Inverse a transformation with respect to the parameters.

        Args:
            input: the input tensor.
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
        """
        if isinstance(module, GeometricAugmentationBase2D):
            input = module.inverse(
                input, None if param is None else cast(Dict, param.data))
        elif isinstance(module, kornia.augmentation.container.ImageSequential):
            temp = module.apply_inverse_func
            module.apply_inverse_func = MaskApplyInverse
            input = module.inverse(
                input, None if param is None else cast(List, param.data))
            module.apply_inverse_func = temp
        return input
Exemple #4
0
    def apply_trans(
        cls,
        input: torch.Tensor,
        label: Optional[torch.Tensor],
        module: nn.Module,
        param: Optional[ParamItem] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Apply a transformation with respect to the parameters.

        Args:
            input: the input tensor.
            label: the optional label tensor.
            module: any torch Module but only kornia augmentation modules will count
                to apply transformations.
            param: the corresponding parameters to the module.
        """
        if param is not None:
            _param = param.data
        else:
            _param = None  # type: ignore

        if isinstance(module, GeometricAugmentationBase2D):
            _param = cast(Dict[str, torch.Tensor], _param)
            input = module(input, params=_param, return_transform=False)
        elif isinstance(module, kornia.augmentation.container.ImageSequential
                        ) and not module.is_intensity_only():
            _param = cast(List[ParamItem], _param)
            temp = module.apply_inverse_func
            module.apply_inverse_func = MaskApplyInverse
            geo_param: List[ParamItem] = _get_geometric_only_param(
                module, _param)
            input = cls.make_input_only_sequential(module)(input,
                                                           label=None,
                                                           params=geo_param)
            module.apply_inverse_func = temp
        else:
            pass  # No need to update anything
        return input, label