def run_inference(checkpoint, data_path, output_path): varnet = VarNet() load_state_dict = torch.load(checkpoint)["state_dict"] state_dict = {} for k, v in load_state_dict.items(): if "varnet" in k: state_dict[k[len("varnet."):]] = v varnet.load_state_dict(state_dict) varnet = varnet.eval() data_transform = DataTransform() dataset = SliceDataset( root=data_path, transform=data_transform, challenge="multicoil", ) dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) start_time = time.perf_counter() outputs = defaultdict(list) for batch in tqdm(dataloader, desc="Running inference..."): masked_kspace, mask, _, fname, slice_num, _, crop_size = batch crop_size = crop_size[0] # always have a batch size of 1 for varnet fname = fname[0] # always have batch size of 1 for varnet with torch.no_grad(): try: device = torch.device("cuda") output = run_model(masked_kspace, mask, varnet, fname, device) except RuntimeError: print("running on cpu") device = torch.device("cpu") output = run_model(masked_kspace, mask, varnet, fname, device) output = T.center_crop(output, crop_size)[0] outputs[fname].append((slice_num, output)) for fname in outputs: outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, output_path / "reconstructions") end_time = time.perf_counter() print(f"elapsed time for {len(dataloader)} slices: {end_time-start_time}")
def test_varnet_num_sense_lines(shape, chans, center_fractions, accelerations, mask_center): mask_func = RandomMaskFunc(center_fractions, accelerations) x = create_input(shape) output, mask, num_low_freqs = transforms.apply_mask(x, mask_func, seed=123) varnet = VarNet( num_cascades=2, sens_chans=4, sens_pools=2, chans=chans, pools=2, mask_center=mask_center, ) if mask_center is True: pad, net_low_freqs = varnet.sens_net.get_pad_and_num_low_freqs( mask, num_low_freqs) assert net_low_freqs == num_low_freqs assert torch.allclose( mask.squeeze()[int(pad):int(pad + net_low_freqs)].to(torch.int8), torch.ones([int(net_low_freqs)], dtype=torch.int8), ) y = varnet(output, mask.byte(), num_low_frequencies=4) assert y.shape[1:] == x.shape[2:4]
def test_varnet(shape, chans, center_fractions, accelerations, mask_center): mask_func = RandomMaskFunc(center_fractions, accelerations) x = create_input(shape) outputs, masks = [], [] for i in range(x.shape[0]): output, mask, _ = transforms.apply_mask(x[i:i + 1], mask_func, seed=123) outputs.append(output) masks.append(mask) output = torch.cat(outputs) mask = torch.cat(masks) varnet = VarNet( num_cascades=2, sens_chans=4, sens_pools=2, chans=chans, pools=2, mask_center=mask_center, ) y = varnet(output, mask.byte()) assert y.shape[1:] == x.shape[2:4]
def __init__( self, num_cascades: int = 12, pools: int = 4, chans: int = 18, sens_pools: int = 4, sens_chans: int = 8, lr: float = 0.0003, lr_step_size: int = 40, lr_gamma: float = 0.1, weight_decay: float = 0.0, **kwargs, ): """ Args: num_cascades: Number of cascades (i.e., layers) for variational network. pools: Number of downsampling and upsampling layers for cascade U-Net. chans: Number of channels for cascade U-Net. sens_pools: Number of downsampling and upsampling layers for sensitivity map U-Net. sens_chans: Number of channels for sensitivity map U-Net. lr: Learning rate. lr_step_size: Learning rate step size. lr_gamma: Learning rate gamma decay. weight_decay: Parameter for penalizing weights norm. num_sense_lines: Number of low-frequency lines to use for sensitivity map computation, must be even or `None`. Default `None` will automatically compute the number from masks. Default behaviour may cause some slices to use more low-frequency lines than others, when used in conjunction with e.g. the EquispacedMaskFunc defaults. To prevent this, either set `num_sense_lines`, or set `skip_low_freqs` and `skip_around_low_freqs` to `True` in the EquispacedMaskFunc. Note that setting this value may lead to undesired behaviour when training on multiple accelerations simultaneously. """ super().__init__(**kwargs) self.save_hyperparameters() self.num_cascades = num_cascades self.pools = pools self.chans = chans self.sens_pools = sens_pools self.sens_chans = sens_chans self.lr = lr self.lr_step_size = lr_step_size self.lr_gamma = lr_gamma self.weight_decay = weight_decay self.varnet = VarNet( num_cascades=self.num_cascades, sens_chans=self.sens_chans, sens_pools=self.sens_pools, chans=self.chans, pools=self.pools, ) self.loss = fastmri.SSIMLoss()
def test_varnet_scripting(): model = VarNet(num_cascades=4, pools=2, chans=8, sens_pools=2, sens_chans=4) scr = torch.jit.script(model) assert scr is not None
def test_varnet(shape, out_chans, chans, center_fractions, accelerations): mask_func = RandomMaskFunc(center_fractions, accelerations) x = create_input(shape) output, mask = transforms.apply_mask(x, mask_func, seed=123) varnet = VarNet(num_cascades=2, sens_chans=4, sens_pools=2, chans=4, pools=2) y = varnet(output, mask.byte()) assert y.shape[1:] == x.shape[2:4]
def __init__( self, num_cascades: int = 12, pools: int = 4, chans: int = 18, sens_pools: int = 4, sens_chans: int = 8, lr: float = 0.0003, lr_step_size: int = 40, lr_gamma: float = 0.1, weight_decay: float = 0.0, **kwargs, ): """ Args: num_cascades: Number of cascades (i.e., layers) for variational network. pools: Number of downsampling and upsampling layers for cascade U-Net. chans: Number of channels for cascade U-Net. sens_pools: Number of downsampling and upsampling layers for sensitivity map U-Net. sens_chans: Number of channels for sensitivity map U-Net. lr: Learning rate. lr_step_size: Learning rate step size. lr_gamma: Learning rate gamma decay. weight_decay: Parameter for penalizing weights norm. """ super().__init__(**kwargs) self.save_hyperparameters() self.num_cascades = num_cascades self.pools = pools self.chans = chans self.sens_pools = sens_pools self.sens_chans = sens_chans self.lr = lr self.lr_step_size = lr_step_size self.lr_gamma = lr_gamma self.weight_decay = weight_decay self.varnet = VarNet( num_cascades=self.num_cascades, sens_chans=self.sens_chans, sens_pools=self.sens_pools, chans=self.chans, pools=self.pools, ) self.loss = fastmri.SSIMLoss()
def run_inference(challenge, state_dict_file, data_path, output_path, device): model = VarNet(num_cascades=12, pools=4, chans=18, sens_pools=4, sens_chans=8) # download the state_dict if we don't have it if state_dict_file is None: if not Path(MODEL_FNAMES[challenge]).exists(): url_root = VARNET_FOLDER download_model(url_root + MODEL_FNAMES[challenge], MODEL_FNAMES[challenge]) state_dict_file = MODEL_FNAMES[challenge] model.load_state_dict(torch.load(state_dict_file)) model = model.eval() # data loader setup data_transform = T.VarNetDataTransform() dataset = SliceDataset(root=data_path, transform=data_transform, challenge="multicoil") dataloader = torch.utils.data.DataLoader(dataset, num_workers=4) # run the model start_time = time.perf_counter() outputs = defaultdict(list) model = model.to(device) for batch in tqdm(dataloader, desc="Running inference"): with torch.no_grad(): output, slice_num, fname = run_varnet_model(batch, model, device) outputs[fname].append((slice_num, output)) # save outputs for fname in outputs: outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])]) fastmri.save_reconstructions(outputs, output_path / "reconstructions") end_time = time.perf_counter() print(f"Elapsed time for {len(dataloader)} slices: {end_time-start_time}")
def __init__( self, num_cascades=12, pools=4, chans=18, sens_pools=4, sens_chans=8, mask_type="equispaced", center_fractions=[0.08], accelerations=[4], lr=0.0003, lr_step_size=40, lr_gamma=0.1, weight_decay=0.0, **kwargs, ): """ Args: num_cascades (int, default=12): Number of cascades (i.e., layers) for variational network. sens_chans (int, default=8): Number of channels for sensitivity map U-Net. sens_pools (int, default=8): Number of downsampling and upsampling layers for sensitivity map U-Net. chans (int, default=18): Number of channels for cascade U-Net. pools (int, default=4): Number of downsampling and upsampling layers for cascade U-Net. mask_type (str, default="equispaced"): Type of mask from ("random", "equispaced"). center_fractions (list, default=[0.08]): Fraction of all samples to take from center (i.e., list of floats). accelerations (list, default=[4]): List of accelerations to apply (i.e., list of ints). lr (float, default=0.0003): Learning rate. lr_step_size (int, default=40): Learning rate step size. lr_gamma (float, default=0.1): Learning rate gamma decay. weight_decay (float, default=0): Parameter for penalizing weights norm. """ super().__init__(**kwargs) if self.batch_size != 1: raise NotImplementedError( f"Only batch_size=1 allowed for {self.__class__.__name__}" ) self.num_cascades = num_cascades self.pools = pools self.chans = chans self.sens_pools = sens_pools self.sens_chans = sens_chans self.mask_type = mask_type self.center_fractions = center_fractions self.accelerations = accelerations self.lr = lr self.lr_step_size = lr_step_size self.lr_gamma = lr_gamma self.weight_decay = weight_decay self.varnet = VarNet( num_cascades=self.num_cascades, sens_chans=self.sens_chans, sens_pools=self.sens_pools, chans=self.chans, pools=self.pools, ) self.loss = fastmri.SSIMLoss()