Ejemplo n.º 1
0
    def read(
        self,
        path: PathOrStr,
        include_textures: bool,
        device,
        path_manager: PathManager,
        **kwargs,
    ) -> Optional[Meshes]:
        if not endswith(path, self.known_suffixes):
            return None

        with _open_file(path, path_manager, "rb") as f:
            data = _load_off_stream(f)
        verts = torch.from_numpy(data["verts"]).to(device)
        if "faces" in data:
            faces = torch.from_numpy(data["faces"]).to(dtype=torch.int64,
                                                       device=device)
        else:
            faces = torch.zeros((0, 3), dtype=torch.int64, device=device)

        textures = None
        if "verts_colors" in data:
            if "faces_colors" in data:
                msg = "Faces colors ignored because vertex colors provided too."
                warnings.warn(msg)
            verts_colors = torch.from_numpy(data["verts_colors"]).to(device)
            textures = TexturesVertex([verts_colors])
        elif "faces_colors" in data:
            faces_colors = torch.from_numpy(data["faces_colors"]).to(device)
            textures = TexturesAtlas([faces_colors[:, None, None, :]])

        mesh = Meshes(verts=[verts.to(device)],
                      faces=[faces.to(device)],
                      textures=textures)
        return mesh
Ejemplo n.º 2
0
def shapenet_models(params, index: int = 0):
    """Generator of shapenet models
    """
    model_path = "models/model_normalized.obj"
    synset = params.synsets[params.category]

    model_list = os.listdir(join(params.shapenet_path, synset))
    model_paths = [
        join(params.shapenet_path, synset, c, model_path) for c in model_list
    ]
    for num, path in enumerate(model_paths):
        try:
            verts, faces, aux = load_obj(path,
                                         load_textures=True,
                                         create_texture_atlas=True)
            mesh = Meshes(
                verts=[verts],
                faces=[faces.verts_idx],
                textures=TexturesAtlas(atlas=[aux.texture_atlas]),
            ).to(device)
            print(f"Adding mesh num {num}: {model_list[num]} ")

            yield mesh, model_list[num]

        except Exception as e:
            # car_exclude_pytorch3d.append(car_list[num])
            print(e, model_list[num])
            continue
Ejemplo n.º 3
0
def load_objs_as_meshes(
    files: list,
    device=None,
    load_textures: bool = True,
    create_texture_atlas: bool = False,
    texture_atlas_size: int = 4,
    texture_wrap: Optional[str] = "repeat",
    path_manager: Optional[PathManager] = None,
):
    """
    Load meshes from a list of .obj files using the load_obj function, and
    return them as a Meshes object. This only works for meshes which have a
    single texture image for the whole mesh. See the load_obj function for more
    details. material_colors and normals are not stored.

    Args:
        files: A list of file-like objects (with methods read, readline, tell,
            and seek), pathlib paths or strings containing file names.
        device: Desired device of returned Meshes. Default:
            uses the current device for the default tensor type.
        load_textures: Boolean indicating whether material files are loaded
        create_texture_atlas, texture_atlas_size, texture_wrap: as for load_obj.
        path_manager: optionally a PathManager object to interpret paths.

    Returns:
        New Meshes object.
    """
    mesh_list = []
    for f_obj in files:
        verts, faces, aux = load_obj(
            f_obj,
            load_textures=load_textures,
            create_texture_atlas=create_texture_atlas,
            texture_atlas_size=texture_atlas_size,
            texture_wrap=texture_wrap,
            path_manager=path_manager,
        )
        tex = None
        if create_texture_atlas:
            # TexturesAtlas type
            tex = TexturesAtlas(atlas=[aux.texture_atlas.to(device)])
        else:
            # TexturesUV type
            tex_maps = aux.texture_images
            if tex_maps is not None and len(tex_maps) > 0:
                verts_uvs = aux.verts_uvs.to(device)  # (V, 2)
                faces_uvs = faces.textures_idx.to(device)  # (F, 3)
                image = list(tex_maps.values())[0].to(device)[None]
                tex = TexturesUV(verts_uvs=[verts_uvs],
                                 faces_uvs=[faces_uvs],
                                 maps=image)

        mesh = Meshes(verts=[verts.to(device)],
                      faces=[faces.verts_idx.to(device)],
                      textures=tex)
        mesh_list.append(mesh)
    if len(mesh_list) == 1:
        return mesh_list[0]
    return join_meshes_as_batch(mesh_list)
