def KL(nn_state, target_psi, space, bases=None, **kwargs): r"""A function for calculating the total KL divergence. :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: WaveFunction :param target_psi: The true wavefunction of the system. Can be a dictionary with each value being the wavefunction represented in a different basis. :type target_psi: torch.Tensor or dict(str, torch.Tensor) :param space: The hilbert space of the system. :type space: torch.Tensor :param bases: An array of unique bases. :type bases: np.array(dtype=str) :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The KL divergence. :rtype: float """ psi_r = torch.zeros(2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device) KL = 0.0 if isinstance(target_psi, dict): target_psi = {k: v.to(nn_state.device) for k, v in target_psi.items()} if bases is None: bases = list(target_psi.keys()) else: assert set(bases) == set(target_psi.keys( )), "Given bases must match the keys of the target_psi dictionary." else: target_psi = target_psi.to(nn_state.device) Z = nn_state.compute_normalization(space) if bases is None: target_probs = cplx.absolute_value(target_psi)**2 nn_probs = nn_state.probability(space, Z) KL += torch.sum(target_probs * probs_to_logits(target_probs)) KL -= torch.sum(target_probs * probs_to_logits(nn_probs)) else: unitary_dict = nn_state.unitary_dict for basis in bases: psi_r = rotate_psi(nn_state, basis, space, unitary_dict) if isinstance(target_psi, dict): target_psi_r = target_psi[basis] else: target_psi_r = rotate_psi(nn_state, basis, space, unitary_dict, target_psi) probs_r = (cplx.absolute_value(psi_r)**2) / Z target_probs_r = cplx.absolute_value(target_psi_r)**2 KL += torch.sum(target_probs_r * probs_to_logits(target_probs_r)) KL -= torch.sum(target_probs_r * probs_to_logits(probs_r)) KL /= float(len(bases)) return KL.item()
def fidelity(nn_state, target_psi, space, **kwargs): r"""Calculates the square of the overlap (fidelity) between the reconstructed wavefunction and the true wavefunction (both in the computational basis). .. math:: F = \vert \langle \psi_{RBM} \vert \psi_{target} \rangle \vert ^2 :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: qucumber.nn_states.WaveFunctionBase :param target_psi: The true wavefunction of the system. :type target_psi: torch.Tensor :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`. The ordering of the basis elements must match with the ordering of the coefficients given in `target_psi`. :type space: torch.Tensor :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The fidelity. :rtype: float """ Z = nn_state.compute_normalization(space) target_psi = target_psi.to(nn_state.device) psi = nn_state.psi(space) / Z.sqrt() F = cplx.inner_prod(target_psi, psi) return cplx.absolute_value(F).pow_(2).item()
def test_absolute_value(self): tensor = torch.tensor( [[[5, 5, -5, -5], [3, 6, -9, 1]], [[2, -2, 2, -2], [-7, 8, 0, 4]]], dtype=torch.double, ) expect = torch.tensor( [[[np.sqrt(29)] * 4, [np.sqrt(58), 10, 9, np.sqrt(17)]]], dtype=torch.double ) self.assertTensorsAlmostEqual( cplx.absolute_value(tensor), expect, msg="Absolute Value failed!" )
def fidelity(nn_state, target, space=None, **kwargs): r"""Calculates the square of the overlap (fidelity) between the reconstructed state and the true state (both in the computational basis). .. math:: F = \vert \langle \psi_{RBM} \vert \psi_{target} \rangle \vert ^2 = \left( \tr \lbrack \sqrt{ \sqrt{\rho_{RBM}} \rho_{target} \sqrt{\rho_{RBM}} } \rbrack \right) ^ 2 :param nn_state: The neural network state. :type nn_state: qucumber.nn_states.NeuralStateBase :param target: The true state of the system. :type target: torch.Tensor :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`. The ordering of the basis elements must match with the ordering of the coefficients given in `target`. If `None`, will generate them using the provided `nn_state`. :type space: torch.Tensor :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The fidelity. :rtype: float """ space = space if space is not None else nn_state.generate_hilbert_space() Z = nn_state.normalization(space) target = target.to(nn_state.device) if isinstance(nn_state, WaveFunctionBase): assert target.dim() == 2, "target must be a complex vector!" psi = nn_state.psi(space) / Z.sqrt() F = cplx.inner_prod(target, psi) return cplx.absolute_value(F).pow_(2).item() else: assert target.dim() == 3, "target must be a complex matrix!" rho = nn_state.rho(space, space) / Z rho_rbm_ = cplx.numpy(rho) target_ = cplx.numpy(target) sqrt_rho_rbm = sqrtm(rho_rbm_) prod = np.matmul(sqrt_rho_rbm, np.matmul(target_, sqrt_rho_rbm)) # Instead of sqrt'ing then taking the trace, we compute the eigenvals, # sqrt those, and then sum them up. This is a bit more efficient. eigvals = np.linalg.eigvals( prod).real # imaginary parts should be zero eigvals = np.abs(eigvals) trace = np.sum(np.sqrt(eigvals)) return trace**2
def NLL(nn_state, samples, space=None, sample_bases=None, **kwargs): r"""A function for calculating the negative log-likelihood (NLL). :param nn_state: The neural network state. :type nn_state: qucumber.nn_states.NeuralStateBase :param samples: Samples to compute the NLL on. :type samples: torch.Tensor :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`. If `None`, will generate them using the provided `nn_state`. :type space: torch.Tensor :param sample_bases: An array of bases where measurements were taken. :type sample_bases: numpy.ndarray :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The Negative Log-Likelihood. :rtype: float """ space = space if space is not None else nn_state.generate_hilbert_space() Z = nn_state.normalization(space) if sample_bases is None: nn_probs = nn_state.probability(samples, Z) NLL_ = -torch.mean(probs_to_logits(nn_probs)).item() return NLL_ else: NLL_ = 0.0 unique_bases, indices = np.unique(sample_bases, axis=0, return_inverse=True) indices = torch.Tensor(indices).to(samples) for i in range(unique_bases.shape[0]): basis = unique_bases[i, :] rot_sites = np.where(basis != "Z")[0] if rot_sites.size != 0: if isinstance(nn_state, WaveFunctionBase): Upsi = rotate_psi_inner_prod(nn_state, basis, samples[indices == i, :]) nn_probs = (cplx.absolute_value(Upsi)**2) / Z else: nn_probs = (rotate_rho_probs(nn_state, basis, samples[indices == i, :]) / Z) else: nn_probs = nn_state.probability(samples[indices == i, :], Z) NLL_ -= torch.sum(probs_to_logits(nn_probs)) return NLL_ / float(len(samples))
def fidelity(nn_state, target_psi, space, **kwargs): r"""Calculates the square of the overlap (fidelity) between the reconstructed wavefunction and the true wavefunction (both in the computational basis). :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: WaveFunction :param target_psi: The true wavefunction of the system. :type target_psi: torch.Tensor :param space: The hilbert space of the system. :type space: torch.Tensor :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The fidelity. :rtype: float """ Z = nn_state.compute_normalization(space) target_psi = target_psi.to(nn_state.device) psi = nn_state.psi(space) / Z.sqrt() F = cplx.inner_prod(target_psi, psi) return cplx.absolute_value(F).pow_(2).item()
def algorithmic_gradKL(self, target, space, all_bases, **kwargs): grad_KL = [ torch.zeros( self.nn_state.rbm_am.num_pars, dtype=torch.double, device=self.nn_state.device, ), torch.zeros( self.nn_state.rbm_ph.num_pars, dtype=torch.double, device=self.nn_state.device, ), ] Z = self.nn_state.normalization(space) for b in range(len(all_bases)): if isinstance(target, dict): target_r = target[all_bases[b]] else: target_r = rotate_psi_inner_prod( self.nn_state, all_bases[b], space, psi=target ) target_r = cplx.absolute_value(target_r) ** 2 for i in range(len(space)): rotated_grad = self.nn_state.gradient(space[i], all_bases[b]) grad_KL[0] += target_r[i] * rotated_grad[0] / float(len(all_bases)) grad_KL[1] += target_r[i] * rotated_grad[1] / float(len(all_bases)) probs = self.nn_state.probability(space, Z) all_grads = self.nn_state.rbm_am.effective_energy_gradient(space, reduce=False) grad_KL[0] -= torch.mv( all_grads.t(), probs ) # average the gradients, weighted by probs return grad_KL
def KL(nn_state, target_psi, space, bases=None, **kwargs): r"""A function for calculating the total KL divergence. .. math:: KL(P_{target} \vert P_{RBM}) = \sum_{x \in \mathcal{H}} P_{target}(x)\log(\frac{P_{RBM}(x)}{P_{target}(x)}) :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: qucumber.nn_states.WaveFunctionBase :param target_psi: The true wavefunction of the system. Can be a dictionary with each value being the wavefunction represented in a different basis, and the key identifying the basis. :type target_psi: torch.Tensor or dict(str, torch.Tensor) :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`. The ordering of the basis elements must match with the ordering of the coefficients given in `target_psi`. :type space: torch.Tensor :param bases: An array of unique bases. If given, the KL divergence will be computed for each basis and the average will be returned. :type bases: np.array(dtype=str) :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The KL divergence. :rtype: float """ psi_r = torch.zeros(2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device) KL = 0.0 if isinstance(target_psi, dict): target_psi = {k: v.to(nn_state.device) for k, v in target_psi.items()} if bases is None: bases = list(target_psi.keys()) else: assert set(bases) == set(target_psi.keys( )), "Given bases must match the keys of the target_psi dictionary." else: target_psi = target_psi.to(nn_state.device) Z = nn_state.compute_normalization(space) if bases is None: target_probs = cplx.absolute_value(target_psi)**2 nn_probs = nn_state.probability(space, Z) KL += torch.sum(target_probs * probs_to_logits(target_probs)) KL -= torch.sum(target_probs * probs_to_logits(nn_probs)) else: unitary_dict = nn_state.unitary_dict for basis in bases: psi_r = rotate_psi(nn_state, basis, space, unitary_dict) if isinstance(target_psi, dict): target_psi_r = target_psi[basis] else: target_psi_r = rotate_psi(nn_state, basis, space, unitary_dict, target_psi) probs_r = (cplx.absolute_value(psi_r)**2) / Z target_probs_r = cplx.absolute_value(target_psi_r)**2 KL += torch.sum(target_probs_r * probs_to_logits(target_probs_r)) KL -= torch.sum(target_probs_r * probs_to_logits(probs_r)) KL /= float(len(bases)) return KL.item()
def KL(nn_state, target, space=None, bases=None, **kwargs): r"""A function for calculating the KL divergence averaged over every given basis. .. math:: KL(P_{target} \vert P_{RBM}) = -\sum_{x \in \mathcal{H}} P_{target}(x)\log(\frac{P_{RBM}(x)}{P_{target}(x)}) :param nn_state: The neural network state. :type nn_state: qucumber.nn_states.NeuralStateBase :param target: The true state (wavefunction or density matrix) of the system. Can be a dictionary with each value being the state represented in a different basis, and the key identifying the basis. :type target: torch.Tensor or dict(str, torch.Tensor) :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`. The ordering of the basis elements must match with the ordering of the coefficients given in `target`. If `None`, will generate them using the provided `nn_state`. :type space: torch.Tensor :param bases: An array of unique bases. If given, the KL divergence will be computed for each basis and the average will be returned. :type bases: numpy.ndarray :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The KL divergence. :rtype: float """ space = space if space is not None else nn_state.generate_hilbert_space() Z = nn_state.normalization(space) if isinstance(target, dict): target = {k: v.to(nn_state.device) for k, v in target.items()} if bases is None: bases = list(target.keys()) else: assert set(bases) == set(target.keys( )), "Given bases must match the keys of the target_psi dictionary." else: target = target.to(nn_state.device) KL = 0.0 if bases is None: target_probs = cplx.absolute_value(target)**2 nn_probs = nn_state.probability(space, Z) KL += _single_basis_KL(target_probs, nn_probs) elif isinstance(nn_state, WaveFunctionBase): for basis in bases: if isinstance(target, dict): target_psi_r = target[basis] assert target_psi_r.dim( ) == 2, "target must be a complex vector!" else: assert target.dim() == 2, "target must be a complex vector!" target_psi_r = rotate_psi(nn_state, basis, space, psi=target) psi_r = rotate_psi(nn_state, basis, space) nn_probs_r = (cplx.absolute_value(psi_r)**2) / Z target_probs_r = cplx.absolute_value(target_psi_r)**2 KL += _single_basis_KL(target_probs_r, nn_probs_r) KL /= float(len(bases)) else: for basis in bases: if isinstance(target, dict): target_rho_r = target[basis] assert target_rho_r.dim( ) == 3, "target must be a complex matrix!" target_probs_r = torch.diagonal(cplx.real(target_rho_r)) else: assert target.dim() == 3, "target must be a complex matrix!" target_probs_r = rotate_rho_probs(nn_state, basis, space, rho=target) rho_r = rotate_rho_probs(nn_state, basis, space) nn_probs_r = rho_r / Z KL += _single_basis_KL(target_probs_r, nn_probs_r) KL /= float(len(bases)) return KL.item()