def test_rotate_psi_inner_prod(num_visible, state_type, precompute_psi): nn_state = state_type(num_visible, gpu=False) basis = "X" * num_visible unitary_dict = create_dict() space = nn_state.generate_hilbert_space() psi = nn_state.psi(space) if precompute_psi else None psi_r = rotate_psi(nn_state, basis, space, unitary_dict, psi=psi) psi_r_ip = rotate_psi_inner_prod(nn_state, basis, space, unitary_dict, psi=psi) assertAlmostEqual(psi_r, psi_r_ip, msg="Fast psi inner product rotation failed!")
def test_rotate_psi(num_visible, wvfn_type): nn_state = wvfn_type(num_visible, gpu=False) basis = "X" * num_visible unitary_dict = create_dict() space = nn_state.generate_hilbert_space() psi = nn_state.psi(space) psi_r_fast = rotate_psi(nn_state, basis, space, unitary_dict, psi=psi) U = reduce(cplx.kronecker_prod, [unitary_dict[b] for b in basis]) psi_r_correct = cplx.matmul(U, psi) assertAlmostEqual(psi_r_fast, psi_r_correct, msg="Fast psi rotation failed!")
def KL(nn_state, target, space=None, bases=None, **kwargs): r"""A function for calculating the KL divergence averaged over every given basis. .. math:: KL(P_{target} \vert P_{RBM}) = -\sum_{x \in \mathcal{H}} P_{target}(x)\log(\frac{P_{RBM}(x)}{P_{target}(x)}) :param nn_state: The neural network state. :type nn_state: qucumber.nn_states.NeuralStateBase :param target: The true state (wavefunction or density matrix) of the system. Can be a dictionary with each value being the state represented in a different basis, and the key identifying the basis. :type target: torch.Tensor or dict(str, torch.Tensor) :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`. The ordering of the basis elements must match with the ordering of the coefficients given in `target`. If `None`, will generate them using the provided `nn_state`. :type space: torch.Tensor :param bases: An array of unique bases. If given, the KL divergence will be computed for each basis and the average will be returned. :type bases: numpy.ndarray :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The KL divergence. :rtype: float """ space = space if space is not None else nn_state.generate_hilbert_space() Z = nn_state.normalization(space) if isinstance(target, dict): target = {k: v.to(nn_state.device) for k, v in target.items()} if bases is None: bases = list(target.keys()) else: assert set(bases) == set(target.keys( )), "Given bases must match the keys of the target_psi dictionary." else: target = target.to(nn_state.device) KL = 0.0 if bases is None: target_probs = cplx.absolute_value(target)**2 nn_probs = nn_state.probability(space, Z) KL += _single_basis_KL(target_probs, nn_probs) elif isinstance(nn_state, WaveFunctionBase): for basis in bases: if isinstance(target, dict): target_psi_r = target[basis] assert target_psi_r.dim( ) == 2, "target must be a complex vector!" else: assert target.dim() == 2, "target must be a complex vector!" target_psi_r = rotate_psi(nn_state, basis, space, psi=target) psi_r = rotate_psi(nn_state, basis, space) nn_probs_r = (cplx.absolute_value(psi_r)**2) / Z target_probs_r = cplx.absolute_value(target_psi_r)**2 KL += _single_basis_KL(target_probs_r, nn_probs_r) KL /= float(len(bases)) else: for basis in bases: if isinstance(target, dict): target_rho_r = target[basis] assert target_rho_r.dim( ) == 3, "target must be a complex matrix!" target_probs_r = torch.diagonal(cplx.real(target_rho_r)) else: assert target.dim() == 3, "target must be a complex matrix!" target_probs_r = rotate_rho_probs(nn_state, basis, space, rho=target) rho_r = rotate_rho_probs(nn_state, basis, space) nn_probs_r = rho_r / Z KL += _single_basis_KL(target_probs_r, nn_probs_r) KL /= float(len(bases)) return KL.item()