class KECLEFT:
    '''
    Class to calculate power spectra up to one loop in "expanded LPT,"
    i.e. wherein the long displacement A_{ij} are expanded but the bias basis is Lagrangian.
    
    This is like CLEFT, but with exponent "E"xpanded and separated into powers of k.
    
    All bias tables are formatted as 1, b1, b1^2, b2, b1b2, b2^2, bs, b1bs, b2bs, bs^2, b3, b1b3 (ncol = 12).
    
    '''
    def __init__(self,
                 k,
                 p,
                 one_loop=True,
                 shear=True,
                 third_order=True,
                 cutoff=20,
                 jn=5,
                 N=4000,
                 threads=1,
                 extrap_min=-4,
                 extrap_max=3,
                 import_wisdom=False,
                 wisdom_file='./zelda_wisdom.npy'):

        self.N = N
        self.extrap_max = extrap_max
        self.extrap_min = extrap_min

        self.cutoff = cutoff
        self.kint = np.logspace(extrap_min, extrap_max, self.N)
        self.qint = np.logspace(-extrap_max, -extrap_min, self.N)

        self.one_loop = one_loop
        self.shear = shear
        self.third_order = third_order

        self.update_power_spectrum(k, p)

        self.pktable = None
        if self.third_order:
            self.num_power_components = 12
        elif self.shear:
            self.num_power_components = 10
        else:
            self.num_power_components = 6

        self.jn = jn
        self.threads = threads
        self.import_wisdom = import_wisdom
        self.wisdom_file = wisdom_file
        self.sph = SphericalBesselTransform(self.qint,
                                            L=self.jn,
                                            ncol=self.num_power_components,
                                            threads=self.threads,
                                            import_wisdom=self.import_wisdom,
                                            wisdom_file=self.wisdom_file)

    def update_power_spectrum(self, k, p):
        # Updates the power spectrum and various q functions. Can continually compute for new cosmologies without reloading FFTW
        self.k = k
        self.p = p
        self.pint = loginterp(k, p)(
            self.kint) * np.exp(-(self.kint / self.cutoff)**2)
        self.setup_powerspectrum()

    def setup_powerspectrum(self):

        # This sets up terms up to one looop in the combination (symmetry factors) they appear in pk

        self.qf = QFuncFFT(self.kint,
                           self.pint,
                           qv=self.qint,
                           oneloop=self.one_loop,
                           shear=self.shear,
                           third_order=self.third_order)

        # linear terms
        self.Xlin = self.qf.Xlin
        self.Ylin = self.qf.Ylin

        self.Ulin = self.qf.Ulin
        self.corlin = self.qf.corlin

        if self.one_loop:
            # one loop terms: here we add in all the symmetry factors
            self.Xloop = 2 * self.qf.Xloop13 + self.qf.Xloop22
            self.sigmaloop = self.Xloop[-1]
            self.Yloop = 2 * self.qf.Yloop13 + self.qf.Yloop22

            self.Vloop = 3 * (2 * self.qf.V1loop112 + self.qf.V3loop112
                              )  # this multiplies mu in the pk integral
            self.Tloop = 3 * self.qf.Tloop112  # and this multiplies mu^3

            self.X10 = 2 * self.qf.X10loop12
            self.Y10 = 2 * self.qf.Y10loop12
            self.sigma10 = (self.X10 + self.Y10)[-1]

            self.U3 = self.qf.U3
            self.U11 = self.qf.U11
            self.U20 = self.qf.U20
            self.Us2 = self.qf.Us2

        else:
            self.Xloop, self.Yloop, self.sigmaloop, self.Vloop, self.Tloop, self.X10, self.Y10, self.sigma10, self.U3, self.U11, self.U20, self.Us2 = (
                0, ) * 12

        # load shear functions
        if self.shear or self.third_order:
            self.Xs2 = self.qf.Xs2
            self.Ys2 = self.qf.Ys2
            self.sigmas2 = (self.Xs2 + self.Ys2)[-1]
            self.V = self.qf.V
            self.zeta = self.qf.zeta
            self.chi = self.qf.chi

        if self.third_order:
            self.Ub3 = self.qf.Ub3
            self.theta = self.qf.theta

    # The various contributions to P(k) are organized into
    # (1) Linear Theory
    # (2) Connected: this comes from terms that come from connected LPT cumulants that can be Fourier-transformed directly
    # (3) p_k#: this comes from disconnected contributions proportional to k^#
    # Once separated these Hankel transform individually at once for all k.

    def compute_p_linear(self):
        self.p_linear = np.zeros((self.num_power_components, self.N))
        self.p_linear[0, :] = self.pint
        self.p_linear[1, :] = 2 * self.pint
        self.p_linear[2, :] = self.pint

    def compute_p_connected(self):

        self.p_connected = np.zeros((self.num_power_components, self.N))
        self.p_connected[
            0, :] = 9. / 98 * self.qf.Q1 + 10. / 21 * self.qf.R1 + 3. / 7 * (
                2 * self.qf.R2 + self.qf.Q2)
        self.p_connected[1, :] = 10. / 21 * self.qf.R1 + 1. / 7 * (
            6 * self.qf.R1 + 12 * self.qf.R2 + 6 * self.qf.Q5)
        self.p_connected[2, :] = 6. / 7 * (self.qf.R1 + self.qf.R2)
        self.p_connected[3, :] = 3. / 7 * self.qf.Q8

        if self.shear or self.third_order:
            self.p_connected[6, :] = 2. / 7 * self.qf.Qs2

        if self.third_order:
            self.p_connected[10, :] = 2 * self.qf.Rb3 * self.pint
            self.p_connected[11, :] = 2 * self.qf.Rb3 * self.pint

    def compute_p_k0(self):
        self.p_k0 = np.zeros((self.num_power_components, self.N))
        ret = np.zeros(self.num_power_components)

        bias_integrands = np.zeros((self.num_power_components, self.N))

        for l in range(1):
            mu0fac = (l == 0)

            bias_integrands[5, :] = mu0fac * (0.5 * self.corlin**2)

            if self.shear or self.third_order:
                bias_integrands[8, :] = mu0fac * self.chi
                bias_integrands[9, :] = mu0fac * self.zeta

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph.sph(l, bias_integrands)
            self.p_k0 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_p_k1(self):
        self.p_k1 = np.zeros((self.num_power_components, self.N))
        ret = np.zeros(self.num_power_components)

        bias_integrands = np.zeros((self.num_power_components, self.N))

        for l in [1]:
            mu1fac = (l == 1)

            bias_integrands[4, :] = mu1fac * (-2 * self.Ulin * self.corlin)

            if self.shear or self.third_order:
                bias_integrands[7, :] = mu1fac * (-2 * self.V)

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph.sph(l, bias_integrands)
            self.p_k1 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_p_k2(self):
        self.p_k2 = np.zeros((self.num_power_components, self.N))
        ret = np.zeros(self.num_power_components)

        bias_integrands = np.zeros((self.num_power_components, self.N))

        for l in range(3):
            mu0fac = (l == 0)
            mu1fac = (l == 1)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)

            bias_integrands[2,:] = mu0fac * (- 0.5*self.Xlin*self.corlin)  + \
                                   mu2fac * (-0.5*self.Ylin*self.corlin - self.Ulin**2)
            bias_integrands[3, :] = mu2fac * (-self.Ulin**2)

            if self.shear or self.third_order:
                bias_integrands[
                    6, :] = mu0fac * (-self.Xs2) + mu2fac * (-self.Ys2)

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph.sph(l, bias_integrands)
            self.p_k2 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_p_k3(self):
        self.p_k3 = np.zeros((self.num_power_components, self.N))
        ret = np.zeros(self.num_power_components)

        bias_integrands = np.zeros((self.num_power_components, self.N))

        for l in range(self.jn):
            mu0fac = (l == 0)
            mu1fac = (l == 1)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)
            mu3fac = 0.6 * (l == 1) - 0.4 * (l == 3)


            bias_integrands[1,:] = mu1fac * (self.Ulin*self.Xlin) + \
                                   mu3fac * (self.Ulin*self.Ylin)

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph.sph(l, bias_integrands)
            self.p_k3 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_p_k4(self):
        self.p_k4 = np.zeros((self.num_power_components, self.N))
        ret = np.zeros(self.num_power_components)

        bias_integrands = np.zeros((self.num_power_components, self.N))

        for l in range(self.jn):
            mu0fac = (l == 0)
            mu1fac = (l == 1)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)
            mu3fac = 0.6 * (l == 1) - 0.4 * (l == 3)
            mu4fac = 0.2 * (l == 0) - 4. / 7 * (l == 2) + 8. / 35 * (l == 4)


            bias_integrands[0,:] = mu0fac * (+1./8*self.Xlin**2 ) + \
                                   mu2fac * (1./4*self.Xlin*self.Ylin  ) + \
                                   mu4fac * (1./8*self.Ylin**2)

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph.sph(l, bias_integrands)
            self.p_k4 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def make_ptable(self, kmin=1e-3, kmax=3, nk=100):
        '''
        Make a table of different terms of P(k) between a given
        'kmin', 'kmax' and for 'nk' equally spaced values in log10 of k
        This is the most time consuming part of the code.
        '''
        self.pktable = np.zeros([nk, self.num_power_components + 1
                                 ])  # one column for ks
        kv = np.logspace(np.log10(kmin), np.log10(kmax), nk)
        self.pktable[:, 0] = kv[:]
        for ii in range(self.num_power_components):
            self.pktable[:, ii+1] = interp1d(self.kint,\
                                       self.p_linear[ii,:] + self.p_connected[ii,:] \
                                       + self.p_k0[ii,:] + self.kint * self.p_k1[ii,:] + self.kint**2 * self.p_k2[ii,:]\
                                       + self.kint**3 * self.p_k3[ii,:] + self.kint**4 * self.p_k4[ii,:])(kv)

    def export_wisdom(self, wisdom_file='./wisdom.npy'):
        self.sph.export_wisdom(wisdom_file=wisdom_file)
