def evaluate_test(model, data_loader, vis_preds=False):
    """
    This function evaluates the model on the dataset defined by data_loader.
    The metrics reported are described in Table 2 of our paper.
    """
    # Note that all eval runs on main process
    assert comm.is_main_process()
    deprocess = imagenet_deprocess(rescale_image=False)
    device = torch.device("cuda:0")
    # evaluation
    class_names = {
        "02828884": "bench",
        "03001627": "chair",
        "03636649": "lamp",
        "03691459": "speaker",
        "04090263": "firearm",
        "04379243": "table",
        "04530566": "watercraft",
        "02691156": "plane",
        "02933112": "cabinet",
        "02958343": "car",
        "03211117": "monitor",
        "04256520": "couch",
        "04401088": "cellphone",
    }

    num_instances = {i: 0 for i in class_names}
    chamfer = {i: 0 for i in class_names}
    normal = {i: 0 for i in class_names}
    f1_01 = {i: 0 for i in class_names}
    f1_03 = {i: 0 for i in class_names}
    f1_05 = {i: 0 for i in class_names}

    num_batch_evaluated = 0
    for batch in data_loader:
        batch = data_loader.postprocess(batch, device)
        imgs, meshes_gt, _, _, _, id_strs, _imgs = batch

        #NOTE: _imgs contains all of the other images in belonging to this model
        #We have to select the next-best-view from that list of images

        sids = [id_str.split("-")[0] for id_str in id_strs]
        for sid in sids:
            num_instances[sid] += 1

        with inference_context(model):
            voxel_scores, meshes_pred = model(imgs)

            #TODO: Render masks from predicted mesh for each view

            cur_metrics = compare_meshes(meshes_pred[-1],
                                         meshes_gt,
                                         reduce=False)
            cur_metrics["verts_per_mesh"] = meshes_pred[-1].num_verts_per_mesh(
            ).cpu()
            cur_metrics["faces_per_mesh"] = meshes_pred[-1].num_faces_per_mesh(
            ).cpu()

            for i, sid in enumerate(sids):
                chamfer[sid] += cur_metrics["Chamfer-L2"][i].item()
                normal[sid] += cur_metrics["AbsNormalConsistency"][i].item()
                f1_01[sid] += cur_metrics["F1@%f" % 0.1][i].item()
                f1_03[sid] += cur_metrics["F1@%f" % 0.3][i].item()
                f1_05[sid] += cur_metrics["F1@%f" % 0.5][i].item()

                if vis_preds:
                    img = image_to_numpy(deprocess(imgs[i]))
                    vis_utils.visualize_prediction(id_strs[i], img,
                                                   meshes_pred[-1][i],
                                                   "/tmp/output")

            num_batch_evaluated += 1
            logger.info("Evaluated %d / %d batches" %
                        (num_batch_evaluated, len(data_loader)))

    vis_utils.print_instances_class_histogram(
        num_instances,
        class_names,
        {
            "chamfer": chamfer,
            "normal": normal,
            "f1_01": f1_01,
            "f1_03": f1_03,
            "f1_05": f1_05
        },
    )