Ejemplo n.º 4
0
    def test_join_meshes_as_batch(self):
        """
        Test that join_meshes_as_batch and load_objs_as_meshes are consistent
        with single meshes.
        """
        def check_triple(mesh, mesh3):
            """
            Verify that mesh3 is three copies of mesh.
            """
            def check_item(x, y):
                self.assertEqual(x is None, y is None)
                if x is not None:
                    self.assertClose(torch.cat([x, x, x]), y)

            check_item(mesh.verts_padded(), mesh3.verts_padded())
            check_item(mesh.faces_padded(), mesh3.faces_padded())

            if mesh.textures is not None:
                if isinstance(mesh.textures, TexturesUV):
                    check_item(
                        mesh.textures.faces_uvs_padded(),
                        mesh3.textures.faces_uvs_padded(),
                    )
                    check_item(
                        mesh.textures.verts_uvs_padded(),
                        mesh3.textures.verts_uvs_padded(),
                    )
                    check_item(mesh.textures.maps_padded(),
                               mesh3.textures.maps_padded())
                elif isinstance(mesh.textures, TexturesVertex):
                    check_item(
                        mesh.textures.verts_features_padded(),
                        mesh3.textures.verts_features_padded(),
                    )
                elif isinstance(mesh.textures, TexturesAtlas):
                    check_item(mesh.textures.atlas_padded(),
                               mesh3.textures.atlas_padded())

        DATA_DIR = Path(
            __file__).resolve().parent.parent / "docs/tutorials/data"
        obj_filename = DATA_DIR / "cow_mesh/cow.obj"

        mesh = load_objs_as_meshes([obj_filename])
        mesh3 = load_objs_as_meshes([obj_filename, obj_filename, obj_filename])
        check_triple(mesh, mesh3)
        self.assertTupleEqual(mesh.textures.maps_padded().shape,
                              (1, 1024, 1024, 3))

        # Try mismatched texture map sizes, which needs a call to interpolate()
        mesh2048 = mesh.clone()
        maps = mesh.textures.maps_padded()
        mesh2048.textures._maps_padded = torch.cat([maps, maps], dim=1)
        join_meshes_as_batch([mesh.to("cuda:0"), mesh2048.to("cuda:0")])

        mesh_notex = load_objs_as_meshes([obj_filename], load_textures=False)
        mesh3_notex = load_objs_as_meshes(
            [obj_filename, obj_filename, obj_filename], load_textures=False)
        check_triple(mesh_notex, mesh3_notex)
        self.assertIsNone(mesh_notex.textures)

        # meshes with vertex texture, join into a batch.
        verts = torch.randn((4, 3), dtype=torch.float32)
        faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
        vert_tex = torch.ones_like(verts)
        rgb_tex = TexturesVertex(verts_features=[vert_tex])
        mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=rgb_tex)
        mesh_rgb3 = join_meshes_as_batch([mesh_rgb, mesh_rgb, mesh_rgb])
        check_triple(mesh_rgb, mesh_rgb3)

        # meshes with texture atlas, join into a batch.
        device = "cuda:0"
        atlas = torch.rand((2, 4, 4, 3), dtype=torch.float32, device=device)
        atlas_tex = TexturesAtlas(atlas=[atlas])
        mesh_atlas = Meshes(verts=[verts], faces=[faces], textures=atlas_tex)
        mesh_atlas3 = join_meshes_as_batch(
            [mesh_atlas, mesh_atlas, mesh_atlas])
        check_triple(mesh_atlas, mesh_atlas3)

        # Test load multiple meshes with textures into a batch.
        teapot_obj = DATA_DIR / "teapot.obj"
        mesh_teapot = load_objs_as_meshes([teapot_obj])
        teapot_verts, teapot_faces = mesh_teapot.get_mesh_verts_faces(0)
        mix_mesh = load_objs_as_meshes([obj_filename, teapot_obj],
                                       load_textures=False)
        self.assertEqual(len(mix_mesh), 2)
        self.assertClose(mix_mesh.verts_list()[0], mesh.verts_list()[0])
        self.assertClose(mix_mesh.faces_list()[0], mesh.faces_list()[0])
        self.assertClose(mix_mesh.verts_list()[1], teapot_verts)
        self.assertClose(mix_mesh.faces_list()[1], teapot_faces)

        cow3_tea = join_meshes_as_batch([mesh3, mesh_teapot],
                                        include_textures=False)
        self.assertEqual(len(cow3_tea), 4)
        check_triple(mesh_notex, cow3_tea[:3])
        self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
        self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])

        # Check error raised if all meshes in the batch don't have the same texture type
        with self.assertRaisesRegex(ValueError, "same type of texture"):
            join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
