def save_zero_filled(data_dir, out_dir, which_challenge): reconstructions = {} for f in data_dir.iterdir(): with h5py.File(f, "r") as hf: enc = ismrmrd.xsd.CreateFromDocument(hf["ismrmrd_header"][()]).encoding[0] masked_kspace = transforms.to_tensor(hf["kspace"][()]) # extract target image width, height from ismrmrd header crop_size = (enc.reconSpace.matrixSize.x, enc.reconSpace.matrixSize.y) # inverse Fourier Transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # crop input image image = transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if which_challenge == "multicoil": image = fastmri.rss(image, dim=1) reconstructions[f.name] = image fastmri.save_reconstructions(reconstructions, out_dir)
def save_zero_filled(data_dir, out_dir, which_challenge): reconstructions = {} for fname in tqdm(list(data_dir.glob("*.h5"))): with h5py.File(fname, "r") as hf: et_root = etree.fromstring(hf["ismrmrd_header"][()]) masked_kspace = transforms.to_tensor(hf["kspace"][()]) # extract target image width, height from ismrmrd header enc = ["encoding", "encodedSpace", "matrixSize"] crop_size = ( int(et_query(et_root, enc + ["x"])), int(et_query(et_root, enc + ["y"])), ) # inverse Fourier Transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) # crop input image image = transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if which_challenge == "multicoil": image = fastmri.rss(image, dim=1) reconstructions[fname.name] = image fastmri.save_reconstructions(reconstructions, out_dir)
def run_inference(challenge, state_dict_file, data_path, output_path, device): model = Unet(in_chans=1, out_chans=1, chans=256, num_pool_layers=4, drop_prob=0.0) # download the state_dict if we don't have it if state_dict_file is None: if not Path(MODEL_FNAMES[challenge]).exists(): url_root = UNET_FOLDER download_model(url_root + MODEL_FNAMES[challenge], MODEL_FNAMES[challenge]) state_dict_file = MODEL_FNAMES[challenge] model.load_state_dict(torch.load(state_dict_file)) model = model.eval() # data loader setup if "_mc" in challenge: data_transform = T.UnetDataTransform(which_challenge="multicoil") else: data_transform = T.UnetDataTransform(which_challenge="singlecoil") if "_mc" in challenge: dataset = SliceDataset( root=data_path, transform=data_transform, challenge="multicoil", ) else: dataset = SliceDataset( root=data_path, transform=data_transform, challenge="singlecoil", ) dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) # run the model start_time = time.perf_counter() outputs = defaultdict(list) model = model.to(device) for batch in tqdm(dataloader, desc="Running inference"): with torch.no_grad(): output, slice_num, fname = run_unet_model(batch, model, device) outputs[fname].append((slice_num, output)) # save outputs for fname in outputs: outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, output_path / "reconstructions") end_time = time.perf_counter() print(f"Elapsed time for {len(dataloader)} slices: {end_time-start_time}")
def save_outputs(outputs, output_path): """Saves reconstruction outputs to output_path.""" reconstructions = defaultdict(list) for fname, slice_num, pred in outputs: reconstructions[fname].append((slice_num, pred)) reconstructions = { fname: np.stack([pred for _, pred in sorted(slice_preds)]) for fname, slice_preds in reconstructions.items() } fastmri.save_reconstructions(reconstructions, output_path)
def run_inference(checkpoint, data_path, output_path): varnet = VarNet() load_state_dict = torch.load(checkpoint)["state_dict"] state_dict = {} for k, v in load_state_dict.items(): if "varnet" in k: state_dict[k[len("varnet."):]] = v varnet.load_state_dict(state_dict) varnet = varnet.eval() data_transform = DataTransform() dataset = SliceDataset( root=data_path, transform=data_transform, challenge="multicoil", ) dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) start_time = time.perf_counter() outputs = defaultdict(list) for batch in tqdm(dataloader, desc="Running inference..."): masked_kspace, mask, _, fname, slice_num, _, crop_size = batch crop_size = crop_size[0] # always have a batch size of 1 for varnet fname = fname[0] # always have batch size of 1 for varnet with torch.no_grad(): try: device = torch.device("cuda") output = run_model(masked_kspace, mask, varnet, fname, device) except RuntimeError: print("running on cpu") device = torch.device("cpu") output = run_model(masked_kspace, mask, varnet, fname, device) output = T.center_crop(output, crop_size)[0] outputs[fname].append((slice_num, output)) for fname in outputs: outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, output_path / "reconstructions") end_time = time.perf_counter() print(f"elapsed time for {len(dataloader)} slices: {end_time-start_time}")
def test_epoch_end(self, test_logs): outputs = defaultdict(list) for log in test_logs: for i, (fname, slice) in enumerate(zip(log["fname"], log["slice"])): outputs[fname].append((slice, log["output"][i])) for fname in outputs: outputs[fname] = np.stack( [out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, self.exp_dir / self.exp_name / "bicubic") return dict()
def run_inference(challenge, state_dict_file, data_path, output_path, mask_func, use_sens_net, device): model = VarNet(num_cascades=12, pools=4, chans=18, use_sens_net=use_sens_net, sens_pools=4, sens_chans=8) # download the state_dict if we don't have it if state_dict_file is None: if not Path(MODEL_FNAMES[challenge]).exists(): url_root = VARNET_FOLDER download_model(url_root + MODEL_FNAMES[challenge], MODEL_FNAMES[challenge]) state_dict_file = MODEL_FNAMES[challenge] model.load_state_dict(torch.load(state_dict_file)) model = model.eval() # data loader setup if mask_func is None: data_transform = T.VarNetDataTransform() else: data_transform = T.VarNetDataTransform(mask_func=mask_func) dataset = SliceDataset( root=data_path, transform=data_transform, challenge="multicoil" ) dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) # run the model start_time = time.perf_counter() outputs = defaultdict(list) model = model.to(device) for batch in tqdm(dataloader, desc="Running inference"): with torch.no_grad(): output, slice_num, fname = run_varnet_model(batch, model, device) outputs[fname].append((slice_num, output)) # save outputs for fname in outputs: outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, output_path / "reconstructions") end_time = time.perf_counter() print(f"Elapsed time for {len(dataloader)} slices: {end_time - start_time}")
def test_epoch_end(self, test_logs): outputs = defaultdict(dict) # use dicts for aggregation to handle duplicate slices in ddp mode for log in test_logs: for i, (fname, slice_num) in enumerate(zip(log["fname"], log["slice"])): outputs[fname][int(slice_num.cpu())] = log["output"][i] # stack all the slices for each file for fname in outputs: outputs[fname] = np.stack( [out for _, out in sorted(outputs[fname].items())] ) # pull the default_root_dir if we have a trainer, otherwise save to cwd if hasattr(self, "trainer"): save_path = pathlib.Path(self.trainer.default_root_dir) / "reconstructions" else: save_path = pathlib.Path.cwd() / "reconstructions" self.print(f"Saving reconstructions to {save_path}") fastmri.save_reconstructions(outputs, save_path)
def validation_epoch_end(self, val_logs): #assert val_logs[0]["output_im"].ndim == 3 device = val_logs[0]["device"] # run the visualizations self._visualize_val( val_outputs=[x["output"].numpy() for x in val_logs], val_targets=[x["target"].numpy() for x in val_logs], val_inputs=[x["input"].numpy() for x in val_logs], ) # aggregate losses losses = [] outputs = defaultdict(list) targets = defaultdict(list) inputs = defaultdict(list) for val_log in val_logs: losses.append(val_log["val_loss"]) for i, (fname, slice_ind) in enumerate( zip(val_log["fname"], val_log["slice"])): # need to check for duplicate slices if slice_ind not in [s for (s, _) in outputs[int(fname)]]: outputs[int(fname)].append( (int(slice_ind), val_log["output"][i])) targets[int(fname)].append( (int(slice_ind), val_log["target"][i])) inputs[int(fname)].append( (int(slice_ind), val_log["input"][i])) # handle aggregation for distributed case with pytorch_lightning metrics metrics = dict(val_loss=0, nmse=0, ssim=0, psnr=0) for fname in outputs: output = torch.stack([out for _, out in sorted(outputs[fname]) ]).numpy() target = torch.stack([tgt for _, tgt in sorted(targets[fname]) ]).numpy() input = torch.stack([inn for _, inn in sorted(inputs[fname])]).numpy() metrics["nmse"] = metrics["nmse"] + evaluate.nmse(target, output) metrics["ssim"] = metrics["ssim"] + evaluate.ssim(target, output) metrics["psnr"] = metrics["psnr"] + evaluate.psnr(target, output) # currently ddp reduction requires everything on CUDA, so we'll do this manually metrics["nmse"] = self.NMSE(torch.tensor(metrics["nmse"]).to(device)) metrics["ssim"] = self.SSIM(torch.tensor(metrics["ssim"]).to(device)) metrics["psnr"] = self.PSNR(torch.tensor(metrics["psnr"]).to(device)) metrics["val_loss"] = self.ValLoss( torch.sum(torch.stack(losses)).to(device)) num_examples = torch.tensor(len(outputs)).to(device) tot_examples = self.TotExamples(num_examples) log_metrics = { f"metrics/{metric}": values / tot_examples for metric, values in metrics.items() } metrics = { metric: values / tot_examples for metric, values in metrics.items() } print(tot_examples, device, metrics) fastmri.save_reconstructions(inputs, self.exp_dir / self.exp_name / "bicubic") return dict(log=log_metrics, **metrics)
def run_inference(challenge, state_dict_file, data_path, output_path, device, mask, in_chans, out_chans, chans): if args.unet_module == "unet": model = Unet(in_chans=in_chans, out_chans=out_chans, chans=chans, num_pool_layers=4, drop_prob=0.0) elif args.unet_module == "nestedunet": model = NestedUnet(in_chans=in_chans, out_chans=out_chans, chans=chans, num_pool_layers=4, drop_prob=0.0) pretrained_dict = torch.load(state_dict_file, map_location=device) model_dict = model.state_dict() if args.fine_tuned_model: if 'state_dict' in pretrained_dict.keys(): model_dict = { k: pretrained_dict["state_dict"][f"unet.{k}"] for k, _ in model_dict.items() } # load from .ckpt elif 'unet' in pretrained_dict.keys(): model_dict = { k: pretrained_dict["unet." + k] for k, v in model_dict.items() } # load from .torch else: model_dict = { k: pretrained_dict[k] for k, v in model_dict.items() } # load from .pt else: if args.unet_module == "unet": model_dict = { k: pretrained_dict["classy_state_dict"]["base_model"]["model"] ["trunk"]["_feature_blocks.unetblock." + k] for k, _ in model_dict.items() } elif args.unet_module == "nestedunet": model_dict = { k: pretrained_dict["classy_state_dict"]["base_model"]["model"] ["trunk"]["_feature_blocks.nublock." + k] for k, v in model_dict.items() } model.load_state_dict(model_dict) model = model.eval() # data loader setup if "_mc" in challenge: data_transform = T.UnetDataTransform(which_challenge="multicoil", mask_func=mask) else: data_transform = T.UnetDataTransform(which_challenge="singlecoil", mask_func=mask) if "_mc" in challenge: dataset = SliceDataset( root=data_path, transform=data_transform, challenge="multicoil", ) else: dataset = SliceDataset( root=data_path, transform=data_transform, challenge="singlecoil", ) dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) # run the model start_time = time.perf_counter() outputs = defaultdict(list) model = model.to(device) for batch in tqdm(dataloader, desc="Running inference"): with torch.no_grad(): output, slice_num, fname = run_unet_model(batch, model, device) outputs[fname].append((slice_num, output)) # save outputs for fname in outputs: outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, output_path / "reconstructions") end_time = time.perf_counter() print(f"Elapsed time for {len(dataloader)} slices: {end_time-start_time}")