コード例 #1
0
ファイル: test_metrics.py プロジェクト: potipot/icevision
def test_efficientdet_metrics(
    fridge_efficientdet_model,
    fridge_efficientdet_records,
    metric,
    expected_output,
    request,
):
    expected_output = request.getfixturevalue(expected_output)
    fridge_efficientdet_model.eval()

    batch, records = efficientdet.build_valid_batch(
        fridge_efficientdet_records)

    raw_preds = fridge_efficientdet_model(*batch)

    preds = efficientdet.convert_raw_predictions(
        batch=batch,
        raw_preds=raw_preds["detections"],
        records=fridge_efficientdet_records,
        detection_threshold=0.0,
    )

    metric.accumulate(preds)

    with CaptureStdout() as output:
        metric.finalize()

    assert output == expected_output
コード例 #2
0
    def after_pred(self):
        self.learn.yb = [self.learn.xb[1]]
        self.learn.xb = [self.learn.xb[0]]

        if not self.training:
            preds = efficientdet.convert_raw_predictions(
                self.pred["detections"], 0)
            self.learn.converted_preds = preds
コード例 #3
0
ファイル: callbacks.py プロジェクト: potipot/icevision
    def after_pred(self):
        self.learn.yb = [self.learn.xb[1]]
        self.learn.xb = [self.learn.xb[0]]

        if not self.training:
            preds = efficientdet.convert_raw_predictions(
                batch=(*self.xb, *self.yb),
                raw_preds=self.pred["detections"],
                records=self.learn.records,
                detection_threshold=0.0,
            )
            self.learn.converted_preds = preds
コード例 #4
0
    def validation_step(self, batch, batch_idx):
        (xb, yb), records = batch

        with torch.no_grad():
            raw_preds = self(xb, yb)
            preds = efficientdet.convert_raw_predictions(
                raw_preds["detections"], 0)
            loss = efficientdet.loss_fn(raw_preds, yb)

        self.accumulate_metrics(records, preds)

        for k, v in raw_preds.items():
            if "loss" in k:
                self.log(f"valid/{k}", v)
コード例 #5
0
ファイル: test_metrics.py プロジェクト: potipot/icevision
def test_plot_confusion_matrix(fridge_efficientdet_model,
                               fridge_efficientdet_records):
    fridge_efficientdet_model.eval()

    batch, records = efficientdet.build_valid_batch(
        fridge_efficientdet_records)

    raw_preds = fridge_efficientdet_model(*batch)

    preds = efficientdet.convert_raw_predictions(
        batch=batch,
        raw_preds=raw_preds["detections"],
        records=fridge_efficientdet_records,
        detection_threshold=0.0,
    )

    cm = SimpleConfusionMatrix()
    cm.accumulate(preds)
    cm.finalize()
    cm.plot()