Пример #1
0
    def _process_mesh(self, mesh, transforms, R=None, t=None):
        # clone mesh
        verts, faces = mesh
        # transform vertices to camera coordinate system
        verts = shape_utils.transform_verts(verts, R, t)

        assert all(
            isinstance(t, (T.HFlipTransform, T.NoOpTransform, T.ResizeTransform))
            for t in transforms.transforms
        )
        for t in transforms.transforms:
            if isinstance(t, T.HFlipTransform):
                verts[:, 0] = -verts[:, 0]
            elif isinstance(t, T.ResizeTransform):
                verts = t.apply_coords(verts)
            elif isinstance(t, T.NoOpTransform):
                pass
            else:
                raise ValueError("Transform {} not recognized".format(t))
        return verts, faces
Пример #2
0
 def _process_dz(self, mesh, transforms, focal_length=1.0, R=None, t=None):
     # clone mesh
     verts, faces = mesh
     # transform vertices to camera coordinate system
     verts = shape_utils.transform_verts(verts, R, t)
     assert all(
         isinstance(t, (T.HFlipTransform, T.NoOpTransform, T.ResizeTransform))
         for t in transforms.transforms
     )
     dz = verts[:, 2].max() - verts[:, 2].min()
     z_center = (verts[:, 2].max() + verts[:, 2].min()) / 2.0
     dz = dz / z_center
     dz = dz * focal_length
     for t in transforms.transforms:
         # NOTE normalize the dz by the height scaling of the image.
         # This is necessary s.t. z-regression targets log(dz/roi_h)
         # are invariant to the scaling of the roi_h
         if isinstance(t, T.ResizeTransform):
             dz = dz * (t.new_h * 1.0 / t.h)
     return dz
Пример #3
0
    def _process_voxel(self, voxel, transforms, R=None, t=None):
        # read voxel
        verts = shape_utils.read_voxel(voxel)
        # transform vertices to camera coordinate system
        verts = shape_utils.transform_verts(verts, R, t)

        # applies image transformations to voxels (represented as verts)
        # NOTE this function does not support generic transforms in T
        # the apply_coords functionality works for voxels for the following
        # transforms (HFlipTransform, NoOpTransform, ResizeTransform)
        assert all(
            isinstance(t, (T.HFlipTransform, T.NoOpTransform, T.ResizeTransform))
            for t in transforms.transforms
        )
        for t in transforms.transforms:
            if isinstance(t, T.HFlipTransform):
                verts[:, 0] = -verts[:, 0]
            elif isinstance(t, T.ResizeTransform):
                verts = t.apply_coords(verts)
            elif isinstance(t, T.NoOpTransform):
                pass
            else:
                raise ValueError("Transform {} not recognized".format(t))
        return verts
