def __init__( self, spatial_size: Sequence[int], rand_size: Sequence[int], pad: int = 0, field_mode: Union[InterpolateMode, str] = InterpolateMode.AREA, align_corners: Optional[bool] = None, prob: float = 0.1, def_range: Union[Sequence[float], float] = 1.0, grid_dtype=torch.float32, grid_mode: Union[GridSampleMode, str] = GridSampleMode.NEAREST, grid_padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, grid_align_corners: Optional[bool] = False, device: Optional[torch.device] = None, ): super().__init__(prob) self.grid_dtype = grid_dtype self.grid_mode = grid_mode self.def_range = def_range self.device = device self.grid_align_corners = grid_align_corners self.grid_padding_mode = grid_padding_mode if isinstance(def_range, (int, float)): self.def_range = (-def_range, def_range) else: if len(def_range) != 2: raise ValueError( "Argument `def_range` should be a number or pair of numbers." ) self.def_range = (min(def_range), max(def_range)) self.sfield = SmoothField( spatial_size=spatial_size, rand_size=rand_size, pad=pad, low=self.def_range[0], high=self.def_range[1], channels=len(rand_size), mode=field_mode, align_corners=align_corners, device=device, ) grid_space = spatial_size if spatial_size is not None else self.sfield.field.shape[ 2:] grid_ranges = [torch.linspace(-1, 1, d) for d in grid_space] grid = meshgrid_ij(*grid_ranges) self.grid = torch.stack(grid).unsqueeze(0).to(self.device, self.grid_dtype)
def make_grid(shape, dtype=None, device=None, requires_grad=True): ranges = [torch.arange(float(s), dtype=dtype, device=device, requires_grad=requires_grad) for s in shape] grid = torch.stack(meshgrid_ij(*ranges), dim=-1) return grid[None]
def get_reference_grid( image_size: Union[Tuple[int], List[int]]) -> torch.Tensor: mesh_points = [torch.arange(0, dim) for dim in image_size] grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...) return grid.to(dtype=torch.float)
def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor: mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]] grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...) grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...) grid = grid.to(ddf) return grid