def validation_epoch_end(self, val_logs): # assert val_logs[0]["output"].ndim == 3 device = val_logs[0]["device"] # run the visualizations self._visualize( val_outputs=[x["output_T2"].numpy() for x in val_logs], val_targets=[x["target_im_T2"].numpy() for x in val_logs], ) # aggregate losses losses = [] outputs = defaultdict(list) targets = defaultdict(list) for val_log in val_logs: losses.append(val_log["val_loss"]) for i, (fname, slice_ind) in enumerate( zip(val_log["fname"], val_log["slice"])): # need to check for duplicate slices if slice_ind not in [s for (s, _) in outputs[int(fname)]]: outputs[int(fname)].append( (int(slice_ind), val_log["output_T2"][i])) targets[int(fname)].append( (int(slice_ind), val_log["target_im_T2"][i])) # handle aggregation for distributed case with pytorch_lightning metrics metrics = dict(val_loss=0, nmse=0, ssim=0, psnr=0) for fname in outputs: output = torch.stack([out for _, out in sorted(outputs[fname]) ]).numpy() #2,1,256,256 target = torch.stack([tgt for _, tgt in sorted(targets[fname]) ]).numpy() output = output[:, 0, :, :] target = target[:, 0, :, :] metrics["nmse"] = metrics["nmse"] + evaluate.nmse(target, output) metrics["ssim"] = metrics["ssim"] + evaluate.ssim(target, output) metrics["psnr"] = metrics["psnr"] + evaluate.psnr(target, output) # currently ddp reduction requires everything on CUDA, so we'll do this manually metrics["nmse"] = self.NMSE(torch.tensor(metrics["nmse"]).to(device)) metrics["ssim"] = self.SSIM(torch.tensor(metrics["ssim"]).to(device)) metrics["psnr"] = self.PSNR(torch.tensor(metrics["psnr"]).to(device)) metrics["val_loss"] = self.ValLoss( torch.sum(torch.stack(losses)).to(device)) num_examples = torch.tensor(len(outputs)).to(device) tot_examples = self.TotExamples(num_examples) log_metrics = { f"metrics/{metric}": values / tot_examples for metric, values in metrics.items() } metrics = { metric: values / tot_examples for metric, values in metrics.items() } print(tot_examples, device, metrics) return dict(log=log_metrics, **metrics)
def validation_step_end(self, val_logs): # check inputs for k in ( "batch_idx", "fname", "slice_num", "max_value", "output", "target", "val_loss", ): if k not in val_logs.keys(): raise RuntimeError( f"Expected key {k} in dict returned by validation_step.") if val_logs["output"].ndim == 2: val_logs["output"] = val_logs["output"].unsqueeze(0) elif val_logs["output"].ndim != 3: raise RuntimeError("Unexpected output size from validation_step.") if val_logs["target"].ndim == 2: val_logs["target"] = val_logs["target"].unsqueeze(0) elif val_logs["target"].ndim != 3: # print(val_logs['target'].shape) raise RuntimeError("Unexpected output size from validation_step.") # pick a set of images to log if we don't have one already if self.val_log_indices is None: self.val_log_indices = list( np.random.permutation(len( self.trainer.val_dataloaders[0]))[:self.num_log_images]) # log images to tensorboard if isinstance(val_logs["batch_idx"], int): batch_indices = [val_logs["batch_idx"]] else: batch_indices = val_logs["batch_idx"] for i, batch_idx in enumerate(batch_indices): if batch_idx in self.val_log_indices: key = f"val_images_idx_{batch_idx}" target = val_logs["target"][i].unsqueeze(0) output = val_logs["output"][i].unsqueeze(0) error = torch.abs(target - output) output = output / output.max() target = target / target.max() error = error / error.max() self.logger.experiment.add_image(f"{key}/target", target, global_step=self.global_step) self.logger.experiment.add_image(f"{key}/reconstruction", output, global_step=self.global_step) self.logger.experiment.add_image(f"{key}/error", error, global_step=self.global_step) # compute evaluation metrics nmse_vals = defaultdict(dict) ssim_vals = defaultdict(dict) psnr_vals = defaultdict(dict) for i, fname in enumerate(val_logs["fname"]): slice_num = int(val_logs["slice_num"][i].cpu()) maxval = val_logs["max_value"][i].cpu().numpy() output = val_logs["output"][i].cpu().numpy() target = val_logs["target"][i].cpu().numpy() nmse_vals[fname][slice_num] = torch.tensor( evaluate.nmse(target, output)).view(1) ssim_vals[fname][slice_num] = torch.tensor( evaluate.ssim(target, output, maxval=maxval)).view(1) psnr_vals[fname][slice_num] = torch.tensor( evaluate.psnr(target, output)).view(1) return { "val_loss": val_logs["val_loss"], "nmse_vals": nmse_vals, "ssim_vals": ssim_vals, "psnr_vals": psnr_vals, }