def run_inference(checkpoint, data_path, output_path): varnet = VarNet() load_state_dict = torch.load(checkpoint)["state_dict"] state_dict = {} for k, v in load_state_dict.items(): if "varnet" in k: state_dict[k[len("varnet."):]] = v varnet.load_state_dict(state_dict) varnet = varnet.eval() data_transform = DataTransform() dataset = SliceDataset( root=data_path, transform=data_transform, challenge="multicoil", ) dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) start_time = time.perf_counter() outputs = defaultdict(list) for batch in tqdm(dataloader, desc="Running inference..."): masked_kspace, mask, _, fname, slice_num, _, crop_size = batch crop_size = crop_size[0] # always have a batch size of 1 for varnet fname = fname[0] # always have batch size of 1 for varnet with torch.no_grad(): try: device = torch.device("cuda") output = run_model(masked_kspace, mask, varnet, fname, device) except RuntimeError: print("running on cpu") device = torch.device("cpu") output = run_model(masked_kspace, mask, varnet, fname, device) output = T.center_crop(output, crop_size)[0] outputs[fname].append((slice_num, output)) for fname in outputs: outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, output_path / "reconstructions") end_time = time.perf_counter() print(f"elapsed time for {len(dataloader)} slices: {end_time-start_time}")
def run_inference(challenge, state_dict_file, data_path, output_path, device): model = VarNet(num_cascades=12, pools=4, chans=18, sens_pools=4, sens_chans=8) # download the state_dict if we don't have it if state_dict_file is None: if not Path(MODEL_FNAMES[challenge]).exists(): url_root = VARNET_FOLDER download_model(url_root + MODEL_FNAMES[challenge], MODEL_FNAMES[challenge]) state_dict_file = MODEL_FNAMES[challenge] model.load_state_dict(torch.load(state_dict_file)) model = model.eval() # data loader setup data_transform = T.VarNetDataTransform() dataset = SliceDataset(root=data_path, transform=data_transform, challenge="multicoil") dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) # run the model start_time = time.perf_counter() outputs = defaultdict(list) model = model.to(device) for batch in tqdm(dataloader, desc="Running inference"): with torch.no_grad(): output, slice_num, fname = run_varnet_model(batch, model, device) outputs[fname].append((slice_num, output)) # save outputs for fname in outputs: outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, output_path / "reconstructions") end_time = time.perf_counter() print(f"Elapsed time for {len(dataloader)} slices: {end_time-start_time}")