def evaluate_split(model,
                   loader,
                   max_predictions=-1,
                   num_predictions_keep=10,
                   prefix="",
                   store_predictions=False):
    """
    This function is used to report validation performance during training.
    """
    # Note that all eval runs on main process
    assert comm.is_main_process()
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module

    device = torch.device("cuda:0")
    num_predictions = 0
    num_predictions_kept = 0
    predictions = defaultdict(list)
    metrics = defaultdict(list)
    deprocess = imagenet_deprocess(rescale_image=False)
    for batch in loader:
        batch = loader.postprocess(batch, device)
        imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch
        voxel_scores, meshes_pred = model(imgs)

        # Only compute metrics for the final predicted meshes, not intermediates
        cur_metrics = compare_meshes(meshes_pred[-1], meshes_gt)
        if cur_metrics is None:
            continue
        for k, v in cur_metrics.items():
            metrics[k].append(v)

        # Store input images and predicted meshes
        if store_predictions:
            N = imgs.shape[0]
            for i in range(N):
                if num_predictions_kept >= num_predictions_keep:
                    break
                num_predictions_kept += 1

                img = image_to_numpy(deprocess(imgs[i]))
                predictions["%simg_input" % prefix].append(img)
                for level, cur_meshes_pred in enumerate(meshes_pred):
                    verts, faces = cur_meshes_pred.get_mesh(i)
                    verts_key = "%sverts_pred_%d" % (prefix, level)
                    faces_key = "%sfaces_pred_%d" % (prefix, level)
                    predictions[verts_key].append(verts.cpu().numpy())
                    predictions[faces_key].append(faces.cpu().numpy())

        num_predictions += len(meshes_gt)
        logger.info("Evaluated %d predictions so far" % num_predictions)
        if 0 < max_predictions <= num_predictions:
            break

    # Average numeric metrics, and concatenate images
    metrics = {"%s%s" % (prefix, k): np.mean(v) for k, v in metrics.items()}
    if store_predictions:
        img_key = "%simg_input" % prefix
        predictions[img_key] = np.stack(predictions[img_key], axis=0)

    return metrics, predictions
def evaluate_test_p2m(model, data_loader):
    """
    This function evaluates the model on the dataset defined by data_loader.
    The metrics reported are described in Table 1 of our paper, following previous
    reported approaches (like Pixel2Mesh - p2m), where meshes are
    rescaled by a factor of 0.57. See the paper for more details.
    """
    assert comm.is_main_process()
    device = torch.device("cuda:0")
    # evaluation
    class_names = {
        "02828884": "bench",
        "03001627": "chair",
        "03636649": "lamp",
        "03691459": "speaker",
        "04090263": "firearm",
        "04379243": "table",
        "04530566": "watercraft",
        "02691156": "plane",
        "02933112": "cabinet",
        "02958343": "car",
        "03211117": "monitor",
        "04256520": "couch",
        "04401088": "cellphone",
    }

    num_instances = {i: 0 for i in class_names}
    chamfer = {i: 0 for i in class_names}
    normal = {i: 0 for i in class_names}
    f1_1e_4 = {i: 0 for i in class_names}
    f1_2e_4 = {i: 0 for i in class_names}

    num_batch_evaluated = 0
    for batch in data_loader:
        batch = data_loader.postprocess(batch, device)
        imgs, meshes_gt, _, _, _, id_strs = batch
        sids = [id_str.split("-")[0] for id_str in id_strs]
        for sid in sids:
            num_instances[sid] += 1

        with inference_context(model):
            voxel_scores, meshes_pred = model(imgs)
            # NOTE that for the F1 thresholds we take the square root of 1e-4 & 2e-4
            # as `compare_meshes` returns the euclidean distance (L2) of two pointclouds.
            # In Pixel2Mesh, the squared L2 (L2^2) is computed instead.
            # i.e. (L2^2 < τ) <=> (L2 < sqrt(τ))
            cur_metrics = compare_meshes(meshes_pred[-1],
                                         meshes_gt,
                                         scale=0.57,
                                         thresholds=[0.01, 0.014142],
                                         reduce=False)
            cur_metrics["verts_per_mesh"] = meshes_pred[-1].num_verts_per_mesh(
            ).cpu()
            cur_metrics["faces_per_mesh"] = meshes_pred[-1].num_faces_per_mesh(
            ).cpu()

            for i, sid in enumerate(sids):
                chamfer[sid] += cur_metrics["Chamfer-L2"][i].item()
                normal[sid] += cur_metrics["AbsNormalConsistency"][i].item()
                f1_1e_4[sid] += cur_metrics["F1@%f" % 0.01][i].item()
                f1_2e_4[sid] += cur_metrics["F1@%f" % 0.014142][i].item()

            num_batch_evaluated += 1
            logger.info("Evaluated %d / %d batches" %
                        (num_batch_evaluated, len(data_loader)))

    vis_utils.print_instances_class_histogram_p2m(
        num_instances,
        class_names,
        {
            "chamfer": chamfer,
            "normal": normal,
            "f1_1e_4": f1_1e_4,
            "f1_2e_4": f1_2e_4
        },
    )
