Exemplo n.º 1
0
    def test_batched_eye(self):
        for dtype in [torch.float, torch.long]:
            for ndim in range(10):
                for batchsize in range(10):
                    with self.subTest(batchsize=batchsize, ndim=ndim, dtype=dtype):
                        batched_eye = get_batched_eye(batchsize=batchsize, ndim=ndim, dtype=dtype)
                        self.assertTupleEqual(batched_eye.size(), (batchsize, ndim, ndim))
                        self.assertEqual(dtype, batched_eye.dtype)

                        non_batched_eye = torch.eye(ndim, dtype=dtype)
                        for _eye in batched_eye:
                            self.assertTrue(torch.allclose(_eye, non_batched_eye, atol=1e-6))
Exemplo n.º 2
0
def create_translation(offset: AffineParamType,
                       batchsize: int,
                       ndim: int,
                       device: Optional[Union[torch.device, str]] = None,
                       dtype: Optional[Union[torch.dtype, str]] = None,
                       image_transform: bool = True) -> torch.Tensor:
    """
    Formats the given translation parameters to a homogeneous transformation
    matrix

    Args:
        offset: the translation offset(s). Supported are:
            * a single parameter (as float or int), which will be replicated
            for all dimensions and batch samples
            * a parameter per sample, which will be
            replicated for all dimensions
            * a parameter per dimension, which will be replicated for all
            batch samples
            * a parameter per sampler per dimension
            * None will be treated as a translation offset of 0
        batchsize: the number of samples per batch
        ndim: the dimensionality of the transform
        device: the device to put the resulting tensor to.
            Defaults to the torch default device
        dtype: the dtype of the resulting trensor.
            Defaults to the torch default dtype
        image_transform: bool
            inverts the translation matrix to match expected behavior when
            applied to an image, e.g. translation > 0 should move the image
            in the positive direction of an axis but the grid in the negative
            direction

    Returns:
        torch.Tensor: the homogeneous transformation matrix [N, NDIM + 1, NDIM + 1],
            N is the batch size and NDIM is the number of spatial dimensions
    """
    if offset is None:
        offset = 0
    offset = expand_scalar_param(offset, batchsize, ndim).to(device=device,
                                                             dtype=dtype)
    eye_batch = get_batched_eye(batchsize=batchsize,
                                ndim=ndim,
                                device=device,
                                dtype=dtype)
    translation_matrix = torch.stack([
        torch.cat([eye, o.view(-1, 1)], dim=1)
        for eye, o in zip(eye_batch, offset)
    ])
    if image_transform:
        translation_matrix[..., -1] = -translation_matrix[..., -1]
    return matrix_to_homogeneous(translation_matrix)
Exemplo n.º 3
0
def create_scale(scale: AffineParamType,
                 batchsize: int,
                 ndim: int,
                 device: Optional[Union[torch.device, str]] = None,
                 dtype: Optional[Union[torch.dtype, str]] = None,
                 image_transform: bool = True) -> torch.Tensor:
    """
    Formats the given scale parameters to a homogeneous transformation matrix

    Args:
        scale : the scale factor(s). Supported are:
            * a single parameter (as float or int), which will be replicated
            for all dimensions and batch samples
            * a parameter per sample, which will be
            replicated for all dimensions
            * a parameter per dimension, which will be replicated for all
            batch samples
            * a parameter per sampler per dimension
            * None will be treated as a scaling factor of 1
        batchsize: the number of samples per batch
        ndim: the dimensionality of the transform
        device: the device to put the resulting tensor to.
            Defaults to the torch default device
        dtype: the dtype of the resulting trensor.
            Defaults to the torch default dtype
        image_transform:  inverts the scale matrix to match expected behavior
            when applied to an image, e.g. scale>1 increases the size of an
            image but decrease the size of an grid

    Returns:
        torch.Tensor: the homogeneous transformation matrix
            [N, NDIM + 1, NDIM + 1], N is the batch size and NDIM is the
            number of spatial dimensions
    """
    if scale is None:
        scale = 1

    scale = expand_scalar_param(scale, batchsize, ndim).to(device=device,
                                                           dtype=dtype)
    if image_transform:
        scale = 1 / scale
    scale_matrix = torch.stack([
        eye * s for eye, s in zip(
            get_batched_eye(
                batchsize=batchsize, ndim=ndim, device=device, dtype=dtype),
            scale)
    ])
    return matrix_to_homogeneous(scale_matrix)