def validation_step(self, batch, batch_idx: int):
        inputs, targets = batch["image"], batch["label"]

        logits = self(inputs)
        loss = self.criterion(logits.view(-1), targets.view(-1))

        if self.current_epoch % self.hparams.log_mod == 0 and batch_idx == 0:
            log_all_info(
                module=self,
                target=targets[0],
                preb=logits[0],
                loss=loss,
                batch_idx=batch_idx,
                state="val",
            )
        self.log("val_loss", loss, sync_dist=True, on_step=True, on_epoch=True)

        targets = targets.cpu().detach().numpy().squeeze()
        predicts = logits.cpu().detach().numpy().squeeze()

        if batch_idx <= 5:
            np.savez(f"{batch_idx}.npz", target=targets, predict=predicts)

        mse = np.square(np.subtract(targets, predicts)).mean()
        ssim_ = ssim(targets,
                     predicts,
                     data_range=predicts.max() - predicts.min())
        psnr_ = psnr(targets,
                     predicts,
                     data_range=predicts.max() - predicts.min())

        brain_mask = targets == targets[0][0][0]

        pred_clip = np.clip(predicts, -self.clip_min, self.clip_max) - min(
            -self.clip_min, np.min(predicts))
        targ_clip = np.clip(targets, -self.clip_min, self.clip_max) - min(
            -self.clip_min, np.min(targets))
        pred_255 = np.floor(256 * (pred_clip /
                                   (self.clip_min + self.clip_max)))
        targ_255 = np.floor(256 * (targ_clip /
                                   (self.clip_min + self.clip_max)))
        pred_255[brain_mask] = 0
        targ_255[brain_mask] = 0

        diff_255 = np.absolute(pred_255.ravel() - targ_255.ravel())
        mae = np.mean(diff_255)

        diff_255_mask = np.absolute(pred_255[~brain_mask].ravel() -
                                    targ_255[~brain_mask].ravel())
        np.mean(diff_255_mask)

        # return {"MAE": mae, "MAE_mask": mae_mask, "MSE": mse}
        return {"MAE": mae, "MSE": mse, "SSIM": ssim_, "PSNR": psnr_}
    def test_step(self, batch, batch_idx: int):
        inputs, targets = batch["image"], batch["label"]
        logits = self(inputs)
        loss = self.criterion(logits.view(-1), targets.view(-1))

        if batch_idx <= 15:
            log_all_info(
                module=self,
                target=targets,
                preb=logits,
                loss=loss,
                batch_idx=batch_idx,
                state="val",
            )

        targets = targets.cpu().detach().numpy().squeeze()
        predicts = logits.cpu().detach().numpy().squeeze()

        mse = np.square(np.subtract(targets, predicts)).mean()

        if batch_idx <= 5:
            np.savez(f"{batch_idx}.npz", target=targets, predict=predicts)

        brain_mask = targets == targets[0][0][0]

        pred_clip = np.clip(predicts, -self.clip_min, self.clip_max) - min(
            -self.clip_min, np.min(predicts))
        targ_clip = np.clip(targets, -self.clip_min, self.clip_max) - min(
            -self.clip_min, np.min(targets))
        pred_255 = np.floor(256 * (pred_clip /
                                   (self.clip_min + self.clip_max)))
        targ_255 = np.floor(256 * (targ_clip /
                                   (self.clip_min + self.clip_max)))
        pred_255[brain_mask] = 0
        targ_255[brain_mask] = 0

        diff_255 = np.absolute(pred_255.ravel() - targ_255.ravel())
        mae = np.mean(diff_255)

        diff_255_mask = np.absolute(pred_255[~brain_mask].ravel() -
                                    targ_255[~brain_mask].ravel())
        mae_mask = np.mean(diff_255_mask)

        return {"MAE": mae, "MAE_mask": mae_mask, "MSE": mse}
Beispiel #3
0
    def validation_step(self, batch, batch_idx: int):
        inputs, targets = batch
        inputs, targets = torch.squeeze(inputs), torch.squeeze(
            targets)  # only use 2D slices to train

        logits = self(inputs)
        loss = self.criterion(logits.view(-1), targets.view(-1))

        if self.current_epoch % 15 == 0 and batch_idx == 0:
            log_all_info(
                module=self,
                target=targets,
                preb=logits,
                loss=loss,
                batch_idx=batch_idx,
                state="val",
            )
        self.log("val_loss", loss, sync_dist=True, on_step=True, on_epoch=True)

        targets = targets.cpu().detach().numpy().squeeze()
        predicts = logits.cpu().detach().numpy().squeeze()

        brain_mask = targets == targets[0][0][0]

        pred_clip = np.clip(predicts, -self.clip_min, self.clip_max) - min(
            -self.clip_min, np.min(predicts))
        targ_clip = np.clip(targets, -self.clip_min, self.clip_max) - min(
            -self.clip_min, np.min(targets))
        pred_255 = np.floor(256 * (pred_clip /
                                   (self.clip_min + self.clip_max)))
        targ_255 = np.floor(256 * (targ_clip /
                                   (self.clip_min + self.clip_max)))
        pred_255[brain_mask] = 0
        targ_255[brain_mask] = 0

        diff_255 = np.absolute(pred_255.ravel() - targ_255.ravel())
        mae = np.mean(diff_255)

        diff_255_mask = np.absolute(pred_255[~brain_mask].ravel() -
                                    targ_255[~brain_mask].ravel())
        mae_mask = np.mean(diff_255_mask)

        return {"MAE": mae, "MAE_mask": mae_mask}
    def training_step(self, batch, batch_idx: int):
        inputs, targets = batch["image"], batch["label"]

        logits = self(inputs)
        loss = self.criterion(logits.view(-1), targets.view(-1))

        if self.current_epoch % self.hparams.log_mod == 0 and batch_idx == 0:
            log_all_info(
                module=self,
                target=targets[0],
                preb=logits[0],
                loss=loss,
                batch_idx=batch_idx,
                state="train",
            )
        self.log("train_loss",
                 loss,
                 sync_dist=True,
                 on_step=True,
                 on_epoch=True)
        return {"loss": loss}
Beispiel #5
0
    def training_step(self, batch, batch_idx: int):
        inputs, targets = batch
        inputs, targets = torch.squeeze(inputs), torch.squeeze(
            targets)  # only use 2D slices to train

        logits = self(inputs)
        loss = self.criterion(logits.view(-1), targets.view(-1))

        if self.current_epoch % 15 == 0 and batch_idx == 0:
            log_all_info(
                module=self,
                target=targets,
                preb=logits,
                loss=loss,
                batch_idx=batch_idx,
                state="train",
            )

        self.log("train_loss",
                 loss,
                 sync_dist=True,
                 on_step=True,
                 on_epoch=True)
        return {"loss": loss}