コード例 #1
0
    def __call__(self, prediction: torch.Tensor,
                 target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            prediction: output of a network, expected to be already binarized.
                Dimensions - (Batch, Channels, Depth, Height, Width)
            target: labels. The classes (channels) should be one-hot encoded
                Dimensions - (Batch, Channels, Depth, Height, Width)

        Returns:
            torch.Tensor: Computed Dice Loss
        """
        return torch.sum(1 - F.dice(prediction, target, self.epsilon))
コード例 #2
0
ファイル: metrics.py プロジェクト: PiechaczekMyller/brats
    def __call__(self, prediction: torch.Tensor,
                 target: torch.Tensor) -> torch.Tensor:
        """

        Args:
            prediction: Network output.
                Dimensions - (Batch, Class, Depth, Height, Width)
            target: Target values. The classes should be ont-hot encoded.
                Dimensions - (Batch, Class, Depth, Height, Width)
        Returns:
            torch.Tensor: DICE for each class score averaged
                across the whole batch
        """
        return F.dice(prediction, target, self.epsilon)
コード例 #3
0
 def test_if_returns_expected_value_for_multiclass(self, images, target,
                                                   classes, result):
     images = images.unsqueeze(dim=CHANNEL_DIM).repeat(1, classes, 1)
     target = target.unsqueeze(dim=CHANNEL_DIM).repeat(1, classes, 1)
     result = [result for _ in range(classes)]
     assert np.all(np.isclose(F.dice(images, target), result, atol=1.e-3))
コード例 #4
0
 def test_if_returns_correct_number_of_classes(self, input_shape, result):
     images = torch.ones(*input_shape)
     target = torch.ones(*input_shape)
     assert len(F.dice(images, target)) == result
コード例 #5
0
 def test_if_returns_expected_values_for_one_class(self, images, target,
                                                   result):
     assert np.isclose(F.dice(images.unsqueeze(dim=CHANNEL_DIM),
                              target.unsqueeze(dim=CHANNEL_DIM)),
                       result,
                       atol=1.e-3)
コード例 #6
0
 def test_if_returns_0_for_worst_fit(self):
     images = torch.zeros(*BATCH_DIMS)
     target = torch.ones(*BATCH_DIMS)
     assert np.isclose(F.dice(images, target), 0, atol=1.e-4)
コード例 #7
0
 def test_if_returns_1_for_perfect_fit(self):
     images = torch.ones(*BATCH_DIMS)
     target = torch.ones(*BATCH_DIMS)
     assert np.isclose(F.dice(images, target), 1, atol=1.e-4)