Exemplo n.º 1
0
    def validation_epoch_end(self, val_logs):
        # assert val_logs[0]["output"].ndim == 3
        device = val_logs[0]["device"]

        # run the visualizations
        self._visualize(
            val_outputs=[x["output_T2"].numpy() for x in val_logs],
            val_targets=[x["target_im_T2"].numpy() for x in val_logs],
        )

        # aggregate losses
        losses = []
        outputs = defaultdict(list)
        targets = defaultdict(list)

        for val_log in val_logs:
            losses.append(val_log["val_loss"])
            for i, (fname, slice_ind) in enumerate(
                    zip(val_log["fname"], val_log["slice"])):
                # need to check for duplicate slices
                if slice_ind not in [s for (s, _) in outputs[int(fname)]]:
                    outputs[int(fname)].append(
                        (int(slice_ind), val_log["output_T2"][i]))
                    targets[int(fname)].append(
                        (int(slice_ind), val_log["target_im_T2"][i]))

        # handle aggregation for distributed case with pytorch_lightning metrics
        metrics = dict(val_loss=0, nmse=0, ssim=0, psnr=0)
        for fname in outputs:
            output = torch.stack([out for _, out in sorted(outputs[fname])
                                  ]).numpy()  #2,1,256,256
            target = torch.stack([tgt for _, tgt in sorted(targets[fname])
                                  ]).numpy()
            output = output[:, 0, :, :]
            target = target[:, 0, :, :]
            metrics["nmse"] = metrics["nmse"] + evaluate.nmse(target, output)
            metrics["ssim"] = metrics["ssim"] + evaluate.ssim(target, output)
            metrics["psnr"] = metrics["psnr"] + evaluate.psnr(target, output)

        # currently ddp reduction requires everything on CUDA, so we'll do this manually
        metrics["nmse"] = self.NMSE(torch.tensor(metrics["nmse"]).to(device))
        metrics["ssim"] = self.SSIM(torch.tensor(metrics["ssim"]).to(device))
        metrics["psnr"] = self.PSNR(torch.tensor(metrics["psnr"]).to(device))
        metrics["val_loss"] = self.ValLoss(
            torch.sum(torch.stack(losses)).to(device))

        num_examples = torch.tensor(len(outputs)).to(device)
        tot_examples = self.TotExamples(num_examples)

        log_metrics = {
            f"metrics/{metric}": values / tot_examples
            for metric, values in metrics.items()
        }
        metrics = {
            metric: values / tot_examples
            for metric, values in metrics.items()
        }
        print(tot_examples, device, metrics)
        return dict(log=log_metrics, **metrics)
Exemplo n.º 2
0
    def validation_step_end(self, val_logs):
        # check inputs
        for k in (
                "batch_idx",
                "fname",
                "slice_num",
                "max_value",
                "output",
                "target",
                "val_loss",
        ):
            if k not in val_logs.keys():
                raise RuntimeError(
                    f"Expected key {k} in dict returned by validation_step.")
        if val_logs["output"].ndim == 2:
            val_logs["output"] = val_logs["output"].unsqueeze(0)
        elif val_logs["output"].ndim != 3:
            raise RuntimeError("Unexpected output size from validation_step.")
        if val_logs["target"].ndim == 2:
            val_logs["target"] = val_logs["target"].unsqueeze(0)
        elif val_logs["target"].ndim != 3:
            # print(val_logs['target'].shape)
            raise RuntimeError("Unexpected output size from validation_step.")

        # pick a set of images to log if we don't have one already
        if self.val_log_indices is None:
            self.val_log_indices = list(
                np.random.permutation(len(
                    self.trainer.val_dataloaders[0]))[:self.num_log_images])

        # log images to tensorboard
        if isinstance(val_logs["batch_idx"], int):
            batch_indices = [val_logs["batch_idx"]]
        else:
            batch_indices = val_logs["batch_idx"]
        for i, batch_idx in enumerate(batch_indices):
            if batch_idx in self.val_log_indices:
                key = f"val_images_idx_{batch_idx}"
                target = val_logs["target"][i].unsqueeze(0)
                output = val_logs["output"][i].unsqueeze(0)
                error = torch.abs(target - output)
                output = output / output.max()
                target = target / target.max()
                error = error / error.max()
                self.logger.experiment.add_image(f"{key}/target",
                                                 target,
                                                 global_step=self.global_step)
                self.logger.experiment.add_image(f"{key}/reconstruction",
                                                 output,
                                                 global_step=self.global_step)
                self.logger.experiment.add_image(f"{key}/error",
                                                 error,
                                                 global_step=self.global_step)

        # compute evaluation metrics
        nmse_vals = defaultdict(dict)
        ssim_vals = defaultdict(dict)
        psnr_vals = defaultdict(dict)
        for i, fname in enumerate(val_logs["fname"]):
            slice_num = int(val_logs["slice_num"][i].cpu())
            maxval = val_logs["max_value"][i].cpu().numpy()
            output = val_logs["output"][i].cpu().numpy()
            target = val_logs["target"][i].cpu().numpy()

            nmse_vals[fname][slice_num] = torch.tensor(
                evaluate.nmse(target, output)).view(1)
            ssim_vals[fname][slice_num] = torch.tensor(
                evaluate.ssim(target, output, maxval=maxval)).view(1)
            psnr_vals[fname][slice_num] = torch.tensor(
                evaluate.psnr(target, output)).view(1)

        return {
            "val_loss": val_logs["val_loss"],
            "nmse_vals": nmse_vals,
            "ssim_vals": ssim_vals,
            "psnr_vals": psnr_vals,
        }