def rotated_gradient(self, basis, sample): r"""Computes the gradients rotated into the measurement basis :param basis: The bases in which the measurement is made :type basis: numpy.ndarray :param sample: The measurement (either 0 or 1) :type sample: torch.Tensor :returns: A list of two tensors, representing the rotated gradients of the amplitude and phase RBMs :rtype: list[torch.Tensor, torch.Tensor] """ UrhoU, UrhoU_v, v = unitaries.rotate_rho_probs(self, basis, sample, include_extras=True) inv_UrhoU = 1 / (UrhoU + 1e-8) # avoid dividing by zero raw_grads = [self.am_grads(v), self.ph_grads(v)] rotated_grad = [ -cplx.einsum("ijb,ijbg->bg", UrhoU_v, g, imag_part=False) for g in raw_grads ] return [torch.einsum("b,bg->g", inv_UrhoU, g) for g in rotated_grad]
def test_rotate_rho_probs(num_visible, state_type, precompute_rho): nn_state = state_type(num_visible, gpu=False) basis = "X" * num_visible unitary_dict = create_dict() space = nn_state.generate_hilbert_space() rho = nn_state.rho(space, expand=True) if precompute_rho else None rho_r = rotate_rho(nn_state, basis, space, unitary_dict, rho=rho) rho_r_probs = torch.diagonal(cplx.real(rho_r)) rho_r_probs_fast = rotate_rho_probs(nn_state, basis, space, unitary_dict, rho=rho) # use different tolerance as this sometimes just barely breaks through the # smaller TOL value from test_grads.py assertAlmostEqual( rho_r_probs, rho_r_probs_fast, tol=(TOL * 10), msg="Fast rho probs rotation failed!", )
def NLL(nn_state, samples, space=None, sample_bases=None, **kwargs): r"""A function for calculating the negative log-likelihood (NLL). :param nn_state: The neural network state. :type nn_state: qucumber.nn_states.NeuralStateBase :param samples: Samples to compute the NLL on. :type samples: torch.Tensor :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`. If `None`, will generate them using the provided `nn_state`. :type space: torch.Tensor :param sample_bases: An array of bases where measurements were taken. :type sample_bases: numpy.ndarray :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The Negative Log-Likelihood. :rtype: float """ space = space if space is not None else nn_state.generate_hilbert_space() Z = nn_state.normalization(space) if sample_bases is None: nn_probs = nn_state.probability(samples, Z) NLL_ = -torch.mean(probs_to_logits(nn_probs)).item() return NLL_ else: NLL_ = 0.0 unique_bases, indices = np.unique(sample_bases, axis=0, return_inverse=True) indices = torch.Tensor(indices).to(samples) for i in range(unique_bases.shape[0]): basis = unique_bases[i, :] rot_sites = np.where(basis != "Z")[0] if rot_sites.size != 0: if isinstance(nn_state, WaveFunctionBase): Upsi = rotate_psi_inner_prod(nn_state, basis, samples[indices == i, :]) nn_probs = (cplx.absolute_value(Upsi)**2) / Z else: nn_probs = (rotate_rho_probs(nn_state, basis, samples[indices == i, :]) / Z) else: nn_probs = nn_state.probability(samples[indices == i, :], Z) NLL_ -= torch.sum(probs_to_logits(nn_probs)) return NLL_ / float(len(samples))
def algorithmic_gradKL(self, target, space, all_bases, **kwargs): grad_KL = [ torch.zeros( self.nn_state.rbm_am.num_pars, dtype=torch.double, device=self.nn_state.device, ), torch.zeros( self.nn_state.rbm_ph.num_pars, dtype=torch.double, device=self.nn_state.device, ), ] Z = self.nn_state.normalization(space) for b in range(len(all_bases)): if isinstance(target, dict): target_r = target[all_bases[b]] target_r = torch.diagonal(cplx.real(target_r)) else: target_r = rotate_rho_probs( self.nn_state, all_bases[b], space, rho=target ) for i in range(len(space)): rotated_grad = self.nn_state.gradient(space[i], all_bases[b]) grad_KL[0] += target_r[i] * rotated_grad[0] / float(len(all_bases)) grad_KL[1] += target_r[i] * rotated_grad[1] / float(len(all_bases)) probs = self.nn_state.probability(space, Z) all_grads = self.nn_state.rbm_am.effective_energy_gradient(space, reduce=False) grad_KL[0] -= torch.mv( all_grads.t(), probs ) # average the gradients, weighted by probs return grad_KL
def KL(nn_state, target, space=None, bases=None, **kwargs): r"""A function for calculating the KL divergence averaged over every given basis. .. math:: KL(P_{target} \vert P_{RBM}) = -\sum_{x \in \mathcal{H}} P_{target}(x)\log(\frac{P_{RBM}(x)}{P_{target}(x)}) :param nn_state: The neural network state. :type nn_state: qucumber.nn_states.NeuralStateBase :param target: The true state (wavefunction or density matrix) of the system. Can be a dictionary with each value being the state represented in a different basis, and the key identifying the basis. :type target: torch.Tensor or dict(str, torch.Tensor) :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`. The ordering of the basis elements must match with the ordering of the coefficients given in `target`. If `None`, will generate them using the provided `nn_state`. :type space: torch.Tensor :param bases: An array of unique bases. If given, the KL divergence will be computed for each basis and the average will be returned. :type bases: numpy.ndarray :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The KL divergence. :rtype: float """ space = space if space is not None else nn_state.generate_hilbert_space() Z = nn_state.normalization(space) if isinstance(target, dict): target = {k: v.to(nn_state.device) for k, v in target.items()} if bases is None: bases = list(target.keys()) else: assert set(bases) == set(target.keys( )), "Given bases must match the keys of the target_psi dictionary." else: target = target.to(nn_state.device) KL = 0.0 if bases is None: target_probs = cplx.absolute_value(target)**2 nn_probs = nn_state.probability(space, Z) KL += _single_basis_KL(target_probs, nn_probs) elif isinstance(nn_state, WaveFunctionBase): for basis in bases: if isinstance(target, dict): target_psi_r = target[basis] assert target_psi_r.dim( ) == 2, "target must be a complex vector!" else: assert target.dim() == 2, "target must be a complex vector!" target_psi_r = rotate_psi(nn_state, basis, space, psi=target) psi_r = rotate_psi(nn_state, basis, space) nn_probs_r = (cplx.absolute_value(psi_r)**2) / Z target_probs_r = cplx.absolute_value(target_psi_r)**2 KL += _single_basis_KL(target_probs_r, nn_probs_r) KL /= float(len(bases)) else: for basis in bases: if isinstance(target, dict): target_rho_r = target[basis] assert target_rho_r.dim( ) == 3, "target must be a complex matrix!" target_probs_r = torch.diagonal(cplx.real(target_rho_r)) else: assert target.dim() == 3, "target must be a complex matrix!" target_probs_r = rotate_rho_probs(nn_state, basis, space, rho=target) rho_r = rotate_rho_probs(nn_state, basis, space) nn_probs_r = rho_r / Z KL += _single_basis_KL(target_probs_r, nn_probs_r) KL /= float(len(bases)) return KL.item()