def _lowpassfilter(size: Tuple[int, int], cutoff: float, n: int) -> torch.Tensor: r""" Constructs a low-pass Butterworth filter. Args: size: Tuple with heigth and width of filter to construct cutoff: Cutoff frequency of the filter in (0, 0.5() n: Filter order. Higher `n` means sharper transition. Note that `n` is doubled so that it is always an even integer. Returns: f = 1 / (1 + w/cutoff) ^ 2n Note: The frequency origin of the returned filter is at the corners. """ assert 0 < cutoff <= 0.5, "Cutoff frequency must be between 0 and 0.5" assert n > 1 and int(n) == n, "n must be an integer >= 1" grid_x, grid_y = get_meshgrid(size) # A matrix with every pixel = radius relative to centre. radius = torch.sqrt(grid_x**2 + grid_y**2) return ifftshift(1. / (1.0 + (radius / cutoff)**(2 * n)))
def sdsp(x: torch.Tensor, data_range: Union[int, float] = 255, omega_0: float = 0.021, sigma_f: float = 1.34, sigma_d: float = 145., sigma_c: float = 0.001) -> torch.Tensor: r"""SDSP algorithm for salient region detection from a given image. Supports only colour images with RGB channel order. Args: x: Tensor. Shape :math:`(N, 3, H, W)`. data_range: Maximum value range of images (usually 1.0 or 255). omega_0: coefficient for log Gabor filter sigma_f: coefficient for log Gabor filter sigma_d: coefficient for the central areas, which have a bias towards attention sigma_c: coefficient for the warm colors, which have a bias towards attention Returns: torch.Tensor: Visual saliency map """ x = x / data_range * 255 size = x.size() size_to_use = (256, 256) x = interpolate(input=x, size=size_to_use, mode='bilinear', align_corners=False) x_lab = rgb2lab(x, data_range=255) x_fft = torch.rfft(x_lab, 2, onesided=False) lg = _log_gabor(size_to_use, omega_0, sigma_f).to(x_fft).view(1, 1, *size_to_use, 1) x_ifft_real = torch.ifft(x_fft * lg, 2)[..., 0] s_f = x_ifft_real.pow(2).sum(dim=1, keepdim=True).sqrt() coordinates = torch.stack(get_meshgrid(size_to_use), dim=0).to(x) coordinates = coordinates * size_to_use[0] + 1 s_d = torch.exp(-torch.sum(coordinates**2, dim=0) / sigma_d**2).view( 1, 1, *size_to_use) eps = torch.finfo(x_lab.dtype).eps min_x = x_lab.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values max_x = x_lab.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values normalized = (x_lab - min_x) / (max_x - min_x + eps) norm = normalized[:, 1:].pow(2).sum(dim=1, keepdim=True) s_c = 1 - torch.exp(-norm / sigma_c**2) vs_m = s_f * s_d * s_c vs_m = interpolate(vs_m, size[-2:], mode='bilinear', align_corners=True) min_vs_m = vs_m.min(dim=-1, keepdim=True).values.min(dim=-2, keepdim=True).values max_vs_m = vs_m.max(dim=-1, keepdim=True).values.max(dim=-2, keepdim=True).values return (vs_m - min_vs_m) / (max_vs_m - min_vs_m + eps)
def _log_gabor(size: Tuple[int, int], omega_0: float, sigma_f: float) -> torch.Tensor: r"""Creates log Gabor filter Args: size: size of the requires log Gabor filter omega_0: center frequency of the filter sigma_f: bandwidth of the filter Returns: log Gabor filter """ xx, yy = get_meshgrid(size) radius = (xx**2 + yy**2).sqrt() mask = radius <= 0.5 r = radius * mask r = ifftshift(r) r[0, 0] = 1 lg = torch.exp((-(r / omega_0).log().pow(2)) / (2 * sigma_f**2)) lg[0, 0] = 0 return lg
def _construct_filters(x: torch.Tensor, scales: int = 4, orientations: int = 4, min_length: int = 6, mult: int = 2, sigma_f: float = 0.55, delta_theta: float = 1.2, k: float = 2.0): """Creates a stack of filters used for computation of phase congruensy maps Args: x: Tensor with shape (N, 1, H, W). scales: Number of wavelets orientations: Number of filter orientations min_length: Wavelength of smallest scale filter mult: Scaling factor between successive filters sigma_f: Ratio of the standard deviation of the Gaussian describing the log Gabor filter's transfer function in the frequency domain to the filter center frequency. delta_theta: Ratio of angular interval between filter orientations and the standard deviation of the angular Gaussian function used to construct filters in the freq. plane. k: No of standard deviations of the noise energy beyond the mean at which we set the noise threshold point, below which phase congruency values get penalized. """ N, _, H, W = x.shape # Calculate the standard deviation of the angular Gaussian function # used to construct filters in the freq. plane. theta_sigma = math.pi / (orientations * delta_theta) # Pre-compute some stuff to speed up filter construction grid_x, grid_y = get_meshgrid((H, W)) radius = torch.sqrt(grid_x**2 + grid_y**2) theta = torch.atan2(-grid_y, grid_x) # Quadrant shift radius and theta so that filters are constructed with 0 frequency at the corners. # Get rid of the 0 radius value at the 0 frequency point (now at top-left corner) # so that taking the log of the radius will not cause trouble. radius = ifftshift(radius) theta = ifftshift(theta) radius[0, 0] = 1 sintheta = torch.sin(theta) costheta = torch.cos(theta) # Filters are constructed in terms of two components. # 1) The radial component, which controls the frequency band that the filter responds to # 2) The angular component, which controls the orientation that the filter responds to. # The two components are multiplied together to construct the overall filter. # First construct a low-pass filter that is as large as possible, yet falls # away to zero at the boundaries. All log Gabor filters are multiplied by # this to ensure no extra frequencies at the 'corners' of the FFT are # incorporated as this seems to upset the normalisation process when lp = _lowpassfilter(size=(H, W), cutoff=.45, n=15) # Construct the radial filter components... log_gabor = [] for s in range(scales): wavelength = min_length * mult**s omega_0 = 1.0 / wavelength gabor_filter = torch.exp( (-torch.log(radius / omega_0)**2) / (2 * math.log(sigma_f)**2)) gabor_filter = gabor_filter * lp gabor_filter[0, 0] = 0 log_gabor.append(gabor_filter) # Then construct the angular filter components... spread = [] for o in range(orientations): angl = o * math.pi / orientations # For each point in the filter matrix calculate the angular distance from # the specified filter orientation. To overcome the angular wrap-around # problem sine difference and cosine difference values are first computed # and then the atan2 function is used to determine angular distance. ds = sintheta * math.cos(angl) - costheta * math.sin( angl) # Difference in sine. dc = costheta * math.cos(angl) + sintheta * math.sin( angl) # Difference in cosine. dtheta = torch.abs(torch.atan2(ds, dc)) spread.append(torch.exp((-dtheta**2) / (2 * theta_sigma**2))) spread = torch.stack(spread) log_gabor = torch.stack(log_gabor) # Multiply, add batch dimension and transfer to correct device. filters = (spread.repeat_interleave(scales, dim=0) * log_gabor.repeat(orientations, 1, 1)).unsqueeze(0).to(x) return filters