예제 #2
0
class LPT_RSD:

    '''
        Class to evaluate the one-loop power spectrum in redshift space with the
        linear velocities resummed. See arXiv:XXXX
        
        Throughout this code we refer to mu_q as "mu" and mu = n.k as "nu."
    '''

    def __init__(self, k, p, third_order = True, shear=True, one_loop=True,\
                 kIR = None, cutoff=10, jn=5, N = 2000, threads=None, extrap_min = -5, extrap_max = 3):

        self.N = N
        self.extrap_max = extrap_max
        self.extrap_min = extrap_min
        
        self.kIR = kIR
        self.cutoff = cutoff
        self.kint = np.logspace(extrap_min,extrap_max,self.N)
        self.qint = np.logspace(-extrap_max,-extrap_min,self.N)
        
        self.third_order = third_order
        self.shear = shear or third_order
        self.one_loop = one_loop
        
        self.k = k
        self.p = p
        self.pint = loginterp(k,p)(self.kint) * np.exp(-(self.kint/self.cutoff)**2)
        self.setup_powerspectrum()
        
        self.pktables = {}
        
        if self.third_order:
            self.num_power_components = 13
        elif self.shear:
            self.num_power_components = 11
        else:
            self.num_power_components = 7
        
        self.jn = jn
        
        if threads is None:
            self.threads = int( os.getenv("OMP_NUM_THREADS","1") )
        else:
            self.threads = threads

        self.sph = SphericalBesselTransform(self.qint, L=self.jn, ncol=self.num_power_components, threads=self.threads)
        self.sph1 = SphericalBesselTransform(self.qint, L=self.jn, ncol=1, threads=self.threads)
        self.sphr = SphericalBesselTransformNP(self.kint,L=5,fourier=True)

    
    def setup_powerspectrum(self):
    
        # This sets up terms up to one loop in the combination (symmetry factors) they appear in pk
    
        self.qf = QFuncFFT(self.kint, self.pint, kIR=self.kIR, qv=self.qint, oneloop=self.one_loop, shear=self.shear, third_order=self.third_order)
    
        # linear terms
        self.Xlin = self.qf.Xlin_lt
        self.Ylin = self.qf.Ylin_lt
        self.XYlin = self.Xlin + self.Ylin; self.sigma = self.XYlin[-1]
        self.yq = self.Ylin / self.qint
        
        self.Xlin_gt = self.qf.Xlin_gt
        self.Ylin_gt = self.qf.Ylin_gt
        
        self.Ulin = self.qf.Ulin
        self.corlin = self.qf.corlin
        
        # load shear functions
        if self.shear:
            self.Xs2 = self.qf.Xs2
            self.Ys2 = self.qf.Ys2; self.sigmas2 = (self.Xs2 + self.Ys2)[-1]
            self.V = self.qf.V
            self.zeta = self.qf.zeta
            self.chi = self.qf.chi
            
        if self.one_loop:

            self.X13, self.Y13 = self.qf.Xloop13, self.qf.Yloop13
            self.X22, self.Y22 = self.qf.Xloop22, self.qf.Yloop22
            
            # These are the decomposition for W112, which we need independently
            self.V1, self.V3 = self.qf.V1loop112, self.qf.V3loop112
            self.T = self.qf.Tloop112
        
            self.X10 = 2 * self.qf.X10loop12
            self.Y10 = 2 * self.qf.Y10loop12
            self.sigma10 = (self.X10 + self.Y10)[-1]
        
            self.U3 = self.qf.U3
            self.U11 = self.qf.U11
            self.U20 = self.qf.U20
            self.Us2 = self.qf.Us2
            self.Ub3 = self.qf.Ub3
            self.theta = self.qf.theta

        else:
            self.X13, self.Y13, self.X22, self.Y22, self.sigmaloop, self.V1, self.V3, self.T, self.Tloop, self.X10, self.Y10, self.sigma10, self.U3, self.U11, self.U20, self.Us2, self.Ub3, self.theta = (0,)*18
            
    #### Define RSD Kernels #######
    
    def setup_rsd_facs(self,f,nu,D=1,nmax=10):
    
        self.f = f
        self.nu = nu
        self.D = D
        self.Kfac = np.sqrt(1+f*(2+f)*nu**2); self.Kfac2 = self.Kfac**2
        self.s = f*nu*np.sqrt(1-nu**2)/self.Kfac
        self.c = np.sqrt(1-self.s**2); self.c2 = self.c**2; self.ic2 = 1/self.c2; self.c3 = self.c**3
        self.Bfac = -0.5 * self.Kfac2 * self.Ylin * self.D**2 # this times k is "B"
        
        # Define Anu, Bnu such that \hn \cdot \hq = Anu * mu + Bnu * sqrt(1-mu^2) cos(phi)
        self.Anu, self.Bnu = self.nu * (1 + f) / self.Kfac, np.sqrt(1-nu**2) / self.Kfac
        
        # Compute derivatives
        # Each is a function of f, nu times (kq)^(-n) for the nth derivative
        
        # and the hypergeometric functions
        self.hyp1 = np.zeros( (self.jn+nmax, self.jn+nmax))
        self.hyp2 = np.zeros( (self.jn+nmax, self.jn+nmax))
        self.fnms = np.zeros( (self.jn+nmax, self.jn+nmax))
        
        for n in range(self.jn+nmax):
            for m in range(self.jn+nmax):
                self.hyp1[n,m] = hyp2f1(0.5-n,-n,0.5-m-n,self.ic2)
                self.hyp2[n,m] = hyp2f1(1.5-n,-n,0.5-m-n,self.ic2)
                self.fnms[n,m] = gamma(m+n+0.5)/gamma(m+1)/gamma(n+0.5)/gamma(1-m+n)
        
        self.G0_l_ns = np.zeros( (self.jn,nmax) )
        self.dG0dA_l_ns = np.zeros( (self.jn,nmax) )
        self.d2G0dA2_l_ns = np.zeros( (self.jn,nmax) )
        self.dG0dC_l_ns = np.zeros( (self.jn,nmax) )
        self.d2G0dCdA_l_ns = np.zeros( (self.jn,nmax) )
        self.d2G0dC2_l_ns = np.zeros( (self.jn,nmax) )
        self.d3G0dA3_l_ns = np.zeros( (self.jn,nmax) )
        self.d3G0dCdA2_l_ns = np.zeros( (self.jn,nmax) )
        self.d4G0dA4_l_ns = np.zeros( (self.jn,nmax) )
        
        for ll in range(self.jn):
            for nn in range(nmax):
                self.G0_l_ns[ll,nn] = self._G0_l_n(ll+nn,ll)
                self.dG0dA_l_ns[ll,nn] = self._dG0dA_l_n(ll+nn,ll)
                self.d2G0dA2_l_ns[ll,nn] = self._d2G0dA2_l_n(ll+nn,ll)
                
                # One loop terms
                self.dG0dC_l_ns[ll,nn] = self._dG0dC_l_n(ll+nn,ll)
                self.d2G0dCdA_l_ns[ll,nn] = self._d2G0dCdA_l_n(ll+nn,ll)
                self.d2G0dC2_l_ns[ll,nn] = self._d2G0dC2_l_n(ll+nn,ll)
                self.d3G0dA3_l_ns[ll,nn] = self._d3G0dA3_l_n(ll+nn,ll)
                self.d3G0dCdA2_l_ns[ll,nn] = self._d3G0dCdA2_l_n(ll+nn,ll)
                self.d4G0dA4_l_ns[ll,nn] = self._d4G0dA4_l_n(ll+nn,ll)
                
        # Also precompute the (BA^2/rho^2) factor
        self.powerfacs = np.array([ (self.Bfac /self.ic2)**n for n in range(self.jn + nmax) ]) # does not include factor of k^2n
        

        
    
    def _G0_l_n(self,n,m):
        x = self.ic2

        return  self.fnms[n,m] * self.hyp1[n,m]
    
    
    def _dG0dA_l_n(self,n,m):
        # Note that in the derivatives we omit factors of (kq)^n left in comments for speedier vector evaluation later
    
        x = self.ic2

        ret = self.s * (-self.hyp1[n,m] + (1-2*n)*self.hyp2[n,m])
        ret *= - self.s
        
        return self.fnms[n,m] * ret # / (k*self.qint)
    
    def _d2G0dA2_l_n(self,n,m):
        x = self.ic2
        
        ret = (1-1./x) * ( (2*m-1-4*n*(m+1))*self.hyp1[n,m] \
                                                +(1-4*n**2+m*(4*n-2))*self.hyp2[n,m] )
        return self.fnms[n,m] * ret #/(k*self.qint)**2
        
    def _dG0dC_l_n(self,n,m):
        x = self.ic2

        ret = self.s * (-self.hyp1[n,m] + (1-2*n)*self.hyp2[n,m])
        
        return self.fnms[n,m] * ret # / (k*self.qint)
        
    def _d2G0dCdA_l_n(self,n,m):
        x = self.ic2
        
        ret  = - ( 2*(m - 2*n*(1+m))*self.c**2 + self.s**2 ) * self.hyp1[n,m]
        ret += (1-2*n) * ( 2*(m-n)*self.c**2 + self.s**2 ) * self.hyp2[n,m]
        
        ret *= self.s / self.c
        
        return self.fnms[n,m] * ret # /(k*self.qint)**2
        
    def _d2G0dC2_l_n(self,n,m):
        x = self.ic2

        ret  = ( (1+2*m-4*n*(1+m))*self.c**2 + 2*self.s**2 ) * self.hyp1[n,m]
        ret += -(1-2*n) * ( (1+2*m-2*n)*self.c**2 + 2*self.s**2 ) * self.hyp2[n,m]
                
        return self.fnms[n,m] * ret # / (k*self.qint)**2
        
    def _d3G0dA3_l_n(self,n,m):
        x = self.ic2
        
        coeff1A = 2*(1-m)*(1-2*m) + 8*(2-m)*(1+m)*n + 8*n**2*(1+m)
        coeff1C = - (1-2*m+4*n*(1+m))
        ret = (coeff1A * self.c**2 + coeff1C * self.s**2) * self.hyp1[n,m]
        
        coeff2A = -(1-2*n)*( 2*(1-2*m+2*n)*(1-m+n) )
        coeff2C = (1-2*n)*(1-2*m+4*n*(1+m))
        ret += (coeff2A * self.c**2 + coeff2C * self.s**2) * self.hyp2[n,m]

        ret *= (self.s**2/self.c)
        
        return self.fnms[n,m] * ret # / (k*self.qint)**3
        
    
    def _d4G0dA4_l_n(self,n,m):
        x = self.ic2
        
        coeff1A = -6 + 22*m - 24*m**2 + 8*m**3 \
                 + n*(-76 - 28*m + 32*m**2 - 16*m**3) \
                 + n**2 * (-56 - 24*m + 32*m**2 ) + n**3 * ( -16 - 16*m )
        coeff1C = 9 - 24*m + 12*m**2 + n * (56 + 24*m - 32*m**2) +\
                   n**2 * (32 + 48*m + 16*m**2)
        ret = (coeff1A * self.c**2 + coeff1C * self.s**2) * self.hyp1[n,m]
        
        coeff2A = 2*(-3+2*m-2*n)*(1-2*m+2*n)*(1-m+n)
        coeff2C = 9 - 24*m + 12*m**2 + n*(44 + 8*m - 16*m**2) + n**2 * (20 + 16*m)
        ret += -(1-2*n)*(coeff2A * self.c**2 + coeff2C * self.s**2) * self.hyp2[n,m]

        ret *= self.fnms[n,m] * self.s**2 # / (k*self.qint)**4
        
        return ret

        
    # dG/dA^2dC
    def _d3G0dCdA2_l_n(self,n,m):
        x = self.ic2
        
        coeff1 =  2 * (m-2*m**2-4*n*(1-m**2)-4*n**2*(1+m)) * self.c**2
        coeff1 += 3 * (1-2*m+4*n*(1+m)) * self.s**2
        
        coeff2 = 2 * (1-2*m+2*n)*(m-n) * self.c**2
        coeff2 += (3-6*m+8*n+4*m*n) * self.s**2
        coeff2 *= - (1-2*n)
        
        ret  = coeff1 * self.hyp1[n,m]
        ret += coeff2 * self.hyp2[n,m]
        
        ret *= self.s
        
        return self.fnms[n,m] * ret # / (k*self.qint)**3
    

    def _G0_l(self,l,k, nmax=10):
        
        summand =  (k**(2* (l+np.arange(nmax))) * self.G0_l_ns[l,:nmax])[:,None] * self.powerfacs[l:l+nmax,:]
        return np.sum(summand,axis=0)
        
        
    
    def _dG0dA_l(self,l,k,nmax=10):
        
        summand =  (k**(2* (l+np.arange(nmax))) * self.dG0dA_l_ns[l,:nmax])[:,None] * self.powerfacs[l:l+nmax,:]
        return np.sum(summand,axis=0) / (k*self.qint)
    
    
    def _d2G0dA2_l(self,l,k,nmax=10):

        
        summand =  (k**(2* (l+np.arange(nmax))) * self.d2G0dA2_l_ns[l,:nmax])[:,None] * self.powerfacs[l:l+nmax,:]
        return np.sum(summand,axis=0) / (k*self.qint)**2
        
    def _dG0dC_l(self,l,k,nmax=10):

        
        summand =  (k**(2* (l+np.arange(nmax))) * self.dG0dC_l_ns[l,:nmax])[:,None] * self.powerfacs[l:l+nmax,:]
        return np.sum(summand,axis=0) / (k*self.qint)
        
        
    def _d2G0dCdA_l(self,l,k,nmax=10):

        summand =  (k**(2* (l+np.arange(nmax))) * self.d2G0dCdA_l_ns[l,:nmax])[:,None] * self.powerfacs[l:l+nmax,:]
        return np.sum(summand,axis=0) / (k*self.qint)**2
        
    def _d2G0dC2_l(self,l,k,nmax=10):
    
        summand =  (k**(2* (l+np.arange(nmax))) * self.d2G0dC2_l_ns[l,:nmax])[:,None] * self.powerfacs[l:l+nmax,:]
        return np.sum(summand,axis=0) / (k*self.qint)**2
        
    def _d3G0dA3_l(self,l,k,nmax=10):

        summand =  (k**(2* (l+np.arange(nmax))) * self.d3G0dA3_l_ns[l,:nmax])[:,None] * self.powerfacs[l:l+nmax,:]
        return np.sum(summand,axis=0) / (k*self.qint)**3
    
    def _d3G0dCdA2_l(self,l,k,nmax=10):

        summand =  (k**(2* (l+np.arange(nmax))) * self.d3G0dCdA2_l_ns[l,:nmax])[:,None] * self.powerfacs[l:l+nmax,:]
        return np.sum(summand,axis=0) / (k*self.qint)**3
        
        
    def _d4G0dA4_l(self,l,k,nmax=10):

        summand =  (k**(2* (l+np.arange(nmax))) * self.d4G0dA4_l_ns[l,:nmax])[:,None] * self.powerfacs[l:l+nmax,:]
        return np.sum(summand,axis=0) / (k*self.qint)**4
        
    
    ### Now define the actual integrals!

    def p_integrals(self, k, nmax=8):
        
        ksq = k**2
        Kfac = self.Kfac
        f = self.f
        nu = self.nu
        Anu, Bnu = self.Anu, self.Bnu
        
        K = k*self.Kfac; Ksq = K**2
        Knfac = nu*(1+f)
        
        D2 = self.D**2; D4 = D2**2

        expon = np.exp(-0.5*Ksq * D2* (self.XYlin - self.sigma))
        exponm1 = np.expm1(-0.5*Ksq * D2* (self.XYlin - self.sigma))
        suppress = np.exp(-0.5*Ksq * D2* self.sigma)
            
            
        A = k*self.qint*self.c
        C = k*self.qint*self.s
        
        
        G0s =  [self._G0_l(ii,k,nmax=nmax)    for ii in range(self.jn)] + [0] + [0] + [0] + [0]
        dGdAs =  [self._dG0dA_l(ii,k,nmax=nmax) for ii in range(self.jn)] + [0] + [0] + [0]
        dGdCs = [self._dG0dC_l(ii,k,nmax=nmax) for ii in range(self.jn)] + [0] + [0] + [0]
        d2GdA2s = [self._d2G0dA2_l(ii,k,nmax=nmax) for ii in range(self.jn)] + [0] + [0]
        d2GdCdAs = [self._d2G0dCdA_l(ii,k,nmax=nmax) for ii in range(self.jn) ] + [0] + [0]
        d2GdC2s = [self._d2G0dC2_l(ii,k,nmax=nmax) for ii in range(self.jn) ] + [0] + [0]
        d3GdA3s = [self._d3G0dA3_l(ii,k,nmax=nmax) for ii in range(self.jn) ] + [0]
        d3GdCdA2s = [self._d3G0dCdA2_l(ii,k,nmax=nmax) for ii in range(self.jn) ] + [0]
        d4GdA4s = [self._d4G0dA4_l(ii,k,nmax=nmax) for ii in range(self.jn) ]
                
        G01s = [-(dGdAs[ii] + 0.5*A*G0s[ii-1])   for ii in range(self.jn)]
        G02s = [-(d2GdA2s[ii] + A * dGdAs[ii-1] + 0.5*G0s[ii-1] + 0.25 * A**2 *G0s[ii-2]) for ii in range(self.jn)]
        G03s = [d3GdA3s[ii] + 1.5*A*d2GdA2s[ii-1] + 1.5*dGdAs[ii-1] \
                 + 0.75*A**2*dGdAs[ii-2] + 0.75*A*G0s[ii-2] + A**3/8.*G0s[ii-3] for ii in range(self.jn)]
        G04s = [d4GdA4s[ii] + 2*A*d3GdA3s[ii-1] + 3*d2GdA2s[ii-1] \
                + 1.5*A**2*d2GdA2s[ii-2] + 3*A*dGdAs[ii-2] + 0.75*G0s[ii-2]\
                + 0.5*A**3*dGdAs[ii-3] + 0.75*A**2*G0s[ii-3]\
                + A**4/16. * G0s[ii-4] for ii in range(self.jn)]
                 
        G10s = [ dGdCs[ii] + 0.5*C*G0s[ii-1]  for ii in range(self.jn)]
        
        G11s = [ d2GdCdAs[ii] + 0.5*C*dGdAs[ii-1] + 0.5*A*dGdCs[ii-1] + 0.25*A*C*G0s[ii-2] for ii in range(self.jn)]
        G20s = [-(d2GdC2s[ii] + C * dGdCs[ii-1] + 0.5*G0s[ii-1] + 0.25 * C**2 *G0s[ii-2]) for ii in range(self.jn)]
        G12s = [-(d3GdCdA2s[ii] + 0.5*C*d2GdA2s[ii-1] + A*d2GdCdAs[ii-1] + 0.5*dGdCs[ii-1]\
                  + 0.5*A*C*dGdAs[ii-2] + 0.25*A**2*dGdCs[ii-2] + 0.25*C*G0s[ii-2] + A**2*C/8*G0s[ii-3])  for ii in range(self.jn)]

        ret = np.zeros(self.num_power_components)
            
        bias_integrands = np.zeros( (self.num_power_components,self.N)  )
                            
        for l in range(self.jn):
            
            mu0 = G0s[l]
            nq1 = self.Anu * G01s[l] + self.Bnu * G10s[l]
            mu_nq1 = self.Anu * G02s[l] + self.Bnu * G11s[l]
            nq2 = self.Anu**2 * G02s[l] + 2 * self.Anu * self.Bnu * G11s[l] + self.Bnu**2 * self.Bnu**2 * G20s[l]
            mu1 = G01s[l]
            mu2 = G02s[l]
            mu3 = G03s[l]
            mu2_nq1 = self.Anu * G03s[l] + self.Bnu * G12s[l]
            mu4 = G04s[l]
            
            bias_integrands[0,:] = 1 * G0s[l] - 0.5 * Ksq * (self.Xlin_gt * G0s[l] + self.Ylin_gt * mu2) # za
            
            bias_integrands[0,:] += -0.5 * ksq * ( 2*(Kfac**2 + 2*f*(1+f)*nu**2) * G0s[l] * self.X13 +\
                                                  2*(Kfac**2*mu2 + 2*f*Kfac*nu*mu_nq1) * self.Y13 +\
                                                   (Kfac**2 + 2*f*(1+f)*nu**2 + f**2*nu**2) * G0s[l] * self.X22 +\
                                                   (Kfac**2*mu2 + 2*f*Kfac*nu*mu_nq1 + f**2*nu**2*nq2) * self.Y22)\
                                 + Ksq**2 / 8. * (self.Xlin_gt**2 * G0s[l] + 2*self.Xlin_gt*self.Ylin_gt*mu2 + self.Ylin_gt**2 * mu4)# Aloop

                                            
            bias_integrands[0,:] += 0.5*k**3 * ( 2*Kfac*(Kfac**2+f*(1+f)*nu**2) * G01s[l] * self.V1 +  \
                                                Kfac**2 * (Kfac*G01s[l] + f*nu*nq1) * self.V3 + \
                                                Kfac**2 * (Kfac*G03s[l] + f*nu*mu2_nq1) * self.T)
                                                
            bias_integrands[1,:] = -2 * K * (self.Ulin + self.U3) * mu1 - Ksq * (self.X10 * mu0 + self.Y10 * mu2 ) \
                                   -4*f*k*nu*self.U3*nq1 - f*ksq*nu*(self.X10 * Knfac * mu0 + Kfac * self.Y10 * mu_nq1)\
                                   -2 * K * self.Ulin * ( -0.5*Ksq*(self.Xlin_gt*mu1 + self.Ylin_gt*mu3) )
                                   
            bias_integrands[2,:] = self.corlin * (mu0 - 0.5*Ksq*(self.Xlin_gt*mu0 + self.Ylin_gt*mu2) )\
                                   - Ksq*self.Ulin**2*mu2 - k*(Kfac*mu1 + f*k*nu*nq1)*self.U11
                                   
                                   
            bias_integrands[3,:] = - Ksq * self.Ulin**2 * mu2 - k*(Kfac*mu1 + f*nu*nq1)*self.U20 # b2
            bias_integrands[4,:] = -2 * K * self.Ulin * self.corlin * mu1 # b1b2
            bias_integrands[5,:] = 0.5 * self.corlin**2 * mu0 # b2sq
            
            if self.shear or self.third_order:
                bias_integrands[6,:] = - Ksq * (self.Xs2 * mu0 + self.Ys2 * mu2) - 2*k*(Kfac*mu1 + f*nu*nq1)*self.Us2 # bs should be both minus
                bias_integrands[7,:] = -2*K*self.V * mu1 # b1bs
                bias_integrands[8,:] = self.chi * mu0 # b2bs
                bias_integrands[9,:] = self.zeta * mu0 # bssq
                
            if self.third_order:
                bias_integrands[10,:] = -2 * K * self.Ub3 * mu1 #b3
                bias_integrands[11,:] = 2 * self.theta * mu0 #b1 b3
                
            bias_integrands[-1,:] = 1 * G0s[l] - 0.5 * Ksq * (self.Xlin_gt * G0s[l] + self.Ylin_gt * mu2) # za
                                   
            # multiply by IR exponent
            if l == 0:
                bias_integrands = bias_integrands * expon * (-2./k/self.qint)**l
                bias_integrands -= bias_integrands[:,-1][:,None]
            else:
                bias_integrands = bias_integrands * expon * (-2./k/self.qint)**l
                                                                
            # do FFTLog
            ktemps, bias_ffts = self.sph.sph(l, bias_integrands)
            ret += interp1d(ktemps, bias_ffts)(k)

        return 4*suppress*np.pi*ret
        

    def make_ptable(self, f, nu, kv = None, kmin = 1e-2, kmax = 0.25, nk = 50,nmax=5):
    
        self.setup_rsd_facs(f,nu,nmax=nmax)
        
        if kv is None:
            kv = np.logspace(np.log10(kmin), np.log10(kmax), nk)
        else:
            nk = len(kv)
            
        self.pktable = np.zeros([nk, self.num_power_components+1]) # one column for ks
        
        self.pktable[:, 0] = kv[:]
        for foo in range(nk):
            self.pktable[foo, 1:] = self.p_integrals(kv[foo],nmax=nmax)
        
        # store a copy in pktables dictionary
        self.pktables[nu] = np.array(self.pktable)
        



    def make_pltable(self,f, apar = 1, aperp = 1, ngauss = 3, kv = None, kmin = 1e-2, kmax = 0.25, nk = 50, nmax=8):
        '''
        Make a table of the monopole and quadrupole in k space.
        Uses gauss legendre integration.
            
        '''
        
        # since we are always symmetric in nu, can ignore negative values
        nus, ws = np.polynomial.legendre.leggauss(2*ngauss)
        nus_calc = nus[0:ngauss]
        
        L0 = np.polynomial.legendre.Legendre((1))(nus)
        L2 = np.polynomial.legendre.Legendre((0,0,1))(nus)
        L4 = np.polynomial.legendre.Legendre((0,0,0,0,1))(nus)
        
        #self.pknutable = np.zeros((len(nus),nk,self.num_power_components+3)) # counterterms have distinct nu structure
        # counterterms + stoch terms have distinct nu structure and have to be added here
        # e.g. k^2 mu^2 is not the same as k_obs^2 mu_obs^2!
        if kv is None:
            kv = np.logspace(np.log10(kmin), np.log10(kmax), nk)
        else:
            nk = len(kv)
        self.pknutable = np.zeros((len(nus),nk,self.num_power_components+6)) 
        
        
        # To implement AP:
        # Calculate P(k,nu) at the true coordinates, given by
        # k_true = k_apfac * kobs
        # nu_true = nu * a_perp/a_par/fac
        # Note that the integration grid on the other hand is never observed
        
        for ii, nu in enumerate(nus_calc):
        
            fac = np.sqrt(1 + nu**2 * ((aperp/apar)**2-1))
            k_apfac = fac / aperp
            nu_true = nu * aperp/apar/fac
            vol_fac = apar * aperp**2
        
            self.setup_rsd_facs(f,nu_true)
            
            for jj, k in enumerate(kv):
                ktrue = k_apfac * k
                pterms = self.p_integrals(ktrue,nmax=nmax)
                
                #self.pknutable[ii,jj,:-4] = pterms[:-1]
                self.pknutable[ii,jj,:-7] = pterms[:-1]
                
                # counterterms
                
                #self.pknutable[ii,jj,-4] = ktrue**2 * pterms[-1]
                #self.pknutable[ii,jj,-3] = ktrue**2 * nu_true**2 * pterms[-1]
                #self.pknutable[ii,jj,-2] = ktrue**2 * nu_true**4 * pterms[-1]
                #self.pknutable[ii,jj,-1] = ktrue**2 * nu_true**6 * pterms[-1]
                
                self.pknutable[ii,jj,-7] = ktrue**2 * pterms[-1]
                self.pknutable[ii,jj,-6] = ktrue**2 * nu_true**2 * pterms[-1]
                self.pknutable[ii,jj,-5] = ktrue**2 * nu_true**4 * pterms[-1]
                self.pknutable[ii,jj,-4] = ktrue**2 * nu_true**6 * pterms[-1]
                
                # stochastic terms
                self.pknutable[ii,jj,-3] = 1
                self.pknutable[ii,jj,-2] = ktrue**2 * nu_true**2
                self.pknutable[ii,jj,-1] = ktrue**4 * nu_true**4
        
        self.pknutable[ngauss:,:,:] = np.flip(self.pknutable[0:ngauss],axis=0)
        
        self.kv = kv
        self.p0ktable = 0.5 * np.sum((ws*L0)[:,None,None]*self.pknutable,axis=0) / vol_fac
        self.p2ktable = 2.5 * np.sum((ws*L2)[:,None,None]*self.pknutable,axis=0) / vol_fac
        self.p4ktable = 4.5 * np.sum((ws*L4)[:,None,None]*self.pknutable,axis=0) / vol_fac
        
        return 0

    def combine_bias_terms_pkmu(self,nu,bvec):
        '''
        Combine bias terms into P(k,nu) given the bias paramters and counterterms listed below.
        
        Returns k, pknu.
        '''

        b1,b2,bs,b3,alpha0,alpha2,alpha4,alpha6, sn,sn2,sn4 = bvec
        bias_monomials = np.array([1, b1, b1**2, b2, b1*b2, b2**2, bs, b1*bs, b2*bs, bs**2, b3, b1*b3])
        
        try:
            pknu = self.pktables[nu]
        except:
            print("ERROR: Use make_ptable first to compute power spectrum components at angle nu.")
            return np.nan, np.nan
    
        kv = pknu[:,0]; za = pknu[:,-1]
        pktemp = np.copy(pknu)[:,1:-1]
                    
        res = np.sum(pktemp * bias_monomials,axis=1)\
              + (alpha0 + alpha2*nu**2 + alpha4*nu**4 + alpha6*nu**6) * kv**2 * za\
            + sn + sn2 * kv**2*nu**2 + sn4 * kv**4 * nu**4
                    
        return kv, res
        
        
    def combine_bias_terms_pkell(self,bvec):
        '''
        Same as function above but for the multipoles.
        
        Returns k, p0, p2, p4, assuming AP parameters from input p{ell}ktable
        '''
    
    
        b1,b2,bs,b3,alpha0,alpha2,alpha4,alpha6,sn,sn2,sn4 = bvec
        #bias_monomials = np.array([1, b1, b1**2, b2, b1*b2, b2**2, bs, b1*bs, b2*bs, bs**2, b3, b1*b3, alpha0, alpha2, alpha4,alpha6])
        bias_monomials = np.array([1, b1, b1**2, b2, b1*b2, b2**2, bs, b1*bs, b2*bs, bs**2, b3, b1*b3, alpha0, alpha2, alpha4,alpha6,sn,sn2,sn4])

        try:
            kv = self.kv
            p0 = np.sum(self.p0ktable * bias_monomials,axis=1)# + sn + 1./3 * kv**2 * sn2 + 1./5 * kv**4 * sn4
            p2 = np.sum(self.p2ktable * bias_monomials,axis=1)# + 2 * kv**2 * sn2 / 3 + 4./7 * kv**4 * sn4
            p4 = np.sum(self.p4ktable * bias_monomials,axis=1)# + 8./35 * kv**4 * sn4
            return kv, p0, p2, p4
        except:
            print("First generate multipole table with make_pltable.")
            
            
    def combine_bias_terms_xiell(self,bvec,method='loginterp'):
        '''
        Same as above but further transform the pkells into xiells.
        
        Again, the paramters f, AP are assumed to be what was input into p{ell}ktable.
        
        '''
        
        kv, p0, p2, p4 = self.combine_bias_terms_pkell(bvec)
        
        if method == 'loginterp':
        
            damping = np.exp(-(self.kint/10)**2)
            p0int = loginterp(kv, p0)(self.kint) * damping
            p2int = loginterp(kv, p2)(self.kint) * damping
            p4int = loginterp(kv, p4)(self.kint) * damping
            
        elif method == 'gauss_poly':
            # Add a point at k = 0 to the spline in k taper nicely
            
            frac = 1
            
            p0int = gaussian_poly_extrap( self.kint,\
                                          np.concatenate(([0], kv)),\
                                          np.concatenate(([0], p0)), frac=frac)
            
            p2int = gaussian_poly_extrap( self.kint,\
                                          np.concatenate(([0], kv)),\
                                          np.concatenate(([0], p2)), frac=frac )
            
            p4int = gaussian_poly_extrap( self.kint,\
                                          np.concatenate(([0], kv)),\
                                          np.concatenate(([0], p4)), frac=frac )
            
        elif method == 'min_cut':
            # Start log extrapolating when p_ell is below a threshold value:
            ftol = 1e-4
            damping = np.exp(-(self.kint/10)**2)
            
            pints = [np.zeros_like(self.kint), np.zeros_like(self.kint), np.zeros_like(self.kint),]
            
            for ii, pp in enumerate([p0,p2,p4]):
                
                iis = np.arange(len(kv))
                pval = np.max(pp)
                
                try:
                    zero_crossing = np.where(np.diff(np.sign(pp)))[0][0]
                except:
                    zero_crossing = len(pp)
                    
                cross_min = pp > (ftol * pval)

                # union is where we interpolate
                where_int = (iis < zero_crossing) * cross_min
                ktemp, ptemp = kv[where_int], pp[where_int]

                pints[ii] += loginterp(ktemp, ptemp)(self.kint) * damping

            p0int, p2int, p4int = pints
            
            
        ss0, xi0 = self.sphr.sph(0,p0int)
        ss2, xi2 = self.sphr.sph(2,p2int); xi2 *= -1
        ss4, xi4 = self.sphr.sph(4,p4int)
        
        return (ss0, xi0), (ss2, xi2), (ss4, xi4)
        
        #except:
        #    print("First generate multipole table with make_pltable.")
            
            
            

    ### Alternative functions to first combine bias terms, then compute power spectrum
    ### This set of functions currently assumes nonzero bs and b3
    
    
    def p_integral_fixedbias(self, k, bvec, nmax=8):
        
        b1,b2,bs,b3,alpha0,alpha2,alpha4,alpha6,sn,sn2,sn4 = bvec
        bias_monomials = np.array([1, b1, b1**2, b2, b1*b2, b2**2, bs, b1*bs, b2*bs, bs**2, b3, b1*b3])
        
        ksq = k**2
        Kfac = self.Kfac
        f = self.f
        nu = self.nu
        Anu, Bnu = self.Anu, self.Bnu
        
        K = k*self.Kfac; Ksq = K**2
        Knfac = nu*(1+f)
        
        D2 = self.D**2; D4 = D2**2

        expon = np.exp(-0.5*Ksq * D2* (self.XYlin - self.sigma))
        exponm1 = np.expm1(-0.5*Ksq * D2* (self.XYlin - self.sigma))
        suppress = np.exp(-0.5*Ksq * D2* self.sigma)
            
            
        A = k*self.qint*self.c
        C = k*self.qint*self.s
        
        
        G0s =  [self._G0_l(ii,k,nmax=nmax)    for ii in range(self.jn)] + [0] + [0] + [0] + [0]
        dGdAs =  [self._dG0dA_l(ii,k,nmax=nmax) for ii in range(self.jn)] + [0] + [0] + [0]
        dGdCs = [self._dG0dC_l(ii,k,nmax=nmax) for ii in range(self.jn)] + [0] + [0] + [0]
        d2GdA2s = [self._d2G0dA2_l(ii,k,nmax=nmax) for ii in range(self.jn)] + [0] + [0]
        d2GdCdAs = [self._d2G0dCdA_l(ii,k,nmax=nmax) for ii in range(self.jn) ] + [0] + [0]
        d2GdC2s = [self._d2G0dC2_l(ii,k,nmax=nmax) for ii in range(self.jn) ] + [0] + [0]
        d3GdA3s = [self._d3G0dA3_l(ii,k,nmax=nmax) for ii in range(self.jn) ] + [0]
        d3GdCdA2s = [self._d3G0dCdA2_l(ii,k,nmax=nmax) for ii in range(self.jn) ] + [0]
        d4GdA4s = [self._d4G0dA4_l(ii,k,nmax=nmax) for ii in range(self.jn) ]
                
        G01s = [-(dGdAs[ii] + 0.5*A*G0s[ii-1])   for ii in range(self.jn)]
        G02s = [-(d2GdA2s[ii] + A * dGdAs[ii-1] + 0.5*G0s[ii-1] + 0.25 * A**2 *G0s[ii-2]) for ii in range(self.jn)]
        G03s = [d3GdA3s[ii] + 1.5*A*d2GdA2s[ii-1] + 1.5*dGdAs[ii-1] \
                 + 0.75*A**2*dGdAs[ii-2] + 0.75*A*G0s[ii-2] + A**3/8.*G0s[ii-3] for ii in range(self.jn)]
        G04s = [d4GdA4s[ii] + 2*A*d3GdA3s[ii-1] + 3*d2GdA2s[ii-1] \
                + 1.5*A**2*d2GdA2s[ii-2] + 3*A*dGdAs[ii-2] + 0.75*G0s[ii-2]\
                + 0.5*A**3*dGdAs[ii-3] + 0.75*A**2*G0s[ii-3]\
                + A**4/16. * G0s[ii-4] for ii in range(self.jn)]
                 
        G10s = [ dGdCs[ii] + 0.5*C*G0s[ii-1]  for ii in range(self.jn)]
        
        G11s = [ d2GdCdAs[ii] + 0.5*C*dGdAs[ii-1] + 0.5*A*dGdCs[ii-1] + 0.25*A*C*G0s[ii-2] for ii in range(self.jn)]
        G20s = [-(d2GdC2s[ii] + C * dGdCs[ii-1] + 0.5*G0s[ii-1] + 0.25 * C**2 *G0s[ii-2]) for ii in range(self.jn)]
        G12s = [-(d3GdCdA2s[ii] + 0.5*C*d2GdA2s[ii-1] + A*d2GdCdAs[ii-1] + 0.5*dGdCs[ii-1]\
                  + 0.5*A*C*dGdAs[ii-2] + 0.25*A**2*dGdCs[ii-2] + 0.25*C*G0s[ii-2] + A**2*C/8*G0s[ii-3])  for ii in range(self.jn)]

        ret = 0
        bias_integrands = np.zeros( (self.num_power_components,self.N)  )
        bias_integrand  = np.zeros(self.N)
                            
        for l in range(self.jn):
            
            mu0 = G0s[l]
            nq1 = self.Anu * G01s[l] + self.Bnu * G10s[l]
            mu_nq1 = self.Anu * G02s[l] + self.Bnu * G11s[l]
            nq2 = self.Anu**2 * G02s[l] + 2 * self.Anu * self.Bnu * G11s[l] + self.Bnu**2 * self.Bnu**2 * G20s[l]
            mu1 = G01s[l]
            mu2 = G02s[l]
            mu3 = G03s[l]
            mu2_nq1 = self.Anu * G03s[l] + self.Bnu * G12s[l]
            mu4 = G04s[l]
            
            bias_integrands[0,:] = 1 * G0s[l] - 0.5 * Ksq * (self.Xlin_gt * G0s[l] + self.Ylin_gt * mu2) # za
            
            bias_integrands[0,:] += -0.5 * ksq * ( 2*(Kfac**2 + 2*f*(1+f)*nu**2) * G0s[l] * self.X13 +\
                                                  2*(Kfac**2*mu2 + 2*f*Kfac*nu*mu_nq1) * self.Y13 +\
                                                   (Kfac**2 + 2*f*(1+f)*nu**2 + f**2*nu**2) * G0s[l] * self.X22 +\
                                                   (Kfac**2*mu2 + 2*f*Kfac*nu*mu_nq1 + f**2*nu**2*nq2) * self.Y22)\
                                 + Ksq**2 / 8. * (self.Xlin_gt**2 * G0s[l] + 2*self.Xlin_gt*self.Ylin_gt*mu2 + self.Ylin_gt**2 * mu4)# Aloop

                                            
            bias_integrands[0,:] += 0.5*k**3 * ( 2*Kfac*(Kfac**2+f*(1+f)*nu**2) * G01s[l] * self.V1 +  \
                                                Kfac**2 * (Kfac*G01s[l] + f*nu*nq1) * self.V3 + \
                                                Kfac**2 * (Kfac*G03s[l] + f*nu*mu2_nq1) * self.T)
                                                
            bias_integrands[1,:] = -2 * K * (self.Ulin + self.U3) * mu1 - Ksq * (self.X10 * mu0 + self.Y10 * mu2 ) \
                                   -4*f*k*nu*self.U3*nq1 - f*ksq*nu*(self.X10 * Knfac * mu0 + Kfac * self.Y10 * mu_nq1)\
                                   -2 * K * self.Ulin * ( -0.5*Ksq*(self.Xlin_gt*mu1 + self.Ylin_gt*mu3) )
                                   
            bias_integrands[2,:] = self.corlin * (mu0 - 0.5*Ksq*(self.Xlin_gt*mu0 + self.Ylin_gt*mu2) )\
                                   - Ksq*self.Ulin**2*mu2 - k*(Kfac*mu1 + f*k*nu*nq1)*self.U11
                                   
                                   
            bias_integrands[3,:] = - Ksq * self.Ulin**2 * mu2 - k*(Kfac*mu1 + f*nu*nq1)*self.U20 # b2
            bias_integrands[4,:] = -2 * K * self.Ulin * self.corlin * mu1 # b1b2
            bias_integrands[5,:] = 0.5 * self.corlin**2 * mu0 # b2sq
            
            if self.shear or self.third_order:
                bias_integrands[6,:] = - Ksq * (self.Xs2 * mu0 + self.Ys2 * mu2) - 2*k*(Kfac*mu1 + f*nu*nq1)*self.Us2 # bs should be both minus
                bias_integrands[7,:] = -2*K*self.V * mu1 # b1bs
                bias_integrands[8,:] = self.chi * mu0 # b2bs
                bias_integrands[9,:] = self.zeta * mu0 # bssq
                
            if self.third_order:
                bias_integrands[10,:] = -2 * K * self.Ub3 * mu1 #b3
                bias_integrands[11,:] = 2 * self.theta * mu0 #b1 b3
                
            bias_integrands[-1,:] = 1 * G0s[l] - 0.5 * Ksq * (self.Xlin_gt * G0s[l] + self.Ylin_gt * mu2) # za
            
            # sum up bias terms, treating counterterms separately
            bias_integrand  = np.sum( bias_monomials[:,None]*bias_integrands[:-1,:],axis=0 )
            bias_integrand += k**2 * (alpha0 + alpha2*nu**2 + alpha4*nu**4 + alpha6*nu**6) * bias_integrands[-1,:]
            
            # multiply by IR exponent
            if l == 0:
                bias_integrand = bias_integrand * expon * (-2./k/self.qint)**l
                bias_integrand -= bias_integrand[-1]
            else:
                bias_integrand = bias_integrand * expon * (-2./k/self.qint)**l
                                                                
            # do FFTLog
            ktemps, bias_fft = self.sph1.sph(l, bias_integrand)
            ret += interp1d(ktemps, bias_fft)(k)

        return 4*suppress*np.pi*ret + sn + k**2 * nu**2 * sn2 + k**4 * nu**4 * sn4
        
    def make_pknu_fixedbias(self, f, nu, bvec, kv = None, kmin = 1e-2, kmax = 0.25, nk = 50,nmax=5):
    
        self.setup_rsd_facs(f,nu,nmax=nmax)
        
        if kv is None:
            kv = np.logspace(np.log10(kmin), np.log10(kmax), nk)
        else:
            nk = len(kv)
        
        pknu= np.zeros(nk) # one column for ks

        for foo in range(nk):
            pknu[foo] = self.p_integral_fixedbias(kv[foo],bvec,nmax=nmax)
        
        return kv, pknu
        
    def make_pell_fixedbias(self, f, bvec, apar = 1, aperp = 1, ngauss=4, kv = None, kmin = 1e-2, kmax = 0.25, nk = 50,nmax=5):
        
        nus, ws = np.polynomial.legendre.leggauss(2*ngauss)
        nus_calc = nus[0:ngauss]
        
        L0 = np.polynomial.legendre.Legendre((1))(nus)
        L2 = np.polynomial.legendre.Legendre((0,0,1))(nus)
        L4 = np.polynomial.legendre.Legendre((0,0,0,0,1))(nus)
        
        if kv is None:
            kv = np.logspace(np.log10(kmin), np.log10(kmax), nk)
        else:
            nk = len(kv)

        pknutable = np.zeros((len(nus),nk)) # counterterms have distinct nu structure

        
        # To implement AP:
        # Calculate P(k,nu) at the true coordinates, given by
        # k_true = k_apfac * kobs
        # nu_true = nu * a_perp/a_par/fac
        # Note that the integration grid on the other hand is never observed
        
        for ii, nu in enumerate(nus_calc):
        
            fac = np.sqrt(1 + nu**2 * ((aperp/apar)**2-1))
            k_apfac = fac / aperp
            nu_true = nu * aperp/apar/fac
            vol_fac = apar * aperp**2
        
            self.setup_rsd_facs(f,nu_true)
            
            for jj, k in enumerate(kv):
                pknutable[ii,jj] = self.p_integral_fixedbias(k_apfac*k, bvec, nmax=nmax)
 
        
        pknutable[ngauss:,:] = np.flip(pknutable[0:ngauss],axis=0)
        

        p0k = 0.5 * np.sum((ws*L0)[:,None]*pknutable,axis=0) / vol_fac
        p2k = 2.5 * np.sum((ws*L2)[:,None]*pknutable,axis=0) / vol_fac
        p4k = 4.5 * np.sum((ws*L4)[:,None]*pknutable,axis=0) / vol_fac
        
        return kv, p0k, p2k, p4k
        
    def make_xiell_fixedbias(self, f, bvec, apar = 1, aperp = 1, ngauss=4, kmin = 1e-3, kmax = 0.8, nk = 100, nmax=5):

        kv, p0k, p2k, p4k = self.make_pell_fixedbias(f, bvec, apar=apar, aperp=aperp, ngauss=ngauss, kmin = kmin, kmax= kmax, nk = nk, nmax=nmax)
        
        damping = np.exp(-(self.kint/10)**2)
        p0int = loginterp(kv, p0k)(self.kint) * damping
        p2int = loginterp(kv, p2k)(self.kint) * damping
        p4int = loginterp(kv, p4k)(self.kint) * damping
        
        ss0, xi0 = self.sphr.sph(0,p0int)
        ss2, xi2 = self.sphr.sph(2,p2int); xi2 *= -1
        ss4, xi4 = self.sphr.sph(4,p4int)
        
        return (ss0, xi0), (ss2, xi2), (ss4, xi4)
