def test_rotate_psi_inner_prod(num_visible, state_type, precompute_psi):
    nn_state = state_type(num_visible, gpu=False)
    basis = "X" * num_visible
    unitary_dict = create_dict()

    space = nn_state.generate_hilbert_space()

    psi = nn_state.psi(space) if precompute_psi else None
    psi_r = rotate_psi(nn_state, basis, space, unitary_dict, psi=psi)

    psi_r_ip = rotate_psi_inner_prod(nn_state, basis, space, unitary_dict, psi=psi)

    assertAlmostEqual(psi_r, psi_r_ip, msg="Fast psi inner product rotation failed!")
def test_rotate_psi(num_visible, wvfn_type):
    nn_state = wvfn_type(num_visible, gpu=False)
    basis = "X" * num_visible
    unitary_dict = create_dict()

    space = nn_state.generate_hilbert_space()
    psi = nn_state.psi(space)

    psi_r_fast = rotate_psi(nn_state, basis, space, unitary_dict, psi=psi)

    U = reduce(cplx.kronecker_prod, [unitary_dict[b] for b in basis])
    psi_r_correct = cplx.matmul(U, psi)

    assertAlmostEqual(psi_r_fast, psi_r_correct, msg="Fast psi rotation failed!")
def KL(nn_state, target, space=None, bases=None, **kwargs):
    r"""A function for calculating the KL divergence averaged over every given
    basis.

    .. math:: KL(P_{target} \vert P_{RBM}) = -\sum_{x \in \mathcal{H}} P_{target}(x)\log(\frac{P_{RBM}(x)}{P_{target}(x)})

    :param nn_state: The neural network state.
    :type nn_state: qucumber.nn_states.NeuralStateBase
    :param target: The true state (wavefunction or density matrix) of the system.
                   Can be a dictionary with each value being the state
                   represented in a different basis, and the key identifying the basis.
    :type target: torch.Tensor or dict(str, torch.Tensor)
    :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
                  The ordering of the basis elements must match with the ordering of the
                  coefficients given in `target`. If `None`, will generate them using
                  the provided `nn_state`.
    :type space: torch.Tensor
    :param bases: An array of unique bases. If given, the KL divergence will be
                  computed for each basis and the average will be returned.
    :type bases: numpy.ndarray
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The KL divergence.
    :rtype: float
    """
    space = space if space is not None else nn_state.generate_hilbert_space()
    Z = nn_state.normalization(space)

    if isinstance(target, dict):
        target = {k: v.to(nn_state.device) for k, v in target.items()}
        if bases is None:
            bases = list(target.keys())
        else:
            assert set(bases) == set(target.keys(
            )), "Given bases must match the keys of the target_psi dictionary."
    else:
        target = target.to(nn_state.device)

    KL = 0.0

    if bases is None:
        target_probs = cplx.absolute_value(target)**2
        nn_probs = nn_state.probability(space, Z)

        KL += _single_basis_KL(target_probs, nn_probs)

    elif isinstance(nn_state, WaveFunctionBase):
        for basis in bases:
            if isinstance(target, dict):
                target_psi_r = target[basis]
                assert target_psi_r.dim(
                ) == 2, "target must be a complex vector!"
            else:
                assert target.dim() == 2, "target must be a complex vector!"
                target_psi_r = rotate_psi(nn_state, basis, space, psi=target)

            psi_r = rotate_psi(nn_state, basis, space)
            nn_probs_r = (cplx.absolute_value(psi_r)**2) / Z
            target_probs_r = cplx.absolute_value(target_psi_r)**2

            KL += _single_basis_KL(target_probs_r, nn_probs_r)

        KL /= float(len(bases))
    else:
        for basis in bases:
            if isinstance(target, dict):
                target_rho_r = target[basis]
                assert target_rho_r.dim(
                ) == 3, "target must be a complex matrix!"
                target_probs_r = torch.diagonal(cplx.real(target_rho_r))
            else:
                assert target.dim() == 3, "target must be a complex matrix!"
                target_probs_r = rotate_rho_probs(nn_state,
                                                  basis,
                                                  space,
                                                  rho=target)

            rho_r = rotate_rho_probs(nn_state, basis, space)
            nn_probs_r = rho_r / Z

            KL += _single_basis_KL(target_probs_r, nn_probs_r)

        KL /= float(len(bases))

    return KL.item()