def run_on_image(self, images, id_str):
        deprocess = imagenet_deprocess(rescale_image=False)
        voxel_scores, meshes_pred = self.predictor(images)

        img = image_to_numpy(deprocess(images[0][0]))
        vis_utils.visualize_prediction(id_str, img, meshes_pred[-1][0],
                                       self.output_dir)
Пример #2
0
def save_images(imgs, file_prefix):
    """
    Args:
    - imgs: tensor of shape (B, V, C, H, W)
    - file_prefix: prefix to use in the filename to distinguish batches
    """
    transform = imagenet_deprocess(False)
    for batch_idx in range(imgs.shape[0]):
        for view_idx in range(imgs.shape[1]):
            img = imgs[batch_idx, view_idx]
            img = transform(img) * 255
            img = img.type(torch.uint8).cpu().detach() \
                     .permute(1, 2, 0).numpy()
            # white background
            img[img == 0] = 255
            filename = "/tmp/image_{}_{}_{}.png" \
                            .format(file_prefix, batch_idx, view_idx)
            cv2.imwrite(filename, img)
Пример #3
0
    def run_on_image(self, image, id_str, gt_verts, gt_faces):
        deprocess = imagenet_deprocess(rescale_image=False)

        with torch.no_grad():
            voxel_scores, meshes_pred = self.predictor(image.to(self.device))

        sid, mid, iid = id_str.split('-')
        iid = int(iid)

        #Transform vertex space
        metadata_path = os.path.join('./datasets/shapenet/ShapeNetV1processed',
                                     sid, mid, "metadata.pt")
        metadata = torch.load(metadata_path)
        K = metadata["intrinsic"]
        RTs = metadata["extrinsics"].to(self.device)
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).to(RTs)

        mesh = meshes_pred[-1][0]
        #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
        #for the GT mesh
        invRT = torch.inverse(RTs[iid].mm(rot_y_90))
        invRT_no_rot = torch.inverse(RTs[iid])
        mesh._verts_list[0] = project_verts(mesh._verts_list[0], invRT)

        #Get look at view extrinsics
        render_metadata_path = os.path.join(
            'datasets/shapenet/ShapeNetRenderingExtrinsics', sid, mid,
            'rendering_metadata.pt')
        render_metadata = torch.load(render_metadata_path)
        render_RTs = render_metadata['extrinsics'].to(self.device)

        verts, faces = mesh.get_mesh_verts_faces(0)
        verts_rgb = torch.ones_like(verts)[None]
        textures = Textures(verts_rgb=verts_rgb.to(self.device))
        mesh.textures = textures

        plt.figure(figsize=(10, 10))

        #Silhouette Renderer
        render_image_size = 256
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=render_image_size,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )

        gt_verts = gt_verts.to(self.device)
        gt_faces = gt_faces.to(self.device)
        verts_rgb = torch.ones_like(gt_verts)[None]
        textures = Textures(verts_rgb=verts_rgb)
        #Invert without the rotation for the vehicle class
        if sid == '02958343':
            gt_verts = project_verts(gt_verts, invRT_no_rot.to(self.device))
        else:
            gt_verts = project_verts(gt_verts, invRT.to(self.device))
        gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces], textures=textures)

        probability_map = 0.01 * torch.ones((1, 24))
        viewgrid = torch.zeros(
            (1, 24, render_image_size, render_image_size)).to(self.device)
        fig = plt.figure(1)
        ax_pred = [fig.add_subplot(5, 5, i + 1) for i in range(24)]
        #fig = plt.figure(2)
        #ax_gt = [fig.add_subplot(5,5,i+1) for i in range(24)]

        for i in range(len(render_RTs)):
            if i == iid:  #Don't include current view
                continue

            R = render_RTs[i][:3, :3].unsqueeze(0)
            T = render_RTs[i][:3, 3].unsqueeze(0)
            cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T)

            silhouette_renderer = MeshRenderer(
                rasterizer=MeshRasterizer(cameras=cameras,
                                          raster_settings=raster_settings),
                shader=SoftSilhouetteShader(blend_params=blend_params))

            ref_image = (silhouette_renderer(meshes_world=gt_mesh, R=R, T=T) >
                         0).float()
            silhouette_image = (silhouette_renderer(
                meshes_world=mesh, R=R, T=T) > 0).float()

            # MSE Loss between both silhouettes
            silh_loss = torch.sum(
                (silhouette_image[0, :, :, 3] - ref_image[0, :, :, 3])**2)
            probability_map[0, i] = silh_loss.detach()

            viewgrid[0, i] = silhouette_image[..., -1]

            #ax_gt[i].imshow(ref_image[0,:,:,3].cpu().numpy())
            #ax_gt[i].set_title(i)

            ax_pred[i].imshow(silhouette_image[0, :, :, 3].cpu().numpy())
            ax_pred[i].set_title(i)

        img = image_to_numpy(deprocess(image[0]))
        #ax_gt[iid].imshow(img)
        ax_pred[iid].imshow(img)
        #fig = plt.figure(3)
        #ax = fig.add_subplot(111)
        #ax.imshow(img)

        pred_prob_map = self.loss_predictor(viewgrid)
        print('Highest actual loss: {}'.format(torch.argmax(probability_map)))
        print('Highest predicted loss: {}'.format(torch.argmax(pred_prob_map)))
        plt.show()
      def run_on_image(self, imgs, id_strs, meshes_gt, render_RTs, RTs):
        deprocess = imagenet_deprocess(rescale_image=False)

        #Silhouette Renderer
        render_image_size = 256
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=render_image_size, 
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
            faces_per_pixel=50, 
        )   

        rot_y_90 = torch.tensor([[0, 0, 1, 0], 
                                    [0, 1, 0, 0], 
                                    [-1, 0, 0, 0], 
                                    [0, 0, 0, 1]]).to(RTs) 
        with torch.no_grad():
            voxel_scores, meshes_pred = self.predictor(imgs)

        B,_,H,W = imgs.shape
        probability_map = 0.01 * torch.ones((B, 24)).to(self.device)
        viewgrid = torch.zeros((B,24,render_image_size,render_image_size)).to(device) # batch size x 24 x H x W
        _meshes_pred = meshes_pred[-1]

        for b, (cur_gt_mesh, cur_pred_mesh) in enumerate(zip(meshes_gt, _meshes_pred)):
            sid = id_strs[b].split('-')[0]

            RT = RTs[b]

            #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
            #for the GT mesh
            invRT = torch.inverse(RT.mm(rot_y_90))
            invRT_no_rot = torch.inverse(RT)

            cur_pred_mesh._verts_list[0] = project_verts(cur_pred_mesh._verts_list[0], invRT)
            #Invert without the rotation for the vehicle class
            if sid == '02958343':
                cur_gt_mesh._verts_list[0] = project_verts(
                    cur_gt_mesh._verts_list[0], invRT_no_rot)
            else:
                cur_gt_mesh._verts_list[0] = project_verts(
                    cur_gt_mesh._verts_list[0], invRT)

            '''
            plt.figure(figsize=(10, 10))

            fig = plt.figure(1)
            ax_pred = [fig.add_subplot(5,5,i+1) for i in range(24)]
            fig = plt.figure(2)
            ax_gt = [fig.add_subplot(5,5,i+1) for i in range(24)]
            '''

            for iid in range(len(render_RTs)):

                R = render_RTs[b][iid][:3,:3].unsqueeze(0)
                T = render_RTs[b][iid][:3,3].unsqueeze(0)
                cameras = OpenGLPerspectiveCameras(device=self.device, R=R, T=T)
                silhouette_renderer = MeshRenderer(
                    rasterizer=MeshRasterizer(
                    cameras=cameras, 
                    raster_settings=raster_settings
                    ),  
                shader=SoftSilhouetteShader(blend_params=blend_params)
                )   

                ref_image        = (silhouette_renderer(meshes_world=cur_gt_mesh, R=R, T=T)>0).float()
                silhouette_image = (silhouette_renderer(meshes_world=cur_pred_mesh, R=R, T=T)>0).float()

                #Add image silhouette to viewgrid
                viewgrid[b,iid] = silhouette_image[...,-1]

                # MSE Loss between both silhouettes
                silh_loss = torch.sum((silhouette_image[0, :, :, 3] - ref_image[0, :, :, 3]) ** 2)
                probability_map[b, iid] = silh_loss.detach()

                '''
                ax_gt[iid].imshow(ref_image[0,:,:,3].cpu().numpy())
                ax_gt[iid].set_title(iid)
                ax_pred[iid].imshow(silhouette_image[0,:,:,3].cpu().numpy())
                ax_pred[iid].set_title(iid)
                '''

            '''
            img = image_to_numpy(deprocess(imgs[b]))
            ax_gt[iid].imshow(img)
            ax_pred[iid].imshow(img)
            fig = plt.figure(3)
            ax = fig.add_subplot(111)
            ax.imshow(img)

            #plt.show()
            '''

        probability_map = probability_map/(torch.max(probability_map, dim=1)[0].unsqueeze(1)) # Normalize
        probability_map = torch.nn.functional.softmax(probability_map, dim=1) # Softmax across images
        pred_prob_map = self.loss_predictor(viewgrid)

        gt_max = torch.argmax(probability_map, dim=1)
        pred_max = torch.argmax(pred_prob_map, dim=1)

        #print('--'*30)
        #print('Item: {}'.format(id_str))
        #print('Highest actual loss: {}'.format(gt_max))
        #print('Highest predicted loss: {}'.format(pred_max))

        #print('GT prob map: {}'.format(probability_map.squeeze()))
        #print('Pred prob map: {}'.format(pred_prob_map.squeeze()))

        correct = torch.sum(pred_max == gt_max).item()
        total   = len(pred_prob_map)

        return correct, total
