def v2rdm_hubbard():
    import sys
    from openfermion.hamiltonians import MolecularData
    from openfermionpsi4 import run_psi4
    from openfermionpyscf import run_pyscf
    from openfermion.utils import map_one_pdm_to_one_hole_dm, \
        map_two_pdm_to_two_hole_dm, map_two_pdm_to_particle_hole_dm
    import openfermion as of

    e_fci = []
    e_rdm = []
    for U in [4]:  # range(1, 11):
        sites = 5
        hubbard = of.hamiltonians.fermi_hubbard(1,
                                                sites,
                                                tunneling=1,
                                                coulomb=U,
                                                chemical_potential=0,
                                                magnetic_field=0,
                                                periodic=True,
                                                spinless=False)
        # op_mat = of.get_sparse_operator(hubbard).toarray()
        # # op_mat = of.get_number_preserving_sparse_operator(hubbard, sites * 2, sites-1).toarray()
        # w, v = np.linalg.eigh(op_mat)
        # # w_idx = 5 # N4U4
        # # w_idx = 25  # N6 U4
        # w_idx = 4
        # n_density = v[:, [w_idx]] @ v[:, [w_idx]].conj().T
        # from representability.fermions.density.antisymm_sz_density import AntiSymmOrbitalDensity

        # density = AntiSymmOrbitalDensity(n_density, sites * 2)
        # tpdm_aa, tpdm_bb, tpdm_ab, [bas_aa, bas_ab] = density.construct_tpdm()
        # rev_bas_aa = dict(zip(bas_aa.values(), bas_aa.keys()))
        # rev_bas_ab = dict(zip(bas_ab.values(), bas_ab.keys()))
        # for r, s in product(range(sites), repeat=2):
        #     i, j = rev_bas_ab[r]
        #     k, l = rev_bas_ab[s]
        # tqdm_aa, tqdm_bb, tqdm_ab, _ = density.construct_thdm()
        # phdm_ab, phdm_ba, phdm_aabb = density.construct_phdm()
        # opdm_a, opdm_b = density.construct_opdm()
        # bas_aa, bas_ab = geminal_spin_basis(sites)

        # opdm_a = Tensor(opdm_a, name='ck_a')
        # opdm_b = Tensor(opdm_b, name='ck_b')
        # oqdm_a = Tensor(np.eye(4) - opdm_a.data, name='kc_a')
        # oqdm_b = Tensor(np.eye(4) - opdm_b.data, name='kc_b')
        # tpdm_aa = Tensor(tpdm_aa, name='cckk_aa', basis=bas_aa)
        # tpdm_bb = Tensor(tpdm_bb, name='cckk_bb', basis=bas_aa)
        # tpdm_ab = Tensor(tpdm_ab, name='cckk_ab', basis=bas_ab)
        # tqdm_aa = Tensor(tqdm_aa, name='kkcc_aa', basis=bas_aa)
        # tqdm_bb = Tensor(tqdm_bb, name='kkcc_bb', basis=bas_aa)
        # tqdm_ab = Tensor(tqdm_ab, name='kkcc_ab', basis=bas_ab)
        # phdm_ab = Tensor(phdm_ab, name='ckck_ab', basis=bas_ab)
        # phdm_ba = Tensor(phdm_ba, name='ckck_ba', basis=bas_ab)
        # phdm_aabb = Tensor(phdm_aabb, name='ckck_aabb')
        # rdms = MultiTensor(
        #     [opdm_a, opdm_b, oqdm_a, oqdm_b,
        #      tpdm_aa, tpdm_bb, tpdm_ab,
        #      tqdm_aa, tqdm_bb, tqdm_ab,
        #      phdm_ab, phdm_ba, phdm_aabb])
        # rdmvec = rdms.vectorize_tensors()

        hamiltonian = of.get_interaction_operator(hubbard)
        op_mat = of.get_number_preserving_sparse_operator(
            hubbard, 2 * sites, sites - 1, spin_preserving=False).toarray()
        w, _ = np.linalg.eigh(op_mat)

        gs_e = w[0]
        print(gs_e)

        one_body_ints, two_body_ints = hamiltonian.one_body_tensor, hamiltonian.two_body_tensor
        two_body_ints = np.einsum('ijkl->ijlk', two_body_ints)

        n_electrons = sites - 1
        print('n_electrons', n_electrons)
        Na = n_electrons // 2
        Nb = n_electrons // 2
        dim = one_body_ints.shape[0]
        spatial_basis_rank = sites
        sdim = spatial_basis_rank
        mm = dim**2
        bij_bas_aa, bij_bas_ab = geminal_spin_basis(spatial_basis_rank)

        # h1, v2 = spin_orbital_interaction_tensor(two_body_ints, one_body_ints)

        opdm_a_interaction, opdm_b_interaction, v2aa, v2bb, v2ab = \
            spin_adapted_interaction_tensor_rdm_consistent(two_body_ints.real,
                                                           one_body_ints.real)

        v2ab_mat = np.zeros_like(v2ab.data)
        for i in range(spatial_basis_rank):
            # ia^ j^b j^b ia
            idx = bij_bas_ab.rev((i, i))
            v2ab_mat[idx, idx] = U

        v2ab = Tensor(v2ab_mat, basis=v2ab.basis, name=v2ab.name)

        dual_basis = sz_adapted_linear_constraints(
            spatial_basis_rank, Na, Nb, ['ck', 'cckk', 'kkcc', 'ckck'])

        print("constructed dual basis")

        copdm_a = opdm_a_interaction
        copdm_b = opdm_b_interaction
        coqdm_a = Tensor(np.zeros((spatial_basis_rank, spatial_basis_rank)),
                         name='kc_a')
        coqdm_b = Tensor(np.zeros((spatial_basis_rank, spatial_basis_rank)),
                         name='kc_b')
        ctpdm_aa = v2aa
        ctpdm_bb = v2bb
        ctpdm_ab = v2ab
        ctqdm_aa = Tensor(np.zeros_like(v2aa.data),
                          name='kkcc_aa',
                          basis=bij_bas_aa)
        ctqdm_bb = Tensor(np.zeros_like(v2bb.data),
                          name='kkcc_bb',
                          basis=bij_bas_aa)
        ctqdm_ab = Tensor(np.zeros_like(v2ab.data),
                          name='kkcc_ab',
                          basis=bij_bas_ab)
        cphdm_ab = Tensor(np.zeros((spatial_basis_rank * spatial_basis_rank,
                                    spatial_basis_rank * spatial_basis_rank)),
                          name='ckck_ab',
                          basis=bij_bas_ab)
        cphdm_ba = Tensor(np.zeros((spatial_basis_rank * spatial_basis_rank,
                                    spatial_basis_rank * spatial_basis_rank)),
                          name='ckck_ba',
                          basis=bij_bas_ab)
        cphdm_aabb = Tensor(np.zeros(
            (2 * spatial_basis_rank**2, 2 * spatial_basis_rank**2)),
                            name='ckck_aabb')

        ctensor = MultiTensor([
            copdm_a, copdm_b, coqdm_a, coqdm_b, ctpdm_aa, ctpdm_bb, ctpdm_ab,
            ctqdm_aa, ctqdm_bb, ctqdm_ab, cphdm_ab, cphdm_ba, cphdm_aabb
        ])

        ctensor.dual_basis = dual_basis
        A, _, b = ctensor.synthesize_dual_basis()
        print("size of dual basis", len(dual_basis.elements))

        # print(tpdm_ab.data.trace())
        # print(ctensor.vectorize_tensors().T @ rdmvec)
        # print(b.shape)
        # print("FCI Residual ", np.linalg.norm(A @ rdmvec - b))
        # exit()

        nc, nv = A.shape
        # A.eliminate_zeros()
        nnz = A.nnz

        sdp = SDP()
        sdp.nc = nc
        sdp.nv = nv
        sdp.nnz = nnz
        sdp.blockstruct = list(
            map(lambda x: int(np.sqrt(x.size)), ctensor.tensors))
        sdp.nb = len(sdp.blockstruct)
        sdp.Amat = A.real
        sdp.bvec = b.todense().real
        sdp.cvec = ctensor.vectorize_tensors().real

        Amat = A.toarray()
        # print(A.shape)
        # # DQ num vars: 4, 4, 4, 4, 6, 6, 16, 6, 6 16
        # print("D2 size ", sum([x**2 for x in [6, 6, 16]]))
        # print("D1Q1 size ", sum([x**2 for x in [4, 4, 4, 4]]))
        # print("Spin constraint ", 16**2)
        # print(sum([x**2 for x in [4, 4, 4, 4, 6, 6, 16, 16]]))

        sm = sdim * (sdim - 1) // 2

        uadapt = gen_trans_2rdm(sdim**2, sdim)

        # spin_adapted_d2ab = uadapt.T @ tpdm_ab.data @ uadapt
        # d2ab_sa_a = spin_adapted_d2ab[:sm, :sm]
        # d2ab_sa_s = spin_adapted_d2ab[sm:, sm:]

        # for r, s in product(range(sdim * (sdim - 1) // 2), repeat=2):
        #     i, j = bas_aa.fwd(r)
        #     k, l = bas_aa.fwd(s)
        #     # print((i, j, k, l), d2ab_sa_a[bas_aa.rev((i, j)), bas_aa.rev((k, l))],
        #     #       uadapt[:, r].T @ tpdm_ab.data @ uadapt[:, s]
        #     #       )
        #     assert np.isclose(d2ab_sa_a[bas_aa.rev((i, j)), bas_aa.rev((k, l))], uadapt[:, [r]].T @ tpdm_ab.data @ uadapt[:, [s]])
        #     assert np.isclose(d2ab_sa_a[bas_aa.rev((i, j)), bas_aa.rev((k, l))], np.trace(tpdm_ab.data @ (uadapt[:, [s]] @ uadapt[:, [r]].T)))
        #     assert np.isclose(d2ab_sa_a[bas_aa.rev((i, j)), bas_aa.rev((k, l))], np.einsum('ij,ij', tpdm_ab.data, (uadapt[:, [s]] @ uadapt[:, [r]].T)))
        #     assert np.isclose(tpdm_aa.data[r, s] + tpdm_aa.data[s, r], uadapt[:, [r]].T @ tpdm_ab.data @ uadapt[:, [s]] + uadapt[:, [s]].T @ tpdm_ab.data @ uadapt[:, [r]])
        #     assert np.isclose(tpdm_bb.data[r, s] + tpdm_bb.data[s, r], uadapt[:, [r]].T @ tpdm_ab.data @ uadapt[:, [s]] + uadapt[:, [s]].T @ tpdm_ab.data @ uadapt[:, [r]])

        print("AA Dim: ", sdim * (sdim - 1) / 2, sm * (sm + 1) / 2)
        for ii in range(Amat.shape[0]):
            amats = vec2block(sdp.blockstruct, Amat[ii, :])
            for aa in amats:
                assert of.is_hermitian(aa)

        sdp.Initialize()
        epsilon = 1.0E-6
        sdp.epsilon = float(epsilon)
        sdp.epsilon_inner = float(epsilon)
        sdp.disp = True
        sdp.iter_max = 50000
        sdp.inner_iter_max = 1
        sdp.inner_solve = 'CG'

        write_sdpfile("new_hubbardN{}U{}_DQG.sdp".format(sites, U), sdp.nc,
                      sdp.nv, sdp.nnz, sdp.nb, sdp.Amat, sdp.bvec, sdp.cvec,
                      sdp.blockstruct)
        # sdp_data = solve_bpsdp(sdp)
        # sdp_data.primal_vector = rdmvec
        # sdp.iter_max = 5000
        #  sdp_data = solve_bpsdp(sdp)
        solve_rrsdp(sdp)
        print(sdp.primal.T @ sdp.cvec, gs_e)
