def __init__( self, threshold: float = 0.5, top_k: Optional[int] = None, reduce: str = "micro", num_classes: Optional[int] = None, ignore_index: Optional[int] = None, mdmc_reduce: Optional[str] = None, multiclass: Optional[bool] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ): multiclass = _deprecation_warn_arg_is_multiclass(is_multiclass, multiclass) super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) self.reduce = reduce self.mdmc_reduce = mdmc_reduce self.num_classes = num_classes self.threshold = threshold self.multiclass = multiclass self.ignore_index = ignore_index self.top_k = top_k if not 0 < threshold < 1: raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") if reduce not in ["micro", "macro", "samples"]: raise ValueError(f"The `reduce` {reduce} is not valid.") if mdmc_reduce not in [None, "samplewise", "global"]: raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") if reduce == "macro" and (not num_classes or num_classes < 1): raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") if mdmc_reduce != "samplewise" and reduce != "samples": if reduce == "micro": zeros_shape = [] elif reduce == "macro": zeros_shape = (num_classes, ) default, reduce_fn = lambda: torch.zeros(zeros_shape, dtype=torch.long), "sum" else: default, reduce_fn = lambda: [], None for s in ("tp", "fp", "tn", "fn"): self.add_state(s, default=default(), dist_reduce_fx=reduce_fn)
def __init__( self, num_classes: Optional[int] = None, threshold: float = 0.5, average: str = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, multilabel: Optional[ bool] = None, # todo: deprecated, remove in v0.4 is_multiclass: Optional[ bool] = None, # todo: deprecated, remove in v0.4 ): _deprecation_warn_arg_multilabel(multilabel) multiclass = _deprecation_warn_arg_is_multiclass( is_multiclass, multiclass) allowed_average = [ "micro", "macro", "weighted", "samples", "none", None ] if average not in allowed_average: raise ValueError( f"The `average` has to be one of {allowed_average}, got {average}." ) super().__init__( reduce="macro" if average in ["weighted", "none", None] else average, mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, multiclass=multiclass, ignore_index=ignore_index, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, ) self.average = average
def precision( preds: Tensor, target: Tensor, average: str = "micro", mdmc_average: Optional[str] = None, ignore_index: Optional[int] = None, num_classes: Optional[int] = None, threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, multilabel: Optional[bool] = None, # todo: deprecated, remove in v0.4 is_multiclass: Optional[bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: r""" Computes `Precision <https://en.wikipedia.org/wiki/Precision_and_recall>`_: .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives respecitively. With the use of ``top_k`` parameter, this metric can generalize to Precision@K. The reduction method (how the precision scores are aggregated) is controlled by the ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`. Args: preds: Predictions from model (probabilities, logits or labels) target: Ground truth values average: Defines the reduction that is applied. Should be one of the following: - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - ``'macro'``: Calculate the metric for each class separately, and average the metrics across classes (with equal weights for each class). - ``'weighted'``: Calculate the metric for each class separately, and average the metrics across classes, weighting each class by its support (``tp + fn``). - ``'none'`` or ``None``: Calculate the metric for each class separately, and return the metric for every class. - ``'samples'``: Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample). .. note:: What is considered a sample in the multi-dimensional multi-class case depends on the value of ``mdmc_average``. mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the ``average`` parameter). Should be one of the following: - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class. - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then averaged over samples. The computation for each sample is done by treating the flattened extra axes ``...`` (see :ref:`references/modules:input types`) as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs (see :ref:`references/modules:input types`) are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` or ``'none'``, the score for the ignored class will be returned as ``nan``. num_classes: Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. top_k: Number of highest probability or logit score predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (``None``) will be interpreted as 1 for these inputs. Should be left at default (``None``) for all other types of inputs. multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's :ref:`documentation section <references/modules:using the multiclass parameter>` for a more detailed explanation and examples. multilabel: .. deprecated:: 0.3 Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. is_multiclass: .. deprecated:: 0.3 Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. Return: The shape of the returned tensor depends on the ``average`` parameter - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number of classes Raises: ValueError: If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"`` or ``None``. ValueError: If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. ValueError: If ``average`` is set but ``num_classes`` is not provided. ValueError: If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. Example: >>> from torchmetrics.functional import precision >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> precision(preds, target, average='macro', num_classes=3) tensor(0.1667) >>> precision(preds, target, average='micro') tensor(0.2500) """ _deprecation_warn_arg_multilabel(multilabel) multiclass = _deprecation_warn_arg_is_multiclass(is_multiclass, multiclass) allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") allowed_mdmc_average = [None, "samplewise", "global"] if mdmc_average not in allowed_mdmc_average: raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") reduce = "macro" if average in ["weighted", "none", None] else average tp, fp, tn, fn = _stat_scores_update( preds, target, reduce=reduce, mdmc_reduce=mdmc_average, threshold=threshold, num_classes=num_classes, top_k=top_k, multiclass=multiclass, ignore_index=ignore_index, ) return _precision_compute(tp, fp, tn, fn, average, mdmc_average)
def stat_scores( preds: Tensor, target: Tensor, reduce: str = "micro", mdmc_reduce: Optional[str] = None, num_classes: Optional[int] = None, top_k: Optional[int] = None, threshold: float = 0.5, multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, is_multiclass: Optional[ bool] = None, # todo: deprecated, remove in v0.4 ) -> Tensor: """Computes the number of true positives, false positives, true negatives, false negatives. Related to `Type I and Type II errors <https://en.wikipedia.org/wiki/Type_I_and_type_II_errors>`__ and the `confusion matrix <https://en.wikipedia.org/wiki/Confusion_matrix#Table_of_confusion>`__. The reduction method (how the statistics are aggregated) is controlled by the ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the multi-dimensional multi-class case. Accepts all inputs listed in :ref:`references/modules:input types`. Args: preds: Predictions from model (probabilities, logits or labels) target: Ground truth values threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. top_k: Number of highest probability or logit score predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs. The default value (``None``) will be interpreted as 1 for these inputs. Should be left at default (``None``) for all other types of inputs. reduce: Defines the reduction that is applied. Should be one of the following: - ``'micro'`` [default]: Counts the statistics by summing over all [sample, class] combinations (globally). Each statistic is represented by a single integer. - ``'macro'``: Counts the statistics for each class separately (over all samples). Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` to be set. - ``'samples'``: Counts the statistics for each sample separately (over all classes). Each statistic is represented by a ``(N, )`` 1d tensor. .. note:: What is considered a sample in the multi-dimensional multi-class case depends on the value of ``mdmc_reduce``. num_classes: Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. ignore_index: Specify a class (label) to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. If an index is ignored, and ``reduce='macro'``, the class statistics for the ignored class will all be returned as ``-1``. mdmc_reduce: Defines how the multi-dimensional multi-class inputs are handeled. Should be one of the following: - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class (see :ref:`references/modules:input types` for the definition of input types). - ``'samplewise'``: In this case, the statistics are computed separately for each sample on the ``N`` axis, and then the outputs are concatenated together. In each sample the extra axes ``...`` are flattened to become the sub-sample axis, and statistics for each sample are computed by treating the sub-sample axis as the ``N`` axis for that sample. - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. multiclass: Used only in certain special cases, where you want to treat inputs as a different type than what they appear to be. See the parameter's :ref:`documentation section <references/modules:using the multiclass parameter>` for a more detailed explanation and examples. is_multiclass: .. deprecated:: 0.3 Argument will not have any effect and will be removed in v0.4, please use ``multiclass`` intead. Return: The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional multi-class data) parameters: - If the data is not multi-dimensional multi-class, then - If ``reduce='micro'``, the shape will be ``(5, )`` - If ``reduce='macro'``, the shape will be ``(C, 5)``, where ``C`` stands for the number of classes - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for the number of samples - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then - If ``reduce='micro'``, the shape will be ``(5, )`` - If ``reduce='macro'``, the shape will be ``(C, 5)`` - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for the product of sizes of all "extra" dimensions of the data (i.e. all dimensions except for ``C`` and ``N``) - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then - If ``reduce='micro'``, the shape will be ``(N, 5)`` - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` Raises: ValueError: If ``reduce`` is none of ``"micro"``, ``"macro"`` or ``"samples"``. ValueError: If ``mdmc_reduce`` is none of ``None``, ``"samplewise"``, ``"global"``. ValueError: If ``reduce`` is set to ``"macro"`` and ``num_classes`` is not provided. ValueError: If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. ValueError: If ``ignore_index`` is used with ``binary data``. ValueError: If inputs are ``multi-dimensional multi-class`` and ``mdmc_reduce`` is not provided. Example: >>> from torchmetrics.functional import stat_scores >>> preds = torch.tensor([1, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) >>> stat_scores(preds, target, reduce='macro', num_classes=3) tensor([[0, 1, 2, 1, 1], [1, 1, 1, 1, 2], [1, 0, 3, 0, 1]]) >>> stat_scores(preds, target, reduce='micro') tensor([2, 2, 6, 2, 4]) """ multiclass = _deprecation_warn_arg_is_multiclass(is_multiclass, multiclass) if reduce not in ["micro", "macro", "samples"]: raise ValueError(f"The `reduce` {reduce} is not valid.") if mdmc_reduce not in [None, "samplewise", "global"]: raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") if reduce == "macro" and (not num_classes or num_classes < 1): raise ValueError( "When you set `reduce` as 'macro', you have to provide the number of classes." ) if num_classes and ignore_index is not None and ( not 0 <= ignore_index < num_classes or num_classes == 1): raise ValueError( f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes" ) tp, fp, tn, fn = _stat_scores_update( preds, target, reduce=reduce, mdmc_reduce=mdmc_reduce, top_k=top_k, threshold=threshold, num_classes=num_classes, multiclass=multiclass, ignore_index=ignore_index, ) return _stat_scores_compute(tp, fp, tn, fn)