Пример #5
0
def evaluate_split(model,
                   loader,
                   max_predictions=-1,
                   num_predictions_keep=10,
                   prefix="",
                   store_predictions=False):
    """
    This function is used to report validation performance during training.
    """
    # Note that all eval runs on main process
    assert comm.is_main_process()
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module

    device = torch.device("cuda:0")
    num_predictions = 0
    num_predictions_kept = 0
    predictions = defaultdict(list)
    metrics = defaultdict(list)
    deprocess = imagenet_deprocess(rescale_image=False)
    for batch in loader:
        batch = loader.postprocess(batch, device)
        imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch
        voxel_scores, meshes_pred = model(imgs)

        # Only compute metrics for the final predicted meshes, not intermediates
        cur_metrics = compare_meshes(meshes_pred[-1], meshes_gt)
        if cur_metrics is None:
            continue
        for k, v in cur_metrics.items():
            metrics[k].append(v)

        # Store input images and predicted meshes
        if store_predictions:
            N = imgs.shape[0]
            for i in range(N):
                if num_predictions_kept >= num_predictions_keep:
                    break
                num_predictions_kept += 1

                img = image_to_numpy(deprocess(imgs[i]))
                predictions["%simg_input" % prefix].append(img)
                for level, cur_meshes_pred in enumerate(meshes_pred):
                    verts, faces = cur_meshes_pred.get_mesh(i)
                    verts_key = "%sverts_pred_%d" % (prefix, level)
                    faces_key = "%sfaces_pred_%d" % (prefix, level)
                    predictions[verts_key].append(verts.cpu().numpy())
                    predictions[faces_key].append(faces.cpu().numpy())

        num_predictions += len(meshes_gt)
        logger.info("Evaluated %d predictions so far" % num_predictions)
        if 0 < max_predictions <= num_predictions:
            break

    # Average numeric metrics, and concatenate images
    metrics = {"%s%s" % (prefix, k): np.mean(v) for k, v in metrics.items()}
    if store_predictions:
        img_key = "%simg_input" % prefix
        predictions[img_key] = np.stack(predictions[img_key], axis=0)

    return metrics, predictions
