Exemple #1
0
def _check_new_img_size(curr_img_size,
                        matrix: torch.Tensor,
                        zero_border: bool = False) -> torch.Tensor:
    """
    Calculates the image size so that the whole image content fits the image.
    The resulting size will be the maximum size of the batch, so that the
    images can remain batched.

    Args:
        curr_img_size: the size of the current image.
            If int, it will be used as size for all image dimensions
        matrix: a batch of affine matrices with shape [N, NDIM, NDIM+1]
        zero_border: whether or not to have a fixed image border at zero

    Returns:
        torch.Tensor: the new image size
    """
    n_dim = matrix.size(-1) - 1
    if check_scalar(curr_img_size):
        curr_img_size = [curr_img_size] * n_dim
    possible_points = unit_box(n_dim, torch.tensor(curr_img_size)).to(matrix)

    transformed_edges = affine_point_transform(
        possible_points[None].expand(matrix.size(0),
                                     *[-1 for _ in possible_points.shape
                                       ]).clone(), matrix)

    if zero_border:
        substr = 0
    else:
        substr = transformed_edges.min(1)[0]

    return (transformed_edges.max(1)[0] - substr).max(0)[0]
Exemple #2
0
def expand_scalar_param(param: AffineParamType, batchsize: int,
                        ndim: int) -> Tensor:
    """
    Bring affine params to shape (batchsize, ndim)

    Args:
        param: affine parameter
        batchsize: size of batch
        ndim: number of spatial dimensions

    Returns:
        torch.Tensor: affine params in correct shape
    """
    if check_scalar(param):
        return torch.tensor([[param] * ndim] * batchsize).float()

    if not torch.is_tensor(param):
        param = torch.tensor(param)
    else:
        param = param.clone()

    if not param.ndimension() == 2:
        if param.shape[0] == ndim:  # scalar per dim
            param = param.reshape(1, -1).expand(batchsize, ndim)
        elif param.shape[0] == batchsize:  # scalar per batch
            param = param.reshape(-1, 1).expand(batchsize, ndim)
        else:
            raise ValueError(
                "Unknown param for expanding. "
                f"Found {param} for batchsize {batchsize} and ndim {ndim}")
    assert all([i == j for i, j in zip(param.shape, (batchsize, ndim))]), \
        (f"Affine param need to have shape (batchsize, ndim)"
         f"({(batchsize, ndim)}) but found {param.shape}")
    return param.float()
Exemple #3
0
    def test_scalar_check(self):
        expectations = [True, True, False, False, False, True, False]
        inputs = [0.0, 1, None, "123", [1, 1], torch.tensor(1), torch.tensor([1, 2])]

        for inp, exp in zip(inputs, expectations):
            with self.subTest(input=inp, expectation=exp):
                self.assertEqual(check_scalar(inp), exp)
Exemple #4
0
    def assemble_matrix(self, **data) -> torch.Tensor:
        """
        Handles the matrix assembly and calculates the scale factors for
        resizing

        Args:
            **data: the data to be transformed. Will be used to determine
                batchsize, dimensionality, dtype and device

        Returns:
            torch.Tensor: the (batched) transformation matrix

        """
        curr_img_size = data[self.keys[0]].shape[2:]
        output_size = self.size

        if torch.is_tensor(output_size):
            self.output_size = int(output_size.item())
        else:
            self.output_size = tuple(int(t.item()) for t in output_size)

        if check_scalar(output_size):
            output_size = [output_size] * len(curr_img_size)

        self.scale = [float(output_size[i]) / float(curr_img_size[i]) for i in range(len(curr_img_size))]
        matrix = super().assemble_matrix(**data)
        return matrix
Exemple #5
0
    def assemble_matrix(self, **data) -> torch.Tensor:
        """
        Handles the matrix assembly and calculates the scale factors for
        resizing

        Args:
            **data: the data to be transformed. Will be used to determine
                batchsize, dimensionality, dtype and device

        Returns:
            torch.Tensor: the (batched) transformation matrix

        """
        curr_img_size = data[self.keys[0]].shape[2:]

        was_scalar = check_scalar(self.output_size)

        if was_scalar:
            self.output_size = [self.output_size] * len(curr_img_size)

        self.scale = [self.output_size[i] / curr_img_size[-i]
                      for i in range(len(curr_img_size))]

        matrix = super().assemble_matrix(**data)

        if was_scalar:
            self.output_size = self.output_size[0]

        return matrix
Exemple #6
0
def affine_image_transform(
    image_batch: torch.Tensor,
    matrix_batch: torch.Tensor,
    output_size: Optional[tuple] = None,
    adjust_size: bool = False,
    interpolation_mode: str = 'bilinear',
    padding_mode: str = 'zeros',
    align_corners: bool = False,
    reverse_order: bool = False,
) -> torch.Tensor:
    """
    Performs an affine transformation on a batch of images

    Args:
        image_batch: the batch to transform. Should have shape of [N, C, NDIM]
        matrix_batch: a batch of affine matrices with shape [N, NDIM, NDIM+1]
        output_size: if given, this will be the resulting image size.
            Defaults to ``None``
        adjust_size: if True, the resulting image size will be calculated
            dynamically to ensure that the whole image fits.
        interpolation_mode: interpolation mode to calculate output values
            'bilinear' | 'nearest'. Default: 'bilinear'
        padding_mode: padding mode for outside grid values
            'zeros' | 'border' | 'reflection'. Default: 'zeros'
        align_corners:  Geometrically, we consider the pixels of the input as
            squares rather than points.
            If set to True, the extrema (-1 and 1) are
            considered as referring to the center points of the input’s corner
            pixels. If set to False, they are instead considered as referring
            to the corner points of the input’s corner pixels,
            making the sampling more resolution agnostic.

    Returns:
        torch.Tensor: transformed image

    Warnings:
        When align_corners = True, the grid positions depend on the pixel size
        relative to the input image size, and so the locations sampled by
        grid_sample() will differ for the same input given at different
        resolutions (that is, after being upsampled or downsampled).

    Notes:
        :attr:`output_size` and :attr:`adjust_size` are mutually exclusive.
        If None of them is set, the resulting image will have the same size
        as the input image.
    """
    # add batch dimension if necessary
    if len(matrix_batch.shape) < 3:
        matrix_batch = matrix_batch[None, ...].expand(image_batch.size(0), -1,
                                                      -1).clone()

    image_size = image_batch.shape[2:]

    if output_size is not None:
        if check_scalar(output_size):
            output_size = tuple([output_size] * matrix_batch.size(-2))

        if adjust_size:
            warnings.warn(
                "Adjust size is mutually exclusive with a "
                "given output size.", UserWarning)
        new_size = output_size
    elif adjust_size:
        new_size = tuple([
            int(tmp.item())
            for tmp in _check_new_img_size(image_size, matrix_batch)
        ])
    else:
        new_size = image_size

    if len(image_size) < len(image_batch.shape):
        missing_dims = len(image_batch.shape) - len(image_size)
        new_size = (*image_batch.shape[:missing_dims], *new_size)

    matrix_batch = matrix_batch.to(image_batch)

    if reverse_order:
        matrix_batch = matrix_revert_coordinate_order(matrix_batch)

    grid = torch.nn.functional.affine_grid(matrix_batch,
                                           size=new_size,
                                           align_corners=align_corners)

    return torch.nn.functional.grid_sample(image_batch,
                                           grid,
                                           mode=interpolation_mode,
                                           padding_mode=padding_mode,
                                           align_corners=align_corners)