예제 #1
0
def nufft_adjoint(input, coord, oshape, oversamp=1.25, width=4.0, n=128, device='cuda'):
    ndim = coord.shape[-1]
    beta = numpy.pi * (((width / oversamp) * (oversamp - 0.5)) ** 2 - 0.8) ** 0.5
    oshape = list(oshape)

    os_shape = _get_oversamp_shape(oshape, ndim, oversamp)

    # Gridding
    coord = _scale_coord(coord, oshape, oversamp, device)
    kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype, device)
    output = interp.gridding(input, os_shape, width, kernel, coord, device)

    # IFFT
    output = output.permute(0, 2, 3, 1)

    # plt.figure()
    # plt.imshow(print_complex_kspace_tensor(output[0].detach().cpu()), cmap='gray')
    # plt.show()

    output = transforms.ifft2(output)

    # plt.figure()
    # plt.imshow(print_complex_image_tensor(output[0].detach().cpu()), cmap='gray')
    # plt.show()

    # Crop
    output = output.permute(0, 3, 1, 2)
    output = util.resize(output, oshape, device=device)
    output *= util.prod(os_shape[-ndim:]) / util.prod(oshape[-ndim:]) ** 0.5

    # Apodize
    output = _apodize(output, ndim, oversamp, width, beta, device)
    return output.permute(0, 2, 3, 1)
예제 #2
0
def nufft_adjoint(input, coord, out_shape, oversamp=1.25, width=4.0, n=128, device='cuda'):
    ndim = coord.shape[-1]
    beta = numpy.pi * (((width / oversamp) * (oversamp - 0.5)) ** 2 - 0.8) ** 0.5
    out_shape = list(out_shape)

    os_shape = _get_oversamp_shape(out_shape, ndim, oversamp)

    # Gridding
    out_shape2 = out_shape.copy()
    os_shape2 = os_shape.copy()
    coord = _scale_coord(coord, out_shape2, oversamp, device)
    kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype, device)
    output = interp.gridding(input, os_shape2, width, kernel, coord, device)

    # IFFT
    output = output.permute(0, 1, 3, 4, 2)
    output = transforms.ifft2_regular(output)
    output = output.permute(0, 1, 4, 2, 3)

    # Crop
    output = util.resize(output, out_shape2, device=device)
    a = util.prod(os_shape2[-ndim:]) / util.prod(out_shape2[-ndim:]) ** 0.5
    output = output * a

    # Apodize
    output = _apodize(output, ndim, oversamp, width, beta, device)

    return output
예제 #3
0
def gridding(input, shape, width, kernel, coord, device):
    ndim = coord.shape[-1]

    batch_shape = shape[:-ndim]
    batch_size = util.prod(batch_shape)

    pts_shape = coord.shape[:-1]
    npts = util.prod(pts_shape)

    input = input.reshape([batch_size, npts])
    coord = coord.reshape([npts, ndim])
    output = torch.zeros([batch_size] + list(shape[-ndim:]), dtype=input.dtype, device=device)

    output = _gridding2(output, input, width, kernel, coord)

    return output.reshape(shape)
예제 #4
0
def interpolate(input, width, kernel, coord, device):
    ndim = coord.shape[-1]

    batch_shape = input.shape[:-ndim]
    batch_size = util.prod(batch_shape)

    pts_shape = coord.shape[:-1]
    npts = util.prod(pts_shape)

    input = input.reshape([batch_size] + list(input.shape[-ndim:]))
    coord = coord.reshape([npts, ndim])
    output = torch.zeros([batch_size, npts], dtype=input.dtype, device=device)

    output = _interpolate2(output, input, width, kernel, coord)

    return output.reshape(batch_shape + pts_shape)
예제 #5
0
def nufft(input,
          coord,
          ndim=2,
          oversamp=1.25,
          width=4.0,
          n=128,
          device='cuda'):
    # ndim = coord.shape[-1]
    beta = numpy.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5
    os_shape = _get_oversamp_shape(input.shape, ndim, oversamp)

    output = input.clone()

    # Apodize
    output = _apodize(output, ndim, oversamp, width, beta, device)

    # Zero-pad
    output = output / util.prod(input.shape[-ndim:])**0.5
    output = util.resize(output, os_shape, device=device)

    # FFT
    if ndim == 2:
        output = transforms.rfft2(output)
    elif ndim == 3:
        output = transforms.rfft3(output)

    # Interpolate
    coord = _scale_coord(coord, input.shape, oversamp, device)
    kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype, device)
    output = interp.interpolate(output, width, kernel, coord, ndim, device)

    return output
예제 #6
0
파일: nufft.py 프로젝트: 3d-flat/3dflat
def nufft_adjoint(input,
                  coord,
                  oshape,
                  ndim=2,
                  oversamp=1.25,
                  width=4.0,
                  n=128,
                  device='cuda'):
    # ndim = coord.shape[-1]
    beta = numpy.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5
    oshape = list(oshape)

    os_shape = _get_oversamp_shape(oshape, ndim, oversamp)

    # Gridding

    oshape2 = oshape.copy()
    oshape2[1] = 2
    os_shape2 = os_shape.copy()
    os_shape2[1] = 2
    coord = _scale_coord(coord, oshape2, oversamp, device)
    kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype, device)
    output = interp.gridding(input, os_shape2, width, kernel, coord, ndim,
                             device)

    # IFFT
    if ndim == 2:
        output = transforms.ifft2(output)
    elif ndim == 3:
        output = transforms.ifft3(output)

    # Crop
    output = util.resize(output, oshape2, device=device)
    a = util.prod(os_shape[-ndim:]) / util.prod(oshape[-ndim:])**0.5
    output = output * a

    # Apodize
    output = _apodize(output, ndim, oversamp, width, beta, device)

    if ndim == 2:
        output = output[:, 0, :, :]
    elif ndim == 3:
        output = output[:, 0, :, :, :]

    return output