示例#1
0
    def _get_transform(**params) -> Callable:
        key_value_flag = params.pop("_key_value", False)

        if key_value_flag:
            transforms_composition = {
                transform_key: ConfigExperiment._get_transform(  # noqa: WPS437
                    **transform_params)
                for transform_key, transform_params in params.items()
            }

            transform = AugmentorCompose({
                key: Augmentor(
                    dict_key=key,
                    augment_fn=transform,
                    input_key=key,
                    output_key=key,
                )
                for key, transform in transforms_composition.items()
            })
        else:
            if "transforms" in params:
                transforms_composition = [
                    ConfigExperiment._get_transform(  # noqa: WPS437
                        **transform_params)
                    for transform_params in params["transforms"]
                ]
                params.update(transforms=transforms_composition)

            transform = TRANSFORMS.get_from_params(**params)

        return transform
示例#2
0
    def __init__(
        self,
        transform: Sequence[Union[dict, nn.Module]],
        input_key: Union[str, int] = "image",
        additional_input_key: Optional[str] = None,
        output_key: Optional[Union[str, int]] = None,
        additional_output_key: Optional[str] = None,
    ) -> None:
        """Constructor method for the :class:`BatchTransformCallback` callback.

        Args:
            transform (Sequence[Union[dict, nn.Module]]): define
                augmentations to apply on a batch

                If a sequence of transforms passed, then each element
                should be either ``kornia.augmentation.AugmentationBase2D``,
                ``kornia.augmentation.AugmentationBase3D``, or ``nn.Module``
                compatible with kornia interface.

                If a sequence of params (``dict``) passed, then each
                element of the sequence must contain ``'transform'`` key with
                an augmentation name as a value. Please note that in this case
                to use custom augmentation you should add it to the
                `TRANSFORMS` registry first.
            input_key (Union[str, int]): key in batch dict
                mapping to transform, e.g. `'image'`
            additional_input_key (Optional[Union[str, int]]): key of an
                additional target in batch dict mapping to transform,
                e.g. `'mask'`
            output_key: key to use to store the result
                of the transform, defaults to `input_key` if not provided
            additional_output_key: key to use to store
                the result of additional target transformation,
                defaults to `additional_input_key` if not provided
        """
        super().__init__(order=CallbackOrder.Internal, node=CallbackNode.all)

        self.input_key = input_key
        self.additional_input = additional_input_key
        self._process_input = (self._process_input_tuple
                               if self.additional_input is not None else
                               self._process_input_tensor)

        self.output_key = output_key or input_key
        self.additional_output = additional_output_key or self.additional_input
        self._process_output = (self._process_output_tuple
                                if self.additional_output is not None else
                                self._process_output_tensor)

        transforms: Sequence[nn.Module] = [
            item if isinstance(item, nn.Module) else
            TRANSFORMS.get_from_params(**item) for item in transform
        ]
        assert all(
            isinstance(t, nn.Module) for t in
            transforms), "`nn.Module` should be a base class for transforms"

        self.transform = nn.Sequential(*transforms)