コード例 #1
0
    def get_tracker(self, wandb_log: bool, tensorboard_log: bool):
        """
        Factory method for the tracker

        Arguments:
            wandb_log - Log using weight and biases
            tensorboard_log - Log using tensorboard
        Returns:
            [BaseTracker] -- tracker
        """
        if self.is_patch:
            return PatchRegistrationTracker(self,
                                            wandb_log=wandb_log,
                                            use_tensorboard=tensorboard_log)
        else:
            if self.is_end2end:
                raise NotImplementedError("implement end2end tracker")
            else:
                return FragmentRegistrationTracker(
                    num_points=self.num_points,
                    tau_1=self.tau_1,
                    tau_2=self.tau_2,
                    rot_thresh=self.rot_thresh,
                    trans_thresh=self.trans_thresh,
                    wandb_log=wandb_log,
                    use_tensorboard=tensorboard_log)
コード例 #2
0
 def test_track_all(self):
     tracker = FragmentRegistrationTracker(MockDataset(), stage="test", tau_2=0.83, num_points=100)
     model = MockModel()
     tracker.reset("test")
     model.iter = 0
     for i in range(4):
         tracker.track(model)
         model.iter += 1
     metrics = tracker.get_metrics()
     self.assertAlmostEqual(metrics["test_hit_ratio"], (4 * 1.0 + 4 * 0.9 + 4 * 0.8 + 0.9 + 0.84 + 0.8 + 0.7) / 16)
     self.assertAlmostEqual(metrics["test_feat_match_ratio"], (4 * 1 + 4 * 1 + 4 * 0 + 2 * 1 + 2 * 0) / 16)
コード例 #3
0
    def get_tracker(self, wandb_log: bool, tensorboard_log: bool):
        """
        Factory method for the tracker

        Arguments:
            wandb_log - Log using weight and biases
            tensorboard_log - Log using tensorboard
        Returns:
            [BaseTracker] -- tracker
        """
        if self.is_patch:
            return PatchRegistrationTracker(self, wandb_log=wandb_log, use_tensorboard=tensorboard_log)
        else:
            return FragmentRegistrationTracker(self, wandb_log=wandb_log, use_tensorboard=tensorboard_log)
コード例 #4
0
 def test_track_batch(self):
     tracker = FragmentRegistrationTracker(MockDataset(), stage="test", tau_2=0.83, num_points=100)
     model = MockModel()
     list_hit_ratio = [1.0, 0.9, 0.8, (0.9 + 0.84 + 0.8 + 0.7) / 4]
     list_feat_match_ratio = [1.0, 1.0, 0.0, 0.5]
     for i in range(4):
         tracker.track(model)
         metrics = tracker.get_metrics()
         # the most important metrics in registration
         self.assertAlmostEqual(metrics["test_hit_ratio"], list_hit_ratio[i])
         self.assertAlmostEqual(metrics["test_feat_match_ratio"], list_feat_match_ratio[i])
         tracker.reset("test")
         model.iter += 1
コード例 #5
0
    def get_tracker(self, wandb_log: bool, tensorboard_log: bool):
        """
        Factory method for the tracker

        Arguments:
            wandb_log - Log using weight and biases
            tensorboard_log - Log using tensorboard
        Returns:
            [BaseTracker] -- tracker
        """
        return FragmentRegistrationTracker(self,
                                           wandb_log=wandb_log,
                                           use_tensorboard=tensorboard_log,
                                           tau_1=self.tau_1,
                                           rot_thresh=self.rot_thresh,
                                           trans_thresh=self.trans_thresh)
コード例 #6
0
    def get_tracker(model, dataset, wandb_log: bool, tensorboard_log: bool):
        """
        Factory method for the tracker

        Arguments:
            task {str} -- task description
            dataset {[type]}
            wandb_log - Log using weight and biases
        Returns:
            [BaseTracker] -- tracker
        """
        if (dataset.is_patch):
            return PatchRegistrationTracker(dataset,
                                            wandb_log=wandb_log,
                                            use_tensorboard=tensorboard_log)
        else:
            return FragmentRegistrationTracker(dataset,
                                               wandb_log=wandb_log,
                                               use_tensorboard=tensorboard_log)