Beispiel #4
0
def evaluate_for_pix3d(
    predictions,
    dataset,
    metadata,
    filter_iou,
    mesh_models=None,
    iou_thresh=0.5,
    mask_thresh=0.5,
    device=None,
    vis_preds=False,
):
    from PIL import Image

    if device is None:
        device = torch.device("cpu")

    F1_TARGET = "[email protected]"

    # classes
    cat_ids = sorted(dataset.getCatIds())
    reverse_id_mapping = {
        v: k
        for k, v in metadata.thing_dataset_id_to_contiguous_id.items()
    }

    # initialize tensors to record box & mask AP, number of gt positives
    box_apscores, box_aplabels = {}, {}
    mask_apscores, mask_aplabels = {}, {}
    mesh_apscores, mesh_aplabels = {}, {}
    npos = {}
    for cat_id in cat_ids:
        box_apscores[cat_id] = [
            torch.tensor([], dtype=torch.float32, device=device)
        ]
        box_aplabels[cat_id] = [
            torch.tensor([], dtype=torch.uint8, device=device)
        ]
        mask_apscores[cat_id] = [
            torch.tensor([], dtype=torch.float32, device=device)
        ]
        mask_aplabels[cat_id] = [
            torch.tensor([], dtype=torch.uint8, device=device)
        ]
        mesh_apscores[cat_id] = [
            torch.tensor([], dtype=torch.float32, device=device)
        ]
        mesh_aplabels[cat_id] = [
            torch.tensor([], dtype=torch.uint8, device=device)
        ]
        npos[cat_id] = 0.0
    box_covered = []
    mask_covered = []
    mesh_covered = []

    # number of gt positive instances per class
    for gt_ann in dataset.dataset["annotations"]:
        gt_label = gt_ann["category_id"]
        # examples with imgfiles = {img/table/1749.jpg, img/table/0045.png}
        # have a mismatch between images and masks. Thus, ignore
        image_file_name = dataset.loadImgs([gt_ann["image_id"]
                                            ])[0]["file_name"]
        if image_file_name in ["img/table/1749.jpg", "img/table/0045.png"]:
            continue
        npos[gt_label] += 1.0

    for prediction in predictions:

        original_id = prediction["image_id"]
        image_width = dataset.loadImgs([original_id])[0]["width"]
        image_height = dataset.loadImgs([original_id])[0]["height"]
        image_size = [image_height, image_width]
        image_file_name = dataset.loadImgs([original_id])[0]["file_name"]
        # examples with imgfiles = {img/table/1749.jpg, img/table/0045.png}
        # have a mismatch between images and masks. Thus, ignore
        if image_file_name in ["img/table/1749.jpg", "img/table/0045.png"]:
            continue

        if "instances" not in prediction:
            continue

        num_img_preds = len(prediction["instances"])
        if num_img_preds == 0:
            continue

        # predictions
        scores = prediction["instances"].scores
        boxes = prediction["instances"].pred_boxes.to(device)
        labels = prediction["instances"].pred_classes
        masks_rles = prediction["instances"].pred_masks_rle
        if hasattr(prediction["instances"], "pred_meshes"):
            meshes = prediction["instances"].pred_meshes  # preditected meshes
            verts = [mesh[0] for mesh in meshes]
            faces = [mesh[1] for mesh in meshes]
            meshes = Meshes(verts=verts, faces=faces).to(device)
        else:
            meshes = ico_sphere(4, device)
            meshes = meshes.extend(num_img_preds).to(device)
        if hasattr(prediction["instances"], "pred_dz"):
            pred_dz = prediction["instances"].pred_dz
            heights = boxes.tensor[:, 3] - boxes.tensor[:, 1]
            # NOTE see appendix for derivation of pred dz
            pred_dz = pred_dz[:, 0] * heights.cpu()
        else:
            raise ValueError("Z range of box not predicted")
        assert prediction["instances"].image_size[0] == image_height
        assert prediction["instances"].image_size[1] == image_width

        # ground truth
        # anotations corresponding to original_id (aka coco image_id)
        gt_ann_ids = dataset.getAnnIds(imgIds=[original_id])
        assert len(
            gt_ann_ids) == 1  # note that pix3d has one annotation per image
        gt_anns = dataset.loadAnns(gt_ann_ids)[0]
        assert gt_anns["image_id"] == original_id

        # get original ground truth mask, box, label & mesh
        maskfile = os.path.join(metadata.image_root, gt_anns["segmentation"])
        with PathManager.open(maskfile, "rb") as f:
            gt_mask = torch.tensor(
                np.asarray(Image.open(f), dtype=np.float32) / 255.0)
        assert gt_mask.shape[0] == image_height and gt_mask.shape[
            1] == image_width

        gt_mask = (gt_mask > 0).to(dtype=torch.uint8)  # binarize mask
        gt_mask_rle = [
            mask_util.encode(np.array(gt_mask[:, :, None], order="F"))[0]
        ]
        gt_box = np.array(gt_anns["bbox"]).reshape(-1, 4)  # xywh from coco
        gt_box = BoxMode.convert(gt_box, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
        gt_label = gt_anns["category_id"]
        faux_gt_targets = Boxes(
            torch.tensor(gt_box, dtype=torch.float32, device=device))

        # load gt mesh and extrinsics/intrinsics
        gt_R = torch.tensor(gt_anns["rot_mat"]).to(device)
        gt_t = torch.tensor(gt_anns["trans_mat"]).to(device)
        gt_K = torch.tensor(gt_anns["K"]).to(device)
        if mesh_models is not None:
            modeltype = gt_anns["model"]
            gt_verts, gt_faces = (
                mesh_models[modeltype][0].clone(),
                mesh_models[modeltype][1].clone(),
            )
            gt_verts = gt_verts.to(device)
            gt_faces = gt_faces.to(device)
        else:
            # load from disc
            raise NotImplementedError
        gt_verts = shape_utils.transform_verts(gt_verts, gt_R, gt_t)
        gt_zrange = torch.stack([gt_verts[:, 2].min(), gt_verts[:, 2].max()])
        gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces])

        # box iou
        boxiou = pairwise_iou(boxes, faux_gt_targets)

        # filter predictions with iou > filter_iou
        valid_pred_ids = boxiou > filter_iou

        # mask iou
        miou = mask_util.iou(masks_rles, gt_mask_rle, [0])

        # # gt zrange (zrange stores min_z and max_z)
        # # zranges = torch.stack([gt_zrange] * len(meshes), dim=0)

        # predicted zrange (= pred_dz)
        assert hasattr(prediction["instances"], "pred_dz")
        # It's impossible to predict the center location in Z (=tc)
        # from the image. See appendix for more.
        tc = (gt_zrange[1] + gt_zrange[0]) / 2.0
        # Given a center location (tc) and a focal_length,
        # pred_dz = pred_dz * box_h * tc / focal_length
        # See appendix for more.
        zranges = torch.stack(
            [
                torch.stack([
                    tc - tc * pred_dz[i] / 2.0 / gt_K[0],
                    tc + tc * pred_dz[i] / 2.0 / gt_K[0]
                ]) for i in range(len(meshes))
            ],
            dim=0,
        )

        gt_Ks = gt_K.view(1, 3).expand(len(meshes), 3)
        meshes = transform_meshes_to_camera_coord_system(
            meshes, boxes.tensor, zranges, gt_Ks, image_size)

        if vis_preds:
            vis_utils.visualize_predictions(
                original_id,
                image_file_name,
                scores,
                labels,
                boxes.tensor,
                masks_rles,
                meshes,
                metadata,
                "/tmp/output",
            )

        shape_metrics = compare_meshes(meshes, gt_mesh, reduce=False)

        # sort predictions in descending order
        scores_sorted, idx_sorted = torch.sort(scores, descending=True)

        for pred_id in range(num_img_preds):
            # remember we only evaluate the preds that have overlap more than
            # iou_filter with the ground truth prediction
            if valid_pred_ids[idx_sorted[pred_id], 0] == 0:
                continue
            # map to dataset category id
            pred_label = reverse_id_mapping[labels[idx_sorted[pred_id]].item()]
            pred_miou = miou[idx_sorted[pred_id]].item()
            pred_biou = boxiou[idx_sorted[pred_id]].item()
            pred_score = scores[idx_sorted[pred_id]].view(1).to(device)
            # note that metrics returns f1 in % (=x100)
            pred_f1 = shape_metrics[F1_TARGET][
                idx_sorted[pred_id]].item() / 100.0

            # mask
            tpfp = torch.tensor([0], dtype=torch.uint8, device=device)
            if ((pred_label == gt_label) and (pred_miou > iou_thresh)
                    and (original_id not in mask_covered)):
                tpfp[0] = 1
                mask_covered.append(original_id)
            mask_apscores[pred_label].append(pred_score)
            mask_aplabels[pred_label].append(tpfp)

            # box
            tpfp = torch.tensor([0], dtype=torch.uint8, device=device)
            if ((pred_label == gt_label) and (pred_biou > iou_thresh)
                    and (original_id not in box_covered)):
                tpfp[0] = 1
                box_covered.append(original_id)
            box_apscores[pred_label].append(pred_score)
            box_aplabels[pred_label].append(tpfp)

            # mesh
            tpfp = torch.tensor([0], dtype=torch.uint8, device=device)
            if ((pred_label == gt_label) and (pred_f1 > iou_thresh)
                    and (original_id not in mesh_covered)):
                tpfp[0] = 1
                mesh_covered.append(original_id)
            mesh_apscores[pred_label].append(pred_score)
            mesh_aplabels[pred_label].append(tpfp)

    # check things for eval
    # assert npos.sum() == len(dataset.dataset["annotations"])
    # convert to tensors
    pix3d_metrics = {}
    boxap, maskap, meshap = 0.0, 0.0, 0.0
    valid = 0.0
    for cat_id in cat_ids:
        cat_name = dataset.loadCats([cat_id])[0]["name"]
        if npos[cat_id] == 0:
            continue
        valid += 1

        cat_box_ap = VOCap.compute_ap(torch.cat(box_apscores[cat_id]),
                                      torch.cat(box_aplabels[cat_id]),
                                      npos[cat_id])
        boxap += cat_box_ap
        pix3d_metrics["box_ap@%.1f - %s" % (iou_thresh, cat_name)] = cat_box_ap

        cat_mask_ap = VOCap.compute_ap(torch.cat(mask_apscores[cat_id]),
                                       torch.cat(mask_aplabels[cat_id]),
                                       npos[cat_id])
        maskap += cat_mask_ap
        pix3d_metrics["mask_ap@%.1f - %s" %
                      (iou_thresh, cat_name)] = cat_mask_ap

        cat_mesh_ap = VOCap.compute_ap(torch.cat(mesh_apscores[cat_id]),
                                       torch.cat(mesh_aplabels[cat_id]),
                                       npos[cat_id])
        meshap += cat_mesh_ap
        pix3d_metrics["mesh_ap@%.1f - %s" %
                      (iou_thresh, cat_name)] = cat_mesh_ap

    pix3d_metrics["box_ap@%.1f" % iou_thresh] = boxap / valid
    pix3d_metrics["mask_ap@%.1f" % iou_thresh] = maskap / valid
    pix3d_metrics["mesh_ap@%.1f" % iou_thresh] = meshap / valid

    # print test ground truth
    vis_utils.print_instances_class_histogram(
        [npos[cat_id] for cat_id in cat_ids],  # number of instances
        [dataset.loadCats([cat_id])[0]["name"]
         for cat_id in cat_ids],  # class names
        pix3d_metrics,
    )

    return pix3d_metrics
