def validation_step(self, batch, batch_idx): stage = "Validation" x, y = batch preds = self(x) y_2d = y.argmax(axis=1).long() loss = F.cross_entropy(input=preds, target=y_2d) + ( 1 - dice_score(pred=preds, target=y_2d)) self.log("Valid Loss", loss, on_step=True) self.logger.experiment.log_metric(f"Loss {stage}", loss, step=self.global_step) score = dice_score(preds, y_2d) self.log("Valid Score", score, on_step=True) self.logger.experiment.log_metric(f"Dice score {stage}", score, step=self.global_step) if True: self._log_step_figures(x, y_2d, preds, batch_idx) return loss
def test_step(self, batch, batch_idx): stage = "Testing" x, y = batch preds = self(x) y_2d = y.argmax(axis=1).long() loss = F.cross_entropy(input=preds, target=y_2d) + ( 1 - dice_score(pred=preds, target=y_2d)) self.log("Valid Loss", loss, on_step=True) self.logger.experiment.log_metric(f"Loss {stage}", loss, step=self.global_step) score = dice_score(preds, y_2d) self.log("Valid Score", score, on_step=True) self.logger.experiment.log_metric(f"Dice score {stage}", score, step=self.global_step) if self.preds is not None: self.preds = torch.cat((preds.argmax(dim=1), self.preds), dim=0) self.targets = torch.cat((y_2d, self.targets), dim=0) else: self.preds = preds.argmax(dim=1) self.targets = y_2d
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Actual metric computation Args: pred: predicted probability for each label target: groundtruth labels Return: torch.Tensor: the calculated dice coefficient """ return dice_score(pred=pred, target=target, bg=self.include_background, nan_score=self.nan_score, no_fg_score=self.no_fg_score, reduction=self.reduction)
def training_step(self, batch, batch_idx): stage = "Training" x, y = batch preds = self(x) y_2d = y.argmax(axis=1).long() loss = F.cross_entropy(input=preds, target=y_2d) + ( 1 - dice_score(pred=preds, target=y_2d)) self.logger.experiment.log_metric(f"Loss {stage}", loss, step=self.global_step) self.log("Train Loss", loss, on_step=True) # if True: # self._log_step_figures(x, y_2d, preds, batch_idx) return loss
def test_dice_score(pred, target, expected): score = dice_score(torch.tensor(pred), torch.tensor(target)) assert score == expected