def test_track_basic(self):
     tracker = PanopticTracker(MockDataset())
     model = MockModel()
     tracker.track(
         model,
         data=Data(pos=torch.tensor([[1, 2]]),
                   batch=torch.tensor([0, 0, 0])),
         min_cluster_points=0,
         iou_threshold=0.25,
     )
     metrics = tracker.get_metrics()
     self.assertAlmostEqual(metrics["train_Iacc"], 1)
     self.assertAlmostEqual(metrics["train_pos"], 1)
     self.assertAlmostEqual(metrics["train_neg"], 0)
예제 #2
0
    def get_tracker(self, wandb_log: bool, tensorboard_log: bool):
        """Factory method for the tracker

        Arguments:
            wandb_log - Log using weight and biases
        Returns:
            [BaseTracker] -- tracker
        """
        return PanopticTracker(self,
                               wandb_log=wandb_log,
                               use_tensorboard=tensorboard_log)
 def test_track_finalise(self):
     tracker = PanopticTracker(MockDataset())
     model = MockModel()
     tracker.track(
         model,
         data=Data(pos=torch.tensor([[1, 2]]),
                   batch=torch.tensor([0, 0, 0])),
         min_cluster_points=0,
         iou_threshold=0.25,
         track_instances=True,
     )
     tracker.finalise(
         track_instances=True,
         iou_threshold=0.25,
     )
     metrics = tracker.get_metrics()
     self.assertAlmostEqual(metrics["train_map"], 1)