def smooth(self,
               vertices: np.ndarray,
               faces: np.ndarray,
               diffusion: ChunkGrid[float],
               original_mesh: Meshes,
               max_iteration=50):
        assert max_iteration > -1
        change = True
        iteration = 0
        loss_mesh = original_mesh.clone()
        smooth_verts = loss_mesh.verts_packed().clone()
        neighbors = self.compute_neighbors(vertices, faces)
        neighbor_len = torch.IntTensor(
            [len(neighbors[i]) for i in range(len(vertices))])
        neighbor_valences = torch.FloatTensor([
            sum([1 / neighbor_len[n] for n in neighbors[i]])
            for i in range(len(vertices))
        ])
        d = 1 + 1 / neighbor_len * neighbor_valences

        difference_max = torch.as_tensor(diffusion.get_values(vertices) + 1)

        while change and iteration < max_iteration:
            iteration += 1
            change = False

            for i in range(2):
                with torch.no_grad():
                    L = loss_mesh.laplacian_packed()
                loss = L.mm(loss_mesh.verts_packed())
                if i == 0:
                    loss_mesh = Meshes([loss], loss_mesh.faces_list())

            # new_vals = smooth_verts - (1/d).unsqueeze(1) * loss
            # difference = torch.sqrt(torch.sum(torch.pow(original_mesh.verts_packed() - new_vals, 2), dim=1))

            new_val = smooth_verts - (loss.T * (1 / d)).T
            differences = torch.linalg.norm(original_mesh.verts_packed() -
                                            new_val,
                                            dim=1)
            cond = differences < difference_max
            if torch.any(cond):
                smooth_verts[cond] = new_val[cond]
                change = True

            # for i, v in enumerate(vertices):
            #     new_val = smooth_verts[i] - (1 / d[i] * loss[i])
            #     difference = torch.dist(original_mesh.verts_packed()[i], new_val)
            #     if difference < difference_max[i]:
            #         smooth_verts[i] = new_val
            #         change = True

            loss_mesh = Meshes([smooth_verts], original_mesh.faces_list())
        return smooth_verts
Exemple #2
0
    def __call__(self, vertices, faces):
        ''' Right now only render silhouettes
            Input: 
            vertices: BN * V * 3
            faces: BN * F * 3
        '''
        torch_mesh = Meshes(verts=vertices.to(self.device),
                            faces=faces.to(self.device))
        silhouette = self.silhouette_renderer(meshes_world=torch_mesh.clone(),
                                              R=self.R,
                                              T=self.T)

        return silhouette
Exemple #3
0
textures = Textures(verts_rgb=verts_rgb.to(device))
faces = torch.tensor(np.int32(flamelayer.faces), dtype=torch.long).cuda()

my_mesh = Meshes(verts=[verts.to(device)],
                 faces=[faces.to(device)],
                 textures=textures)

camera_position = nn.Parameter(
    torch.from_numpy(np.array([1, 0.1, 0.1],
                              dtype=np.float32)).to(my_mesh.device))

R = look_at_rotation(camera_position[None, :], device=device)  # (1, 3, 3)
T = -torch.bmm(R.transpose(1, 2), camera_position[None, :, None])[:, :,
                                                                  0]  # (1, 3)

silhouete = silhouette_renderer(meshes_world=my_mesh.clone(), R=R, T=T)
silhouete = silhouete.detach().cpu().numpy()[..., 3]

print(silhouete.shape)
plt.imshow(silhouete.squeeze())
plt.show()


class Model(nn.Module):
    def __init__(self, meshes, renderer, image_ref):
        super().__init__()
        self.meshes = meshes
        self.device = meshes.device
        self.renderer = renderer

        # Get the silhouette of the reference RGB image by finding all the non zero values.
