示例#1
0
    def _nufft(self, freq_domain_data, iflag=1, eps=1E-7):
        """
        rotate coordinates and perform nufft
        :param freq_domain_data:
        :param iflag/eps: see finufftpy doc
        :param eps: precision of nufft
        :return: nufft of freq_domain_data after applying self.rotations
        """

        if not finufft:
            raise ImportError('finufftpy not available')

        new_grid_coords = self._rotate_coordinates()[0]

        # initialize array for nufft output
        f = np.zeros([len(new_grid_coords[0])], dtype=np.complex128, order='F')

        freq_domain_data_flat = np.asfortranarray(
            freq_domain_data.flatten(order='F'))

        finufftpy.nufft3d1(
            new_grid_coords[0],
            new_grid_coords[1],
            new_grid_coords[2],
            freq_domain_data_flat,
            iflag,
            eps,
            self.im_shape[0],
            self.im_shape[1],
            self.im_shape[2],
            f,
            debug=0,
            spread_debug=0,
            spread_sort=2,
            fftw=0,
            modeord=0,
            chkbnds=0,
            upsampfac=1.25)  # upsampling at 1.25 saves time at low precisions

        im_out = f.reshape(self.im_shape, order='F')

        return im_out
示例#2
0
def anufft3(vol_f, fourier_pts, sz):
    if len(sz) != 3:
        raise ValueError('sz must be 3')
    if len(fourier_pts.shape) != 2:
        raise ValueError('fourier_pts must be 2D with shape 3x_')
    if fourier_pts.shape[0] != 3:
        raise ValueError('fourier_pts must be 2D with shape 3x_')
    if not fourier_pts.flags.c_contiguous:
        fourier_pts = fourier_pts.copy()
    if not vol_f.flags.c_contiguous:
        vol_f = vol_f.copy()

    x = fourier_pts[0]
    y = fourier_pts[1]
    z = fourier_pts[2]
    isign = 1
    eps = 1e-15
    ms, mt, mu = sz
    f = np.empty(sz, dtype='complex128', order='F')
    finufftpy.nufft3d1(x, y, z, vol_f, isign, eps, ms, mt, mu, f)
    return f.copy()