Пример #6
0
def evaluate_test(model, data_loader, vis_preds=False):
    """
    This function evaluates the model on the dataset defined by data_loader.
    The metrics reported are described in Table 2 of our paper.
    """
    # Note that all eval runs on main process
    assert comm.is_main_process()
    deprocess = imagenet_deprocess(rescale_image=False)
    device = torch.device("cuda:0")
    # evaluation
    class_names = {
        "02828884": "bench",
        "03001627": "chair",
        "03636649": "lamp",
        "03691459": "speaker",
        "04090263": "firearm",
        "04379243": "table",
        "04530566": "watercraft",
        "02691156": "plane",
        "02933112": "cabinet",
        "02958343": "car",
        "03211117": "monitor",
        "04256520": "couch",
        "04401088": "cellphone",
    }

    num_instances = {i: 0 for i in class_names}
    chamfer = {i: 0 for i in class_names}
    normal = {i: 0 for i in class_names}
    f1_01 = {i: 0 for i in class_names}
    f1_03 = {i: 0 for i in class_names}
    f1_05 = {i: 0 for i in class_names}

    num_batch_evaluated = 0
    for batch in data_loader:
        batch = data_loader.postprocess(batch, device)
        imgs, meshes_gt, _, _, _, id_strs, _imgs = batch

        #NOTE: _imgs contains all of the other images in belonging to this model
        #We have to select the next-best-view from that list of images

        sids = [id_str.split("-")[0] for id_str in id_strs]
        for sid in sids:
            num_instances[sid] += 1

        with inference_context(model):
            voxel_scores, meshes_pred = model(imgs)

            #TODO: Render masks from predicted mesh for each view

            cur_metrics = compare_meshes(meshes_pred[-1],
                                         meshes_gt,
                                         reduce=False)
            cur_metrics["verts_per_mesh"] = meshes_pred[-1].num_verts_per_mesh(
            ).cpu()
            cur_metrics["faces_per_mesh"] = meshes_pred[-1].num_faces_per_mesh(
            ).cpu()

            for i, sid in enumerate(sids):
                chamfer[sid] += cur_metrics["Chamfer-L2"][i].item()
                normal[sid] += cur_metrics["AbsNormalConsistency"][i].item()
                f1_01[sid] += cur_metrics["F1@%f" % 0.1][i].item()
                f1_03[sid] += cur_metrics["F1@%f" % 0.3][i].item()
                f1_05[sid] += cur_metrics["F1@%f" % 0.5][i].item()

                if vis_preds:
                    img = image_to_numpy(deprocess(imgs[i]))
                    vis_utils.visualize_prediction(id_strs[i], img,
                                                   meshes_pred[-1][i],
                                                   "/tmp/output")

            num_batch_evaluated += 1
            logger.info("Evaluated %d / %d batches" %
                        (num_batch_evaluated, len(data_loader)))

    vis_utils.print_instances_class_histogram(
        num_instances,
        class_names,
        {
            "chamfer": chamfer,
            "normal": normal,
            "f1_01": f1_01,
            "f1_03": f1_03,
            "f1_05": f1_05
        },
    )
    def run_on_image(self, image, id_str, gt_verts, gt_faces):
        deprocess = imagenet_deprocess(rescale_image=False)

        with torch.no_grad():
            voxel_scores, meshes_pred = self.predictor(image)

        sid, mid, iid = id_str.split('-')
        iid = int(iid)

        #Transform vertex space
        metadata_path = os.path.join('./datasets/shapenet/ShapeNetV1processed',
                                     sid, mid, "metadata.pt")
        metadata = torch.load(metadata_path)
        K = metadata["intrinsic"]
        RTs = metadata["extrinsics"]
        rot_y_90 = torch.tensor([[0, 0, 1, 0], [0, 1, 0, 0], [-1, 0, 0, 0],
                                 [0, 0, 0, 1]]).to(RTs)

        mesh = meshes_pred[-1][0]
        #For some strange reason all classes (expect vehicle class) require a 90 degree rotation about the y-axis
        #for the GT mesh
        invRT = torch.inverse(RTs[iid].mm(rot_y_90))
        invRT_no_rot = torch.inverse(RTs[iid])
        mesh._verts_list[0] = project_verts(mesh._verts_list[0], invRT.cpu())

        #Get look at view extrinsics
        render_metadata_path = os.path.join(
            'datasets/shapenet/ShapeNetRenderingExtrinsics', sid, mid,
            'rendering_metadata.pt')
        render_metadata = torch.load(render_metadata_path)
        render_RTs = render_metadata['extrinsics']

        plt.figure(figsize=(10, 10))
        R = render_RTs[iid][:3, :3].unsqueeze(0)
        T = render_RTs[iid][:3, 3].unsqueeze(0)
        cameras = OpenGLPerspectiveCameras(R=R, T=T)

        #Phong Renderer
        lights = PointLights(location=[[0.0, 0.0, -3.0]])
        raster_settings = RasterizationSettings(image_size=256,
                                                blur_radius=0.0,
                                                faces_per_pixel=1,
                                                bin_size=0)
        phong_renderer = MeshRenderer(rasterizer=MeshRasterizer(
            cameras=cameras, raster_settings=raster_settings),
                                      shader=HardPhongShader(lights=lights))

        #Silhouette Renderer
        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=256,
            blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
            faces_per_pixel=50,
        )
        silhouette_renderer = MeshRenderer(
            rasterizer=MeshRasterizer(cameras=cameras,
                                      raster_settings=raster_settings),
            shader=SoftSilhouetteShader(blend_params=blend_params))

        verts, faces = mesh.get_mesh_verts_faces(0)
        verts_rgb = torch.ones_like(verts)[None]
        textures = Textures(verts_rgb=verts_rgb)
        mesh.textures = textures

        verts_rgb = torch.ones_like(gt_verts)[None]
        textures = Textures(verts_rgb=verts_rgb)
        #Invert without the rotation for the vehicle class
        if sid == '02958343':
            gt_verts = project_verts(gt_verts, invRT_no_rot.cpu())
        else:
            gt_verts = project_verts(gt_verts, invRT.cpu())
        gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces], textures=textures)

        img = image_to_numpy(deprocess(image[0]))
        mesh_image = phong_renderer(meshes_world=mesh, R=R, T=T)
        gt_silh_image = (silhouette_renderer(meshes_world=gt_mesh, R=R, T=T) >
                         0).float()
        silhouette_image = (silhouette_renderer(meshes_world=mesh, R=R, T=T) >
                            0).float()

        plt.subplot(2, 2, 1)
        plt.imshow(img)
        plt.title('input image')
        plt.subplot(2, 2, 2)
        plt.imshow(mesh_image[0, ..., :3].cpu().numpy())
        plt.title('rendered mesh')
        plt.subplot(2, 2, 3)
        plt.imshow(gt_silh_image[0, ..., 3].cpu().numpy())
        plt.title('silhouette of gt mesh')
        plt.subplot(2, 2, 4)
        plt.imshow(silhouette_image[0, ..., 3].cpu().numpy())
        plt.title('silhouette of rendered mesh')

        plt.show()
        #plt.savefig('./output_demo/figures/'+id_str+'.png')

        vis_utils.visualize_prediction(id_str, img, mesh, self.output_dir)
