def test_plot_gaussian_function(): mu = 0.5 sigma = 0.4 kp = torch.randn(1, 1, 2, requires_grad=True) * sigma + mu z = MF.gaussian_like_function(kp, 14, 14, sigma=0.1).squeeze().detach().numpy() #plot_heightmap3d(z) plot_heatmap2d(z)
def test_gaussian_like_batch(): ss = knn.SpatialSoftmax() k = ss(heatmap_batch) hm = MF.gaussian_like_function(k, 5, 5) print('') print(hm)
def test_bottlneck_grads(): heatmap = torch.rand(1, 1, 5, 5, requires_grad=True) h = heatmap.neg() ss = knn.SpatialSoftmax() kp = ss(h) ss = MF.gaussian_like_function(kp, 5, 5) loss = torch.sum(ss) loss.backward() print(heatmap.grad) print(kp[0].grad, kp[1].grad)
def test_gaussian_function_grads(): x, y = torch.rand(1, 5, requires_grad=True), torch.rand(1, 5, requires_grad=True) kp = x.neg(), y.neg() ss = MF.gaussian_like_function(kp, 5, 5) loss = torch.sum(ss) loss.backward() print(ss) print(x, y) print(x.grad, y.grad)
def test_co_ords(): height, width = 16, 16 hm = torch.zeros(1, 1, height, width) hm[0, 0, 0, 15] = 20.0 k, p = MF.spacial_softmax(hm, probs=True) g = MF.gaussian_like_function(k, height, width) #plot_heightmap3d(hm[0, 0].detach().numpy()) #plot_heightmap3d(g[0, 0].detach().numpy(), k[0, 0]) #plot_single_channel(hm[0, 0]) #plot_single_channel(g[0, 0]) d = UniImageViewer()
def test_align_kp_with_gaussian(): hm = heatmap() # image = TVF.to_pil_image(heatmap[0]) ss = knn.SpatialSoftmax() kp = ss(hm) z = MF.gaussian_like_function(kp, 5, 5, sigma=0.1).squeeze().detach().numpy() img = plot_keypoints_on_image(kp[0], hm[0]) plt.imshow(img) plot_heatmap2d(z)
def test_gaussian_like(): heatmap = torch.tensor([ [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 5, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], ]).expand(1, 1, 5, 5).float() ss = knn.SpatialSoftmax() k = ss(heatmap) hm = MF.gaussian_like_function(k, 5, 5) print('') print(hm)
def __getitem__(self, item): kp = torch.rand(1, self.keypoints, 2) pointmap = MF.point_map(kp, self.height, self.width) mask = MF.gaussian_like_function(kp, self.height, self.width) mask, _ = torch.max(mask, dim=1, keepdim=True) return pointmap.squeeze(0), mask.squeeze(0)