Example #1
0
def test_2d_kbnufft_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    kbnufft_ob = KbNufft(im_size=im_size, numpoints=(4, 6))
    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6))

    x.requires_grad = True
    y = kbnufft_ob.forward(x, ktraj)

    ((y**2) / 2).sum().backward()
    x_grad = x.grad.clone().detach()

    x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj)

    assert torch.norm(x_grad - x_hat) < norm_tol
Example #2
0
def test_2d_kbnufft_backward(params_2d, testing_tol, testing_dtype,
                             device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        x.requires_grad = True
        y = kbnufft_ob.forward(x, ktraj)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj)

        assert torch.norm(x_grad - x_hat) < norm_tol
Example #3
0
def test_3d_kbnufft_adjoint_backward():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (11, 33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 3, klength)).to(dtype)

    kbnufft_ob = KbNufft(im_size=im_size, numpoints=(2, 4, 6))
    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(2, 4, 6))

    real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    y.requires_grad = True
    x = adjkbnufft_ob.forward(y, ktraj, interp_mats)

    ((x**2) / 2).sum().backward()
    y_grad = y.grad.clone().detach()

    y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats)

    assert torch.norm(y_grad - y_hat) < norm_tol
Example #4
0
def setup():
    spokelength = 400
    grid_size = (spokelength, spokelength)
    nspokes = 10

    ga = np.deg2rad(180 / ((1 + np.sqrt(5)) / 2))
    kx = np.zeros(shape=(spokelength, nspokes))
    ky = np.zeros(shape=(spokelength, nspokes))
    ky[:, 0] = np.linspace(-np.pi, np.pi, spokelength)
    for i in range(1, nspokes):
        kx[:, i] = np.cos(ga) * kx[:, i - 1] - np.sin(ga) * ky[:, i - 1]
        ky[:, i] = np.sin(ga) * kx[:, i - 1] + np.cos(ga) * ky[:, i - 1]

    ky = np.transpose(ky)
    kx = np.transpose(kx)

    ktraj = np.stack((ky.flatten(), kx.flatten()), axis=0)
    im_size = (200, 200)
    nufft_ob = KbNufftModule(im_size=im_size,
                             grid_size=grid_size,
                             norm='ortho')
    torch_forward = KbNufft(im_size=im_size, grid_size=grid_size, norm='ortho')
    torch_backward = AdjKbNufft(im_size=im_size,
                                grid_size=grid_size,
                                norm='ortho')
    return ktraj, nufft_ob, torch_forward, torch_backward
Example #5
0
def test_toeplitz_nufft_3d(params_3d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    # this tolerance looks really bad, but toep struggles with random traj
    # for radial it's more like 1e-06
    norm_tol = 1e-1

    im_size = params_3d["im_size"]
    numpoints = params_3d["numpoints"]

    x = params_3d["x"]
    y = params_3d["y"]
    ktraj = params_3d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(
            dtype=dtype, device=device
        )
        adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(
            dtype=dtype, device=device
        )
        toep_ob = ToepNufft()

        kern = calc_toep_kernel(adjkbnufft_ob, ktraj)

        normal_forw = adjkbnufft_ob(kbnufft_ob(x, ktraj), ktraj)

        toep_forw = toep_ob(x, kern)

        diff = torch.norm(normal_forw - toep_forw) / torch.norm(normal_forw)

        assert diff < norm_tol
def test_nufft_2d_adjoint():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 2, klength)).to(dtype)

    kbnufft_ob = KbNufft(im_size=im_size, numpoints=(4, 6))
    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(4, 6))

    x_forw = kbnufft_ob(x, ktraj)
    y_back = adjkbnufft_ob(y, ktraj)

    inprod1 = inner_product(y, x_forw, dim=2)
    inprod2 = inner_product(y_back, x, dim=2)

    assert torch.norm(inprod1 - inprod2) < norm_tol
def test_nufft_3d_adjoint():
    dtype = torch.double

    nslice = 2
    ncoil = 4
    im_size = (11, 33, 24)
    klength = 112

    x = np.random.normal(size=(nslice, ncoil) + im_size) + \
        1j*np.random.normal(size=(nslice, ncoil) + im_size)
    x = torch.tensor(np.stack((np.real(x), np.imag(x)), axis=2)).to(dtype)

    y = np.random.normal(size=(nslice, ncoil, klength)) + \
        1j*np.random.normal(size=(nslice, ncoil, klength))
    y = torch.tensor(np.stack((np.real(y), np.imag(y)), axis=2)).to(dtype)

    ktraj = torch.randn(*(nslice, 3, klength)).to(dtype)

    kbnufft_ob = KbNufft(im_size=im_size, numpoints=(2, 4, 6))
    adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=(2, 4, 6))

    real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
    interp_mats = {'real_interp_mats': real_mat, 'imag_interp_mats': imag_mat}

    x_forw = kbnufft_ob(x, ktraj, interp_mats)
    y_back = adjkbnufft_ob(y, ktraj, interp_mats)

    inprod1 = inner_product(y, x_forw, dim=2)
    inprod2 = inner_product(y_back, x, dim=2)

    assert torch.norm(inprod1 - inprod2) < norm_tol
Example #8
0
def test_nufft_2d_adjoint(params_2d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_2d["im_size"]
    numpoints = params_2d["numpoints"]

    x = params_2d["x"]
    y = params_2d["y"]
    ktraj = params_2d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size, numpoints=numpoints).to(
            dtype=dtype, device=device
        )
        adjkbnufft_ob = AdjKbNufft(im_size=im_size, numpoints=numpoints).to(
            dtype=dtype, device=device
        )

        x_forw = kbnufft_ob(x, ktraj)
        y_back = adjkbnufft_ob(y, ktraj)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
Example #9
0
    def initialize_plan(self, ratio=2):

        grid_size = (
            ratio * self.grid.shape_2d[0],
            ratio * self.grid.shape_2d[1]
        )

        self.nufft_ob = KbNufft(
            im_size=self.grid.shape_2d,
            grid_size=grid_size,
            norm='ortho'
        ).to(self.dtype)
def test_3d_kbnufft_backward(params_3d, testing_tol, testing_dtype,
                             device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_3d['im_size']
    numpoints = params_3d['numpoints']

    x = params_3d['x']
    y = params_3d['y']
    ktraj = params_3d['ktraj']

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = kbnufft_ob.forward(x, ktraj, interp_mats)

        ((y**2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj, interp_mats)

        assert torch.norm(x_grad - x_hat) < norm_tol
def test_3d_kbnufft_adjoint_backward(params_3d, testing_tol, testing_dtype,
                                     device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_3d["im_size"]
    numpoints = params_3d["numpoints"]

    x = params_3d["x"]
    y = params_3d["y"]
    ktraj = params_3d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        y.requires_grad = True
        x = adjkbnufft_ob.forward(y, ktraj, interp_mats)

        ((x**2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()

        y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats)

        assert torch.norm(y_grad - y_hat) < norm_tol
Example #12
0
    def __init__(self, uv_wavelengths, grid):

        self.dtype = torch.float

        self.uv_wavelengths = uv_wavelengths

        self.uv_wavelengths = np.array([
            self.uv_wavelengths[:, 1] / (1.0 / (2.0 * grid.pixel_scales[0] * units.arcsec.to(units.rad))) * np.pi,
            self.uv_wavelengths[:, 0] / (1.0 / (2.0 * grid.pixel_scales[0] * units.arcsec.to(units.rad))) * np.pi
        ]).T

        self.uv_wavelengths = np.stack(
            (
                np.transpose(uv_wavelengths[:, 0]).flatten(),
                np.transpose(uv_wavelengths[:, 1]).flatten()
            ),
            axis=0
        )
        self.uv_wavelengths = torch.tensor(
            self.uv_wavelengths
        ).to(self.dtype).unsqueeze(0)

        self.grid = grid

        # # NOTE: The plan need only be initialized once
        # self.initialize_plan()

        ratio = 2
        grid_size = (
            ratio * self.grid.shape_2d[0],
            ratio * self.grid.shape_2d[1]
        )

        self.nufft_ob = KbNufft(
            im_size=self.grid.shape_2d,
            grid_size=grid_size,
            norm='ortho'
        ).to(self.dtype)

        # ...
        self.shift = np.exp(
            -2.0
            * np.pi
            * 1j
            * (
                self.grid.pixel_scales[0]/2.0 * units.arcsec.to(units.rad) * self.uv_wavelengths[:, 1]
                + self.grid.pixel_scales[0]/2.0 * units.arcsec.to(units.rad) * self.uv_wavelengths[:, 0]
            )
        )
Example #13
0
def test_kb_matching(testing_tol):
    norm_tol = testing_tol

    def check_tables(table1, table2):
        for ind, table in enumerate(table1):
            assert np.linalg.norm(table - table2[ind]) < norm_tol

    im_szs = [(256, 256), (10, 256, 256)]

    kbwidths = [2.34, 5]
    orders = [0, 2]

    for kbwidth in kbwidths:
        for order in orders:
            for im_sz in im_szs:
                smap = torch.randn(*((1,) + im_sz))

                base_table = AdjKbNufft(
                    im_sz, order=order, kbwidth=kbwidth).table

                cur_table = KbNufft(im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = KbInterpBack(
                    im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = KbInterpForw(
                    im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = MriSenseNufft(
                    smap, im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)

                cur_table = AdjMriSenseNufft(
                    smap, im_sz, order=order, kbwidth=kbwidth).table
                check_tables(base_table, cur_table)
def test_nufft_3d_adjoint(params_3d, testing_tol, testing_dtype, device_list):
    dtype = testing_dtype
    norm_tol = testing_tol

    im_size = params_3d["im_size"]
    numpoints = params_3d["numpoints"]

    x = params_3d["x"]
    y = params_3d["y"]
    ktraj = params_3d["ktraj"]

    for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(im_size=im_size,
                             numpoints=numpoints).to(dtype=dtype,
                                                     device=device)
        adjkbnufft_ob = AdjKbNufft(im_size=im_size,
                                   numpoints=numpoints).to(dtype=dtype,
                                                           device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            "real_interp_mats": real_mat,
            "imag_interp_mats": imag_mat
        }

        x_forw = kbnufft_ob(x, ktraj, interp_mats)
        y_back = adjkbnufft_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
def test_2d_init_inputs():
    # all object initializations have assertions
    # this should result in an error if any dimensions don't match

    # test 2d scalar inputs
    im_sz = (256, 256)
    smap = torch.randn(*((1,) + im_sz))
    grid_sz = (512, 512)
    n_shift = (128, 128)
    numpoints = 6
    table_oversamp = 2 ** 10
    kbwidth = 2.34
    order = 0
    norm = "None"

    ob = KbInterpForw(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )
    ob = KbInterpBack(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )

    ob = KbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjKbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )

    ob = MriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjMriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )

    # test 2d tuple inputs
    im_sz = (256, 256)
    smap = torch.randn(*((1,) + im_sz))
    grid_sz = (512, 512)
    n_shift = (128, 128)
    numpoints = (6, 6)
    table_oversamp = (2 ** 10, 2 ** 10)
    kbwidth = (2.34, 2.34)
    order = (0, 0)
    norm = "None"

    ob = KbInterpForw(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )
    ob = KbInterpBack(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
    )

    ob = KbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjKbNufft(
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )

    ob = MriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
    ob = AdjMriSenseNufft(
        smap=smap,
        im_size=im_sz,
        grid_size=grid_sz,
        n_shift=n_shift,
        numpoints=numpoints,
        table_oversamp=table_oversamp,
        kbwidth=kbwidth,
        order=order,
        norm=norm,
    )
Example #16
0
    #     uv_wavelengths[:, 1] / (1.0 / (2.0 * grid.pixel_scales[0] * units.arcsec.to(units.rad))) * np.pi,
    #     uv_wavelengths[:, 0] / (1.0 / (2.0 * grid.pixel_scales[0] * units.arcsec.to(units.rad))) * np.pi
    # ]).T
    #
    # uv_wavelengths_temp = np.stack(
    #     (
    #         np.transpose(uv_wavelengths_temp[:, 0]).flatten(),
    #         np.transpose(uv_wavelengths_temp[:, 1]).flatten()
    #     ),
    #     axis=0
    # )
    #
    # ktraj = torch.tensor(uv_wavelengths_temp).to(dtype).unsqueeze(0)


    nufft_ob = KbNufft(im_size=im_size, grid_size=grid_size, norm='ortho').to(dtype)
    adjnufft_ob = AdjKbNufft(im_size=im_size, grid_size=grid_size, norm='ortho').to(dtype)

    # calculate k-space data
    #print(image.shape, ktraj.shape);exit()
    kdata = nufft_ob(_image, ktraj)
    #print(kdata[0, 0, 0, :])
    #exit()

    # real_visibilities__from__kbnufft_transformer = kdata[0, 0, 0, :]
    # real_visibilities__from__kbnufft_transformer *= shift.real
    # imag_visibilities__from__kbnufft_transformer = kdata[0, 0, 1, :]
    # imag_visibilities__from__kbnufft_transformer *= shift.imag

    # image_blurry = adjnufft_ob(kdata, ktraj)
    #
    def __init__(self, im_size, ktraj, dcomp, csmap, norm='ortho'):
        super(Dyn2DRadEncObj, self).__init__()
        """
		inputs:
		- im_size	... the shape of the object to be imaged
		- ktraj 	... the k-space tractories  (torch tensor)
		- csmap 	... the coil-sensivity maps (torch tensor)
		- dcomp 	... the density compensation (torch tensor)
		
		N.B. the shapes for the tensors are the following:
			
			image   (1,1,2,Nx,Ny,Nt),
			ktraj 	(1,2,Nrad,Nt),
			csmap   (1,Nc,2,Nx,Ny),
			dcomp 	(1,1,1,Nrad,Nt),
			kdata 	(1,Nc,2,Nrad,Nt),
			
		where 
		- Nx,Ny,Nt = im_size
		- Nrad = 2*Ny*n_spokes
		- n_spokes = the number of spokes per 2D frame
		- Nc = the number of receiver coils

		"""
        dtype = torch.float

        Nx, Ny, Nt = im_size
        self.im_size = im_size
        self.spokelength = im_size[1] * 2
        self.Nrad = ktraj.shape[2]
        self.ktraj_tensor = ktraj
        self.dcomp_tensor = dcomp

        #parameters for contstructing the operators
        spokelength = im_size[1] * 2
        grid_size = (spokelength, spokelength)

        #single-coil/multi-coil NUFFTs
        if csmap is not None:
            self.NUFFT = MriSenseNufft(smap=csmap,
                                       im_size=im_size[:2],
                                       grid_size=grid_size,
                                       norm=norm).to(dtype)
            self.AdjNUFFT = AdjMriSenseNufft(smap=csmap,
                                             im_size=im_size[:2],
                                             grid_size=grid_size,
                                             norm=norm).to(dtype)
            self.ToepNUFFT = ToepSenseNufft(csmap).to(dtype)
        else:
            self.NUFFT = KbNufft(im_size=im_size[:2],
                                 grid_size=grid_size,
                                 norm=norm).to(dtype)
            self.AdjNUFFT = AdjKbNufft(im_size=im_size[:2],
                                       grid_size=grid_size,
                                       norm=norm).to(dtype)
            self.ToepNUFFT = ToepNufft().to(dtype)

        #calculate Toeplitz kernels
        # for E^{\dagger} \circ E
        self.AdagA_toep_kernel_list = [
            calc_toep_kernel(self.AdjNUFFT,
                             ktraj[..., kt],
                             weights=dcomp[..., kt]) for kt in range(Nt)
        ]

        # for E^H \circ E
        self.AHA_toep_kernel_list = [
            calc_toep_kernel(self.AdjNUFFT, ktraj[..., kt]) for kt in range(Nt)
        ]
Example #18
0
        zbl * 1e-6, (prior_fwhm, prior_fwhm, 0, prior_fwhm, prior_fwhm))

    simim = prior.copy()

    # simim = eh.image.load_fits(gt_path)
    # simim = simim.regrid_image(fov, npix)
    simim.ra = obs.ra
    simim.dec = obs.dec
    simim.rf = obs.rf

    save_path = args.save_path
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # define the eht observation function
    nufft_ob = KbNufft(im_size=(npix, npix), numpoints=3)
    ktraj_vis, pulsefac_vis_torch, cphase_ind_list, cphase_sign_list, camp_ind_list = Obs_params_torch(
        obs, simim)
    eht_obs_torch = eht_observation_pytorch(npix, nufft_ob, ktraj_vis,
                                            pulsefac_vis_torch,
                                            cphase_ind_list, cphase_sign_list,
                                            camp_ind_list, device)

    if args.model_form == 'realnvp':
        n_flow = args.n_flow
        affine = True
        img_generator = realnvpfc_model.RealNVP(npix * npix,
                                                n_flow,
                                                affine=affine).to(device)
        # img_generator.load_state_dict(torch.load(save_path+'/generativemodel_'+args.model_form+'_res{}flow{}logdet{}_closure_fluxcentermemtsv'.format(npix, n_flow, args.logdet)))
    elif args.model_form == 'glow':