예제 #1
0
    def _eval_predictions_precision(self):
        """NOTE: eval precision instead of recall
        Evaluate self._predictions on 6d pose.
        Return results with the metrics of the tasks.
        """
        self._logger.info("Eval results ...")
        cfg = self.cfg
        method_name = f"{cfg.EXP_ID.replace('_', '-')}"
        cache_path = osp.join(self._output_dir,
                              f"{method_name}_{self.dataset_name}_preds.pkl")
        if osp.exists(cache_path) and self.use_cache:
            self._logger.info("load cached predictions")
            self._predictions = mmcv.load(cache_path)
        else:
            if hasattr(self, "_predictions"):
                mmcv.dump(self._predictions, cache_path)
            else:
                raise RuntimeError("Please run inference first")

        precisions = OrderedDict()
        errors = OrderedDict()
        self.get_gts()

        error_names = ["ad", "re", "te", "proj"]
        metric_names = [
            "ad_2",
            "ad_5",
            "ad_10",
            "rete_2",
            "rete_5",
            "rete_10",
            "re_2",
            "re_5",
            "re_10",
            "te_2",
            "te_5",
            "te_10",
            "proj_2",
            "proj_5",
            "proj_10",
        ]

        for obj_name in self.gts:
            if obj_name not in self._predictions:
                continue
            cur_label = self.obj_names.index(obj_name)
            if obj_name not in precisions:
                precisions[obj_name] = OrderedDict()
                for metric_name in metric_names:
                    precisions[obj_name][metric_name] = []

            if obj_name not in errors:
                errors[obj_name] = OrderedDict()
                for err_name in error_names:
                    errors[obj_name][err_name] = []

            #################
            obj_gts = self.gts[obj_name]
            obj_preds = self._predictions[obj_name]
            for file_name, gt_anno in obj_gts.items():
                # compute precision as in DPOD paper
                if file_name not in obj_preds:  # no pred found
                    # NOTE: just ignore undetected
                    continue
                # compute each metric
                R_pred = obj_preds[file_name]["R"]
                t_pred = obj_preds[file_name]["t"]

                R_gt = gt_anno["R"]
                t_gt = gt_anno["t"]

                t_error = te(t_pred, t_gt)

                if obj_name in cfg.DATASETS.SYM_OBJS:
                    R_gt_sym = get_closest_rot(
                        R_pred, R_gt, self._metadata.sym_infos[cur_label])
                    r_error = re(R_pred, R_gt_sym)

                    proj_2d_error = arp_2d(
                        R_pred,
                        t_pred,
                        R_gt_sym,
                        t_gt,
                        pts=self.models_3d[cur_label]["pts"],
                        K=gt_anno["K"])

                    ad_error = adi(R_pred,
                                   t_pred,
                                   R_gt,
                                   t_gt,
                                   pts=self.models_3d[self.obj_names.index(
                                       obj_name)]["pts"])
                else:
                    r_error = re(R_pred, R_gt)

                    proj_2d_error = arp_2d(
                        R_pred,
                        t_pred,
                        R_gt,
                        t_gt,
                        pts=self.models_3d[cur_label]["pts"],
                        K=gt_anno["K"])

                    ad_error = add(R_pred,
                                   t_pred,
                                   R_gt,
                                   t_gt,
                                   pts=self.models_3d[self.obj_names.index(
                                       obj_name)]["pts"])

                #########
                errors[obj_name]["ad"].append(ad_error)
                errors[obj_name]["re"].append(r_error)
                errors[obj_name]["te"].append(t_error)
                errors[obj_name]["proj"].append(proj_2d_error)
                ############
                precisions[obj_name]["ad_2"].append(
                    float(ad_error < 0.02 * self.diameters[cur_label]))
                precisions[obj_name]["ad_5"].append(
                    float(ad_error < 0.05 * self.diameters[cur_label]))
                precisions[obj_name]["ad_10"].append(
                    float(ad_error < 0.1 * self.diameters[cur_label]))
                # deg, cm
                precisions[obj_name]["rete_2"].append(
                    float(r_error < 2 and t_error < 0.02))
                precisions[obj_name]["rete_5"].append(
                    float(r_error < 5 and t_error < 0.05))
                precisions[obj_name]["rete_10"].append(
                    float(r_error < 10 and t_error < 0.1))

                precisions[obj_name]["re_2"].append(float(r_error < 2))
                precisions[obj_name]["re_5"].append(float(r_error < 5))
                precisions[obj_name]["re_10"].append(float(r_error < 10))

                precisions[obj_name]["te_2"].append(float(t_error < 0.02))
                precisions[obj_name]["te_5"].append(float(t_error < 0.05))
                precisions[obj_name]["te_10"].append(float(t_error < 0.1))
                # px
                precisions[obj_name]["proj_2"].append(float(proj_2d_error < 2))
                precisions[obj_name]["proj_5"].append(float(proj_2d_error < 5))
                precisions[obj_name]["proj_10"].append(
                    float(proj_2d_error < 10))

        # summarize
        obj_names = sorted(list(precisions.keys()))
        header = ["objects"] + obj_names + [f"Avg({len(obj_names)})"]
        big_tab = [header]
        for metric_name in metric_names:
            line = [metric_name]
            this_line_res = []
            for obj_name in obj_names:
                res = precisions[obj_name][metric_name]
                if len(res) > 0:
                    line.append(f"{100 * np.mean(res):.2f}")
                    this_line_res.append(np.mean(res))
                else:
                    line.append(0.0)
                    this_line_res.append(0.0)
            # mean
            if len(obj_names) > 0:
                line.append(f"{100 * np.mean(this_line_res):.2f}")
            big_tab.append(line)

        for error_name in ["re", "te"]:
            line = [error_name]
            this_line_res = []
            for obj_name in obj_names:
                res = errors[obj_name][error_name]
                if len(res) > 0:
                    line.append(f"{np.mean(res):.2f}")
                    this_line_res.append(np.mean(res))
                else:
                    line.append(float("nan"))
                    this_line_res.append(float("nan"))
            # mean
            if len(obj_names) > 0:
                line.append(f"{np.mean(this_line_res):.2f}")
            big_tab.append(line)
        ### log big table
        self._logger.info("precisions")
        res_log_tab_str = tabulate(
            big_tab,
            tablefmt="plain",
            # floatfmt=floatfmt
        )
        self._logger.info("\n{}".format(res_log_tab_str))
        errors_cache_path = osp.join(
            self._output_dir, f"{method_name}_{self.dataset_name}_errors.pkl")
        recalls_cache_path = osp.join(
            self._output_dir,
            f"{method_name}_{self.dataset_name}_precisions.pkl")
        self._logger.info(f"{errors_cache_path}")
        self._logger.info(f"{recalls_cache_path}")
        mmcv.dump(errors, errors_cache_path)
        mmcv.dump(precisions, recalls_cache_path)

        dump_tab_name = osp.join(
            self._output_dir,
            f"{method_name}_{self.dataset_name}_tab_precisions.txt")
        with open(dump_tab_name, "w") as f:
            f.write("{}\n".format(res_log_tab_str))
        if self._distributed:
            self._logger.warning(
                "\n The current evaluation on multi-gpu is not correct, run with single-gpu instead."
            )
        return {}
