def _compute_num_cnots(U):
    r"""Compute the number of CNOTs required to implement a U in SU(4). This is based on
    the trace of

    .. math::

        \gamma(U) = (E^\dag U E) (E^\dag U E)^T,

    and follows the arguments of this paper: https://arxiv.org/abs/quant-ph/0308045.
    """
    u = math.dot(Edag, math.dot(U, E))
    gammaU = math.dot(u, math.T(u))
    trace = math.trace(gammaU)

    # Case: 0 CNOTs (tensor product), the trace is +/- 4
    # We need a tolerance of around 1e-7 here in order to work with the case where U
    # is specified with 8 decimal places.
    if math.allclose(trace, 4, atol=1e-7) or math.allclose(
            trace, -4, atol=1e-7):
        return 0

    # To distinguish between 1/2 CNOT cases, we need to look at the eigenvalues
    evs = math.linalg.eigvals(gammaU)

    sorted_evs = math.sort(math.imag(evs))

    # Case: 1 CNOT, the trace is 0, and the eigenvalues of gammaU are [-1j, -1j, 1j, 1j]
    # Checking the eigenvalues is needed because of some special 2-CNOT cases that yield
    # a trace 0.
    if math.allclose(trace, 0j, atol=1e-7) and math.allclose(
            sorted_evs, [-1, -1, 1, 1]):
        return 1

    # Case: 2 CNOTs, the trace has only a real part (or is 0)
    if math.allclose(math.imag(trace), 0.0, atol=1e-7):
        return 2

    # For the case with 3 CNOTs, the trace is a non-zero complex number
    # with both real and imaginary parts.
    return 3
def _extract_su2su2_prefactors(U, V):
    r"""This function is used for the case of 2 CNOTs and 3 CNOTs. It does something
    similar as the 1-CNOT case, but there is no special form for one of the
    SO(4) operations.

    Suppose U, V are SU(4) matrices for which there exists A, B, C, D such that
    (A \otimes B) V (C \otimes D) = U. The problem is to find A, B, C, D in SU(2)
    in an analytic and fully differentiable manner.

    This decomposition is possible when U and V are in the same double coset of
    SU(4), meaning there exists G, H in SO(4) s.t. G (Edag V E) H = (Edag U
    E). This is guaranteed here by how V was constructed in both the
    _decomposition_2_cnots and _decomposition_3_cnots methods.

    Then, we can use the fact that E SO(4) Edag gives us something in SU(2) x
    SU(2) to give A, B, C, D.
    """

    # A lot of the work here happens in the magic basis. Essentially, we
    # don't look explicitly at some U = G V H, but rather at
    #     E^\dagger U E = G E^\dagger V E H
    # so that we can recover
    #     U = (E G E^\dagger) V (E H E^\dagger) = (A \otimes B) V (C \otimes D).
    # There is some math in the paper explaining how when we define U in this way,
    # we can simultaneously diagonalize functions of U and V to ensure they are
    # in the same coset and recover the decomposition.
    u = math.dot(math.cast_like(Edag, V), math.dot(U, math.cast_like(E, V)))
    v = math.dot(math.cast_like(Edag, V), math.dot(V, math.cast_like(E, V)))

    uuT = math.dot(u, math.T(u))
    vvT = math.dot(v, math.T(v))

    # Get the p and q in SO(4) that diagonalize uuT and vvT respectively (and
    # their eigenvalues). We are looking for a simultaneous diagonalization,
    # which we know exists because of how U and V were constructed. Furthermore,
    # The way we will do this is by noting that, since uuT/vvT are complex and
    # symmetric, so both their real and imaginary parts share a set of
    # real-valued eigenvectors, which are also eigenvectors of uuT/vvT
    # themselves. So we can use eigh, which orders the eigenvectors, and so we
    # are guaranteed that the p and q returned will be "in the same order".
    _, p = math.linalg.eigh(math.real(uuT) + math.imag(uuT))
    _, q = math.linalg.eigh(math.real(vvT) + math.imag(vvT))

    # If determinant of p/q is not 1, it is in O(4) but not SO(4), and has determinant
    # We can transform it to SO(4) by simply negating one of the columns.
    p = math.dot(p, math.diag([1, 1, 1, math.sign(math.linalg.det(p))]))
    q = math.dot(q, math.diag([1, 1, 1, math.sign(math.linalg.det(q))]))

    # Now, we should have p, q in SO(4) such that p^T u u^T p = q^T v v^T q.
    # Then (v^\dag q p^T u)(v^\dag q p^T u)^T = I.
    # So we can set G = p q^T, H = v^\dag q p^T u to obtain G v H = u.
    G = math.dot(math.cast_like(p, 1j), math.T(q))
    H = math.dot(math.conj(math.T(v)), math.dot(math.T(G), u))

    # These are still in SO(4) though - we want to convert things into SU(2) x SU(2)
    # so use the entangler. Since u = E^\dagger U E and v = E^\dagger V E where U, V
    # are the target matrices, we can reshuffle as in the docstring above,
    #     U = (E G E^\dagger) V (E H E^\dagger) = (A \otimes B) V (C \otimes D)
    # where A, B, C, D are in SU(2) x SU(2).
    AB = math.dot(math.cast_like(E, G), math.dot(G, math.cast_like(Edag, G)))
    CD = math.dot(math.cast_like(E, H), math.dot(H, math.cast_like(Edag, H)))

    # Now, we just need to extract the constituent tensor products.
    A, B = _su2su2_to_tensor_products(AB)
    C, D = _su2su2_to_tensor_products(CD)

    return A, B, C, D