Beispiel #1
0
def _sk_spec(preds,
             target,
             reduce,
             num_classes,
             multiclass,
             ignore_index,
             top_k=None,
             mdmc_reduce=None,
             stats=None):

    if stats:
        fp, tn = stats
    else:
        stats = _sk_stats_score(preds, target, reduce, num_classes, multiclass,
                                ignore_index, top_k)
        fp, tn = stats

    fp, tn = tensor(fp), tensor(tn)
    spec = _reduce_stat_scores(
        numerator=tn,
        denominator=tn + fp,
        weights=None if reduce != "weighted" else tn + fp,
        average=reduce,
        mdmc_average=mdmc_reduce,
    )
    if reduce in [None, "none"
                  ] and ignore_index is not None and preds.shape[1] > 1:
        spec = spec.numpy()
        spec = np.insert(spec, ignore_index, math.nan)
        spec = tensor(spec)

    return spec
Beispiel #2
0
def _specificity_compute(
    fp: Tensor,
    tn: Tensor,
    average: str,
    mdmc_average: Optional[str],
) -> Tensor:
    return _reduce_stat_scores(
        numerator=tn,
        denominator=tn + fp,
        weights=None if average != "weighted" else tn + fp,
        average=average,
        mdmc_average=mdmc_average,
    )
Beispiel #3
0
def _precision_compute(
    tp: torch.Tensor,
    fp: torch.Tensor,
    tn: torch.Tensor,
    fn: torch.Tensor,
    average: str,
    mdmc_average: Optional[str],
) -> torch.Tensor:
    return _reduce_stat_scores(
        numerator=tp,
        denominator=tp + fp,
        weights=None if average != "weighted" else tp + fn,
        average=average,
        mdmc_average=mdmc_average,
    )
Beispiel #4
0
def _recall_compute(
    tp: Tensor,
    fp: Tensor,
    tn: Tensor,
    fn: Tensor,
    average: str,
    mdmc_average: Optional[str],
) -> Tensor:
    # todo: `tp` is unused
    # todo: `tn` is unused
    return _reduce_stat_scores(
        numerator=tp,
        denominator=tp + fn,
        weights=None if average != "weighted" else tp + fn,
        average=average,
        mdmc_average=mdmc_average,
    )
Beispiel #5
0
def _accuracy_compute(
    tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str, mdmc_average: str, mode: DataType
) -> Tensor:
    simple_average = ["micro", "samples"]
    if (mode == DataType.BINARY and average in simple_average) or mode == DataType.MULTILABEL:
        numerator = tp + tn
        denominator = tp + tn + fp + fn
    else:
        numerator = tp
        denominator = tp + fn
    return _reduce_stat_scores(
        numerator=numerator,
        denominator=denominator,
        weights=None if average != "weighted" else tp + fn,
        average=average,
        mdmc_average=mdmc_average,
    )
Beispiel #6
0
def _fbeta_compute(
    tp: Tensor,
    fp: Tensor,
    tn: Tensor,
    fn: Tensor,
    beta: float,
    ignore_index: Optional[int],
    average: str,
    mdmc_average: Optional[str],
) -> Tensor:

    if average == "micro" and mdmc_average != MDMCAverageMethod.SAMPLEWISE:
        mask = tp >= 0
        precision = _safe_divide(tp[mask].sum().float(),
                                 (tp[mask] + fp[mask]).sum())
        recall = _safe_divide(tp[mask].sum().float(),
                              (tp[mask] + fn[mask]).sum())
    else:
        precision = _safe_divide(tp.float(), tp + fp)
        recall = _safe_divide(tp.float(), tp + fn)

    num = (1 + beta**2) * precision * recall
    denom = beta**2 * precision + recall
    denom[denom == 0.] = 1  # avoid division by 0

    if ignore_index is not None:
        if (average
                not in (AverageMethod.MICRO.value, AverageMethod.SAMPLES.value)
                and mdmc_average == MDMCAverageMethod.SAMPLEWISE  # noqa: W503
            ):
            num[..., ignore_index] = -1
            denom[..., ignore_index] = -1
        elif average not in (AverageMethod.MICRO.value,
                             AverageMethod.SAMPLES.value):
            num[ignore_index, ...] = -1
            denom[ignore_index, ...] = -1

    return _reduce_stat_scores(
        numerator=num,
        denominator=denom,
        weights=None if average != "weighted" else tp + fn,
        average=average,
        mdmc_average=mdmc_average,
    )