def sdp_nrep_sz_reconstruction(corrupted_tpdm_aa,
                               corrupted_tpdm_bb,
                               corrupted_tpdm_ab,
                               num_alpha,
                               num_beta,
                               disp=False,
                               inner_iter_type='EXACT',
                               epsilon=1.0E-8,
                               max_iter=5000):
    if np.ndim(corrupted_tpdm_aa) != 2:
        raise TypeError("corrupted_tpdm_aa must be a 2-tensor")
    if np.ndim(corrupted_tpdm_bb) != 2:
        raise TypeError("corrupted_tpdm_bb must be a 2-tensor")
    if np.ndim(corrupted_tpdm_ab) != 2:
        raise TypeError("corrupted_tpdm_ab must be a 2-tensor")

    if num_alpha != num_beta:
        raise ValueError(
            "right now we are not supporting differing spin numbers")

    spatial_basis_rank = int(np.sqrt(corrupted_tpdm_ab.shape[0]))
    # get basis bijection
    bij_bas_aa, bij_bas_ab = geminal_spin_basis(spatial_basis_rank)

    # build basis look up table
    bas_aa = {}
    bas_ab = {}
    cnt_aa = 0
    cnt_ab = 0
    # iterate over spatial orbital indices
    for p, q in product(range(spatial_basis_rank), repeat=2):
        if q > p:
            bas_aa[(p, q)] = cnt_aa
            cnt_aa += 1
        bas_ab[(p, q)] = cnt_ab
        cnt_ab += 1

    dual_basis = sz_adapted_linear_constraints(spatial_basis_rank, num_alpha,
                                               num_beta,
                                               ['ck', 'cckk', 'kkcc', 'ckck'])
    dual_basis += d2_e2_mapping(spatial_basis_rank, bas_aa, bas_ab,
                                corrupted_tpdm_aa, corrupted_tpdm_bb,
                                corrupted_tpdm_ab)

    c_cckk_me_aa = spin_orbital_marginal_norm_min(corrupted_tpdm_aa.shape[0],
                                                  tensor_name='cckk_me_aa')
    c_cckk_me_bb = spin_orbital_marginal_norm_min(corrupted_tpdm_bb.shape[0],
                                                  tensor_name='cckk_me_bb')
    c_cckk_me_ab = spin_orbital_marginal_norm_min(corrupted_tpdm_ab.shape[0],
                                                  tensor_name='cckk_me_ab')
    copdm_a = Tensor(np.zeros((spatial_basis_rank, spatial_basis_rank)),
                     name='ck_a')
    copdm_b = Tensor(np.zeros((spatial_basis_rank, spatial_basis_rank)),
                     name='ck_b')
    coqdm_a = Tensor(np.zeros((spatial_basis_rank, spatial_basis_rank)),
                     name='kc_a')
    coqdm_b = Tensor(np.zeros((spatial_basis_rank, spatial_basis_rank)),
                     name='kc_b')
    ctpdm_aa = Tensor(np.zeros_like(corrupted_tpdm_aa),
                      name='cckk_aa',
                      basis=bij_bas_aa)
    ctpdm_bb = Tensor(np.zeros_like(corrupted_tpdm_bb),
                      name='cckk_bb',
                      basis=bij_bas_aa)
    ctpdm_ab = Tensor(np.zeros_like(corrupted_tpdm_ab),
                      name='cckk_ab',
                      basis=bij_bas_ab)
    ctqdm_aa = Tensor(np.zeros_like(corrupted_tpdm_aa),
                      name='kkcc_aa',
                      basis=bij_bas_aa)
    ctqdm_bb = Tensor(np.zeros_like(corrupted_tpdm_bb),
                      name='kkcc_bb',
                      basis=bij_bas_aa)
    ctqdm_ab = Tensor(np.zeros_like(corrupted_tpdm_ab),
                      name='kkcc_ab',
                      basis=bij_bas_ab)

    cphdm_ab = Tensor(np.zeros((spatial_basis_rank, spatial_basis_rank,
                                spatial_basis_rank, spatial_basis_rank)),
                      name='ckck_ab')
    cphdm_ba = Tensor(np.zeros((spatial_basis_rank, spatial_basis_rank,
                                spatial_basis_rank, spatial_basis_rank)),
                      name='ckck_ba')
    cphdm_aabb = Tensor(np.zeros(
        (2 * spatial_basis_rank**2, 2 * spatial_basis_rank**2)),
                        name='ckck_aabb')

    ctensor = MultiTensor([
        copdm_a, copdm_b, coqdm_a, coqdm_b, ctpdm_aa, ctpdm_bb, ctpdm_ab,
        ctqdm_aa, ctqdm_bb, ctqdm_ab, cphdm_ab, cphdm_ba, cphdm_aabb,
        c_cckk_me_aa, c_cckk_me_bb, c_cckk_me_ab
    ])

    ctensor.dual_basis = dual_basis
    A, _, b = ctensor.synthesize_dual_basis()

    nc, nv = A.shape
    nnz = A.nnz

    sdp = SDP()

    sdp.nc = nc
    sdp.nv = nv
    sdp.nnz = nnz
    sdp.blockstruct = list(map(lambda x: int(np.sqrt(x.size)),
                               ctensor.tensors))
    sdp.nb = len(sdp.blockstruct)
    sdp.Amat = A.real
    sdp.bvec = b.todense().real

    sdp.cvec = ctensor.vectorize_tensors().real

    sdp.Initialize()

    sdp.epsilon = float(epsilon)
    sdp.epsilon_inner = float(epsilon)
    sdp.inner_solve = inner_iter_type
    sdp.disp = disp
    sdp.iter_max = max_iter

    solve_bpsdp(sdp)

    rdms_solution = vec2block(sdp.blockstruct, sdp.primal)
    return rdms_solution[4], rdms_solution[5], rdms_solution[6]
