def _compute(self):
        device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available(
        ) else "cpu"
        metric = ConfusionMatrix(include_background=True, metric_name="tpr")

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)
        metric.attach(engine, "confusion_matrix")
        if dist.get_rank() == 0:
            y_pred = torch.tensor(
                [
                    [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]],
                     [[1.0, 0.0], [0.0, 0.0]]],
                    [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]],
                     [[0.0, 1.0], [1.0, 0.0]]],
                ],
                device=device,
            )
            y = torch.tensor(
                [
                    [[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]],
                     [[1.0, 0.0], [0.0, 0.0]]],
                    [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]],
                     [[0.0, 1.0], [1.0, 0.0]]],
                ],
                device=device,
            )
            metric.update([y_pred, y])

        if dist.get_rank() == 1:
            y_pred = torch.tensor(
                [[[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [1.0, 1.0]],
                  [[0.0, 1.0], [0.0, 0.0]]]],
                device=device)
            y = torch.tensor(
                [[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]],
                  [[1.0, 1.0], [1.0, 1.0]]]],
                device=device)
            metric.update([y_pred, y])

        avg_metric = metric.compute()
        np.testing.assert_allclose(avg_metric, 0.7, rtol=1e-04, atol=1e-04)
Ejemplo n.º 2
0
    def test_compute_seg(self, input_params, expected_avg):
        metric = ConfusionMatrix(**input_params)

        def _val_func(engine, batch):
            pass

        engine = Engine(_val_func)
        metric.attach(engine, "confusion_matrix")

        y_pred = data_1["y_pred"]
        y = data_1["y"]
        metric.update([y_pred, y])

        y_pred = data_2["y_pred"]
        y = data_2["y"]
        metric.update([y_pred, y])

        avg_metric = metric.compute()
        self.assertAlmostEqual(avg_metric, expected_avg, places=4)