def contract_1e_soc(f1e, fcivec, norb, nelec):

    fcinew = numpy.zeros_like(fcivec, dtype=numpy.complex64)
    goffset = 0

    for neleca in range(nelec + 1):
        nelecb = nelec - neleca

        link_indexa = cistring_soc.gen_linkstr_index_o0(range(norb), neleca)
        link_indexb = cistring_soc.gen_linkstr_index_o0(range(norb), nelecb)

        na = cistring_soc.num_strings(norb, neleca)
        nb = cistring_soc.num_strings(norb, nelecb)
        ci0 = fcivec[goffset:goffset + na * nb].reshape(na, nb)

        t1 = numpy.zeros((norb, norb, na, nb), dtype=numpy.complex64)
        for str0, tab in enumerate(link_indexa):
            for a, i, str1, sign in tab:
                t1[a, i, str1] += sign * ci0[str0]
        for str0, tab in enumerate(link_indexb):
            for a, i, str1, sign in tab:
                t1[a, i, :, str1] += sign * ci0[:, str0]

        ci0 = numpy.dot(f1e.reshape(-1), t1.reshape(-1, na * nb))
        fcinew[goffset:goffset + na * nb] = ci0.reshape(-1)
        goffset += na * nb

    return fcinew.reshape(fcivec.shape)
def kernel(h1e, hsoc, g2e, norb, nelec):

    h2e = absorb_h1e_soc(h1e, g2e, norb, nelec, .5)

    def hop(c):
        hc = contract_2e_soc(h2e, c, norb, nelec) + contract_1esoc_soc(
            hsoc, c, norb, nelec)
        return hc.reshape(-1)

    e = 0.0
    hdiag = make_hdiag_soc(h1e, g2e, norb, nelec)

    na = hdiag.shape[0]
    ci0 = numpy.zeros(na, numpy.complex64)

    #Need to allow mixing otherwise no overlap
    #This will not be an issue with SOC switched on

    goffset = 0
    for neleca in range(nelec + 1):
        nelecb = nelec - neleca
        na = cistring_soc.num_strings(norb, neleca)
        nb = cistring_soc.num_strings(norb, nelecb)
        goffset += na * nb
    ci0 = numpy.random.random(goffset)
    ci0 /= numpy.linalg.norm(ci0)

    precond = lambda x, e, *args: x / (hdiag - e + 1e-4)
    e, c = pyscf.lib.davidson(hop,
                              ci0.reshape(-1),
                              precond,
                              max_cycle=100,
                              max_space=100,
                              tol=1.e-15)
    return e, c
def contract_1esoc_soc(f1e, fcivec, norb, nelec):

    fcinew = numpy.zeros_like(fcivec, dtype=numpy.complex64)
    goffset = 0

    for neleca in range(nelec + 1):
        nelecb = nelec - neleca

        na = cistring_soc.num_strings(norb, neleca)
        nb = cistring_soc.num_strings(norb, nelecb)
        ci0 = fcivec[goffset:goffset + na * nb].reshape(na, nb)

        if (nelecb > 0):
            #c+_a c_b
            link_index, vv = cistring_soc.gen_linkstr_index_o0_soc(
                range(norb), neleca, nelecb, 0)
            nna = cistring_soc.num_strings(norb, neleca + 1)
            nnb = cistring_soc.num_strings(norb, nelecb - 1)
            t1 = numpy.zeros((norb, norb, nna, nnb), dtype=numpy.complex64)
            for str0, tab in enumerate(link_index):
                for loc, (a, i, target, sign) in enumerate(tab):
                    t1[a, i, target[0],
                       target[1]] += sign * ci0[vv[str0][loc], str0]
            ci0 = numpy.dot(f1e[0].reshape(-1), t1.reshape(-1, nna * nnb))
            fcinew[goffset + na * nb:goffset + na * nb +
                   nna * nnb] += ci0.reshape(-1)

#c+_b c_a
        if (neleca > 0):
            ci0 = fcivec[goffset:goffset + na * nb].reshape(na, nb)

            link_index, vv = cistring_soc.gen_linkstr_index_o0_soc(
                range(norb), neleca, nelecb, 1)
            nna = cistring_soc.num_strings(norb, neleca - 1)
            nnb = cistring_soc.num_strings(norb, nelecb + 1)
            t1 = numpy.zeros((norb, norb, nna, nnb), dtype=numpy.complex64)
            for str0, tab in enumerate(link_index):
                for loc, (a, i, target, sign) in enumerate(tab):
                    t1[a, i, target[0],
                       target[1]] += sign * ci0[str0, vv[str0][loc]]

            ci0 = numpy.dot(f1e[1].reshape(-1), t1.reshape(-1, nna * nnb))
            fcinew[goffset - nna * nnb:goffset] += ci0.reshape(-1)

        goffset += na * nb

    return fcinew.reshape(fcivec.shape)
