Beispiel #1
0
def calc_peak_length_dice_pytorch(classes,
                                  y_pred,
                                  y_true,
                                  max_angle_error=[0.9],
                                  max_length_error=0.1):
    import torch
    from tractseg.libs import pytorch_utils

    if len(y_pred.shape) == 4:  # 2D
        y_true = y_true.permute(0, 2, 3, 1)
        y_pred = y_pred.permute(0, 2, 3, 1)
    else:  # 3D
        y_true = y_true.permute(0, 2, 3, 4, 1)
        y_pred = y_pred.permute(0, 2, 3, 4, 1)

    #Single threshold
    score_per_bundle = {}
    bundles = dataset_specific_utils.get_bundle_names(classes)[1:]
    for idx, bundle in enumerate(bundles):
        y_pred_bund = y_pred[..., (idx * 3):(idx * 3) + 3].contiguous()
        y_true_bund = y_true[..., (idx * 3):(idx * 3) +
                             3].contiguous()  # [x, y, z, 3]

        angles = pytorch_utils.angle_last_dim(y_pred_bund, y_true_bund)

        lenghts_pred = torch.norm(y_pred_bund, 2., -1)
        lengths_true = torch.norm(y_true_bund, 2, -1)
        lengths_binary = torch.abs(lenghts_pred - lengths_true) < (
            max_length_error * lengths_true)
        lengths_binary = lengths_binary.view(-1)

        gt_binary = y_true_bund.sum(dim=-1) > 0
        gt_binary = gt_binary.view(-1)  # [bs*x*y]

        angles_binary = angles > max_angle_error[0]
        angles_binary = angles_binary.view(-1)

        combined = lengths_binary * angles_binary

        f1 = pytorch_utils.f1_score_binary(gt_binary, combined)
        score_per_bundle[bundle] = f1
    return score_per_bundle
Beispiel #2
0
def calc_peak_dice_pytorch(classes, y_pred, y_true, max_angle_error=[0.9]):
    """
    Calculate angle between groundtruth and prediction and keep the voxels where
    angle is smaller than MAX_ANGLE_ERROR.

    From groundtruth generate a binary mask by selecting all voxels with len > 0.

    Calculate Dice from these 2 masks.

    -> Penalty on peaks outside of tract or if predicted peak=0
    -> no penalty on very very small with right direction -> bad
    => Peak_dice can be high even if peaks inside of tract almost missing (almost 0)

    Args:
        y_pred:
        y_true:
        max_angle_error: 0.7 ->  angle error of 45° or less; 0.9 ->  angle error of 23° or less
                         Can be list with several values -> calculate for several thresholds

    Returns:

    """
    from tractseg.libs import pytorch_utils

    y_true = y_true.permute(0, 2, 3, 1)
    y_pred = y_pred.permute(0, 2, 3, 1)

    #Single threshold
    if len(max_angle_error) == 1:
        score_per_bundle = {}
        bundles = dataset_specific_utils.get_bundle_names(classes)[1:]
        for idx, bundle in enumerate(bundles):
            # if bundle == "CST_right":
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                 3].contiguous()  # [x, y, z, 3]

            angles = pytorch_utils.angle_last_dim(y_pred_bund, y_true_bund)
            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            angles_binary = angles > max_angle_error[0]
            angles_binary = angles_binary.view(-1)

            f1 = pytorch_utils.f1_score_binary(gt_binary, angles_binary)
            score_per_bundle[bundle] = f1

        return score_per_bundle

    #multiple thresholds
    else:
        score_per_bundle = {}
        bundles = dataset_specific_utils.get_bundle_names(classes)[1:]
        for idx, bundle in enumerate(bundles):
            y_pred_bund = y_pred[:, :, :, (idx * 3):(idx * 3) + 3].contiguous()
            y_true_bund = y_true[:, :, :, (idx * 3):(idx * 3) +
                                 3].contiguous()  # [x, y, z, 3]

            angles = pytorch_utils.angle_last_dim(y_pred_bund, y_true_bund)
            gt_binary = y_true_bund.sum(dim=-1) > 0
            gt_binary = gt_binary.view(-1)  # [bs*x*y]

            score_per_bundle[bundle] = []
            for threshold in max_angle_error:
                angles_binary = angles > threshold
                angles_binary = angles_binary.view(-1)

                f1 = pytorch_utils.f1_score_binary(gt_binary, angles_binary)
                score_per_bundle[bundle].append(f1)

        return score_per_bundle