def _calc_Mscal(comm, wEBlm, lmax=None, verbose=True):
    '''Calculate the temperature/V polarization mode-mixing matrix
    '''

    rank = comm.Get_rank()
    size = comm.Get_size()

    if lmax is None:
        lmax = H.Alm.getlmax(len(wEBlm[0, :]))

    Mscal = np.zeros([lmax+1, lmax+1])
    wl = H.alm2cl(wEBlm[0, :])

    for l1 in range(2+rank, lmax+1, size):
        if verbose:
            print("Scal l1 = ", l1)
        for l2 in range(2, lmax+1):
            l3min = np.abs(l1-l2)
            l3max = np.abs(l1+l2)

            l3vals = np.arange(l3min, l3max+1)

            # Calculates Wigner 3j symbol for all valid l3
            JT = wc.wigner3j_vect(2*l1, 2*l2, 0, 0)

            # Since we only have wl up to lmax we must ignore all terms that
            # have l3 > lmax
            idx = l3vals <= lmax

            Mscal[l1, l2] = np.sum((2*l3vals[idx]+1)*wl[l3vals[idx]] *
                                   JT[idx]**2)

            normfact = (2.0*l2+1.0) / (4.0*np.pi)
            Mscal[l1, l2] *= normfact

    Mscal = comm.allreduce(Mscal)

    return Mscal
Exemple #2
0
def _correct_TE(apfunction, x, w):
    '''Calculate an amplitude correction of the TE cross-correlation
    due to the apodization of the correlation function.

    Parameters
    ----------
    apfunction : array
        The function dependent on angle used to apodize the correlation
        function.

    x : array
        cos(angle), the Legendre-Gauss quadrature points

    w : array
        the weights needed for the Legendre-Gauss quadrature

    Returns
    -------
    fact : array
        TE correction function due to the fact that the correlation function
        is apodized
    '''

    nell = len(apfunction)
    lmax = nell - 1

    #Calculate f_l from apodizing function in real space
    fl = np.zeros(nell)
    for i in range(nell):
        Pl0 = scipy.special.lpmv(0, i, x)
        fl[i] = np.sum(
            w * apfunction * Pl0) * 2 * np.pi * (2 * i + 1) / (4 * np.pi)

    kcross = np.zeros([nell, nell])

    rank = comm.rank
    size = comm.size

    for l1 in range(2 + rank, nell, size):
        for l2 in range(l1, nell):
            l3min = np.abs(l1 - l2)
            l3max = np.abs(l1 + l2)

            #runs from l3min to l3max
            wigner00 = wc.wigner3j_vect(2 * l1, 2 * l2, 2 * 0, 2 * 0)
            wigner22 = wc.wigner3j_vect(2 * l1, 2 * l2, 2 * 2, -2 * 2)

            if l3max > lmax:
                tmp = lmax - l3max
                wigner00 = wigner00[:tmp]
                wigner22 = wigner22[:tmp]

            kcross[l1, l2] = np.sum(wigner00 * wigner22 * fl[l3min:l3max + 1])
            kcross[l2, l1] = kcross[l1, l2]

            kcross[l1, l2] *= 2 * l2 + 1

            if l2 != l1:
                kcross[l2, l1] *= 2 * l1 + 1

    kcross = comm.allreduce(kcross)

    fact = np.sum(kcross, axis=1)

    return fact