Beispiel #5
0
def evaluate_test(model, data_loader, vis_preds=False):
    """
    This function evaluates the model on the dataset defined by data_loader.
    The metrics reported are described in Table 2 of our paper.
    """
    # Note that all eval runs on main process
    assert comm.is_main_process()
    deprocess = imagenet_deprocess(rescale_image=False)
    device = torch.device("cuda:0")
    # evaluation
    class_names = {
        "02828884": "bench",
        "03001627": "chair",
        "03636649": "lamp",
        "03691459": "speaker",
        "04090263": "firearm",
        "04379243": "table",
        "04530566": "watercraft",
        "02691156": "plane",
        "02933112": "cabinet",
        "02958343": "car",
        "03211117": "monitor",
        "04256520": "couch",
        "04401088": "cellphone",
    }

    num_instances = {i: 0 for i in class_names}
    chamfer = {i: 0 for i in class_names}
    normal = {i: 0 for i in class_names}
    f1_01 = {i: 0 for i in class_names}
    f1_03 = {i: 0 for i in class_names}
    f1_05 = {i: 0 for i in class_names}

    num_batch_evaluated = 0
    for batch in data_loader:
        batch = data_loader.postprocess(batch, device)
        sids = [id_str.split("-")[0] for id_str in batch["id_strs"]]
        for sid in sids:
            num_instances[sid] += 1

        with inference_context(model):
            model_kwargs = {}
            module = model.module if hasattr(model, "module") else model
            if isinstance(module, VoxMeshMultiViewHead):
                model_kwargs["intrinsics"] = batch["intrinsics"]
                model_kwargs["extrinsics"] = batch["extrinsics"]
            if isinstance(module, VoxMeshDepthHead):
                model_kwargs["masks"] = batch["masks"]

            model_outputs = model(batch["imgs"], **model_kwargs)
            voxel_scores = model_outputs["voxel_scores"]
            meshes_pred = model_outputs["meshes_pred"]

            cur_metrics = compare_meshes(meshes_pred[-1], batch["meshes"], reduce=False)
            cur_metrics["verts_per_mesh"] = meshes_pred[-1].num_verts_per_mesh().cpu()
            cur_metrics["faces_per_mesh"] = meshes_pred[-1].num_faces_per_mesh().cpu()

            for i, sid in enumerate(sids):
                chamfer[sid] += cur_metrics["Chamfer-L2"][i].item()
                normal[sid] += cur_metrics["AbsNormalConsistency"][i].item()
                f1_01[sid] += cur_metrics["F1@%f" % 0.1][i].item()
                f1_03[sid] += cur_metrics["F1@%f" % 0.3][i].item()
                f1_05[sid] += cur_metrics["F1@%f" % 0.5][i].item()

                if vis_preds:
                    img = image_to_numpy(deprocess(batch["imgs"][i]))
                    vis_utils.visualize_prediction(
                        batch["id_strs"][i], img, meshes_pred[-1][i], "/tmp/output"
                    )

            num_batch_evaluated += 1
            logger.info("Evaluated %d / %d batches" % (num_batch_evaluated, len(data_loader)))

    vis_utils.print_instances_class_histogram(
        num_instances,
        class_names,
        {"chamfer": chamfer, "normal": normal, "f1_01": f1_01, "f1_03": f1_03, "f1_05": f1_05},
    )
