예제 #1
0
    def validation_step(self, batch, batch_idx):
        stage = "Validation"
        x, y = batch
        preds = self(x)
        y_2d = y.argmax(axis=1).long()
        loss = F.cross_entropy(input=preds, target=y_2d) + (
            1 - dice_score(pred=preds, target=y_2d))
        self.log("Valid Loss", loss, on_step=True)
        self.logger.experiment.log_metric(f"Loss {stage}",
                                          loss,
                                          step=self.global_step)

        score = dice_score(preds, y_2d)
        self.log("Valid Score", score, on_step=True)
        self.logger.experiment.log_metric(f"Dice score {stage}",
                                          score,
                                          step=self.global_step)

        if True:
            self._log_step_figures(x, y_2d, preds, batch_idx)

        return loss
예제 #2
0
    def test_step(self, batch, batch_idx):
        stage = "Testing"
        x, y = batch
        preds = self(x)
        y_2d = y.argmax(axis=1).long()
        loss = F.cross_entropy(input=preds, target=y_2d) + (
            1 - dice_score(pred=preds, target=y_2d))
        self.log("Valid Loss", loss, on_step=True)
        self.logger.experiment.log_metric(f"Loss {stage}",
                                          loss,
                                          step=self.global_step)

        score = dice_score(preds, y_2d)
        self.log("Valid Score", score, on_step=True)
        self.logger.experiment.log_metric(f"Dice score {stage}",
                                          score,
                                          step=self.global_step)

        if self.preds is not None:
            self.preds = torch.cat((preds.argmax(dim=1), self.preds), dim=0)
            self.targets = torch.cat((y_2d, self.targets), dim=0)
        else:
            self.preds = preds.argmax(dim=1)
            self.targets = y_2d
예제 #3
0
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Actual metric computation

        Args:
            pred: predicted probability for each label
            target: groundtruth labels

        Return:
            torch.Tensor: the calculated dice coefficient
        """
        return dice_score(pred=pred,
                          target=target,
                          bg=self.include_background,
                          nan_score=self.nan_score,
                          no_fg_score=self.no_fg_score,
                          reduction=self.reduction)
예제 #4
0
    def training_step(self, batch, batch_idx):
        stage = "Training"
        x, y = batch

        preds = self(x)
        y_2d = y.argmax(axis=1).long()
        loss = F.cross_entropy(input=preds, target=y_2d) + (
            1 - dice_score(pred=preds, target=y_2d))
        self.logger.experiment.log_metric(f"Loss {stage}",
                                          loss,
                                          step=self.global_step)
        self.log("Train Loss", loss, on_step=True)

        # if True:
        #     self._log_step_figures(x, y_2d, preds, batch_idx)

        return loss
def test_dice_score(pred, target, expected):
    score = dice_score(torch.tensor(pred), torch.tensor(target))
    assert score == expected