Ejemplo n.º 5
0
    def test_save_load_icosphere(self):
        # Test that saving a mesh as an off file and loading it results in the
        # same data on the correct device, for all permitted types of textures.
        # Standard test is for random colors, but also check totally white,
        # because there's a different in OFF semantics between "1.0" color (=full)
        # and "1" (= 1/255 color)
        sphere = ico_sphere(0)
        io = IO()
        device = torch.device("cuda:0")

        atlas_padded = torch.rand(1, sphere.faces_list()[0].shape[0], 1, 1, 3)
        atlas = TexturesAtlas(atlas_padded)

        atlas_padded_white = torch.ones(1,
                                        sphere.faces_list()[0].shape[0], 1, 1,
                                        3)
        atlas_white = TexturesAtlas(atlas_padded_white)

        verts_colors_padded = torch.rand(1, sphere.verts_list()[0].shape[0], 3)
        vertex_texture = TexturesVertex(verts_colors_padded)

        verts_colors_padded_white = torch.ones(1,
                                               sphere.verts_list()[0].shape[0],
                                               3)
        vertex_texture_white = TexturesVertex(verts_colors_padded_white)

        # No colors case
        with NamedTemporaryFile(mode="w", suffix=".off") as f:
            io.save_mesh(sphere, f.name)
            f.flush()
            mesh1 = io.load_mesh(f.name, device=device)
        self.assertEqual(mesh1.device, device)
        mesh1 = mesh1.cpu()
        self.assertClose(mesh1.verts_padded(), sphere.verts_padded())
        self.assertClose(mesh1.faces_padded(), sphere.faces_padded())
        self.assertIsNone(mesh1.textures)

        # Atlas case
        sphere.textures = atlas
        with NamedTemporaryFile(mode="w", suffix=".off") as f:
            io.save_mesh(sphere, f.name)
            f.flush()
            mesh2 = io.load_mesh(f.name, device=device)

        self.assertEqual(mesh2.device, device)
        mesh2 = mesh2.cpu()
        self.assertClose(mesh2.verts_padded(), sphere.verts_padded())
        self.assertClose(mesh2.faces_padded(), sphere.faces_padded())
        self.assertClose(mesh2.textures.atlas_padded(),
                         atlas_padded,
                         atol=1e-4)

        # White atlas case
        sphere.textures = atlas_white
        with NamedTemporaryFile(mode="w", suffix=".off") as f:
            io.save_mesh(sphere, f.name)
            f.flush()
            mesh3 = io.load_mesh(f.name)

        self.assertClose(mesh3.textures.atlas_padded(),
                         atlas_padded_white,
                         atol=1e-4)

        # TexturesVertex case
        sphere.textures = vertex_texture
        with NamedTemporaryFile(mode="w", suffix=".off") as f:
            io.save_mesh(sphere, f.name)
            f.flush()
            mesh4 = io.load_mesh(f.name, device=device)

        self.assertEqual(mesh4.device, device)
        mesh4 = mesh4.cpu()
        self.assertClose(mesh4.verts_padded(), sphere.verts_padded())
        self.assertClose(mesh4.faces_padded(), sphere.faces_padded())
        self.assertClose(mesh4.textures.verts_features_padded(),
                         verts_colors_padded,
                         atol=1e-4)

        # white TexturesVertex case
        sphere.textures = vertex_texture_white
        with NamedTemporaryFile(mode="w", suffix=".off") as f:
            io.save_mesh(sphere, f.name)
            f.flush()
            mesh5 = io.load_mesh(f.name)

        self.assertClose(mesh5.textures.verts_features_padded(),
                         verts_colors_padded_white,
                         atol=1e-4)
