def contract_1e_soc(f1e, fcivec, norb, nelec):

    fcinew = numpy.zeros_like(fcivec, dtype=numpy.complex64)
    goffset = 0

    for neleca in range(nelec + 1):
        nelecb = nelec - neleca

        link_indexa = cistring_soc.gen_linkstr_index_o0(range(norb), neleca)
        link_indexb = cistring_soc.gen_linkstr_index_o0(range(norb), nelecb)

        na = cistring_soc.num_strings(norb, neleca)
        nb = cistring_soc.num_strings(norb, nelecb)
        ci0 = fcivec[goffset:goffset + na * nb].reshape(na, nb)

        t1 = numpy.zeros((norb, norb, na, nb), dtype=numpy.complex64)
        for str0, tab in enumerate(link_indexa):
            for a, i, str1, sign in tab:
                t1[a, i, str1] += sign * ci0[str0]
        for str0, tab in enumerate(link_indexb):
            for a, i, str1, sign in tab:
                t1[a, i, :, str1] += sign * ci0[:, str0]

        ci0 = numpy.dot(f1e.reshape(-1), t1.reshape(-1, na * nb))
        fcinew[goffset:goffset + na * nb] = ci0.reshape(-1)
        goffset += na * nb

    return fcinew.reshape(fcivec.shape)
def make_rdm1(fcivec, norb, nelec, opt=None):

    rdm1 = numpy.zeros((2, 2, norb, norb), dtype=numpy.complex64)
    goffset = 0

    for neleca in range(nelec + 1):
        nelecb = nelec - neleca

        link_indexa = cistring_soc.gen_linkstr_index_o0(range(norb), neleca)
        link_indexb = cistring_soc.gen_linkstr_index_o0(range(norb), nelecb)

        na = cistring_soc.num_strings(norb, neleca)
        nb = cistring_soc.num_strings(norb, nelecb)
        ci0 = fcivec[goffset:goffset + na * nb].reshape(na, nb)

        #sector a,a
        #spin a
        for str0, tab in enumerate(link_indexa):
            for a, i, str1, sign in tab:
                rdm1[0, 0, a,
                     i] += sign * numpy.dot(ci0[str1].conjugate(), ci0[str0])

#sector b,b
#spin b
        for str0, tab in enumerate(link_indexb):
            for k in range(na):
                for a, i, str1, sign in tab:
                    rdm1[1, 1, a,
                         i] += sign * ci0[k, str1].conjugate() * ci0[k, str0]

#sector a,b
        if (nelecb > 0):
            #c+_a c_b
            link_indexa, vv = cistring_soc.gen_linkstr_index_o0_soc(
                range(norb), neleca, nelecb, 0)
            nna = cistring_soc.num_strings(norb, neleca + 1)
            nnb = cistring_soc.num_strings(norb, nelecb - 1)
            ciT = fcivec[goffset + na * nb:goffset + na * nb +
                         nna * nnb].reshape(nna, nnb)
            for str0, tab in enumerate(link_indexa):
                for loc, (a, i, target, sign) in enumerate(tab):
                    #print a,i,target,sign, str0, vv[str0], ci0[tstr,str0], ciT[target[0],target[1]]
                    tstr = vv[str0][loc]
                    rdm1[0, 1, a,
                         i] += sign * ci0[tstr,
                                          str0] * ciT[target[0],
                                                      target[1]].conjugate()

        if (neleca > 0):
            #c+_b c_a
            link_indexb, vv = cistring_soc.gen_linkstr_index_o0_soc(
                range(norb), neleca, nelecb, 1)
            nna = cistring_soc.num_strings(norb, neleca - 1)
            nnb = cistring_soc.num_strings(norb, nelecb + 1)
            ciT = fcivec[goffset - nna * nnb:goffset].reshape(nna, nnb)
            for str0, tab in enumerate(link_indexb):
                for loc, (a, i, target, sign) in enumerate(tab):
                    #for tstr in vv[str0]:
                    tstr = vv[str0][loc]
                    rdm1[1, 0, a,
                         i] += sign * ci0[str0,
                                          tstr] * ciT[target[0],
                                                      target[1]].conjugate()

        goffset += na * nb

    return rdm1