Exemple #4
0
def pred_synth(segpose,
               params: Params,
               mesh_type: str = "dolphin",
               device: str = "cuda"):

    if not params.pred_dir and not os.path.exists(params.pred_dir):
        raise FileNotFoundError(
            "Prediction directory has not been set or the file does not exist, please set using cli args or params"
        )
    pred_folders = [
        join(params.pred_dir, f) for f in os.listdir(params.pred_dir)
    ]
    count = 1
    for p in sorted(pred_folders):
        try:
            print(p)
            manager = RenderManager.from_path(p)
            manager.rectify_paths(base_folder=params.pred_dir)
        except FileNotFoundError:
            continue
        # Run Silhouette Prediction Network
        logging.info(f"Starting mask predictions")
        mask_priors = []
        R_pred, T_pred = [], []
        q_loss, t_loss = 0, 0
        # Collect Translation stats
        R_gt, T_gt = manager._trajectory
        poses_gt = EvMaskPoseDataset.preprocess_poses(manager._trajectory)
        std_T, mean_T = torch.std_mean(T_gt)
        for idx in range(len(manager)):
            try:
                ev_frame = manager.get_event_frame(idx)
            except Exception as e:
                print(e)
                break
            mask_pred, pose_pred = predict_segpose(segpose, ev_frame,
                                                   params.threshold_conf,
                                                   params.img_size)
            # mask_pred = smooth_predicted_mask(mask_pred)
            manager.add_pred(idx, mask_pred, "silhouette")
            mask_priors.append(torch.from_numpy(mask_pred))

            # Make qexp a torch function
            # q_pred = qexp(pose_pred[:, 3:])
            # q_targ = qexp(poses_gt[idx, 3:].unsqueeze(0))
            ####  SHOULD THIS BE NORMALIZED ??
            q_pred = pose_pred[:, 3:]
            q_targ = poses_gt[idx, 3:]

            q_pred_unit = q_pred / torch.norm(q_pred)
            q_targ_unit = q_targ / torch.norm(q_targ)
            # print("learnt: ", q_pred_unit, q_targ_unit)

            t_pred = pose_pred[:, :3] * std_T + mean_T
            t_targ = poses_gt[idx, :3] * std_T + mean_T
            T_pred.append(t_pred)

            q_loss += quaternion_angular_error(q_pred_unit, q_targ_unit)
            t_loss += t_error(t_pred, t_targ)

            r_pred = rc.quaternion_to_matrix(q_pred).unsqueeze(0)
            R_pred.append(r_pred.squeeze(0))

        q_loss_mean = q_loss / (idx + 1)
        t_loss_mean = t_loss / (idx + 1)

        # Convert R,T to world-to-view transforms --> Pytorch3d convention for the :

        R_pred_abs = torch.cat(R_pred)
        T_pred_abs = torch.cat(T_pred)
        # Take inverse of view-to-world (output of net) to obtain w2v
        wtv_trans = (get_world_to_view_transform(
            R=R_pred_abs, T=T_pred_abs).inverse().get_matrix())
        T_pred = wtv_trans[:, 3, :3]
        R_pred = wtv_trans[:, :3, :3]
        R_pred_test = look_at_rotation(T_pred_abs)
        T_pred_test = -torch.bmm(R_pred_test.transpose(1, 2),
                                 T_pred_abs[:, :, None])[:, :, 0]
        # Convert back to view-to-world to get absolute
        vtw_trans = (get_world_to_view_transform(
            R=R_pred_test, T=T_pred_test).inverse().get_matrix())
        T_pred_trans = vtw_trans[:, 3, :3]
        R_pred_trans = vtw_trans[:, :3, :3]

        # Calc pose error for this:
        q_loss_mean_test = 0
        t_loss_mean_test = 0
        for idx in range(len(R_pred_test)):
            q_pred_trans = rc.matrix_to_quaternion(R_pred_trans[idx]).squeeze()
            q_targ = poses_gt[idx, 3:]
            q_targ_unit = q_targ / torch.norm(q_targ)
            # print("look: ", q_test, q_targ)
            q_loss_mean_test += quaternion_angular_error(
                q_pred_trans, q_targ_unit)
            t_targ = poses_gt[idx, :3] * std_T + mean_T
            t_loss_mean_test += t_error(T_pred_trans[idx], t_targ)
        q_loss_mean_test /= idx + 1
        t_loss_mean_test /= idx + 1

        logging.info(
            f"Mean Translation Error: {t_loss_mean}; Mean Rotation Error: {q_loss_mean}"
        )
        logging.info(
            f"Mean Translation Error: {t_loss_mean_test}; Mean Rotation Error: {q_loss_mean_test}"
        )

        # Plot estimated cameras
        logging.info(f"Plotting pose map")
        idx = random.sample(range(len(R_gt)), k=2)
        pose_plot = plot_cams_from_poses(
            (R_gt[idx], T_gt[idx]), (R_pred[idx], T_pred[idx]), params.device)
        pose_plot_test = plot_cams_from_poses(
            (R_gt[idx], T_gt[idx]), (R_pred_test[idx], T_pred_test[idx]),
            params.device)
        manager.add_pose_plot(pose_plot, "rot+trans")
        manager.add_pose_plot(pose_plot_test, "trans")

        count += 1
        groundtruth_silhouettes = manager._images("silhouette") / 255.0
        print(groundtruth_silhouettes.shape)
        print(torch.stack((mask_priors)).shape)
        seg_iou = neg_iou_loss(groundtruth_silhouettes,
                               torch.stack((mask_priors)) / 255.0)
        print("Seg IoU", seg_iou)

        # RUN MESH DEFORMATION

        # RUN MESH DEFORMATION
        # Run it 3 times: w/ Rot+Trans - w/ Trans+LookAt - w/ GT Pose
        experiments = {
            "GT-Pose": [R_gt, T_gt],
            # "Rot+Trans": [R_pred, T_pred],
            # "Trans+LookAt": [R_pred_test, T_pred_test]
        }

        results = {}
        input_m = torch.stack((mask_priors))

        for i in range(len(experiments.keys())):

            logging.info(
                f"Input pred shape & max: {input_m.shape}, {input_m.max()}")
            # The MeshDeformation model will return silhouettes across all view by default

            mesh_model = MeshDeformationModel(device=device, params=params)

            R, T = list(experiments.values())[i]
            experiment_results = mesh_model.run_optimization(input_m, R, T)
            renders = mesh_model.render_final_mesh((R, T), "predict",
                                                   input_m.shape[-2:])

            mesh_silhouettes = renders["silhouettes"].squeeze(1)
            mesh_images = renders["images"].squeeze(1)
            experiment_name = list(experiments.keys())[i]
            for idx in range(len(mesh_silhouettes)):
                manager.add_pred(
                    idx,
                    mesh_silhouettes[idx].cpu().numpy(),
                    "silhouette",
                    destination=f"mesh_{experiment_name}",
                )
                manager.add_pred(
                    idx,
                    mesh_images[idx].cpu().numpy(),
                    "phong",
                    destination=f"mesh_{experiment_name}",
                )

            # Calculate chamfer loss:
            mesh_pred = mesh_model._final_mesh
            if mesh_type == "dolphin":
                path = "data/meshes/dolphin/dolphin.obj"
                mesh_gt = load_objs_as_meshes(
                    [path],
                    create_texture_atlas=False,
                    load_textures=True,
                    device=device,
                )
            # Shapenet Cars
            elif mesh_type == "shapenet":
                mesh_info = manager.metadata["mesh_info"]
                path = os.path.join(
                    params.gt_mesh_path,
                    f"/ShapeNetCorev2/{mesh_info['synset_id']}/{mesh_info['mesh_id']}/models/model_normalized.obj"
                )
                # path = f"data/ShapeNetCorev2/{mesh_info['synset_id']}/{mesh_info['mesh_id']}/models/model_normalized.obj"
                try:
                    verts, faces, aux = load_obj(path,
                                                 load_textures=True,
                                                 create_texture_atlas=True)

                    mesh_gt = Meshes(
                        verts=[verts],
                        faces=[faces.verts_idx],
                        textures=TexturesAtlas(atlas=[aux.texture_atlas]),
                    ).to(device)
                except:
                    mesh_gt = None
                    print("CANNOT COMPUTE CHAMFER LOSS")
            if mesh_gt:
                mesh_pred_compute, mesh_gt_compute = scale_meshes(
                    mesh_pred.clone(), mesh_gt.clone())
                pcl_pred = sample_points_from_meshes(mesh_pred_compute,
                                                     num_samples=5000)
                pcl_gt = sample_points_from_meshes(mesh_gt_compute,
                                                   num_samples=5000)
                chamfer_loss = chamfer_distance(pcl_pred,
                                                pcl_gt,
                                                point_reduction="mean")
                print("CHAMFER LOSS: ", chamfer_loss)
                experiment_results["chamfer_loss"] = (
                    chamfer_loss[0].cpu().detach().numpy().tolist())

            mesh_iou = neg_iou_loss(groundtruth_silhouettes, mesh_silhouettes)

            experiment_results["mesh_iou"] = mesh_iou.cpu().numpy().tolist()

            results[experiment_name] = experiment_results

            manager.add_pred_mesh(mesh_pred, experiment_name)
        # logging.info(f"Input pred shape & max: {input_m.shape}, {input_m.max()}")
        # # The MeshDeformation model will return silhouettes across all view by default
        #
        #
        #
        # experiment_results = models["mesh"].run_optimization(input_m, R_gt, T_gt)
        # renders = models["mesh"].render_final_mesh(
        #     (R_gt, T_gt), "predict", input_m.shape[-2:]
        # )
        #
        # mesh_silhouettes = renders["silhouettes"].squeeze(1)
        # mesh_images = renders["images"].squeeze(1)
        # experiment_name = params.name
        # for idx in range(len(mesh_silhouettes)):
        #     manager.add_pred(
        #         idx,
        #         mesh_silhouettes[idx].cpu().numpy(),
        #         "silhouette",
        #         destination=f"mesh_{experiment_name}",
        #     )
        #     manager.add_pred(
        #         idx,
        #         mesh_images[idx].cpu().numpy(),
        #         "phong",
        #         destination=f"mesh_{experiment_name}",
        #     )
        #
        # # Calculate chamfer loss:
        # mesh_pred = models["mesh"]._final_mesh
        # if mesh_type == "dolphin":
        #     path = params.gt_mesh_path
        #     mesh_gt = load_objs_as_meshes(
        #         [path],
        #         create_texture_atlas=False,
        #         load_textures=True,
        #         device=device,
        #     )
        # # Shapenet Cars
        # elif mesh_type == "shapenet":
        #     mesh_info = manager.metadata["mesh_info"]
        #     path = params.gt_mesh_path
        #     try:
        #         verts, faces, aux = load_obj(
        #             path, load_textures=True, create_texture_atlas=True
        #         )
        #
        #         mesh_gt = Meshes(
        #             verts=[verts],
        #             faces=[faces.verts_idx],
        #             textures=TexturesAtlas(atlas=[aux.texture_atlas]),
        #         ).to(device)
        #     except:
        #         mesh_gt = None
        #         print("CANNOT COMPUTE CHAMFER LOSS")
        # if mesh_gt and params.is_real_data:
        #     mesh_pred_compute, mesh_gt_compute = scale_meshes(
        #         mesh_pred.clone(), mesh_gt.clone()
        #     )
        #     pcl_pred = sample_points_from_meshes(
        #         mesh_pred_compute, num_samples=5000
        #     )
        #     pcl_gt = sample_points_from_meshes(mesh_gt_compute, num_samples=5000)
        #     chamfer_loss = chamfer_distance(
        #         pcl_pred, pcl_gt, point_reduction="mean"
        #     )
        #     print("CHAMFER LOSS: ", chamfer_loss)
        #     experiment_results["chamfer_loss"] = (
        #         chamfer_loss[0].cpu().detach().numpy().tolist()
        #     )
        #
        # mesh_iou = neg_iou_loss_all(groundtruth_silhouettes, mesh_silhouettes)
        #
        # experiment_results["mesh_iou"] = mesh_iou.cpu().numpy().tolist()
        #
        # results[experiment_name] = experiment_results
        #
        # manager.add_pred_mesh(mesh_pred, experiment_name)

        seg_iou = neg_iou_loss_all(groundtruth_silhouettes, input_m / 255.0)
        gt_iou = neg_iou_loss_all(groundtruth_silhouettes,
                                  groundtruth_silhouettes)

        results["mesh_iou"] = mesh_iou.detach().cpu().numpy().tolist()
        results["seg_iou"] = seg_iou.detach().cpu().numpy().tolist()
        logging.info(f"Mesh IOU list & results: {mesh_iou}")
        logging.info(f"Seg IOU list & results: {seg_iou}")
        logging.info(f"GT IOU list & results: {gt_iou} ")

        # results["mean_iou"] = IOULoss().forward(groundtruth, mesh_silhouettes).detach().cpu().numpy().tolist()
        # results["mean_dice"] = DiceCoeffLoss().forward(groundtruth, mesh_silhouettes)

        manager.set_pred_results(results)
        manager.close()
Exemple #5
0
    loss_model, _ = model()
    print('loss - ', loss)
    print('loss model - ', loss_model)
    obj = loss
    if obj.requires_grad:
        obj.backward()
    return obj


optimizer.step(fit_closure)
loss = image_fit_loss(face_mesh_alt)
R = look_at_rotation(camera_translation[None, :], device=model.device)
T = -torch.bmm(R.transpose(1, 2), camera_translation[None, :,
                                                     None])[:, :, 0]  # (1, 3)
image = silhouette_renderer(meshes_world=face_mesh_alt.clone(), R=R, T=T)
image = image[..., 3].detach().squeeze().cpu().numpy()
image = img_as_ubyte(image)
writer.append_data(image)
print('long LBFGS')
plt.subplot(121)
plt.imshow(image)
plt.title("iter: %d, loss: %0.2f" % (1, loss.data))
plt.grid("off")
plt.axis("off")
plt.subplot(122)
plt.imshow(image_ref)
plt.grid("off")
plt.axis("off")
plt.show()