Пример #8
0
def evaluate_split_depth(
    model, loader, max_predictions=-1, num_predictions_keep=10,
    prefix="", store_predictions=False
):
    """
    This function is used to report validation performance during training.
    """
    # Note that all eval runs on main process
    assert comm.is_main_process()
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module

    device = torch.device("cuda:0")
    num_predictions = 0
    num_predictions_kept = 0
    total_l1_err = 0.0
    num_pixels = 0.0
    predictions = defaultdict(list)
    metrics = defaultdict(list)
    deprocess = imagenet_deprocess(rescale_image=False)
    for batch in loader:
        batch = {
            k: (v.to(device) if isinstance(v, torch.Tensor) else v)
            for k, v in batch.items()
        }
        model_kwargs = {}
        model_kwargs["extrinsics"] = batch["extrinsics"]
        model_outputs = model(batch["imgs"], **model_kwargs)
        pred_depths = model_outputs["depths"]
        loss = adaptive_berhu_loss(
            batch["depths"], pred_depths, batch["masks"]
        ).item()

        depth_gt = interpolate_multi_view_tensor(
            batch["depths"], pred_depths.shape[-2:]
        )
        mask = interpolate_multi_view_tensor(
            batch["masks"], pred_depths.shape[-2:]
        )
        masked_pred_depths = pred_depths * mask
        total_l1_err += \
                torch.sum(torch.abs(masked_pred_depths - depth_gt)).item()
        num_pixels += torch.sum(mask).item()
        cur_metrics = {"depth_loss": loss, "negative_depth_loss": -loss}

        if cur_metrics is None:
            continue
        for k, v in cur_metrics.items():
            metrics[k].append(v)

        # Store input images and predicted meshes
        if store_predictions:
            N = batch["imgs"].shape[0]
            for i in range(N):
                if num_predictions_kept >= num_predictions_keep:
                    break
                num_predictions_kept += 1

                img = image_to_numpy(deprocess(batch["imgs"][i]))
                depth = image_to_numpy(batch["depths"][i].unsqueeze(0))
                pred_depth = image_to_numpy(pred_depths[i].unsqueeze(0))
                predictions["%simg_input" % prefix].append(img)
                predictions["%sdepth_input" % prefix].append(depth)
                predictions["%sdepth_pred" % prefix].append(pred_depth)

        num_predictions += len(batch["imgs"])
        logger.info(
            "Evaluated %d predictions so far: avg err: %f" \
                    % (num_predictions, total_l1_err / num_pixels)
        )
        if 0 < max_predictions <= num_predictions:
            break

    # Average numeric metrics, and concatenate images
    metrics = {"%s%s" % (prefix, k): np.mean(v) for k, v in metrics.items()}
    if store_predictions:
        keys = ["%simg_input", "%sdepth_input", "%sdepth_pred"]
        keys = [i % prefix for i in keys]
        predictions = {
            k: np.stack(predictions[k], axis=0) for k, v in predictions.items()
        }

    return metrics, predictions
