def apply_by_key( self, input: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], module: nn.Module, param: Optional[Dict[str, torch.Tensor]] = None, dcate: Union[str, int, DataKey] = DataKey.INPUT, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if DataKey.get(dcate) in [DataKey.INPUT]: return self.apply_to_input(input, module, param) if DataKey.get(dcate) in [DataKey.MASK]: if isinstance(input, (tuple, )): return (self.apply_to_mask(input[0], module, param), *input[1:]) return self.apply_to_mask(input, module, param) if DataKey.get(dcate) in [DataKey.BBOX, DataKey.BBOX_XYXY]: if isinstance(input, (tuple, )): return (self.apply_to_bbox(input[0], module, param, mode='xyxy'), *input[1:]) return self.apply_to_bbox(input, module, param, mode='xyxy') if DataKey.get(dcate) in [DataKey.BBOX_XYHW]: if isinstance(input, (tuple, )): return (self.apply_to_bbox(input[0], module, param, mode='xyhw'), *input[1:]) return self.apply_to_bbox(input, module, param, mode='xyhw') if DataKey.get(dcate) in [DataKey.KEYPOINTS]: if isinstance(input, (tuple, )): return (self.apply_to_keypoints(input[0], module, param), *input[1:]) return self.apply_to_keypoints(input, module, param) raise NotImplementedError(f"input type of {dcate} is not implemented.")
def _get_func_by_key( cls, dcate: Union[str, int, DataKey]) -> Type[ApplyInverseInterface]: if DataKey.get(dcate) == DataKey.INPUT: return InputApplyInverse if DataKey.get(dcate) in [DataKey.MASK]: return MaskApplyInverse if DataKey.get(dcate) in [DataKey.BBOX, DataKey.BBOX_XYXY]: return BBoxXYXYApplyInverse if DataKey.get(dcate) in [DataKey.BBOX_XYHW]: return BBoxXYWHApplyInverse if DataKey.get(dcate) in [DataKey.KEYPOINTS]: return KeypointsApplyInverse raise NotImplementedError(f"input type of {dcate} is not implemented.")
def _get_func_by_key( cls, dcate: Union[str, int, DataKey]) -> Type[ApplyInverseInterface]: if DataKey.get(dcate) == DataKey.INPUT: return InputApplyInverse if DataKey.get(dcate) == DataKey.MASK: return MaskApplyInverse if DataKey.get(dcate) in [ DataKey.BBOX, DataKey.BBOX_XYXY, DataKey.BBOX_XYWH ]: # We are converting to (B, 4, 2) internally for all formats. return BBoxApplyInverse if DataKey.get(dcate) in [DataKey.KEYPOINTS]: return KeypointsApplyInverse raise NotImplementedError(f"input type of {dcate} is not implemented.")
def __init__( self, *args: Union[_AugmentationBase, ImageSequential], data_keys: List[Union[str, int, DataKey]] = [DataKey.INPUT], same_on_batch: Optional[bool] = None, return_transform: Optional[bool] = None, keepdim: Optional[bool] = None, random_apply: Union[int, bool, Tuple[int, int]] = False, ) -> None: super(AugmentationSequential, self).__init__( *args, same_on_batch=same_on_batch, return_transform=return_transform, keepdim=keepdim, random_apply=random_apply, ) self.data_keys = [DataKey.get(inp) for inp in data_keys] assert all(in_type in DataKey for in_type in self.data_keys ), f"`data_keys` must be in {DataKey}. Got {data_keys}." if self.data_keys[0] != DataKey.INPUT: raise NotImplementedError( f"The first input must be {DataKey.INPUT}.") for arg in args: if isinstance(arg, PatchSequential) and not arg.is_intensity_only(): warnings.warn( "Geometric transformation detected in PatchSeqeuntial, which would break bbox, mask." )
def inverse_by_key( self, input: torch.Tensor, module: nn.Module, param: Optional[Dict[str, torch.Tensor]] = None, dcate: Union[str, int, DataKey] = DataKey.INPUT, ) -> torch.Tensor: if DataKey.get(dcate) in [DataKey.INPUT, DataKey.MASK]: return self.inverse_input(input, module, param) if DataKey.get(dcate) in [DataKey.BBOX, DataKey.BBOX_XYXY]: return self.inverse_bbox(input, module, param, mode='xyxy') if DataKey.get(dcate) in [DataKey.BBOX_XYHW]: return self.inverse_bbox(input, module, param, mode='xyhw') if DataKey.get(dcate) in [DataKey.KEYPOINTS]: return self.inverse_keypoints(input, module, param) raise NotImplementedError(f"input type of {dcate} is not implemented.")
def apply_by_key( self, input: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], module_name: str, module: Optional[nn.Module] = None, param: Optional[ParamItem] = None, dcate: Union[str, int, DataKey] = DataKey.INPUT, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if param is not None: assert module_name == param.name if module is None: # TODO (jian): double check why typing is crashing module = self.get_submodule(module_name) # type: ignore if DataKey.get(dcate) in [DataKey.INPUT]: return self.apply_to_input(input, module_name, module, param) if DataKey.get(dcate) in [DataKey.MASK]: if isinstance(input, (tuple, )): return (self.apply_to_mask(input[0], module, param), *input[1:]) return self.apply_to_mask(input, module, param) if DataKey.get(dcate) in [DataKey.BBOX, DataKey.BBOX_XYXY]: if isinstance(input, (tuple, )): return (self.apply_to_bbox(input[0], module, param, mode='xyxy'), *input[1:]) return self.apply_to_bbox(input, module, param, mode='xyxy') if DataKey.get(dcate) in [DataKey.BBOX_XYHW]: if isinstance(input, (tuple, )): return (self.apply_to_bbox(input[0], module, param, mode='xyhw'), *input[1:]) return self.apply_to_bbox(input, module, param, mode='xyhw') if DataKey.get(dcate) in [DataKey.KEYPOINTS]: if isinstance(input, (tuple, )): return (self.apply_to_keypoints(input[0], module, param), *input[1:]) return self.apply_to_keypoints(input, module, param) raise NotImplementedError(f"input type of {dcate} is not implemented.")
def apply_by_key( self, input: TensorWithTransformMat, label: Optional[torch.Tensor], module: Optional[nn.Module], param: ParamItem, dcate: Union[str, int, DataKey] = DataKey.INPUT, ) -> Tuple[TensorWithTransformMat, Optional[torch.Tensor]]: if module is None: module = self.get_submodule(param.name) if DataKey.get(dcate) in [DataKey.INPUT]: return self.apply_to_input(input, label, module=module, param=param) if DataKey.get(dcate) in [DataKey.MASK]: if isinstance(input, (tuple, )): return (self.apply_to_mask(input[0], module, param), *input[1:]), None return self.apply_to_mask(input, module, param), None if DataKey.get(dcate) in [DataKey.BBOX, DataKey.BBOX_XYXY]: if isinstance(input, (tuple, )): return (self.apply_to_bbox(input[0], module, param, mode='xyxy'), *input[1:]), None return self.apply_to_bbox(input, module, param, mode='xyxy'), None if DataKey.get(dcate) in [DataKey.BBOX_XYHW]: if isinstance(input, (tuple, )): return (self.apply_to_bbox(input[0], module, param, mode='xyhw'), *input[1:]), None return self.apply_to_bbox(input, module, param, mode='xyhw'), None if DataKey.get(dcate) in [DataKey.KEYPOINTS]: if isinstance(input, (tuple, )): return (self.apply_to_keypoints(input[0], module, param), *input[1:]), None return self.apply_to_keypoints(input, module, param), None raise NotImplementedError(f"input type of {dcate} is not implemented.")
def _arguments_preproc(self, *args: Tensor, data_keys: List[DataKey]): inp: List[Any] = [] for arg, dcate in zip(args, data_keys): if DataKey.get(dcate) in [ DataKey.INPUT, DataKey.MASK, DataKey.KEYPOINTS ]: inp.append(arg) elif DataKey.get(dcate) in [ DataKey.BBOX, DataKey.BBOX_XYXY, DataKey.BBOX_XYWH ]: if DataKey.get(dcate) in [DataKey.BBOX]: mode = "vertices_plus" elif DataKey.get(dcate) in [DataKey.BBOX_XYXY]: mode = "xyxy" elif DataKey.get(dcate) in [DataKey.BBOX_XYWH]: mode = "xywh" else: raise ValueError( f"Unsupported mode `{DataKey.get(dcate).name}`.") inp.append(Boxes.from_tensor(arg, mode=mode)) # type: ignore else: raise NotImplementedError( f"input type of {dcate} is not implemented.") return inp
def __init__( self, *args: Union[_AugmentationBase, ImageSequential], data_keys: List[Union[str, int, DataKey]] = [DataKey.INPUT], same_on_batch: Optional[bool] = None, return_transform: Optional[bool] = None, keepdim: Optional[bool] = None, random_apply: Union[int, bool, Tuple[int, int]] = False, random_apply_weights: Optional[List[float]] = None, ) -> None: super().__init__( *args, same_on_batch=same_on_batch, return_transform=return_transform, keepdim=keepdim, random_apply=random_apply, random_apply_weights=random_apply_weights, ) self.data_keys = [DataKey.get(inp) for inp in data_keys] if not all(in_type in DataKey for in_type in self.data_keys): raise AssertionError( f"`data_keys` must be in {DataKey}. Got {data_keys}.") if self.data_keys[0] != DataKey.INPUT: raise NotImplementedError( f"The first input must be {DataKey.INPUT}.") self.contains_video_sequential: bool = False self.contains_3d_augmentation: bool = False for arg in args: if isinstance(arg, PatchSequential) and not arg.is_intensity_only(): warnings.warn( "Geometric transformation detected in PatchSeqeuntial, which would break bbox, mask." ) if isinstance(arg, VideoSequential): self.contains_video_sequential = True # NOTE: only for images are supported for 3D. if isinstance(arg, AugmentationBase3D): self.contains_3d_augmentation = True self._transform_matrix: Optional[Tensor] = None
def __init__( self, *args: _AugmentationBase, data_keys: List[Union[str, int, DataKey]] = [DataKey.INPUT], same_on_batch: Optional[bool] = None, return_transform: Optional[bool] = None, keepdim: Optional[bool] = None, ) -> None: super(AugmentationSequential, self).__init__(*args, same_on_batch=same_on_batch, return_transform=return_transform, keepdim=keepdim) self.data_keys = [DataKey.get(inp) for inp in data_keys] assert all(in_type in DataKey for in_type in self.data_keys ), f"`data_keys` must be in {DataKey}. Got {data_keys}." if self.data_keys[0] != DataKey.INPUT: raise NotImplementedError( f"The first input must be {DataKey.INPUT}.")
def forward( # type: ignore self, *args: TensorWithTransformMat, label: Optional[torch.Tensor] = None, params: Optional[List[ParamItem]] = None, data_keys: Optional[List[Union[str, int, DataKey]]] = None, ) -> Union[TensorWithTransformMat, Tuple[TensorWithTransformMat, Optional[torch.Tensor]], List[TensorWithTransformMat], Tuple[ List[TensorWithTransformMat], Optional[torch.Tensor]], ]: """Compute multiple tensors simultaneously according to ``self.data_keys``.""" if data_keys is None: data_keys = cast(List[Union[str, int, DataKey]], self.data_keys) else: data_keys = [DataKey.get(inp) for inp in data_keys] if len(args) != len(data_keys): raise AssertionError( f"The number of inputs must align with the number of data_keys. Got {len(args)} and {len(data_keys)}." ) if params is None: if DataKey.INPUT in data_keys: _input = args[data_keys.index(DataKey.INPUT)] if isinstance(_input, (tuple, list)): inp = _input[0] else: inp = _input if self.contains_video_sequential: _, out_shape = self.autofill_dim(inp, dim_range=(3, 5)) else: _, out_shape = self.autofill_dim(inp, dim_range=(2, 4)) params = self.forward_parameters(out_shape) else: raise ValueError( "`params` must be provided whilst INPUT is not in data_keys." ) outputs: List[TensorWithTransformMat] = [None] * len( data_keys) # type: ignore if DataKey.INPUT in data_keys: idx = data_keys.index(DataKey.INPUT) out = super().forward(args[idx], label, params=params) if self.return_label: input, label = cast( Tuple[TensorWithTransformMat, torch.Tensor], out) else: input = cast(TensorWithTransformMat, out) outputs[idx] = input self.return_label = label is not None or self.contains_label_operations( params) for idx, (input, dcate, out) in enumerate(zip(args, data_keys, outputs)): if out is not None: continue for param in params: module = self.get_submodule(param.name) if dcate == DataKey.INPUT: input, label = self.apply_to_input(input, label, module=module, param=param) elif isinstance( module, IntensityAugmentationBase2D) and dcate in DataKey: pass # Do nothing elif isinstance(module, ImageSequential) and module.is_intensity_only( ) and dcate in DataKey: pass # Do nothing elif isinstance(module, VideoSequential) and dcate not in [ DataKey.INPUT, DataKey.MASK ]: batch_size: int = input.size(0) input = input.view(-1, *input.shape[2:]) input, label = ApplyInverse.apply_by_key( input, label, module, param, dcate) input = input.view(batch_size, -1, *input.shape[1:]) elif isinstance(module, PatchSequential): raise NotImplementedError( "Geometric involved PatchSequential is not supported.") elif isinstance(module, ( GeometricAugmentationBase2D, ImageSequential, )) and dcate in DataKey: input, label = ApplyInverse.apply_by_key( input, label, module, param, dcate) elif isinstance(module, (SequentialBase, )): raise ValueError(f"Unsupported Sequential {module}.") else: raise NotImplementedError( f"data_key {dcate} is not implemented for {module}.") outputs[idx] = input return self.__packup_output__(outputs, label)
def forward( # type: ignore self, *args: Tensor, label: Optional[Tensor] = None, params: Optional[List[ParamItem]] = None, data_keys: Optional[List[Union[str, int, DataKey]]] = None, ) -> Union[Tensor, Tuple[Tensor, Optional[Tensor]], List[Tensor], Tuple[ List[Tensor], Optional[Tensor]], ]: """Compute multiple tensors simultaneously according to ``self.data_keys``.""" _data_keys: List[DataKey] if data_keys is None: _data_keys = self.data_keys else: _data_keys = [DataKey.get(inp) for inp in data_keys] self.data_keys = _data_keys self._validate_args_datakeys(*args, data_keys=_data_keys) args = self._arguments_preproc(*args, data_keys=_data_keys) if params is None: # image data must exist if params is not provided. if DataKey.INPUT in _data_keys: _input = args[_data_keys.index(DataKey.INPUT)] inp = _input if isinstance(inp, (tuple, list)): raise ValueError( f"`INPUT` should be a tensor but `{type(inp)}` received." ) # A video input shall be BCDHW while an image input shall be BCHW if self.contains_video_sequential or self.contains_3d_augmentation: _, out_shape = self.autofill_dim(inp, dim_range=(3, 5)) else: _, out_shape = self.autofill_dim(inp, dim_range=(2, 4)) params = self.forward_parameters(out_shape) else: raise ValueError( "`params` must be provided whilst INPUT is not in data_keys." ) outputs: List[Tensor] = [None] * len(_data_keys) # type: ignore # Forward the first image data to freeze the parameters. if DataKey.INPUT in _data_keys: idx = _data_keys.index(DataKey.INPUT) _inp = args[idx] _out = super().forward(_inp, label, params=params) # type: ignore self._transform_matrix = self.get_transformation_matrix( _inp, params=params) if self.return_label: _input, label = cast(Tuple[Tensor, Tensor], _out) else: _input = cast(Tensor, _out) outputs[idx] = _input self.return_label = self.return_label or label is not None or self.contains_label_operations( params) for idx, (arg, dcate, out) in enumerate(zip(args, _data_keys, outputs)): if out is not None: continue # Using tensors straight-away if isinstance(arg, (Boxes, )): input = arg.data # all boxes are in (B, N, 4, 2) format now. else: input = arg for param in params: module = self.get_submodule(param.name) if dcate == DataKey.INPUT: input, label = self.apply_to_input(input, label, module=module, param=param) elif isinstance(module, IntensityAugmentationBase2D) and dcate in DataKey \ and not isinstance(module, RandomErasing): pass # Do nothing elif isinstance(module, ImageSequential) and module.is_intensity_only( ) and dcate in DataKey: pass # Do nothing elif isinstance(module, VideoSequential) and dcate not in [ DataKey.INPUT, DataKey.MASK ]: batch_size: int = input.size(0) input = input.view(-1, *input.shape[2:]) input, label = ApplyInverse.apply_by_key( input, label, module, param, dcate) input = input.view(batch_size, -1, *input.shape[1:]) elif isinstance(module, PatchSequential): raise NotImplementedError( "Geometric involved PatchSequential is not supported.") elif isinstance(module, (GeometricAugmentationBase2D, ImageSequential, RandomErasing)) \ and dcate in DataKey: input, label = ApplyInverse.apply_by_key( input, label, module, param, dcate) elif isinstance(module, (SequentialBase, )): raise ValueError(f"Unsupported Sequential {module}.") else: raise NotImplementedError( f"data_key {dcate} is not implemented for {module}.") if isinstance(arg, (Boxes, )): arg._data = input outputs[idx] = arg.to_tensor() else: outputs[idx] = input return self.__packup_output__(outputs, label)
def inverse( # type: ignore self, *args: Tensor, params: Optional[List[ParamItem]] = None, data_keys: Optional[List[Union[str, int, DataKey]]] = None, ) -> Union[Tensor, List[Tensor]]: """Reverse the transformation applied. Number of input tensors must align with the number of``data_keys``. If ``data_keys`` is not set, use ``self.data_keys`` by default. """ if data_keys is None: data_keys = cast(List[Union[str, int, DataKey]], self.data_keys) _data_keys: List[DataKey] = [DataKey.get(inp) for inp in data_keys] if len(args) != len(data_keys): raise AssertionError( "The number of inputs must align with the number of data_keys, " f"Got {len(args)} and {len(data_keys)}.") args = self._arguments_preproc(*args, data_keys=_data_keys) if params is None: if self._params is None: raise ValueError( "No parameters available for inversing, please run a forward pass first " "or passing valid params into this function.") params = self._params outputs: List[Tensor] = [None] * len(data_keys) # type: ignore for idx, (arg, dcate) in enumerate(zip(args, data_keys)): if dcate == DataKey.INPUT and isinstance(arg, (tuple, list)): input, _ = arg # ignore the transformation matrix whilst inverse # Using tensors straight-away elif isinstance(arg, (Boxes, )): input = arg.data # all boxes are in (B, N, 4, 2) format now. else: input = arg for (name, module), param in zip_longest( list(self.get_forward_sequence(params))[::-1], params[::-1]): if isinstance(module, (_AugmentationBase, ImageSequential)): param = params[name] if name in params else param else: param = None if isinstance(module, IntensityAugmentationBase2D) and dcate in DataKey \ and not isinstance(module, RandomErasing): pass # Do nothing elif isinstance(module, ImageSequential) and module.is_intensity_only( ) and dcate in DataKey: pass # Do nothing elif isinstance(module, VideoSequential) and dcate not in [ DataKey.INPUT, DataKey.MASK ]: batch_size: int = input.size(0) input = input.view(-1, *input.shape[2:]) input = ApplyInverse.inverse_by_key( input, module, param, dcate) input = input.view(batch_size, -1, *input.shape[1:]) elif isinstance(module, PatchSequential): raise NotImplementedError( "Geometric involved PatchSequential is not supported.") elif isinstance(module, (GeometricAugmentationBase2D, ImageSequential, RandomErasing)) \ and dcate in DataKey: input = ApplyInverse.inverse_by_key( input, module, param, dcate) elif isinstance(module, (SequentialBase, )): raise ValueError(f"Unsupported Sequential {module}.") else: raise NotImplementedError( f"data_key {dcate} is not implemented for {module}.") if isinstance(arg, (Boxes, )): arg._data = input outputs[idx] = arg.to_tensor() else: outputs[idx] = input if len(outputs) == 1 and isinstance(outputs, (tuple, list)): return outputs[0] return outputs