Ejemplo n.º 1
0
    def rotated_gradient(self, basis, sample):
        r"""Computes the gradients rotated into the measurement basis

        :param basis: The bases in which the measurement is made
        :type basis: numpy.ndarray
        :param sample: The measurement (either 0 or 1)
        :type sample: torch.Tensor

        :returns: A list of two tensors, representing the rotated gradients
                  of the amplitude and phase RBMs
        :rtype: list[torch.Tensor, torch.Tensor]
        """
        UrhoU, UrhoU_v, v = unitaries.rotate_rho_probs(self,
                                                       basis,
                                                       sample,
                                                       include_extras=True)
        inv_UrhoU = 1 / (UrhoU + 1e-8)  # avoid dividing by zero

        raw_grads = [self.am_grads(v), self.ph_grads(v)]

        rotated_grad = [
            -cplx.einsum("ijb,ijbg->bg", UrhoU_v, g, imag_part=False)
            for g in raw_grads
        ]

        return [torch.einsum("b,bg->g", inv_UrhoU, g) for g in rotated_grad]
Ejemplo n.º 2
0
def test_rotate_rho_probs(num_visible, state_type, precompute_rho):
    nn_state = state_type(num_visible, gpu=False)
    basis = "X" * num_visible
    unitary_dict = create_dict()

    space = nn_state.generate_hilbert_space()

    rho = nn_state.rho(space, expand=True) if precompute_rho else None
    rho_r = rotate_rho(nn_state, basis, space, unitary_dict, rho=rho)
    rho_r_probs = torch.diagonal(cplx.real(rho_r))

    rho_r_probs_fast = rotate_rho_probs(nn_state,
                                        basis,
                                        space,
                                        unitary_dict,
                                        rho=rho)

    # use different tolerance as this sometimes just barely breaks through the
    #  smaller TOL value from test_grads.py
    assertAlmostEqual(
        rho_r_probs,
        rho_r_probs_fast,
        tol=(TOL * 10),
        msg="Fast rho probs rotation failed!",
    )
Ejemplo n.º 3
0
def NLL(nn_state, samples, space=None, sample_bases=None, **kwargs):
    r"""A function for calculating the negative log-likelihood (NLL).

    :param nn_state: The neural network state.
    :type nn_state: qucumber.nn_states.NeuralStateBase
    :param samples: Samples to compute the NLL on.
    :type samples: torch.Tensor
    :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
                  If `None`, will generate them using the provided `nn_state`.
    :type space: torch.Tensor
    :param sample_bases: An array of bases where measurements were taken.
    :type sample_bases: numpy.ndarray
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The Negative Log-Likelihood.
    :rtype: float
    """
    space = space if space is not None else nn_state.generate_hilbert_space()
    Z = nn_state.normalization(space)

    if sample_bases is None:
        nn_probs = nn_state.probability(samples, Z)
        NLL_ = -torch.mean(probs_to_logits(nn_probs)).item()
        return NLL_
    else:
        NLL_ = 0.0

        unique_bases, indices = np.unique(sample_bases,
                                          axis=0,
                                          return_inverse=True)
        indices = torch.Tensor(indices).to(samples)

        for i in range(unique_bases.shape[0]):
            basis = unique_bases[i, :]
            rot_sites = np.where(basis != "Z")[0]

            if rot_sites.size != 0:
                if isinstance(nn_state, WaveFunctionBase):
                    Upsi = rotate_psi_inner_prod(nn_state, basis,
                                                 samples[indices == i, :])
                    nn_probs = (cplx.absolute_value(Upsi)**2) / Z
                else:
                    nn_probs = (rotate_rho_probs(nn_state, basis,
                                                 samples[indices == i, :]) / Z)
            else:
                nn_probs = nn_state.probability(samples[indices == i, :], Z)

            NLL_ -= torch.sum(probs_to_logits(nn_probs))

        return NLL_ / float(len(samples))
