def apply_trans( # type: ignore cls, input: torch.Tensor, label: Optional[torch.Tensor], module: nn.Module, param: ParamItem) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Apply a transformation with respect to the parameters. Args: input: the input tensor. label: the optional label tensor. module: any torch Module but only kornia augmentation modules will count to apply transformations. param: the corresponding parameters to the module. """ if isinstance(module, (MixAugmentationBase, )): input, label = module(input, label=label, params=param.data) elif isinstance(module, (_AugmentationBase, )): input = module(input, params=param.data) elif isinstance(module, kornia.augmentation.ImageSequential): temp = module.apply_inverse_func temp2 = module.return_label module.apply_inverse_func = InputApplyInverse module.return_label = True if isinstance(module, kornia.augmentation.AugmentationSequential): input, label = module(input, label=label, params=param.data, data_keys=[cls.data_key]) else: input, label = module(input, label=label, params=param.data) module.apply_inverse_func = temp module.return_label = temp2 else: if param.data is not None: raise AssertionError( f"Non-augmentaion operation {param.name} require empty parameters. Got {param}." ) # In case of return_transform = True if isinstance(input, (tuple, list)): input = (module(input[0]), input[1]) else: input = module(input) return input, label
def apply_trans( cls, input: torch.Tensor, label: Optional[torch.Tensor], module: nn.Module, param: Optional[ParamItem] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Apply a transformation with respect to the parameters. Args: input: the input tensor. label: the optional label tensor. module: any torch Module but only kornia augmentation modules will count to apply transformations. param: the corresponding parameters to the module. """ if param is not None: _param = param.data else: _param = None # type: ignore if isinstance(module, (GeometricAugmentationBase2D, RandomErasing)): _param = cast(Dict[str, torch.Tensor], _param).copy() # TODO: Parametrize value to pad with across the board for different keys if 'values' in _param: _param['values'] = torch.zeros_like( _param['values']) # Always pad with zeros input = module(input, params=_param) elif isinstance(module, kornia.augmentation.ImageSequential ) and not module.is_intensity_only(): _param = cast(List[ParamItem], _param) temp = module.apply_inverse_func module.apply_inverse_func = MaskApplyInverse geo_param: List[ParamItem] = _get_geometric_only_param( module, _param) input = cls.make_input_only_sequential(module)(input, label=None, params=geo_param) module.apply_inverse_func = temp else: pass # No need to update anything return input, label
def inverse(cls, input: torch.Tensor, module: nn.Module, param: Optional[ParamItem] = None) -> torch.Tensor: """Inverse a transformation with respect to the parameters. Args: input: the input tensor. module: any torch Module but only kornia augmentation modules will count to apply transformations. param: the corresponding parameters to the module. """ if isinstance(module, GeometricAugmentationBase2D): input = module.inverse( input, None if param is None else cast(Dict, param.data)) elif isinstance(module, kornia.augmentation.container.ImageSequential): temp = module.apply_inverse_func module.apply_inverse_func = MaskApplyInverse input = module.inverse( input, None if param is None else cast(List, param.data)) module.apply_inverse_func = temp return input
def apply_trans( cls, input: torch.Tensor, label: Optional[torch.Tensor], module: nn.Module, param: Optional[ParamItem] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Apply a transformation with respect to the parameters. Args: input: the input tensor. label: the optional label tensor. module: any torch Module but only kornia augmentation modules will count to apply transformations. param: the corresponding parameters to the module. """ if param is not None: _param = param.data else: _param = None # type: ignore if isinstance(module, GeometricAugmentationBase2D): _param = cast(Dict[str, torch.Tensor], _param) input = module(input, params=_param, return_transform=False) elif isinstance(module, kornia.augmentation.container.ImageSequential ) and not module.is_intensity_only(): _param = cast(List[ParamItem], _param) temp = module.apply_inverse_func module.apply_inverse_func = MaskApplyInverse geo_param: List[ParamItem] = _get_geometric_only_param( module, _param) input = cls.make_input_only_sequential(module)(input, label=None, params=geo_param) module.apply_inverse_func = temp else: pass # No need to update anything return input, label