Пример #4
0
def evaluate_for_pix3d(
    predictions,
    dataset,
    metadata,
    filter_iou,
    mesh_models=None,
    iou_thresh=0.5,
    mask_thresh=0.5,
    device=None,
    vis_preds=False,
):
    from PIL import Image

    if device is None:
        device = torch.device("cpu")

    F1_TARGET = "[email protected]"

    # classes
    cat_ids = sorted(dataset.getCatIds())
    reverse_id_mapping = {
        v: k
        for k, v in metadata.thing_dataset_id_to_contiguous_id.items()
    }

    # initialize tensors to record box & mask AP, number of gt positives
    box_apscores, box_aplabels = {}, {}
    mask_apscores, mask_aplabels = {}, {}
    mesh_apscores, mesh_aplabels = {}, {}
    npos = {}
    for cat_id in cat_ids:
        box_apscores[cat_id] = [
            torch.tensor([], dtype=torch.float32, device=device)
        ]
        box_aplabels[cat_id] = [
            torch.tensor([], dtype=torch.uint8, device=device)
        ]
        mask_apscores[cat_id] = [
            torch.tensor([], dtype=torch.float32, device=device)
        ]
        mask_aplabels[cat_id] = [
            torch.tensor([], dtype=torch.uint8, device=device)
        ]
        mesh_apscores[cat_id] = [
            torch.tensor([], dtype=torch.float32, device=device)
        ]
        mesh_aplabels[cat_id] = [
            torch.tensor([], dtype=torch.uint8, device=device)
        ]
        npos[cat_id] = 0.0
    box_covered = []
    mask_covered = []
    mesh_covered = []

    # number of gt positive instances per class
    for gt_ann in dataset.dataset["annotations"]:
        gt_label = gt_ann["category_id"]
        # examples with imgfiles = {img/table/1749.jpg, img/table/0045.png}
        # have a mismatch between images and masks. Thus, ignore
        image_file_name = dataset.loadImgs([gt_ann["image_id"]
                                            ])[0]["file_name"]
        if image_file_name in ["img/table/1749.jpg", "img/table/0045.png"]:
            continue
        npos[gt_label] += 1.0

    for prediction in predictions:

        original_id = prediction["image_id"]
        image_width = dataset.loadImgs([original_id])[0]["width"]
        image_height = dataset.loadImgs([original_id])[0]["height"]
        image_size = [image_height, image_width]
        image_file_name = dataset.loadImgs([original_id])[0]["file_name"]
        # examples with imgfiles = {img/table/1749.jpg, img/table/0045.png}
        # have a mismatch between images and masks. Thus, ignore
        if image_file_name in ["img/table/1749.jpg", "img/table/0045.png"]:
            continue

        if "instances" not in prediction:
            continue

        num_img_preds = len(prediction["instances"])
        if num_img_preds == 0:
            continue

        # predictions
        scores = prediction["instances"].scores
        boxes = prediction["instances"].pred_boxes.to(device)
        labels = prediction["instances"].pred_classes
        masks_rles = prediction["instances"].pred_masks_rle
        if hasattr(prediction["instances"], "pred_meshes"):
            meshes = prediction["instances"].pred_meshes  # preditected meshes
            verts = [mesh[0] for mesh in meshes]
            faces = [mesh[1] for mesh in meshes]
            meshes = Meshes(verts=verts, faces=faces).to(device)
        else:
            meshes = ico_sphere(4, device)
            meshes = meshes.extend(num_img_preds).to(device)
        if hasattr(prediction["instances"], "pred_dz"):
            pred_dz = prediction["instances"].pred_dz
            heights = boxes.tensor[:, 3] - boxes.tensor[:, 1]
            # NOTE see appendix for derivation of pred dz
            pred_dz = pred_dz[:, 0] * heights.cpu()
        else:
            raise ValueError("Z range of box not predicted")
        assert prediction["instances"].image_size[0] == image_height
        assert prediction["instances"].image_size[1] == image_width

        # ground truth
        # anotations corresponding to original_id (aka coco image_id)
        gt_ann_ids = dataset.getAnnIds(imgIds=[original_id])
        assert len(
            gt_ann_ids) == 1  # note that pix3d has one annotation per image
        gt_anns = dataset.loadAnns(gt_ann_ids)[0]
        assert gt_anns["image_id"] == original_id

        # get original ground truth mask, box, label & mesh
        maskfile = os.path.join(metadata.image_root, gt_anns["segmentation"])
        with PathManager.open(maskfile, "rb") as f:
            gt_mask = torch.tensor(
                np.asarray(Image.open(f), dtype=np.float32) / 255.0)
        assert gt_mask.shape[0] == image_height and gt_mask.shape[
            1] == image_width

        gt_mask = (gt_mask > 0).to(dtype=torch.uint8)  # binarize mask
        gt_mask_rle = [
            mask_util.encode(np.array(gt_mask[:, :, None], order="F"))[0]
        ]
        gt_box = np.array(gt_anns["bbox"]).reshape(-1, 4)  # xywh from coco
        gt_box = BoxMode.convert(gt_box, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
        gt_label = gt_anns["category_id"]
        faux_gt_targets = Boxes(
            torch.tensor(gt_box, dtype=torch.float32, device=device))

        # load gt mesh and extrinsics/intrinsics
        gt_R = torch.tensor(gt_anns["rot_mat"]).to(device)
        gt_t = torch.tensor(gt_anns["trans_mat"]).to(device)
        gt_K = torch.tensor(gt_anns["K"]).to(device)
        if mesh_models is not None:
            modeltype = gt_anns["model"]
            gt_verts, gt_faces = (
                mesh_models[modeltype][0].clone(),
                mesh_models[modeltype][1].clone(),
            )
            gt_verts = gt_verts.to(device)
            gt_faces = gt_faces.to(device)
        else:
            # load from disc
            raise NotImplementedError
        gt_verts = shape_utils.transform_verts(gt_verts, gt_R, gt_t)
        gt_zrange = torch.stack([gt_verts[:, 2].min(), gt_verts[:, 2].max()])
        gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces])

        # box iou
        boxiou = pairwise_iou(boxes, faux_gt_targets)

        # filter predictions with iou > filter_iou
        valid_pred_ids = boxiou > filter_iou

        # mask iou
        miou = mask_util.iou(masks_rles, gt_mask_rle, [0])

        # # gt zrange (zrange stores min_z and max_z)
        # # zranges = torch.stack([gt_zrange] * len(meshes), dim=0)

        # predicted zrange (= pred_dz)
        assert hasattr(prediction["instances"], "pred_dz")
        # It's impossible to predict the center location in Z (=tc)
        # from the image. See appendix for more.
        tc = (gt_zrange[1] + gt_zrange[0]) / 2.0
        # Given a center location (tc) and a focal_length,
        # pred_dz = pred_dz * box_h * tc / focal_length
        # See appendix for more.
        zranges = torch.stack(
            [
                torch.stack([
                    tc - tc * pred_dz[i] / 2.0 / gt_K[0],
                    tc + tc * pred_dz[i] / 2.0 / gt_K[0]
                ]) for i in range(len(meshes))
            ],
            dim=0,
        )

        gt_Ks = gt_K.view(1, 3).expand(len(meshes), 3)
        meshes = transform_meshes_to_camera_coord_system(
            meshes, boxes.tensor, zranges, gt_Ks, image_size)

        if vis_preds:
            vis_utils.visualize_predictions(
                original_id,
                image_file_name,
                scores,
                labels,
                boxes.tensor,
                masks_rles,
                meshes,
                metadata,
                "/tmp/output",
            )

        shape_metrics = compare_meshes(meshes, gt_mesh, reduce=False)

        # sort predictions in descending order
        scores_sorted, idx_sorted = torch.sort(scores, descending=True)

        for pred_id in range(num_img_preds):
            # remember we only evaluate the preds that have overlap more than
            # iou_filter with the ground truth prediction
            if valid_pred_ids[idx_sorted[pred_id], 0] == 0:
                continue
            # map to dataset category id
            pred_label = reverse_id_mapping[labels[idx_sorted[pred_id]].item()]
            pred_miou = miou[idx_sorted[pred_id]].item()
            pred_biou = boxiou[idx_sorted[pred_id]].item()
            pred_score = scores[idx_sorted[pred_id]].view(1).to(device)
            # note that metrics returns f1 in % (=x100)
            pred_f1 = shape_metrics[F1_TARGET][
                idx_sorted[pred_id]].item() / 100.0

            # mask
            tpfp = torch.tensor([0], dtype=torch.uint8, device=device)
            if ((pred_label == gt_label) and (pred_miou > iou_thresh)
                    and (original_id not in mask_covered)):
                tpfp[0] = 1
                mask_covered.append(original_id)
            mask_apscores[pred_label].append(pred_score)
            mask_aplabels[pred_label].append(tpfp)

            # box
            tpfp = torch.tensor([0], dtype=torch.uint8, device=device)
            if ((pred_label == gt_label) and (pred_biou > iou_thresh)
                    and (original_id not in box_covered)):
                tpfp[0] = 1
                box_covered.append(original_id)
            box_apscores[pred_label].append(pred_score)
            box_aplabels[pred_label].append(tpfp)

            # mesh
            tpfp = torch.tensor([0], dtype=torch.uint8, device=device)
            if ((pred_label == gt_label) and (pred_f1 > iou_thresh)
                    and (original_id not in mesh_covered)):
                tpfp[0] = 1
                mesh_covered.append(original_id)
            mesh_apscores[pred_label].append(pred_score)
            mesh_aplabels[pred_label].append(tpfp)

    # check things for eval
    # assert npos.sum() == len(dataset.dataset["annotations"])
    # convert to tensors
    pix3d_metrics = {}
    boxap, maskap, meshap = 0.0, 0.0, 0.0
    valid = 0.0
    for cat_id in cat_ids:
        cat_name = dataset.loadCats([cat_id])[0]["name"]
        if npos[cat_id] == 0:
            continue
        valid += 1

        cat_box_ap = VOCap.compute_ap(torch.cat(box_apscores[cat_id]),
                                      torch.cat(box_aplabels[cat_id]),
                                      npos[cat_id])
        boxap += cat_box_ap
        pix3d_metrics["box_ap@%.1f - %s" % (iou_thresh, cat_name)] = cat_box_ap

        cat_mask_ap = VOCap.compute_ap(torch.cat(mask_apscores[cat_id]),
                                       torch.cat(mask_aplabels[cat_id]),
                                       npos[cat_id])
        maskap += cat_mask_ap
        pix3d_metrics["mask_ap@%.1f - %s" %
                      (iou_thresh, cat_name)] = cat_mask_ap

        cat_mesh_ap = VOCap.compute_ap(torch.cat(mesh_apscores[cat_id]),
                                       torch.cat(mesh_aplabels[cat_id]),
                                       npos[cat_id])
        meshap += cat_mesh_ap
        pix3d_metrics["mesh_ap@%.1f - %s" %
                      (iou_thresh, cat_name)] = cat_mesh_ap

    pix3d_metrics["box_ap@%.1f" % iou_thresh] = boxap / valid
    pix3d_metrics["mask_ap@%.1f" % iou_thresh] = maskap / valid
    pix3d_metrics["mesh_ap@%.1f" % iou_thresh] = meshap / valid

    # print test ground truth
    vis_utils.print_instances_class_histogram(
        [npos[cat_id] for cat_id in cat_ids],  # number of instances
        [dataset.loadCats([cat_id])[0]["name"]
         for cat_id in cat_ids],  # class names
        pix3d_metrics,
    )

    return pix3d_metrics