def make_rdm12(fcivec, norb, nelec, opt=None):

    goffset = 0
    t1 = numpy.zeros((2, 2, len(fcivec), norb, norb), dtype=numpy.complex64)
    for neleca in range(nelec + 1):
        nelecb = nelec - neleca

        link_indexa = cistring_soc.gen_linkstr_index_o0(range(norb), neleca)
        link_indexb = cistring_soc.gen_linkstr_index_o0(range(norb), nelecb)

        na = cistring_soc.num_strings(norb, neleca)
        nb = cistring_soc.num_strings(norb, nelecb)
        ci0 = fcivec[goffset:goffset + na * nb].reshape(na, nb)

        #sector a,a
        #spin a
        for str0, tab in enumerate(link_indexa):
            for a, i, str1, sign in tab:
                for k in range(nb):
                    t1[0, 0, goffset + str1 * nb + k, a,
                       i] += sign * ci0[str0, k]

        #sector b,b
        #spin b
        for str0, tab in enumerate(link_indexb):
            for a, i, str1, sign in tab:
                for k in range(na):
                    t1[1, 1, goffset + k * nb + str1, a,
                       i] += sign * ci0[k, str0]

        #sector a,b
        if (nelecb > 0):
            #c+_a c_b
            link_indexa, vv = cistring_soc.gen_linkstr_index_o0_soc(
                range(norb), neleca, nelecb, 0)
            nna = cistring_soc.num_strings(norb, neleca + 1)
            nnb = cistring_soc.num_strings(norb, nelecb - 1)

            for str0b, tab in enumerate(link_indexa):
                for loc, (a, i, target, sign) in enumerate(tab):
                    str0a = vv[str0][loc]
                    #print a,i,target,sign, str0, vv[str0], ci0[tstr,str0], ciT[target[0],target[1]]
                    targetloc = goffset + na * nb + nnb * target[0] + target[1]
                    t1[0, 1, targetloc, a, i] += sign * ci0[str0a, str0b]

        if (neleca > 0):
            #c+_b c_a
            link_indexb, vv = cistring_soc.gen_linkstr_index_o0_soc(
                range(norb), neleca, nelecb, 1)
            nna = cistring_soc.num_strings(norb, neleca - 1)
            nnb = cistring_soc.num_strings(norb, nelecb + 1)

            for str0a, tab in enumerate(link_indexb):
                for loc, (a, i, target, sign) in enumerate(tab):
                    str0b = vv[str0][loc]
                    targetloc = goffset - nna * nnb + nnb * target[0] + target[
                        1]
                    t1[1, 0, targetloc, a, i] += sign * ci0[str0a, str0b]

        goffset += na * nb

    #Now construct the different sectors for rdm2
    rdm1 = numpy.zeros((2, 2, norb, norb), dtype=numpy.complex64)
    rdm2 = numpy.zeros((2, 2, 2, 2, norb, norb, norb, norb),
                       dtype=numpy.complex64)

    for a in range(2):
        for b in range(2):
            rdm1[a, b] = numpy.einsum('m,mij->ij', fcivec.conjugate(), t1[a,
                                                                          b])
            for c in range(2):
                for d in range(2):
                    rdm2[a, b, c,
                         d] = numpy.einsum('mij,mkl->jikl',
                                           t1[a, b].conjugate(), t1[c, d])

    #Remove additional contribution
    for a in range(2):
        for b in range(2):
            for c in range(2):
                for k in range(norb):
                    rdm2[a, b, b, c, :, k, k, :] -= rdm1[a, c]

    return rdm1, rdm2
def make_rdm1(fcivec, norb, nelec, opt=None):

    rdm1 = numpy.zeros((2, 2, norb, norb), dtype=numpy.complex64)
    goffset = 0

    for neleca in range(nelec + 1):
        nelecb = nelec - neleca

        link_indexa = cistring_soc.gen_linkstr_index_o0(range(norb), neleca)
        link_indexb = cistring_soc.gen_linkstr_index_o0(range(norb), nelecb)

        na = cistring_soc.num_strings(norb, neleca)
        nb = cistring_soc.num_strings(norb, nelecb)
        ci0 = fcivec[goffset:goffset + na * nb].reshape(na, nb)

        #sector a,a
        #spin a
        for str0, tab in enumerate(link_indexa):
            for a, i, str1, sign in tab:
                rdm1[0, 0, a,
                     i] += sign * numpy.dot(ci0[str1].conjugate(), ci0[str0])

#sector b,b
#spin b
        for str0, tab in enumerate(link_indexb):
            for k in range(na):
                for a, i, str1, sign in tab:
                    rdm1[1, 1, a,
                         i] += sign * ci0[k, str1].conjugate() * ci0[k, str0]

#sector a,b
        if (nelecb > 0):
            #c+_a c_b
            link_indexa, vv = cistring_soc.gen_linkstr_index_o0_soc(
                range(norb), neleca, nelecb, 0)
            nna = cistring_soc.num_strings(norb, neleca + 1)
            nnb = cistring_soc.num_strings(norb, nelecb - 1)
            ciT = fcivec[goffset + na * nb:goffset + na * nb +
                         nna * nnb].reshape(nna, nnb)
            for str0, tab in enumerate(link_indexa):
                for loc, (a, i, target, sign) in enumerate(tab):
                    #print a,i,target,sign, str0, vv[str0], ci0[tstr,str0], ciT[target[0],target[1]]
                    tstr = vv[str0][loc]
                    rdm1[0, 1, a,
                         i] += sign * ci0[tstr,
                                          str0] * ciT[target[0],
                                                      target[1]].conjugate()

        if (neleca > 0):
            #c+_b c_a
            link_indexb, vv = cistring_soc.gen_linkstr_index_o0_soc(
                range(norb), neleca, nelecb, 1)
            nna = cistring_soc.num_strings(norb, neleca - 1)
            nnb = cistring_soc.num_strings(norb, nelecb + 1)
            ciT = fcivec[goffset - nna * nnb:goffset].reshape(nna, nnb)
            for str0, tab in enumerate(link_indexb):
                for loc, (a, i, target, sign) in enumerate(tab):
                    #for tstr in vv[str0]:
                    tstr = vv[str0][loc]
                    rdm1[1, 0, a,
                         i] += sign * ci0[str0,
                                          tstr] * ciT[target[0],
                                                      target[1]].conjugate()

        goffset += na * nb

    return rdm1