예제 #1
0
파일: augment.py 프로젝트: Russ76/kornia
 def inverse_bbox(
     self,
     input: torch.Tensor,
     module: nn.Module,
     param: Optional[Dict[str, torch.Tensor]] = None,
     mode: str = "xyxy",
 ) -> torch.Tensor:
     if isinstance(module, GeometricAugmentationBase2D):
         transform = module.compute_inverse_transformation(
             module.get_transformation_matrix(input, param))
         input = transform_boxes(
             torch.as_tensor(transform,
                             device=input.device,
                             dtype=input.dtype), input, mode)
     return input
예제 #2
0
 def apply_to_bbox(
     self,
     input: torch.Tensor,
     module: nn.Module,
     param: Optional[Dict[str, torch.Tensor]] = None,
     mode: str = "xyxy",
 ) -> torch.Tensor:
     if isinstance(module, GeometricAugmentationBase2D) and param is None:
         raise ValueError(
             f"Transformation matrix for {module} has not been computed.")
     if isinstance(module,
                   GeometricAugmentationBase2D) and param is not None:
         input = transform_boxes(
             module.get_transformation_matrix(input, param), input, mode)
     else:
         pass  # No need to update anything
     return input