def evaluate_test(model, data_loader, vis_preds=False): """ This function evaluates the model on the dataset defined by data_loader. The metrics reported are described in Table 2 of our paper. """ # Note that all eval runs on main process assert comm.is_main_process() deprocess = imagenet_deprocess(rescale_image=False) device = torch.device("cuda:0") # evaluation class_names = { "02828884": "bench", "03001627": "chair", "03636649": "lamp", "03691459": "speaker", "04090263": "firearm", "04379243": "table", "04530566": "watercraft", "02691156": "plane", "02933112": "cabinet", "02958343": "car", "03211117": "monitor", "04256520": "couch", "04401088": "cellphone", } num_instances = {i: 0 for i in class_names} chamfer = {i: 0 for i in class_names} normal = {i: 0 for i in class_names} f1_01 = {i: 0 for i in class_names} f1_03 = {i: 0 for i in class_names} f1_05 = {i: 0 for i in class_names} num_batch_evaluated = 0 for batch in data_loader: batch = data_loader.postprocess(batch, device) imgs, meshes_gt, _, _, _, id_strs, _imgs = batch #NOTE: _imgs contains all of the other images in belonging to this model #We have to select the next-best-view from that list of images sids = [id_str.split("-")[0] for id_str in id_strs] for sid in sids: num_instances[sid] += 1 with inference_context(model): voxel_scores, meshes_pred = model(imgs) #TODO: Render masks from predicted mesh for each view cur_metrics = compare_meshes(meshes_pred[-1], meshes_gt, reduce=False) cur_metrics["verts_per_mesh"] = meshes_pred[-1].num_verts_per_mesh( ).cpu() cur_metrics["faces_per_mesh"] = meshes_pred[-1].num_faces_per_mesh( ).cpu() for i, sid in enumerate(sids): chamfer[sid] += cur_metrics["Chamfer-L2"][i].item() normal[sid] += cur_metrics["AbsNormalConsistency"][i].item() f1_01[sid] += cur_metrics["F1@%f" % 0.1][i].item() f1_03[sid] += cur_metrics["F1@%f" % 0.3][i].item() f1_05[sid] += cur_metrics["F1@%f" % 0.5][i].item() if vis_preds: img = image_to_numpy(deprocess(imgs[i])) vis_utils.visualize_prediction(id_strs[i], img, meshes_pred[-1][i], "/tmp/output") num_batch_evaluated += 1 logger.info("Evaluated %d / %d batches" % (num_batch_evaluated, len(data_loader))) vis_utils.print_instances_class_histogram( num_instances, class_names, { "chamfer": chamfer, "normal": normal, "f1_01": f1_01, "f1_03": f1_03, "f1_05": f1_05 }, )
def evaluate_split(model, loader, max_predictions=-1, num_predictions_keep=10, prefix="", store_predictions=False): """ This function is used to report validation performance during training. """ # Note that all eval runs on main process assert comm.is_main_process() if isinstance(model, torch.nn.parallel.DistributedDataParallel): model = model.module device = torch.device("cuda:0") num_predictions = 0 num_predictions_kept = 0 predictions = defaultdict(list) metrics = defaultdict(list) deprocess = imagenet_deprocess(rescale_image=False) for batch in loader: batch = loader.postprocess(batch, device) imgs, meshes_gt, points_gt, normals_gt, voxels_gt = batch voxel_scores, meshes_pred = model(imgs) # Only compute metrics for the final predicted meshes, not intermediates cur_metrics = compare_meshes(meshes_pred[-1], meshes_gt) if cur_metrics is None: continue for k, v in cur_metrics.items(): metrics[k].append(v) # Store input images and predicted meshes if store_predictions: N = imgs.shape[0] for i in range(N): if num_predictions_kept >= num_predictions_keep: break num_predictions_kept += 1 img = image_to_numpy(deprocess(imgs[i])) predictions["%simg_input" % prefix].append(img) for level, cur_meshes_pred in enumerate(meshes_pred): verts, faces = cur_meshes_pred.get_mesh(i) verts_key = "%sverts_pred_%d" % (prefix, level) faces_key = "%sfaces_pred_%d" % (prefix, level) predictions[verts_key].append(verts.cpu().numpy()) predictions[faces_key].append(faces.cpu().numpy()) num_predictions += len(meshes_gt) logger.info("Evaluated %d predictions so far" % num_predictions) if 0 < max_predictions <= num_predictions: break # Average numeric metrics, and concatenate images metrics = {"%s%s" % (prefix, k): np.mean(v) for k, v in metrics.items()} if store_predictions: img_key = "%simg_input" % prefix predictions[img_key] = np.stack(predictions[img_key], axis=0) return metrics, predictions
def evaluate_test_p2m(model, data_loader): """ This function evaluates the model on the dataset defined by data_loader. The metrics reported are described in Table 1 of our paper, following previous reported approaches (like Pixel2Mesh - p2m), where meshes are rescaled by a factor of 0.57. See the paper for more details. """ assert comm.is_main_process() device = torch.device("cuda:0") # evaluation class_names = { "02828884": "bench", "03001627": "chair", "03636649": "lamp", "03691459": "speaker", "04090263": "firearm", "04379243": "table", "04530566": "watercraft", "02691156": "plane", "02933112": "cabinet", "02958343": "car", "03211117": "monitor", "04256520": "couch", "04401088": "cellphone", } num_instances = {i: 0 for i in class_names} chamfer = {i: 0 for i in class_names} normal = {i: 0 for i in class_names} f1_1e_4 = {i: 0 for i in class_names} f1_2e_4 = {i: 0 for i in class_names} num_batch_evaluated = 0 for batch in data_loader: batch = data_loader.postprocess(batch, device) imgs, meshes_gt, _, _, _, id_strs = batch sids = [id_str.split("-")[0] for id_str in id_strs] for sid in sids: num_instances[sid] += 1 with inference_context(model): voxel_scores, meshes_pred = model(imgs) # NOTE that for the F1 thresholds we take the square root of 1e-4 & 2e-4 # as `compare_meshes` returns the euclidean distance (L2) of two pointclouds. # In Pixel2Mesh, the squared L2 (L2^2) is computed instead. # i.e. (L2^2 < τ) <=> (L2 < sqrt(τ)) cur_metrics = compare_meshes(meshes_pred[-1], meshes_gt, scale=0.57, thresholds=[0.01, 0.014142], reduce=False) cur_metrics["verts_per_mesh"] = meshes_pred[-1].num_verts_per_mesh( ).cpu() cur_metrics["faces_per_mesh"] = meshes_pred[-1].num_faces_per_mesh( ).cpu() for i, sid in enumerate(sids): chamfer[sid] += cur_metrics["Chamfer-L2"][i].item() normal[sid] += cur_metrics["AbsNormalConsistency"][i].item() f1_1e_4[sid] += cur_metrics["F1@%f" % 0.01][i].item() f1_2e_4[sid] += cur_metrics["F1@%f" % 0.014142][i].item() num_batch_evaluated += 1 logger.info("Evaluated %d / %d batches" % (num_batch_evaluated, len(data_loader))) vis_utils.print_instances_class_histogram_p2m( num_instances, class_names, { "chamfer": chamfer, "normal": normal, "f1_1e_4": f1_1e_4, "f1_2e_4": f1_2e_4 }, )
def evaluate_for_pix3d( predictions, dataset, metadata, filter_iou, mesh_models=None, iou_thresh=0.5, mask_thresh=0.5, device=None, vis_preds=False, ): from PIL import Image if device is None: device = torch.device("cpu") F1_TARGET = "[email protected]" # classes cat_ids = sorted(dataset.getCatIds()) reverse_id_mapping = { v: k for k, v in metadata.thing_dataset_id_to_contiguous_id.items() } # initialize tensors to record box & mask AP, number of gt positives box_apscores, box_aplabels = {}, {} mask_apscores, mask_aplabels = {}, {} mesh_apscores, mesh_aplabels = {}, {} npos = {} for cat_id in cat_ids: box_apscores[cat_id] = [ torch.tensor([], dtype=torch.float32, device=device) ] box_aplabels[cat_id] = [ torch.tensor([], dtype=torch.uint8, device=device) ] mask_apscores[cat_id] = [ torch.tensor([], dtype=torch.float32, device=device) ] mask_aplabels[cat_id] = [ torch.tensor([], dtype=torch.uint8, device=device) ] mesh_apscores[cat_id] = [ torch.tensor([], dtype=torch.float32, device=device) ] mesh_aplabels[cat_id] = [ torch.tensor([], dtype=torch.uint8, device=device) ] npos[cat_id] = 0.0 box_covered = [] mask_covered = [] mesh_covered = [] # number of gt positive instances per class for gt_ann in dataset.dataset["annotations"]: gt_label = gt_ann["category_id"] # examples with imgfiles = {img/table/1749.jpg, img/table/0045.png} # have a mismatch between images and masks. Thus, ignore image_file_name = dataset.loadImgs([gt_ann["image_id"] ])[0]["file_name"] if image_file_name in ["img/table/1749.jpg", "img/table/0045.png"]: continue npos[gt_label] += 1.0 for prediction in predictions: original_id = prediction["image_id"] image_width = dataset.loadImgs([original_id])[0]["width"] image_height = dataset.loadImgs([original_id])[0]["height"] image_size = [image_height, image_width] image_file_name = dataset.loadImgs([original_id])[0]["file_name"] # examples with imgfiles = {img/table/1749.jpg, img/table/0045.png} # have a mismatch between images and masks. Thus, ignore if image_file_name in ["img/table/1749.jpg", "img/table/0045.png"]: continue if "instances" not in prediction: continue num_img_preds = len(prediction["instances"]) if num_img_preds == 0: continue # predictions scores = prediction["instances"].scores boxes = prediction["instances"].pred_boxes.to(device) labels = prediction["instances"].pred_classes masks_rles = prediction["instances"].pred_masks_rle if hasattr(prediction["instances"], "pred_meshes"): meshes = prediction["instances"].pred_meshes # preditected meshes verts = [mesh[0] for mesh in meshes] faces = [mesh[1] for mesh in meshes] meshes = Meshes(verts=verts, faces=faces).to(device) else: meshes = ico_sphere(4, device) meshes = meshes.extend(num_img_preds).to(device) if hasattr(prediction["instances"], "pred_dz"): pred_dz = prediction["instances"].pred_dz heights = boxes.tensor[:, 3] - boxes.tensor[:, 1] # NOTE see appendix for derivation of pred dz pred_dz = pred_dz[:, 0] * heights.cpu() else: raise ValueError("Z range of box not predicted") assert prediction["instances"].image_size[0] == image_height assert prediction["instances"].image_size[1] == image_width # ground truth # anotations corresponding to original_id (aka coco image_id) gt_ann_ids = dataset.getAnnIds(imgIds=[original_id]) assert len( gt_ann_ids) == 1 # note that pix3d has one annotation per image gt_anns = dataset.loadAnns(gt_ann_ids)[0] assert gt_anns["image_id"] == original_id # get original ground truth mask, box, label & mesh maskfile = os.path.join(metadata.image_root, gt_anns["segmentation"]) with PathManager.open(maskfile, "rb") as f: gt_mask = torch.tensor( np.asarray(Image.open(f), dtype=np.float32) / 255.0) assert gt_mask.shape[0] == image_height and gt_mask.shape[ 1] == image_width gt_mask = (gt_mask > 0).to(dtype=torch.uint8) # binarize mask gt_mask_rle = [ mask_util.encode(np.array(gt_mask[:, :, None], order="F"))[0] ] gt_box = np.array(gt_anns["bbox"]).reshape(-1, 4) # xywh from coco gt_box = BoxMode.convert(gt_box, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) gt_label = gt_anns["category_id"] faux_gt_targets = Boxes( torch.tensor(gt_box, dtype=torch.float32, device=device)) # load gt mesh and extrinsics/intrinsics gt_R = torch.tensor(gt_anns["rot_mat"]).to(device) gt_t = torch.tensor(gt_anns["trans_mat"]).to(device) gt_K = torch.tensor(gt_anns["K"]).to(device) if mesh_models is not None: modeltype = gt_anns["model"] gt_verts, gt_faces = ( mesh_models[modeltype][0].clone(), mesh_models[modeltype][1].clone(), ) gt_verts = gt_verts.to(device) gt_faces = gt_faces.to(device) else: # load from disc raise NotImplementedError gt_verts = shape_utils.transform_verts(gt_verts, gt_R, gt_t) gt_zrange = torch.stack([gt_verts[:, 2].min(), gt_verts[:, 2].max()]) gt_mesh = Meshes(verts=[gt_verts], faces=[gt_faces]) # box iou boxiou = pairwise_iou(boxes, faux_gt_targets) # filter predictions with iou > filter_iou valid_pred_ids = boxiou > filter_iou # mask iou miou = mask_util.iou(masks_rles, gt_mask_rle, [0]) # # gt zrange (zrange stores min_z and max_z) # # zranges = torch.stack([gt_zrange] * len(meshes), dim=0) # predicted zrange (= pred_dz) assert hasattr(prediction["instances"], "pred_dz") # It's impossible to predict the center location in Z (=tc) # from the image. See appendix for more. tc = (gt_zrange[1] + gt_zrange[0]) / 2.0 # Given a center location (tc) and a focal_length, # pred_dz = pred_dz * box_h * tc / focal_length # See appendix for more. zranges = torch.stack( [ torch.stack([ tc - tc * pred_dz[i] / 2.0 / gt_K[0], tc + tc * pred_dz[i] / 2.0 / gt_K[0] ]) for i in range(len(meshes)) ], dim=0, ) gt_Ks = gt_K.view(1, 3).expand(len(meshes), 3) meshes = transform_meshes_to_camera_coord_system( meshes, boxes.tensor, zranges, gt_Ks, image_size) if vis_preds: vis_utils.visualize_predictions( original_id, image_file_name, scores, labels, boxes.tensor, masks_rles, meshes, metadata, "/tmp/output", ) shape_metrics = compare_meshes(meshes, gt_mesh, reduce=False) # sort predictions in descending order scores_sorted, idx_sorted = torch.sort(scores, descending=True) for pred_id in range(num_img_preds): # remember we only evaluate the preds that have overlap more than # iou_filter with the ground truth prediction if valid_pred_ids[idx_sorted[pred_id], 0] == 0: continue # map to dataset category id pred_label = reverse_id_mapping[labels[idx_sorted[pred_id]].item()] pred_miou = miou[idx_sorted[pred_id]].item() pred_biou = boxiou[idx_sorted[pred_id]].item() pred_score = scores[idx_sorted[pred_id]].view(1).to(device) # note that metrics returns f1 in % (=x100) pred_f1 = shape_metrics[F1_TARGET][ idx_sorted[pred_id]].item() / 100.0 # mask tpfp = torch.tensor([0], dtype=torch.uint8, device=device) if ((pred_label == gt_label) and (pred_miou > iou_thresh) and (original_id not in mask_covered)): tpfp[0] = 1 mask_covered.append(original_id) mask_apscores[pred_label].append(pred_score) mask_aplabels[pred_label].append(tpfp) # box tpfp = torch.tensor([0], dtype=torch.uint8, device=device) if ((pred_label == gt_label) and (pred_biou > iou_thresh) and (original_id not in box_covered)): tpfp[0] = 1 box_covered.append(original_id) box_apscores[pred_label].append(pred_score) box_aplabels[pred_label].append(tpfp) # mesh tpfp = torch.tensor([0], dtype=torch.uint8, device=device) if ((pred_label == gt_label) and (pred_f1 > iou_thresh) and (original_id not in mesh_covered)): tpfp[0] = 1 mesh_covered.append(original_id) mesh_apscores[pred_label].append(pred_score) mesh_aplabels[pred_label].append(tpfp) # check things for eval # assert npos.sum() == len(dataset.dataset["annotations"]) # convert to tensors pix3d_metrics = {} boxap, maskap, meshap = 0.0, 0.0, 0.0 valid = 0.0 for cat_id in cat_ids: cat_name = dataset.loadCats([cat_id])[0]["name"] if npos[cat_id] == 0: continue valid += 1 cat_box_ap = VOCap.compute_ap(torch.cat(box_apscores[cat_id]), torch.cat(box_aplabels[cat_id]), npos[cat_id]) boxap += cat_box_ap pix3d_metrics["box_ap@%.1f - %s" % (iou_thresh, cat_name)] = cat_box_ap cat_mask_ap = VOCap.compute_ap(torch.cat(mask_apscores[cat_id]), torch.cat(mask_aplabels[cat_id]), npos[cat_id]) maskap += cat_mask_ap pix3d_metrics["mask_ap@%.1f - %s" % (iou_thresh, cat_name)] = cat_mask_ap cat_mesh_ap = VOCap.compute_ap(torch.cat(mesh_apscores[cat_id]), torch.cat(mesh_aplabels[cat_id]), npos[cat_id]) meshap += cat_mesh_ap pix3d_metrics["mesh_ap@%.1f - %s" % (iou_thresh, cat_name)] = cat_mesh_ap pix3d_metrics["box_ap@%.1f" % iou_thresh] = boxap / valid pix3d_metrics["mask_ap@%.1f" % iou_thresh] = maskap / valid pix3d_metrics["mesh_ap@%.1f" % iou_thresh] = meshap / valid # print test ground truth vis_utils.print_instances_class_histogram( [npos[cat_id] for cat_id in cat_ids], # number of instances [dataset.loadCats([cat_id])[0]["name"] for cat_id in cat_ids], # class names pix3d_metrics, ) return pix3d_metrics
def evaluate_test(model, data_loader, vis_preds=False): """ This function evaluates the model on the dataset defined by data_loader. The metrics reported are described in Table 2 of our paper. """ # Note that all eval runs on main process assert comm.is_main_process() deprocess = imagenet_deprocess(rescale_image=False) device = torch.device("cuda:0") # evaluation class_names = { "02828884": "bench", "03001627": "chair", "03636649": "lamp", "03691459": "speaker", "04090263": "firearm", "04379243": "table", "04530566": "watercraft", "02691156": "plane", "02933112": "cabinet", "02958343": "car", "03211117": "monitor", "04256520": "couch", "04401088": "cellphone", } num_instances = {i: 0 for i in class_names} chamfer = {i: 0 for i in class_names} normal = {i: 0 for i in class_names} f1_01 = {i: 0 for i in class_names} f1_03 = {i: 0 for i in class_names} f1_05 = {i: 0 for i in class_names} num_batch_evaluated = 0 for batch in data_loader: batch = data_loader.postprocess(batch, device) sids = [id_str.split("-")[0] for id_str in batch["id_strs"]] for sid in sids: num_instances[sid] += 1 with inference_context(model): model_kwargs = {} module = model.module if hasattr(model, "module") else model if isinstance(module, VoxMeshMultiViewHead): model_kwargs["intrinsics"] = batch["intrinsics"] model_kwargs["extrinsics"] = batch["extrinsics"] if isinstance(module, VoxMeshDepthHead): model_kwargs["masks"] = batch["masks"] model_outputs = model(batch["imgs"], **model_kwargs) voxel_scores = model_outputs["voxel_scores"] meshes_pred = model_outputs["meshes_pred"] cur_metrics = compare_meshes(meshes_pred[-1], batch["meshes"], reduce=False) cur_metrics["verts_per_mesh"] = meshes_pred[-1].num_verts_per_mesh().cpu() cur_metrics["faces_per_mesh"] = meshes_pred[-1].num_faces_per_mesh().cpu() for i, sid in enumerate(sids): chamfer[sid] += cur_metrics["Chamfer-L2"][i].item() normal[sid] += cur_metrics["AbsNormalConsistency"][i].item() f1_01[sid] += cur_metrics["F1@%f" % 0.1][i].item() f1_03[sid] += cur_metrics["F1@%f" % 0.3][i].item() f1_05[sid] += cur_metrics["F1@%f" % 0.5][i].item() if vis_preds: img = image_to_numpy(deprocess(batch["imgs"][i])) vis_utils.visualize_prediction( batch["id_strs"][i], img, meshes_pred[-1][i], "/tmp/output" ) num_batch_evaluated += 1 logger.info("Evaluated %d / %d batches" % (num_batch_evaluated, len(data_loader))) vis_utils.print_instances_class_histogram( num_instances, class_names, {"chamfer": chamfer, "normal": normal, "f1_01": f1_01, "f1_03": f1_03, "f1_05": f1_05}, )
def evaluate_vox(model, loader, prediction_dir=None, max_predictions=-1): """ This function is used to report validation performance of voxel head output """ # Note that all eval runs on main process assert comm.is_main_process() if isinstance(model, torch.nn.parallel.DistributedDataParallel): model = model.module if prediction_dir is not None: for prefix in ["merged", "vox_0", "vox_1", "vox_2"]: output_dir = pred_filename = os.path.join( prediction_dir, prefix, "predict", "0" ) os.makedirs(output_dir, exist_ok=True) device = torch.device("cuda:0") metrics = defaultdict(list) deprocess = imagenet_deprocess(rescale_image=False) for batch_idx, batch in tqdm.tqdm(enumerate(loader)): if max_predictions >= 1 and batch_idx > max_predictions: break batch = loader.postprocess(batch, device) model_kwargs = {} module = model.module if hasattr(model, "module") else model if isinstance(module, VoxMeshMultiViewHead): model_kwargs["intrinsics"] = batch["intrinsics"] model_kwargs["extrinsics"] = batch["extrinsics"] if isinstance(module, VoxDepthHead): model_kwargs["masks"] = batch["masks"] if module.cfg.MODEL.USE_GT_DEPTH: model_kwargs["depths"] = batch["depths"] model_outputs = model(batch["imgs"], **model_kwargs) voxel_scores = model_outputs["voxel_scores"] transformed_voxel_scores = model_outputs["transformed_voxel_scores"] merged_voxel_scores = model_outputs.get( "merged_voxel_scores", None ) # NOTE that for the F1 thresholds we take the square root of 1e-4 & 2e-4 # as `compare_meshes` returns the euclidean distance (L2) of two pointclouds. # In Pixel2Mesh, the squared L2 (L2^2) is computed instead. # i.e. (L2^2 < τ) <=> (L2 < sqrt(τ)) if "meshes_pred" in model_outputs: meshes_pred = model_outputs["meshes_pred"] cur_metrics = compare_meshes( meshes_pred[-1], batch["meshes"], scale=0.57, thresholds=[0.01, 0.014142] ) for k, v in cur_metrics.items(): metrics["final_" + k].append(v) voxel_losses = MeshLoss.voxel_loss( voxel_scores, merged_voxel_scores, batch["voxels"] ) # to get metric negate loss for k, v in voxel_losses.items(): metrics[k].append(-v.detach().item()) # save meshes if prediction_dir is not None: # cubify all the voxel scores merged_vox_mesh = cubify( merged_voxel_scores, module.voxel_size, module.cubify_threshold ) # transformed_vox_mesh = [cubify( # i, module.voxel_size, module.cubify_threshold # ) for i in transformed_voxel_scores] vox_meshes = { "merged": merged_vox_mesh, # **{ # "vox_%d" % i: mesh # for i, mesh in enumerate(transformed_vox_mesh) # } } gt_mesh = batch["meshes"].scale_verts(0.57) gt_points = sample_points_from_meshes( gt_mesh, 9000, return_normals=False ) gt_points = gt_points.cpu().detach().numpy() for mesh_idx in range(len(batch["id_strs"])): label, label_appendix \ = batch["id_strs"][mesh_idx].split("-")[:2] for prefix, vox_mesh in vox_meshes.items(): output_dir = pred_filename = os.path.join( prediction_dir, prefix, "predict", "0" ) pred_filename = os.path.join( output_dir, "{}_{}_predict.xyz".format(label, label_appendix) ) gt_filename = os.path.join( output_dir, "{}_{}_ground.xyz".format(label, label_appendix) ) pred_mesh = vox_mesh[mesh_idx].scale_verts(0.57) pred_points = sample_points_from_meshes( pred_mesh, 6466, return_normals=False ) pred_points = pred_points.squeeze(0).cpu() \ .detach().numpy() np.savetxt(pred_filename, pred_points) np.savetxt(gt_filename, gt_points[mesh_idx]) # find accuracy of each cubified voxel meshes for prefix, vox_mesh in vox_meshes.items(): vox_mesh_metrics = compare_meshes( vox_mesh, batch["meshes"], scale=0.57, thresholds=[0.01, 0.014142] ) if vox_mesh_metrics is None: continue for k, v in vox_mesh_metrics.items(): metrics[prefix + "_" + k].append(v) # Average numeric metrics, and concatenate images metrics = {k: np.mean(v) for k, v in metrics.items()} return metrics
def evaluate_split( model, loader, max_predictions=-1, num_predictions_keep=10, prefix="", store_predictions=False ): """ This function is used to report validation performance during training. """ # Note that all eval runs on main process assert comm.is_main_process() if isinstance(model, torch.nn.parallel.DistributedDataParallel): model = model.module device = torch.device("cuda:0") num_predictions = 0 num_predictions_kept = 0 predictions = defaultdict(list) metrics = defaultdict(list) deprocess = imagenet_deprocess(rescale_image=False) for batch in loader: batch = loader.postprocess(batch, device) model_kwargs = {} module = model.module if hasattr(model, "module") else model if isinstance(module, VoxMeshMultiViewHead): model_kwargs["intrinsics"] = batch["intrinsics"] model_kwargs["extrinsics"] = batch["extrinsics"] if isinstance(module, VoxMeshDepthHead): model_kwargs["masks"] = batch["masks"] if module.cfg.MODEL.USE_GT_DEPTH: model_kwargs["depths"] = batch["depths"] model_outputs = model(batch["imgs"], **model_kwargs) meshes_pred = model_outputs["meshes_pred"] voxel_scores = model_outputs["voxel_scores"] merged_voxel_scores = model_outputs.get( "merged_voxel_scores", None ) # Only compute metrics for the final predicted meshes, not intermediates cur_metrics = compare_meshes(meshes_pred[-1], batch["meshes"]) if cur_metrics is None: continue for k, v in cur_metrics.items(): metrics[k].append(v) voxel_losses = MeshLoss.voxel_loss( voxel_scores, merged_voxel_scores, batch["voxels"] ) # to get metric negate loss for k, v in voxel_losses.items(): metrics[k].append(-v.item()) # Store input images and predicted meshes if store_predictions: N = batch["imgs"].shape[0] for i in range(N): if num_predictions_kept >= num_predictions_keep: break num_predictions_kept += 1 img = image_to_numpy(deprocess(batch["imgs"][i])) predictions["%simg_input" % prefix].append(img) for level, cur_meshes_pred in enumerate(meshes_pred): verts, faces = cur_meshes_pred.get_mesh(i) verts_key = "%sverts_pred_%d" % (prefix, level) faces_key = "%sfaces_pred_%d" % (prefix, level) predictions[verts_key].append(verts.cpu().numpy()) predictions[faces_key].append(faces.cpu().numpy()) num_predictions += len(batch["meshes"]) logger.info("Evaluated %d predictions so far" % num_predictions) if 0 < max_predictions <= num_predictions: break # Average numeric metrics, and concatenate images metrics = {"%s%s" % (prefix, k): np.mean(v) for k, v in metrics.items()} if store_predictions: img_key = "%simg_input" % prefix predictions[img_key] = np.stack(predictions[img_key], axis=0) return metrics, predictions