def dqg_run_bpsdp():
    import sys
    from openfermion.hamiltonians import MolecularData
    from openfermionpsi4 import run_psi4
    from openfermionpyscf import run_pyscf
    from openfermion.utils import map_one_pdm_to_one_hole_dm, \
        map_two_pdm_to_two_hole_dm, map_two_pdm_to_particle_hole_dm

    print('Running System Setup')
    basis = 'sto-6g'
    # basis = '6-31g'
    multiplicity = 1
    # charge = 0
    # geometry = [('H', [0.0, 0.0, 0.0]), ('H', [0, 0, 0.75])]
    # charge = 1
    # geometry = [('H', [0.0, 0.0, 0.0]), ('He', [0, 0, 0.75])]
    charge = 0
    bd = 1.2
    # geometry = [('H', [0.0, 0.0, 0.0]), ('H', [0, 0, bd]),
    #             ('H', [0.0, 0.0, 2 * bd]), ('H', [0, 0, 3 * bd])]
    # geometry = [['H', [0, 0, 0]], ['H', [1.2, 0, 0]],
    #             ['H', [0, 1.2, 0]], ['H', [1.2, 1.2, 0]]]
    # geometry = [['He', [0, 0, 0]], ['H', [0, 0, 1.2]]]
    #  geometry = [['Be' [0, 0, 0]], [['B', [1.2, 0, 0]]]]
    geometry = [['N', [0, 0, 0]], ['N', [0, 0, 1.1]]]
    molecule = MolecularData(geometry, basis, multiplicity, charge)
    # Run Psi4.
    # molecule = run_psi4(molecule,
    #                     run_scf=True,
    #                     run_mp2=False,
    #                     run_cisd=False,
    #                     run_ccsd=False,
    #                     run_fci=True,
    #                     delete_input=True)
    molecule = run_pyscf(molecule,
                         run_scf=True,
                         run_mp2=False,
                         run_cisd=False,
                         run_ccsd=False,
                         run_fci=True)

    print('nuclear_repulsion', molecule.nuclear_repulsion)
    print('gs energy ', molecule.fci_energy)
    print("hf energy ", molecule.hf_energy)

    nuclear_repulsion = molecule.nuclear_repulsion
    gs_energy = molecule.fci_energy

    import openfermion as of
    hamiltonian = molecule.get_molecular_hamiltonian(
        occupied_indices=[0], active_indices=[1, 2, 3, 4])
    print(type(hamiltonian))
    print(hamiltonian)
    nuclear_repulsion = hamiltonian.constant
    hamiltonian.constant = 0
    ham = of.get_sparse_operator(hamiltonian).toarray()
    w, v = np.linalg.eigh(ham)
    idx = 0
    gs_energy = w[idx]
    n_density = v[:, [idx]] @ v[:, [idx]].conj().T

    from representability.fermions.density.antisymm_sz_density import AntiSymmOrbitalDensity

    density = AntiSymmOrbitalDensity(n_density, 8)
    opdm_a, opdm_b = density.construct_opdm()
    tpdm_aa, tpdm_bb, tpdm_ab, _ = density.construct_tpdm()

    true_tpdm = density.get_tpdm(density.rho, density.dim)
    true_tpdm = true_tpdm.transpose(0, 1, 3, 2)
    test_tpdm = unspin_adapt(tpdm_aa, tpdm_bb, tpdm_ab)
    assert np.allclose(true_tpdm, test_tpdm)

    tqdm_aa, tqdm_bb, tqdm_ab, _ = density.construct_thdm()
    phdm_ab, phdm_ba, phdm_aabb = density.construct_phdm()
    Na = np.round(opdm_a.trace()).real
    Nb = np.round(opdm_b.trace()).real

    one_body_ints, two_body_ints = hamiltonian.one_body_tensor, hamiltonian.two_body_tensor
    two_body_ints = np.einsum('ijkl->ijlk', two_body_ints)

    n_electrons = Na + Nb
    print('n_electrons', n_electrons)
    dim = one_body_ints.shape[0]
    spatial_basis_rank = dim // 2
    bij_bas_aa, bij_bas_ab = geminal_spin_basis(spatial_basis_rank)

    opdm_a_interaction, opdm_b_interaction, v2aa, v2bb, v2ab = \
        spin_adapted_interaction_tensor_rdm_consistent(two_body_ints,
                                                       one_body_ints)

    dual_basis = sz_adapted_linear_constraints(
        spatial_basis_rank,
        Na,
        Nb, ['ck', 'kc', 'cckk', 'ckck', 'kkcc'],
        S=1,
        M=-1)
    print("constructed dual basis")

    opdm_a = Tensor(opdm_a, name='ck_a')
    opdm_b = Tensor(opdm_b, name='ck_b')
    oqdm_a = Tensor(np.eye(dim // 2) - opdm_a.data, name='kc_a')
    oqdm_b = Tensor(np.eye(dim // 2) - opdm_b.data, name='kc_b')

    tpdm_aa = Tensor(tpdm_aa, name='cckk_aa', basis=bij_bas_aa)
    tpdm_bb = Tensor(tpdm_bb, name='cckk_bb', basis=bij_bas_aa)
    tpdm_ab = Tensor(tpdm_ab, name='cckk_ab', basis=bij_bas_ab)

    tqdm_aa = Tensor(tqdm_aa, name='kkcc_aa', basis=bij_bas_aa)
    tqdm_bb = Tensor(tqdm_bb, name='kkcc_bb', basis=bij_bas_aa)
    tqdm_ab = Tensor(tqdm_ab, name='kkcc_ab', basis=bij_bas_ab)

    phdm_ab = Tensor(phdm_ab, name='ckck_ab', basis=bij_bas_ab)
    phdm_ba = Tensor(phdm_ba, name='ckck_ba', basis=bij_bas_ab)
    phdm_aabb = Tensor(phdm_aabb, name='ckck_aabb')

    dtensor = MultiTensor([
        opdm_a, opdm_b, oqdm_a, oqdm_b, tpdm_aa, tpdm_bb, tpdm_ab, tqdm_aa,
        tqdm_bb, tqdm_ab, phdm_ab, phdm_ba, phdm_aabb
    ])

    copdm_a = opdm_a_interaction
    copdm_b = opdm_b_interaction
    coqdm_a = Tensor(np.zeros((spatial_basis_rank, spatial_basis_rank)),
                     name='kc_a')
    coqdm_b = Tensor(np.zeros((spatial_basis_rank, spatial_basis_rank)),
                     name='kc_b')
    ctpdm_aa = v2aa
    ctpdm_bb = v2bb
    ctpdm_ab = v2ab
    ctqdm_aa = Tensor(np.zeros_like(v2aa.data),
                      name='kkcc_aa',
                      basis=bij_bas_aa)
    ctqdm_bb = Tensor(np.zeros_like(v2bb.data),
                      name='kkcc_bb',
                      basis=bij_bas_aa)
    ctqdm_ab = Tensor(np.zeros_like(v2ab.data),
                      name='kkcc_ab',
                      basis=bij_bas_ab)
    cphdm_ab = Tensor(np.zeros((spatial_basis_rank**2, spatial_basis_rank**2)),
                      name='ckck_ab',
                      basis=bij_bas_ab)
    cphdm_ba = Tensor(np.zeros((spatial_basis_rank**2, spatial_basis_rank**2)),
                      name='ckck_ba',
                      basis=bij_bas_ab)
    cphdm_aabb = Tensor(np.zeros(
        (2 * spatial_basis_rank**2, 2 * spatial_basis_rank**2)),
                        name='ckck_aabb')

    ctensor = MultiTensor([
        copdm_a, copdm_b, coqdm_a, coqdm_b, ctpdm_aa, ctpdm_bb, ctpdm_ab,
        ctqdm_aa, ctqdm_bb, ctqdm_ab, cphdm_ab, cphdm_ba, cphdm_aabb
    ])

    print(
        (ctensor.vectorize_tensors().T @ dtensor.vectorize_tensors())[0,
                                                                      0].real)
    print(gs_energy)

    ctensor.dual_basis = dual_basis
    A, _, b = ctensor.synthesize_dual_basis()
    print("size of dual basis", len(dual_basis.elements))

    print(A @ dtensor.vectorize_tensors() - b)

    nc, nv = A.shape
    A.eliminate_zeros()
    nnz = A.nnz

    from sdpsolve.sdp import SDP
    from sdpsolve.solvers.bpsdp import solve_bpsdp
    from sdpsolve.solvers.bpsdp.bpsdp_old import solve_bpsdp
    from sdpsolve.utils.matreshape import vec2block
    sdp = SDP()

    sdp.nc = nc
    sdp.nv = nv
    sdp.nnz = nnz
    sdp.blockstruct = list(map(lambda x: int(np.sqrt(x.size)),
                               ctensor.tensors))
    sdp.nb = len(sdp.blockstruct)
    sdp.Amat = A.real
    sdp.bvec = b.todense().real
    sdp.cvec = ctensor.vectorize_tensors().real

    sdp.Initialize()
    epsilon = 1.0E-7
    sdp.epsilon = float(epsilon)
    sdp.epsilon_inner = float(epsilon) / 100

    sdp.disp = True
    sdp.iter_max = 70000
    sdp.inner_solve = 'CG'
    sdp.inner_iter_max = 2

    # # sdp_data = solve_bpsdp(sdp)
    solve_bpsdp(sdp)
    # # create all the psd-matrices for the
    # variable_dictionary = {}
    # for tensor in ctensor.tensors:
    #     linear_dim = int(np.sqrt(tensor.size))
    #     variable_dictionary[tensor.name] = cvx.Variable(shape=(linear_dim, linear_dim), PSD=True, name=tensor.name)

    # print("constructing constraints")
    # constraints = []
    # for dbe in dual_basis:
    #     single_constraint = []
    #     for tname, v_elements, p_coeffs in dbe:
    #         active_indices = get_var_indices(ctensor.tensors[tname], v_elements)
    #         single_constraint.append(variable_dictionary[tname][active_indices] * p_coeffs)
    #     constraints.append(cvx.sum(single_constraint) == dbe.dual_scalar)
    # print('constraints constructed')

    # print("constructing the problem")
    # objective = cvx.Minimize(
    #             cvx.trace(copdm_a.data @ variable_dictionary['ck_a']) +
    #             cvx.trace(copdm_b.data @ variable_dictionary['ck_b']) +
    #             cvx.trace(ctpdm_aa.data @ variable_dictionary['cckk_aa']) +
    #             cvx.trace(ctpdm_bb.data @ variable_dictionary['cckk_bb']) +
    #             cvx.trace(ctpdm_ab.data @ variable_dictionary['cckk_ab']))

    # cvx_problem = cvx.Problem(objective, constraints=constraints)
    # print('problem constructed')

    # cvx_problem.solve(solver=cvx.SCS, verbose=True, eps=0.5E-5, max_iters=100000)

    # rdms_solution = vec2block(sdp.blockstruct, sdp.primal)

    print(gs_energy)
    # print(cvx_problem.value + nuclear_repulsion)
    # print(sdp_data.primal_value() + nuclear_repulsion)
    print(sdp.primal.T @ sdp.cvec)

    print(nuclear_repulsion)
    rdms = vec2block(sdp.blockstruct, sdp.primal)

    tpdm = unspin_adapt(rdms[4], rdms[5], rdms[6])
    print(np.einsum('ijij', tpdm))
    tpdm = np.einsum('ijkl->ijlk', tpdm)
def sdp_nrep_reconstruction(corrupted_tpdm, num_alpha, num_beta):
    """
    Reconstruct a 2-RDm that looks like the input corrupted tpdm

    This reconstruction scheme uses the spin-orbital reconstruction code which is not the optimal size SDP

    :param corrupted_tpdm: measured 2-RDM from the device
    :param num_alpha: number of alpha spin electrons
    :param num_beta: number of beta spin electrons
    :return: purified 2-RDM
    """
    if np.ndim(corrupted_tpdm) != 4:
        raise TypeError("corrupted_tpdm must be a 4-tensor")

    if num_alpha != num_beta:
        raise ValueError(
            "right now we are not supporting differing spin numbers")

    sp_dim = corrupted_tpdm.shape[0]  # single-particle rank
    opdm = np.zeros((sp_dim, sp_dim), dtype=int)
    oqdm = np.zeros((sp_dim, sp_dim), dtype=int)
    tpdm = np.zeros_like(corrupted_tpdm)
    tqdm = np.zeros_like(corrupted_tpdm)
    tgdm = np.zeros_like(corrupted_tpdm)
    opdm = Tensor(tensor=opdm, name='ck')
    oqdm = Tensor(tensor=oqdm, name='kc')
    tpdm = Tensor(tensor=tpdm, name='cckk')
    tqdm = Tensor(tensor=tqdm, name='kkcc')
    tgdm = Tensor(tensor=tgdm, name='ckck')
    error_matrix = spin_orbital_marginal_norm_min(sp_dim**2,
                                                  tensor_name='cckk_me')
    rdms = MultiTensor([opdm, oqdm, tpdm, tqdm, tgdm, error_matrix])

    db = spin_orbital_linear_constraints(sp_dim, num_alpha + num_beta,
                                         ['ck', 'cckk', 'kkcc', 'ckck'])
    db += d2_e2_mapping_spinorbital(sp_dim, corrupted_tpdm)

    rdms.dual_basis = db
    A, _, c = rdms.synthesize_dual_basis()
    nv = A.shape[1]
    nc = A.shape[0]
    nnz = A.nnz

    blocklist = [
        sp_dim, sp_dim, sp_dim**2, sp_dim**2, sp_dim**2, 2 * sp_dim**2
    ]
    nb = len(blocklist)

    sdp = SDP()

    sdp.nc = nc
    sdp.nv = nv
    sdp.nnz = nnz
    sdp.blockstruct = blocklist
    sdp.nb = nb
    sdp.Amat = A.real
    sdp.bvec = c.todense().real

    sdp.cvec = rdms.vectorize_tensors().real

    sdp.Initialize()

    sdp.epsilon = float(1.0E-8)
    sdp.inner_solve = "EXACT"
    sdp.disp = True
    solve_bpsdp(sdp)

    solution_rdms = vec2block(blocklist, sdp.primal)
    tpdm_reconstructed = np.zeros_like(corrupted_tpdm)
    for p, q, r, s in product(range(sp_dim), repeat=4):
        tpdm_reconstructed[p, q, r, s] = solution_rdms[2][p * sp_dim + q,
                                                          r * sp_dim + s]

    return tpdm_reconstructed