Ejemplo n.º 4
0
    def algorithmic_gradKL(self, target, space, all_bases, **kwargs):
        grad_KL = [
            torch.zeros(
                self.nn_state.rbm_am.num_pars,
                dtype=torch.double,
                device=self.nn_state.device,
            ),
            torch.zeros(
                self.nn_state.rbm_ph.num_pars,
                dtype=torch.double,
                device=self.nn_state.device,
            ),
        ]
        Z = self.nn_state.normalization(space)

        for b in range(len(all_bases)):
            if isinstance(target, dict):
                target_r = target[all_bases[b]]
                target_r = torch.diagonal(cplx.real(target_r))
            else:
                target_r = rotate_rho_probs(
                    self.nn_state, all_bases[b], space, rho=target
                )

            for i in range(len(space)):
                rotated_grad = self.nn_state.gradient(space[i], all_bases[b])
                grad_KL[0] += target_r[i] * rotated_grad[0] / float(len(all_bases))
                grad_KL[1] += target_r[i] * rotated_grad[1] / float(len(all_bases))

        probs = self.nn_state.probability(space, Z)
        all_grads = self.nn_state.rbm_am.effective_energy_gradient(space, reduce=False)
        grad_KL[0] -= torch.mv(
            all_grads.t(), probs
        )  # average the gradients, weighted by probs

        return grad_KL
Ejemplo n.º 5
0
def KL(nn_state, target, space=None, bases=None, **kwargs):
    r"""A function for calculating the KL divergence averaged over every given
    basis.

    .. math:: KL(P_{target} \vert P_{RBM}) = -\sum_{x \in \mathcal{H}} P_{target}(x)\log(\frac{P_{RBM}(x)}{P_{target}(x)})

    :param nn_state: The neural network state.
    :type nn_state: qucumber.nn_states.NeuralStateBase
    :param target: The true state (wavefunction or density matrix) of the system.
                   Can be a dictionary with each value being the state
                   represented in a different basis, and the key identifying the basis.
    :type target: torch.Tensor or dict(str, torch.Tensor)
    :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`.
                  The ordering of the basis elements must match with the ordering of the
                  coefficients given in `target`. If `None`, will generate them using
                  the provided `nn_state`.
    :type space: torch.Tensor
    :param bases: An array of unique bases. If given, the KL divergence will be
                  computed for each basis and the average will be returned.
    :type bases: numpy.ndarray
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The KL divergence.
    :rtype: float
    """
    space = space if space is not None else nn_state.generate_hilbert_space()
    Z = nn_state.normalization(space)

    if isinstance(target, dict):
        target = {k: v.to(nn_state.device) for k, v in target.items()}
        if bases is None:
            bases = list(target.keys())
        else:
            assert set(bases) == set(target.keys(
            )), "Given bases must match the keys of the target_psi dictionary."
    else:
        target = target.to(nn_state.device)

    KL = 0.0

    if bases is None:
        target_probs = cplx.absolute_value(target)**2
        nn_probs = nn_state.probability(space, Z)

        KL += _single_basis_KL(target_probs, nn_probs)

    elif isinstance(nn_state, WaveFunctionBase):
        for basis in bases:
            if isinstance(target, dict):
                target_psi_r = target[basis]
                assert target_psi_r.dim(
                ) == 2, "target must be a complex vector!"
            else:
                assert target.dim() == 2, "target must be a complex vector!"
                target_psi_r = rotate_psi(nn_state, basis, space, psi=target)

            psi_r = rotate_psi(nn_state, basis, space)
            nn_probs_r = (cplx.absolute_value(psi_r)**2) / Z
            target_probs_r = cplx.absolute_value(target_psi_r)**2

            KL += _single_basis_KL(target_probs_r, nn_probs_r)

        KL /= float(len(bases))
    else:
        for basis in bases:
            if isinstance(target, dict):
                target_rho_r = target[basis]
                assert target_rho_r.dim(
                ) == 3, "target must be a complex matrix!"
                target_probs_r = torch.diagonal(cplx.real(target_rho_r))
            else:
                assert target.dim() == 3, "target must be a complex matrix!"
                target_probs_r = rotate_rho_probs(nn_state,
                                                  basis,
                                                  space,
                                                  rho=target)

            rho_r = rotate_rho_probs(nn_state, basis, space)
            nn_probs_r = rho_r / Z

            KL += _single_basis_KL(target_probs_r, nn_probs_r)

        KL /= float(len(bases))

    return KL.item()