def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.use_sigpy, noncart=self.noncart) self.x_adj = self.A.adjoint(inp)
def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] # NJM: support for multiband if self.hparams.multiband: phi = data['phi'] self.A = MultiBandMRI(maps, masks, phi, l2lam=0., ksp_shape=phi.shape, img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart) else: self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart) self.x_adj = self.A.adjoint(inp)
class MoDLReconOneUnroll(torch.nn.Module): def __init__(self, denoiser, l2lam, hparams): super(MoDLReconOneUnroll, self).__init__() self.l2lam = l2lam self.num_cg = None self.x_adj = None self.hparams = hparams self.denoiser = denoiser def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart) if self.hparams.adjoint_data: self.x_adj = inp if self.A.single_channel: self.inp = fft_forw(maps.squeeze(1) * self.x_adj) else: self.x_adj = self.A.adjoint(inp) if self.A.single_channel: self.inp = inp.squeeze(1) def forward(self, x): assert self.x_adj is not None, "x_adj not computed!" r = self.denoiser(x) if self.A.single_channel: # multiply with maps because they might not be all-ones, and they include the fftmod term maps = self.A.maps.squeeze(1) r_ft = fft_forw(r * maps) x_ft_ones = (self.inp + self.l2lam * r_ft) / (1 + self.l2lam) x_ft = x_ft_ones * (abs(self.A.mask) != 0) + r_ft * (abs( self.A.mask) == 0) x = torch.conj(maps) * fft_adj(x_ft) self.num_cg = 0 else: cg_op = ConjGrad(self.x_adj + self.l2lam * r, self.A.normal, l2lam=self.l2lam, max_iter=self.hparams.cg_max_iter, eps=self.hparams.cg_eps, verbose=False) x = cg_op.forward(x) self.num_cg = cg_op.num_cg return x def get_metadata(self): return { 'num_cg': self.num_cg, }
class CGSenseRecon(Recon): def __init__(self, hparams): super(CGSenseRecon, self).__init__(hparams) self.l2lam = torch.nn.Parameter(torch.tensor(hparams.l2lam_init)) self.A = None def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart) if self.hparams.adjoint_data: self.x_adj = inp else: self.x_adj = self.A.adjoint(inp) def forward(self, y): cg_op = ConjGrad(self.x_adj, self.A.normal, l2lam=self.l2lam, max_iter=self.hparams.cg_max_iter, eps=self.hparams.cg_eps, verbose=False) x_out = cg_op.forward(self.x_adj * 0) self.num_cg = cg_op.num_cg return x_out def get_metadata(self): return { 'num_cg': self.num_cg, }
class UNetRecon(Recon): def __init__(self, hparams): super(UNetRecon, self).__init__(hparams) if self.hparams.network == 'UNet': self.network = UNet(batch_norm=self.hparams.batch_norm, l2lam=self.hparams.l2lam_init) def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] self.A = MultiChannelMRI(maps, masks, l2lam=self.hparams.l2lam_init, img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart) self.x_adj = self.A.adjoint(inp) def forward(self, y): return self.network(self.x_adj) def get_metadata(self): return {}
class ResNetRecon(Recon): def __init__(self, hparams): super(ResNetRecon, self).__init__(hparams) copy_shape = np.array(self.D.shape) if hparams.num_spatial_dimensions == 2: num_channels = 2*np.prod(copy_shape[1:-2]) elif hparams.num_spatial_dimensions == 3: num_channels = 2*np.prod(copy_shape[1:-3]) else: raise ValueError('only 2D or 3D number of spatial dimensions are supported!') self.in_channels = num_channels if self.hparams.network == 'ResNet5Block': # FIX ALSO self.network = ResNet5Block(num_filters_start=self.in_channels, num_filters_end=self.in_channels, num_filters=self.hparams.latent_channels, filter_size=7, batch_norm=self.hparams.batch_norm) elif self.hparams.network == 'ResNet': self.network = ResNet(in_channels=self.in_channels, latent_channels=self.hparams.latent_channels, num_blocks=self.hparams.num_blocks, kernel_size=7, batch_norm=self.hparams.batch_norm) def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart) self.x_adj = self.A.adjoint(inp) def forward(self, y): return self.network(self.x_adj) def get_metadata(self): return {}
def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] with torch.no_grad(): self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart) self.x_adj = self.A.adjoint(inp) self.eps = (self.A.maps.shape[1] * torch.sum(self.A.mask.reshape( (self.A.mask.shape[0], -1)), dim=1)).sqrt() * self.hparams.stdev
def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart) if self.hparams.adjoint_data: self.x_adj = inp if self.A.single_channel: self.inp = fft_forw(maps.squeeze(1) * self.x_adj) else: self.x_adj = self.A.adjoint(inp) if self.A.single_channel: self.inp = inp.squeeze(1)
def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] with torch.no_grad(): self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart) self.x_adj = self.A.adjoint(inp) # FIXME: look at modl for single channel support assert self.A.single_channel is False, 'single channel support not yet implemented!' self.eps = (self.A.maps.shape[1] * torch.sum(self.A.mask.reshape( (self.A.mask.shape[0], -1)), dim=1)).sqrt() * self.hparams.stdev
class MoDLReconOneUnroll(torch.nn.Module): def __init__(self, denoiser, l2lam, hparams): super(MoDLReconOneUnroll, self).__init__() self.l2lam = l2lam self.num_cg = None self.x_adj = None self.hparams = hparams self.denoiser = denoiser def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart) if self.hparams.adjoint_data: self.x_adj = inp else: self.x_adj = self.A.adjoint(inp) def forward(self, x): assert self.x_adj is not None, "x_adj not computed!" r = self.denoiser(x) cg_op = ConjGrad(self.x_adj + self.l2lam * r, self.A.normal, l2lam=self.l2lam, max_iter=self.hparams.cg_max_iter, eps=self.hparams.cg_eps, verbose=False) x = cg_op.forward(x) self.num_cg = cg_op.num_cg return x def get_metadata(self): return { 'num_cg': self.num_cg, }
class ResNetRecon(Recon): def __init__(self, args): super(ResNetRecon, self).__init__(args) if args.network == 'ResNet5Block': self.denoiser = ResNet5Block(num_filters=args.latent_channels, filter_size=7, batch_norm=args.batch_norm) elif args.network == 'ResNet': self.denoiser = ResNet(latent_channels=args.latent_channels, num_blocks=args.num_blocks, kernel_size=7, batch_norm=args.batch_norm) def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.use_sigpy, noncart=self.noncart) self.x_adj = self.A.adjoint(inp) def forward(self, y): return self.denoiser(self.x_adj) def get_metadata(self): return {}
class DeepBasisPursuitRecon(Recon): def __init__(self, args): super(DeepBasisPursuitRecon, self).__init__(args) self.l2lam = torch.nn.Parameter(torch.tensor(args.l2lam_init)) self.num_admm = args.num_admm if args.network == 'ResNet5Block': self.denoiser = ResNet5Block(num_filters=args.latent_channels, filter_size=7, batch_norm=args.batch_norm) elif args.network == 'ResNet': self.denoiser = ResNet(latent_channels=args.latent_channels, num_blocks=args.num_blocks, kernel_size=7, batch_norm=args.batch_norm) self.debug_level = 0 def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.use_sigpy, noncart=self.noncart) self.x_adj = self.A.adjoint(inp) def forward(self, y): eps = opt.ip_batch(self.A.maps.shape[1] * self.A.mask.sum((1, 2))).sqrt() * self.stdev x = self.A.adjoint(y) z = self.A(x) z_old = z u = z.new_zeros(z.shape) x.requires_grad = False z.requires_grad = False z_old.requires_grad = False u.requires_grad = False self.num_cg = np.zeros((self.num_unrolls,self.num_admm,)) for i in range(self.num_unrolls): r = self.denoiser(x) for j in range(self.num_admm): rhs = self.l2lam * self.A.adjoint(z - u) + r fun = lambda xx: self.l2lam * self.A.normal(xx) + xx cg_op = ConjGrad(rhs, fun, max_iter=self.cg_max_iter, eps=self.eps, verbose=False) x = cg_op.forward(x) n_cg = cg_op.num_cg self.num_cg[i, j] = n_cg Ax_plus_u = self.A(x) + u z_old = z z = y + opt.l2ball_proj_batch(Ax_plus_u - y, eps) u = Ax_plus_u - z # check ADMM convergence Ax = self.A(x) r_norm = opt.ip_batch(Ax-z).sqrt() s_norm = opt.ip_batch(self.l2lam * self.A.adjoint(z - z_old)).sqrt() if (r_norm + s_norm).max() < 1E-2: if self.debug_level > 0: tqdm.tqdm.write('stopping early, a={}'.format(a)) break return x def get_metadata(self): return { 'num_cg': self.num_cg.ravel(), }
class DeepBasisPursuitRecon(Recon): def __init__(self, hparams): super(DeepBasisPursuitRecon, self).__init__(hparams) self.l2lam = torch.nn.Parameter(torch.tensor(hparams.l2lam_init)) self.num_admm = hparams.num_admm copy_shape = np.array(self.D.shape) if hparams.num_spatial_dimensions == 2: num_channels = 2 * np.prod(copy_shape[1:-2]) elif hparams.num_spatial_dimensions == 3: num_channels = 2 * np.prod(copy_shape[1:-3]) else: raise ValueError( 'only 2D or 3D number of spatial dimensions are supported!') self.in_channels = num_channels if hparams.network == 'ResNet5Block': self.denoiser = ResNet5Block(num_filters_start=self.in_channels, num_filters_end=self.in_channels, num_filters=hparams.latent_channels, filter_size=7, batch_norm=hparams.batch_norm) elif hparams.network == 'ResNet': self.denoiser = ResNet(in_channels=self.in_channels, latent_channels=hparams.latent_channels, num_blocks=hparams.num_blocks, kernel_size=7, batch_norm=hparams.batch_norm) self.debug_level = 0 self.mean_residual_norm = 0 def batch(self, data): maps = data['maps'] masks = data['masks'] inp = data['out'] with torch.no_grad(): self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart) self.x_adj = self.A.adjoint(inp) # FIXME: look at modl for single channel support assert self.A.single_channel is False, 'single channel support not yet implemented!' self.eps = (self.A.maps.shape[1] * torch.sum(self.A.mask.reshape( (self.A.mask.shape[0], -1)), dim=1)).sqrt() * self.hparams.stdev def forward(self, y): with torch.no_grad(): self.mean_eps = torch.mean(self.eps) #print('epsilon is {}'.format(self.eps)) x = self.A.adjoint(y) z = self.A(x) z_old = z u = z.new_zeros(z.shape) x.requires_grad = False z.requires_grad = False z_old.requires_grad = False u.requires_grad = False self.num_cg = np.zeros(( self.hparams.num_unrolls, self.hparams.num_admm, )) for i in range(self.hparams.num_unrolls): r = self.denoiser(x) for j in range(self.hparams.num_admm): rhs = self.l2lam * self.A.adjoint(z - u) + r fun = lambda xx: self.l2lam * self.A.normal(xx) + xx cg_op = ConjGrad(rhs, fun, max_iter=self.hparams.cg_max_iter, eps=self.hparams.cg_eps, verbose=False) x = cg_op.forward(x) n_cg = cg_op.num_cg self.num_cg[i, j] = n_cg Ax_plus_u = self.A(x) + u z_old = z z = y + opt.l2ball_proj_batch(Ax_plus_u - y, self.eps) u = Ax_plus_u - z # check ADMM convergence with torch.no_grad(): Ax = self.A(x) tmp = Ax - z tmp = tmp.contiguous() r_norm = torch.real(opt.zdot_single_batch(tmp)).sqrt() tmp = self.l2lam * self.A.adjoint(z - z_old) tmp = tmp.contiguous() s_norm = torch.real(opt.zdot_single_batch(tmp)).sqrt() if (r_norm + s_norm).max() < 1E-2: if self.debug_level > 0: tqdm.tqdm.write('stopping early, a={}'.format(a)) break tmp = y - Ax self.mean_residual_norm = torch.mean( torch.sqrt(torch.real(opt.zdot_single_batch(tmp)))) return x def get_metadata(self): return { 'num_cg': self.num_cg.ravel(), 'mean_residual_norm': self.mean_residual_norm, 'mean_eps': self.mean_eps, }