Esempio n. 1
0
def KL(nn_state, target_psi, space, bases=None, **kwargs):
    r"""A function for calculating the total KL divergence.

    :param nn_state: The neural network state (i.e. complex wavefunction or
                     positive wavefunction).
    :type nn_state: WaveFunction
    :param target_psi: The true wavefunction of the system. Can be a dictionary
                       with each value being the wavefunction represented in a
                       different basis.
    :type target_psi: torch.Tensor or dict(str, torch.Tensor)
    :param space: The hilbert space of the system.
    :type space: torch.Tensor
    :param bases: An array of unique bases.
    :type bases: np.array(dtype=str)
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The KL divergence.
    :rtype: float
    """
    psi_r = torch.zeros(2,
                        1 << nn_state.num_visible,
                        dtype=torch.double,
                        device=nn_state.device)
    KL = 0.0

    if isinstance(target_psi, dict):
        target_psi = {k: v.to(nn_state.device) for k, v in target_psi.items()}
        if bases is None:
            bases = list(target_psi.keys())
        else:
            assert set(bases) == set(target_psi.keys(
            )), "Given bases must match the keys of the target_psi dictionary."
    else:
        target_psi = target_psi.to(nn_state.device)

    Z = nn_state.compute_normalization(space)
    if bases is None:
        target_probs = cplx.absolute_value(target_psi)**2
        nn_probs = nn_state.probability(space, Z)
        KL += torch.sum(target_probs * probs_to_logits(target_probs))
        KL -= torch.sum(target_probs * probs_to_logits(nn_probs))
    else:
        unitary_dict = nn_state.unitary_dict
        for basis in bases:
            psi_r = rotate_psi(nn_state, basis, space, unitary_dict)
            if isinstance(target_psi, dict):
                target_psi_r = target_psi[basis]
            else:
                target_psi_r = rotate_psi(nn_state, basis, space, unitary_dict,
                                          target_psi)

            probs_r = (cplx.absolute_value(psi_r)**2) / Z
            target_probs_r = cplx.absolute_value(target_psi_r)**2

            KL += torch.sum(target_probs_r * probs_to_logits(target_probs_r))
            KL -= torch.sum(target_probs_r * probs_to_logits(probs_r))
        KL /= float(len(bases))

    return KL.item()
Esempio n. 2
0
def fidelity(nn_state, target_psi, space, **kwargs):
    r"""Calculates the square of the overlap (fidelity) between the reconstructed
    wavefunction and the true wavefunction (both in the computational basis).

    .. math:: F = \vert \langle \psi_{RBM} \vert \psi_{target} \rangle \vert ^2

    :param nn_state: The neural network state (i.e. complex wavefunction or
                     positive wavefunction).
    :type nn_state: qucumber.nn_states.WaveFunctionBase
    :param target_psi: The true wavefunction of the system.
    :type target_psi: torch.Tensor
    :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
                  The ordering of the basis elements must match with the ordering of the
                  coefficients given in `target_psi`.
    :type space: torch.Tensor
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The fidelity.
    :rtype: float
    """
    Z = nn_state.compute_normalization(space)
    target_psi = target_psi.to(nn_state.device)
    psi = nn_state.psi(space) / Z.sqrt()
    F = cplx.inner_prod(target_psi, psi)
    return cplx.absolute_value(F).pow_(2).item()
Esempio n. 3
0
    def test_absolute_value(self):
        tensor = torch.tensor(
            [[[5, 5, -5, -5], [3, 6, -9, 1]], [[2, -2, 2, -2], [-7, 8, 0, 4]]],
            dtype=torch.double,
        )

        expect = torch.tensor(
            [[[np.sqrt(29)] * 4, [np.sqrt(58), 10, 9, np.sqrt(17)]]], dtype=torch.double
        )

        self.assertTensorsAlmostEqual(
            cplx.absolute_value(tensor), expect, msg="Absolute Value failed!"
        )
Esempio n. 4
0
def fidelity(nn_state, target, space=None, **kwargs):
    r"""Calculates the square of the overlap (fidelity) between the reconstructed
    state and the true state (both in the computational basis).

    .. math::

        F = \vert \langle \psi_{RBM} \vert \psi_{target} \rangle \vert ^2
          = \left( \tr \lbrack \sqrt{ \sqrt{\rho_{RBM}} \rho_{target} \sqrt{\rho_{RBM}} } \rbrack \right) ^ 2

    :param nn_state: The neural network state.
    :type nn_state: qucumber.nn_states.NeuralStateBase
    :param target: The true state of the system.
    :type target: torch.Tensor
    :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
                  The ordering of the basis elements must match with the ordering of the
                  coefficients given in `target`. If `None`, will generate them using
                  the provided `nn_state`.
    :type space: torch.Tensor
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The fidelity.
    :rtype: float
    """
    space = space if space is not None else nn_state.generate_hilbert_space()
    Z = nn_state.normalization(space)
    target = target.to(nn_state.device)

    if isinstance(nn_state, WaveFunctionBase):
        assert target.dim() == 2, "target must be a complex vector!"

        psi = nn_state.psi(space) / Z.sqrt()
        F = cplx.inner_prod(target, psi)
        return cplx.absolute_value(F).pow_(2).item()
    else:
        assert target.dim() == 3, "target must be a complex matrix!"

        rho = nn_state.rho(space, space) / Z
        rho_rbm_ = cplx.numpy(rho)
        target_ = cplx.numpy(target)

        sqrt_rho_rbm = sqrtm(rho_rbm_)
        prod = np.matmul(sqrt_rho_rbm, np.matmul(target_, sqrt_rho_rbm))

        # Instead of sqrt'ing then taking the trace, we compute the eigenvals,
        #  sqrt those, and then sum them up. This is a bit more efficient.
        eigvals = np.linalg.eigvals(
            prod).real  # imaginary parts should be zero
        eigvals = np.abs(eigvals)
        trace = np.sum(np.sqrt(eigvals))

        return trace**2
Esempio n. 5
0
def NLL(nn_state, samples, space=None, sample_bases=None, **kwargs):
    r"""A function for calculating the negative log-likelihood (NLL).

    :param nn_state: The neural network state.
    :type nn_state: qucumber.nn_states.NeuralStateBase
    :param samples: Samples to compute the NLL on.
    :type samples: torch.Tensor
    :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
                  If `None`, will generate them using the provided `nn_state`.
    :type space: torch.Tensor
    :param sample_bases: An array of bases where measurements were taken.
    :type sample_bases: numpy.ndarray
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The Negative Log-Likelihood.
    :rtype: float
    """
    space = space if space is not None else nn_state.generate_hilbert_space()
    Z = nn_state.normalization(space)

    if sample_bases is None:
        nn_probs = nn_state.probability(samples, Z)
        NLL_ = -torch.mean(probs_to_logits(nn_probs)).item()
        return NLL_
    else:
        NLL_ = 0.0

        unique_bases, indices = np.unique(sample_bases,
                                          axis=0,
                                          return_inverse=True)
        indices = torch.Tensor(indices).to(samples)

        for i in range(unique_bases.shape[0]):
            basis = unique_bases[i, :]
            rot_sites = np.where(basis != "Z")[0]

            if rot_sites.size != 0:
                if isinstance(nn_state, WaveFunctionBase):
                    Upsi = rotate_psi_inner_prod(nn_state, basis,
                                                 samples[indices == i, :])
                    nn_probs = (cplx.absolute_value(Upsi)**2) / Z
                else:
                    nn_probs = (rotate_rho_probs(nn_state, basis,
                                                 samples[indices == i, :]) / Z)
            else:
                nn_probs = nn_state.probability(samples[indices == i, :], Z)

            NLL_ -= torch.sum(probs_to_logits(nn_probs))

        return NLL_ / float(len(samples))
Esempio n. 6
0
def fidelity(nn_state, target_psi, space, **kwargs):
    r"""Calculates the square of the overlap (fidelity) between the reconstructed
    wavefunction and the true wavefunction (both in the computational basis).

    :param nn_state: The neural network state (i.e. complex wavefunction or
                     positive wavefunction).
    :type nn_state: WaveFunction
    :param target_psi: The true wavefunction of the system.
    :type target_psi: torch.Tensor
    :param space: The hilbert space of the system.
    :type space: torch.Tensor
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The fidelity.
    :rtype: float
    """
    Z = nn_state.compute_normalization(space)
    target_psi = target_psi.to(nn_state.device)
    psi = nn_state.psi(space) / Z.sqrt()
    F = cplx.inner_prod(target_psi, psi)
    return cplx.absolute_value(F).pow_(2).item()
Esempio n. 7
0
    def algorithmic_gradKL(self, target, space, all_bases, **kwargs):
        grad_KL = [
            torch.zeros(
                self.nn_state.rbm_am.num_pars,
                dtype=torch.double,
                device=self.nn_state.device,
            ),
            torch.zeros(
                self.nn_state.rbm_ph.num_pars,
                dtype=torch.double,
                device=self.nn_state.device,
            ),
        ]
        Z = self.nn_state.normalization(space)

        for b in range(len(all_bases)):
            if isinstance(target, dict):
                target_r = target[all_bases[b]]
            else:
                target_r = rotate_psi_inner_prod(
                    self.nn_state, all_bases[b], space, psi=target
                )
            target_r = cplx.absolute_value(target_r) ** 2

            for i in range(len(space)):
                rotated_grad = self.nn_state.gradient(space[i], all_bases[b])
                grad_KL[0] += target_r[i] * rotated_grad[0] / float(len(all_bases))
                grad_KL[1] += target_r[i] * rotated_grad[1] / float(len(all_bases))

        probs = self.nn_state.probability(space, Z)
        all_grads = self.nn_state.rbm_am.effective_energy_gradient(space, reduce=False)
        grad_KL[0] -= torch.mv(
            all_grads.t(), probs
        )  # average the gradients, weighted by probs

        return grad_KL
Esempio n. 8
0
def KL(nn_state, target_psi, space, bases=None, **kwargs):
    r"""A function for calculating the total KL divergence.

    .. math:: KL(P_{target} \vert P_{RBM}) = \sum_{x \in \mathcal{H}} P_{target}(x)\log(\frac{P_{RBM}(x)}{P_{target}(x)})

    :param nn_state: The neural network state (i.e. complex wavefunction or
                     positive wavefunction).
    :type nn_state: qucumber.nn_states.WaveFunctionBase
    :param target_psi: The true wavefunction of the system. Can be a dictionary
                       with each value being the wavefunction represented in a
                       different basis, and the key identifying the basis.
    :type target_psi: torch.Tensor or dict(str, torch.Tensor)
    :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
                  The ordering of the basis elements must match with the ordering of the
                  coefficients given in `target_psi`.
    :type space: torch.Tensor
    :param bases: An array of unique bases. If given, the KL divergence will be
                  computed for each basis and the average will be returned.
    :type bases: np.array(dtype=str)
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The KL divergence.
    :rtype: float
    """
    psi_r = torch.zeros(2,
                        1 << nn_state.num_visible,
                        dtype=torch.double,
                        device=nn_state.device)
    KL = 0.0

    if isinstance(target_psi, dict):
        target_psi = {k: v.to(nn_state.device) for k, v in target_psi.items()}
        if bases is None:
            bases = list(target_psi.keys())
        else:
            assert set(bases) == set(target_psi.keys(
            )), "Given bases must match the keys of the target_psi dictionary."
    else:
        target_psi = target_psi.to(nn_state.device)

    Z = nn_state.compute_normalization(space)
    if bases is None:
        target_probs = cplx.absolute_value(target_psi)**2
        nn_probs = nn_state.probability(space, Z)
        KL += torch.sum(target_probs * probs_to_logits(target_probs))
        KL -= torch.sum(target_probs * probs_to_logits(nn_probs))
    else:
        unitary_dict = nn_state.unitary_dict
        for basis in bases:
            psi_r = rotate_psi(nn_state, basis, space, unitary_dict)
            if isinstance(target_psi, dict):
                target_psi_r = target_psi[basis]
            else:
                target_psi_r = rotate_psi(nn_state, basis, space, unitary_dict,
                                          target_psi)

            probs_r = (cplx.absolute_value(psi_r)**2) / Z
            target_probs_r = cplx.absolute_value(target_psi_r)**2

            KL += torch.sum(target_probs_r * probs_to_logits(target_probs_r))
            KL -= torch.sum(target_probs_r * probs_to_logits(probs_r))
        KL /= float(len(bases))

    return KL.item()
Esempio n. 9
0
def KL(nn_state, target, space=None, bases=None, **kwargs):
    r"""A function for calculating the KL divergence averaged over every given
    basis.

    .. math:: KL(P_{target} \vert P_{RBM}) = -\sum_{x \in \mathcal{H}} P_{target}(x)\log(\frac{P_{RBM}(x)}{P_{target}(x)})

    :param nn_state: The neural network state.
    :type nn_state: qucumber.nn_states.NeuralStateBase
    :param target: The true state (wavefunction or density matrix) of the system.
                   Can be a dictionary with each value being the state
                   represented in a different basis, and the key identifying the basis.
    :type target: torch.Tensor or dict(str, torch.Tensor)
    :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
                  The ordering of the basis elements must match with the ordering of the
                  coefficients given in `target`. If `None`, will generate them using
                  the provided `nn_state`.
    :type space: torch.Tensor
    :param bases: An array of unique bases. If given, the KL divergence will be
                  computed for each basis and the average will be returned.
    :type bases: numpy.ndarray
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The KL divergence.
    :rtype: float
    """
    space = space if space is not None else nn_state.generate_hilbert_space()
    Z = nn_state.normalization(space)

    if isinstance(target, dict):
        target = {k: v.to(nn_state.device) for k, v in target.items()}
        if bases is None:
            bases = list(target.keys())
        else:
            assert set(bases) == set(target.keys(
            )), "Given bases must match the keys of the target_psi dictionary."
    else:
        target = target.to(nn_state.device)

    KL = 0.0

    if bases is None:
        target_probs = cplx.absolute_value(target)**2
        nn_probs = nn_state.probability(space, Z)

        KL += _single_basis_KL(target_probs, nn_probs)

    elif isinstance(nn_state, WaveFunctionBase):
        for basis in bases:
            if isinstance(target, dict):
                target_psi_r = target[basis]
                assert target_psi_r.dim(
                ) == 2, "target must be a complex vector!"
            else:
                assert target.dim() == 2, "target must be a complex vector!"
                target_psi_r = rotate_psi(nn_state, basis, space, psi=target)

            psi_r = rotate_psi(nn_state, basis, space)
            nn_probs_r = (cplx.absolute_value(psi_r)**2) / Z
            target_probs_r = cplx.absolute_value(target_psi_r)**2

            KL += _single_basis_KL(target_probs_r, nn_probs_r)

        KL /= float(len(bases))
    else:
        for basis in bases:
            if isinstance(target, dict):
                target_rho_r = target[basis]
                assert target_rho_r.dim(
                ) == 3, "target must be a complex matrix!"
                target_probs_r = torch.diagonal(cplx.real(target_rho_r))
            else:
                assert target.dim() == 3, "target must be a complex matrix!"
                target_probs_r = rotate_rho_probs(nn_state,
                                                  basis,
                                                  space,
                                                  rho=target)

            rho_r = rotate_rho_probs(nn_state, basis, space)
            nn_probs_r = rho_r / Z

            KL += _single_basis_KL(target_probs_r, nn_probs_r)

        KL /= float(len(bases))

    return KL.item()