示例#3
0
def accuracy_speed_tests(num_nonuniform_points,num_uniform_points,eps):
	nj,nk = int(num_nonuniform_points),int(num_nonuniform_points)
	iflag=1
	num_samples=int(np.minimum(5,num_uniform_points*0.5+1)) # number of outputs used for estimating accuracy; is small for speed

	print('Accuracy and speed tests for %d nonuniform points and eps=%g (error estimates use %d samples per run)' % (num_nonuniform_points,eps,num_samples))

	# for doing the error estimates
	Xest=np.zeros(num_samples,dtype=np.complex128)
	Xtrue=np.zeros(num_samples,dtype=np.complex128)

	###### 1-d cases ........................................................
	ms=int(num_uniform_points)

	xj=np.random.rand(nj)*2*math.pi-math.pi
	cj=np.random.rand(nj)+1j*np.random.rand(nj);
	fk=np.zeros([ms],dtype=np.complex128)
	timer=time.time()
	ret=finufftpy.nufft1d1(xj,cj,iflag,eps,ms,fk)
	elapsed=time.time()-timer

	k=np.arange(-np.floor(ms/2),np.floor((ms-1)/2+1))
	for ii in np.arange(0,num_samples):
		Xest[ii]=np.sum(cj * np.exp(1j*k[ii]*xj))
		Xtrue[ii]=fk[ii]
	print_report('finufft1d1',elapsed,Xest,Xtrue,nj)

	xj=np.random.rand(nj)*2*math.pi-math.pi
	cj=np.zeros([nj],dtype=np.complex128);
	fk=np.random.rand(ms)+1j*np.random.rand(ms);
	timer=time.time()
	ret=finufftpy.nufft1d2(xj,cj,iflag,eps,fk)
	elapsed=time.time()-timer

	k=np.arange(-np.floor(ms/2),np.floor((ms-1)/2+1))
	for ii in np.arange(0,num_samples):
		Xest[ii]=np.sum(fk * np.exp(1j*k*xj[ii]))
		Xtrue[ii]=cj[ii]
	print_report('finufft1d2',elapsed,Xest,Xtrue,nj)

	x=np.random.rand(nj)*2*math.pi-math.pi
	c=np.random.rand(nj)+1j*np.random.rand(nj);
	s=np.random.rand(nk)*2*math.pi-math.pi
	f=np.zeros([nk],dtype=np.complex128)
	timer=time.time()
	ret=finufftpy.nufft1d3(x,c,iflag,eps,s,f)
	elapsed=time.time()-timer

	for ii in np.arange(0,num_samples):
		Xest[ii]=np.sum(c * np.exp(1j*s[ii]*x))
		Xtrue[ii]=f[ii]
	print_report('finufft1d3',elapsed,Xest,Xtrue,nj+nk)

	###### 2-d cases ....................................................
	ms=int(np.ceil(np.sqrt(num_uniform_points)))
	mt=ms

	xj=np.random.rand(nj)*2*math.pi-math.pi
	yj=np.random.rand(nj)*2*math.pi-math.pi
	cj=np.random.rand(nj)+1j*np.random.rand(nj)
	fk=np.zeros([ms,mt],dtype=np.complex128,order='F')
	timer=time.time()
	ret=finufftpy.nufft2d1(xj,yj,cj,iflag,eps,ms,mt,fk)
	elapsed=time.time()-timer

	Ks,Kt=np.mgrid[-np.floor(ms/2):np.floor((ms-1)/2+1),-np.floor(mt/2):np.floor((mt-1)/2+1)]

	for ii in np.arange(0,num_samples):
		Xest[ii]=np.sum(cj * np.exp(1j*(Ks.ravel()[ii]*xj+Kt.ravel()[ii]*yj)))
		Xtrue[ii]=fk.ravel()[ii]
	print_report('finufft2d1',elapsed,Xest,Xtrue,nj)

	## 2d1many:
	ndata = 5       # how many vectors to do
	cj=np.array(np.random.rand(nj,ndata)+1j*np.random.rand(nj,ndata),order='F')
	fk=np.zeros([ms,mt,ndata],dtype=np.complex128,order='F')
	timer=time.time()
	ret=finufftpy.nufft2d1many(xj,yj,cj,iflag,eps,ms,mt,fk)
	elapsed=time.time()-timer

	dtest = ndata-1    # which of the ndata to test (in 0,..,ndata-1)
	for ii in np.arange(0,num_samples):
		Xest[ii]=np.sum(cj[:,dtest] * np.exp(1j*(Ks.ravel(order='F')[ii]*xj+Kt.ravel(order='F')[ii]*yj)))   # note fortran-ravel-order needed throughout - mess.
		Xtrue[ii]=fk.ravel(order='F')[ii + dtest*ms*mt]       # hack the offset in fk array - has to be better way
	print_report('finufft2d1many',elapsed,Xest,Xtrue,ndata*nj)

	# 2d2
	xj=np.random.rand(nj)*2*math.pi-math.pi
	yj=np.random.rand(nj)*2*math.pi-math.pi
	cj=np.zeros([nj],dtype=np.complex128);
	fk=np.random.rand(ms,mt)+1j*np.random.rand(ms,mt);
	timer=time.time()
	ret=finufftpy.nufft2d2(xj,yj,cj,iflag,eps,fk)
	elapsed=time.time()-timer

	Ks,Kt=np.mgrid[-np.floor(ms/2):np.floor((ms-1)/2+1),-np.floor(mt/2):np.floor((mt-1)/2+1)]
	for ii in np.arange(0,num_samples):
		Xest[ii]=np.sum(fk * np.exp(1j*(Ks*xj[ii]+Kt*yj[ii])))
		Xtrue[ii]=cj[ii]
	print_report('finufft2d2',elapsed,Xest,Xtrue,nj)

	# 2d2many (using same ndata and dtest as 2d1many; see above)
	cj=np.zeros([nj,ndata],order='F',dtype=np.complex128);
	fk=np.array(np.random.rand(ms,mt,ndata)+1j*np.random.rand(ms,mt,ndata),order='F')
	timer=time.time()
	ret=finufftpy.nufft2d2many(xj,yj,cj,iflag,eps,fk)
	elapsed=time.time()-timer

	for ii in np.arange(0,num_samples):
		Xest[ii]=np.sum(fk[:,:,dtest] * np.exp(1j*(Ks*xj[ii]+Kt*yj[ii])))
		Xtrue[ii]=cj[ii,dtest]
	print_report('finufft2d2many',elapsed,Xest,Xtrue,ndata*nj)
	
	# 2d3
	x=np.random.rand(nj)*2*math.pi-math.pi
	y=np.random.rand(nj)*2*math.pi-math.pi
	c=np.random.rand(nj)+1j*np.random.rand(nj);
	s=np.random.rand(nk)*2*math.pi-math.pi
	t=np.random.rand(nk)*2*math.pi-math.pi
	f=np.zeros([nk],dtype=np.complex128)
	timer=time.time()
	ret=finufftpy.nufft2d3(x,y,c,iflag,eps,s,t,f)
	elapsed=time.time()-timer

	for ii in np.arange(0,num_samples):
		Xest[ii]=np.sum(c * np.exp(1j*(s[ii]*x+t[ii]*y)))
		Xtrue[ii]=f[ii]
	print_report('finufft2d3',elapsed,Xest,Xtrue,nj+nk)

	###### 3-d cases ............................................................
	ms=int(np.ceil(num_uniform_points**(1.0/3)))
	mt=ms
	mu=ms

	xj=np.random.rand(nj)*2*math.pi-math.pi
	yj=np.random.rand(nj)*2*math.pi-math.pi
	zj=np.random.rand(nj)*2*math.pi-math.pi
	cj=np.random.rand(nj)+1j*np.random.rand(nj);
	fk=np.zeros([ms,mt,mu],dtype=np.complex128,order='F')
	timer=time.time()
	ret=finufftpy.nufft3d1(xj,yj,zj,cj,iflag,eps,ms,mt,mu,fk)
	elapsed=time.time()-timer

	Ks,Kt,Ku=np.mgrid[-np.floor(ms/2):np.floor((ms-1)/2+1),-np.floor(mt/2):np.floor((mt-1)/2+1),-np.floor(mu/2):np.floor((mu-1)/2+1)]
	for ii in np.arange(0,num_samples):
		Xest[ii]=np.sum(cj * np.exp(1j*(Ks.ravel()[ii]*xj+Kt.ravel()[ii]*yj+Ku.ravel()[ii]*zj)))
		Xtrue[ii]=fk.ravel()[ii]
	print_report('finufft3d1',elapsed,Xest,Xtrue,nj)

	xj=np.random.rand(nj)*2*math.pi-math.pi
	yj=np.random.rand(nj)*2*math.pi-math.pi
	zj=np.random.rand(nj)*2*math.pi-math.pi
	cj=np.zeros([nj],dtype=np.complex128);
	fk=np.random.rand(ms,mt,mu)+1j*np.random.rand(ms,mt,mu);
	timer=time.time()
	ret=finufftpy.nufft3d2(xj,yj,zj,cj,iflag,eps,fk)
	elapsed=time.time()-timer

	Ks,Kt,Ku=np.mgrid[-np.floor(ms/2):np.floor((ms-1)/2+1),-np.floor(mt/2):np.floor((mt-1)/2+1),-np.floor(mu/2):np.floor((mu-1)/2+1)]
	for ii in np.arange(0,num_samples):
		Xest[ii]=np.sum(fk * np.exp(1j*(Ks*xj[ii]+Kt*yj[ii]+Ku*zj[ii])))
		Xtrue[ii]=cj[ii]
	print_report('finufft3d2',elapsed,Xest,Xtrue,nj)

	x=np.random.rand(nj)*2*math.pi-math.pi
	y=np.random.rand(nj)*2*math.pi-math.pi
	z=np.random.rand(nj)*2*math.pi-math.pi
	c=np.random.rand(nj)+1j*np.random.rand(nj);
	s=np.random.rand(nk)*2*math.pi-math.pi
	t=np.random.rand(nk)*2*math.pi-math.pi
	u=np.random.rand(nk)*2*math.pi-math.pi
	f=np.zeros([nk],dtype=np.complex128)
	timer=time.time()
	ret=finufftpy.nufft3d3(x,y,z,c,iflag,eps,s,t,u,f)
	elapsed=time.time()-timer

	for ii in np.arange(0,num_samples):
		Xest[ii]=np.sum(c * np.exp(1j*(s[ii]*x+t[ii]*y+u[ii]*z)))
		Xtrue[ii]=f[ii]
	print_report('finufft3d3',elapsed,Xest,Xtrue,nj+nk)
