def test_constant_thresholding(self): """Test thresholding with constant.""" # test batch of probability map with unsqueezed channel dimension prob_maps = torch.randn(2, 1, 10, 10) binary_maps = batch_thresholding(prob_maps, thresh_mode='constant', constant=0.5) self.assertEqual(binary_maps.shape, (2, 1, 10, 10)) self.assertTrue(check_binary_map(binary_maps)) # test batch of probability map with squeezed channel dimension prob_maps = torch.randn(2, 10, 10) binary_maps = batch_thresholding(prob_maps, thresh_mode='constant', constant=0.5) self.assertEqual(binary_maps.shape, (2, 1, 10, 10)) self.assertTrue(check_binary_map(binary_maps)) # test single sample of probability map with squeezed channel dimension prob_maps = torch.randn(1, 10, 10) binary_maps = batch_thresholding(prob_maps, thresh_mode='constant', constant=0.5) self.assertEqual(binary_maps.shape, (1, 1, 10, 10)) self.assertTrue(check_binary_map(binary_maps))
def get_binary_confusion_matrix(binary_maps, targets, masks=None, reduction='sum'): """Get binary confusion matrix (TP, FP, TN, FN).""" binary_maps = convert_to_ndarray(binary_maps) targets = convert_to_ndarray(targets) assert binary_maps.shape == targets.shape assert check_binary_map(binary_maps) assert check_binary_map(targets) if masks is not None: masks = convert_to_ndarray(masks) assert binary_maps.shape == masks.shape assert check_binary_map(masks) targets_neg = -1 * (targets - 1) inputs_neg = -1 * (binary_maps - 1) true_pos = targets * binary_maps false_pos = targets_neg * binary_maps true_neg = targets_neg * inputs_neg false_neg = targets * inputs_neg if masks is not None: true_pos = true_pos * masks false_pos = false_pos * masks true_neg = true_neg * masks false_neg = false_neg * masks if reduction == 'none': pass elif reduction == 'sum': true_pos = float(np.sum(true_pos)) false_pos = float(np.sum(false_pos)) true_neg = float(np.sum(true_neg)) false_neg = float(np.sum(false_neg)) elif reduction == 'mean': true_pos = float(np.mean(true_pos)) false_pos = float(np.mean(false_pos)) true_neg = float(np.mean(true_neg)) false_neg = float(np.mean(false_neg)) else: LOGGER.error('Invalid reduction mode: %s', reduction) raise NotImplementedError( 'Invalid reduction mode: {}'.format(reduction)) return true_pos, false_pos, true_neg, false_neg
def get_dice_coefficient(prob_maps, binary_maps, targets, masks=None, reduction='mean', epsilon=1e-8): """Dice coefficient.""" assert prob_maps.shape == targets.shape assert binary_maps.shape == targets.shape assert check_probability_map(prob_maps) assert check_binary_map(binary_maps) assert check_binary_map(targets) batch_size = prob_maps.shape[0] prob_maps = prob_maps.view(batch_size, -1) binary_maps = binary_maps.view(batch_size, -1) targets = targets.view(batch_size, -1) prob_maps = convert_to_ndarray(prob_maps) binary_maps = convert_to_ndarray(binary_maps) targets = convert_to_ndarray(targets) if masks is not None: masks = masks.view(batch_size, -1) assert binary_maps.shape == masks.shape assert check_binary_map(masks) masks = convert_to_ndarray(masks) prob_maps = prob_maps * masks binary_maps = binary_maps * masks targets = targets * masks intesection = (binary_maps * targets).sum(-1) binary_maps_norm = binary_maps.sum(-1) targets_norm = binary_maps.sum(-1) dice_score = (2.0 * intesection + epsilon) / ( binary_maps_norm + targets_norm + epsilon) if reduction == 'none': pass elif reduction == 'sum': dice_score = dice_score.sum() elif reduction == 'mean': dice_score = dice_score.mean() else: LOGGER.error('Invalid reduction mode: %s', reduction) raise NotImplementedError('Invalid reduction mode: {}'.format( reduction)) return dice_score
def get_area_under_roc_cruve(prob_maps, targets): """Get area under ROC curve.""" assert check_probability_map(prob_maps) assert check_binary_map(targets) prob_maps = convert_to_ndarray(prob_maps.flatten()) targets = convert_to_ndarray(targets.flatten()) auroc = roc_auc_score(y_score=prob_maps, y_true=targets) return auroc
def get_average_precision_score(prob_maps, targets): """Get average precision that summarize PR curve.""" assert check_probability_map(prob_maps) assert check_binary_map(targets) prob_maps = convert_to_ndarray(prob_maps.flatten()) targets = convert_to_ndarray(targets.flatten()) average_precision = average_precision_score(y_score=prob_maps, y_true=targets) return average_precision
def get_pr_cruve(prob_maps, targets): """Get precision, recall and thresholds for precision - recall curve.""" assert check_probability_map(prob_maps) assert check_binary_map(targets) prob_maps = convert_to_ndarray(prob_maps.flatten()) targets = convert_to_ndarray(targets.flatten()) precision, recall, thresholds = precision_recall_curve( probas_pred=prob_maps, y_true=targets) return precision, recall, thresholds
def get_roc_curve(prob_maps, targets): """Get false_positive_rate, true_positive_rate, and coresponding thresholds for ROC curve (TPR - FPR curve).""" assert check_probability_map(prob_maps) assert check_binary_map(targets) prob_maps = convert_to_ndarray(prob_maps.flatten()) targets = convert_to_ndarray(targets.flatten()) fpr, tpr, thresholds = roc_curve(y_score=prob_maps, y_true=targets) return fpr, tpr, thresholds
def test_skimage_thresholdings(self): """Test thresholding methods from skimage.""" skimage_thresholding_methods = { 'otsu': { 'nbins': 256, 'return_all': False }, 'isodata': { 'nbins': 256, 'return_all': False }, 'li': { 'tolerance': None }, 'mean': {}, 'triangle': { 'nbins': 256 }, 'yen': { 'nbins': 256 }, 'niblack': { 'window_size': 15, 'k': 0.2 }, 'sauvola': { 'window_size': 15, 'k': 0.2, 'r': None }, 'local': { 'block_size': 11, 'method': 'gaussian', 'offset': 0, 'mode': 'reflect', 'param': None, 'cval': 0 }, } prob_maps = torch.randn(2, 1, 10, 10) for method, kwargs in skimage_thresholding_methods.items(): binary_maps = batch_thresholding(prob_maps, thresh_mode=method, **kwargs) self.assertTrue(check_binary_map(binary_maps))
def get_pos_neg_count(dataset: torch.utils.data.Dataset, target_key: str): """Get dataset statistics.""" pos_count = 0 neg_count = 0 for sample in dataset: target = sample[target_key] if isinstance(target, (Image.Image, np.ndarray)): target = TF.to_tensor(target) assert check_binary_map(target) pos_count += target.sum() neg_count += (target.numel() - target.sum()) pos_count = pos_count.numpy() neg_count = neg_count.numpy() return pos_count, neg_count