def _input_format_classification_one_hot( num_classes: int, preds: Tensor, target: Tensor, threshold: float = 0.5, multilabel: bool = False, ) -> Tuple[Tensor, Tensor]: """Convert preds and target tensors into one hot spare label tensors Args: num_classes: number of classes preds: either tensor with labels, tensor with probabilities/logits or multilabel tensor target: tensor with ground true labels threshold: float used for thresholding multilabel input multilabel: boolean flag indicating if input is multilabel Raises: ValueError: If ``preds`` and ``target`` don't have the same number of dimensions or one additional dimension for ``preds``. Returns: preds: one hot tensor of shape [num_classes, -1] with predicted labels target: one hot tensors of shape [num_classes, -1] with true labels """ if preds.ndim not in (target.ndim, target.ndim + 1): raise ValueError( "preds and target must have same number of dimensions, or one additional dimension for preds" ) if preds.ndim == target.ndim + 1: # multi class probabilities preds = torch.argmax(preds, dim=1) if preds.ndim == target.ndim and preds.dtype in ( torch.long, torch.int) and num_classes > 1 and not multilabel: # multi-class preds = to_onehot(preds, num_classes=num_classes) target = to_onehot(target, num_classes=num_classes) elif preds.ndim == target.ndim and preds.is_floating_point(): # binary or multilabel probabilities preds = (preds >= threshold).long() # transpose class as first dim and reshape if preds.ndim > 1: preds = preds.transpose(1, 0) target = target.transpose(1, 0) return preds.reshape(num_classes, -1), target.reshape(num_classes, -1)
def update(self, preds: torch.Tensor, targets: torch.Tensor): # type: ignore if self.exclude_neutral: pp = (preds[targets != 0] >= 0).int() tt = (targets[targets != 0] >= 0).int() else: pp = (preds >= 0).int() tt = (targets >= 0).int() pp = to_onehot(pp, num_classes=2).transpose(1, 0).reshape(2, -1) tt = to_onehot(tt, num_classes=2).transpose(1, 0).reshape(2, -1) true_positives = torch.sum(pp * tt, dim=1) predicted_positives = torch.sum(pp, dim=1) actual_positives = torch.sum(tt, dim=1) self.true_positives += true_positives self.predicted_positives += predicted_positives self.actual_positives += actual_positives
def test_onehot(): test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) expected = torch.stack([ torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) ]) assert test_tensor.shape == (2, 5) assert expected.shape == (2, 10, 5) onehot_classes = to_onehot(test_tensor, num_classes=10) onehot_no_classes = to_onehot(test_tensor) assert torch.allclose(onehot_classes, onehot_no_classes) assert onehot_classes.shape == expected.shape assert onehot_no_classes.shape == expected.shape assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes) assert torch.allclose(expected.to(onehot_classes), onehot_classes)
def _hinge_update( preds: Tensor, target: Tensor, squared: bool = False, multiclass_mode: Optional[Union[str, MulticlassMode]] = None, ) -> Tuple[Tensor, Tensor]: """Updates and returns sum over Hinge loss scores for each observation and the total number of observations. Args: preds: Predicted tensor target: Ground truth tensor squared: If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. multiclass_mode: Which approach to use for multi-class inputs (has no effect in the binary case). ``None`` (default), ``MulticlassMode.CRAMMER_SINGER`` or ``"crammer-singer"``, uses the Crammer Singer multi-class hinge loss. ``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"`` computes the hinge loss in a one-vs-all fashion. """ preds, target = _input_squeeze(preds, target) mode = _check_shape_and_type_consistency_hinge(preds, target) if mode == DataType.MULTICLASS: target = to_onehot(target, max(2, preds.shape[1])).bool() if mode == DataType.MULTICLASS and (multiclass_mode is None or multiclass_mode == MulticlassMode.CRAMMER_SINGER): margin = preds[target] margin -= torch.max(preds[~target].view(preds.shape[0], -1), dim=1)[0] elif mode == DataType.BINARY or multiclass_mode == MulticlassMode.ONE_VS_ALL: target = target.bool() margin = torch.zeros_like(preds) margin[target] = preds[target] margin[~target] = -preds[~target] else: raise ValueError( "The `multiclass_mode` should be either None / 'crammer-singer' / MulticlassMode.CRAMMER_SINGER" "(default) or 'one-vs-all' / MulticlassMode.ONE_VS_ALL," f" got {multiclass_mode}.") measures = 1 - margin measures = torch.clamp(measures, 0) if squared: measures = measures.pow(2) total = tensor(target.shape[0], device=target.device) return measures.sum(dim=0), total
def _hinge_update( preds: Tensor, target: Tensor, squared: bool = False, multiclass_mode: Optional[Union[str, MulticlassMode]] = None, ) -> Tuple[Tensor, Tensor]: if preds.shape[0] == 1: preds, target = preds.squeeze().unsqueeze( 0), target.squeeze().unsqueeze(0) else: preds, target = preds.squeeze(), target.squeeze() mode = _check_shape_and_type_consistency_hinge(preds, target) if mode == DataType.MULTICLASS: target = to_onehot(target, max(2, preds.shape[1])).bool() if mode == DataType.MULTICLASS and (multiclass_mode is None or multiclass_mode == MulticlassMode.CRAMMER_SINGER): margin = preds[target] margin -= torch.max(preds[~target].view(preds.shape[0], -1), dim=1)[0] elif mode == DataType.BINARY or multiclass_mode == MulticlassMode.ONE_VS_ALL: target = target.bool() margin = torch.zeros_like(preds) margin[target] = preds[target] margin[~target] = -preds[~target] else: raise ValueError( "The `multiclass_mode` should be either None / 'crammer-singer' / MulticlassMode.CRAMMER_SINGER" "(default) or 'one-vs-all' / MulticlassMode.ONE_VS_ALL," f" got {multiclass_mode}.") measures = 1 - margin measures = torch.clamp(measures, 0) if squared: measures = measures.pow(2) total = tensor(target.shape[0], device=target.device) return measures.sum(dim=0), total
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """ Args preds: (n_samples, n_classes) tensor target: (n_samples, n_classes) tensor """ # binary case if len(preds.shape) == len(target.shape) == 1: preds = preds.reshape(-1, 1) target = target.reshape(-1, 1) if len(preds.shape) == len(target.shape) + 1: target = to_onehot(target, num_classes=self.num_classes) target = target == 1 # Iterate one threshold at a time to conserve memory for i in range(self.num_thresholds): predictions = preds >= self.thresholds[i] self.TPs[:, i] += (target & predictions).sum(dim=0) self.FPs[:, i] += ((~target) & (predictions)).sum(dim=0) self.FNs[:, i] += ((target) & (~predictions)).sum(dim=0)
def _onehot2(x): return to_onehot(x, 2)
def _onehot(x): return to_onehot(x, NUM_CLASSES)
) if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k: preds = (preds >= threshold).int() num_classes = num_classes if not multiclass else 2 if case == DataType.MULTILABEL and top_k: preds = select_topk(preds, top_k) if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or multiclass: if preds.is_floating_point(): num_classes = preds.shape[1] preds = select_topk(preds, top_k or 1) else: num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 preds = to_onehot(preds, max(2, num_classes)) target = to_onehot(target, max(2, num_classes)) if multiclass is False: preds, target = preds[:, 1, ...], target[:, 1, ...] if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and multiclass is not False) or multiclass: target = target.reshape(target.shape[0], target.shape[1], -1) preds = preds.reshape(preds.shape[0], preds.shape[1], -1) else: target = target.reshape(target.shape[0], -1) preds = preds.reshape(preds.shape[0], -1) # Some operations above create an extra dimension for MC/binary case - this removes it if preds.ndim > 2: