def test_incorrect_inputs(preds, target, num_classes, is_multiclass): with pytest.raises(ValueError): _input_format_classification(preds=preds, target=target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass)
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): def __get_data_type_enum(str_exp_mode): return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode) for exp_mode in (exp_mode, __get_data_type_enum(exp_mode)): preds_out, target_out, mode = _input_format_classification( preds=inputs.preds[0], target=inputs.target[0], threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k, ) assert mode == exp_mode assert torch.equal(preds_out, post_preds(inputs.preds[0]).int()) assert torch.equal(target_out, post_target(inputs.target[0]).int()) # Test that things work when batch_size = 1 preds_out, target_out, mode = _input_format_classification( preds=inputs.preds[0][[0], ...], target=inputs.target[0][[0], ...], threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k, ) assert mode == exp_mode assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int()) assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int())
def _accuracy_update( preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], subset_accuracy: bool) -> Tuple[torch.Tensor, torch.Tensor]: preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) if mode == DataType.MULTILABEL and top_k: raise ValueError( "You can not use the `top_k` parameter to calculate accuracy for multi-label inputs." ) if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy): correct = (preds == target).all(dim=1).sum() total = torch.tensor(target.shape[0], device=target.device) elif mode == DataType.MULTILABEL and not subset_accuracy: correct = (preds == target).sum() total = torch.tensor(target.numel(), device=target.device) elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy): correct = (preds * target).sum() total = target.sum() elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy: sample_correct = (preds * target).sum(dim=(1, 2)) correct = (sample_correct == target.shape[2]).sum() total = torch.tensor(target.shape[0], device=target.device) return correct, total
def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_index, top_k, mdmc_reduce=None): preds, target, _ = _input_format_classification( preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k ) sk_preds, sk_target = preds.numpy(), target.numpy() if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: sk_preds = np.delete(sk_preds, ignore_index, 1) sk_target = np.delete(sk_target, ignore_index, 1) if preds.shape[1] == 1 and reduce == "samples": sk_target = sk_target.T sk_preds = sk_preds.T sk_stats = multilabel_confusion_matrix( sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 ) if preds.shape[1] == 1 and reduce != "samples": sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] else: sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] if reduce == "micro": sk_stats = sk_stats.sum(axis=0, keepdims=True) sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) if reduce == "micro": sk_stats = sk_stats[0] if reduce == "macro" and ignore_index is not None and preds.shape[1]: sk_stats[ignore_index, :] = -1 return sk_stats
def _sk_stat_scores_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, is_multiclass, ignore_index, top_k): preds, target, _ = _input_format_classification( preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k) if mdmc_reduce == "global": preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k) elif mdmc_reduce == "samplewise": scores = [] for i in range(preds.shape[0]): pred_i = preds[i, ...].T target_i = target[i, ...].T scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k) scores.append(np.expand_dims(scores_i, 0)) return np.concatenate(scores)
def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average): preds, target, _ = _input_format_classification( preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass) if mdmc_average == "global": preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index) elif mdmc_average == "samplewise": scores = [] for i in range(preds.shape[0]): pred_i = preds[i, ...].T target_i = target[i, ...].T scores_i = _sk_prec_recall(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index) scores.append(np.expand_dims(scores_i, 0)) return np.concatenate(scores).mean(axis=0)
def _sk_hamming_loss(preds, target): sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape( sk_target.shape[0], -1) return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds)
def test_threshold(): target = T([1, 1, 1]).int() preds_probs = T([0.5 - 1e-5, 0.5, 0.5 + 1e-5]) preds_probs_out, _, _ = _input_format_classification(preds_probs, target, threshold=0.5) assert torch.equal(torch.tensor([0, 1, 1], dtype=torch.int), preds_probs_out.squeeze().int())
def _stat_scores_update( preds: torch.Tensor, target: torch.Tensor, reduce: str = "micro", mdmc_reduce: Optional[str] = None, num_classes: Optional[int] = None, top_k: Optional[int] = None, threshold: float = 0.5, is_multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: preds, target, _ = _input_format_classification( preds, target, threshold=threshold, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k) if ignore_index is not None and not 0 <= ignore_index < preds.shape[1]: raise ValueError( f"The `ignore_index` {ignore_index} is not valid for inputs with {preds.shape[0]} classes" ) if ignore_index is not None and preds.shape[1] == 1: raise ValueError("You can not use `ignore_index` with binary data.") if preds.ndim == 3: if not mdmc_reduce: raise ValueError( "When your inputs are multi-dimensional multi-class, you have to set the `mdmc_reduce` parameter" ) if mdmc_reduce == "global": preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) # Delete what is in ignore_index, if applicable (and classes don't matter): if ignore_index is not None and reduce != "macro": preds = _del_column(preds, ignore_index) target = _del_column(target, ignore_index) tp, fp, tn, fn = _stat_scores(preds, target, reduce=reduce) # Take care of ignore_index if ignore_index is not None and reduce == "macro": tp[..., ignore_index] = -1 fp[..., ignore_index] = -1 tn[..., ignore_index] = -1 fn[..., ignore_index] = -1 return tp, fp, tn, fn
def _sk_accuracy(preds, target, subset_accuracy): sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD) sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy: sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1)) sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2]) elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy: return np.all(sk_preds == sk_target, axis=(1, 2)).mean() elif mode == DataType.MULTILABEL and not subset_accuracy: sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1) return sk_accuracy(y_true=sk_target, y_pred=sk_preds)
def _hamming_distance_update( preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5, ) -> Tuple[torch.Tensor, int]: preds, target, _ = _input_format_classification(preds, target, threshold=threshold) correct = (preds == target).sum() total = preds.numel() return correct, total
def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, str]: # use _input_format_classification for validating the input and get the mode of data _, _, mode = _input_format_classification(preds, target) if mode == 'multi class multi dim': n_classes = preds.shape[1] preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) target = target.flatten() if mode == 'multi-label' and preds.ndim > 2: n_classes = preds.shape[1] preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) target = target.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) return preds, target, mode
def _confusion_matrix_update(preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5) -> torch.Tensor: preds, target, mode = _input_format_classification(preds, target, threshold) if mode not in (DataType.BINARY, DataType.MULTILABEL): preds = preds.argmax(dim=1) target = target.argmax(dim=1) unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to( torch.long) bins = torch.bincount(unique_mapping, minlength=num_classes**2) confmat = bins.reshape(num_classes, num_classes) return confmat
def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None): if average == "none": average = None if num_classes == 1: average = "binary" labels = list(range(num_classes)) try: labels.remove(ignore_index) except ValueError: pass sk_preds, sk_target, _ = _input_format_classification( preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass) sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels) if len(labels) != num_classes and not average: sk_scores = np.insert(sk_scores, ignore_index, np.nan) return sk_scores
def test_incorrect_threshold(threshold): preds, target = rand(size=(7, )), randint(high=2, size=(7, )) with pytest.raises(ValueError): _input_format_classification(preds, target, threshold=threshold)