def validation_step(self, batch, batch_idx): masked_kspace, mask, target, fname, slice_num, max_value, _ = batch output = self.forward(masked_kspace, mask) target, output = T.center_crop_to_smallest(target, output) # hash strings to int so pytorch can concat them fnumber = torch.zeros(len(fname), dtype=torch.long, device=output.device) for i, fn in enumerate(fname): fnumber[i] = ( int(hashlib.sha256(fn.encode("utf-8")).hexdigest(), 16) % 10**12) return { "fname": fnumber, "slice": slice_num, "output": output, "target": target, "val_loss": self.loss(output.unsqueeze(1), target.unsqueeze(1), data_range=max_value), }
def training_step(self, batch, batch_idx): masked_kspace, mask, target, _, _, max_value, _ = batch output = self(masked_kspace, mask) target, output = T.center_crop_to_smallest(target, output) loss = self.loss(output.unsqueeze(1), target.unsqueeze(1), data_range=max_value) return {"loss": loss, "log": {"train_loss": loss.item()}}
def training_step(self, batch, batch_idx): masked_kspace, mask, target, _, _, max_value, _ = batch output = self(masked_kspace, mask) target, output = transforms.center_crop_to_smallest(target, output) loss = self.loss(output.unsqueeze(1), target.unsqueeze(1), data_range=max_value) self.log("train_loss", loss) return loss
def training_step(self, batch, batch_idx): output = self(batch.masked_kspace, batch.mask, batch.num_low_frequencies) target, output = transforms.center_crop_to_smallest( batch.target, output) loss = self.loss(output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value) self.log("train_loss", loss) return loss
def training_loss(self, batch): output, target = self.predict(batch) output, target = transforms.center_crop_to_smallest(output, target) if self.args.nan_detection: if torch.any(torch.isnan(output)): print(output) raise Exception("nan encountered") if torch.any(torch.isinf(output)): print(output) raise Exception("inf encountered") loss = F.l1_loss(output, target) return loss, output, target
def validation_step(self, batch, batch_idx): masked_kspace, mask, target, fname, slice_num, max_value, _ = batch kspace_pred = self(masked_kspace, mask) output = fastmri.complex_abs(fastmri.ifft2c(kspace_pred)) target, output = transforms.center_crop_to_smallest(target, output) return { "batch_idx": batch_idx, "fname": fname, "slice_num": slice_num, "max_value": max_value, "output": output, "target": target, "val_loss": self.loss(kspace_pred, masked_kspace), }
def validation_step(self, batch, batch_idx): masked_kspace, mask, target, fname, slice_num, max_value, _ = batch output = self.forward(masked_kspace, mask) target, output = transforms.center_crop_to_smallest(target, output) return { "batch_idx": batch_idx, "fname": fname, "slice_num": slice_num, "max_value": max_value, "output": output, "target": target, "val_loss": self.loss( output.unsqueeze(1), target.unsqueeze(1), data_range=max_value ), }
def validation_step(self, batch, batch_idx): output = self.forward(batch.masked_kspace, batch.mask, batch.num_low_frequencies) target, output = transforms.center_crop_to_smallest( batch.target, output) return { "batch_idx": batch_idx, "fname": batch.fname, "slice_num": batch.slice_num, "max_value": batch.max_value, "output": output, "target": target, "val_loss": self.loss(output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value), }
def compute_stats(self, epoch, loader, setname): """ This is separate from stats mainly for distributed support""" args = self.args self.model.eval() ndevbatches = len(self.dev_loader) logging.info(f"Evaluating {ndevbatches} batches ...") recons, gts = defaultdict(list), defaultdict(list) acquisition_machine_by_fname = dict() with torch.no_grad(): for batch_idx, batch in enumerate(self.dev_loader): progress = epoch + batch_idx / ndevbatches logging_epoch = batch_idx % args.log_interval == 0 logging_epoch_info = batch_idx % (2 * args.log_interval) == 0 log = logging.info if logging_epoch_info else logging.debug self.start_of_test_batch_hook(progress, logging_epoch) batch = self.preprocess_data(batch) output, target = self.predict(batch) output = self.unnorm(output, batch) target = self.unnorm(target, batch) fname, slice = batch.fname, batch.slice for i in range(output.shape[0]): slice_cpu = slice[i].item() recons[fname[i]].append( (slice_cpu, output[i].float().cpu().numpy())) gts[fname[i]].append( (slice_cpu, target[i].float().cpu().numpy())) acquisition_type = batch.attrs_dict['acquisition'][i] machine_type = batch.attrs_dict['system'][i] acquisition_machine_by_fname[ fname[i]] = machine_type + '_' + acquisition_type if logging_epoch or batch_idx == ndevbatches - 1: gpu_memory_gb = torch.cuda.memory_allocated() / 1000000000 host_memory_gb = utils.host_memory_usage_in_gb() log(f"Evaluated {batch_idx+1} of {ndevbatches} (GPU Mem: {gpu_memory_gb:2.3f}gb Host Mem: {gpu_memory_gb:2.3f}gb)" ) sys.stdout.flush() if self.args.debug_epoch_stats: break del output, target, batch logging.debug(f"Finished evaluating") self.end_of_test_epoch_hook() recons = { fname: np.stack([pred for _, pred in sorted(slice_preds)]) for fname, slice_preds in recons.items() } gts = { fname: np.stack([pred for _, pred in sorted(slice_preds)]) for fname, slice_preds in gts.items() } nmse, psnr, ssims = [], [], [] ssim_for_acquisition_machine = defaultdict(list) recon_keys = list(recons.keys()).copy() for fname in recon_keys: pred_or, gt_or = recons[fname].squeeze(1), gts[fname].squeeze( 1) pred, gt = transforms.center_crop_to_smallest(pred_or, gt_or) del pred_or, gt_or ssim = evaluate.ssim(gt, pred) acquisition_machine = acquisition_machine_by_fname[fname] ssim_for_acquisition_machine[acquisition_machine].append(ssim) ssims.append(ssim) nmse.append(evaluate.nmse(gt, pred)) psnr.append(evaluate.psnr(gt, pred)) del gt, pred del recons[fname], gts[fname] if len(nmse) == 0: nmse.append(0) ssims.append(0) psnr.append(0) min_vol_ssim = np.argmin(ssims) min_vol = str(recon_keys[min_vol_ssim]) logging.info(f"Min vol ssims: {min_vol}") sys.stdout.flush() del recons, gts acquisition_machine_losses = dict.fromkeys( self.dev_data.system_acquisitions, 0) for key, value in ssim_for_acquisition_machine.items(): acquisition_machine_losses[key] = np.mean(value) losses = { 'NMSE': np.mean(nmse), 'PSNR': np.mean(psnr), 'SSIM': np.mean(ssims), 'SSIM_var': np.var(ssims), 'SSIM_min': np.min(ssims), **acquisition_machine_losses } return losses
def test_center_crop_to_smallest(x_shape, y_shape, target_shape): input_x = create_input(x_shape) input_y = create_input(y_shape) x_out, y_out = transforms.center_crop_to_smallest(input_x, input_y) assert x_out.shape == y_out.shape assert list(x_out.shape) == target_shape