def __init__( self, num_cascades: int = 12, pools: int = 4, chans: int = 18, sens_pools: int = 4, sens_chans: int = 8, lr: float = 0.0003, lr_step_size: int = 40, lr_gamma: float = 0.1, weight_decay: float = 0.0, **kwargs, ): """ Args: num_cascades: Number of cascades (i.e., layers) for variational network. pools: Number of downsampling and upsampling layers for cascade U-Net. chans: Number of channels for cascade U-Net. sens_pools: Number of downsampling and upsampling layers for sensitivity map U-Net. sens_chans: Number of channels for sensitivity map U-Net. lr: Learning rate. lr_step_size: Learning rate step size. lr_gamma: Learning rate gamma decay. weight_decay: Parameter for penalizing weights norm. num_sense_lines: Number of low-frequency lines to use for sensitivity map computation, must be even or `None`. Default `None` will automatically compute the number from masks. Default behaviour may cause some slices to use more low-frequency lines than others, when used in conjunction with e.g. the EquispacedMaskFunc defaults. To prevent this, either set `num_sense_lines`, or set `skip_low_freqs` and `skip_around_low_freqs` to `True` in the EquispacedMaskFunc. Note that setting this value may lead to undesired behaviour when training on multiple accelerations simultaneously. """ super().__init__(**kwargs) self.save_hyperparameters() self.num_cascades = num_cascades self.pools = pools self.chans = chans self.sens_pools = sens_pools self.sens_chans = sens_chans self.lr = lr self.lr_step_size = lr_step_size self.lr_gamma = lr_gamma self.weight_decay = weight_decay self.varnet = VarNet( num_cascades=self.num_cascades, sens_chans=self.sens_chans, sens_pools=self.sens_pools, chans=self.chans, pools=self.pools, ) self.loss = fastmri.SSIMLoss()
def __init__( self, num_cascades: int = 12, pools: int = 4, chans: int = 18, sens_pools: int = 4, sens_chans: int = 8, lr: float = 0.0003, lr_step_size: int = 40, lr_gamma: float = 0.1, weight_decay: float = 0.0, **kwargs, ): """ Args: num_cascades: Number of cascades (i.e., layers) for variational network. pools: Number of downsampling and upsampling layers for cascade U-Net. chans: Number of channels for cascade U-Net. sens_pools: Number of downsampling and upsampling layers for sensitivity map U-Net. sens_chans: Number of channels for sensitivity map U-Net. lr: Learning rate. lr_step_size: Learning rate step size. lr_gamma: Learning rate gamma decay. weight_decay: Parameter for penalizing weights norm. """ super().__init__(**kwargs) self.save_hyperparameters() self.num_cascades = num_cascades self.pools = pools self.chans = chans self.sens_pools = sens_pools self.sens_chans = sens_chans self.lr = lr self.lr_step_size = lr_step_size self.lr_gamma = lr_gamma self.weight_decay = weight_decay self.varnet = VarNet( num_cascades=self.num_cascades, sens_chans=self.sens_chans, sens_pools=self.sens_pools, chans=self.chans, pools=self.pools, ) self.loss = fastmri.SSIMLoss()
# In[9]: import fastmri from fastmri.data import transforms as T # In[10]: slice_kspace2 = T.to_tensor(slice_kspace) # Convert from numpy array to pytorch tensor slice_image = fastmri.ifft2c(slice_kspace2) # Apply Inverse Fourier Transform to get the complex image slice_image_abs = fastmri.complex_abs(slice_image) # Compute absolute value to get a real image # SSIM loss loss = fastmri.SSIMLoss() print(loss(slice_image_abs.unsqueeze(1), slice_image_abs.unsqueeze(1), data_range=slice_image_abs.max().reshape(-1))) # In[15]: show_coils(slice_image_abs, [0], cmap='gray') # As we can see, each coil in a multi-coil MRI scan focusses on a different region of the image. These coils can be combined into the full image using the Root-Sum-of-Squares (RSS) transform. # In[16]: slice_image_rss = fastmri.rss(slice_image_abs, dim=0)
def __init__( self, num_cascades=12, pools=4, chans=18, sens_pools=4, sens_chans=8, mask_type="equispaced", center_fractions=[0.08], accelerations=[4], lr=0.0003, lr_step_size=40, lr_gamma=0.1, weight_decay=0.0, **kwargs, ): """ Args: num_cascades (int, default=12): Number of cascades (i.e., layers) for variational network. sens_chans (int, default=8): Number of channels for sensitivity map U-Net. sens_pools (int, default=8): Number of downsampling and upsampling layers for sensitivity map U-Net. chans (int, default=18): Number of channels for cascade U-Net. pools (int, default=4): Number of downsampling and upsampling layers for cascade U-Net. mask_type (str, default="equispaced"): Type of mask from ("random", "equispaced"). center_fractions (list, default=[0.08]): Fraction of all samples to take from center (i.e., list of floats). accelerations (list, default=[4]): List of accelerations to apply (i.e., list of ints). lr (float, default=0.0003): Learning rate. lr_step_size (int, default=40): Learning rate step size. lr_gamma (float, default=0.1): Learning rate gamma decay. weight_decay (float, default=0): Parameter for penalizing weights norm. """ super().__init__(**kwargs) if self.batch_size != 1: raise NotImplementedError( f"Only batch_size=1 allowed for {self.__class__.__name__}" ) self.num_cascades = num_cascades self.pools = pools self.chans = chans self.sens_pools = sens_pools self.sens_chans = sens_chans self.mask_type = mask_type self.center_fractions = center_fractions self.accelerations = accelerations self.lr = lr self.lr_step_size = lr_step_size self.lr_gamma = lr_gamma self.weight_decay = weight_decay self.varnet = VarNet( num_cascades=self.num_cascades, sens_chans=self.sens_chans, sens_pools=self.sens_pools, chans=self.chans, pools=self.pools, ) self.loss = fastmri.SSIMLoss()