Exemple #1
0
    def test_hard_rgb_blend(self):
        N, H, W, K = 5, 10, 10, 20
        pix_to_face = torch.randint(low=-1, high=100, size=(N, H, W, K))
        bary_coords = torch.ones((N, H, W, K, 3))
        fragments = Fragments(
            pix_to_face=pix_to_face,
            bary_coords=bary_coords,
            zbuf=pix_to_face,  # dummy
            dists=pix_to_face,  # dummy
        )
        colors = torch.randn((N, H, W, K, 3))
        blend_params = BlendParams(1e-4, 1e-4, (0.5, 0.5, 1))
        images = hard_rgb_blend(colors, fragments, blend_params)

        # Examine if the foreground colors are correct.
        is_foreground = pix_to_face[..., 0] >= 0
        self.assertClose(images[is_foreground][:, :3],
                         colors[is_foreground][..., 0, :])

        # Examine if the background colors are correct.
        for i in range(3):  # i.e. RGB
            channel_color = blend_params.background_color[i]
            self.assertTrue(images[~is_foreground][...,
                                                   i].eq(channel_color).all())

        # Examine the alpha channel
        self.assertClose(images[..., 3], (pix_to_face[..., 0] >= 0).float())
Exemple #2
0
 def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
     
     blend_params = kwargs.get("blend_params", self.blend_params)
     
     texels = meshes.sample_textures(fragments)
     images = hard_rgb_blend(texels, fragments, blend_params)
     return images  # (N, H, W, 3) RGBA image
Exemple #3
0
 def test_hard_rgb_blend(self):
     N, H, W, K = 5, 10, 10, 20
     pix_to_face = torch.ones((N, H, W, K))
     bary_coords = torch.ones((N, H, W, K, 3))
     fragments = Fragments(
         pix_to_face=pix_to_face,
         bary_coords=bary_coords,
         zbuf=pix_to_face,  # dummy
         dists=pix_to_face,  # dummy
     )
     colors = bary_coords.clone()
     top_k = torch.randn((K, 3))
     colors[..., :, :] = top_k
     images = hard_rgb_blend(colors, fragments)
     expected_vals = torch.ones((N, H, W, 4))
     pix_cols = torch.ones_like(expected_vals[..., :3]) * top_k[0, :]
     expected_vals[..., :3] = pix_cols
     self.assertTrue(torch.allclose(images, expected_vals))
Exemple #4
0
    def render_torch(self,
                     verts,
                     faces,
                     rgb,
                     bcg_color=(1., 1., 1.),
                     get_depth=False,
                     get_alpha=False):
        # b, h, w = grid_3d.shape[:3]
        b = verts.size(0)
        textures = TexturesVertex(verts_features=rgb.view(b, -1, 3))
        mesh = Meshes(verts=verts, faces=faces, textures=textures)

        fragments = self.rasterizer_torch(mesh)
        texels = mesh.sample_textures(fragments)
        materials = Materials(device=verts.device)
        blend_params = BlendParams(background_color=bcg_color)
        images = hard_rgb_blend(texels, fragments, blend_params)
        images = images[..., :3].permute(0, 3, 1, 2)

        out = (images, )
        if get_depth:
            depth = fragments.zbuf[..., 0]
            mask = (depth == -1.0).float()
            max_depth = self.max_depth + 0.5 * (self.max_depth -
                                                self.min_depth)
            depth = mask * max_depth * torch.ones_like(depth) + (1 -
                                                                 mask) * depth
            out = out + (depth, )
        if get_alpha:
            colors = torch.ones_like(fragments.bary_coords)
            blend_params = BlendParams(sigma=1e-2,
                                       gamma=1e-4,
                                       background_color=(1., 1., 1.))
            alpha = sigmoid_alpha_blend(colors, fragments, blend_params)[...,
                                                                         -1]
            out = tuple(out) + (alpha, )
        if len(out) == 1:
            out = out[0]
        return out