def get_grid_dict(patch_size: int = 32) -> Dict[str, torch.Tensor]: r"""Get cartesian and polar parametrizations of grid.""" kgrid = create_meshgrid(height=patch_size, width=patch_size, normalized_coordinates=True) x = kgrid[0, :, :, 0] y = kgrid[0, :, :, 1] rho, phi = cart2pol(x, y) grid_dict = {'x': x, 'y': y, 'rho': rho, 'phi': phi} return grid_dict
def forward(self, x: torch.Tensor) -> torch.Tensor: if not isinstance(x, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(x)}") if not len(x.shape) == 4: raise ValueError( f"Invalid input shape, we expect Bx1xHxW. Got: {x.shape}") # Modify 'diff' gradient. Before we had lambda function, but it is not jittable grads_xy = -self.grad(x) gx = grads_xy[:, :, 0, :, :] gy = grads_xy[:, :, 1, :, :] y = torch.cat(cart2pol(gx, gy, self.eps), dim=1) return y
def __init__(self, patch_size: int = 32, relative: bool = False) -> None: super().__init__() self.patch_size = patch_size self.relative = relative self.eps = 1e-8 # Theta kernel for gradients. self.kernel = VonMisesKernel(patch_size=patch_size, coeffs=COEFFS['theta']) # Relative gradients. kgrid = create_meshgrid(height=patch_size, width=patch_size, normalized_coordinates=True) _, phi = cart2pol(kgrid[:, :, :, 0], kgrid[:, :, :, 1]) self.register_buffer('phi', phi)