Beispiel #1
0
 def test_to_norm_affine_ill(self, affine, src_size, dst_size,
                             align_corners):
     with self.assertRaises(ValueError):
         to_norm_affine(affine, src_size, dst_size, align_corners)
     with self.assertRaises(ValueError):
         affine = torch.as_tensor(affine,
                                  device=torch.device("cpu:0"),
                                  dtype=torch.float32)
         to_norm_affine(affine, src_size, dst_size, align_corners)
Beispiel #2
0
    def test_to_norm_affine(self, affine, src_size, dst_size, align_corners, expected):
        affine = torch.as_tensor(affine, device=torch.device("cpu:0"), dtype=torch.float32)
        new_affine = to_norm_affine(affine, src_size, dst_size, align_corners)
        new_affine = new_affine.detach().cpu().numpy()
        np.testing.assert_allclose(new_affine, expected, atol=1e-6)

        if torch.cuda.is_available():
            affine = torch.as_tensor(affine, device=torch.device("cuda:0"), dtype=torch.float32)
            new_affine = to_norm_affine(affine, src_size, dst_size, align_corners)
            new_affine = new_affine.detach().cpu().numpy()
            np.testing.assert_allclose(new_affine, expected, atol=1e-4)
Beispiel #3
0
    def forward(self, src, theta, spatial_size=None):
        """
        ``theta`` must be an affine transformation matrix with shape
        3x3 or Nx3x3 or Nx2x3 or 2x3 for spatial 2D transforms,
        4x4 or Nx4x4 or Nx3x4 or 3x4 for spatial 3D transforms,
        where `N` is the batch size. `theta` will be converted into float Tensor for the computation.

        Args:
            src (array_like): image in spatial 2D or 3D (N, C, spatial_dims),
                where N is the batch dim, C is the number of channels.
            theta (array_like): Nx3x3, Nx2x3, 3x3, 2x3 for spatial 2D inputs,
                Nx4x4, Nx3x4, 3x4, 4x4 for spatial 3D inputs. When the batch dimension is omitted,
                `theta` will be repeated N times, N is the batch dim of `src`.
            spatial_size (list or tuple of int): output spatial shape, the full output shape will be
                `[N, C, *spatial_size]` where N and C are inferred from the `src`.
        """
        # validate `theta`
        if not torch.is_tensor(theta) or not torch.is_tensor(src):
            raise TypeError(
                f"both src and theta must be torch Tensor, got {type(src).__name__}, {type(theta).__name__}."
            )
        if theta.ndim not in (2, 3):
            raise ValueError("affine must be Nxdxd or dxd.")
        if theta.ndim == 2:
            theta = theta[None]  # adds a batch dim.
        theta = theta.clone()  # no in-place change of theta
        theta_shape = tuple(theta.shape[1:])
        if theta_shape in ((2, 3), (3, 4)):  # needs padding to dxd
            pad_affine = torch.tensor([0, 0, 1] if theta_shape[0] ==
                                      2 else [0, 0, 0, 1])
            pad_affine = pad_affine.repeat(theta.shape[0], 1, 1).to(theta)
            pad_affine.requires_grad = False
            theta = torch.cat([theta, pad_affine], dim=1)
        if tuple(theta.shape[1:]) not in ((3, 3), (4, 4)):
            raise ValueError(
                f"affine must be Nx3x3 or Nx4x4, got: {theta.shape}.")

        # validate `src`
        sr = src.ndim - 2  # input spatial rank
        if sr not in (2, 3):
            raise ValueError("src must be spatially 2D or 3D.")

        # set output shape
        src_size = tuple(src.shape)
        dst_size = src_size  # default to the src shape
        if self.spatial_size is not None:
            dst_size = src_size[:2] + self.spatial_size
        if spatial_size is not None:
            dst_size = src_size[:2] + ensure_tuple(spatial_size)

        # reverse and normalise theta if needed
        if not self.normalized:
            theta = to_norm_affine(affine=theta,
                                   src_size=src_size[2:],
                                   dst_size=dst_size[2:],
                                   align_corners=self.align_corners)
        if self.reverse_indexing:
            rev_idx = torch.as_tensor(range(sr - 1, -1, -1), device=src.device)
            theta[:, :sr] = theta[:, rev_idx]
            theta[:, :, :sr] = theta[:, :, rev_idx]
        if (theta.shape[0] == 1) and src_size[0] > 1:
            # adds a batch dim to `theta` in order to match `src`
            theta = theta.repeat(src_size[0], 1, 1)
        if theta.shape[0] != src_size[0]:
            raise ValueError(
                "batch dimension of affine and image does not match, got affine: {} and image: {}."
                .format(theta.shape[0], src_size[0]))

        grid = nn.functional.affine_grid(theta=theta[:, :sr],
                                         size=dst_size,
                                         align_corners=self.align_corners)
        dst = nn.functional.grid_sample(
            input=src.contiguous(),
            grid=grid,
            mode=self.mode,
            padding_mode=self.padding_mode,
            align_corners=self.align_corners,
        )
        return dst