def _check_new_img_size(curr_img_size, matrix: torch.Tensor, zero_border: bool = False) -> torch.Tensor: """ Calculates the image size so that the whole image content fits the image. The resulting size will be the maximum size of the batch, so that the images can remain batched. Args: curr_img_size: the size of the current image. If int, it will be used as size for all image dimensions matrix: a batch of affine matrices with shape [N, NDIM, NDIM+1] zero_border: whether or not to have a fixed image border at zero Returns: torch.Tensor: the new image size """ n_dim = matrix.size(-1) - 1 if check_scalar(curr_img_size): curr_img_size = [curr_img_size] * n_dim possible_points = unit_box(n_dim, torch.tensor(curr_img_size)).to(matrix) transformed_edges = affine_point_transform( possible_points[None].expand(matrix.size(0), *[-1 for _ in possible_points.shape ]).clone(), matrix) if zero_border: substr = 0 else: substr = transformed_edges.min(1)[0] return (transformed_edges.max(1)[0] - substr).max(0)[0]
def expand_scalar_param(param: AffineParamType, batchsize: int, ndim: int) -> Tensor: """ Bring affine params to shape (batchsize, ndim) Args: param: affine parameter batchsize: size of batch ndim: number of spatial dimensions Returns: torch.Tensor: affine params in correct shape """ if check_scalar(param): return torch.tensor([[param] * ndim] * batchsize).float() if not torch.is_tensor(param): param = torch.tensor(param) else: param = param.clone() if not param.ndimension() == 2: if param.shape[0] == ndim: # scalar per dim param = param.reshape(1, -1).expand(batchsize, ndim) elif param.shape[0] == batchsize: # scalar per batch param = param.reshape(-1, 1).expand(batchsize, ndim) else: raise ValueError( "Unknown param for expanding. " f"Found {param} for batchsize {batchsize} and ndim {ndim}") assert all([i == j for i, j in zip(param.shape, (batchsize, ndim))]), \ (f"Affine param need to have shape (batchsize, ndim)" f"({(batchsize, ndim)}) but found {param.shape}") return param.float()
def test_scalar_check(self): expectations = [True, True, False, False, False, True, False] inputs = [0.0, 1, None, "123", [1, 1], torch.tensor(1), torch.tensor([1, 2])] for inp, exp in zip(inputs, expectations): with self.subTest(input=inp, expectation=exp): self.assertEqual(check_scalar(inp), exp)
def assemble_matrix(self, **data) -> torch.Tensor: """ Handles the matrix assembly and calculates the scale factors for resizing Args: **data: the data to be transformed. Will be used to determine batchsize, dimensionality, dtype and device Returns: torch.Tensor: the (batched) transformation matrix """ curr_img_size = data[self.keys[0]].shape[2:] output_size = self.size if torch.is_tensor(output_size): self.output_size = int(output_size.item()) else: self.output_size = tuple(int(t.item()) for t in output_size) if check_scalar(output_size): output_size = [output_size] * len(curr_img_size) self.scale = [float(output_size[i]) / float(curr_img_size[i]) for i in range(len(curr_img_size))] matrix = super().assemble_matrix(**data) return matrix
def assemble_matrix(self, **data) -> torch.Tensor: """ Handles the matrix assembly and calculates the scale factors for resizing Args: **data: the data to be transformed. Will be used to determine batchsize, dimensionality, dtype and device Returns: torch.Tensor: the (batched) transformation matrix """ curr_img_size = data[self.keys[0]].shape[2:] was_scalar = check_scalar(self.output_size) if was_scalar: self.output_size = [self.output_size] * len(curr_img_size) self.scale = [self.output_size[i] / curr_img_size[-i] for i in range(len(curr_img_size))] matrix = super().assemble_matrix(**data) if was_scalar: self.output_size = self.output_size[0] return matrix
def affine_image_transform( image_batch: torch.Tensor, matrix_batch: torch.Tensor, output_size: Optional[tuple] = None, adjust_size: bool = False, interpolation_mode: str = 'bilinear', padding_mode: str = 'zeros', align_corners: bool = False, reverse_order: bool = False, ) -> torch.Tensor: """ Performs an affine transformation on a batch of images Args: image_batch: the batch to transform. Should have shape of [N, C, NDIM] matrix_batch: a batch of affine matrices with shape [N, NDIM, NDIM+1] output_size: if given, this will be the resulting image size. Defaults to ``None`` adjust_size: if True, the resulting image size will be calculated dynamically to ensure that the whole image fits. interpolation_mode: interpolation mode to calculate output values 'bilinear' | 'nearest'. Default: 'bilinear' padding_mode: padding mode for outside grid values 'zeros' | 'border' | 'reflection'. Default: 'zeros' align_corners: Geometrically, we consider the pixels of the input as squares rather than points. If set to True, the extrema (-1 and 1) are considered as referring to the center points of the input’s corner pixels. If set to False, they are instead considered as referring to the corner points of the input’s corner pixels, making the sampling more resolution agnostic. Returns: torch.Tensor: transformed image Warnings: When align_corners = True, the grid positions depend on the pixel size relative to the input image size, and so the locations sampled by grid_sample() will differ for the same input given at different resolutions (that is, after being upsampled or downsampled). Notes: :attr:`output_size` and :attr:`adjust_size` are mutually exclusive. If None of them is set, the resulting image will have the same size as the input image. """ # add batch dimension if necessary if len(matrix_batch.shape) < 3: matrix_batch = matrix_batch[None, ...].expand(image_batch.size(0), -1, -1).clone() image_size = image_batch.shape[2:] if output_size is not None: if check_scalar(output_size): output_size = tuple([output_size] * matrix_batch.size(-2)) if adjust_size: warnings.warn( "Adjust size is mutually exclusive with a " "given output size.", UserWarning) new_size = output_size elif adjust_size: new_size = tuple([ int(tmp.item()) for tmp in _check_new_img_size(image_size, matrix_batch) ]) else: new_size = image_size if len(image_size) < len(image_batch.shape): missing_dims = len(image_batch.shape) - len(image_size) new_size = (*image_batch.shape[:missing_dims], *new_size) matrix_batch = matrix_batch.to(image_batch) if reverse_order: matrix_batch = matrix_revert_coordinate_order(matrix_batch) grid = torch.nn.functional.affine_grid(matrix_batch, size=new_size, align_corners=align_corners) return torch.nn.functional.grid_sample(image_batch, grid, mode=interpolation_mode, padding_mode=padding_mode, align_corners=align_corners)