예제 #2
0
                                e = [pose_error.mssd(R_e, t_e, R_g, t_g, models[obj_id]["pts"], models_sym[obj_id])]

                        elif p["error_type"] == "mspd":
                            e = [pose_error.mspd(R_e, t_e, R_g, t_g, K, models[obj_id]["pts"], models_sym[obj_id])]

                        elif p["error_type"] in ["ad", "add", "adi"]:
                            if not spheres_overlap:
                                # Infinite error if the bounding spheres do not overlap. With
                                # typically used values of the correctness threshold for the AD
                                # error (e.g. k*diameter, where k = 0.1), such pose estimates
                                # would be considered incorrect anyway.
                                e = [float("inf")]
                            else:
                                if p["error_type"] == "ad":
                                    if obj_id in dp_model["symmetric_obj_ids"]:
                                        e = [pose_error.adi(R_e, t_e, R_g, t_g, models[obj_id]["pts"])]
                                    else:
                                        e = [pose_error.add(R_e, t_e, R_g, t_g, models[obj_id]["pts"])]

                                elif p["error_type"] == "add":
                                    e = [pose_error.add(R_e, t_e, R_g, t_g, models[obj_id]["pts"])]

                                else:  # 'adi'
                                    e = [pose_error.adi(R_e, t_e, R_g, t_g, models[obj_id]["pts"])]

                        ################################
                        elif p["error_type"] in ["ABSad", "ABSadd", "ABSadi", "AUCad", "AUCadd", "AUCadi"]:
                            if p["error_type"] in ["ABSad", "AUCad"]:
                                if obj_id in dp_model["symmetric_obj_ids"]:
                                    e = [pose_error.adi(R_e, t_e, R_g, t_g, models[obj_id]["pts"]) / 10]  # mm to cm
                                else: