Beispiel #1
0
def test_toeplitz_nufft_3d(params_3d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    # this tolerance looks really bad, but toep struggles with random traj
    # for radial it's more like 1e-06
    norm_tol = 1e-1

    im_size = params_3d["im_size"]
    numpoints = params_3d["numpoints"]

    x = params_3d["x"]
    y = params_3d["y"]
    ktraj = params_3d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(
            dtype=dtype, device=device
        )
        adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(
            dtype=dtype, device=device
        )
        toep_ob = ToepNufft()

        kern = calc_toep_kernel(adjkbnufft_ob, ktraj)

        normal_forw = adjkbnufft_ob(kbnufft_ob(x, ktraj), ktraj)

        toep_forw = toep_ob(x, kern)

        diff = torch.norm(normal_forw - toep_forw) / torch.norm(normal_forw)

        assert diff < norm_tol
def test_toepnufft_2d_adjoint(params_2d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = torch.randn(x.shape)
    ktraj = params_2d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(
            dtype=dtype, device=device
        )
        toep_ob = ToepNufft().to(dtype=dtype, device=device)

        kern = calc_toep_kernel(adjkbnufft_ob, ktraj)

        x_forw = toep_ob(x, kern)
        y_back = toep_ob(y, kern)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
    def __init__(self, im_size, ktraj, dcomp, csmap, norm='ortho'):
        super(Dyn2DRadEncObj, self).__init__()
        """
		inputs:
		- im_size	... the shape of the object to be imaged
		- ktraj 	... the k-space tractories  (torch tensor)
		- csmap 	... the coil-sensivity maps (torch tensor)
		- dcomp 	... the density compensation (torch tensor)
		
		N.B. the shapes for the tensors are the following:
			
			image   (1,1,2,Nx,Ny,Nt),
			ktraj 	(1,2,Nrad,Nt),
			csmap   (1,Nc,2,Nx,Ny),
			dcomp 	(1,1,1,Nrad,Nt),
			kdata 	(1,Nc,2,Nrad,Nt),
			
		where 
		- Nx,Ny,Nt = im_size
		- Nrad = 2*Ny*n_spokes
		- n_spokes = the number of spokes per 2D frame
		- Nc = the number of receiver coils

		"""
        dtype = torch.float

        Nx, Ny, Nt = im_size
        self.im_size = im_size
        self.spokelength = im_size[1] * 2
        self.Nrad = ktraj.shape[2]
        self.ktraj_tensor = ktraj
        self.dcomp_tensor = dcomp

        #parameters for contstructing the operators
        spokelength = im_size[1] * 2
        grid_size = (spokelength, spokelength)

        #single-coil/multi-coil NUFFTs
        if csmap is not None:
            self.NUFFT = MriSenseNufft(smap=csmap,
                                       im_size=im_size[:2],
                                       grid_size=grid_size,
                                       norm=norm).to(dtype)
            self.AdjNUFFT = AdjMriSenseNufft(smap=csmap,
                                             im_size=im_size[:2],
                                             grid_size=grid_size,
                                             norm=norm).to(dtype)
            self.ToepNUFFT = ToepSenseNufft(csmap).to(dtype)
        else:
            self.NUFFT = KbNufft(im_size=im_size[:2],
                                 grid_size=grid_size,
                                 norm=norm).to(dtype)
            self.AdjNUFFT = AdjKbNufft(im_size=im_size[:2],
                                       grid_size=grid_size,
                                       norm=norm).to(dtype)
            self.ToepNUFFT = ToepNufft().to(dtype)

        #calculate Toeplitz kernels
        # for E^{\dagger} \circ E
        self.AdagA_toep_kernel_list = [
            calc_toep_kernel(self.AdjNUFFT,
                             ktraj[..., kt],
                             weights=dcomp[..., kt]) for kt in range(Nt)
        ]

        # for E^H \circ E
        self.AHA_toep_kernel_list = [
            calc_toep_kernel(self.AdjNUFFT, ktraj[..., kt]) for kt in range(Nt)
        ]
Beispiel #4
0
def profile_torchkbnufft(image,
                         ktraj,
                         smap,
                         im_size,
                         device,
                         sparse_mats_flag=False,
                         use_toep=False):
    # run double precision for CPU, float for GPU
    # these seem to be present in reference implementations
    if device == torch.device("cpu"):
        dtype = torch.double
        if use_toep:
            num_nuffts = 20
        else:
            num_nuffts = 5
    else:
        dtype = torch.float
        if use_toep:
            num_nuffts = 50
        else:
            num_nuffts = 20
    cpudevice = torch.device("cpu")

    image = image.to(dtype=dtype)
    ktraj = ktraj.to(dtype=dtype)
    smap = smap.to(dtype=dtype)

    kbsense_ob = MriSenseNufft(smap=smap, im_size=im_size).to(dtype=dtype,
                                                              device=device)
    adjkbsense_ob = AdjMriSenseNufft(smap=smap,
                                     im_size=im_size).to(dtype=dtype,
                                                         device=device)

    adjkbnufft_ob = AdjKbNufft(im_size=im_size).to(dtype=dtype, device=device)

    # precompute toeplitz kernel if using toeplitz
    if use_toep:
        print("using toeplitz for forward/backward")
        kern = calc_toep_kernel(adjkbsense_ob, ktraj)
        toep_ob = ToepSenseNufft(smap=smap).to(dtype=dtype, device=device)

    # precompute the sparse interpolation matrices
    if sparse_mats_flag:
        print("using sparse interpolation matrices")
        real_mat, imag_mat = precomp_sparse_mats(ktraj, adjkbnufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }
    else:
        print("not using sparse interpolation matrices")
        interp_mats = None

    if use_toep:
        # warm-up computation
        for _ in range(num_nuffts):
            x = toep_ob(image.to(device=device),
                        kern.to(device=device)).to(cpudevice)
        # run the speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            x = toep_ob(image.to(device=device), kern.to(device=device))
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            print("GPU forward max memory: {} GB".format(max_mem / 1e9))
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        print("toeplitz forward/backward average time: {}".format(avg_time))
    else:
        # warm-up computation
        for _ in range(num_nuffts):
            y = kbsense_ob(image.to(device=device), ktraj.to(device=device),
                           interp_mats).to(cpudevice)

        # run the forward speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            y = kbsense_ob(image.to(device=device), ktraj.to(device=device),
                           interp_mats)
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            print("GPU forward max memory: {} GB".format(max_mem / 1e9))
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        print("forward average time: {}".format(avg_time))

        # warm-up computation
        for _ in range(num_nuffts):
            x = adjkbsense_ob(y.to(device), ktraj.to(device), interp_mats)

        # run the adjoint speed tests
        if device == torch.device("cuda"):
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for _ in range(num_nuffts):
            x = adjkbsense_ob(y.to(device), ktraj.to(device), interp_mats)
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
            max_mem = torch.cuda.max_memory_allocated()
            print("GPU adjoint max memory: {} GB".format(max_mem / 1e9))
        end_time = time.perf_counter()
        avg_time = (end_time - start_time) / num_nuffts
        print("backward average time: {}".format(avg_time))