示例#1
0
    def test_constant_thresholding(self):
        """Test thresholding with constant."""

        # test batch of probability map with unsqueezed channel dimension
        prob_maps = torch.randn(2, 1, 10, 10)
        binary_maps = batch_thresholding(prob_maps,
                                         thresh_mode='constant',
                                         constant=0.5)

        self.assertEqual(binary_maps.shape, (2, 1, 10, 10))
        self.assertTrue(check_binary_map(binary_maps))

        # test batch of probability map with squeezed channel dimension
        prob_maps = torch.randn(2, 10, 10)
        binary_maps = batch_thresholding(prob_maps,
                                         thresh_mode='constant',
                                         constant=0.5)

        self.assertEqual(binary_maps.shape, (2, 1, 10, 10))
        self.assertTrue(check_binary_map(binary_maps))

        # test single sample of probability map with squeezed channel dimension
        prob_maps = torch.randn(1, 10, 10)
        binary_maps = batch_thresholding(prob_maps,
                                         thresh_mode='constant',
                                         constant=0.5)

        self.assertEqual(binary_maps.shape, (1, 1, 10, 10))
        self.assertTrue(check_binary_map(binary_maps))
示例#2
0
def get_binary_confusion_matrix(binary_maps,
                                targets,
                                masks=None,
                                reduction='sum'):
    """Get binary confusion matrix (TP, FP, TN, FN)."""
    binary_maps = convert_to_ndarray(binary_maps)
    targets = convert_to_ndarray(targets)

    assert binary_maps.shape == targets.shape
    assert check_binary_map(binary_maps)
    assert check_binary_map(targets)

    if masks is not None:
        masks = convert_to_ndarray(masks)
        assert binary_maps.shape == masks.shape
        assert check_binary_map(masks)

    targets_neg = -1 * (targets - 1)
    inputs_neg = -1 * (binary_maps - 1)

    true_pos = targets * binary_maps
    false_pos = targets_neg * binary_maps

    true_neg = targets_neg * inputs_neg
    false_neg = targets * inputs_neg

    if masks is not None:
        true_pos = true_pos * masks
        false_pos = false_pos * masks
        true_neg = true_neg * masks
        false_neg = false_neg * masks

    if reduction == 'none':
        pass

    elif reduction == 'sum':
        true_pos = float(np.sum(true_pos))
        false_pos = float(np.sum(false_pos))
        true_neg = float(np.sum(true_neg))
        false_neg = float(np.sum(false_neg))

    elif reduction == 'mean':
        true_pos = float(np.mean(true_pos))
        false_pos = float(np.mean(false_pos))
        true_neg = float(np.mean(true_neg))
        false_neg = float(np.mean(false_neg))

    else:
        LOGGER.error('Invalid reduction mode: %s', reduction)
        raise NotImplementedError(
            'Invalid reduction mode: {}'.format(reduction))

    return true_pos, false_pos, true_neg, false_neg
示例#3
0
def get_dice_coefficient(prob_maps, binary_maps, targets, masks=None,
                         reduction='mean', epsilon=1e-8):
    """Dice coefficient."""

    assert prob_maps.shape == targets.shape
    assert binary_maps.shape == targets.shape
    assert check_probability_map(prob_maps)
    assert check_binary_map(binary_maps)
    assert check_binary_map(targets)

    batch_size = prob_maps.shape[0]
    prob_maps = prob_maps.view(batch_size, -1)
    binary_maps = binary_maps.view(batch_size, -1)
    targets = targets.view(batch_size, -1)

    prob_maps = convert_to_ndarray(prob_maps)
    binary_maps = convert_to_ndarray(binary_maps)
    targets = convert_to_ndarray(targets)

    if masks is not None:
        masks = masks.view(batch_size, -1)
        assert binary_maps.shape == masks.shape
        assert check_binary_map(masks)
        masks = convert_to_ndarray(masks)
        prob_maps = prob_maps * masks
        binary_maps = binary_maps * masks
        targets = targets * masks

    intesection = (binary_maps * targets).sum(-1)
    binary_maps_norm = binary_maps.sum(-1)
    targets_norm = binary_maps.sum(-1)

    dice_score = (2.0 * intesection + epsilon) / (
        binary_maps_norm + targets_norm + epsilon)

    if reduction == 'none':
        pass

    elif reduction == 'sum':
        dice_score = dice_score.sum()

    elif reduction == 'mean':
        dice_score = dice_score.mean()

    else:
        LOGGER.error('Invalid reduction mode: %s', reduction)
        raise NotImplementedError('Invalid reduction mode: {}'.format(
            reduction))

    return dice_score