Beispiel #6
0
def evaluate_vox(model, loader, prediction_dir=None, max_predictions=-1):
    """
    This function is used to report validation performance of voxel head output
    """
    # Note that all eval runs on main process
    assert comm.is_main_process()
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module

    if prediction_dir is not None:
        for prefix in ["merged", "vox_0", "vox_1", "vox_2"]:
            output_dir = pred_filename = os.path.join(
                prediction_dir, prefix, "predict", "0"
            )
            os.makedirs(output_dir, exist_ok=True)

    device = torch.device("cuda:0")
    metrics = defaultdict(list)
    deprocess = imagenet_deprocess(rescale_image=False)
    for batch_idx, batch in tqdm.tqdm(enumerate(loader)):
        if max_predictions >= 1 and batch_idx > max_predictions:
            break
        batch = loader.postprocess(batch, device)
        model_kwargs = {}
        module = model.module if hasattr(model, "module") else model
        if isinstance(module, VoxMeshMultiViewHead):
            model_kwargs["intrinsics"] = batch["intrinsics"]
            model_kwargs["extrinsics"] = batch["extrinsics"]
        if isinstance(module, VoxDepthHead):
            model_kwargs["masks"] = batch["masks"]
            if module.cfg.MODEL.USE_GT_DEPTH:
                model_kwargs["depths"] = batch["depths"]
        model_outputs = model(batch["imgs"], **model_kwargs)
        voxel_scores = model_outputs["voxel_scores"]
        transformed_voxel_scores = model_outputs["transformed_voxel_scores"]
        merged_voxel_scores = model_outputs.get(
            "merged_voxel_scores", None
        )

        # NOTE that for the F1 thresholds we take the square root of 1e-4 & 2e-4
        # as `compare_meshes` returns the euclidean distance (L2) of two pointclouds.
        # In Pixel2Mesh, the squared L2 (L2^2) is computed instead.
        # i.e. (L2^2 < τ) <=> (L2 < sqrt(τ))
        if "meshes_pred" in model_outputs:
            meshes_pred = model_outputs["meshes_pred"]
            cur_metrics = compare_meshes(
                meshes_pred[-1], batch["meshes"],
                scale=0.57, thresholds=[0.01, 0.014142]
            )
            for k, v in cur_metrics.items():
                metrics["final_" + k].append(v)

        voxel_losses = MeshLoss.voxel_loss(
            voxel_scores, merged_voxel_scores, batch["voxels"]
        )
        # to get metric negate loss
        for k, v in voxel_losses.items():
            metrics[k].append(-v.detach().item())

        # save meshes
        if prediction_dir is not None:
            # cubify all the voxel scores
            merged_vox_mesh = cubify(
                merged_voxel_scores, module.voxel_size, module.cubify_threshold
            )
            # transformed_vox_mesh = [cubify(
            #     i, module.voxel_size, module.cubify_threshold
            # ) for i in transformed_voxel_scores]
            vox_meshes = {
                "merged": merged_vox_mesh,
                # **{
                #     "vox_%d" % i: mesh
                #     for i, mesh in enumerate(transformed_vox_mesh)
                # }
            }

            gt_mesh = batch["meshes"].scale_verts(0.57)
            gt_points = sample_points_from_meshes(
                gt_mesh, 9000, return_normals=False
            )
            gt_points = gt_points.cpu().detach().numpy()

            for mesh_idx in range(len(batch["id_strs"])):
                label, label_appendix \
                        = batch["id_strs"][mesh_idx].split("-")[:2]
                for prefix, vox_mesh in vox_meshes.items():
                    output_dir = pred_filename = os.path.join(
                        prediction_dir, prefix, "predict", "0"
                    )
                    pred_filename = os.path.join(
                        output_dir,
                        "{}_{}_predict.xyz".format(label, label_appendix)
                    )
                    gt_filename = os.path.join(
                        output_dir,
                        "{}_{}_ground.xyz".format(label, label_appendix)
                    )

                    pred_mesh = vox_mesh[mesh_idx].scale_verts(0.57)
                    pred_points = sample_points_from_meshes(
                        pred_mesh, 6466, return_normals=False
                    )
                    pred_points = pred_points.squeeze(0).cpu() \
                                             .detach().numpy()

                    np.savetxt(pred_filename, pred_points)
                    np.savetxt(gt_filename, gt_points[mesh_idx])

            # find accuracy of each cubified voxel meshes
            for prefix, vox_mesh in vox_meshes.items():
                vox_mesh_metrics = compare_meshes(
                    vox_mesh, batch["meshes"],
                    scale=0.57, thresholds=[0.01, 0.014142]
                )

                if vox_mesh_metrics is None:
                    continue
                for k, v in vox_mesh_metrics.items():
                    metrics[prefix + "_" + k].append(v)

    # Average numeric metrics, and concatenate images
    metrics = {k: np.mean(v) for k, v in metrics.items()}
    return metrics
