Example #1
0
    def track(self, model: model_interface.TrackerInterface, full_res=False, data=None, **kwargs):
        """ Add current model predictions (usually the result of a batch) to the tracking
        """
        super().track(model)

        # Train mode or low res, nothing special to do
        if self._stage == "train" or not full_res:
            return

        # Test mode, compute votes in order to get full res predictions
        if self._test_area is None:
            self._test_area = self._dataset.test_data.clone()
            if self._test_area.y is None:
                raise ValueError("It seems that the test area data does not have labels (attribute y).")
            self._test_area.prediction_count = torch.zeros(self._test_area.y.shape[0], dtype=torch.int)
            self._test_area.votes = torch.zeros((self._test_area.y.shape[0], self._num_classes), dtype=torch.float)
            self._test_area.to(model.device)

        # Gather origin ids and check that it fits with the test set
        inputs = data if data is not None else model.get_input()
        if inputs[SaveOriginalPosId.KEY] is None:
            raise ValueError("The inputs given to the model do not have a %s attribute." % SaveOriginalPosId.KEY)

        originids = inputs[SaveOriginalPosId.KEY]
        if originids.dim() == 2:
            originids = originids.flatten()
        if originids.max() >= self._test_area.pos.shape[0]:
            raise ValueError("Origin ids are larger than the number of points in the original point cloud.")

        # Set predictions
        outputs = model.get_output()
        self._test_area.votes[originids] += outputs
        self._test_area.prediction_count[originids] += 1
Example #2
0
    def track(self, model: model_interface.TrackerInterface, **kwargs):
        super().track(model)
        if self._stage != "train":
            batch_idx, batch_idx_target = model.get_batch()
            # batch_xyz, batch_xyz_target = model.get_xyz()  # type: ignore
            # batch_ind, batch_ind_target, batch_size_ind = model.get_ind()  # type: ignore
            input, input_target = model.get_input()
            batch_xyz, batch_xyz_target = input.pos, input_target.pos
            batch_ind, batch_ind_target, batch_size_ind = input.ind, input_target.ind, input.size
            batch_feat, batch_feat_target = model.get_output()

            nb_batches = batch_idx.max() + 1
            cum_sum = 0
            cum_sum_target = 0
            begin = 0
            end = batch_size_ind[0].item()
            for b in range(nb_batches):
                xyz = batch_xyz[batch_idx == b]
                xyz_target = batch_xyz_target[batch_idx_target == b]
                feat = batch_feat[batch_idx == b]
                feat_target = batch_feat_target[batch_idx_target == b]
                # as we have concatenated ind,
                # we need to substract the cum_sum because we deal
                # with each batch independently
                # ind = batch_ind[b * len(batch_ind) / nb_batches : (b + 1) * len(batch_ind) / nb_batches] - cum_sum
                # ind_target = (batch_ind_target[b * len(batch_ind_target) / nb_batches : (b + 1) * len(batch_ind_target) / nb_batches]- cum_sum_target)
                ind = batch_ind[begin:end] - cum_sum
                ind_target = batch_ind_target[begin:end] - cum_sum_target
                # print(begin, end)
                if b < nb_batches - 1:
                    begin = end
                    end = begin + batch_size_ind[b + 1].item()
                cum_sum += len(xyz)
                cum_sum_target += len(xyz_target)
                rand = torch.randperm(len(feat))[:self.num_points]
                rand_target = torch.randperm(
                    len(feat_target))[:self.num_points]

                matches_gt = torch.stack([ind, ind_target]).transpose(0, 1)

                # print(matches_gt.max(0), len(xyz), len(xyz_target), len(matches_gt))
                # print(batch_ind.shape, nb_batches)
                T_gt = estimate_transfo(xyz[matches_gt[:, 0]],
                                        xyz_target[matches_gt[:, 1]])

                matches_pred = get_matches(feat[rand],
                                           feat_target[rand_target])
                T_pred = fast_global_registration(
                    xyz[rand][matches_pred[:, 0]],
                    xyz_target[rand_target][matches_pred[:, 1]])

                hit_ratio = compute_hit_ratio(
                    xyz[rand][matches_pred[:, 0]],
                    xyz_target[rand_target][matches_pred[:,
                                                         1]], T_gt, self.tau_1)

                trans_error, rot_error = compute_transfo_error(T_pred, T_gt)

                sr_err = compute_scaled_registration_error(xyz, T_gt, T_pred)
                self._hit_ratio.add(hit_ratio.item())
                self._feat_match_ratio.add(
                    float(hit_ratio.item() > self.tau_2))
                self._trans_error.add(trans_error.item())
                self._rot_error.add(rot_error.item())
                self._rre.add(rot_error.item() < self.rot_thresh)
                self._rte.add(trans_error.item() < self.trans_thresh)
                self._sr_err.add(sr_err.item())