def run_inference(challenge, state_dict_file, data_path, output_path, device): model = Unet(in_chans=1, out_chans=1, chans=256, num_pool_layers=4, drop_prob=0.0) # download the state_dict if we don't have it if state_dict_file is None: if not Path(MODEL_FNAMES[challenge]).exists(): url_root = UNET_FOLDER download_model(url_root + MODEL_FNAMES[challenge], MODEL_FNAMES[challenge]) state_dict_file = MODEL_FNAMES[challenge] model.load_state_dict(torch.load(state_dict_file)) model = model.eval() # data loader setup if "_mc" in challenge: data_transform = T.UnetDataTransform(which_challenge="multicoil") else: data_transform = T.UnetDataTransform(which_challenge="singlecoil") if "_mc" in challenge: dataset = SliceDataset( root=data_path, transform=data_transform, challenge="multicoil", ) else: dataset = SliceDataset( root=data_path, transform=data_transform, challenge="singlecoil", ) dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) # run the model start_time = time.perf_counter() outputs = defaultdict(list) model = model.to(device) for batch in tqdm(dataloader, desc="Running inference"): with torch.no_grad(): output, slice_num, fname = run_unet_model(batch, model, device) outputs[fname].append((slice_num, output)) # save outputs for fname in outputs: outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, output_path / "reconstructions") end_time = time.perf_counter() print(f"Elapsed time for {len(dataloader)} slices: {end_time-start_time}")
def _create_data_loader(self, data_transform, data_partition, sample_rate=None): sample_rate = sample_rate or self.sample_rate dataset = SliceDataset(root=self.data_path / f"{self.challenge}_{data_partition}", transform=data_transform, sample_rate=sample_rate, challenge=self.challenge, mode=data_partition) is_train = data_partition == "train" # ensure that entire volumes go to the same GPU in the ddp setting sampler = None if self.use_ddp: if is_train: sampler = DistributedSampler(dataset) else: sampler = VolumeSampler(dataset) dataloader = DataLoader( dataset=dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=False, drop_last=is_train, sampler=sampler, ) return dataloader
def prepare_data(self): # call dataset for each split one time to make sure the cache is set up on the # rank 0 ddp process. if not using cache, don't do this if self.use_dataset_cache_file: if self.test_path is not None: test_path = self.test_path else: test_path = self.data_path / f"{self.challenge}_test" data_paths = [ self.data_path / f"{self.challenge}_train", self.data_path / f"{self.challenge}_val", test_path, ] data_transforms = [ self.train_transform, self.val_transform, self.test_transform, ] for i, (data_path, data_transform) in enumerate( zip(data_paths, data_transforms)): sample_rate = self.sample_rate if i == 0 else 1.0 volume_sample_rate = self.volume_sample_rate if i == 0 else None _ = SliceDataset( root=data_path, transform=data_transform, sample_rate=sample_rate, volume_sample_rate=volume_sample_rate, challenge=self.challenge, use_dataset_cache=self.use_dataset_cache_file, )
def _create_data_loader( self, data_transform: Callable, data_partition: str, sample_rate: Optional[float] = None, ) -> torch.utils.data.DataLoader: if data_partition == "train": is_train = True sample_rate = self.sample_rate if sample_rate is None else sample_rate else: is_train = False sample_rate = 1.0 # if desired, combine train and val together for the train split dataset: Union[SliceDataset, CombinedSliceDataset] if is_train and self.combine_train_val: data_paths = [ self.data_path + f"/{self.challenge}_train", self.data_path + f"/{self.challenge}_val", ] data_transforms = [data_transform, data_transform] challenges = [self.challenge, self.challenge] sample_rates = [sample_rate, sample_rate] dataset = CombinedSliceDataset( roots=data_paths, transforms=data_transforms, challenges=challenges, sample_rates=sample_rates, use_dataset_cache=self.use_dataset_cache_file, ) else: if data_partition in ("test", "challenge") and self.test_path is not None: data_path = self.test_path else: data_path = self.data_path + f"/{self.challenge}_{data_partition}" dataset = SliceDataset( root=data_path, transform=data_transform, sample_rate=sample_rate, challenge=self.challenge, use_dataset_cache=self.use_dataset_cache_file, ) # ensure that entire volumes go to the same GPU in the ddp setting sampler = None if self.distributed_sampler: if is_train: sampler = torch.utils.data.DistributedSampler(dataset) else: sampler = fastmri.data.VolumeSampler(dataset) dataloader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=worker_init_fn, sampler=sampler, ) return dataloader
def run_inference(checkpoint, data_path, output_path): varnet = VarNet() load_state_dict = torch.load(checkpoint)["state_dict"] state_dict = {} for k, v in load_state_dict.items(): if "varnet" in k: state_dict[k[len("varnet."):]] = v varnet.load_state_dict(state_dict) varnet = varnet.eval() data_transform = DataTransform() dataset = SliceDataset( root=data_path, transform=data_transform, challenge="multicoil", ) dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) start_time = time.perf_counter() outputs = defaultdict(list) for batch in tqdm(dataloader, desc="Running inference..."): masked_kspace, mask, _, fname, slice_num, _, crop_size = batch crop_size = crop_size[0] # always have a batch size of 1 for varnet fname = fname[0] # always have batch size of 1 for varnet with torch.no_grad(): try: device = torch.device("cuda") output = run_model(masked_kspace, mask, varnet, fname, device) except RuntimeError: print("running on cpu") device = torch.device("cpu") output = run_model(masked_kspace, mask, varnet, fname, device) output = T.center_crop(output, crop_size)[0] outputs[fname].append((slice_num, output)) for fname in outputs: outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, output_path / "reconstructions") end_time = time.perf_counter() print(f"elapsed time for {len(dataloader)} slices: {end_time-start_time}")
def run_inference(challenge, state_dict_file, data_path, output_path, mask_func, use_sens_net, device): model = VarNet(num_cascades=12, pools=4, chans=18, use_sens_net=use_sens_net, sens_pools=4, sens_chans=8) # download the state_dict if we don't have it if state_dict_file is None: if not Path(MODEL_FNAMES[challenge]).exists(): url_root = VARNET_FOLDER download_model(url_root + MODEL_FNAMES[challenge], MODEL_FNAMES[challenge]) state_dict_file = MODEL_FNAMES[challenge] model.load_state_dict(torch.load(state_dict_file)) model = model.eval() # data loader setup if mask_func is None: data_transform = T.VarNetDataTransform() else: data_transform = T.VarNetDataTransform(mask_func=mask_func) dataset = SliceDataset( root=data_path, transform=data_transform, challenge="multicoil" ) dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) # run the model start_time = time.perf_counter() outputs = defaultdict(list) model = model.to(device) for batch in tqdm(dataloader, desc="Running inference"): with torch.no_grad(): output, slice_num, fname = run_varnet_model(batch, model, device) outputs[fname].append((slice_num, output)) # save outputs for fname in outputs: outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, output_path / "reconstructions") end_time = time.perf_counter() print(f"Elapsed time for {len(dataloader)} slices: {end_time - start_time}")
def run_inference(challenge, state_dict_file, data_path, output_path, device, mask, in_chans, out_chans, chans): if args.unet_module == "unet": model = Unet(in_chans=in_chans, out_chans=out_chans, chans=chans, num_pool_layers=4, drop_prob=0.0) elif args.unet_module == "nestedunet": model = NestedUnet(in_chans=in_chans, out_chans=out_chans, chans=chans, num_pool_layers=4, drop_prob=0.0) pretrained_dict = torch.load(state_dict_file, map_location=device) model_dict = model.state_dict() if args.fine_tuned_model: if 'state_dict' in pretrained_dict.keys(): model_dict = { k: pretrained_dict["state_dict"][f"unet.{k}"] for k, _ in model_dict.items() } # load from .ckpt elif 'unet' in pretrained_dict.keys(): model_dict = { k: pretrained_dict["unet." + k] for k, v in model_dict.items() } # load from .torch else: model_dict = { k: pretrained_dict[k] for k, v in model_dict.items() } # load from .pt else: if args.unet_module == "unet": model_dict = { k: pretrained_dict["classy_state_dict"]["base_model"]["model"] ["trunk"]["_feature_blocks.unetblock." + k] for k, _ in model_dict.items() } elif args.unet_module == "nestedunet": model_dict = { k: pretrained_dict["classy_state_dict"]["base_model"]["model"] ["trunk"]["_feature_blocks.nublock." + k] for k, v in model_dict.items() } model.load_state_dict(model_dict) model = model.eval() # data loader setup if "_mc" in challenge: data_transform = T.UnetDataTransform(which_challenge="multicoil", mask_func=mask) else: data_transform = T.UnetDataTransform(which_challenge="singlecoil", mask_func=mask) if "_mc" in challenge: dataset = SliceDataset( root=data_path, transform=data_transform, challenge="multicoil", ) else: dataset = SliceDataset( root=data_path, transform=data_transform, challenge="singlecoil", ) dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) # run the model start_time = time.perf_counter() outputs = defaultdict(list) model = model.to(device) for batch in tqdm(dataloader, desc="Running inference"): with torch.no_grad(): output, slice_num, fname = run_unet_model(batch, model, device) outputs[fname].append((slice_num, output)) # save outputs for fname in outputs: outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, output_path / "reconstructions") end_time = time.perf_counter() print(f"Elapsed time for {len(dataloader)} slices: {end_time-start_time}")
help="Number of processes. Set to 0 to disable multiprocessing.", ) return parser if __name__ == "__main__": args = create_arg_parser().parse_args() if args.split in ("train", "val"): mask = create_mask_for_mask_type( args.mask_type, args.center_fractions, args.accelerations, ) else: mask = None args.reg_wt = None # need this global for multiprocessing dataset = SliceDataset( root=args.data_path / f"{args.challenge}_{args.split}", transform=DataTransform(split=args.split, mask_func=mask, reg_wt=args.reg_wt), challenge=args.challenge, sample_rate=args.sample_rate, ) run_bart(args)