Пример #9
0
def evaluate_test(model, data_loader, vis_preds=False):
    """
    This function evaluates the model on the dataset defined by data_loader.
    The metrics reported are described in Table 2 of our paper.
    """
    # Note that all eval runs on main process
    assert comm.is_main_process()
    deprocess = imagenet_deprocess(rescale_image=False)
    device = torch.device("cuda:0")
    # evaluation
    class_names = {
        "02828884": "bench",
        "03001627": "chair",
        "03636649": "lamp",
        "03691459": "speaker",
        "04090263": "firearm",
        "04379243": "table",
        "04530566": "watercraft",
        "02691156": "plane",
        "02933112": "cabinet",
        "02958343": "car",
        "03211117": "monitor",
        "04256520": "couch",
        "04401088": "cellphone",
    }

    num_instances = {i: 0 for i in class_names}
    chamfer = {i: 0 for i in class_names}
    normal = {i: 0 for i in class_names}
    f1_01 = {i: 0 for i in class_names}
    f1_03 = {i: 0 for i in class_names}
    f1_05 = {i: 0 for i in class_names}

    num_batch_evaluated = 0
    for batch in data_loader:
        batch = data_loader.postprocess(batch, device)
        sids = [id_str.split("-")[0] for id_str in batch["id_strs"]]
        for sid in sids:
            num_instances[sid] += 1

        with inference_context(model):
            model_kwargs = {}
            module = model.module if hasattr(model, "module") else model
            if isinstance(module, VoxMeshMultiViewHead):
                model_kwargs["intrinsics"] = batch["intrinsics"]
                model_kwargs["extrinsics"] = batch["extrinsics"]
            if isinstance(module, VoxMeshDepthHead):
                model_kwargs["masks"] = batch["masks"]

            model_outputs = model(batch["imgs"], **model_kwargs)
            voxel_scores = model_outputs["voxel_scores"]
            meshes_pred = model_outputs["meshes_pred"]

            cur_metrics = compare_meshes(meshes_pred[-1], batch["meshes"], reduce=False)
            cur_metrics["verts_per_mesh"] = meshes_pred[-1].num_verts_per_mesh().cpu()
            cur_metrics["faces_per_mesh"] = meshes_pred[-1].num_faces_per_mesh().cpu()

            for i, sid in enumerate(sids):
                chamfer[sid] += cur_metrics["Chamfer-L2"][i].item()
                normal[sid] += cur_metrics["AbsNormalConsistency"][i].item()
                f1_01[sid] += cur_metrics["F1@%f" % 0.1][i].item()
                f1_03[sid] += cur_metrics["F1@%f" % 0.3][i].item()
                f1_05[sid] += cur_metrics["F1@%f" % 0.5][i].item()

                if vis_preds:
                    img = image_to_numpy(deprocess(batch["imgs"][i]))
                    vis_utils.visualize_prediction(
                        batch["id_strs"][i], img, meshes_pred[-1][i], "/tmp/output"
                    )

            num_batch_evaluated += 1
            logger.info("Evaluated %d / %d batches" % (num_batch_evaluated, len(data_loader)))

    vis_utils.print_instances_class_histogram(
        num_instances,
        class_names,
        {"chamfer": chamfer, "normal": normal, "f1_01": f1_01, "f1_03": f1_03, "f1_05": f1_05},
    )