def _calc_Mcross(comm, wEBlm, lmax=None, cl_type='pseudo', verbose=True):
    '''Calculate the temp-pol mode-mixing matrix.
    '''

    rank = comm.Get_rank()
    size = comm.Get_size()

    if lmax is None:
        lmax = H.Alm.getlmax(len(wEBlm[0, :]))
    
    l, m = H.Alm.getlm(lmax)

    Mcross = np.zeros([2*lmax+2, 2*lmax+2])

    a0 = 1.0
    a1 = 2.0
    a2 = 1.0

    l1_vals = range(2+rank, lmax+1, size)

    # Don't need to calculate this if we are looking at pseudo Cls, but it
    # is not the bottleneck in the calculation so I don't care
    Nl12 = H_ext._nfunc(l1_vals, 2)
    Nl11 = H_ext._nfunc(l1_vals, 1)
    Nl10 = H_ext._nfunc(l1_vals, 0)
    fact0s = Nl10 / Nl12
    fact1s = Nl11 / Nl12

    for l1, fact0, fact1 in zip(l1_vals, fact0s, fact1s):
        if verbose:
            print("Cross l1 = ", l1)

        for l2 in range(2, lmax+1):
            l3min = np.abs(l1-l2)
            l3max = np.abs(l1+l2)

            # Minimum l3 for case when m3=1,2
            l3min2 = np.max([l3min, 2])
            l3min1 = np.max([l3min, 1])

            wc_0 = wc.wigner3j_vect(2*l1, 2*l2, -2*2, 2*2)
            wc_1 = wc.wigner3j_vect(2*l1, 2*l2, 2*2, -2*2)
            Jp0 = wc_0 + wc_1
#           Jm0 = wc_0 - wc_1
            JT = wc.wigner3j_vect(2*l1, 2*l2, 0, 0)

            if cl_type == 'pure' or cl_type == 'hybrid':
                wc_0 = wc.wigner3j_vect(2*l1, 2*l2, (-2+1)*2, 2*2)
                wc_1 = wc.wigner3j_vect(2*l1, 2*l2, (2-1)*2, -2*2)
                Jp1 = wc_0 + wc_1
#               Jm1 = wc_0 - wc_1

                wc_0 = wc.wigner3j_vect(2*l1, 2*l2, (-2+2)*2, 2*2)
                wc_1 = wc.wigner3j_vect(2*l1, 2*l2, (2-2)*2, -2*2)
                Jp2 = wc_0 + wc_1
#               Jm2 = wc_0 - wc_1

            idx = np.all([l >= l3min, l <= l3max], axis=0)
            l_tmp = l[idx]
            m_tmp = m[idx]
            m0 = m_tmp == 0

            # sub array of wEBlm that have valid l (for l = l3)
            wEBlm_tmp = wEBlm[:, idx]

            idx_l3_0 = np.array(l_tmp - l3min, dtype=np.int)  # Jp0,Jm0
            idx_l3_1 = np.array(l_tmp - l3min1, dtype=np.int)  # Jp1,Jm1
            idx_l3_2 = np.array(l_tmp - l3min2, dtype=np.int)  # Jp2,Jm2

            # Subsets of the subset. When l3min is 0 or 1 and take the whole
            # range, Jp1(Jm1) and/or Jp2(Jm2) get non-zero values when
            # they should be zero (when l3 = 0 or 1) because negative indices
            # in idx_l3_1 and idx_l3_2 wrap around to the end
            idx1 = l_tmp >= l3min1
            idx2 = l_tmp >= l3min2

            # TE,TE, and TB,TB (2/4)
            termE = a2 * JT[idx_l3_0]*Jp0[idx_l3_0] * np.real(wEBlm_tmp[0, :]*np.conj(wEBlm_tmp[0, :]))

            if cl_type == 'hybrid' or cl_type == 'pseudo':
                Mcross[l1, l2] = 2*np.sum(termE) - np.sum(termE[m0])

            if cl_type == 'pure' or cl_type == 'hybrid':
                termE[idx1] += a1*fact1 * (JT[idx_l3_0]*Jp1[idx_l3_1] * np.real(wEBlm_tmp[0, :]*np.conj(wEBlm_tmp[1, :])))[idx1]
                termE[idx2] += a0*fact0 * (JT[idx_l3_0]*Jp2[idx_l3_2] * np.real(wEBlm_tmp[0, :]*np.conj(wEBlm_tmp[2, :])))[idx2]

            Mcross[l1+lmax+1, l2+lmax+1] = 2*np.sum(termE) - np.sum(termE[m0])

            if cl_type == 'pure':
                Mcross[l1, l2] = Mcross[l1+lmax+1, l2+lmax+1]

            # TE,TB and TB,TE (4/4)
            if cl_type == 'pure' or cl_type == 'hybrid':
                termE = np.zeros_like(JT[idx_l3_0], dtype=np.float64)
                termE[idx1] += a1*fact1 * (JT[idx_l3_0]*Jp1[idx_l3_1] * np.real(wEBlm_tmp[0, :]*np.conj(wEBlm_tmp[-1, :])))[idx1]
                termE[idx2] += a0*fact0 * (JT[idx_l3_0]*Jp2[idx_l3_2] * np.real(wEBlm_tmp[0, :]*np.conj(wEBlm_tmp[-2, :])))[idx2]
                Mcross[l1+lmax+1, l2] = 2*np.sum(termE) - np.sum(termE[m0])

            if cl_type == 'pure':
                Mcross[l1, l2+lmax+1] = Mcross[l1+lmax+1, l2]

            normfact = (2.0*l2+1.0) / (4.0*np.pi)

            Mcross[l1, l2] *= normfact / 2.0
            Mcross[l1+lmax+1, l2+lmax+1] *= normfact / 2.0
            Mcross[l1, l2+lmax+1] *= normfact / 2.0
            Mcross[l1+lmax+1, l2] *= -normfact / 2.0

    Mcross = comm.allreduce(Mcross)

    return Mcross
