Example #1
0
def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes,
                      ignore_index, match_str):
    with pytest.raises(ValueError, match=match_str):
        metric(
            average=average,
            mdmc_average=mdmc_average,
            num_classes=num_classes,
            ignore_index=ignore_index,
        )

    with pytest.raises(ValueError, match=match_str):
        fn_metric(
            _input_binary.preds[0],
            _input_binary.target[0],
            average=average,
            mdmc_average=mdmc_average,
            num_classes=num_classes,
            ignore_index=ignore_index,
        )

    with pytest.raises(ValueError, match=match_str):
        precision_recall(
            _input_binary.preds[0],
            _input_binary.target[0],
            average=average,
            mdmc_average=mdmc_average,
            num_classes=num_classes,
            ignore_index=ignore_index,
        )
Example #2
0
def test_precision_recall_joint(average):
    """A simple test of the joint precision_recall metric.

    No need to test this thorougly, as it is just a combination of precision and recall,
    which are already tested thoroughly.
    """

    precision_result = precision(
        _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
    )
    recall_result = recall(
        _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
    )

    prec_recall_result = precision_recall(
        _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
    )

    assert torch.equal(precision_result, prec_recall_result[0])
    assert torch.equal(recall_result, prec_recall_result[1])
Example #3
0
    def batch_step(self, batch):
        """ Used in train and validation """
        data, target = batch
        if self.training and self.hparams.cutmix>0 and torch.rand(1) < self.hparams.cutmix_prob:
            lam = np.random.beta(self.hparams.cutmix, self.hparams.cutmix)
            rand_index = torch.randperm(data.size()[0]).to(data.device)
            target_a = target
            target_b = target[rand_index]
            # Now the bboxes for the input and mask
            _, _, w, h = data.size()
            cut_rat = np.sqrt(1.0 - lam)
            cut_w, cut_h = int(w*cut_rat), int(h*cut_rat)  # Box size
            cx, cy = np.random.randint(w), np.random.randint(h)  # Box center
            bbx1 = np.clip(cx - cut_w // 2, 0, w)
            bbx2 = np.clip(cx + cut_w // 2, 0, w)
            bby1 = np.clip(cy - cut_h // 2, 0, h)
            bby2 = np.clip(cy + cut_h // 2, 0, h)
            data[:, :, bbx1:bbx2, bby1:bby2] = data[rand_index, :, bbx1:bbx2, bby1:bby2]
            # Adjust the classification loss based on pixel area ratio
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (w*h))
            logits = self.model(data)
            loss = self.criterion(logits, target_a)*lam + self.criterion(logits, target_b)*(1.0-lam)
        else:
            logits = self.model(data)
            loss = self.criterion(logits, target)

        pred = torch.argmax(logits, dim=1)
        acc = accuracy(pred, target)
        avg_precision, avg_recall = precision_recall(pred, target, num_classes=self.hparams.num_classes,
                                                        average="macro", mdmc_average="global")
        weighted_f1 = f1(pred, target, num_classes=self.hparams.num_classes,
                            threshold=0.5, average="weighted")
        metrics = {
            "loss": loss,  # attached to computation graph, not necessary in validation, but I'm to lazy to fix
            "accuracy": acc,
            "average_precision": avg_precision,
            "average_recall": avg_recall,
            "weighted_f1": weighted_f1,
        }
        return metrics