Ejemplo n.º 6
0
def pred_synth(segpose,
               params: Params,
               mesh_type: str = "dolphin",
               device: str = "cuda"):

    if not params.pred_dir and not os.path.exists(params.pred_dir):
        raise FileNotFoundError(
            "Prediction directory has not been set or the file does not exist, please set using cli args or params"
        )
    pred_folders = [
        join(params.pred_dir, f) for f in os.listdir(params.pred_dir)
    ]
    count = 1
    for p in sorted(pred_folders):
        try:
            print(p)
            manager = RenderManager.from_path(p)
            manager.rectify_paths(base_folder=params.pred_dir)
        except FileNotFoundError:
            continue
        # Run Silhouette Prediction Network
        logging.info(f"Starting mask predictions")
        mask_priors = []
        R_pred, T_pred = [], []
        q_loss, t_loss = 0, 0
        # Collect Translation stats
        R_gt, T_gt = manager._trajectory
        poses_gt = EvMaskPoseDataset.preprocess_poses(manager._trajectory)
        std_T, mean_T = torch.std_mean(T_gt)
        for idx in range(len(manager)):
            try:
                ev_frame = manager.get_event_frame(idx)
            except Exception as e:
                print(e)
                break
            mask_pred, pose_pred = predict_segpose(segpose, ev_frame,
                                                   params.threshold_conf,
                                                   params.img_size)
            # mask_pred = smooth_predicted_mask(mask_pred)
            manager.add_pred(idx, mask_pred, "silhouette")
            mask_priors.append(torch.from_numpy(mask_pred))

            # Make qexp a torch function
            # q_pred = qexp(pose_pred[:, 3:])
            # q_targ = qexp(poses_gt[idx, 3:].unsqueeze(0))
            ####  SHOULD THIS BE NORMALIZED ??
            q_pred = pose_pred[:, 3:]
            q_targ = poses_gt[idx, 3:]

            q_pred_unit = q_pred / torch.norm(q_pred)
            q_targ_unit = q_targ / torch.norm(q_targ)
            # print("learnt: ", q_pred_unit, q_targ_unit)

            t_pred = pose_pred[:, :3] * std_T + mean_T
            t_targ = poses_gt[idx, :3] * std_T + mean_T
            T_pred.append(t_pred)

            q_loss += quaternion_angular_error(q_pred_unit, q_targ_unit)
            t_loss += t_error(t_pred, t_targ)

            r_pred = rc.quaternion_to_matrix(q_pred).unsqueeze(0)
            R_pred.append(r_pred.squeeze(0))

        q_loss_mean = q_loss / (idx + 1)
        t_loss_mean = t_loss / (idx + 1)

        # Convert R,T to world-to-view transforms --> Pytorch3d convention for the :

        R_pred_abs = torch.cat(R_pred)
        T_pred_abs = torch.cat(T_pred)
        # Take inverse of view-to-world (output of net) to obtain w2v
        wtv_trans = (get_world_to_view_transform(
            R=R_pred_abs, T=T_pred_abs).inverse().get_matrix())
        T_pred = wtv_trans[:, 3, :3]
        R_pred = wtv_trans[:, :3, :3]
        R_pred_test = look_at_rotation(T_pred_abs)
        T_pred_test = -torch.bmm(R_pred_test.transpose(1, 2),
                                 T_pred_abs[:, :, None])[:, :, 0]
        # Convert back to view-to-world to get absolute
        vtw_trans = (get_world_to_view_transform(
            R=R_pred_test, T=T_pred_test).inverse().get_matrix())
        T_pred_trans = vtw_trans[:, 3, :3]
        R_pred_trans = vtw_trans[:, :3, :3]

        # Calc pose error for this:
        q_loss_mean_test = 0
        t_loss_mean_test = 0
        for idx in range(len(R_pred_test)):
            q_pred_trans = rc.matrix_to_quaternion(R_pred_trans[idx]).squeeze()
            q_targ = poses_gt[idx, 3:]
            q_targ_unit = q_targ / torch.norm(q_targ)
            # print("look: ", q_test, q_targ)
            q_loss_mean_test += quaternion_angular_error(
                q_pred_trans, q_targ_unit)
            t_targ = poses_gt[idx, :3] * std_T + mean_T
            t_loss_mean_test += t_error(T_pred_trans[idx], t_targ)
        q_loss_mean_test /= idx + 1
        t_loss_mean_test /= idx + 1

        logging.info(
            f"Mean Translation Error: {t_loss_mean}; Mean Rotation Error: {q_loss_mean}"
        )
        logging.info(
            f"Mean Translation Error: {t_loss_mean_test}; Mean Rotation Error: {q_loss_mean_test}"
        )

        # Plot estimated cameras
        logging.info(f"Plotting pose map")
        idx = random.sample(range(len(R_gt)), k=2)
        pose_plot = plot_cams_from_poses(
            (R_gt[idx], T_gt[idx]), (R_pred[idx], T_pred[idx]), params.device)
        pose_plot_test = plot_cams_from_poses(
            (R_gt[idx], T_gt[idx]), (R_pred_test[idx], T_pred_test[idx]),
            params.device)
        manager.add_pose_plot(pose_plot, "rot+trans")
        manager.add_pose_plot(pose_plot_test, "trans")

        count += 1
        groundtruth_silhouettes = manager._images("silhouette") / 255.0
        print(groundtruth_silhouettes.shape)
        print(torch.stack((mask_priors)).shape)
        seg_iou = neg_iou_loss(groundtruth_silhouettes,
                               torch.stack((mask_priors)) / 255.0)
        print("Seg IoU", seg_iou)

        # RUN MESH DEFORMATION

        # RUN MESH DEFORMATION
        # Run it 3 times: w/ Rot+Trans - w/ Trans+LookAt - w/ GT Pose
        experiments = {
            "GT-Pose": [R_gt, T_gt],
            # "Rot+Trans": [R_pred, T_pred],
            # "Trans+LookAt": [R_pred_test, T_pred_test]
        }

        results = {}
        input_m = torch.stack((mask_priors))

        for i in range(len(experiments.keys())):

            logging.info(
                f"Input pred shape & max: {input_m.shape}, {input_m.max()}")
            # The MeshDeformation model will return silhouettes across all view by default

            mesh_model = MeshDeformationModel(device=device, params=params)

            R, T = list(experiments.values())[i]
            experiment_results = mesh_model.run_optimization(input_m, R, T)
            renders = mesh_model.render_final_mesh((R, T), "predict",
                                                   input_m.shape[-2:])

            mesh_silhouettes = renders["silhouettes"].squeeze(1)
            mesh_images = renders["images"].squeeze(1)
            experiment_name = list(experiments.keys())[i]
            for idx in range(len(mesh_silhouettes)):
                manager.add_pred(
                    idx,
                    mesh_silhouettes[idx].cpu().numpy(),
                    "silhouette",
                    destination=f"mesh_{experiment_name}",
                )
                manager.add_pred(
                    idx,
                    mesh_images[idx].cpu().numpy(),
                    "phong",
                    destination=f"mesh_{experiment_name}",
                )

            # Calculate chamfer loss:
            mesh_pred = mesh_model._final_mesh
            if mesh_type == "dolphin":
                path = "data/meshes/dolphin/dolphin.obj"
                mesh_gt = load_objs_as_meshes(
                    [path],
                    create_texture_atlas=False,
                    load_textures=True,
                    device=device,
                )
            # Shapenet Cars
            elif mesh_type == "shapenet":
                mesh_info = manager.metadata["mesh_info"]
                path = os.path.join(
                    params.gt_mesh_path,
                    f"/ShapeNetCorev2/{mesh_info['synset_id']}/{mesh_info['mesh_id']}/models/model_normalized.obj"
                )
                # path = f"data/ShapeNetCorev2/{mesh_info['synset_id']}/{mesh_info['mesh_id']}/models/model_normalized.obj"
                try:
                    verts, faces, aux = load_obj(path,
                                                 load_textures=True,
                                                 create_texture_atlas=True)

                    mesh_gt = Meshes(
                        verts=[verts],
                        faces=[faces.verts_idx],
                        textures=TexturesAtlas(atlas=[aux.texture_atlas]),
                    ).to(device)
                except:
                    mesh_gt = None
                    print("CANNOT COMPUTE CHAMFER LOSS")
            if mesh_gt:
                mesh_pred_compute, mesh_gt_compute = scale_meshes(
                    mesh_pred.clone(), mesh_gt.clone())
                pcl_pred = sample_points_from_meshes(mesh_pred_compute,
                                                     num_samples=5000)
                pcl_gt = sample_points_from_meshes(mesh_gt_compute,
                                                   num_samples=5000)
                chamfer_loss = chamfer_distance(pcl_pred,
                                                pcl_gt,
                                                point_reduction="mean")
                print("CHAMFER LOSS: ", chamfer_loss)
                experiment_results["chamfer_loss"] = (
                    chamfer_loss[0].cpu().detach().numpy().tolist())

            mesh_iou = neg_iou_loss(groundtruth_silhouettes, mesh_silhouettes)

            experiment_results["mesh_iou"] = mesh_iou.cpu().numpy().tolist()

            results[experiment_name] = experiment_results

            manager.add_pred_mesh(mesh_pred, experiment_name)
        # logging.info(f"Input pred shape & max: {input_m.shape}, {input_m.max()}")
        # # The MeshDeformation model will return silhouettes across all view by default
        #
        #
        #
        # experiment_results = models["mesh"].run_optimization(input_m, R_gt, T_gt)
        # renders = models["mesh"].render_final_mesh(
        #     (R_gt, T_gt), "predict", input_m.shape[-2:]
        # )
        #
        # mesh_silhouettes = renders["silhouettes"].squeeze(1)
        # mesh_images = renders["images"].squeeze(1)
        # experiment_name = params.name
        # for idx in range(len(mesh_silhouettes)):
        #     manager.add_pred(
        #         idx,
        #         mesh_silhouettes[idx].cpu().numpy(),
        #         "silhouette",
        #         destination=f"mesh_{experiment_name}",
        #     )
        #     manager.add_pred(
        #         idx,
        #         mesh_images[idx].cpu().numpy(),
        #         "phong",
        #         destination=f"mesh_{experiment_name}",
        #     )
        #
        # # Calculate chamfer loss:
        # mesh_pred = models["mesh"]._final_mesh
        # if mesh_type == "dolphin":
        #     path = params.gt_mesh_path
        #     mesh_gt = load_objs_as_meshes(
        #         [path],
        #         create_texture_atlas=False,
        #         load_textures=True,
        #         device=device,
        #     )
        # # Shapenet Cars
        # elif mesh_type == "shapenet":
        #     mesh_info = manager.metadata["mesh_info"]
        #     path = params.gt_mesh_path
        #     try:
        #         verts, faces, aux = load_obj(
        #             path, load_textures=True, create_texture_atlas=True
        #         )
        #
        #         mesh_gt = Meshes(
        #             verts=[verts],
        #             faces=[faces.verts_idx],
        #             textures=TexturesAtlas(atlas=[aux.texture_atlas]),
        #         ).to(device)
        #     except:
        #         mesh_gt = None
        #         print("CANNOT COMPUTE CHAMFER LOSS")
        # if mesh_gt and params.is_real_data:
        #     mesh_pred_compute, mesh_gt_compute = scale_meshes(
        #         mesh_pred.clone(), mesh_gt.clone()
        #     )
        #     pcl_pred = sample_points_from_meshes(
        #         mesh_pred_compute, num_samples=5000
        #     )
        #     pcl_gt = sample_points_from_meshes(mesh_gt_compute, num_samples=5000)
        #     chamfer_loss = chamfer_distance(
        #         pcl_pred, pcl_gt, point_reduction="mean"
        #     )
        #     print("CHAMFER LOSS: ", chamfer_loss)
        #     experiment_results["chamfer_loss"] = (
        #         chamfer_loss[0].cpu().detach().numpy().tolist()
        #     )
        #
        # mesh_iou = neg_iou_loss_all(groundtruth_silhouettes, mesh_silhouettes)
        #
        # experiment_results["mesh_iou"] = mesh_iou.cpu().numpy().tolist()
        #
        # results[experiment_name] = experiment_results
        #
        # manager.add_pred_mesh(mesh_pred, experiment_name)

        seg_iou = neg_iou_loss_all(groundtruth_silhouettes, input_m / 255.0)
        gt_iou = neg_iou_loss_all(groundtruth_silhouettes,
                                  groundtruth_silhouettes)

        results["mesh_iou"] = mesh_iou.detach().cpu().numpy().tolist()
        results["seg_iou"] = seg_iou.detach().cpu().numpy().tolist()
        logging.info(f"Mesh IOU list & results: {mesh_iou}")
        logging.info(f"Seg IOU list & results: {seg_iou}")
        logging.info(f"GT IOU list & results: {gt_iou} ")

        # results["mean_iou"] = IOULoss().forward(groundtruth, mesh_silhouettes).detach().cpu().numpy().tolist()
        # results["mean_dice"] = DiceCoeffLoss().forward(groundtruth, mesh_silhouettes)

        manager.set_pred_results(results)
        manager.close()
