コード例 #1
0
def test_segmentation_statistics():
    size = 4
    half_size = size // 2
    shape = (1, 1, size, size)

    # check 0: one empty
    empty = torch.zeros(shape)
    full = torch.ones(shape)
    tp, fp, fn = get_segmentation_statistics(empty, full, class_dim=1)
    assert tp == torch.tensor([0.0]) and fp == torch.tensor(
        [0.0]) and fn == torch.tensor([16.0])

    # check 0: no overlap
    left = torch.ones(shape)
    left[:, :, :, half_size:] = 0
    right = torch.ones(shape)
    right[:, :, :, :half_size] = 0
    tp, fp, fn = get_segmentation_statistics(left, right, class_dim=1)
    assert tp == torch.tensor([0.0]) and fp == torch.tensor(
        [8.0]) and fn == torch.tensor([8.0])

    # check 1: both empty, both full, complete overlap
    tp, fp, fn = get_segmentation_statistics(empty, empty, class_dim=1)
    assert tp == torch.tensor([0.0]) and fp == torch.tensor(
        [0.0]) and fn == torch.tensor([0.0])
    tp, fp, fn = get_segmentation_statistics(full, full, class_dim=1)
    assert tp == torch.tensor([16.0]) and fp == torch.tensor(
        [0.0]) and fn == torch.tensor([0.0])
    tp, fp, fn = get_segmentation_statistics(left, left, class_dim=1)
    assert tp == torch.tensor([8.0]) and fp == torch.tensor(
        [0.0]) and fn == torch.tensor([0.0])

    # check 0.5: half overlap
    top_left = torch.zeros(shape)
    top_left[:, :, :half_size, :half_size] = 1
    tp, fp, fn = get_segmentation_statistics(left, top_left, class_dim=1)
    assert tp == torch.tensor([4.0]) and fp == torch.tensor(
        [4.0]) and fn == torch.tensor([0.0])

    # check multiclass
    a = torch.cat([empty, left, empty, full, left, top_left], dim=1)
    b = torch.cat([full, right, empty, full, left, left], dim=1)
    true_tp = torch.tensor([0.0, 0.0, 0.0, 16.0, 8.0, 4.0])
    true_fp = torch.tensor([0.0, 8.0, 0.0, 0.0, 0.0, 0.0])
    true_fn = torch.tensor([16.0, 8.0, 0.0, 0.0, 0.0, 4.0])
    tp, fp, fn = get_segmentation_statistics(a, b, class_dim=1)
    assert torch.allclose(tp, true_tp)
    assert torch.allclose(fp, true_fp)
    assert torch.allclose(fn, true_fn)

    aaa = torch.cat([a, a, a], dim=0)
    bbb = torch.cat([b, b, b], dim=0)
    true_tp = torch.tensor([0.0, 0.0, 0.0, 48.0, 24.0, 12.0])
    true_fp = torch.tensor([0.0, 24.0, 0.0, 0.0, 0.0, 0.0])
    true_fn = torch.tensor([48.0, 24.0, 0.0, 0.0, 0.0, 12.0])
    tp, fp, fn = get_segmentation_statistics(aaa, bbb, class_dim=1)
    assert torch.allclose(tp, true_tp)
    assert torch.allclose(fp, true_fp)
    assert torch.allclose(fn, true_fn)
コード例 #2
0
    def update(self, outputs: torch.Tensor,
               targets: torch.Tensor) -> torch.Tensor:
        """Updates segmentation statistics with new data and return intermediate metrics values.

        Args:
            outputs: tensor of logits
            targets: tensor of targets

        Returns:
            metric for each class
        """
        tp, fp, fn = get_segmentation_statistics(
            outputs=outputs.detach(),
            targets=targets.detach(),
            class_dim=self.class_dim,
            threshold=self.threshold,
        )

        for idx, (tp_class, fp_class, fn_class) in enumerate(zip(tp, fp, fn)):
            if idx in self.statistics:
                self.statistics[idx]["tp"] += tp_class
                self.statistics[idx]["fp"] += fp_class
                self.statistics[idx]["fn"] += fn_class
            else:
                self.statistics[idx] = {}
                self.statistics[idx]["tp"] = tp_class
                self.statistics[idx]["fp"] = fp_class
                self.statistics[idx]["fn"] = fn_class

        metrics_per_class = self.metric_fn(tp, fp, fn)
        return metrics_per_class