示例#4
0
def accuracy_speed_tests(num_nonuniform_points, num_uniform_points, eps):
    nj, nk = int(num_nonuniform_points), int(num_nonuniform_points)
    iflag = 1
    num_samples = int(np.minimum(20, num_uniform_points * 0.5 +
                                 1))  #for estimating accuracy

    print(
        'Accuracy and speed tests for %d nonuniform points and eps=%g (error estimates use %d samples per run)'
        % (num_nonuniform_points, eps, num_samples))

    # for doing the error estimates
    Xest = np.zeros(num_samples, dtype=np.complex128)
    Xtrue = np.zeros(num_samples, dtype=np.complex128)

    ###### 1-d
    ms = int(num_uniform_points)

    xj = np.random.rand(nj) * 2 * math.pi - math.pi
    cj = np.random.rand(nj) + 1j * np.random.rand(nj)
    fk = np.zeros([ms], dtype=np.complex128)
    timer = time.time()
    ret = finufftpy.nufft1d1(xj, cj, iflag, eps, ms, fk)
    elapsed = time.time() - timer

    k = np.arange(-np.floor(ms / 2), np.floor((ms - 1) / 2 + 1))
    for ii in np.arange(0, num_samples):
        Xest[ii] = np.sum(cj * np.exp(1j * k[ii] * xj))
        Xtrue[ii] = fk[ii]
    print_report('finufft1d1', elapsed, Xest, Xtrue, nj)

    xj = np.random.rand(nj) * 2 * math.pi - math.pi
    cj = np.zeros([nj], dtype=np.complex128)
    fk = np.random.rand(ms) + 1j * np.random.rand(ms)
    timer = time.time()
    ret = finufftpy.nufft1d2(xj, cj, iflag, eps, fk)
    elapsed = time.time() - timer

    k = np.arange(-np.floor(ms / 2), np.floor((ms - 1) / 2 + 1))
    for ii in np.arange(0, num_samples):
        Xest[ii] = np.sum(fk * np.exp(1j * k * xj[ii]))
        Xtrue[ii] = cj[ii]
    print_report('finufft1d2', elapsed, Xest, Xtrue, nj)

    x = np.random.rand(nj) * 2 * math.pi - math.pi
    c = np.random.rand(nj) + 1j * np.random.rand(nj)
    s = np.random.rand(nk) * 2 * math.pi - math.pi
    f = np.zeros([nk], dtype=np.complex128)
    timer = time.time()
    ret = finufftpy.nufft1d3(x, c, iflag, eps, s, f)
    elapsed = time.time() - timer

    for ii in np.arange(0, num_samples):
        Xest[ii] = np.sum(c * np.exp(1j * s[ii] * x))
        Xtrue[ii] = f[ii]
    print_report('finufft1d3', elapsed, Xest, Xtrue, nj + nk)

    ###### 2-d
    ms = int(np.ceil(np.sqrt(num_uniform_points)))
    mt = ms

    xj = np.random.rand(nj) * 2 * math.pi - math.pi
    yj = np.random.rand(nj) * 2 * math.pi - math.pi
    cj = np.random.rand(nj) + 1j * np.random.rand(nj)
    fk = np.zeros([ms, mt], dtype=np.complex128, order='F')
    timer = time.time()
    ret = finufftpy.nufft2d1(xj, yj, cj, iflag, eps, ms, mt, fk)
    elapsed = time.time() - timer

    Ks, Kt = np.mgrid[-np.floor(ms / 2):np.floor((ms - 1) / 2 + 1),
                      -np.floor(mt / 2):np.floor((mt - 1) / 2 + 1)]

    for ii in np.arange(0, num_samples):
        Xest[ii] = np.sum(
            cj * np.exp(1j * (Ks.ravel()[ii] * xj + Kt.ravel()[ii] * yj)))
        Xtrue[ii] = fk.ravel()[ii]
    print_report('finufft2d1', elapsed, Xest, Xtrue, nj)

    xj = np.random.rand(nj) * 2 * math.pi - math.pi
    yj = np.random.rand(nj) * 2 * math.pi - math.pi
    cj = np.zeros([nj], dtype=np.complex128)
    fk = np.random.rand(ms, mt) + 1j * np.random.rand(ms, mt)
    timer = time.time()
    ret = finufftpy.nufft2d2(xj, yj, cj, iflag, eps, fk)
    elapsed = time.time() - timer

    Ks, Kt = np.mgrid[-np.floor(ms / 2):np.floor((ms - 1) / 2 + 1),
                      -np.floor(mt / 2):np.floor((mt - 1) / 2 + 1)]
    for ii in np.arange(0, num_samples):
        Xest[ii] = np.sum(fk * np.exp(1j * (Ks * xj[ii] + Kt * yj[ii])))
        Xtrue[ii] = cj[ii]
    print_report('finufft2d2', elapsed, Xest, Xtrue, nj)

    x = np.random.rand(nj) * 2 * math.pi - math.pi
    y = np.random.rand(nj) * 2 * math.pi - math.pi
    c = np.random.rand(nj) + 1j * np.random.rand(nj)
    s = np.random.rand(nk) * 2 * math.pi - math.pi
    t = np.random.rand(nk) * 2 * math.pi - math.pi
    f = np.zeros([nk], dtype=np.complex128)
    timer = time.time()
    ret = finufftpy.nufft2d3(x, y, c, iflag, eps, s, t, f)
    elapsed = time.time() - timer

    for ii in np.arange(0, num_samples):
        Xest[ii] = np.sum(c * np.exp(1j * (s[ii] * x + t[ii] * y)))
        Xtrue[ii] = f[ii]
    print_report('finufft2d3', elapsed, Xest, Xtrue, nj + nk)

    ###### 3-d
    ms = int(np.ceil(num_uniform_points**(1.0 / 3)))
    mt = ms
    mu = ms

    xj = np.random.rand(nj) * 2 * math.pi - math.pi
    yj = np.random.rand(nj) * 2 * math.pi - math.pi
    zj = np.random.rand(nj) * 2 * math.pi - math.pi
    cj = np.random.rand(nj) + 1j * np.random.rand(nj)
    fk = np.zeros([ms, mt, mu], dtype=np.complex128, order='F')
    timer = time.time()
    ret = finufftpy.nufft3d1(xj, yj, zj, cj, iflag, eps, ms, mt, mu, fk)
    elapsed = time.time() - timer

    Ks, Kt, Ku = np.mgrid[-np.floor(ms / 2):np.floor((ms - 1) / 2 + 1),
                          -np.floor(mt / 2):np.floor((mt - 1) / 2 + 1),
                          -np.floor(mu / 2):np.floor((mu - 1) / 2 + 1)]
    for ii in np.arange(0, num_samples):
        Xest[ii] = np.sum(cj * np.exp(
            1j *
            (Ks.ravel()[ii] * xj + Kt.ravel()[ii] * yj + Ku.ravel()[ii] * zj)))
        Xtrue[ii] = fk.ravel()[ii]
    print_report('finufft3d1', elapsed, Xest, Xtrue, nj)

    xj = np.random.rand(nj) * 2 * math.pi - math.pi
    yj = np.random.rand(nj) * 2 * math.pi - math.pi
    zj = np.random.rand(nj) * 2 * math.pi - math.pi
    cj = np.zeros([nj], dtype=np.complex128)
    fk = np.random.rand(ms, mt, mu) + 1j * np.random.rand(ms, mt, mu)
    timer = time.time()
    ret = finufftpy.nufft3d2(xj, yj, zj, cj, iflag, eps, fk)
    elapsed = time.time() - timer

    Ks, Kt, Ku = np.mgrid[-np.floor(ms / 2):np.floor((ms - 1) / 2 + 1),
                          -np.floor(mt / 2):np.floor((mt - 1) / 2 + 1),
                          -np.floor(mu / 2):np.floor((mu - 1) / 2 + 1)]
    for ii in np.arange(0, num_samples):
        Xest[ii] = np.sum(
            fk * np.exp(1j * (Ks * xj[ii] + Kt * yj[ii] + Ku * zj[ii])))
        Xtrue[ii] = cj[ii]
    print_report('finufft3d2', elapsed, Xest, Xtrue, nj)

    x = np.random.rand(nj) * 2 * math.pi - math.pi
    y = np.random.rand(nj) * 2 * math.pi - math.pi
    z = np.random.rand(nj) * 2 * math.pi - math.pi
    c = np.random.rand(nj) + 1j * np.random.rand(nj)
    s = np.random.rand(nk) * 2 * math.pi - math.pi
    t = np.random.rand(nk) * 2 * math.pi - math.pi
    u = np.random.rand(nk) * 2 * math.pi - math.pi
    f = np.zeros([nk], dtype=np.complex128)
    timer = time.time()
    ret = finufftpy.nufft3d3(x, y, z, c, iflag, eps, s, t, u, f)
    elapsed = time.time() - timer

    for ii in np.arange(0, num_samples):
        Xest[ii] = np.sum(c * np.exp(1j * (s[ii] * x + t[ii] * y + u[ii] * z)))
        Xtrue[ii] = f[ii]
    print_report('finufft3d3', elapsed, Xest, Xtrue, nj + nk)