class KEVelocityMoments(KECLEFT):
    '''
    Class based on cleft_kexpanded_fftw to compute pairwise velocity moments, in expanded LPT.
    
    Structured in the same way as the inherited class but with functions for velocity moments.
    
    '''
    def __init__(self, *args, beyond_gauss=True, **kw):
        '''
        Same keywords as the cleft_kexpanded_fftw class. Go look there!
        '''

        # Set up the configuration space quantities
        KECLEFT.__init__(self, *args, **kw)

        self.setup_onedot()
        self.setup_twodots()
        self.beyond_gauss = beyond_gauss

        # v12 and sigma12 only have a subset of the bias contributions so we don't need to have as many FFTs
        if self.third_order:
            self.num_vel_components = 8
            self.vii = np.array([0, 1, 2, 3, 4, 6, 7, 10]) + 1
            self.num_spar_components = 5
            self.sparii = np.array([0, 1, 2, 3, 6]) + 1
            self.num_strace_components = 5
            self.straceii = np.array([0, 1, 2, 3, 6]) + 1
        elif self.shear:
            self.num_vel_components = 7
            self.vii = np.array([0, 1, 2, 3, 4, 6, 7]) + 1
            self.num_spar_components = 5
            self.sparii = np.array([0, 1, 2, 3, 6]) + 1
            self.num_strace_components = 5
            self.straceii = np.array([0, 1, 2, 3, 6]) + 1
        else:
            self.num_vel_components = 5
            self.vii = np.array([0, 1, 2, 3, 4]) + 1
            self.num_spar_components = 4
            self.sparii = np.array([0, 1, 2, 3]) + 1
            self.num_strace_components = 4
            self.straceii = np.array([0, 1, 2, 3]) + 1

        # Need one extra component to do the matter za
        self.sph_v = SphericalBesselTransform(self.qint,
                                              L=self.jn,
                                              ncol=(self.num_vel_components),
                                              threads=self.threads,
                                              import_wisdom=self.import_wisdom,
                                              wisdom_file=self.wisdom_file)
        self.sph_spar = SphericalBesselTransform(
            self.qint,
            L=self.jn,
            ncol=(self.num_spar_components),
            threads=self.threads,
            import_wisdom=self.import_wisdom,
            wisdom_file=self.wisdom_file)
        self.sph_strace = SphericalBesselTransform(
            self.qint,
            L=self.jn,
            ncol=(self.num_strace_components),
            threads=self.threads,
            import_wisdom=self.import_wisdom,
            wisdom_file=self.wisdom_file)

        if self.beyond_gauss:
            # Beyond the first two moments
            self.num_gamma_components = 2
            self.gii = np.array([
                0, 1
            ]) + 1  # gamma has matter (all loop, so lump into 0) and b1
            self.sph_gamma1 = SphericalBesselTransform(
                self.qint,
                L=self.jn,
                ncol=(self.num_gamma_components),
                threads=self.threads,
                import_wisdom=self.import_wisdom,
                wisdom_file=self.wisdom_file)
            self.sph_gamma2 = SphericalBesselTransform(
                self.qint,
                L=self.jn,
                ncol=(self.num_gamma_components),
                threads=self.threads,
                import_wisdom=self.import_wisdom,
                wisdom_file=self.wisdom_file)

            # fourth moment
            self.num_kappa_components = 3
            self.kii = np.array(
                [0, 1, 2]) + 1  # note that these are not the bias comps
            self.sph_kappa = SphericalBesselTransform(
                self.qint,
                L=self.jn,
                ncol=(self.num_kappa_components),
                threads=self.threads,
                import_wisdom=self.import_wisdom,
                wisdom_file=self.wisdom_file)

        self.compute_oneloop_spectra()

    def make_tables(self, kmin=1e-3, kmax=3, nk=100, linear_theory=False):

        self.kv = np.logspace(np.log10(kmin), np.log10(kmax), nk)
        self.make_ptable(kmin=kmin, kmax=kmax, nk=nk)
        self.make_vtable(kmin=kmin, kmax=kmax, nk=nk)
        self.make_spartable(kmin=kmin, kmax=kmax, nk=nk)
        self.make_stracetable(kmin=kmin, kmax=kmax, nk=nk)
        self.convert_sigma_bases()

        if self.beyond_gauss:  # make these even if not required since they're fast
            self.make_gamma1table(kmin=kmin, kmax=kmax, nk=nk)
            self.make_gamma2table(kmin=kmin, kmax=kmax, nk=nk)
            self.convert_gamma_bases()

            self.make_kappatable(kmin=kmin, kmax=kmax, nk=nk)
            self.convert_kappa_bases()

    def compute_oneloop_spectra(self):
        '''
        Compute all velocity spectra nonzero at one loop.
        '''
        self.compute_p_linear()
        self.compute_p_connected()
        self.compute_p_k0()
        self.compute_p_k1()
        self.compute_p_k2()
        self.compute_p_k3()
        self.compute_p_k4()

        self.compute_v_linear()
        self.compute_v_connected()
        self.compute_v_k0()
        self.compute_v_k1()
        self.compute_v_k2()
        self.compute_v_k3()

        self.compute_spar_linear()
        self.compute_spar_connected()
        self.compute_spar_k0()
        self.compute_spar_k1()
        self.compute_spar_k2()

        self.compute_strace_linear()
        self.compute_strace_connected()
        self.compute_strace_k0()
        self.compute_strace_k1()
        self.compute_strace_k2()

        if self.beyond_gauss:

            self.compute_gamma1_connected()
            self.compute_gamma1_k0()
            self.compute_gamma1_k1()

            self.compute_gamma2_connected()
            self.compute_gamma2_k0()
            self.compute_gamma2_k1()

            self.compute_kappa_k0()

    def update_power_spectrum(self, k, p):
        '''
        Same as the one in cleft_fftw but also do the velocities.
        '''
        super(KEVelocityMoments, self).update_power_spectrum(k, p)
        self.setup_onedot()
        self.setup_twodots()
        self.setup_threedots()

    def setup_onedot(self):
        '''
        Create quantities linear in f. All quantities are with f = 1, since converting back is trivial.
        '''
        self.Xdot = self.Xlin
        self.sigmadot = self.Xdot[-1]
        self.Ydot = self.Ylin

        self.Vdot = 4. / 3 * self.Vloop  #these are only the symmetrized version since all we need...
        self.Tdot = 4. / 3 * self.Tloop  # is k_i k_j k_k W_{ijk}

        self.Udot = self.Ulin
        self.Uloopdot = 3 * self.U3

        self.U11dot = 2 * self.U11
        self.U20dot = 2 * self.U20

        # some one loop terms have to be explicitly set to zero
        if self.one_loop:
            self.Xloopdot = (4 * self.qf.Xloop13 +
                             2 * self.qf.Xloop22) * self.one_loop
            self.sigmaloopdot = self.Xloopdot[-1]
            self.Yloopdot = (4 * self.qf.Yloop13 +
                             2 * self.qf.Yloop22) * self.one_loop
            self.X10dot = 1.5 * self.X10
            self.sigma10dot = self.X10dot[-1]
            self.Y10dot = 1.5 * self.Y10
        else:
            self.Xloopdot = 0
            self.sigmaloopdot = 0
            self.Yloopdot = 0
            self.X10dot = 0
            self.sigma10dot = 0
            self.Y10dot = 0

        if self.shear or self.third_order:
            self.Us2dot = 2 * self.Us2
            self.V12dot = self.V
            self.Xs2dot = self.Xs2
            self.sigmas2dot = self.Xs2dot[-1]
            self.Ys2dot = self.Ys2

        if self.third_order:
            self.Ub3dot = self.Ub3

    def setup_twodots(self):
        '''
        Same as onedot but now for those quadratic in f.
        '''
        self.Xddot = self.Xlin
        self.sigmaddot = self.Xddot[-1]
        self.Yddot = self.Ylin

        # Here we will need two forms, one symmetrized:
        self.Vddot = 5. / 3 * self.Vloop  #these are only the symmetrized version since all we need...
        self.Tddot = 5. / 3 * self.Tloop  # is k_i k_j k_k W_{ijk}

        # Explicitly set certain terms to zero if not one loop
        if self.one_loop:
            self.Xloopddot = (4 * self.qf.Xloop22 +
                              6 * self.qf.Xloop13) * self.one_loop
            self.sigmaloopddot = self.Xloopddot[-1]
            self.Yloopddot = (4 * self.qf.Yloop22 +
                              6 * self.qf.Yloop13) * self.one_loop

            self.X10ddot = 2 * self.X10
            self.sigma10ddot = self.X10ddot[-1]
            self.Y10ddot = 2 * self.Y10

            # and the other from k_i \delta_{jk} \ddot{W}_{ijk}
            self.kdelta_Wddot = (18 * self.qf.V1loop112 + 7 * self.qf.V3loop112
                                 + 5 * self.qf.Tloop112) * self.one_loop
        else:
            self.Xloopddot = 0
            self.sigmaloopddot = 0
            self.Yloopddot = 0
            self.X10ddot = 0
            self.sigma10ddot = 0
            self.Y10ddot = 0
            self.kdelta_Wddot = 0

        if self.shear or self.third_order:
            self.Xs2ddot = self.Xs2
            self.sigmas2ddot = self.Xs2ddot[-1]
            self.Ys2ddot = self.Ys2

    def setup_threedots(self):
        self.Vdddot = 2 * self.Vloop
        self.Tdddot = 2 * self.Tloop

    def compute_v_linear(self):

        self.v_linear = np.zeros((self.num_vel_components, self.N))
        self.v_linear[0, :] = (-2 * self.pint) / self.kint
        self.v_linear[1, :] = (-2 * self.pint) / self.kint

    def compute_v_connected(self):

        self.v_connected = np.zeros((self.num_vel_components, self.N))
        self.v_connected[0, :] = (
            -2 *
            (2 * 9. / 98 * self.qf.Q1 + 4 * 5. / 21 * self.qf.R1) - 12. / 7 *
            (self.qf.Q2 + 2 * self.qf.R2)) / self.kint
        self.v_connected[1, :] = (
            -3 * (12 * self.qf.R2 + 6 * self.qf.Q5 + 6 * self.qf.R1) / 7 - 3 *
            (10. / 21 * self.qf.R1)) / self.kint
        self.v_connected[2, :] = -12. / 7 * (self.qf.R1 +
                                             self.qf.R2) / self.kint
        self.v_connected[3, :] = -6. / 7 * self.qf.Q8 / self.kint

        if self.shear or self.third_order:
            self.v_connected[5, :] = -2 * 2 * 1. / 7 * self.qf.Qs2 / self.kint
        if self.third_order:
            self.v_connected[7, :] = -2 * self.qf.Rb3 * self.pint / self.kint

    def compute_v_k0(self):

        self.v_k0 = np.zeros((self.num_vel_components, self.N))
        ret = np.zeros(self.num_vel_components)

        bias_integrands = np.zeros((self.num_vel_components, self.N))

        for l in range(2):
            mu1fac = (l == 1)

            bias_integrands[4, :] = mu1fac * (2 * self.corlin * self.Udot)

            if self.shear or self.third_order:
                bias_integrands[6, :] = mu1fac * (2 * self.V12dot)

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_v.sph(l, bias_integrands)
            self.v_k0 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_v_k1(self):

        self.v_k1 = np.zeros((self.num_vel_components, self.N))
        ret = np.zeros(self.num_vel_components)

        bias_integrands = np.zeros((self.num_vel_components, self.N))

        for l in [0, 2]:
            mu0fac = (l == 0)
            mu1fac = (l == 1)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)

            bias_integrands[2,:] = mu0fac * ( self.corlin*self.Xdot ) + \
                                   mu2fac * (2*self.Ulin*self.Udot + self.corlin*self.Ydot)
            bias_integrands[3, :] = mu2fac * (2 * self.Ulin * self.Udot)

            if self.shear or self.third_order:
                bias_integrands[5, :] = mu0fac * (2 * self.Xs2dot) + mu2fac * (
                    2 * self.Ys2dot)
                #bias_integrands[9,:] = mu0fac * self.zeta

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_v.sph(l, bias_integrands)
            self.v_k1 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_v_k2(self):

        self.v_k2 = np.zeros((self.num_vel_components, self.N))
        ret = np.zeros(self.num_vel_components)

        bias_integrands = np.zeros((self.num_vel_components, self.N))

        for l in range(4):
            mu0fac = (l == 0)
            mu1fac = (l == 1)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)
            mu3fac = 0.6 * (l == 1) - 0.4 * (l == 3)

            bias_integrands[1,:] = mu1fac * ( -2*self.Ulin*self.Xdot - self.Udot*self.Xlin ) + \
                                   mu3fac * ( -2*self.Ulin*self.Ydot - self.Udot*self.Ylin )

            #if self.shear or self.third_order:
            #bias_integrands[8,:] = mu0fac * self.chi
            #bias_integrands[9,:] = mu0fac * self.zeta

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_v.sph(l, bias_integrands)
            self.v_k2 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_v_k3(self):

        self.v_k3 = np.zeros((self.num_vel_components, self.N))
        ret = np.zeros(self.num_vel_components)

        bias_integrands = np.zeros((self.num_vel_components, self.N))

        for l in range(self.jn):
            mu0fac = (l == 0)
            mu1fac = (l == 1)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)
            mu3fac = 0.6 * (l == 1) - 0.4 * (l == 3)
            mu4fac = 0.2 * (l == 0) - 4. / 7 * (l == 2) + 8. / 35 * (l == 4)

            bias_integrands[0,:] = mu0fac * ( - 0.5*self.Xdot*self.Xlin ) + \
                                   mu2fac * ( - 0.5*(self.Xdot*self.Ylin+self.Ydot*self.Xlin) ) + \
                                   mu4fac * ( - 0.5*self.Ylin*self.Ydot )

            #if self.shear or self.third_order:
            #bias_integrands[8,:] = mu0fac * self.chi
            #bias_integrands[9,:] = mu0fac * self.zeta

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_v.sph(l, bias_integrands)
            self.v_k3 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_spar_linear(self):
        self.spar_linear = np.zeros((self.num_spar_components, self.N))
        self.spar_linear[0, :] = (-2 * self.pint) / self.kint**2

    def compute_spar_connected(self):
        self.spar_connected = np.zeros((self.num_spar_components, self.N))
        self.spar_connected[0, :] = (
            -2 * (4 * 9. / 98 * self.qf.Q1 + 6 * 5. / 21 * self.qf.R1) - 6 *
            (5. / 3 * 3. / 7 * (self.qf.Q2 + 2 * self.qf.R2))) / self.kint**2
        self.spar_connected[
            1, :] = (-2 * 2 * (12. / 7 * self.qf.R2 + 6. / 7 * self.qf.Q5 +
                               6. / 7 * self.qf.R1)) / self.kint**2

    def compute_spar_k0(self):
        self.spar_k0 = np.zeros((self.num_spar_components, self.N))
        bias_integrands = np.zeros((self.num_spar_components, self.N))

        for l in range(3):
            mu0fac = (l == 0)
            mu1fac = (l == 1)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)

            bias_integrands[
                2, :] = mu0fac * (self.corlin * self.Xddot) + mu2fac * (
                    self.corlin * self.Yddot + 2 * self.Udot**2)
            bias_integrands[3, :] = mu2fac * (2 * self.Udot**2)

            if self.shear or self.third_order:
                bias_integrands[4, :] = mu0fac * (
                    2 * self.Xs2ddot) + mu2fac * (2 * self.Ys2ddot)

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_spar.sph(l, bias_integrands)
            self.spar_k0 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_spar_k1(self):
        self.spar_k1 = np.zeros((self.num_spar_components, self.N))
        bias_integrands = np.zeros((self.num_spar_components, self.N))

        for l in range(4):
            mu0fac = (l == 0)
            mu1fac = (l == 1)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)
            mu3fac = 0.6 * (l == 1) - 0.4 * (l == 3)
            mu4fac = 0.2 * (l == 0) - 4. / 7 * (l == 2) + 8. / 35 * (l == 4)

            bias_integrands[1,:] = mu1fac * (-2*(self.Ulin*self.Xddot + 2*self.Udot*self.Xdot) ) + \
                                   mu3fac * (-2*(self.Ulin*self.Yddot + 2*self.Udot*self.Ydot) )

            #if self.shear or self.third_order:
            #bias_integrands[8,:] = mu0fac * self.chi
            #bias_integrands[9,:] = mu0fac * self.zeta

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_spar.sph(l, bias_integrands)
            self.spar_k1 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_spar_k2(self):
        self.spar_k2 = np.zeros((self.num_spar_components, self.N))
        bias_integrands = np.zeros((self.num_spar_components, self.N))

        for l in range(self.jn):
            mu0fac = (l == 0)
            mu1fac = (l == 1)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)
            mu3fac = 0.6 * (l == 1) - 0.4 * (l == 3)
            mu4fac = 0.2 * (l == 0) - 4. / 7 * (l == 2) + 8. / 35 * (l == 4)

            bias_integrands[0,:] = mu0fac * (-self.Xdot**2 - 0.5*self.Xddot*self.Xlin) + \
                                   mu2fac * (- 2*self.Xdot*self.Ydot - 0.5*(self.Xddot*self.Ylin + self.Yddot*self.Xlin)) + \
                                   mu4fac * (-self.Ydot**2 - 0.5*self.Yddot*self.Ylin)

            #if self.shear or self.third_order:
            #bias_integrands[8,:] = mu0fac * self.chi
            #bias_integrands[9,:] = mu0fac * self.zeta

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_spar.sph(l, bias_integrands)
            self.spar_k2 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_strace_linear(self):
        self.strace_linear = np.zeros((self.num_spar_components, self.N))
        self.strace_linear[0, :] = (-2 * self.pint) / self.kint**2

    def compute_strace_connected(self):
        self.strace_connected = np.zeros((self.num_spar_components, self.N))
        self.strace_connected[0,:] = (- 2*(4*9./98*self.qf.Q1 + 6*5./21*self.qf.R1) \
                                        + 6./7*(self.qf.Q1-5*self.qf.Q2+4*self.qf.R1-10*self.qf.R2))/self.kint**2
        self.strace_connected[1, :] = (-4 / self.kint**2 * 3. / 7 *
                                       (4 * self.qf.R2 + 2 * self.qf.Q5))

    def compute_strace_k0(self):
        self.strace_k0 = np.zeros((self.num_strace_components, self.N))
        bias_integrands = np.zeros((self.num_strace_components, self.N))

        for l in range(self.jn):
            mu0fac = (l == 0)

            bias_integrands[2, :] = mu0fac * (self.corlin *
                                              (3 * self.Xddot + self.Yddot) +
                                              2 * self.Udot**2)
            bias_integrands[3, :] = mu0fac * (2 * self.Udot**2)

            if self.shear or self.third_order:
                bias_integrands[
                    4, :] = mu0fac * (2 * (3 * self.Xs2ddot + self.Ys2ddot))

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_strace.sph(l, bias_integrands)
            self.strace_k0 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_strace_k1(self):
        self.strace_k1 = np.zeros((self.num_strace_components, self.N))
        bias_integrands = np.zeros((self.num_strace_components, self.N))

        for l in range(self.jn):
            mu0fac = (l == 0)
            mu1fac = (l == 1)

            bias_integrands[1, :] = mu1fac * (-2 * self.Ulin *
                                              (3 * self.Xddot + self.Yddot) -
                                              4 * self.Udot *
                                              (self.Xdot + self.Ydot))

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_strace.sph(l, bias_integrands)
            self.strace_k1 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_strace_k2(self):
        self.strace_k2 = np.zeros((self.num_strace_components, self.N))
        bias_integrands = np.zeros((self.num_strace_components, self.N))

        for l in range(self.jn):
            mu0fac = (l == 0)
            mu1fac = (l == 1)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)

            bias_integrands[0,:] = mu0fac * ( (3*self.Xddot + self.Yddot)*(- 0.5*self.Xlin) - self.Xdot**2) + \
                                   mu2fac * ((3*self.Xddot + self.Yddot)*(- 0.5*self.Ylin) - (self.Ydot**2+2*self.Xdot*self.Ydot))

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_strace.sph(l, bias_integrands)
            self.strace_k2 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_gamma1_connected(self):
        self.gamma1_connected = np.zeros((self.num_gamma_components, self.N))
        self.gamma1_connected[0, :] = 36. / 7 * (self.qf.Q2 +
                                                 2 * self.qf.R2) / self.kint**3

    def compute_gamma1_k0(self):
        self.gamma1_k0 = np.zeros((self.num_gamma_components, self.N))
        bias_integrands = np.zeros((self.num_gamma_components, self.N))

        for l in range(self.jn):
            mu1fac = (l == 1)
            mu3fac = 0.6 * (l == 1) - 0.4 * (l == 3)

            bias_integrands[
                1, :] = mu1fac * (6 * self.Udot * self.Xddot) + mu3fac * (
                    6 * self.Udot * self.Yddot)

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_gamma1.sph(l, bias_integrands)
            self.gamma1_k0 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_gamma1_k1(self):
        self.gamma1_k1 = np.zeros((self.num_gamma_components, self.N))
        bias_integrands = np.zeros((self.num_gamma_components, self.N))

        for l in range(self.jn):
            mu0fac = (l == 0)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)
            mu4fac = 0.2 * (l == 0) - 4. / 7 * (l == 2) + 8. / 35 * (l == 4)

            bias_integrands[0,:] = mu0fac * (3*self.Xdot*self.Xddot) + \
                                   mu2fac * (3*(self.Xdot*self.Yddot+self.Ydot*self.Xddot)) + \
                                   mu4fac * (3*self.Ydot*self.Yddot)

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_gamma1.sph(l, bias_integrands)
            self.gamma1_k1 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_gamma2_connected(self):
        self.gamma2_connected = np.zeros((self.num_gamma_components, self.N))
        self.gamma2_connected[
            0, :] = -12. / 7 * (2 * self.qf.R1 - 6 * self.qf.R2 + self.qf.Q1 -
                                3 * self.qf.Q2) / self.kint**3

    def compute_gamma2_k0(self):
        self.gamma2_k0 = np.zeros((self.num_gamma_components, self.N))
        bias_integrands = np.zeros((self.num_gamma_components, self.N))

        for l in range(self.jn):
            mu1fac = (l == 1)
            mu3fac = 0.6 * (l == 1) - 0.4 * (l == 3)

            bias_integrands[
                1, :] = mu1fac * (2 * self.Udot *
                                  (5 * self.Xddot + 3 * self.Yddot))

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_gamma2.sph(l, bias_integrands)
            self.gamma2_k0 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_gamma2_k1(self):
        self.gamma2_k1 = np.zeros((self.num_gamma_components, self.N))
        bias_integrands = np.zeros((self.num_gamma_components, self.N))

        for l in range(self.jn):
            mu0fac = (l == 0)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)
            mu4fac = 0.2 * (l == 0) - 4. / 7 * (l == 2) + 8. / 35 * (l == 4)

            bias_integrands[0,:] = mu0fac * ( (5*self.Xdot*self.Xddot+self.Xdot*self.Yddot) ) + \
                                       mu2fac * ( (2*self.Xdot*self.Yddot+self.Ydot*(5*self.Xddot+3*self.Yddot)) )

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_gamma2.sph(l, bias_integrands)
            self.gamma2_k1 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def compute_kappa_k0(self):
        self.kappa_k0 = np.zeros((self.num_kappa_components, self.N))
        bias_integrands = np.zeros((self.num_kappa_components, self.N))

        for l in range(self.jn):
            mu0fac = (l == 0)
            mu2fac = 1. / 3 * (l == 0) - 2. / 3 * (l == 2)
            mu4fac = 0.2 * (l == 0) - 4. / 7 * (l == 2) + 8. / 35 * (l == 4)

            bias_integrands[0, :] = mu0fac * (15 * self.Xddot**2 +
                                              10 * self.Xddot * self.Yddot +
                                              3 * self.Yddot**2)
            bias_integrands[1,:] = mu0fac * (5 * self.Xddot**2 + self.Xddot*self.Yddot) + \
                                   mu2fac * (7*self.Xddot*self.Yddot + 3*self.Yddot**2)
            bias_integrands[2, :] = mu0fac * (3 * self.Xddot**2) + mu2fac * (
                6 * self.Xddot * self.Yddot) + mu4fac * (3 * self.Yddot**2)

            if l >= 0:
                bias_integrands -= bias_integrands[:, -1][:, None]

            # do FFTLog
            ktemps, bias_ffts = self.sph_kappa.sph(l, bias_integrands)
            self.kappa_k0 += 4 * np.pi * interp1d(
                ktemps, bias_ffts, bounds_error=False)(self.kint)

    def make_table(self,
                   kmin=1e-3,
                   kmax=3,
                   nk=100,
                   func_name='power',
                   linear_theory=False):
        '''
            Make a table of different terms of P(k), v(k), sigma(k) between a given
            'kmin', 'kmax' and for 'nk' equally spaced values in log10 of k
            This is the most time consuming part of the code.
        '''

        pktable = np.zeros([
            nk, self.num_power_components + 1
        ])  # one column for ks, but last column in power now the counterterm
        kv = np.logspace(np.log10(kmin), np.log10(kmax), nk)
        pktable[:, 0] = kv[:]

        if not linear_theory:
            if func_name == 'power':
                components = [ (1, self.p_linear+self.p_connected + self.p_k0), (self.kint, self.p_k1),\
                           (self.kint**2, self.p_k2), (self.kint**3, self.p_k3), (self.kint**4, self.p_k4)  ]
                iis = np.arange(1, 1 + self.num_power_components)
            elif func_name == 'velocity':
                components = [
                    (1, self.v_linear + self.v_connected + self.v_k0),
                    (self.kint, self.v_k1), (self.kint**2, self.v_k2),
                    (self.kint**3, self.v_k3)
                ]
                iis = self.vii
            elif func_name == 'spar':
                components = [
                    (1, self.spar_linear + self.spar_connected + self.spar_k0),
                    (self.kint, self.spar_k1), (self.kint**2, self.spar_k2)
                ]
                iis = self.sparii
            elif func_name == 'strace':
                components = [(1, self.strace_linear + self.strace_connected +
                               self.strace_k0), (self.kint, self.strace_k1),
                              (self.kint**2, self.strace_k2)]
                iis = self.straceii
            elif func_name == 'gamma1':
                components = [(1, self.gamma1_connected + self.gamma1_k0),
                              (self.kint, self.gamma1_k1)]
                iis = self.gii
            elif func_name == 'gamma2':
                components = [(1, self.gamma2_connected + self.gamma2_k0),
                              (self.kint, self.gamma2_k1)]
                iis = self.gii
            elif func_name == 'kappa':
                components = [(1, self.kappa_k0)]
                iis = self.kii
        else:
            if func_name == 'power':
                components = [(1, self.p_linear)]
                iis = np.arange(1, 1 + self.num_power_components)
            elif func_name == 'velocity':
                components = [(1, self.v_linear)]
                iis = self.vii
            elif func_name == 'spar':
                components = [(1, self.spar_linear)]
                iis = self.sparii
            elif func_name == 'strace':
                components = [(1, self.strace_linear)]
                iis = self.straceii
            elif func_name == 'gamma1':
                return pktable
            elif func_name == 'gamma2':
                return pktable
            elif func_name == 'kappa':
                return pktable

        # sum the components:
        ptable = 0
        for (kpow, comp) in components:
            ptable += kpow * comp

        # interpolate onto range of interest
        for jj in range(len(iis)):
            pktable[:, iis[jj]] = interp1d(self.kint, ptable[jj, :])(kv)

        return pktable

    def make_ptable(self, kmin=1e-3, kmax=3, nk=100):
        self.pktable_linear = self.make_table(kmin=kmin,
                                              kmax=kmax,
                                              nk=nk,
                                              func_name='power',
                                              linear_theory=True)
        self.pktable = self.make_table(kmin=kmin,
                                       kmax=kmax,
                                       nk=nk,
                                       func_name='power')

    def make_vtable(self, kmin=1e-3, kmax=3, nk=100):
        self.vktable_linear = self.make_table(kmin=kmin,
                                              kmax=kmax,
                                              nk=nk,
                                              func_name='velocity',
                                              linear_theory=True)
        self.vktable = self.make_table(kmin=kmin,
                                       kmax=kmax,
                                       nk=nk,
                                       func_name='velocity')

    def make_spartable(self, kmin=1e-3, kmax=3, nk=100):
        self.sparktable_linear = self.make_table(kmin=kmin,
                                                 kmax=kmax,
                                                 nk=nk,
                                                 func_name='spar',
                                                 linear_theory=True)
        self.sparktable = self.make_table(kmin=kmin,
                                          kmax=kmax,
                                          nk=nk,
                                          func_name='spar')

    def make_stracetable(self, kmin=1e-3, kmax=3, nk=100):
        self.stracektable_linear = self.make_table(kmin=kmin,
                                                   kmax=kmax,
                                                   nk=nk,
                                                   func_name='strace',
                                                   linear_theory=True)
        self.stracektable = self.make_table(kmin=kmin,
                                            kmax=kmax,
                                            nk=nk,
                                            func_name='strace')

    def make_gamma1table(self, kmin=1e-3, kmax=3, nk=100):
        self.gamma1ktable_linear = self.make_table(kmin=kmin,
                                                   kmax=kmax,
                                                   nk=nk,
                                                   func_name='gamma1',
                                                   linear_theory=True)
        self.gamma1ktable = self.make_table(kmin=kmin,
                                            kmax=kmax,
                                            nk=nk,
                                            func_name='gamma1')

    def make_gamma2table(self, kmin=1e-3, kmax=3, nk=100, linear_theory=True):
        self.gamma2ktable_linear = self.make_table(kmin=kmin,
                                                   kmax=kmax,
                                                   nk=nk,
                                                   func_name='gamma2',
                                                   linear_theory=True)
        self.gamma2ktable = self.make_table(kmin=kmin,
                                            kmax=kmax,
                                            nk=nk,
                                            func_name='gamma2')

    def make_kappatable(self, kmin=1e-3, kmax=3, nk=100):
        self.kappaktable_linear = self.make_table(kmin=kmin,
                                                  kmax=kmax,
                                                  nk=nk,
                                                  func_name='kappa',
                                                  linear_theory=True)
        self.kappaktable = self.make_table(kmin=kmin,
                                           kmax=kmax,
                                           nk=nk,
                                           func_name='kappa')

    def convert_sigma_bases(self, basis='Legendre'):
        '''
        Function to convert Tr\sigma and \sigma_\par to the desired basis.
        
        These are:
        - Legendre
        
        sigma = sigma_0 delta_ij + sigma_2 (3 k_i k_j - delta_ij)/2
        
        - Polynomial
        
        sigma = sigma_0 delta_ij + sigma_2 k_i k_j
        
        - los (line of sight, note that sigma_0 = kpar and sigma_2 = kperp in this case)
        
        sigma = sigma_0 k_i k_j + sigma_2 (delta_ij - k_i k_j)/2
        
        '''
        if self.sparktable is None or self.stracektable is None:
            print("Error: Need to compute sigma before changing bases!")
            return 0

        kv = self.sparktable[:, 0]

        if basis == 'Legendre':
            self.s0_linear = self.stracektable_linear / 3.
            self.s2_linear = self.sparktable_linear - self.s0_linear
            self.s0_linear[:, 0] = kv
            self.s2_linear[:, 0] = kv

            self.s0 = self.stracektable / 3.
            self.s2 = self.sparktable - self.s0
            self.s0[:, 0] = kv
            self.s2[:, 0] = kv

        if basis == 'Polynomial':
            self.s0_linear = 0.5 * (self.stracektable_linear -
                                    self.sparktable_linear)
            self.s2_linear = 0.5 * (3 * self.sparktable_linear -
                                    self.stracektable_linear)
            self.s0_linear[:, 0] = kv
            self.s2_linear[:, 0] = kv

            self.s0 = 0.5 * (self.stracektable - self.sparktable)
            self.s2 = 0.5 * (3 * self.sparktable - self.stracektable)
            self.s0[:, 0] = kv
            self.s2[:, 0] = kv

        if basis == 'los':
            self.s0_linear = self.sparktable_linear
            self.s2_linear = self.stracektable_linear - self.sparktable_linear
            self.s0_linear[:, 0] = kv
            self.s2_linear[:, 0] = kv

            self.s0 = self.sparktable
            self.s2 = self.stracektable - self.sparktable
            self.s0[:, 0] = kv
            self.s2[:, 0] = kv

    def convert_gamma_bases(self, basis='Polynomial'):
        '''
        Translates the contraction of gamma into the polynomial/legendre basis
        given by Im[gamma] = g3 \hk_i \hk_j \hk_k + g1 (\hk_i \delta{ij} + et cycl) / 3
        
        '''
        if self.gamma1ktable is None or self.gamma2ktable is None:
            print("Error: Need to compute sigma before changing bases!")
            return 0

        kv = self.gamma1ktable[:, 0]

        # Polynomial basis
        if basis == 'Polynomial':
            self.g1 = 1.5 * self.gamma2ktable - 1.5 * self.gamma1ktable
            self.g3 = 2.5 * self.gamma1ktable - 1.5 * self.gamma2ktable

        if basis == 'Legendre':
            self.g1 = 0.6 * self.gamma2ktable
            self.g3 = 2.5 * self.gamma1ktable - 1.5 * self.gamma2ktable

        self.g1[:, 0] = kv
        self.g3[:, 0] = kv

    def convert_kappa_bases(self, basis='Polynomial'):
        '''
        Translates the contraction of gamma into the polynomial basis
        given by
        kappa = kappa0 / 3 * (delta_ij delta_kl + perms)
                + kappa2 / 6 * (k_i k_j delta_kl + perms)
                + kappa4 * k_i k_j k_k k_l
        '''

        if self.kappaktable is None:
            print("Error: Need to compute kappa before changing bases!")
            return 0

        self.k0 = 3. / 8 * (self.kappaktable[:, 1] - 2 * self.kappaktable[:, 2]
                            + self.kappaktable[:, 3])
        self.k2 = 3. / 4 * (-self.kappaktable[:, 1] +
                            6 * self.kappaktable[:, 2] -
                            5 * self.kappaktable[:, 3])
        self.k4 = 1. / 8 * (3 * self.kappaktable[:, 1] -
                            30 * self.kappaktable[:, 2] +
                            35 * self.kappaktable[:, 3])
