def test_batched_eye(self): for dtype in [torch.float, torch.long]: for ndim in range(10): for batchsize in range(10): with self.subTest(batchsize=batchsize, ndim=ndim, dtype=dtype): batched_eye = get_batched_eye(batchsize=batchsize, ndim=ndim, dtype=dtype) self.assertTupleEqual(batched_eye.size(), (batchsize, ndim, ndim)) self.assertEqual(dtype, batched_eye.dtype) non_batched_eye = torch.eye(ndim, dtype=dtype) for _eye in batched_eye: self.assertTrue(torch.allclose(_eye, non_batched_eye, atol=1e-6))
def create_translation(offset: AffineParamType, batchsize: int, ndim: int, device: Optional[Union[torch.device, str]] = None, dtype: Optional[Union[torch.dtype, str]] = None, image_transform: bool = True) -> torch.Tensor: """ Formats the given translation parameters to a homogeneous transformation matrix Args: offset: the translation offset(s). Supported are: * a single parameter (as float or int), which will be replicated for all dimensions and batch samples * a parameter per sample, which will be replicated for all dimensions * a parameter per dimension, which will be replicated for all batch samples * a parameter per sampler per dimension * None will be treated as a translation offset of 0 batchsize: the number of samples per batch ndim: the dimensionality of the transform device: the device to put the resulting tensor to. Defaults to the torch default device dtype: the dtype of the resulting trensor. Defaults to the torch default dtype image_transform: bool inverts the translation matrix to match expected behavior when applied to an image, e.g. translation > 0 should move the image in the positive direction of an axis but the grid in the negative direction Returns: torch.Tensor: the homogeneous transformation matrix [N, NDIM + 1, NDIM + 1], N is the batch size and NDIM is the number of spatial dimensions """ if offset is None: offset = 0 offset = expand_scalar_param(offset, batchsize, ndim).to(device=device, dtype=dtype) eye_batch = get_batched_eye(batchsize=batchsize, ndim=ndim, device=device, dtype=dtype) translation_matrix = torch.stack([ torch.cat([eye, o.view(-1, 1)], dim=1) for eye, o in zip(eye_batch, offset) ]) if image_transform: translation_matrix[..., -1] = -translation_matrix[..., -1] return matrix_to_homogeneous(translation_matrix)
def create_scale(scale: AffineParamType, batchsize: int, ndim: int, device: Optional[Union[torch.device, str]] = None, dtype: Optional[Union[torch.dtype, str]] = None, image_transform: bool = True) -> torch.Tensor: """ Formats the given scale parameters to a homogeneous transformation matrix Args: scale : the scale factor(s). Supported are: * a single parameter (as float or int), which will be replicated for all dimensions and batch samples * a parameter per sample, which will be replicated for all dimensions * a parameter per dimension, which will be replicated for all batch samples * a parameter per sampler per dimension * None will be treated as a scaling factor of 1 batchsize: the number of samples per batch ndim: the dimensionality of the transform device: the device to put the resulting tensor to. Defaults to the torch default device dtype: the dtype of the resulting trensor. Defaults to the torch default dtype image_transform: inverts the scale matrix to match expected behavior when applied to an image, e.g. scale>1 increases the size of an image but decrease the size of an grid Returns: torch.Tensor: the homogeneous transformation matrix [N, NDIM + 1, NDIM + 1], N is the batch size and NDIM is the number of spatial dimensions """ if scale is None: scale = 1 scale = expand_scalar_param(scale, batchsize, ndim).to(device=device, dtype=dtype) if image_transform: scale = 1 / scale scale_matrix = torch.stack([ eye * s for eye, s in zip( get_batched_eye( batchsize=batchsize, ndim=ndim, device=device, dtype=dtype), scale) ]) return matrix_to_homogeneous(scale_matrix)