def do_metric_reduction( f: torch.Tensor, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, ): """ This function is to do the metric reduction for calculated metrics of each example's each class. Args: f: a tensor that contains the calculated metric scores per batch and per class. The first two dims should be batch and class. reduction: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``} Define the mode to reduce computation result of 1 batch data. Defaults to ``"mean"``. Raises: ValueError: When ``reduction`` is not one of ["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"]. """ # some elements might be Nan (if ground truth y was missing (zeros)) # we need to account for it nans = torch.isnan(f) not_nans = (~nans).float() f[nans] = 0 t_zero = torch.zeros(1, device=f.device, dtype=f.dtype) reduction = MetricReduction(reduction) if reduction == MetricReduction.MEAN: # 2 steps, first, mean by channel (accounting for nans), then by batch not_nans = not_nans.sum(dim=1) f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average not_nans = (not_nans > 0).float().sum(dim=0) f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM: not_nans = not_nans.sum(dim=[0, 1]) f = torch.sum(f, dim=[0, 1]) # sum over the batch and channel dims elif reduction == MetricReduction.MEAN_BATCH: not_nans = not_nans.sum(dim=0) f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average elif reduction == MetricReduction.SUM_BATCH: not_nans = not_nans.sum(dim=0) f = f.sum(dim=0) # the batch sum elif reduction == MetricReduction.MEAN_CHANNEL: not_nans = not_nans.sum(dim=1) f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average elif reduction == MetricReduction.SUM_CHANNEL: not_nans = not_nans.sum(dim=1) f = f.sum(dim=1) # the channel sum elif reduction != MetricReduction.NONE: raise ValueError( f"Unsupported reduction: {reduction}, available options are " '["mean", "sum", "mean_batch", "sum_batch", "mean_channel", "sum_channel" "none"].' ) return f, not_nans
def __init__( self, include_background: bool = True, to_onehot_y: bool = False, mutually_exclusive: bool = False, sigmoid: bool = False, other_act: Optional[Callable] = None, logit_thresh: float = 0.5, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, ) -> None: super().__init__() if sigmoid and other_act is not None: raise ValueError( "Incompatible values: ``sigmoid=True`` and ``other_act is not None``." ) self.include_background = include_background self.to_onehot_y = to_onehot_y self.mutually_exclusive = mutually_exclusive self.sigmoid = sigmoid self.other_act = other_act self.logit_thresh = logit_thresh self.reduction: MetricReduction = MetricReduction(reduction) self.not_nans: Optional[ torch.Tensor] = None # keep track for valid elements in the batch
def __init__( self, include_background: bool = True, to_onehot_y: bool = False, mutually_exclusive: bool = False, sigmoid: bool = False, logit_thresh: float = 0.5, reduction: Union[MetricReduction, str] = MetricReduction.MEAN, ): super().__init__() self.include_background = include_background self.to_onehot_y = to_onehot_y self.mutually_exclusive = mutually_exclusive self.sigmoid = sigmoid self.logit_thresh = logit_thresh self.reduction: MetricReduction = MetricReduction(reduction) self.not_nans = None # keep track for valid elements in the batch