Пример #10
0
def evaluate_vox(model, loader, prediction_dir=None, max_predictions=-1):
    """
    This function is used to report validation performance of voxel head output
    """
    # Note that all eval runs on main process
    assert comm.is_main_process()
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module

    if prediction_dir is not None:
        for prefix in ["merged", "vox_0", "vox_1", "vox_2"]:
            output_dir = pred_filename = os.path.join(
                prediction_dir, prefix, "predict", "0"
            )
            os.makedirs(output_dir, exist_ok=True)

    device = torch.device("cuda:0")
    metrics = defaultdict(list)
    deprocess = imagenet_deprocess(rescale_image=False)
    for batch_idx, batch in tqdm.tqdm(enumerate(loader)):
        if max_predictions >= 1 and batch_idx > max_predictions:
            break
        batch = loader.postprocess(batch, device)
        model_kwargs = {}
        module = model.module if hasattr(model, "module") else model
        if isinstance(module, VoxMeshMultiViewHead):
            model_kwargs["intrinsics"] = batch["intrinsics"]
            model_kwargs["extrinsics"] = batch["extrinsics"]
        if isinstance(module, VoxDepthHead):
            model_kwargs["masks"] = batch["masks"]
            if module.cfg.MODEL.USE_GT_DEPTH:
                model_kwargs["depths"] = batch["depths"]
        model_outputs = model(batch["imgs"], **model_kwargs)
        voxel_scores = model_outputs["voxel_scores"]
        transformed_voxel_scores = model_outputs["transformed_voxel_scores"]
        merged_voxel_scores = model_outputs.get(
            "merged_voxel_scores", None
        )

        # NOTE that for the F1 thresholds we take the square root of 1e-4 & 2e-4
        # as `compare_meshes` returns the euclidean distance (L2) of two pointclouds.
        # In Pixel2Mesh, the squared L2 (L2^2) is computed instead.
        # i.e. (L2^2 < τ) <=> (L2 < sqrt(τ))
        if "meshes_pred" in model_outputs:
            meshes_pred = model_outputs["meshes_pred"]
            cur_metrics = compare_meshes(
                meshes_pred[-1], batch["meshes"],
                scale=0.57, thresholds=[0.01, 0.014142]
            )
            for k, v in cur_metrics.items():
                metrics["final_" + k].append(v)

        voxel_losses = MeshLoss.voxel_loss(
            voxel_scores, merged_voxel_scores, batch["voxels"]
        )
        # to get metric negate loss
        for k, v in voxel_losses.items():
            metrics[k].append(-v.detach().item())

        # save meshes
        if prediction_dir is not None:
            # cubify all the voxel scores
            merged_vox_mesh = cubify(
                merged_voxel_scores, module.voxel_size, module.cubify_threshold
            )
            # transformed_vox_mesh = [cubify(
            #     i, module.voxel_size, module.cubify_threshold
            # ) for i in transformed_voxel_scores]
            vox_meshes = {
                "merged": merged_vox_mesh,
                # **{
                #     "vox_%d" % i: mesh
                #     for i, mesh in enumerate(transformed_vox_mesh)
                # }
            }

            gt_mesh = batch["meshes"].scale_verts(0.57)
            gt_points = sample_points_from_meshes(
                gt_mesh, 9000, return_normals=False
            )
            gt_points = gt_points.cpu().detach().numpy()

            for mesh_idx in range(len(batch["id_strs"])):
                label, label_appendix \
                        = batch["id_strs"][mesh_idx].split("-")[:2]
                for prefix, vox_mesh in vox_meshes.items():
                    output_dir = pred_filename = os.path.join(
                        prediction_dir, prefix, "predict", "0"
                    )
                    pred_filename = os.path.join(
                        output_dir,
                        "{}_{}_predict.xyz".format(label, label_appendix)
                    )
                    gt_filename = os.path.join(
                        output_dir,
                        "{}_{}_ground.xyz".format(label, label_appendix)
                    )

                    pred_mesh = vox_mesh[mesh_idx].scale_verts(0.57)
                    pred_points = sample_points_from_meshes(
                        pred_mesh, 6466, return_normals=False
                    )
                    pred_points = pred_points.squeeze(0).cpu() \
                                             .detach().numpy()

                    np.savetxt(pred_filename, pred_points)
                    np.savetxt(gt_filename, gt_points[mesh_idx])

            # find accuracy of each cubified voxel meshes
            for prefix, vox_mesh in vox_meshes.items():
                vox_mesh_metrics = compare_meshes(
                    vox_mesh, batch["meshes"],
                    scale=0.57, thresholds=[0.01, 0.014142]
                )

                if vox_mesh_metrics is None:
                    continue
                for k, v in vox_mesh_metrics.items():
                    metrics[prefix + "_" + k].append(v)

    # Average numeric metrics, and concatenate images
    metrics = {k: np.mean(v) for k, v in metrics.items()}
    return metrics