示例#5
0
 def EwaldFarVel(self, ptsxyz, forces, nThr):
     """
     This function computes the far field Ewald velocity. 
     Inputs: ptsxyz = the list of points 
     in undeformed, Cartesian coordinates, forces = forces at those points.
     This function relies entirely on calls to FINUFFT. See the documentation
     there for more information.
     """
     # Compute the coordinates in the transformed basis
     pts = self._currentDomain.primecoords(ptsxyz)
     # Rescale to [-pi,pi] (for FINUFFT)
     Lens = self._currentDomain.getPeriodicLens()
     pts = 2 * pi * np.mod(pts, Lens) / Lens - pi
     # Forcing on the grid (FINUFFT type 1)
     fi.nufft3d1(pts[:,0],pts[:,1],pts[:,2],forces[:,0],-1,fartol,\
                 self._nx,self._ny,self._nz,self._fxhat,modeord=1,nThreads=nThr)
     fi.nufft3d1(pts[:,0],pts[:,1],pts[:,2],forces[:,1],-1,fartol,\
                 self._nx,self._ny,self._nz,self._fyhat,modeord=1,nThreads=nThr)
     fi.nufft3d1(pts[:,0],pts[:,1],pts[:,2],forces[:,2],-1,fartol,\
                 self._nx,self._ny,self._nz,self._fzhat,modeord=1,nThreads=nThr)
     # Manipulation in Fourier space
     kxP, kyP, kzP = self._currentDomain.primeWaveNumbersFromUnprimed(
         self._kx, self._ky, self._kz)
     k = np.sqrt(kxP * kxP + kyP * kyP + kzP * kzP)
     # Multiplication factor for the RPY tensor
     factor = 1.0 / (self._mu * k * k) * np.sinc(k * self._a / pi)**2
     factor *= (1 + k * k / (4 * self._xi * self._xi)) * np.exp(
         -k * k / (4 * self._xi * self._xi))
     # splitting function
     factor[0, 0, 0] = 0
     # zero out 0 mode
     uxhat = factor * self._fxhat
     uyhat = factor * self._fyhat
     uzhat = factor * self._fzhat
     # Project off so we get divergence free
     uprojx = uxhat - (kxP * uxhat + kyP * uyhat + kzP * uzhat) * kxP / (k *
                                                                         k)
     uprojx[0, 0, 0] = 0
     uprojy = uyhat - (kxP * uxhat + kyP * uyhat + kzP * uzhat) * kyP / (k *
                                                                         k)
     uprojy[0, 0, 0] = 0
     uprojz = uzhat - (kxP * uxhat + kyP * uyhat + kzP * uzhat) * kzP / (k *
                                                                         k)
     uprojz[0, 0, 0] = 0
     # Velocities at the points (FINUFFT type 2)
     fi.nufft3d2(pts[:, 0],
                 pts[:, 1],
                 pts[:, 2],
                 self._ufarx,
                 1,
                 fartol,
                 uprojx,
                 modeord=1,
                 nThreads=nThr)
     fi.nufft3d2(pts[:, 0],
                 pts[:, 1],
                 pts[:, 2],
                 self._ufary,
                 1,
                 fartol,
                 uprojy,
                 modeord=1,
                 nThreads=nThr)
     fi.nufft3d2(pts[:, 0],
                 pts[:, 1],
                 pts[:, 2],
                 self._ufarz,
                 1,
                 fartol,
                 uprojz,
                 modeord=1,
                 nThreads=nThr)
     vol = self._currentDomain.getVol()
     return np.concatenate(
         ([np.real(self._ufarx) / vol], [np.real(self._ufary) / vol],
          [np.real(self._ufarz) / vol])).T
    def apply_affine_transform(
            self,
            tensor: torch.Tensor,
            scaling_params: List[float],
            rotation_params: List[float],
            padding_values: List[float]
            ) -> torch.Tensor:
        assert tensor.ndim == 4
        assert len(tensor) == 1

        from torchio.transforms.augmentation.intensity.random_motion_from_time_course import create_rotation_matrix_3d
        import math
        import finufftpy

        image = tensor[0]
        #noise_mean, nois_std = estimate_borders_mean_std(np.abs(image.numpy())) #random_noise gives negativ values ...
        noise_mean, nois_std = estimate_borders_mean_std(image.numpy())
        original_image_shape = image.shape
        if self.oversampling_pct > 0.0:
            if len(padding_values) == 2: #mean std
                padd_mode = 'random.normal'
            else:
                padd_mode = 'constant'

            image = self._oversample(image, self.oversampling_pct, padding_mode=padd_mode,
                                     padding_normal=padding_values)

        #im_freq_domain = (np.fft.fftshift(np.fft.fftn(np.fft.ifftshift(image)))).astype(np.complex128)
        im_freq_domain = self._fft_im(image)
        #if self.oversampling_pct > 0.0:
        #    im_freq_domain = self._oversample(im_freq_domain, self.oversampling_pct,
        #                                      padding_mode='random.normal', padding_normal=(noise_mean, nois_std))


        rrrot = -np.radians(rotation_params); rrrot[1] = - rrrot[1] #to get the same as sitk ... hmmm
        rotation_matrices = create_rotation_matrix_3d(rrrot)
        scaling_matrices = np.eye(3) / np.array(scaling_params) #/ to have same convention as
        rotation_matrices = np.matmul(rotation_matrices, scaling_matrices)
        im_shape = im_freq_domain.shape

        center = [math.ceil((x - 1) / 2) for x in im_shape]

        [i1, i2, i3] = np.meshgrid(2*(np.arange(im_shape[0]) - center[0])/im_shape[0],
                                   2*(np.arange(im_shape[1]) - center[1])/im_shape[1],
                                   2*(np.arange(im_shape[2]) - center[2])/im_shape[2], indexing='ij')

        grid_coordinates = np.array([i1.flatten('F'), i2.flatten('F'), i3.flatten('F')])

        method='one_matrix'
        if method=='one_matrix':
            new_grid_coords = np.matmul(rotation_matrices, grid_coordinates)
        else: #grrr marche pas ! (inspirer de random_motion_from_time_course
            rotation_matrices = np.expand_dims(rotation_matrices, [2, 3, 4])
            rotation_matrices = np.tile(rotation_matrices, [1, 1] + list(im_shape))  # 3 x 3 x img_shape
            rotation_matrices = rotation_matrices.reshape([-1, 3, 3], order='F')

            # tile grid coordinates for vectorizing computation
            grid_coordinates_tiled = np.tile(grid_coordinates, [3, 1])
            grid_coordinates_tiled = grid_coordinates_tiled.reshape([3, -1], order='F').T
            rotation_matrices = rotation_matrices.reshape([-1, 3]) #reshape for matrix multiplication, so no order F

            new_grid_coords = (rotation_matrices * grid_coordinates_tiled).sum(axis=1)
            # reshape new grid coords back to 3 x nvoxels
            new_grid_coords = new_grid_coords.reshape([3, -1], order='F')

        # scale data between -pi and pi
        max_vals = [1, 1, 1]
        new_grid_coordinates_scaled = [(new_grid_coords[i, :] / max_vals[i]) * math.pi for i in [0, 1, 2]]

        # initialize array for nufft output
        f = np.zeros([len(new_grid_coordinates_scaled[0])], dtype=np.complex128, order='F')

        freq_domain_data_flat = np.asfortranarray(im_freq_domain.flatten(order='F'))
        iflag, eps = 1,  1E-7
        finufftpy.nufft3d1(new_grid_coordinates_scaled[0], new_grid_coordinates_scaled[1],
                           new_grid_coordinates_scaled[2], freq_domain_data_flat,
                           iflag, eps, im_shape[0], im_shape[1], im_shape[2],
                           f, debug=0, spread_debug=0, spread_sort=2, fftw=0, modeord=0,
                           chkbnds=0, upsampfac=1.25)  # upsampling at 1.25 saves time at low precisions
        im_out = f.reshape(im_shape, order='F')
        im_out = abs(im_out / im_out.size)

        if im_shape[0] - original_image_shape[0]:
            im_out = self.crop_volume(im_out, original_image_shape)

        #ov(im_out)

        tensor[0] = torch.from_numpy(im_out)

        return tensor