Exemplo n.º 1
0
    def calc_peak_length_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9], max_length_error=0.1):
        '''
        Ca

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            # if bundle == "CST_right":
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()      # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)

            lenghts_pred = torch.norm(y_pred_bund, 2., -1)
            lengths_true = torch.norm(y_true_bund, 2, -1)
            lengths_binary = torch.abs(lenghts_pred-lengths_true) < (max_length_error * lengths_true)
            lengths_binary = lengths_binary.view(-1)

            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.view(-1)

            combined = lengths_binary * angles_binary

            f1 = PytorchUtils.f1_score_binary(gt_binary, combined)
            score_per_bundle[bundle] = f1
        return score_per_bundle
Exemplo n.º 2
0
    def calc_peak_length_dice_pytorch(HP,
                                      y_pred,
                                      y_true,
                                      max_angle_error=[0.9],
                                      max_length_error=0.1):
        '''
        Ca

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(
                einsum('abcd,abcd->abc', a, b) /
                (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        score_per_bundle = {}
        bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
        for idx, bundle in enumerate(bundles):
            # if bundle == "CST_right":
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                 3].contiguous()  # [x,y,z,3]

            angles = angle_last_dim(y_pred_bund, y_true_bund)

            lenghts_pred = torch.norm(y_pred_bund, 2., -1)
            lengths_true = torch.norm(y_true_bund, 2, -1)
            lengths_binary = torch.abs(lenghts_pred - lengths_true) < (
                max_length_error * lengths_true)
            lengths_binary = lengths_binary.view(-1)

            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.view(-1)

            combined = lengths_binary * angles_binary

            f1 = PytorchUtils.f1_score_binary(gt_binary, combined)
            score_per_bundle[bundle] = f1
        return score_per_bundle
Exemplo n.º 3
0
    def calc_peak_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9]):
        '''
        Calculate angle between groundtruth and prediction and keep the voxels where
        angle is smaller than MAX_ANGLE_ERROR.

        From groundtruth generate a binary mask by selecting all voxels with len > 0.

        Calculate Dice from these 2 masks.

        -> Penalty on peaks outside of tract or if predicted peak=0
        -> no penalty on very very small with right direction -> bad
        => Peak_dice can be high even if peaks inside of tract almost missing (almost 0)

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(einsum('abcd,abcd->abc', a, b) / (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        if len(max_angle_error) == 1:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()      # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                angles_binary = angles > max_angle_error[0]
                angles_binary = angles_binary.view(-1)

                f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                score_per_bundle[bundle] = f1

            return score_per_bundle

        #multiple thresholds
        else:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()  # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                score_per_bundle[bundle] = []
                for threshold in max_angle_error:
                    angles_binary = angles > threshold
                    angles_binary = angles_binary.view(-1)

                    f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                    score_per_bundle[bundle].append(f1)

            return score_per_bundle
Exemplo n.º 4
0
    def calc_peak_dice_pytorch(HP, y_pred, y_true, max_angle_error=[0.9]):
        '''
        Calculate angle between groundtruth and prediction and keep the voxels where
        angle is smaller than MAX_ANGLE_ERROR.

        From groundtruth generate a binary mask by selecting all voxels with len > 0.

        Calculate Dice from these 2 masks.

        -> Penalty on peaks outside of tract or if predicted peak=0
        -> no penalty on very very small with right direction -> bad
        => Peak_dice can be high even if peaks inside of tract almost missing (almost 0)

        :param y_pred:
        :param y_true:
        :param max_angle_error:  0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                                 Can be list with several values -> calculate for several thresholds
        :return:
        '''
        import torch
        from tractseg.libs.PytorchEinsum import einsum
        from tractseg.libs.PytorchUtils import PytorchUtils

        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)

        def angle_last_dim(a, b):
            '''
            Calculate the angle between two nd-arrays (array of vectors) along the last dimension

            without anything further: 1->0°, 0.9->23°, 0.7->45°, 0->90°
            np.arccos -> returns degree in pi (90°: 0.5*pi)

            return: one dimension less then input
            '''
            return torch.abs(
                einsum('abcd,abcd->abc', a, b) /
                (torch.norm(a, 2., -1) * torch.norm(b, 2, -1) + 1e-7))

        #Single threshold
        if len(max_angle_error) == 1:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :,
                                     (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                     3].contiguous()  # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                angles_binary = angles > max_angle_error[0]
                angles_binary = angles_binary.view(-1)

                f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                score_per_bundle[bundle] = f1

            return score_per_bundle

        #multiple thresholds
        else:
            score_per_bundle = {}
            bundles = ExpUtils.get_bundle_names(HP.CLASSES)[1:]
            for idx, bundle in enumerate(bundles):
                # if bundle == "CST_right":
                y_pred_bund = y_pred[:, :, :,
                                     (idx * 3):(idx * 3) + 3].contiguous()
                y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                     3].contiguous()  # [x,y,z,3]

                angles = angle_last_dim(y_pred_bund, y_true_bund)
                gt_binary = y_true_bund.sum(dim=-1) > 0
                gt_binary = gt_binary.view(-1)  # [bs*x*y]

                score_per_bundle[bundle] = []
                for threshold in max_angle_error:
                    angles_binary = angles > threshold
                    angles_binary = angles_binary.view(-1)

                    f1 = PytorchUtils.f1_score_binary(gt_binary, angles_binary)
                    score_per_bundle[bundle].append(f1)

            return score_per_bundle