def test_softmax_rgb_blend(self): # Create dummy outputs of rasterization simulating a cube in the centre # of the image with surrounding padded values. N, S, K = 1, 8, 2 pix_to_face = -torch.ones((N, S, S, K), dtype=torch.int64) h = int(S / 2) pix_to_face_full = torch.randint(size=(N, h, h, K), low=0, high=100) s = int(S / 4) e = int(0.75 * S) pix_to_face[:, s:e, s:e, :] = pix_to_face_full bary_coords = torch.ones((N, S, S, K, 3)) random_sign_flip = torch.rand((N, S, S, K)) random_sign_flip[random_sign_flip > 0.5] *= -1.0 zbuf1 = torch.randn(size=(N, S, S, K)) # randomly flip the sign of the distance # (-) means inside triangle, (+) means outside triangle. dists1 = torch.randn(size=(N, S, S, K)) * random_sign_flip dists2 = dists1.clone() zbuf2 = zbuf1.clone() dists1.requires_grad = True dists2.requires_grad = True zbuf1.requires_grad = True zbuf2.requires_grad = True colors = torch.randn_like(bary_coords) fragments1 = Fragments( pix_to_face=pix_to_face, bary_coords=bary_coords, # dummy zbuf=zbuf1, dists=dists1, ) fragments2 = Fragments( pix_to_face=pix_to_face, bary_coords=bary_coords, # dummy zbuf=zbuf2, dists=dists2, ) blend_params = BlendParams(sigma=1e-1) images = softmax_rgb_blend(colors, fragments1, blend_params) images_naive = softmax_blend_naive(colors, fragments2, blend_params) self.assertTrue(torch.allclose(images, images_naive)) # Check gradients. images.sum().backward() self.assertTrue(hasattr(dists1, "grad")) self.assertTrue(hasattr(zbuf1, "grad")) images_naive.sum().backward() self.assertTrue(hasattr(dists2, "grad")) self.assertTrue(hasattr(zbuf2, "grad")) self.assertTrue(torch.allclose(dists1.grad, dists2.grad, atol=2e-5)) self.assertTrue(torch.allclose(zbuf1.grad, zbuf2.grad, atol=2e-5)) # Helpful comments below.# Helpful comments below.
def fn(): # test forward and backward pass images = softmax_rgb_blend(colors, fragments, blend_params) images.sum().backward() torch.cuda.synchronize()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: colors = torch.stack([fragments.zbuf, fragments.zbuf, fragments.zbuf]).permute(1, 2, 3, 4, 0) images = softmax_rgb_blend(colors, fragments, self.blend_params) print(images.shape) return images[..., 2]