Exemplo n.º 1
0
 def apply_by_key(
     self,
     input: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
     module: nn.Module,
     param: Optional[Dict[str, torch.Tensor]] = None,
     dcate: Union[str, int, DataKey] = DataKey.INPUT,
 ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
     if DataKey.get(dcate) in [DataKey.INPUT]:
         return self.apply_to_input(input, module, param)
     if DataKey.get(dcate) in [DataKey.MASK]:
         if isinstance(input, (tuple, )):
             return (self.apply_to_mask(input[0], module,
                                        param), *input[1:])
         return self.apply_to_mask(input, module, param)
     if DataKey.get(dcate) in [DataKey.BBOX, DataKey.BBOX_XYXY]:
         if isinstance(input, (tuple, )):
             return (self.apply_to_bbox(input[0],
                                        module,
                                        param,
                                        mode='xyxy'), *input[1:])
         return self.apply_to_bbox(input, module, param, mode='xyxy')
     if DataKey.get(dcate) in [DataKey.BBOX_XYHW]:
         if isinstance(input, (tuple, )):
             return (self.apply_to_bbox(input[0],
                                        module,
                                        param,
                                        mode='xyhw'), *input[1:])
         return self.apply_to_bbox(input, module, param, mode='xyhw')
     if DataKey.get(dcate) in [DataKey.KEYPOINTS]:
         if isinstance(input, (tuple, )):
             return (self.apply_to_keypoints(input[0], module,
                                             param), *input[1:])
         return self.apply_to_keypoints(input, module, param)
     raise NotImplementedError(f"input type of {dcate} is not implemented.")
Exemplo n.º 2
0
 def _get_func_by_key(
         cls, dcate: Union[str, int,
                           DataKey]) -> Type[ApplyInverseInterface]:
     if DataKey.get(dcate) == DataKey.INPUT:
         return InputApplyInverse
     if DataKey.get(dcate) in [DataKey.MASK]:
         return MaskApplyInverse
     if DataKey.get(dcate) in [DataKey.BBOX, DataKey.BBOX_XYXY]:
         return BBoxXYXYApplyInverse
     if DataKey.get(dcate) in [DataKey.BBOX_XYHW]:
         return BBoxXYWHApplyInverse
     if DataKey.get(dcate) in [DataKey.KEYPOINTS]:
         return KeypointsApplyInverse
     raise NotImplementedError(f"input type of {dcate} is not implemented.")
Exemplo n.º 3
0
 def _get_func_by_key(
         cls, dcate: Union[str, int,
                           DataKey]) -> Type[ApplyInverseInterface]:
     if DataKey.get(dcate) == DataKey.INPUT:
         return InputApplyInverse
     if DataKey.get(dcate) == DataKey.MASK:
         return MaskApplyInverse
     if DataKey.get(dcate) in [
             DataKey.BBOX, DataKey.BBOX_XYXY, DataKey.BBOX_XYWH
     ]:
         # We are converting to (B, 4, 2) internally for all formats.
         return BBoxApplyInverse
     if DataKey.get(dcate) in [DataKey.KEYPOINTS]:
         return KeypointsApplyInverse
     raise NotImplementedError(f"input type of {dcate} is not implemented.")
Exemplo n.º 4
0
    def __init__(
        self,
        *args: Union[_AugmentationBase, ImageSequential],
        data_keys: List[Union[str, int, DataKey]] = [DataKey.INPUT],
        same_on_batch: Optional[bool] = None,
        return_transform: Optional[bool] = None,
        keepdim: Optional[bool] = None,
        random_apply: Union[int, bool, Tuple[int, int]] = False,
    ) -> None:
        super(AugmentationSequential, self).__init__(
            *args,
            same_on_batch=same_on_batch,
            return_transform=return_transform,
            keepdim=keepdim,
            random_apply=random_apply,
        )

        self.data_keys = [DataKey.get(inp) for inp in data_keys]

        assert all(in_type in DataKey for in_type in self.data_keys
                   ), f"`data_keys` must be in {DataKey}. Got {data_keys}."

        if self.data_keys[0] != DataKey.INPUT:
            raise NotImplementedError(
                f"The first input must be {DataKey.INPUT}.")

        for arg in args:
            if isinstance(arg,
                          PatchSequential) and not arg.is_intensity_only():
                warnings.warn(
                    "Geometric transformation detected in PatchSeqeuntial, which would break bbox, mask."
                )
Exemplo n.º 5
0
 def inverse_by_key(
     self,
     input: torch.Tensor,
     module: nn.Module,
     param: Optional[Dict[str, torch.Tensor]] = None,
     dcate: Union[str, int, DataKey] = DataKey.INPUT,
 ) -> torch.Tensor:
     if DataKey.get(dcate) in [DataKey.INPUT, DataKey.MASK]:
         return self.inverse_input(input, module, param)
     if DataKey.get(dcate) in [DataKey.BBOX, DataKey.BBOX_XYXY]:
         return self.inverse_bbox(input, module, param, mode='xyxy')
     if DataKey.get(dcate) in [DataKey.BBOX_XYHW]:
         return self.inverse_bbox(input, module, param, mode='xyhw')
     if DataKey.get(dcate) in [DataKey.KEYPOINTS]:
         return self.inverse_keypoints(input, module, param)
     raise NotImplementedError(f"input type of {dcate} is not implemented.")
Exemplo n.º 6
0
 def apply_by_key(
     self,
     input: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
     module_name: str,
     module: Optional[nn.Module] = None,
     param: Optional[ParamItem] = None,
     dcate: Union[str, int, DataKey] = DataKey.INPUT,
 ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
     if param is not None:
         assert module_name == param.name
     if module is None:
         # TODO (jian): double check why typing is crashing
         module = self.get_submodule(module_name)  # type: ignore
     if DataKey.get(dcate) in [DataKey.INPUT]:
         return self.apply_to_input(input, module_name, module, param)
     if DataKey.get(dcate) in [DataKey.MASK]:
         if isinstance(input, (tuple, )):
             return (self.apply_to_mask(input[0], module,
                                        param), *input[1:])
         return self.apply_to_mask(input, module, param)
     if DataKey.get(dcate) in [DataKey.BBOX, DataKey.BBOX_XYXY]:
         if isinstance(input, (tuple, )):
             return (self.apply_to_bbox(input[0],
                                        module,
                                        param,
                                        mode='xyxy'), *input[1:])
         return self.apply_to_bbox(input, module, param, mode='xyxy')
     if DataKey.get(dcate) in [DataKey.BBOX_XYHW]:
         if isinstance(input, (tuple, )):
             return (self.apply_to_bbox(input[0],
                                        module,
                                        param,
                                        mode='xyhw'), *input[1:])
         return self.apply_to_bbox(input, module, param, mode='xyhw')
     if DataKey.get(dcate) in [DataKey.KEYPOINTS]:
         if isinstance(input, (tuple, )):
             return (self.apply_to_keypoints(input[0], module,
                                             param), *input[1:])
         return self.apply_to_keypoints(input, module, param)
     raise NotImplementedError(f"input type of {dcate} is not implemented.")
Exemplo n.º 7
0
 def apply_by_key(
     self,
     input: TensorWithTransformMat,
     label: Optional[torch.Tensor],
     module: Optional[nn.Module],
     param: ParamItem,
     dcate: Union[str, int, DataKey] = DataKey.INPUT,
 ) -> Tuple[TensorWithTransformMat, Optional[torch.Tensor]]:
     if module is None:
         module = self.get_submodule(param.name)
     if DataKey.get(dcate) in [DataKey.INPUT]:
         return self.apply_to_input(input,
                                    label,
                                    module=module,
                                    param=param)
     if DataKey.get(dcate) in [DataKey.MASK]:
         if isinstance(input, (tuple, )):
             return (self.apply_to_mask(input[0], module,
                                        param), *input[1:]), None
         return self.apply_to_mask(input, module, param), None
     if DataKey.get(dcate) in [DataKey.BBOX, DataKey.BBOX_XYXY]:
         if isinstance(input, (tuple, )):
             return (self.apply_to_bbox(input[0],
                                        module,
                                        param,
                                        mode='xyxy'), *input[1:]), None
         return self.apply_to_bbox(input, module, param, mode='xyxy'), None
     if DataKey.get(dcate) in [DataKey.BBOX_XYHW]:
         if isinstance(input, (tuple, )):
             return (self.apply_to_bbox(input[0],
                                        module,
                                        param,
                                        mode='xyhw'), *input[1:]), None
         return self.apply_to_bbox(input, module, param, mode='xyhw'), None
     if DataKey.get(dcate) in [DataKey.KEYPOINTS]:
         if isinstance(input, (tuple, )):
             return (self.apply_to_keypoints(input[0], module,
                                             param), *input[1:]), None
         return self.apply_to_keypoints(input, module, param), None
     raise NotImplementedError(f"input type of {dcate} is not implemented.")
Exemplo n.º 8
0
 def _arguments_preproc(self, *args: Tensor, data_keys: List[DataKey]):
     inp: List[Any] = []
     for arg, dcate in zip(args, data_keys):
         if DataKey.get(dcate) in [
                 DataKey.INPUT, DataKey.MASK, DataKey.KEYPOINTS
         ]:
             inp.append(arg)
         elif DataKey.get(dcate) in [
                 DataKey.BBOX, DataKey.BBOX_XYXY, DataKey.BBOX_XYWH
         ]:
             if DataKey.get(dcate) in [DataKey.BBOX]:
                 mode = "vertices_plus"
             elif DataKey.get(dcate) in [DataKey.BBOX_XYXY]:
                 mode = "xyxy"
             elif DataKey.get(dcate) in [DataKey.BBOX_XYWH]:
                 mode = "xywh"
             else:
                 raise ValueError(
                     f"Unsupported mode `{DataKey.get(dcate).name}`.")
             inp.append(Boxes.from_tensor(arg, mode=mode))  # type: ignore
         else:
             raise NotImplementedError(
                 f"input type of {dcate} is not implemented.")
     return inp
Exemplo n.º 9
0
    def __init__(
        self,
        *args: Union[_AugmentationBase, ImageSequential],
        data_keys: List[Union[str, int, DataKey]] = [DataKey.INPUT],
        same_on_batch: Optional[bool] = None,
        return_transform: Optional[bool] = None,
        keepdim: Optional[bool] = None,
        random_apply: Union[int, bool, Tuple[int, int]] = False,
        random_apply_weights: Optional[List[float]] = None,
    ) -> None:
        super().__init__(
            *args,
            same_on_batch=same_on_batch,
            return_transform=return_transform,
            keepdim=keepdim,
            random_apply=random_apply,
            random_apply_weights=random_apply_weights,
        )

        self.data_keys = [DataKey.get(inp) for inp in data_keys]

        if not all(in_type in DataKey for in_type in self.data_keys):
            raise AssertionError(
                f"`data_keys` must be in {DataKey}. Got {data_keys}.")

        if self.data_keys[0] != DataKey.INPUT:
            raise NotImplementedError(
                f"The first input must be {DataKey.INPUT}.")

        self.contains_video_sequential: bool = False
        self.contains_3d_augmentation: bool = False
        for arg in args:
            if isinstance(arg,
                          PatchSequential) and not arg.is_intensity_only():
                warnings.warn(
                    "Geometric transformation detected in PatchSeqeuntial, which would break bbox, mask."
                )
            if isinstance(arg, VideoSequential):
                self.contains_video_sequential = True
            # NOTE: only for images are supported for 3D.
            if isinstance(arg, AugmentationBase3D):
                self.contains_3d_augmentation = True
        self._transform_matrix: Optional[Tensor] = None
Exemplo n.º 10
0
    def __init__(
        self,
        *args: _AugmentationBase,
        data_keys: List[Union[str, int, DataKey]] = [DataKey.INPUT],
        same_on_batch: Optional[bool] = None,
        return_transform: Optional[bool] = None,
        keepdim: Optional[bool] = None,
    ) -> None:
        super(AugmentationSequential,
              self).__init__(*args,
                             same_on_batch=same_on_batch,
                             return_transform=return_transform,
                             keepdim=keepdim)

        self.data_keys = [DataKey.get(inp) for inp in data_keys]

        assert all(in_type in DataKey for in_type in self.data_keys
                   ), f"`data_keys` must be in {DataKey}. Got {data_keys}."

        if self.data_keys[0] != DataKey.INPUT:
            raise NotImplementedError(
                f"The first input must be {DataKey.INPUT}.")
Exemplo n.º 11
0
    def forward(  # type: ignore
        self,
        *args: TensorWithTransformMat,
        label: Optional[torch.Tensor] = None,
        params: Optional[List[ParamItem]] = None,
        data_keys: Optional[List[Union[str, int, DataKey]]] = None,
    ) -> Union[TensorWithTransformMat, Tuple[TensorWithTransformMat,
                                             Optional[torch.Tensor]],
               List[TensorWithTransformMat], Tuple[
                   List[TensorWithTransformMat], Optional[torch.Tensor]], ]:
        """Compute multiple tensors simultaneously according to ``self.data_keys``."""
        if data_keys is None:
            data_keys = cast(List[Union[str, int, DataKey]], self.data_keys)
        else:
            data_keys = [DataKey.get(inp) for inp in data_keys]

        if len(args) != len(data_keys):
            raise AssertionError(
                f"The number of inputs must align with the number of data_keys. Got {len(args)} and {len(data_keys)}."
            )

        if params is None:
            if DataKey.INPUT in data_keys:
                _input = args[data_keys.index(DataKey.INPUT)]
                if isinstance(_input, (tuple, list)):
                    inp = _input[0]
                else:
                    inp = _input
                if self.contains_video_sequential:
                    _, out_shape = self.autofill_dim(inp, dim_range=(3, 5))
                else:
                    _, out_shape = self.autofill_dim(inp, dim_range=(2, 4))
                params = self.forward_parameters(out_shape)
            else:
                raise ValueError(
                    "`params` must be provided whilst INPUT is not in data_keys."
                )

        outputs: List[TensorWithTransformMat] = [None] * len(
            data_keys)  # type: ignore
        if DataKey.INPUT in data_keys:
            idx = data_keys.index(DataKey.INPUT)
            out = super().forward(args[idx], label, params=params)
            if self.return_label:
                input, label = cast(
                    Tuple[TensorWithTransformMat, torch.Tensor], out)
            else:
                input = cast(TensorWithTransformMat, out)
            outputs[idx] = input

        self.return_label = label is not None or self.contains_label_operations(
            params)

        for idx, (input, dcate, out) in enumerate(zip(args, data_keys,
                                                      outputs)):
            if out is not None:
                continue
            for param in params:
                module = self.get_submodule(param.name)
                if dcate == DataKey.INPUT:
                    input, label = self.apply_to_input(input,
                                                       label,
                                                       module=module,
                                                       param=param)
                elif isinstance(
                        module,
                        IntensityAugmentationBase2D) and dcate in DataKey:
                    pass  # Do nothing
                elif isinstance(module,
                                ImageSequential) and module.is_intensity_only(
                                ) and dcate in DataKey:
                    pass  # Do nothing
                elif isinstance(module, VideoSequential) and dcate not in [
                        DataKey.INPUT, DataKey.MASK
                ]:
                    batch_size: int = input.size(0)
                    input = input.view(-1, *input.shape[2:])
                    input, label = ApplyInverse.apply_by_key(
                        input, label, module, param, dcate)
                    input = input.view(batch_size, -1, *input.shape[1:])
                elif isinstance(module, PatchSequential):
                    raise NotImplementedError(
                        "Geometric involved PatchSequential is not supported.")
                elif isinstance(module, (
                        GeometricAugmentationBase2D,
                        ImageSequential,
                )) and dcate in DataKey:
                    input, label = ApplyInverse.apply_by_key(
                        input, label, module, param, dcate)
                elif isinstance(module, (SequentialBase, )):
                    raise ValueError(f"Unsupported Sequential {module}.")
                else:
                    raise NotImplementedError(
                        f"data_key {dcate} is not implemented for {module}.")
            outputs[idx] = input

        return self.__packup_output__(outputs, label)
Exemplo n.º 12
0
    def forward(  # type: ignore
        self,
        *args: Tensor,
        label: Optional[Tensor] = None,
        params: Optional[List[ParamItem]] = None,
        data_keys: Optional[List[Union[str, int, DataKey]]] = None,
    ) -> Union[Tensor, Tuple[Tensor, Optional[Tensor]], List[Tensor], Tuple[
            List[Tensor], Optional[Tensor]], ]:
        """Compute multiple tensors simultaneously according to ``self.data_keys``."""
        _data_keys: List[DataKey]
        if data_keys is None:
            _data_keys = self.data_keys
        else:
            _data_keys = [DataKey.get(inp) for inp in data_keys]
            self.data_keys = _data_keys
        self._validate_args_datakeys(*args, data_keys=_data_keys)

        args = self._arguments_preproc(*args, data_keys=_data_keys)

        if params is None:
            # image data must exist if params is not provided.
            if DataKey.INPUT in _data_keys:
                _input = args[_data_keys.index(DataKey.INPUT)]
                inp = _input
                if isinstance(inp, (tuple, list)):
                    raise ValueError(
                        f"`INPUT` should be a tensor but `{type(inp)}` received."
                    )
                # A video input shall be BCDHW while an image input shall be BCHW
                if self.contains_video_sequential or self.contains_3d_augmentation:
                    _, out_shape = self.autofill_dim(inp, dim_range=(3, 5))
                else:
                    _, out_shape = self.autofill_dim(inp, dim_range=(2, 4))
                params = self.forward_parameters(out_shape)
            else:
                raise ValueError(
                    "`params` must be provided whilst INPUT is not in data_keys."
                )

        outputs: List[Tensor] = [None] * len(_data_keys)  # type: ignore
        # Forward the first image data to freeze the parameters.
        if DataKey.INPUT in _data_keys:
            idx = _data_keys.index(DataKey.INPUT)
            _inp = args[idx]
            _out = super().forward(_inp, label, params=params)  # type: ignore
            self._transform_matrix = self.get_transformation_matrix(
                _inp, params=params)
            if self.return_label:
                _input, label = cast(Tuple[Tensor, Tensor], _out)
            else:
                _input = cast(Tensor, _out)
            outputs[idx] = _input

        self.return_label = self.return_label or label is not None or self.contains_label_operations(
            params)

        for idx, (arg, dcate, out) in enumerate(zip(args, _data_keys,
                                                    outputs)):
            if out is not None:
                continue
            # Using tensors straight-away
            if isinstance(arg, (Boxes, )):
                input = arg.data  # all boxes are in (B, N, 4, 2) format now.
            else:
                input = arg

            for param in params:
                module = self.get_submodule(param.name)
                if dcate == DataKey.INPUT:
                    input, label = self.apply_to_input(input,
                                                       label,
                                                       module=module,
                                                       param=param)
                elif isinstance(module, IntensityAugmentationBase2D) and dcate in DataKey \
                        and not isinstance(module, RandomErasing):
                    pass  # Do nothing
                elif isinstance(module,
                                ImageSequential) and module.is_intensity_only(
                                ) and dcate in DataKey:
                    pass  # Do nothing
                elif isinstance(module, VideoSequential) and dcate not in [
                        DataKey.INPUT, DataKey.MASK
                ]:
                    batch_size: int = input.size(0)
                    input = input.view(-1, *input.shape[2:])
                    input, label = ApplyInverse.apply_by_key(
                        input, label, module, param, dcate)
                    input = input.view(batch_size, -1, *input.shape[1:])
                elif isinstance(module, PatchSequential):
                    raise NotImplementedError(
                        "Geometric involved PatchSequential is not supported.")
                elif isinstance(module, (GeometricAugmentationBase2D, ImageSequential, RandomErasing)) \
                        and dcate in DataKey:
                    input, label = ApplyInverse.apply_by_key(
                        input, label, module, param, dcate)
                elif isinstance(module, (SequentialBase, )):
                    raise ValueError(f"Unsupported Sequential {module}.")
                else:
                    raise NotImplementedError(
                        f"data_key {dcate} is not implemented for {module}.")

            if isinstance(arg, (Boxes, )):
                arg._data = input
                outputs[idx] = arg.to_tensor()
            else:
                outputs[idx] = input

        return self.__packup_output__(outputs, label)
Exemplo n.º 13
0
    def inverse(  # type: ignore
        self,
        *args: Tensor,
        params: Optional[List[ParamItem]] = None,
        data_keys: Optional[List[Union[str, int, DataKey]]] = None,
    ) -> Union[Tensor, List[Tensor]]:
        """Reverse the transformation applied.

        Number of input tensors must align with the number of``data_keys``. If ``data_keys`` is not set, use
        ``self.data_keys`` by default.
        """
        if data_keys is None:
            data_keys = cast(List[Union[str, int, DataKey]], self.data_keys)

        _data_keys: List[DataKey] = [DataKey.get(inp) for inp in data_keys]

        if len(args) != len(data_keys):
            raise AssertionError(
                "The number of inputs must align with the number of data_keys, "
                f"Got {len(args)} and {len(data_keys)}.")

        args = self._arguments_preproc(*args, data_keys=_data_keys)

        if params is None:
            if self._params is None:
                raise ValueError(
                    "No parameters available for inversing, please run a forward pass first "
                    "or passing valid params into this function.")
            params = self._params

        outputs: List[Tensor] = [None] * len(data_keys)  # type: ignore
        for idx, (arg, dcate) in enumerate(zip(args, data_keys)):
            if dcate == DataKey.INPUT and isinstance(arg, (tuple, list)):
                input, _ = arg  # ignore the transformation matrix whilst inverse
            # Using tensors straight-away
            elif isinstance(arg, (Boxes, )):
                input = arg.data  # all boxes are in (B, N, 4, 2) format now.
            else:
                input = arg
            for (name, module), param in zip_longest(
                    list(self.get_forward_sequence(params))[::-1],
                    params[::-1]):
                if isinstance(module, (_AugmentationBase, ImageSequential)):
                    param = params[name] if name in params else param
                else:
                    param = None
                if isinstance(module, IntensityAugmentationBase2D) and dcate in DataKey \
                        and not isinstance(module, RandomErasing):
                    pass  # Do nothing
                elif isinstance(module,
                                ImageSequential) and module.is_intensity_only(
                                ) and dcate in DataKey:
                    pass  # Do nothing
                elif isinstance(module, VideoSequential) and dcate not in [
                        DataKey.INPUT, DataKey.MASK
                ]:
                    batch_size: int = input.size(0)
                    input = input.view(-1, *input.shape[2:])
                    input = ApplyInverse.inverse_by_key(
                        input, module, param, dcate)
                    input = input.view(batch_size, -1, *input.shape[1:])
                elif isinstance(module, PatchSequential):
                    raise NotImplementedError(
                        "Geometric involved PatchSequential is not supported.")
                elif isinstance(module, (GeometricAugmentationBase2D, ImageSequential, RandomErasing)) \
                        and dcate in DataKey:
                    input = ApplyInverse.inverse_by_key(
                        input, module, param, dcate)
                elif isinstance(module, (SequentialBase, )):
                    raise ValueError(f"Unsupported Sequential {module}.")
                else:
                    raise NotImplementedError(
                        f"data_key {dcate} is not implemented for {module}.")
            if isinstance(arg, (Boxes, )):
                arg._data = input
                outputs[idx] = arg.to_tensor()
            else:
                outputs[idx] = input

        if len(outputs) == 1 and isinstance(outputs, (tuple, list)):
            return outputs[0]

        return outputs