Beispiel #7
0
def evaluate_split(
    model, loader, max_predictions=-1, num_predictions_keep=10, prefix="", store_predictions=False
):
    """
    This function is used to report validation performance during training.
    """
    # Note that all eval runs on main process
    assert comm.is_main_process()
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model = model.module

    device = torch.device("cuda:0")
    num_predictions = 0
    num_predictions_kept = 0
    predictions = defaultdict(list)
    metrics = defaultdict(list)
    deprocess = imagenet_deprocess(rescale_image=False)
    for batch in loader:
        batch = loader.postprocess(batch, device)
        model_kwargs = {}
        module = model.module if hasattr(model, "module") else model
        if isinstance(module, VoxMeshMultiViewHead):
            model_kwargs["intrinsics"] = batch["intrinsics"]
            model_kwargs["extrinsics"] = batch["extrinsics"]
        if isinstance(module, VoxMeshDepthHead):
            model_kwargs["masks"] = batch["masks"]
            if module.cfg.MODEL.USE_GT_DEPTH:
                model_kwargs["depths"] = batch["depths"]
        model_outputs = model(batch["imgs"], **model_kwargs)
        meshes_pred = model_outputs["meshes_pred"]
        voxel_scores = model_outputs["voxel_scores"]
        merged_voxel_scores = model_outputs.get(
            "merged_voxel_scores", None
        )

        # Only compute metrics for the final predicted meshes, not intermediates
        cur_metrics = compare_meshes(meshes_pred[-1], batch["meshes"])
        if cur_metrics is None:
            continue
        for k, v in cur_metrics.items():
            metrics[k].append(v)

        voxel_losses = MeshLoss.voxel_loss(
            voxel_scores, merged_voxel_scores, batch["voxels"]
        )
        # to get metric negate loss
        for k, v in voxel_losses.items():
            metrics[k].append(-v.item())

        # Store input images and predicted meshes
        if store_predictions:
            N = batch["imgs"].shape[0]
            for i in range(N):
                if num_predictions_kept >= num_predictions_keep:
                    break
                num_predictions_kept += 1

                img = image_to_numpy(deprocess(batch["imgs"][i]))
                predictions["%simg_input" % prefix].append(img)
                for level, cur_meshes_pred in enumerate(meshes_pred):
                    verts, faces = cur_meshes_pred.get_mesh(i)
                    verts_key = "%sverts_pred_%d" % (prefix, level)
                    faces_key = "%sfaces_pred_%d" % (prefix, level)
                    predictions[verts_key].append(verts.cpu().numpy())
                    predictions[faces_key].append(faces.cpu().numpy())

        num_predictions += len(batch["meshes"])
        logger.info("Evaluated %d predictions so far" % num_predictions)
        if 0 < max_predictions <= num_predictions:
            break

    # Average numeric metrics, and concatenate images
    metrics = {"%s%s" % (prefix, k): np.mean(v) for k, v in metrics.items()}
    if store_predictions:
        img_key = "%simg_input" % prefix
        predictions[img_key] = np.stack(predictions[img_key], axis=0)

    return metrics, predictions