def test_top_k(k: int, preds: torch.Tensor, target: torch.Tensor, reduce: str, expected: torch.Tensor): """ A simple test to check that top_k works as expected """ class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3) class_metric.update(preds, target) assert torch.equal(class_metric.compute(), expected.T) assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T)
def __init__( self, n_channel: int = 1, n_class: int = 2, learning_rate: float = 1e-4, class_weight: List[float] = None, backbone: Union[str, nn.Module] = "simple-cnn", backbone_output_size: int = 0, n_hidden: int = 512, dropout: float = 0.2, lr_scheduler: bool = False, lr_scheduler_warmup_steps: int = 100, lr_scheduler_total_steps: int = 0, **kwargs, ): super().__init__() self.save_hyperparameters() if isinstance(backbone, str): self.backbone, backbone_output_size = get_backbone( backbone, channels=n_channel, dropout=dropout, **kwargs, ) self.classifier = Classifier(backbone_output_size, n_class, n_hidden, dropout) if class_weight is not None: class_weight = torch.tensor(class_weight, dtype=torch.float) self.loss_fn = nn.CrossEntropyLoss(weight=class_weight) self.train_accuracy = Accuracy() self.val_accuracy = Accuracy() self.test_metrics = MetricCollection([ Accuracy(), F1(num_classes=self.hparams.n_class, average="macro"), Recall(num_classes=self.hparams.n_class, average="macro"), # balanced acc. StatScores( num_classes=self.hparams.n_class if self.hparams.n_class > 2 else 1, reduce="micro", multiclass=self.hparams.n_class > 2, ), ])
def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index): """Test a combination of parameters that are invalid and should raise an error. This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting ``num_classes`` when ``reduce='macro'`, not setting ``mdmc_reduce`` when inputs are multi-dim multi-class``, setting ``ignore_index`` when inputs are binary, as well as setting ``ignore_index`` to a value higher than the number of classes. """ with pytest.raises(ValueError): stat_scores( inputs.preds[0], inputs.target[0], reduce, mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index ) with pytest.raises(ValueError): sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index) sts(inputs.preds[0], inputs.target[0])
def test_wrong_threshold(): with pytest.raises(ValueError): StatScores(threshold=1.5)