def _calc_Mpol(comm, wEBlm, lmax=None, cl_type='pseudo', verbose=True):
    '''Calculate the polarization mode mixing matrix
    '''

    rank = comm.Get_rank()
    size = comm.Get_size()

    if lmax is None:
        lmax = H.Alm.getlmax(len(wEBlm[0, :]))
    
    l, m = H.Alm.getlm(lmax)

    Mpol = np.zeros([3*lmax+3, 3*lmax+3])

    a0 = 1.0
    a1 = 2.0
    a2 = 1.0

    l1_vals = range(2+rank, lmax+1, size)

    # Don't need to calculate this if we are looking at pseudo Cls, but it
    # is not the bottleneck in the calculation so I don't care
    Nl12 = H_ext._nfunc(l1_vals, 2)
    Nl11 = H_ext._nfunc(l1_vals, 1)
    Nl10 = H_ext._nfunc(l1_vals, 0)
    fact0s = Nl10 / Nl12
    fact1s = Nl11 / Nl12

    for l1, fact0, fact1 in zip(l1_vals, fact0s, fact1s):
        if verbose:
            print("Pol l1 = ", l1)

        for l2 in range(2, lmax+1):
            l3min = np.abs(l1-l2)
            l3max = np.abs(l1+l2)

            # Minimum l3 for when m3 is 1 or 2
            l3min2 = np.max([l3min, 2])
            l3min1 = np.max([l3min, 1])

            # Wigner Symbols that we need are
            # (l1   l2 l3 ) and (l1   l2 l3) for m=0 (pseudo)
            # (-2+m 2  -m )     (2-m  -2 m ) and m=1,2 (pure)

            # m=0 term needed for pseudo and pure
            wc_0 = wc.wigner3j_vect(2*l1, 2*l2, -2*2, 2*2)
            wc_1 = wc.wigner3j_vect(2*l1, 2*l2, 2*2, -2*2)
            Jp0 = wc_0 + wc_1
            Jm0 = wc_0 - wc_1

            # m=1,2 terms are only needed for pure modes
            if (cl_type == 'pure') or (cl_type == 'hybrid'):
                wc_0 = wc.wigner3j_vect(2*l1, 2*l2, (-2+1)*2, 2*2)
                wc_1 = wc.wigner3j_vect(2*l1, 2*l2, (2-1)*2, -2*2)
                Jp1 = wc_0 + wc_1
                Jm1 = wc_0 - wc_1

                wc_0 = wc.wigner3j_vect(2*l1, 2*l2, (-2+2)*2, 2*2)
                wc_1 = wc.wigner3j_vect(2*l1, 2*l2, (2-2)*2, -2*2)
                Jp2 = wc_0 + wc_1
                Jm2 = wc_0 - wc_1

            # This allows us to remove any sum over m or l3
            idx = np.all([l >= l3min, l <= l3max], axis=0)
            l_tmp = l[idx]
            m_tmp = m[idx]
            m0 = m_tmp == 0

            wEBlm_tmp = wEBlm[:, idx]

            # Calculate the correct index in the Jp and Jm terms
            idx_l3_0 = np.array(l_tmp - l3min, dtype=np.int)  # Jp0, Jm0
            idx_l3_1 = np.array(l_tmp - l3min1, dtype=np.int)  # Jp1, Jm1
            idx_l3_2 = np.array(l_tmp - l3min2, dtype=np.int)  # Jp2, Jm2

            # When l3min is 0 or 1 and take the whole range, Jp1(Jm1) and/or
            # Jp2(Jm2) get non-zero values when they should be zero (when l3 =
            # 0 or 1) because negative indices in idx_l3_1 and idx_l3_2 wrap
            # around to the end
            idx1 = l_tmp >= l3min1
            idx2 = l_tmp >= l3min2

            # EE,EE and BB,BB (2/9)
            termE = a2 * wEBlm_tmp[0, :] * Jp0[idx_l3_0]
            termB = np.zeros_like(termE)

            if cl_type == 'pseudo' or cl_type == 'hybrid':
                Mpol[l1, l2] = 2*np.sum(termE*np.conj(termE))
                Mpol[l1, l2] -= np.sum(termE[m0]*np.conj(termE[m0]))

            if cl_type == 'pure' or cl_type == 'hybrid':
                termE[idx1] += a1*fact1*wEBlm_tmp[1, idx1]*Jp1[idx_l3_1][idx1]
                termE[idx2] += a0*fact0*wEBlm_tmp[2, idx2]*Jp2[idx_l3_2][idx2]
                termB[idx1] += a1*fact1*wEBlm_tmp[-1, idx1]*Jm1[idx_l3_1][idx1]
                termB[idx2] += a0*fact0*wEBlm_tmp[-2, idx2]*Jm2[idx_l3_2][idx2]

            Mpol[l1+lmax+1, l2+lmax+1] = 2*np.sum(termE*np.conj(termE) +
                                                  termB*np.conj(termB))
            Mpol[l1+lmax+1, l2+lmax+1] -= np.sum(termE[m0]*np.conj(termE[m0]) +
                                                 termB[m0]*np.conj(termB[m0]))

            if cl_type == 'pure':
                Mpol[l1, l2] = Mpol[l1+lmax+1, l2+lmax+1]

            # EE,BB and BB,EE (4/9)
            termE = a2 * wEBlm_tmp[0, :] * Jm0[idx_l3_0]
            termB = np.zeros_like(termE)

            if cl_type == 'pseudo' or cl_type == 'hybrid':
                Mpol[l1, l2+lmax+1] = 2*np.sum(termE*np.conj(termE))
                Mpol[l1, l2+lmax+1] -= np.sum(termE[m0]*np.conj(termE[m0]))

            if cl_type == 'pseudo':
                Mpol[l1+lmax+1, l2] = Mpol[l1, l2+lmax+1]

            if cl_type == 'pure' or cl_type == 'hybrid':
                termE[idx1] += a1*fact1*wEBlm_tmp[1, idx1]*Jm1[idx_l3_1][idx1]
                termE[idx2] += a0*fact0*wEBlm_tmp[2, idx2]*Jm2[idx_l3_2][idx2]
                termB[idx1] += a1*fact1*wEBlm_tmp[-1, idx1]*Jp1[idx_l3_1][idx1]
                termB[idx2] += a0*fact0*wEBlm_tmp[-2, idx2]*Jp2[idx_l3_2][idx2]

                Mpol[l1+lmax+1, l2] = 2*np.sum(termE*np.conj(termE) +
                                               termB*np.conj(termB))
                Mpol[l1+lmax+1, l2] -= np.sum(termE[m0]*np.conj(termE[m0]) +
                                              termB[m0]*np.conj(termB[m0]))

            if cl_type == 'pure':
                Mpol[l1, l2+lmax+1] = Mpol[l1+lmax+1, l2]

            # EB,EB (5/9)
            if cl_type == 'pseudo' or cl_type == 'pure':
                # EE,EE - EE,BB
                Mpol[l1+2*lmax+2, l2+2*lmax+2] = Mpol[l1, l2] - Mpol[l1, l2+lmax+1]
            else:
                termE = np.zeros_like(Jp0[idx_l3_0], dtype=np.float64)
                termE[idx2] += a0*fact0*(Jp0[idx_l3_0]*Jp2[idx_l3_2]-Jm0[idx_l3_0]*Jm2[idx_l3_2])[idx2] * np.real(wEBlm_tmp[0, idx2]*np.conj(wEBlm_tmp[2, idx2]))
                termE[idx1] += a1*fact1*(Jp0[idx_l3_0]*Jp1[idx_l3_1]-Jm0[idx_l3_0]*Jm1[idx_l3_1])[idx1] * np.real(wEBlm_tmp[0, idx1]*np.conj(wEBlm_tmp[1, idx1]))
                termE += a2 * (Jp0[idx_l3_0]*Jp0[idx_l3_0] - Jm0[idx_l3_0]*Jm0[idx_l3_0]) * np.real(wEBlm_tmp[0, :]*np.conj(wEBlm_tmp[0, :]))
                Mpol[l1+2*lmax+2, l2+2*lmax+2] = 2*np.sum(termE) - np.sum(termE[m0])

            # EE,EB and BB,EB (7/9)
            if cl_type == 'pure' or cl_type == 'hybrid':
                termE = np.zeros_like(Jp0[idx_l3_0], dtype=np.float64)
                # only non-zero for pure (and B is pure in hybrid).
                termE[idx2] += a0*a0*fact0*fact0*(Jp2[idx_l3_2]*Jp2[idx_l3_2]*np.real(wEBlm_tmp[2, :]*np.conj(wEBlm_tmp[-2, :]))
                                                  - Jm2[idx_l3_2]*Jm2[idx_l3_2]*np.real(wEBlm_tmp[-2, :]*np.conj(wEBlm_tmp[2, :])))[idx2]
                termE[idx2] += a0*a1*fact0*fact1*(Jp2[idx_l3_2]*Jp1[idx_l3_1]*np.real(wEBlm_tmp[2, :]*np.conj(wEBlm_tmp[-1, :]))
                                                  - Jm2[idx_l3_2]*Jm1[idx_l3_1]*np.real(wEBlm_tmp[-2, :]*np.conj(wEBlm_tmp[1, :])))[idx2]
                termE[idx2] += a0*a2*fact0*(-Jm2[idx_l3_2]*Jm0[idx_l3_0]*np.real(wEBlm_tmp[-2, :]*np.conj(wEBlm_tmp[0, :])))[idx2]
                termE[idx2] += a1*a0*fact1*fact0*(Jp1[idx_l3_1]*Jp2[idx_l3_2]*np.real(wEBlm_tmp[1, :]*np.conj(wEBlm_tmp[-2, :]))
                                                  - Jm1[idx_l3_1]*Jm2[idx_l3_2]*np.real(wEBlm_tmp[-1, :]*np.conj(wEBlm_tmp[2, :])))[idx2]
                termE[idx1] += a1*a1*fact1*fact1*(Jp1[idx_l3_1]*Jp1[idx_l3_1]*np.real(wEBlm_tmp[1, :]*np.conj(wEBlm_tmp[-1, :]))
                                                  - Jm1[idx_l3_1]*Jm1[idx_l3_1]*np.real(wEBlm_tmp[-1, :]*np.conj(wEBlm_tmp[1, :])))[idx1]
                termE[idx1] += a1*a2*fact1*(-Jm1[idx_l3_1]*Jm0[idx_l3_0]*np.real(wEBlm_tmp[-1, :]*np.conj(wEBlm_tmp[0, :])))[idx1]
                termE[idx2] += a2*a0*fact0*(Jp0[idx_l3_0]*Jp2[idx_l3_2]*np.real(wEBlm_tmp[0, :]*np.conj(wEBlm_tmp[-2, :])))[idx2]
                termE[idx1] += a2*a1*fact1*(Jp0[idx_l3_0]*Jp1[idx_l3_1]*np.real(wEBlm_tmp[0, :]*np.conj(wEBlm_tmp[-1, :])))[idx1]

                Mpol[l1+lmax+1, l2+2*lmax+2] = 2*np.sum(termE) - np.sum(termE[m0])

            if cl_type == 'pure':
                Mpol[l1, l2+2*lmax+2] = Mpol[l1+lmax+1, l2+2*lmax+2]

            # EB,EE and EB,BB (9/9)
            if cl_type == 'pure':
                Mpol[l1+2*lmax+2, l2] = Mpol[l1, l2+2*lmax+2]
                Mpol[l1+2*lmax+2, l2+lmax+1] = Mpol[l1, l2+2*lmax+2]
            elif cl_type == 'hybrid':
                termE = np.zeros_like(Jp0[idx_l3_0], dtype=np.float64)
                termE[idx2] += a0*fact0*(Jp0[idx_l3_0]*Jp2[idx_l3_2]*np.real(wEBlm_tmp[0, :]*np.conj(wEBlm_tmp[-2, :])))[idx2]
                termE[idx1] += a1*fact1*(Jp0[idx_l3_0]*Jp1[idx_l3_1]*np.real(wEBlm_tmp[0, :]*np.conj(wEBlm_tmp[-1, :])))[idx1]

                termB = np.zeros_like(Jm0[idx_l3_0], dtype=np.float64)
                termB[idx2] += a0*fact0*(Jm0[idx_l3_0]*Jm2[idx_l3_2]*np.real(wEBlm_tmp[-2, :]*np.conj(wEBlm_tmp[0, :])))[idx2]
                termB[idx1] += a1*fact1*(Jm0[idx_l3_0]*Jm1[idx_l3_1]*np.real(wEBlm_tmp[-1, :]*np.conj(wEBlm_tmp[0, :])))[idx1]

                Mpol[l1+2*lmax+2, l2] = 2*np.sum(termE) - np.sum(termE[m0])
                Mpol[l1+2*lmax+2, l2+lmax+1] = 2*np.sum(termB) - np.sum(termB[m0])

            normfact = (2.0*l2+1.0) / (4.0*np.pi)
            Mpol[l1, l2] *= normfact / 4.0
            Mpol[l1+lmax+1, l2+lmax+1] *= normfact / 4.0
            Mpol[l1, l2+lmax+1] *= normfact / 4.0
            Mpol[l1+lmax+1, l2] *= normfact / 4.0
            Mpol[l1+2*lmax+2, l2+2*lmax+2] *= normfact / 4.0
            Mpol[l1, l2+2*lmax+2] *= normfact / 2.0
            Mpol[l1+lmax+1, l2+2*lmax+2] *= -normfact / 2.0
            Mpol[l1+2*lmax+2, l2] *= -normfact / 4.0
            Mpol[l1+2*lmax+2, l2+lmax+1] *= normfact / 4.0

    Mpol = comm.allreduce(Mpol)

    return Mpol