Ejemplo n.º 7
0
def load_objs_as_meshes(
    files: list,
    device=None,
    load_textures: bool = True,
    create_texture_atlas: bool = False,
    texture_atlas_size: int = 4,
    texture_wrap: Optional[str] = "repeat",
    path_manager: Optional[PathManager] = None,
):
    """
    Load meshes from a list of .obj files using the load_obj function, and
    return them as a Meshes object. This only works for meshes which have a
    single texture image for the whole mesh. See the load_obj function for more
    details. material_colors and normals are not stored.

    Args:
        files: A list of file-like objects (with methods read, readline, tell,
            and seek), pathlib paths or strings containing file names.
        device: Desired device of returned Meshes. Default:
            uses the current device for the default tensor type.
        load_textures: Boolean indicating whether material files are loaded
        create_texture_atlas, texture_atlas_size, texture_wrap: as for load_obj.
        path_manager: optionally a PathManager object to interpret paths.

    Returns:
        New Meshes object.
    """
    mesh_list = []
    for f_obj in files:
        verts, faces, aux = load_obj(
            f_obj,
            load_textures=load_textures,
            create_texture_atlas=create_texture_atlas,
            texture_atlas_size=texture_atlas_size,
            texture_wrap=texture_wrap,
            path_manager=path_manager,
        )
        tex = None
        if create_texture_atlas:
            # TexturesAtlas type
            tex = TexturesAtlas(atlas=[aux.texture_atlas.to(device)])
        else:
            # TexturesUV type
            tex_maps = aux.texture_images
            textures = []
            if tex_maps is not None and len(tex_maps) > 0:
                verts_uvs = aux.verts_uvs.to(device)  # (V, 2)
                faces_uvs = faces.textures_idx.to(device)  # (F, 3)
                face_to_mat = faces.materials_idx

                # code for checking

                current_offset = 0
                for mat_idx, (mat_name,
                              tex_map) in enumerate(tex_maps.items()):
                    image = tex_map.flip(0).unsqueeze(0).to(device)
                    faces_mask = face_to_mat == mat_idx

                    faces_verts_uvs = faces_uvs[faces_mask].unique()
                    tex_verts_uvs = verts_uvs[faces_verts_uvs]

                    tex_faces_uvs = faces_uvs[faces_mask] - current_offset

                    tex = TexturesUV(verts_uvs=[tex_verts_uvs],
                                     faces_uvs=[tex_faces_uvs],
                                     maps=image,
                                     mat_names=[mat_name])
                    textures.append(tex)

                    current_offset += tex_verts_uvs.shape[0]

                tex = textures[0]
                if len(textures) > 1:
                    tex = tex.join_batch(textures[1:]).join_scene()

        mesh = Meshes(verts=[verts.to(device)],
                      faces=[faces.verts_idx.to(device)],
                      textures=tex,
                      aux=aux)
        mesh_list.append(mesh)
    if len(mesh_list) == 1:
        return mesh_list[0]
    return join_meshes_as_batch(mesh_list)