def make_rdm12(fcivec, norb, nelec, opt=None):

    goffset = 0
    t1 = numpy.zeros((2, 2, len(fcivec), norb, norb), dtype=numpy.complex64)
    for neleca in range(nelec + 1):
        nelecb = nelec - neleca

        link_indexa = cistring_soc.gen_linkstr_index_o0(range(norb), neleca)
        link_indexb = cistring_soc.gen_linkstr_index_o0(range(norb), nelecb)

        na = cistring_soc.num_strings(norb, neleca)
        nb = cistring_soc.num_strings(norb, nelecb)
        ci0 = fcivec[goffset:goffset + na * nb].reshape(na, nb)

        #sector a,a
        #spin a
        for str0, tab in enumerate(link_indexa):
            for a, i, str1, sign in tab:
                for k in range(nb):
                    t1[0, 0, goffset + str1 * nb + k, a,
                       i] += sign * ci0[str0, k]

        #sector b,b
        #spin b
        for str0, tab in enumerate(link_indexb):
            for a, i, str1, sign in tab:
                for k in range(na):
                    t1[1, 1, goffset + k * nb + str1, a,
                       i] += sign * ci0[k, str0]

        #sector a,b
        if (nelecb > 0):
            #c+_a c_b
            link_indexa, vv = cistring_soc.gen_linkstr_index_o0_soc(
                range(norb), neleca, nelecb, 0)
            nna = cistring_soc.num_strings(norb, neleca + 1)
            nnb = cistring_soc.num_strings(norb, nelecb - 1)

            for str0b, tab in enumerate(link_indexa):
                for loc, (a, i, target, sign) in enumerate(tab):
                    str0a = vv[str0][loc]
                    #print a,i,target,sign, str0, vv[str0], ci0[tstr,str0], ciT[target[0],target[1]]
                    targetloc = goffset + na * nb + nnb * target[0] + target[1]
                    t1[0, 1, targetloc, a, i] += sign * ci0[str0a, str0b]

        if (neleca > 0):
            #c+_b c_a
            link_indexb, vv = cistring_soc.gen_linkstr_index_o0_soc(
                range(norb), neleca, nelecb, 1)
            nna = cistring_soc.num_strings(norb, neleca - 1)
            nnb = cistring_soc.num_strings(norb, nelecb + 1)

            for str0a, tab in enumerate(link_indexb):
                for loc, (a, i, target, sign) in enumerate(tab):
                    str0b = vv[str0][loc]
                    targetloc = goffset - nna * nnb + nnb * target[0] + target[
                        1]
                    t1[1, 0, targetloc, a, i] += sign * ci0[str0a, str0b]

        goffset += na * nb

    #Now construct the different sectors for rdm2
    rdm1 = numpy.zeros((2, 2, norb, norb), dtype=numpy.complex64)
    rdm2 = numpy.zeros((2, 2, 2, 2, norb, norb, norb, norb),
                       dtype=numpy.complex64)

    for a in range(2):
        for b in range(2):
            rdm1[a, b] = numpy.einsum('m,mij->ij', fcivec.conjugate(), t1[a,
                                                                          b])
            for c in range(2):
                for d in range(2):
                    rdm2[a, b, c,
                         d] = numpy.einsum('mij,mkl->jikl',
                                           t1[a, b].conjugate(), t1[c, d])

    #Remove additional contribution
    for a in range(2):
        for b in range(2):
            for c in range(2):
                for k in range(norb):
                    rdm2[a, b, b, c, :, k, k, :] -= rdm1[a, c]

    return rdm1, rdm2
def make_hdiag_soc(h1e, g2e, norb, nelec, opt=None):
    '''
    if isinstance(nelec, (int, numpy.number)):
        nelecb = nelec//2
        neleca = nelec - nelecb
    else:
        neleca, nelecb = nelec
    '''
    hdiag = []

    g2e = pyscf.ao2mo.restore(1, g2e, norb)
    diagj = numpy.einsum('iijj->ij', g2e)
    diagk = numpy.einsum('ijji->ij', g2e)

    for neleca in range(nelec + 1):
        nelecb = nelec - neleca

        offseta = cistring_soc.num_strings_soc(norb, neleca)
        offsetb = cistring_soc.num_strings_soc(norb, nelecb)

        if neleca == 0:

            link_indexb = cistring_soc.gen_linkstr_index_o0(
                range(norb), nelecb)
            occslistb = [tab[:nelecb, 0] for tab in link_indexb]
            aocc = offseta
            for boccb in occslistb:
                e1 = h1e[boccb, boccb].sum()
                e2 = diagj[boccb][:, boccb].sum() - diagk[boccb][:,
                                                                 boccb].sum()
                hdiag.append(e1 + e2 * .5)

        elif nelecb == 0:

            link_indexa = cistring_soc.gen_linkstr_index_o0(
                range(norb), neleca)
            occslista = [tab[:neleca, 0] for tab in link_indexa]

            bocc = offsetb
            for aoccb in occslista:
                e1 = h1e[aoccb, aoccb].sum()
                e2 = diagj[aoccb][:, aoccb].sum() - diagk[aoccb][:,
                                                                 aoccb].sum()
                hdiag.append(e1 + e2 * .5)

        else:

            link_indexa = cistring_soc.gen_linkstr_index_o0(
                range(norb), neleca)
            link_indexb = cistring_soc.gen_linkstr_index_o0(
                range(norb), nelecb)

            occslista = [tab[:neleca, 0] for tab in link_indexa]
            occslistb = [tab[:nelecb, 0] for tab in link_indexb]

            for aoccb in occslista:
                for boccb in occslistb:
                    e1 = h1e[aoccb, aoccb].sum() + h1e[boccb, boccb].sum()
                    e2 = diagj[aoccb][:,aoccb].sum() + diagj[aoccb][:,boccb].sum() \
                         + diagj[boccb][:,aoccb].sum() + diagj[boccb][:,boccb].sum() \
                         - diagk[aoccb][:,aoccb].sum() - diagk[boccb][:,boccb].sum()
                    hdiag.append(e1 + e2 * .5)

    return numpy.array(hdiag)