def _get_transform(**params) -> Callable: key_value_flag = params.pop("_key_value", False) if key_value_flag: transforms_composition = { transform_key: ConfigExperiment._get_transform( # noqa: WPS437 **transform_params) for transform_key, transform_params in params.items() } transform = AugmentorCompose({ key: Augmentor( dict_key=key, augment_fn=transform, input_key=key, output_key=key, ) for key, transform in transforms_composition.items() }) else: if "transforms" in params: transforms_composition = [ ConfigExperiment._get_transform( # noqa: WPS437 **transform_params) for transform_params in params["transforms"] ] params.update(transforms=transforms_composition) transform = TRANSFORMS.get_from_params(**params) return transform
def __init__( self, transform: Sequence[Union[dict, nn.Module]], input_key: Union[str, int] = "image", additional_input_key: Optional[str] = None, output_key: Optional[Union[str, int]] = None, additional_output_key: Optional[str] = None, ) -> None: """Constructor method for the :class:`BatchTransformCallback` callback. Args: transform (Sequence[Union[dict, nn.Module]]): define augmentations to apply on a batch If a sequence of transforms passed, then each element should be either ``kornia.augmentation.AugmentationBase2D``, ``kornia.augmentation.AugmentationBase3D``, or ``nn.Module`` compatible with kornia interface. If a sequence of params (``dict``) passed, then each element of the sequence must contain ``'transform'`` key with an augmentation name as a value. Please note that in this case to use custom augmentation you should add it to the `TRANSFORMS` registry first. input_key (Union[str, int]): key in batch dict mapping to transform, e.g. `'image'` additional_input_key (Optional[Union[str, int]]): key of an additional target in batch dict mapping to transform, e.g. `'mask'` output_key: key to use to store the result of the transform, defaults to `input_key` if not provided additional_output_key: key to use to store the result of additional target transformation, defaults to `additional_input_key` if not provided """ super().__init__(order=CallbackOrder.Internal, node=CallbackNode.all) self.input_key = input_key self.additional_input = additional_input_key self._process_input = (self._process_input_tuple if self.additional_input is not None else self._process_input_tensor) self.output_key = output_key or input_key self.additional_output = additional_output_key or self.additional_input self._process_output = (self._process_output_tuple if self.additional_output is not None else self._process_output_tensor) transforms: Sequence[nn.Module] = [ item if isinstance(item, nn.Module) else TRANSFORMS.get_from_params(**item) for item in transform ] assert all( isinstance(t, nn.Module) for t in transforms), "`nn.Module` should be a base class for transforms" self.transform = nn.Sequential(*transforms)