def get_area_under_roc_cruve(prob_maps, targets): """Get area under ROC curve.""" assert check_probability_map(prob_maps) assert check_binary_map(targets) prob_maps = convert_to_ndarray(prob_maps.flatten()) targets = convert_to_ndarray(targets.flatten()) auroc = roc_auc_score(y_score=prob_maps, y_true=targets) return auroc
def get_average_precision_score(prob_maps, targets): """Get average precision that summarize PR curve.""" assert check_probability_map(prob_maps) assert check_binary_map(targets) prob_maps = convert_to_ndarray(prob_maps.flatten()) targets = convert_to_ndarray(targets.flatten()) average_precision = average_precision_score(y_score=prob_maps, y_true=targets) return average_precision
def get_pr_cruve(prob_maps, targets): """Get precision, recall and thresholds for precision - recall curve.""" assert check_probability_map(prob_maps) assert check_binary_map(targets) prob_maps = convert_to_ndarray(prob_maps.flatten()) targets = convert_to_ndarray(targets.flatten()) precision, recall, thresholds = precision_recall_curve( probas_pred=prob_maps, y_true=targets) return precision, recall, thresholds
def get_roc_curve(prob_maps, targets): """Get false_positive_rate, true_positive_rate, and coresponding thresholds for ROC curve (TPR - FPR curve).""" assert check_probability_map(prob_maps) assert check_binary_map(targets) prob_maps = convert_to_ndarray(prob_maps.flatten()) targets = convert_to_ndarray(targets.flatten()) fpr, tpr, thresholds = roc_curve(y_score=prob_maps, y_true=targets) return fpr, tpr, thresholds
def get_binary_confusion_matrix(binary_maps, targets, masks=None, reduction='sum'): """Get binary confusion matrix (TP, FP, TN, FN).""" binary_maps = convert_to_ndarray(binary_maps) targets = convert_to_ndarray(targets) assert binary_maps.shape == targets.shape assert check_binary_map(binary_maps) assert check_binary_map(targets) if masks is not None: masks = convert_to_ndarray(masks) assert binary_maps.shape == masks.shape assert check_binary_map(masks) targets_neg = -1 * (targets - 1) inputs_neg = -1 * (binary_maps - 1) true_pos = targets * binary_maps false_pos = targets_neg * binary_maps true_neg = targets_neg * inputs_neg false_neg = targets * inputs_neg if masks is not None: true_pos = true_pos * masks false_pos = false_pos * masks true_neg = true_neg * masks false_neg = false_neg * masks if reduction == 'none': pass elif reduction == 'sum': true_pos = float(np.sum(true_pos)) false_pos = float(np.sum(false_pos)) true_neg = float(np.sum(true_neg)) false_neg = float(np.sum(false_neg)) elif reduction == 'mean': true_pos = float(np.mean(true_pos)) false_pos = float(np.mean(false_pos)) true_neg = float(np.mean(true_neg)) false_neg = float(np.mean(false_neg)) else: LOGGER.error('Invalid reduction mode: %s', reduction) raise NotImplementedError( 'Invalid reduction mode: {}'.format(reduction)) return true_pos, false_pos, true_neg, false_neg
def get_dice_coefficient(prob_maps, binary_maps, targets, masks=None, reduction='mean', epsilon=1e-8): """Dice coefficient.""" assert prob_maps.shape == targets.shape assert binary_maps.shape == targets.shape assert check_probability_map(prob_maps) assert check_binary_map(binary_maps) assert check_binary_map(targets) batch_size = prob_maps.shape[0] prob_maps = prob_maps.view(batch_size, -1) binary_maps = binary_maps.view(batch_size, -1) targets = targets.view(batch_size, -1) prob_maps = convert_to_ndarray(prob_maps) binary_maps = convert_to_ndarray(binary_maps) targets = convert_to_ndarray(targets) if masks is not None: masks = masks.view(batch_size, -1) assert binary_maps.shape == masks.shape assert check_binary_map(masks) masks = convert_to_ndarray(masks) prob_maps = prob_maps * masks binary_maps = binary_maps * masks targets = targets * masks intesection = (binary_maps * targets).sum(-1) binary_maps_norm = binary_maps.sum(-1) targets_norm = binary_maps.sum(-1) dice_score = (2.0 * intesection + epsilon) / ( binary_maps_norm + targets_norm + epsilon) if reduction == 'none': pass elif reduction == 'sum': dice_score = dice_score.sum() elif reduction == 'mean': dice_score = dice_score.mean() else: LOGGER.error('Invalid reduction mode: %s', reduction) raise NotImplementedError('Invalid reduction mode: {}'.format( reduction)) return dice_score