示例#4
0
def get_area_under_roc_cruve(prob_maps, targets):
    """Get area under ROC curve."""

    assert check_probability_map(prob_maps)
    assert check_binary_map(targets)

    prob_maps = convert_to_ndarray(prob_maps.flatten())
    targets = convert_to_ndarray(targets.flatten())

    auroc = roc_auc_score(y_score=prob_maps, y_true=targets)
    return auroc
示例#5
0
def get_average_precision_score(prob_maps, targets):
    """Get average precision that summarize PR curve."""

    assert check_probability_map(prob_maps)
    assert check_binary_map(targets)

    prob_maps = convert_to_ndarray(prob_maps.flatten())
    targets = convert_to_ndarray(targets.flatten())

    average_precision = average_precision_score(y_score=prob_maps,
                                                y_true=targets)
    return average_precision
示例#6
0
def get_pr_cruve(prob_maps, targets):
    """Get precision, recall and thresholds for precision - recall curve."""

    assert check_probability_map(prob_maps)
    assert check_binary_map(targets)

    prob_maps = convert_to_ndarray(prob_maps.flatten())
    targets = convert_to_ndarray(targets.flatten())

    precision, recall, thresholds = precision_recall_curve(
        probas_pred=prob_maps, y_true=targets)
    return precision, recall, thresholds
示例#7
0
def get_roc_curve(prob_maps, targets):
    """Get false_positive_rate, true_positive_rate, and coresponding
    thresholds for ROC curve (TPR - FPR curve)."""

    assert check_probability_map(prob_maps)
    assert check_binary_map(targets)

    prob_maps = convert_to_ndarray(prob_maps.flatten())
    targets = convert_to_ndarray(targets.flatten())

    fpr, tpr, thresholds = roc_curve(y_score=prob_maps, y_true=targets)
    return fpr, tpr, thresholds
示例#8
0
    def test_skimage_thresholdings(self):
        """Test thresholding methods from skimage."""
        skimage_thresholding_methods = {
            'otsu': {
                'nbins': 256,
                'return_all': False
            },
            'isodata': {
                'nbins': 256,
                'return_all': False
            },
            'li': {
                'tolerance': None
            },
            'mean': {},
            'triangle': {
                'nbins': 256
            },
            'yen': {
                'nbins': 256
            },
            'niblack': {
                'window_size': 15,
                'k': 0.2
            },
            'sauvola': {
                'window_size': 15,
                'k': 0.2,
                'r': None
            },
            'local': {
                'block_size': 11,
                'method': 'gaussian',
                'offset': 0,
                'mode': 'reflect',
                'param': None,
                'cval': 0
            },
        }

        prob_maps = torch.randn(2, 1, 10, 10)

        for method, kwargs in skimage_thresholding_methods.items():
            binary_maps = batch_thresholding(prob_maps,
                                             thresh_mode=method,
                                             **kwargs)

            self.assertTrue(check_binary_map(binary_maps))
def get_pos_neg_count(dataset: torch.utils.data.Dataset, target_key: str):
    """Get dataset statistics."""
    pos_count = 0
    neg_count = 0

    for sample in dataset:
        target = sample[target_key]

        if isinstance(target, (Image.Image, np.ndarray)):
            target = TF.to_tensor(target)

        assert check_binary_map(target)

        pos_count += target.sum()
        neg_count += (target.numel() - target.sum())

    pos_count = pos_count.numpy()
    neg_count = neg_count.numpy()

    return pos_count, neg_count