Пример #1
0
def test_top_k(k: int, preds: torch.Tensor, target: torch.Tensor, reduce: str, expected: torch.Tensor):
    """ A simple test to check that top_k works as expected """

    class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3)
    class_metric.update(preds, target)

    assert torch.equal(class_metric.compute(), expected.T)
    assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T)
Пример #2
0
    def __init__(
        self,
        n_channel: int = 1,
        n_class: int = 2,
        learning_rate: float = 1e-4,
        class_weight: List[float] = None,
        backbone: Union[str, nn.Module] = "simple-cnn",
        backbone_output_size: int = 0,
        n_hidden: int = 512,
        dropout: float = 0.2,
        lr_scheduler: bool = False,
        lr_scheduler_warmup_steps: int = 100,
        lr_scheduler_total_steps: int = 0,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters()

        if isinstance(backbone, str):
            self.backbone, backbone_output_size = get_backbone(
                backbone,
                channels=n_channel,
                dropout=dropout,
                **kwargs,
            )

        self.classifier = Classifier(backbone_output_size, n_class, n_hidden,
                                     dropout)

        if class_weight is not None:
            class_weight = torch.tensor(class_weight, dtype=torch.float)
        self.loss_fn = nn.CrossEntropyLoss(weight=class_weight)

        self.train_accuracy = Accuracy()
        self.val_accuracy = Accuracy()
        self.test_metrics = MetricCollection([
            Accuracy(),
            F1(num_classes=self.hparams.n_class, average="macro"),
            Recall(num_classes=self.hparams.n_class,
                   average="macro"),  # balanced acc.
            StatScores(
                num_classes=self.hparams.n_class
                if self.hparams.n_class > 2 else 1,
                reduce="micro",
                multiclass=self.hparams.n_class > 2,
            ),
        ])
Пример #3
0
def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index):
    """Test a combination of parameters that are invalid and should raise an error.

    This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting
    ``num_classes`` when ``reduce='macro'`, not setting ``mdmc_reduce`` when inputs
    are multi-dim multi-class``, setting ``ignore_index`` when inputs are binary, as well
    as setting ``ignore_index`` to a value higher than the number of classes.
    """
    with pytest.raises(ValueError):
        stat_scores(
            inputs.preds[0], inputs.target[0], reduce, mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index
        )

    with pytest.raises(ValueError):
        sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index)
        sts(inputs.preds[0], inputs.target[0])
Пример #4
0
def test_wrong_threshold():
    with pytest.raises(ValueError):
        StatScores(threshold=1.5)