Exemplo n.º 1
0
    def track(self, model: TrackerInterface, **kwargs):
        """ Track metrics for panoptic segmentation
        """
        BaseTracker.track(self, model)
        outputs: PanopticResults = model.get_output()
        labels: PanopticLabels = model.get_labels()

        # Track semantic
        super()._compute_metrics(outputs.semantic_logits, labels.y)
Exemplo n.º 2
0
def test_epoch(
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    voting_runs=1,
    tracker_options={},
):

    loaders = dataset.test_dataloaders

    for loader in loaders:
        stage_name = loader.dataset.name
        tracker.reset(stage_name)
        for i in range(voting_runs):
            with Ctq(loader) as tq_test_loader:
                for data in tq_test_loader:
                    with torch.no_grad():
                        model.set_input(data, device)
                        model.forward()

                    tracker.track(model, **tracker_options)
                    tq_test_loader.set_postfix(**tracker.get_metrics(),
                                               color=COLORS.TEST_COLOR)

        tracker.finalise(**tracker_options)
        tracker.print_summary()
Exemplo n.º 3
0
    def track(self,
              model: TrackerInterface,
              data=None,
              iou_threshold=0.25,
              track_instances=True,
              min_cluster_points=10,
              **kwargs):
        """ Track metrics for panoptic segmentation
        """
        self._iou_threshold = iou_threshold
        BaseTracker.track(self, model)
        outputs: PanopticResults = model.get_output()
        labels: PanopticLabels = model.get_labels()

        # Track semantic
        super()._compute_metrics(outputs.semantic_logits, labels.y)

        if not data:
            return
        assert data.pos.dim() == 2, "Only supports packed batches"

        # Object accuracy
        clusters = PanopticTracker._extract_clusters(outputs,
                                                     min_cluster_points)
        if not clusters:
            return

        predicted_labels = outputs.semantic_logits.max(1)[1]
        tp, fp, acc = self._compute_acc(clusters, predicted_labels, labels,
                                        data.batch, labels.num_instances,
                                        iou_threshold)
        self._pos.add(tp)
        self._neg.add(fp)
        self._acc_meter.add(acc)

        # Track instances for AP
        if track_instances:
            pred_clusters = self._pred_instances_per_scan(
                clusters, predicted_labels, outputs.cluster_scores, data.batch,
                self._scan_id_offset)
            gt_clusters = self._gt_instances_per_scan(labels.instance_labels,
                                                      labels.y, data.batch,
                                                      self._scan_id_offset)
            self._ap_meter.add(pred_clusters, gt_clusters)
            self._scan_id_offset += data.batch[-1].item() + 1
Exemplo n.º 4
0
def train_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device: str,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    debugging,
):

    early_break = getattr(debugging, "early_break", False)
    profiling = getattr(debugging, "profiling", False)

    model.train()
    tracker.reset("train")
    visualizer.reset(epoch, "train")
    train_loader = dataset.train_dataloader

    iter_data_time = time.time()
    with Ctq(train_loader) as tq_train_loader:
        for i, data in enumerate(tq_train_loader):
            t_data = time.time() - iter_data_time
            iter_start_time = time.time()
            model.set_input(data, device)
            model.optimize_parameters(epoch, dataset.batch_size)
            if i % 10 == 0:
                tracker.track(model)

            tq_train_loader.set_postfix(**tracker.get_metrics(),
                                        data_loading=float(t_data),
                                        iteration=float(time.time() -
                                                        iter_start_time),
                                        color=COLORS.TRAIN_COLOR)

            if visualizer.is_active:
                visualizer.save_visuals(model.get_current_visuals())

            iter_data_time = time.time()

            if early_break:
                break

            if profiling:
                if i > getattr(debugging, "num_batches", 50):
                    return 0

    tracker.finalise()
    metrics = tracker.publish(epoch)
    checkpoint.save_best_models_under_current_metrics(model, metrics,
                                                      tracker.metric_func)
    log.info("Learning rate = %f" % model.learning_rate)
Exemplo n.º 5
0
def test_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    debugging,
):
    early_break = getattr(debugging, "early_break", False)
    model.eval()

    loaders = dataset.test_dataloaders

    for loader in loaders:
        stage_name = loader.dataset.name
        tracker.reset(stage_name)
        visualizer.reset(epoch, stage_name)
        with Ctq(loader) as tq_test_loader:
            for data in tq_test_loader:
                with torch.no_grad():
                    model.set_input(data, device)
                    model.forward()

                tracker.track(model)
                tq_test_loader.set_postfix(**tracker.get_metrics(),
                                           color=COLORS.TEST_COLOR)

                if visualizer.is_active:
                    visualizer.save_visuals(model.get_current_visuals())

                if early_break:
                    break

        tracker.finalise()
        metrics = tracker.publish(epoch)
        tracker.print_summary()
        checkpoint.save_best_models_under_current_metrics(
            model, metrics, tracker.metric_func)
Exemplo n.º 6
0
def eval_epoch(
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    voting_runs=1,
    tracker_options={},
):
    tracker.reset("val")
    loader = dataset.val_dataloader
    for i in range(voting_runs):
        with Ctq(loader) as tq_val_loader:
            for data in tq_val_loader:
                with torch.no_grad():
                    model.set_input(data, device)
                    model.forward()

                tracker.track(model, **tracker_options)
                tq_val_loader.set_postfix(**tracker.get_metrics(),
                                          color=COLORS.VAL_COLOR)

    tracker.finalise(**tracker_options)
    tracker.print_summary()
Exemplo n.º 7
0
def eval_epoch(
    epoch: int,
    model: BaseModel,
    dataset,
    device,
    tracker: BaseTracker,
    checkpoint: ModelCheckpoint,
    visualizer: Visualizer,
    debugging,
):

    early_break = getattr(debugging, "early_break", False)

    model.eval()
    tracker.reset("val")
    visualizer.reset(epoch, "val")
    loader = dataset.val_dataloader
    with Ctq(loader) as tq_val_loader:
        for data in tq_val_loader:
            with torch.no_grad():
                model.set_input(data, device)
                model.forward()

            tracker.track(model)
            tq_val_loader.set_postfix(**tracker.get_metrics(),
                                      color=COLORS.VAL_COLOR)

            if visualizer.is_active:
                visualizer.save_visuals(model.get_current_visuals())

            if early_break:
                break

    metrics = tracker.publish(epoch)
    tracker.print_summary()
    checkpoint.save_best_models_under_current_metrics(model, metrics,
                                                      tracker.metric_func)