예제 #4
0
class CLEFT:
    '''
    Class to calculate power spectra up to one loop.
    
    Based on Chirag's code
    
    https://github.com/sfschen/velocileptors/blob/master/LPT/cleft_fftw.py
    
    The bias parameters are ordered in pktable as
    1, b1, b1^2, b2, b1b2, b2^2, bs, b1bs, b2bs, bs^2, b3, b1 b3
    where b3 is a catch-all for third order bias parameters degenerate at one-loop order.
    
    Can combine into a full one-loop real-space power spectrum using the function combine_bias_terms_pk.
    
    '''
    def __init__(self,
                 k,
                 p,
                 one_loop=True,
                 shear=True,
                 third_order=True,
                 cutoff=10,
                 jn=5,
                 N=2000,
                 threads=1,
                 extrap_min=-5,
                 extrap_max=3,
                 import_wisdom=False,
                 wisdom_file='wisdom.npy'):

        self.N = N
        self.extrap_max = extrap_max
        self.extrap_min = extrap_min

        self.cutoff = cutoff
        self.kint = np.logspace(extrap_min, extrap_max, self.N)
        self.qint = np.logspace(-extrap_max, -extrap_min, self.N)

        self.one_loop = one_loop
        self.shear = shear
        self.third_order = third_order

        self.update_power_spectrum(k, p)

        self.pktable = None
        if self.third_order:
            self.num_power_components = 13
        elif self.shear:
            self.num_power_components = 11
        else:
            self.num_power_components = 7

        self.jn = jn
        self.threads = threads
        self.import_wisdom = import_wisdom
        self.wisdom_file = wisdom_file
        self.sph = SphericalBesselTransform(self.qint,
                                            L=self.jn,
                                            ncol=self.num_power_components,
                                            threads=self.threads,
                                            import_wisdom=self.import_wisdom,
                                            wisdom_file=self.wisdom_file)

    def update_power_spectrum(self, k, p):
        # Updates the power spectrum and various q functions. Can continually compute for new cosmologies without reloading FFTW
        self.k = k
        self.p = p
        self.pint = loginterp(k, p)(
            self.kint) * np.exp(-(self.kint / self.cutoff)**2)
        self.setup_powerspectrum()

    def setup_powerspectrum(self):

        # This sets up terms up to one looop in the combination (symmetry factors) they appear in pk

        self.qf = QFuncFFT(self.kint,
                           self.pint,
                           qv=self.qint,
                           oneloop=self.one_loop,
                           shear=self.shear,
                           third_order=self.third_order)

        # linear terms
        self.Xlin = self.qf.Xlin
        self.Ylin = self.qf.Ylin

        self.XYlin = self.Xlin + self.Ylin
        self.sigma = self.XYlin[-1]
        self.yq = self.Ylin / self.qint

        self.Ulin = self.qf.Ulin
        self.corlin = self.qf.corlin

        if self.one_loop:
            # one loop terms: here we add in all the symmetry factors
            self.Xloop = 2 * self.qf.Xloop13 + self.qf.Xloop22
            self.sigmaloop = self.Xloop[-1]
            self.Yloop = 2 * self.qf.Yloop13 + self.qf.Yloop22

            self.Vloop = 3 * (2 * self.qf.V1loop112 + self.qf.V3loop112
                              )  # this multiplies mu in the pk integral
            self.Tloop = 3 * self.qf.Tloop112  # and this multiplies mu^3

            self.X10 = 2 * self.qf.X10loop12
            self.Y10 = 2 * self.qf.Y10loop12
            self.sigma10 = (self.X10 + self.Y10)[-1]

            self.U3 = self.qf.U3
            self.U11 = self.qf.U11
            self.U20 = self.qf.U20
            self.Us2 = self.qf.Us2

        else:
            self.Xloop, self.Yloop, self.sigmaloop, self.Vloop, self.Tloop, self.X10, self.Y10, self.sigma10, self.U3, self.U11, self.U20, self.Us2 = (
                0, ) * 12

        # load shear functions
        if self.shear or self.third_order:
            self.Xs2 = self.qf.Xs2
            self.Ys2 = self.qf.Ys2
            self.sigmas2 = (self.Xs2 + self.Ys2)[-1]
            self.V = self.qf.V
            self.zeta = self.qf.zeta
            self.chi = self.qf.chi

        if self.third_order:
            self.Ub3 = self.qf.Ub3
            self.theta = self.qf.theta

    def p_integrals(self, k):
        '''
        Compute P(k) for a single k as a vector of all bias contributions.
        
        '''
        ksq = k**2
        kcu = k**3
        k4 = k**4
        expon = np.exp(-0.5 * ksq * (self.XYlin - self.sigma))
        exponm1 = np.expm1(-0.5 * ksq * (self.XYlin - self.sigma))
        suppress = np.exp(-0.5 * ksq * self.sigma)

        ret = np.zeros(self.num_power_components)

        bias_integrands = np.zeros((self.num_power_components, self.N))

        for l in range(self.jn):
            # l-dep functions
            shiftfac = (l > 0) / (k * self.yq)
            mu2fac = 1. - 2. * l / ksq / self.Ylin
            mu3fac = 1. - 2. * (
                l - 1) / ksq / self.Ylin  # mu3 terms start at j1 so l -> l-1
            mu4fac = 1 - 4 * l / ksq / self.Ylin + 4 * l * (l - 1) / (
                ksq * self.Ylin)**2

            bias_integrands[0, :] = 1. - 0.5 * ksq * (
                self.Xloop + mu2fac * self.Yloop) + kcu * shiftfac * (
                    self.Vloop + self.Tloop * mu3fac) / 6.  # matter
            bias_integrands[
                1, :] = (-2 * k * (self.Ulin + self.U3)) * shiftfac - ksq * (
                    self.X10 + self.Y10 * mu2fac)  # b1
            bias_integrands[
                2, :] = self.corlin - ksq * mu2fac * self.Ulin**2 - shiftfac * k * self.U11  # b1sq
            bias_integrands[
                3, :] = -ksq * mu2fac * self.Ulin**2 - shiftfac * k * self.U20  # b2
            bias_integrands[4, :] = (-2 * k * self.Ulin *
                                     self.corlin) * shiftfac  # b1b2
            bias_integrands[5, :] = 0.5 * self.corlin**2  # b2sq

            if self.shear or self.third_order:
                bias_integrands[6, :] = -ksq * (
                    self.Xs2 + mu2fac * self.Ys2
                ) - 2 * k * self.Us2 * shiftfac  # bs should be both minus
                bias_integrands[7, :] = -2 * k * self.V * shiftfac  # b1bs
                bias_integrands[8, :] = self.chi  # b2bs
                bias_integrands[9, :] = self.zeta  # bssq

            if self.third_order:
                bias_integrands[10, :] = -2 * k * self.Ub3 * shiftfac  #bs
                bias_integrands[11, :] = 2 * self.theta  #b1 bs

            bias_integrands[
                -1, :] = 1  # this is the counterterm, minus a factor of k2

            # multiply by IR exponent
            if l == 0:
                bias_integrands = bias_integrands * expon
                bias_integrands -= bias_integrands[:,
                                                   -1][:,
                                                       None]  # note that expon(q = infinity) = 1
            else:
                bias_integrands = bias_integrands * expon * self.yq**l

            # do FFTLog
            ktemps, bias_ffts = self.sph.sph(l, bias_integrands)
            ret += k**l * interp1d(ktemps, bias_ffts)(k)

        #ret += ret[0] * zero_lags

        return 4 * suppress * np.pi * ret

    def make_ptable(self, kmin=1e-3, kmax=3, nk=100):
        '''
        Make a table of different terms of P(k) between a given
        'kmin', 'kmax' and for 'nk' equally spaced values in log10 of k
        This is the most time consuming part of the code.
        '''
        self.pktable = np.zeros([nk, self.num_power_components + 1
                                 ])  # one column for ks
        kv = np.logspace(np.log10(kmin), np.log10(kmax), nk)
        self.pktable[:, 0] = kv[:]
        for foo in range(nk):
            self.pktable[foo, 1:] = self.p_integrals(kv[foo])

    def combine_bias_terms_pk(self, b1, b2, bs, b3, alpha, sn):
        '''
        Combine all the bias terms into one power spectrum,
        where alpha is the counterterm and sn the shot noise/stochastic contribution.
        
        Three options, for
        
        (1) Full one-loop bias expansion (third order bias)
        (2) only quadratic bias, including shear
        (3) only density bias
        
        If (2) or (3), i.e. the class is set such that shear=False or third_order=False then the bs
        and b3 parameters are not used.
        
        '''
        arr = self.pktable

        if self.third_order:
            bias_monomials = np.array([
                1, b1, b1**2, b2, b1 * b2, b2**2, bs, b1 * bs, b2 * bs, bs**2,
                b3, b1 * b3
            ])
        elif self.shear:
            bias_monomials = np.array([
                1, b1, b1**2, b2, b1 * b2, b2**2, bs, b1 * bs, b2 * bs, bs**2
            ])
        else:
            bias_monomials = np.array([1, b1, b1**2, b2, b1 * b2, b2**2])

        kv = arr[:, 0]
        za = arr[:, -1]
        pktemp = np.copy(arr)[:, 1:-1]

        res = np.sum(pktemp * bias_monomials, axis=1) + alpha * kv**2 * za + sn

        return kv, res

    def export_wisdom(self, wisdom_file='./wisdom.npy'):
        self.sph.export_wisdom(wisdom_file=wisdom_file)
