def test_to_norm_affine(self,
                            affine,
                            src_size,
                            dst_size,
                            align_corners,
                            expected,
                            zero_centered=False):
        affine = torch.as_tensor(affine,
                                 device=torch.device("cpu:0"),
                                 dtype=torch.float32)
        new_affine = to_norm_affine(affine, src_size, dst_size, align_corners,
                                    zero_centered)
        new_affine = new_affine.detach().cpu().numpy()
        np.testing.assert_allclose(new_affine, expected, atol=1e-6)

        if torch.cuda.is_available():
            affine = torch.as_tensor(affine,
                                     device=torch.device("cuda:0"),
                                     dtype=torch.float32)
            new_affine = to_norm_affine(affine, src_size, dst_size,
                                        align_corners, zero_centered)
            new_affine = new_affine.detach().cpu().numpy()
            np.testing.assert_allclose(new_affine,
                                       expected,
                                       atol=1e-5,
                                       rtol=_rtol)
 def test_to_norm_affine_ill(self, affine, src_size, dst_size,
                             align_corners):
     with self.assertRaises(TypeError):
         to_norm_affine(affine, src_size, dst_size, align_corners)
     with self.assertRaises(ValueError):
         affine = torch.as_tensor(affine,
                                  device=torch.device("cpu:0"),
                                  dtype=torch.float32)
         to_norm_affine(affine, src_size, dst_size, align_corners)
Exemple #3
0
    def forward(self,
                src,
                theta,
                spatial_size: Optional[Union[Sequence[int], int]] = None):
        """
        ``theta`` must be an affine transformation matrix with shape
        3x3 or Nx3x3 or Nx2x3 or 2x3 for spatial 2D transforms,
        4x4 or Nx4x4 or Nx3x4 or 3x4 for spatial 3D transforms,
        where `N` is the batch size. `theta` will be converted into float Tensor for the computation.

        Args:
            src (array_like): image in spatial 2D or 3D (N, C, spatial_dims),
                where N is the batch dim, C is the number of channels.
            theta (array_like): Nx3x3, Nx2x3, 3x3, 2x3 for spatial 2D inputs,
                Nx4x4, Nx3x4, 3x4, 4x4 for spatial 3D inputs. When the batch dimension is omitted,
                `theta` will be repeated N times, N is the batch dim of `src`.
            spatial_size: output spatial shape, the full output shape will be
                `[N, C, *spatial_size]` where N and C are inferred from the `src`.

        Raises:
            TypeError: When ``theta`` is not a ``torch.Tensor``.
            ValueError: When ``theta`` is not one of [Nxdxd, dxd].
            ValueError: When ``theta`` is not one of [Nx3x3, Nx4x4].
            TypeError: When ``src`` is not a ``torch.Tensor``.
            ValueError: When ``src`` spatially is not one of [2D, 3D].
            ValueError: When affine and image batch dimension differ.

        """
        # validate `theta`
        if not torch.is_tensor(theta):
            raise TypeError(
                f"theta must be torch.Tensor but is {type(theta).__name__}.")
        if theta.ndim not in (2, 3):
            raise ValueError(f"theta must be Nxdxd or dxd, got {theta.shape}.")
        if theta.ndim == 2:
            theta = theta[None]  # adds a batch dim.
        theta = theta.clone()  # no in-place change of theta
        theta_shape = tuple(theta.shape[1:])
        if theta_shape in ((2, 3), (3, 4)):  # needs padding to dxd
            pad_affine = torch.tensor([0, 0, 1] if theta_shape[0] ==
                                      2 else [0, 0, 0, 1])
            pad_affine = pad_affine.repeat(theta.shape[0], 1, 1).to(theta)
            pad_affine.requires_grad = False
            theta = torch.cat([theta, pad_affine], dim=1)
        if tuple(theta.shape[1:]) not in ((3, 3), (4, 4)):
            raise ValueError(
                f"theta must be Nx3x3 or Nx4x4, got {theta.shape}.")

        # validate `src`
        if not torch.is_tensor(src):
            raise TypeError(
                f"src must be torch.Tensor but is {type(src).__name__}.")
        sr = src.ndim - 2  # input spatial rank
        if sr not in (2, 3):
            raise ValueError(
                f"Unsupported src dimension: {sr}, available options are [2, 3]."
            )

        # set output shape
        src_size = tuple(src.shape)
        dst_size = src_size  # default to the src shape
        if self.spatial_size is not None:
            dst_size = src_size[:2] + self.spatial_size
        if spatial_size is not None:
            dst_size = src_size[:2] + ensure_tuple(spatial_size)

        # reverse and normalise theta if needed
        if not self.normalized:
            theta = to_norm_affine(affine=theta,
                                   src_size=src_size[2:],
                                   dst_size=dst_size[2:],
                                   align_corners=self.align_corners)
        if self.reverse_indexing:
            rev_idx = torch.as_tensor(range(sr - 1, -1, -1), device=src.device)
            theta[:, :sr] = theta[:, rev_idx]
            theta[:, :, :sr] = theta[:, :, rev_idx]
        if (theta.shape[0] == 1) and src_size[0] > 1:
            # adds a batch dim to `theta` in order to match `src`
            theta = theta.repeat(src_size[0], 1, 1)
        if theta.shape[0] != src_size[0]:
            raise ValueError(
                f"affine and image batch dimension must match, got affine={theta.shape[0]} image={src_size[0]}."
            )

        grid = nn.functional.affine_grid(theta=theta[:, :sr],
                                         size=list(dst_size),
                                         align_corners=self.align_corners)
        dst = nn.functional.grid_sample(
            input=src.contiguous(),
            grid=grid,
            mode=self.mode.value,
            padding_mode=self.padding_mode.value,
            align_corners=self.align_corners,
        )
        return dst