def setup(): spokelength = 400 grid_size = (spokelength, spokelength) nspokes = 10 ga = np.deg2rad(180 / ((1 + np.sqrt(5)) / 2)) kx = np.zeros(shape=(spokelength, nspokes)) ky = np.zeros(shape=(spokelength, nspokes)) ky[:, 0] = np.linspace(-np.pi, np.pi, spokelength) for i in range(1, nspokes): kx[:, i] = np.cos(ga) * kx[:, i - 1] - np.sin(ga) * ky[:, i - 1] ky[:, i] = np.sin(ga) * kx[:, i - 1] + np.cos(ga) * ky[:, i - 1] ky = np.transpose(ky) kx = np.transpose(kx) ktraj = np.stack((ky.flatten(), kx.flatten()), axis=0) im_size = (200, 200) nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho') torch_forward = KbNufft(im_size=im_size, grid_size=grid_size, norm='ortho') torch_backward = AdjKbNufft(im_size=im_size, grid_size=grid_size, norm='ortho') return ktraj, nufft_ob, torch_forward, torch_backward
def test_3d_kbnufft_adjoint_backward(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (11, 33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 3, klength)).to(dtype) kbnufft_ob = KbNufft(im_size=im_size, numpoints=(2, 4, 6)) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(2, 4, 6)) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat} y.requires_grad = True x = adjkbnufft_ob.forward(y, ktraj, interp_mats) ((x**2) / 2).sum().backward() y_grad = y.grad.clone().detach() y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats) assert torch.norm(y_grad - y_hat) < norm_tol
def test_2d_kbnufft_backward(params_2d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_2d["im_size"] numpoints = params_2d["numpoints"] x = params_2d["x"] y = params_2d["y"] ktraj = params_2d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) x.requires_grad = True y = kbnufft_ob.forward(x, ktraj) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj) assert torch.norm(x_grad - x_hat) < norm_tol
def test_toeplitz_nufft_3d(params_3d, testing_tol, testing_dtype, device_list): dtype = testing_dtype # this tolerance looks really bad, but toep struggles with random traj # for radial it's more like 1e-06 norm_tol = 1e-1 im_size = params_3d["im_size"] numpoints = params_3d["numpoints"] x = params_3d["x"] y = params_3d["y"] ktraj = params_3d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to( dtype=dtype, device=device ) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to( dtype=dtype, device=device ) toep_ob = ToepNufft() kern = calc_toep_kernel(adjkbnufft_ob, ktraj) normal_forw = adjkbnufft_ob(kbnufft_ob(x, ktraj), ktraj) toep_forw = toep_ob(x, kern) diff = torch.norm(normal_forw - toep_forw) / torch.norm(normal_forw) assert diff < norm_tol
def test_nufft_2d_adjoint(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 2, klength)).to(dtype) kbnufft_ob = KbNufft(im_size=im_size, numpoints=(4, 6)) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6)) x_forw = kbnufft_ob(x, ktraj) y_back = adjkbnufft_ob(y, ktraj) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_nufft_3d_adjoint(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (11, 33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 3, klength)).to(dtype) kbnufft_ob = KbNufft(im_size=im_size, numpoints=(2, 4, 6)) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(2, 4, 6)) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat} x_forw = kbnufft_ob(x, ktraj, interp_mats) y_back = adjkbnufft_ob(y, ktraj, interp_mats) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_2d_kbnufft_backward(): dtype = torch.double nslice = 2 ncoil = 4 im_size = (33, 24) klength = 112 x = np.random.normal(size=(nslice, ncoil) + im_size) + \ 1j*np.random.normal(size=(nslice, ncoil) + im_size) x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype) y = np.random.normal(size=(nslice, ncoil, klength)) + \ 1j*np.random.normal(size=(nslice, ncoil, klength)) y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype) ktraj = torch.randn(*(nslice, 2, klength)).to(dtype) kbnufft_ob = KbNufft(im_size=im_size, numpoints=(4, 6)) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6)) x.requires_grad = True y = kbnufft_ob.forward(x, ktraj) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj) assert torch.norm(x_grad - x_hat) < norm_tol
def test_nufft_2d_adjoint(params_2d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_2d["im_size"] numpoints = params_2d["numpoints"] x = params_2d["x"] y = params_2d["y"] ktraj = params_2d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to( dtype=dtype, device=device ) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to( dtype=dtype, device=device ) x_forw = kbnufft_ob(x, ktraj) y_back = adjkbnufft_ob(y, ktraj) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def initialize_plan(self, ratio=2): grid_size = ( ratio * self.grid.shape_2d[0], ratio * self.grid.shape_2d[1] ) self.nufft_ob = KbNufft( im_size=self.grid.shape_2d, grid_size=grid_size, norm='ortho' ).to(self.dtype)
def __init__(self, uv_wavelengths, grid): self.dtype = torch.float self.uv_wavelengths = uv_wavelengths self.uv_wavelengths = np.array([ self.uv_wavelengths[:, 1] / (1.0 / (2.0 * grid.pixel_scales[0] * units.arcsec.to(units.rad))) * np.pi, self.uv_wavelengths[:, 0] / (1.0 / (2.0 * grid.pixel_scales[0] * units.arcsec.to(units.rad))) * np.pi ]).T self.uv_wavelengths = np.stack( ( np.transpose(uv_wavelengths[:, 0]).flatten(), np.transpose(uv_wavelengths[:, 1]).flatten() ), axis=0 ) self.uv_wavelengths = torch.tensor( self.uv_wavelengths ).to(self.dtype).unsqueeze(0) self.grid = grid # # NOTE: The plan need only be initialized once # self.initialize_plan() ratio = 2 grid_size = ( ratio * self.grid.shape_2d[0], ratio * self.grid.shape_2d[1] ) self.nufft_ob = KbNufft( im_size=self.grid.shape_2d, grid_size=grid_size, norm='ortho' ).to(self.dtype) # ... self.shift = np.exp( -2.0 * np.pi * 1j * ( self.grid.pixel_scales[0]/2.0 * units.arcsec.to(units.rad) * self.uv_wavelengths[:, 1] + self.grid.pixel_scales[0]/2.0 * units.arcsec.to(units.rad) * self.uv_wavelengths[:, 0] ) )
def test_3d_kbnufft_backward(params_3d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_3d['im_size'] numpoints = params_3d['numpoints'] x = params_3d['x'] y = params_3d['y'] ktraj = params_3d['ktraj'] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = { 'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat } x.requires_grad = True y = kbnufft_ob.forward(x, ktraj, interp_mats) ((y**2) / 2).sum().backward() x_grad = x.grad.clone().detach() x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj, interp_mats) assert torch.norm(x_grad - x_hat) < norm_tol
def test_3d_kbnufft_adjoint_backward(params_3d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_3d["im_size"] numpoints = params_3d["numpoints"] x = params_3d["x"] y = params_3d["y"] ktraj = params_3d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = { "real_interp_mats": real_mat, "imag_interp_mats": imag_mat } y.requires_grad = True x = adjkbnufft_ob.forward(y, ktraj, interp_mats) ((x**2) / 2).sum().backward() y_grad = y.grad.clone().detach() y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats) assert torch.norm(y_grad - y_hat) < norm_tol
def test_kb_matching(testing_tol): norm_tol = testing_tol def check_tables(table1, table2): for ind, table in enumerate(table1): assert np.linalg.norm(table - table2[ind]) < norm_tol im_szs = [(256, 256), (10, 256, 256)] kbwidths = [2.34, 5] orders = [0, 2] for kbwidth in kbwidths: for order in orders: for im_sz in im_szs: smap = torch.randn(*((1,) + im_sz)) base_table = AdjKbNufft( im_sz, order=order, kbwidth=kbwidth).table cur_table = KbNufft(im_sz, order=order, kbwidth=kbwidth).table check_tables(base_table, cur_table) cur_table = KbInterpBack( im_sz, order=order, kbwidth=kbwidth).table check_tables(base_table, cur_table) cur_table = KbInterpForw( im_sz, order=order, kbwidth=kbwidth).table check_tables(base_table, cur_table) cur_table = MriSenseNufft( smap, im_sz, order=order, kbwidth=kbwidth).table check_tables(base_table, cur_table) cur_table = AdjMriSenseNufft( smap, im_sz, order=order, kbwidth=kbwidth).table check_tables(base_table, cur_table)
def test_nufft_3d_adjoint(params_3d, testing_tol, testing_dtype, device_list): dtype = testing_dtype norm_tol = testing_tol im_size = params_3d["im_size"] numpoints = params_3d["numpoints"] x = params_3d["x"] y = params_3d["y"] ktraj = params_3d["ktraj"] for device in device_list: x = x.detach().to(dtype=dtype, device=device) y = y.detach().to(dtype=dtype, device=device) ktraj = ktraj.detach().to(dtype=dtype, device=device) kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(dtype=dtype, device=device) real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob) interp_mats = { "real_interp_mats": real_mat, "imag_interp_mats": imag_mat } x_forw = kbnufft_ob(x, ktraj, interp_mats) y_back = adjkbnufft_ob(y, ktraj, interp_mats) inprod1 = inner_product(y, x_forw, dim=2) inprod2 = inner_product(y_back, x, dim=2) assert torch.norm(inprod1 - inprod2) < norm_tol
def test_2d_init_inputs(): # all object initializations have assertions # this should result in an error if any dimensions don't match # test 2d scalar inputs im_sz = (256, 256) smap = torch.randn(*((1,) + im_sz)) grid_sz = (512, 512) n_shift = (128, 128) numpoints = 6 table_oversamp = 2 ** 10 kbwidth = 2.34 order = 0 norm = "None" ob = KbInterpForw( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, ) ob = KbInterpBack( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, ) ob = KbNufft( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) ob = AdjKbNufft( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) ob = MriSenseNufft( smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) ob = AdjMriSenseNufft( smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) # test 2d tuple inputs im_sz = (256, 256) smap = torch.randn(*((1,) + im_sz)) grid_sz = (512, 512) n_shift = (128, 128) numpoints = (6, 6) table_oversamp = (2 ** 10, 2 ** 10) kbwidth = (2.34, 2.34) order = (0, 0) norm = "None" ob = KbInterpForw( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, ) ob = KbInterpBack( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, ) ob = KbNufft( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) ob = AdjKbNufft( im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) ob = MriSenseNufft( smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, ) ob = AdjMriSenseNufft( smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints, table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm, )
def __init__(self, im_size, ktraj, dcomp, csmap, norm='ortho'): super(Dyn2DRadEncObj, self).__init__() """ inputs: - im_size ... the shape of the object to be imaged - ktraj ... the k-space tractories (torch tensor) - csmap ... the coil-sensivity maps (torch tensor) - dcomp ... the density compensation (torch tensor) N.B. the shapes for the tensors are the following: image (1,1,2,Nx,Ny,Nt), ktraj (1,2,Nrad,Nt), csmap (1,Nc,2,Nx,Ny), dcomp (1,1,1,Nrad,Nt), kdata (1,Nc,2,Nrad,Nt), where - Nx,Ny,Nt = im_size - Nrad = 2*Ny*n_spokes - n_spokes = the number of spokes per 2D frame - Nc = the number of receiver coils """ dtype = torch.float Nx, Ny, Nt = im_size self.im_size = im_size self.spokelength = im_size[1] * 2 self.Nrad = ktraj.shape[2] self.ktraj_tensor = ktraj self.dcomp_tensor = dcomp #parameters for contstructing the operators spokelength = im_size[1] * 2 grid_size = (spokelength, spokelength) #single-coil/multi-coil NUFFTs if csmap is not None: self.NUFFT = MriSenseNufft(smap=csmap, im_size=im_size[:2], grid_size=grid_size, norm=norm).to(dtype) self.AdjNUFFT = AdjMriSenseNufft(smap=csmap, im_size=im_size[:2], grid_size=grid_size, norm=norm).to(dtype) self.ToepNUFFT = ToepSenseNufft(csmap).to(dtype) else: self.NUFFT = KbNufft(im_size=im_size[:2], grid_size=grid_size, norm=norm).to(dtype) self.AdjNUFFT = AdjKbNufft(im_size=im_size[:2], grid_size=grid_size, norm=norm).to(dtype) self.ToepNUFFT = ToepNufft().to(dtype) #calculate Toeplitz kernels # for E^{\dagger} \circ E self.AdagA_toep_kernel_list = [ calc_toep_kernel(self.AdjNUFFT, ktraj[..., kt], weights=dcomp[..., kt]) for kt in range(Nt) ] # for E^H \circ E self.AHA_toep_kernel_list = [ calc_toep_kernel(self.AdjNUFFT, ktraj[..., kt]) for kt in range(Nt) ]
zbl * 1e-6, (prior_fwhm, prior_fwhm, 0, prior_fwhm, prior_fwhm)) simim = prior.copy() # simim = eh.image.load_fits(gt_path) # simim = simim.regrid_image(fov, npix) simim.ra = obs.ra simim.dec = obs.dec simim.rf = obs.rf save_path = args.save_path if not os.path.exists(save_path): os.makedirs(save_path) # define the eht observation function nufft_ob = KbNufft(im_size=(npix, npix), numpoints=3) ktraj_vis, pulsefac_vis_torch, cphase_ind_list, cphase_sign_list, camp_ind_list = Obs_params_torch( obs, simim) eht_obs_torch = eht_observation_pytorch(npix, nufft_ob, ktraj_vis, pulsefac_vis_torch, cphase_ind_list, cphase_sign_list, camp_ind_list, device) if args.model_form == 'realnvp': n_flow = args.n_flow affine = True img_generator = realnvpfc_model.RealNVP(npix * npix, n_flow, affine=affine).to(device) # img_generator.load_state_dict(torch.load(save_path+'/generativemodel_'+args.model_form+'_res{}flow{}logdet{}_closure_fluxcentermemtsv'.format(npix, n_flow, args.logdet))) elif args.model_form == 'glow':
# uv_wavelengths[:, 1] / (1.0 / (2.0 * grid.pixel_scales[0] * units.arcsec.to(units.rad))) * np.pi, # uv_wavelengths[:, 0] / (1.0 / (2.0 * grid.pixel_scales[0] * units.arcsec.to(units.rad))) * np.pi # ]).T # # uv_wavelengths_temp = np.stack( # ( # np.transpose(uv_wavelengths_temp[:, 0]).flatten(), # np.transpose(uv_wavelengths_temp[:, 1]).flatten() # ), # axis=0 # ) # # ktraj = torch.tensor(uv_wavelengths_temp).to(dtype).unsqueeze(0) nufft_ob = KbNufft(im_size=im_size, grid_size=grid_size, norm='ortho').to(dtype) adjnufft_ob = AdjKbNufft(im_size=im_size, grid_size=grid_size, norm='ortho').to(dtype) # calculate k-space data #print(image.shape, ktraj.shape);exit() kdata = nufft_ob(_image, ktraj) #print(kdata[0, 0, 0, :]) #exit() # real_visibilities__from__kbnufft_transformer = kdata[0, 0, 0, :] # real_visibilities__from__kbnufft_transformer *= shift.real # imag_visibilities__from__kbnufft_transformer = kdata[0, 0, 1, :] # imag_visibilities__from__kbnufft_transformer *= shift.imag # image_blurry = adjnufft_ob(kdata, ktraj) #