Пример #1
0
 def build_train_loss(self):
     # Set up loss functions
     self.chamfer_dist = torch.nn.DataParallel(ChamferDistance().to(
         self.gpu_ids[0]),
                                               device_ids=self.gpu_ids)
     self.chamfer_dist_mean = torch.nn.DataParallel(
         ChamferDistanceMean().to(self.gpu_ids[0]), device_ids=self.gpu_ids)
     self.emd_dist = torch.nn.DataParallel(emd.emdModule().to(
         self.gpu_ids[0]),
                                           device_ids=self.gpu_ids)
Пример #2
0
 def build_val_loss(self):
     self.chamfer_dist_mean = ChamferDistanceMean().cuda()
     self.emd_dist = emd.emdModule().cuda()
Пример #3
0
 def build_val_loss(self):
     # Set up loss functions
     self.chamfer_dist = ChamferDistance().cuda()
     self.chamfer_dist_mean = ChamferDistanceMean().cuda()
     self.emd_dist = emd.emdModule().cuda()
Пример #4
0
class Metrics(object):
    ITEMS = [
        {
            "name": "F-Score",
            "enabled": True,
            "eval_func": "cls._get_f_score",
            "is_greater_better": True,
            "init_value": 0,
        },
        {
            "name": "ChamferDistance",
            "enabled": True,
            "eval_func": "cls._get_chamfer_distance",
            "eval_object": ChamferDistanceMean(),
            "is_greater_better": False,
            "init_value": 32767,
        },
        {
            "name": "EMD",
            "enabled": True,
            "eval_func": "cls._get_emd",
            "eval_object": emd.emdModule(),
            "is_greater_better": False,
            "init_value": 32767,
        },
    ]

    @classmethod
    def get(cls, pred, gt):
        _items = cls.items()
        _values = [0] * len(_items)
        for i, item in enumerate(_items):
            eval_func = eval(item["eval_func"])
            _values[i] = eval_func(pred, gt)

        return _values

    @classmethod
    def items(cls):
        return [i for i in cls.ITEMS if i["enabled"]]

    @classmethod
    def names(cls):
        _items = cls.items()
        return [i["name"] for i in _items]

    @classmethod
    def _get_f_score(cls, pred, gt, th=0.01):
        """References: https://github.com/lmb-freiburg/what3d/blob/master/util.py"""
        pred = cls._get_open3d_ptcloud(pred)
        gt = cls._get_open3d_ptcloud(gt)

        dist1 = pred.compute_point_cloud_distance(gt)
        dist2 = gt.compute_point_cloud_distance(pred)

        recall = float(sum(d < th for d in dist2)) / float(len(dist2))
        precision = float(sum(d < th for d in dist1)) / float(len(dist1))
        return 2 * recall * precision / (
            recall + precision) if recall + precision else 0

    @classmethod
    def _get_open3d_ptcloud(cls, tensor):
        tensor = tensor.squeeze().cpu().numpy()
        ptcloud = open3d.geometry.PointCloud()
        ptcloud.points = open3d.utility.Vector3dVector(tensor)

        return ptcloud

    @classmethod
    def _get_chamfer_distance(cls, pred, gt):
        chamfer_distance = cls.ITEMS[1]["eval_object"]
        return chamfer_distance(pred, gt).item() * 1000

    @classmethod
    def _get_emd(cls, pred, gt):
        EMD = cls.ITEMS[2]["eval_object"]
        dist, _ = EMD(pred, gt, eps=0.005, iters=50)  # for val
        # dist, _ = EMD(pred, gt, 0.002, 10000) # final test ?
        emd = torch.sqrt(dist).mean(1).mean()
        return emd.item() * 100

    def __init__(self, metric_name, values):
        self._items = Metrics.items()
        self._values = [item["init_value"] for item in self._items]
        self.metric_name = metric_name

        if type(values).__name__ == "dict":
            metric_indexes = {}
            for idx, item in enumerate(self._items):
                item_name = item["name"]
                metric_indexes[item_name] = idx
            for k, v in values.items():
                if k not in metric_indexes:
                    logger.warn("Ignore Metric[Name=%s] due to disability." %
                                k)
                    continue
                self._values[metric_indexes[k]] = v
        elif type(values).__name__ == "list":
            self._values = values
        else:
            raise Exception("Unsupported value type: %s" % type(values))

    def state_dict(self):
        _dict = {}
        for i in range(len(self._items)):
            item = self._items[i]["name"]
            value = self._values[i]
            _dict[item] = value

        return _dict

    def __repr__(self):
        return str(self.state_dict())

    def better_than(self, other):
        if other is None:
            return True

        _index = -1
        for i, _item in enumerate(self._items):
            if _item["name"] == self.metric_name:
                _index = i
                break
        if _index == -1:
            raise Exception("Invalid metric name to compare.")

        _metric = self._items[i]
        _value = self._values[_index]
        other_value = other._values[_index]
        return _value > other_value if _metric[
            "is_greater_better"] else _value < other_value