from tests.classification.inputs import _input_multiclass_prob as _mc_prob from tests.classification.inputs import _input_multidim_multiclass as _mdmc from tests.classification.inputs import _input_multidim_multiclass_prob as _mdmc_prob from tests.classification.inputs import _input_multilabel as _ml from tests.classification.inputs import _input_multilabel_multidim as _mlmd from tests.classification.inputs import _input_multilabel_multidim_prob as _mlmd_prob from tests.classification.inputs import _input_multilabel_prob as _ml_prob from tests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, THRESHOLD from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.data import select_topk, to_onehot from torchmetrics.utilities.enums import DataType torch.manual_seed(42) # Some additional inputs to test on _ml_prob_half = Input(_ml_prob.preds.half(), _ml_prob.target) _mc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2) _mc_prob_2cls_preds /= _mc_prob_2cls_preds.sum(dim=2, keepdim=True) _mc_prob_2cls = Input(_mc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) _mdmc_prob_many_dims_preds = rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM, EXTRA_DIM) _mdmc_prob_many_dims_preds /= _mdmc_prob_many_dims_preds.sum(dim=2, keepdim=True) _mdmc_prob_many_dims = Input( _mdmc_prob_many_dims_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, EXTRA_DIM)), )
import numpy as np import pytest import torch from sklearn.metrics import hinge_loss as sk_hinge from sklearn.preprocessing import OneHotEncoder from tests.classification.inputs import Input from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester from torchmetrics import Hinge from torchmetrics.functional import hinge from torchmetrics.functional.classification.hinge import MulticlassMode torch.manual_seed(42) _input_binary = Input(preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) _input_binary_single = Input(preds=torch.randn((NUM_BATCHES, 1)), target=torch.randint(high=2, size=(NUM_BATCHES, 1))) _input_multiclass = Input(preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))) def _sk_hinge(preds, target, squared, multiclass_mode): sk_preds, sk_target = preds.numpy(), target.numpy()