Пример #11
0
def evaluate_split(
    model, loader, max_predictions=-1, num_predictions_keep=10, prefix="", store_predictions=False
):
    """
    This function is used to report validation performance during training.
    """
    # Note that all eval runs on main process
    assert comm.is_main_process()
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module

    device = torch.device("cuda:0")
    num_predictions = 0
    num_predictions_kept = 0
    predictions = defaultdict(list)
    metrics = defaultdict(list)
    deprocess = imagenet_deprocess(rescale_image=False)
    for batch in loader:
        batch = loader.postprocess(batch, device)
        model_kwargs = {}
        module = model.module if hasattr(model, "module") else model
        if isinstance(module, VoxMeshMultiViewHead):
            model_kwargs["intrinsics"] = batch["intrinsics"]
            model_kwargs["extrinsics"] = batch["extrinsics"]
        if isinstance(module, VoxMeshDepthHead):
            model_kwargs["masks"] = batch["masks"]
            if module.cfg.MODEL.USE_GT_DEPTH:
                model_kwargs["depths"] = batch["depths"]
        model_outputs = model(batch["imgs"], **model_kwargs)
        meshes_pred = model_outputs["meshes_pred"]
        voxel_scores = model_outputs["voxel_scores"]
        merged_voxel_scores = model_outputs.get(
            "merged_voxel_scores", None
        )

        # Only compute metrics for the final predicted meshes, not intermediates
        cur_metrics = compare_meshes(meshes_pred[-1], batch["meshes"])
        if cur_metrics is None:
            continue
        for k, v in cur_metrics.items():
            metrics[k].append(v)

        voxel_losses = MeshLoss.voxel_loss(
            voxel_scores, merged_voxel_scores, batch["voxels"]
        )
        # to get metric negate loss
        for k, v in voxel_losses.items():
            metrics[k].append(-v.item())

        # Store input images and predicted meshes
        if store_predictions:
            N = batch["imgs"].shape[0]
            for i in range(N):
                if num_predictions_kept >= num_predictions_keep:
                    break
                num_predictions_kept += 1

                img = image_to_numpy(deprocess(batch["imgs"][i]))
                predictions["%simg_input" % prefix].append(img)
                for level, cur_meshes_pred in enumerate(meshes_pred):
                    verts, faces = cur_meshes_pred.get_mesh(i)
                    verts_key = "%sverts_pred_%d" % (prefix, level)
                    faces_key = "%sfaces_pred_%d" % (prefix, level)
                    predictions[verts_key].append(verts.cpu().numpy())
                    predictions[faces_key].append(faces.cpu().numpy())

        num_predictions += len(batch["meshes"])
        logger.info("Evaluated %d predictions so far" % num_predictions)
        if 0 < max_predictions <= num_predictions:
            break

    # Average numeric metrics, and concatenate images
    metrics = {"%s%s" % (prefix, k): np.mean(v) for k, v in metrics.items()}
    if store_predictions:
        img_key = "%simg_input" % prefix
        predictions[img_key] = np.stack(predictions[img_key], axis=0)

    return metrics, predictions