class VelocityMoments(CLEFT):
    '''
    Class based on cleft_fftw to compute pairwise velocity moments.
    '''
    def __init__(self, *args, beyond_gauss=False, **kw):
        '''
        If beyond_gauss = True computes the third and fourth moments, otherwise
        default is to enable calculation of P(k), v(k) and sigma(k).
        
        Other keywords the same as the cleft_fftw class. Go look there!
        '''

        # Set up the configuration space quantities
        CLEFT.__init__(self, *args, **kw)

        self.beyond_gauss = beyond_gauss

        self.setup_onedot()
        self.setup_twodots()

        # v12 and sigma12 only have a subset of the bias contributions so we don't need to have as many FFTs
        if self.third_order:
            self.num_vel_components = 8
            self.vii = np.array([0, 1, 2, 3, 4, 6, 7, 10]) + 1
            self.num_spar_components = 5
            self.sparii = np.array([0, 1, 2, 3, 6]) + 1
            self.num_strace_components = 5
            self.straceii = np.array([0, 1, 2, 3, 6]) + 1
        elif self.shear:
            self.num_vel_components = 7
            self.vii = np.array([0, 1, 2, 3, 4, 6, 7]) + 1
            self.num_spar_components = 5
            self.sparii = np.array([0, 1, 2, 3, 6]) + 1
            self.num_strace_components = 5
            self.straceii = np.array([0, 1, 2, 3, 6]) + 1
        else:
            self.num_vel_components = 5
            self.vii = np.array([0, 1, 2, 3, 4]) + 1
            self.num_spar_components = 4
            self.sparii = np.array([0, 1, 2, 3]) + 1
            self.num_strace_components = 4
            self.straceii = np.array([0, 1, 2, 3]) + 1

        self.sph_v = SphericalBesselTransform(self.qint,
                                              L=self.jn,
                                              ncol=(self.num_vel_components),
                                              threads=self.threads,
                                              import_wisdom=self.import_wisdom,
                                              wisdom_file=self.wisdom_file)
        self.sph_spar = SphericalBesselTransform(
            self.qint,
            L=self.jn,
            ncol=(self.num_spar_components),
            threads=self.threads,
            import_wisdom=self.import_wisdom,
            wisdom_file=self.wisdom_file)
        self.sph_strace = SphericalBesselTransform(
            self.qint,
            L=self.jn,
            ncol=(self.num_strace_components),
            threads=self.threads,
            import_wisdom=self.import_wisdom,
            wisdom_file=self.wisdom_file)

        if self.beyond_gauss:
            # Beyond the first two moments
            self.num_gamma_components = 2
            self.gii = np.array([
                0, 1
            ]) + 1  # gamma has matter (all loop, so lump into 0) and b1
            self.sph_gamma1 = SphericalBesselTransform(
                self.qint,
                L=self.jn,
                ncol=(self.num_gamma_components),
                threads=self.threads,
                import_wisdom=self.import_wisdom,
                wisdom_file=self.wisdom_file)
            self.sph_gamma2 = SphericalBesselTransform(
                self.qint,
                L=self.jn,
                ncol=(self.num_gamma_components),
                threads=self.threads,
                import_wisdom=self.import_wisdom,
                wisdom_file=self.wisdom_file)

            # fourth moment
            self.num_kappa_components = 3
            self.kii = np.array(
                [0, 1, 2]) + 1  # note that these are not the bias comps
            self.sph_kappa = SphericalBesselTransform(
                self.qint,
                L=self.jn,
                ncol=(self.num_kappa_components),
                threads=self.threads,
                import_wisdom=self.import_wisdom,
                wisdom_file=self.wisdom_file)

    def update_power_spectrum(self, k, p):
        '''
        Same as the one in cleft_fftw but also do the velocities.
        '''
        super(VelocityMoments, self).update_power_spectrum(k, p)
        self.setup_onedot()
        self.setup_twodots()
        self.setup_threedots()

    def setup_onedot(self):
        '''
        Create quantities linear in f. All quantities are with f = 1, since converting back is trivial.
        '''
        self.Xdot = self.Xlin
        self.sigmadot = self.Xdot[-1]
        self.Ydot = self.Ylin

        self.Vdot = 4. / 3 * self.Vloop  # these are only the symmetrized version since all we need...
        self.Tdot = 4. / 3 * self.Tloop  # is k_i k_j k_k W_{ijk}

        self.Udot = self.Ulin
        self.Uloopdot = 3 * self.U3

        self.U11dot = 2 * self.U11
        self.U20dot = 2 * self.U20

        # some one loop terms have to be explicitly set to zero
        if self.one_loop:
            self.Xloopdot = (4 * self.qf.Xloop13 +
                             2 * self.qf.Xloop22) * self.one_loop
            self.sigmaloopdot = self.Xloopdot[-1]
            self.Yloopdot = (4 * self.qf.Yloop13 +
                             2 * self.qf.Yloop22) * self.one_loop
            self.X10dot = 1.5 * self.X10
            self.sigma10dot = self.X10dot[-1]
            self.Y10dot = 1.5 * self.Y10
        else:
            self.Xloopdot = 0
            self.sigmaloopdot = 0
            self.Yloopdot = 0
            self.X10dot = 0
            self.sigma10dot = 0
            self.Y10dot = 0

        if self.shear:
            self.Us2dot = 2 * self.Us2
            self.V12dot = self.V
            self.Xs2dot = self.Xs2
            self.sigmas2dot = self.Xs2dot[-1]
            self.Ys2dot = self.Ys2

        if self.third_order:
            self.Ub3dot = self.Ub3

    def setup_twodots(self):
        '''
        Same as onedot but now for those quadratic in f.
        '''
        self.Xddot = self.Xlin
        self.sigmaddot = self.Xddot[-1]
        self.Yddot = self.Ylin

        # Here we will need two forms, one symmetrized:
        self.Vddot = 5. / 3 * self.Vloop  #these are only the symmetrized version since all we need...
        self.Tddot = 5. / 3 * self.Tloop  # is k_i k_j k_k W_{ijk}

        # Explicitly set certain terms to zero if not one loop
        if self.one_loop:
            self.Xloopddot = (4 * self.qf.Xloop22 +
                              6 * self.qf.Xloop13) * self.one_loop
            self.sigmaloopddot = self.Xloopddot[-1]
            self.Yloopddot = (4 * self.qf.Yloop22 +
                              6 * self.qf.Yloop13) * self.one_loop

            self.X10ddot = 2 * self.X10
            self.sigma10ddot = self.X10ddot[-1]
            self.Y10ddot = 2 * self.Y10

            # and the other from k_i \delta_{jk} \ddot{W}_{ijk}
            self.kdelta_Wddot = (18 * self.qf.V1loop112 + 7 * self.qf.V3loop112
                                 + 5 * self.qf.Tloop112) * self.one_loop
        else:
            self.Xloopddot = 0
            self.sigmaloopddot = 0
            self.Yloopddot = 0
            self.X10ddot = 0
            self.sigma10ddot = 0
            self.Y10ddot = 0
            self.kdelta_Wddot = 0

        if self.shear:
            self.Xs2ddot = self.Xs2
            self.sigmas2ddot = self.Xs2ddot[-1]
            self.Ys2ddot = self.Ys2

    def setup_threedots(self):
        self.Vdddot = 2 * self.Vloop
        self.Tdddot = 2 * self.Tloop

    def v_integrals(self, k):
        '''
        Gives bias contributions to v(k) at a given k.
        '''

        ksq = k**2
        kcu = k**3
        expon = np.exp(-0.5 * ksq * (self.XYlin - self.sigma))
        suppress = np.exp(-0.5 * ksq * self.sigma)

        ret = np.zeros(self.num_vel_components)

        bias_integrands = np.zeros((self.num_vel_components, self.N))

        for l in range(self.jn):
            # l-dep functions
            mu1fac = (l > 0) / (k * self.yq)
            mu2fac = 1. - 2. * l / ksq / self.Ylin
            mu3fac = mu1fac * (1. - 2. * (l - 1) / ksq / self.Ylin
                               )  # mu3 terms start at j1 so l -> l-1

            bias_integrands[0, :] = k * (
                self.Xdot + self.Xloopdot + mu2fac *
                (self.Ydot + self.Yloopdot)) - 0.5 * ksq * (
                    mu1fac * self.Vdot + mu3fac * self.Tdot)  # matter
            bias_integrands[1, :] = 2 * (
                -ksq * self.Ulin *
                (mu1fac * self.Xdot + mu3fac * self.Ydot) + mu1fac *
                (self.Udot + self.Uloopdot) + k *
                (self.X10dot + mu2fac * self.Y10dot))  # b1
            bias_integrands[
                2, :] = 2 * k * mu2fac * self.Ulin * self.Udot + mu1fac * self.U11dot + k * self.corlin * (
                    self.Xdot + self.Ydot * mu2fac)  # b1sq
            bias_integrands[
                3, :] = 2 * k * self.Ulin * self.Udot * mu2fac + self.U20dot * mu1fac  # b2
            bias_integrands[
                4, :] = 2 * self.corlin * self.Udot * mu1fac  # b1b2

            if self.shear or self.third_order:
                bias_integrands[5, :] = 2 * self.Us2dot * mu1fac + 2 * k * (
                    self.Xs2dot + mu2fac * self.Ys2dot
                )  #bs: the second factor used to miss a factor of two
                bias_integrands[6, :] = 2 * self.V12dot * mu1fac  #b1 bs
            if self.third_order:
                bias_integrands[7, :] = 2 * self.Ub3dot * mu1fac

            # multiply by IR exponent
            if l == 0:
                bias_integrands = bias_integrands * expon
                bias_integrands -= bias_integrands[:,
                                                   -1][:,
                                                       None]  # note that expon(q = infinity) = 1
            else:
                bias_integrands = bias_integrands * expon * self.yq**l

            # do FFTLog
            ktemps, bias_ffts = self.sph_v.sph(l, bias_integrands)
            ret += k**l * interp1d(ktemps, bias_ffts)(k)

        return 4 * suppress * np.pi * ret

    def spar_integrals(self, k):
        '''
        Gives bias contributions to \sigma_\parallel at a given k.
        '''
        ksq = k**2
        kcu = k**3
        expon = np.exp(-0.5 * ksq * (self.XYlin - self.sigma))
        suppress = np.exp(-0.5 * ksq * self.sigma)

        ret = np.zeros(self.num_spar_components)

        bias_integrands = np.zeros((self.num_spar_components, self.N))

        for l in range(self.jn):
            # l-dep functions
            mu1fac = (l > 0) / (k * self.yq)
            mu2fac = 1. - 2. * l / ksq / self.Ylin
            mu3fac = mu1fac * (1. - 2. * (l - 1) / ksq / self.Ylin
                               )  # mu3 terms start at j1 so l -> l-1
            mu4fac = 1 - 4 * l / ksq / self.Ylin + 4 * l * (l - 1) / (
                ksq * self.Ylin)**2

            bias_integrands[
                0, :] = self.Xddot + self.Yddot * mu2fac + self.Xloopddot - ksq * self.Xdot**2 + (
                    self.Yloopddot - 2 * ksq * self.Xdot *
                    self.Ydot) * mu2fac - ksq * self.Ydot**2 * mu4fac - k * (
                        mu1fac * self.Vddot + mu3fac * self.Tddot)  # matter
            bias_integrands[1, :] = 2 * (
                self.X10ddot - k *
                (self.Ulin * self.Xddot + 2 * self.Udot * self.Xdot) * mu1fac +
                self.Y10ddot * mu2fac - k *
                (self.Ulin * self.Yddot + 2 * self.Udot * self.Ydot) * mu3fac
            )  # b1
            bias_integrands[2, :] = self.corlin * self.Xddot + (
                self.corlin * self.Yddot + 2 * self.Udot**2) * mu2fac  # b1sq
            bias_integrands[3, :] = 2 * self.Udot**2 * mu2fac  # b2

            if self.shear or self.third_order:
                bias_integrands[4, :] = 2 * (
                    self.Xs2ddot + self.Ys2ddot * mu2fac)  # bs

            # multiply by IR exponent
            if l == 0:
                bias_integrands = bias_integrands * expon
                bias_integrands -= bias_integrands[:,
                                                   -1][:,
                                                       None]  # note that expon(q = infinity) = 1
            else:
                bias_integrands = bias_integrands * expon * self.yq**l

            # do FFTLog
            ktemps, bias_ffts = self.sph_spar.sph(l, bias_integrands)
            ret += k**l * interp1d(ktemps, bias_ffts)(k)

        return 4 * suppress * np.pi * ret

    def strace_integrals(self, k):
        '''
            Gives bias contributions to \sigma_\parallel at a given k.
            '''
        ksq = k**2
        kcu = k**3
        expon = np.exp(-0.5 * ksq * (self.XYlin - self.sigma))
        suppress = np.exp(-0.5 * ksq * self.sigma)

        ret = np.zeros(self.num_strace_components)

        bias_integrands = np.zeros((self.num_strace_components, self.N))

        for l in range(self.jn):
            # l-dep functions
            mu1fac = (l > 0) / (k * self.yq)
            mu2fac = 1. - 2. * l / ksq / self.Ylin
            mu3fac = mu1fac * (1. - 2. * (l - 1) / ksq / self.Ylin
                               )  # mu3 terms start at j1 so l -> l-1

            bias_integrands[0, :] = (
                3 * self.Xddot + self.Yddot
            ) + 3 * self.Xloopddot + self.Yloopddot - ksq * self.Xdot**2 - ksq * (
                self.Ydot**2 + 2 * self.Xdot *
                self.Ydot) * mu2fac - k * self.kdelta_Wddot * mu1fac  # za
            bias_integrands[1, :] = 2 * (
                (3 * self.X10ddot + self.Y10ddot) - k * self.Ulin *
                (3 * self.Xddot + self.Yddot) * mu1fac - 2 * k * self.Udot *
                (self.Xdot + self.Ydot) * mu1fac)  # b1
            bias_integrands[2, :] = self.corlin * (
                3 * self.Xddot + self.Yddot) + 2 * self.Udot**2  # b1sq
            bias_integrands[3, :] = 2 * self.Udot**2  # b2

            if self.shear or self.third_order:
                bias_integrands[4, :] = 2 * (3 * self.Xs2ddot + self.Ys2ddot)

            if l == 0:
                bias_integrands = bias_integrands * expon
                bias_integrands -= bias_integrands[:,
                                                   -1][:,
                                                       None]  # note that expon(q = infinity) = 1
            else:
                bias_integrands = bias_integrands * expon * self.yq**l

            # do FFTLog
            ktemps, bias_ffts = self.sph_strace.sph(l, bias_integrands)
            ret += k**l * interp1d(ktemps, bias_ffts)(k)

        return 4 * suppress * np.pi * ret

    def gamma1_integrals(self, k):
        '''
        Gives bias contributions to Im[\hk_i \hk_j \hk_k \gamma_{ijk}]
        '''
        ksq = k**2
        kcu = k**3
        expon = np.exp(-0.5 * ksq * (self.XYlin - self.sigma))
        suppress = np.exp(-0.5 * ksq * self.sigma)

        ret = np.zeros(self.num_gamma_components)

        bias_integrands = np.zeros((self.num_gamma_components, self.N))

        #zero_lags = np.array([self.sigmadot*self.sigmaddot])

        for l in range(self.jn):
            # l-dep functions
            mu1fac = (l > 0) / (k * self.yq)
            mu2fac = 1. - 2. * l / ksq / self.Ylin
            mu3fac = mu1fac * (1. - 2. * (l - 1) / ksq / self.Ylin
                               )  # mu3 terms start at j1 so l -> l-1
            mu4fac = 1 - 4 * l / ksq / self.Ylin + 4 * l * (l - 1) / (
                ksq * self.Ylin)**2

            bias_integrands[
                0, :] = self.Vdddot * mu1fac + self.Tdddot * mu3fac + 3 * k * (
                    self.Xdot * self.Xddot +
                    (self.Xdot * self.Yddot + self.Ydot * self.Xddot) * mu2fac
                    + self.Ydot * self.Yddot * mu4fac)  # matter
            bias_integrands[1, :] = (
                6 * self.Udot * (self.Xddot * mu1fac + self.Yddot * mu3fac)
            )  # b1

            # multiply by IR exponent
            if l == 0:
                bias_integrands = bias_integrands * expon
                bias_integrands -= bias_integrands[:,
                                                   -1][:,
                                                       None]  # note that expon(q = infinity) = 1
            else:
                bias_integrands = bias_integrands * expon * self.yq**l

            # do FFTLog
            ktemps, bias_ffts = self.sph_gamma1.sph(l, bias_integrands)
            ret += k**l * interp1d(ktemps, bias_ffts)(k)

        return 4 * suppress * np.pi * ret

    def gamma2_integrals(self, k):
        '''
            Gives bias contributions to Im[ \hk_i \delta_{jk} \gamma_{ijk} ]
            '''
        ksq = k**2
        kcu = k**3
        expon = np.exp(-0.5 * ksq * (self.XYlin - self.sigma))
        suppress = np.exp(-0.5 * ksq * self.sigma)

        ret = np.zeros(self.num_gamma_components)

        bias_integrands = np.zeros((self.num_gamma_components, self.N))

        #zero_lags = np.array([5*self.sigmadot*self.sigmaddot])

        for l in range(self.jn):
            # l-dep functions
            mu1fac = (l > 0) / (k * self.yq)
            mu2fac = 1. - 2. * l / ksq / self.Ylin

            bias_integrands[
                0, :] = (5 * self.Vdddot / 3 + self.Tdddot) * mu1fac + k * (
                    5 * self.Xdot * self.Xddot + self.Xdot * self.Yddot +
                    (2 * self.Xdot * self.Yddot + self.Ydot *
                     (5 * self.Xddot + 3 * self.Yddot)) * mu2fac)  # matter
            bias_integrands[1, :] = (2 * self.Udot *
                                     (5 * self.Xddot + 3 * self.Yddot) * mu1fac
                                     )  # b1

            # multiply by IR exponent
            if l == 0:
                bias_integrands = bias_integrands * expon
                bias_integrands -= bias_integrands[:,
                                                   -1][:,
                                                       None]  # note that expon(q = infinity) = 1
            else:
                bias_integrands = bias_integrands * expon * self.yq**l

            # do FFTLog
            ktemps, bias_ffts = self.sph_gamma2.sph(l, bias_integrands)
            ret += k**l * interp1d(ktemps, bias_ffts)(k)

        return 4 * suppress * np.pi * ret

    def kappa_integrals(self, k):
        '''
        Since kappa_ijkl only involes one term we can just do them all in one go.
        
        The contractions are
        (1) \delta_{ij} \delta_{kl}
        (2) \hk_i \hk_j \delta_{kl}
        (3)\hk_i \hk_j \hk_k \hk_l
        where \hk = \hat{k} is the unit vector of k.
        
        '''
        ksq = k**2
        kf = k**4
        expon = np.exp(-0.5 * ksq * (self.XYlin - self.sigma))
        suppress = np.exp(-0.5 * ksq * self.sigma)

        ret = np.zeros(self.num_kappa_components)

        bias_integrands = np.zeros((self.num_kappa_components, self.N))

        for l in range(self.jn):
            # l-dep functions
            mu2fac = 1. - 2. * l / ksq / self.Ylin
            mu4fac = 1 - 4 * l / ksq / self.Ylin + 4 * l * (l - 1) / (
                ksq * self.Ylin)**2

            bias_integrands[
                0, :] = 15 * self.Xddot**2 + 10 * self.Xddot * self.Yddot + 3 * self.Yddot**2
            bias_integrands[
                1, :] = 5 * self.Xddot**2 + self.Xddot * self.Yddot + (
                    7 * self.Xddot * self.Yddot + 3 * self.Yddot**2) * mu2fac
            bias_integrands[
                2, :] = 3 * self.Xddot**2 + 6 * self.Xddot * self.Yddot * mu2fac + 3 * self.Yddot**2 * mu4fac

            if l == 0:
                bias_integrands = bias_integrands * expon
                bias_integrands -= bias_integrands[:,
                                                   -1][:,
                                                       None]  # note that expon(q = infinity) = 1
            else:
                bias_integrands = bias_integrands * expon * self.yq**l

            # do FFTLog
            ktemps, bias_ffts = self.sph_kappa.sph(l, bias_integrands)
            ret += k**l * interp1d(ktemps, bias_ffts)(k)

        return 4 * suppress * np.pi * ret

    def make_table(self, kmin=1e-3, kmax=3, nk=100, func_name='power'):
        '''
            Make a table of different terms of P(k), v(k), sigma(k) between a given
            'kmin', 'kmax' and for 'nk' equally spaced values in log10 of k
            This is the most time consuming part of the code.
        '''

        if func_name == 'power':
            func = self.p_integrals
            iis = np.arange(1 + self.num_power_components)
        elif func_name == 'velocity':
            func = self.v_integrals
            iis = self.vii
        elif func_name == 'spar':
            func = self.spar_integrals
            iis = self.sparii
        elif func_name == 'strace':
            func = self.strace_integrals
            iis = self.straceii
        elif func_name == 'gamma1':
            func = self.gamma1_integrals
            iis = self.gii
        elif func_name == 'gamma2':
            func = self.gamma2_integrals
            iis = self.gii
        elif func_name == 'kappa':
            func = self.kappa_integrals
            iis = self.kii

        pktable = np.zeros([
            nk, self.num_power_components + 1 - 1
        ])  # one column for ks, but last column in power now the counterterm
        kv = np.logspace(np.log10(kmin), np.log10(kmax), nk)
        pktable[:, 0] = kv[:]
        for foo in range(nk):
            pktable[foo, iis] = func(kv[foo])

        return pktable

    def make_vtable(self, kmin=1e-3, kmax=3, nk=100):
        self.vktable = self.make_table(kmin=kmin,
                                       kmax=kmax,
                                       nk=nk,
                                       func_name='velocity')

    def make_spartable(self, kmin=1e-3, kmax=3, nk=100):
        self.sparktable = self.make_table(kmin=kmin,
                                          kmax=kmax,
                                          nk=nk,
                                          func_name='spar')

    def make_stracetable(self, kmin=1e-3, kmax=3, nk=100):
        self.stracektable = self.make_table(kmin=kmin,
                                            kmax=kmax,
                                            nk=nk,
                                            func_name='strace')

    def make_gamma1table(self, kmin=1e-3, kmax=3, nk=100):
        self.gamma1ktable = self.make_table(kmin=kmin,
                                            kmax=kmax,
                                            nk=nk,
                                            func_name='gamma1')

    def make_gamma2table(self, kmin=1e-3, kmax=3, nk=100):
        self.gamma2ktable = self.make_table(kmin=kmin,
                                            kmax=kmax,
                                            nk=nk,
                                            func_name='gamma2')

    def make_kappatable(self, kmin=1e-3, kmax=3, nk=100):
        self.kappaktable = self.make_table(kmin=kmin,
                                           kmax=kmax,
                                           nk=nk,
                                           func_name='kappa')

    def convert_sigma_bases(self, basis='Legendre'):
        '''
        Function to convert Tr\sigma and \sigma_\par to the desired basis.
        
        These are:
        - Legendre
        
        sigma = sigma_0 delta_ij + sigma_2 (3 k_i k_j - delta_ij)/2
        
        - Polynomial
        
        sigma = sigma_0 delta_ij + sigma_2 k_i k_j
        
        - los (line of sight, note that sigma_0 = kpar and sigma_2 = kperp in this case)
        
        sigma = sigma_0 k_i k_j + sigma_2 (delta_ij - k_i k_j)/2
        
        '''
        if self.sparktable is None or self.stracektable is None:
            print("Error: Need to compute sigma before changing bases!")
            return 0

        kv = self.sparktable[:, 0]

        if basis == 'Legendre':
            self.s0 = self.stracektable / 3.
            self.s2 = self.sparktable - self.s0
            self.s0[:, 0] = kv
            self.s2[:, 0] = kv

        if basis == 'Polynomial':
            self.s0 = 0.5 * (self.stracektable - self.sparktable)
            self.s2 = 0.5 * (3 * self.sparktable - self.stracektable)
            self.s0[:, 0] = kv
            self.s2[:, 0] = kv

        if basis == 'los':
            self.s0 = self.sparktable
            self.s2 = self.stracektable - self.sparktable
            self.s0[:, 0] = kv
            self.s2[:, 0] = kv

    def convert_gamma_bases(self, basis='Polynomial'):
        '''
        Translates the contraction of gamma into the polynomial/legendre basis
        given by Im[gamma] = g3 \hk_i \hk_j \hk_k + g1 (\hk_i \delta{ij} + et cycl) / 3
        
        '''
        if self.gamma1ktable is None or self.gamma2ktable is None:
            print("Error: Need to compute sigma before changing bases!")
            return 0

        kv = self.gamma1ktable[:, 0]

        # Polynomial basis
        if basis == 'Polynomial':
            self.g1 = 1.5 * self.gamma2ktable - 1.5 * self.gamma1ktable
            self.g3 = 2.5 * self.gamma1ktable - 1.5 * self.gamma2ktable

        if basis == 'Legendre':
            self.g1 = 0.6 * self.gamma2ktable
            self.g3 = 2.5 * self.gamma1ktable - 1.5 * self.gamma2ktable
            #self.g1 = self.g1 + 0.6*self.g3
            #self.g3 = 0.4 * self.g3

        self.g1[:, 0] = kv
        self.g3[:, 0] = kv

    def convert_kappa_bases(self, basis='Polynomial'):
        '''
        Translates the contraction of gamma into the polynomial basis
        given by kappa = kappa0 / 3 * (delta_ij delta_kl + perms) + kappa2 / 6 * (k_i k_j delta_kl + perms) + kappa4 * k_i k_j k_k k_l.
        '''

        if self.kappaktable is None:
            print("Error: Need to compute kappa before changing bases!")
            return 0

        self.kv = self.kappaktable[:, 0]

        self.k0 = 3. / 8 * (self.kappaktable[:, 1] - 2 * self.kappaktable[:, 2]
                            + self.kappaktable[:, 3])
        self.k2 = 3. / 4 * (-self.kappaktable[:, 1] +
                            6 * self.kappaktable[:, 2] -
                            5 * self.kappaktable[:, 3])
        self.k4 = 1. / 8 * (3 * self.kappaktable[:, 1] -
                            30 * self.kappaktable[:, 2] +
                            35 * self.kappaktable[:, 3])

    # the following functions combine all the components into the spectra given some set
    # of bias parameters shared between P(k), v(k), sigma(k)
    # these are, in order, b1, b2, bs, alpha, alpha_v, alpha_s, alpha_s2, sn, sv, s0.

    def combine_bias_terms_vk(self, b1, b2, bs, b3, alpha_v, sv):
        '''
        Combine all the bias terms into one velocity spectrum.
        Assumes the P(k) table has already been computed.
            
        alpha_v, sv = counterterm and stochastic term.
            
        '''
        arr = self.vktable

        if self.third_order:
            bias_monomials = np.array([
                1, b1, b1**2, b2, b1 * b2, b2**2, bs, b1 * bs, b2 * bs, bs**2,
                b3, b1 * b3
            ])
        elif self.shear:
            bias_monomials = np.array([
                1, b1, b1**2, b2, b1 * b2, b2**2, bs, b1 * bs, b2 * bs, bs**2
            ])
        else:
            bias_monomials = np.array([1, b1, b1**2, b2, b1 * b2, b2**2])

        try:
            kv = arr[:, 0]
            za = self.pktable[:, -1]
        except:
            print("Compute the power spectrum table first!")

        pktemp = np.copy(arr)[:, 1:]

        res = np.sum(pktemp * bias_monomials,
                     axis=1) + alpha_v * kv * za + sv * kv

        return kv, res

    def combine_bias_terms_sk(self,
                              b1,
                              b2,
                              bs,
                              b3,
                              alpha_s0,
                              alpha_s2,
                              s0_stoch,
                              basis='Polynomial'):
        '''
        Combine all the bias terms into one velocity dispersion spectrum.
        Assumes the P(k) table has already been computed.
        
        alpha_s0, alpha_s2 = counterterm for s0 and s2, s0_stoch = stochastic term for s0.
        '''

        self.convert_sigma_bases(basis=basis)

        if self.third_order:
            bias_monomials = np.array([
                1, b1, b1**2, b2, b1 * b2, b2**2, bs, b1 * bs, b2 * bs, bs**2,
                b3, b1 * b3
            ])
        elif self.shear:
            bias_monomials = np.array([
                1, b1, b1**2, b2, b1 * b2, b2**2, bs, b1 * bs, b2 * bs, bs**2
            ])
        else:
            bias_monomials = np.array([1, b1, b1**2, b2, b1 * b2, b2**2])

        # Do the monopole
        try:
            arr = self.s0
            kv = arr[:, 0]
            za = self.pktable[:, -1]
        except:
            print("Compute the power spectrum table first!")

        pktemp = np.copy(arr)[:, 1:]
        s0 = np.sum(
            pktemp * bias_monomials, axis=1
        ) + alpha_s0 * za + s0_stoch  # here the counterterm is a zero lag and just gives P_Zel

        # and the quadratic
        arr = self.s2

        kv = arr[:, 0]
        pktemp = np.copy(arr)[:, 1:]

        s2 = np.sum(
            pktemp * bias_monomials,
            axis=1) + alpha_s2 * za  # there's now a counterterm here too!

        return kv, s0, s2

    def combine_bias_terms_gk(self,
                              b1,
                              b2,
                              bs,
                              b3,
                              alpha_g1,
                              alpha_g3,
                              basis='Polynomial'):

        self.convert_gamma_bases(basis=basis)

        if self.third_order:
            bias_monomials = np.array([
                1, b1, b1**2, b2, b1 * b2, b2**2, bs, b1 * bs, b2 * bs, bs**2,
                b3, b1 * b3
            ])
        elif self.shear:
            bias_monomials = np.array([
                1, b1, b1**2, b2, b1 * b2, b2**2, bs, b1 * bs, b2 * bs, bs**2
            ])
        else:
            bias_monomials = np.array([1, b1, b1**2, b2, b1 * b2, b2**2])

        # Do the monopole
        try:
            arr = self.g1
            kv = arr[:, 0]
            za = self.pktable[:, -1]
        except:
            print("Compute the power spectrum table first!")

        pktemp = np.copy(arr)[:, 1:]
        g1 = np.sum(
            pktemp * bias_monomials, axis=1
        ) + alpha_g1 * za / kv  # here the counterterm is a zero lag and just gives P_Zel

        # and the quadratic
        arr = self.g3

        kv = arr[:, 0]
        pktemp = np.copy(arr)[:, 1:]

        g3 = np.sum(
            pktemp * bias_monomials,
            axis=1) + alpha_g3 * za / za  # there's now a counterterm here too!

        return kv, g1, g3

    def combine_bias_terms_kk(self, alpha_k2, k0_stoch):

        try:
            kv = self.kv
            za = self.pktable[:, -1]
        except:
            print("Compute spectra first!")

        return self.kv, self.k0 + k0_stoch, self.k2 + alpha_k2 * za / kv**2, self.k4