def compute_numerical_kl(self, psi_dict, vis, Z, unitary_dict, bases): N = self.nn_state.num_visible psi_r = torch.zeros(2, 1 << N, dtype=torch.double, device=self.nn_state.device) KL = 0.0 for i in range(len(vis)): KL += (cplx.norm_sqr(psi_dict[bases[0]][:, i]) * cplx.norm_sqr(psi_dict[bases[0]][:, i]).log() / float(len(bases))) KL -= (cplx.norm_sqr(psi_dict[bases[0]][:, i]) * self.nn_state.probability(vis[i], Z).log().item() / float(len(bases))) for b in range(1, len(bases)): psi_r = self.rotate_psi(bases[b], unitary_dict, vis) for ii in range(len(vis)): if cplx.norm_sqr(psi_dict[bases[b]][:, ii]) > 0.0: KL += (cplx.norm_sqr(psi_dict[bases[b]][:, ii]) * cplx.norm_sqr(psi_dict[bases[b]][:, ii]).log() / float(len(bases))) KL -= (cplx.norm_sqr(psi_dict[bases[b]][:, ii]) * cplx.norm_sqr(psi_r[:, ii]).log() / float(len(bases))) KL += (cplx.norm_sqr(psi_dict[bases[b]][:, ii]) * Z.log() / float(len(bases))) return KL
def NLL(nn_state, samples, space, train_bases=None, **kwargs): r"""A function for calculating the negative log-likelihood. :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: WaveFunction :param samples: Samples to compute the NLL on. :type samples: torch.Tensor :param space: The hilbert space of the system. :type space: torch.Tensor :param train_bases: An array of bases where measurements were taken. :type train_bases: np.array(dtype=str) :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The Negative Log-Likelihood. :rtype: float """ psi_r = torch.zeros(2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device) NLL = 0.0 unitary_dict = unitaries.create_dict() Z = nn_state.compute_normalization(space) eps = 0.000001 if train_bases is None: for i in range(len(samples)): NLL -= (cplx.norm_sqr(nn_state.psi(samples[i])) + eps).log() NLL += Z.log() else: for i in range(len(samples)): # Check whether the sample was measured the reference basis is_reference_basis = True # b_ID = 0 for j in range(nn_state.num_visible): if train_bases[i][j] != "Z": is_reference_basis = False break if is_reference_basis is True: NLL -= (cplx.norm_sqr(nn_state.psi(samples[i])) + eps).log() NLL += Z.log() else: psi_r = rotate_psi(nn_state, train_bases[i], space, unitary_dict) # Get the index value of the sample state ind = 0 for j in range(nn_state.num_visible): if samples[i, nn_state.num_visible - j - 1] == 1: ind += pow(2, j) NLL -= cplx.norm_sqr(psi_r[:, ind]).log().item() NLL += Z.log() return (NLL / float(len(samples))).item()
def test_norm_sqr(self): scalar = torch.tensor([3, 4], dtype=torch.double) expect = torch.tensor(25, dtype=torch.double) self.assertTensorsEqual(cplx.norm_sqr(scalar), expect, msg="Norm failed!")
def algorithmic_gradKL(self, psi_dict, vis, unitary_dict, bases, **kwargs): grad_KL = [ torch.zeros( self.nn_state.rbm_am.num_pars, dtype=torch.double, device=self.nn_state.device, ), torch.zeros( self.nn_state.rbm_ph.num_pars, dtype=torch.double, device=self.nn_state.device, ), ] Z = self.nn_state.compute_normalization(vis).to(device=self.nn_state.device) for i in range(len(vis)): grad_KL[0] += ( cplx.norm_sqr(psi_dict[bases[0]][:, i]) * self.nn_state.rbm_am.effective_energy_gradient(vis[i]) / float(len(bases)) ) grad_KL[0] -= ( self.nn_state.probability(vis[i], Z) * self.nn_state.rbm_am.effective_energy_gradient(vis[i]) / float(len(bases)) ) for b in range(1, len(bases)): for i in range(len(vis)): rotated_grad = self.nn_state.gradient(bases[b], vis[i]) grad_KL[0] += ( cplx.norm_sqr(psi_dict[bases[b]][:, i]) * rotated_grad[0] / float(len(bases)) ) grad_KL[1] += ( cplx.norm_sqr(psi_dict[bases[b]][:, i]) * rotated_grad[1] / float(len(bases)) ) grad_KL[0] -= ( self.nn_state.probability(vis[i], Z) * self.nn_state.rbm_am.effective_energy_gradient(vis[i]) / float(len(bases)) ) return grad_KL
def NLL(nn_state, samples, space, bases=None, **kwargs): r"""A function for calculating the negative log-likelihood (NLL). :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: qucumber.nn_states.WaveFunctionBase :param samples: Samples to compute the NLL on. :type samples: torch.Tensor :param space: The basis elements of the Hilbert space of the system :math:`\mathcal{H}`. :type space: torch.Tensor :param bases: An array of bases where measurements were taken. :type bases: np.array(dtype=str) :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The Negative Log-Likelihood. :rtype: float """ psi_r = torch.zeros(2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device) NLL = 0.0 Z = nn_state.compute_normalization(space) if bases is None: nn_probs = nn_state.probability(samples, Z) NLL = -torch.sum(probs_to_logits(nn_probs)) else: unitary_dict = nn_state.unitary_dict for i in range(len(samples)): # Check whether the sample was measured the reference basis is_reference_basis = True for j in range(nn_state.num_visible): if bases[i][j] != "Z": is_reference_basis = False break if is_reference_basis is True: nn_probs = nn_state.probability(samples[i], Z) NLL -= torch.sum(probs_to_logits(nn_probs)) else: psi_r = rotate_psi(nn_state, bases[i], space, unitary_dict) # Get the index value of the sample state ind = 0 for j in range(nn_state.num_visible): if samples[i, nn_state.num_visible - j - 1] == 1: ind += pow(2, j) probs_r = cplx.norm_sqr(psi_r[:, ind]) / Z NLL -= probs_to_logits(probs_r).item() return (NLL / float(len(samples))).item()
def compute_numerical_NLL(self, data_samples, data_bases, Z, unitary_dict, vis): NLL = 0 batch_size = len(data_samples) b_flag = 0 for i in range(batch_size): bitstate = [] for j in range(self.nn_state.num_visible): ind = 0 if data_bases[i][j] != "Z": b_flag = 1 bitstate.append(int(data_samples[i, j].item())) ind = int("".join(str(i) for i in bitstate), 2) if b_flag == 0: NLL -= (self.nn_state.probability(data_samples[i], Z)).log().item() / batch_size else: psi_r = self.rotate_psi(data_bases[i], unitary_dict, vis) NLL -= (cplx.norm_sqr(psi_r[:, ind]).log() - Z.log()).item() / batch_size return NLL
def fidelity(nn_state, target_psi, space, **kwargs): r"""Calculates the square of the overlap (fidelity) between the reconstructed wavefunction and the true wavefunction (both in the computational basis). :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: WaveFunction :param target_psi: The true wavefunction of the system. :type target_psi: torch.Tensor :param space: The hilbert space of the system. :type space: torch.Tensor :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The fidelity. :rtype: torch.Tensor """ Z = nn_state.compute_normalization(space) F = torch.tensor([0.0, 0.0], dtype=torch.double, device=nn_state.device) target_psi = target_psi.to(nn_state.device) for i in range(len(space)): psi = nn_state.psi(space[i]) / Z.sqrt() F[0] += target_psi[0, i] * psi[0] + target_psi[1, i] * psi[1] F[1] += target_psi[0, i] * psi[1] - target_psi[1, i] * psi[0] return cplx.norm_sqr(F)
def KL(nn_state, target_psi, space, bases=None, **kwargs): r"""A function for calculating the total KL divergence. :param nn_state: The neural network state (i.e. complex wavefunction or positive wavefunction). :type nn_state: WaveFunction :param target_psi: The true wavefunction of the system. :type target_psi: torch.Tensor :param space: The hilbert space of the system. :type space: torch.Tensor :param bases: An array of unique bases. :type bases: np.array(dtype=str) :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored. :returns: The KL divergence. :rtype: float """ psi_r = torch.zeros(2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device) KL = 0.0 unitary_dict = unitaries.create_dict() target_psi = target_psi.to(nn_state.device) Z = nn_state.compute_normalization(space) eps = 0.000001 if bases is None: num_bases = 1 for i in range(len(space)): KL += (cplx.norm_sqr(target_psi[:, i]) * (cplx.norm_sqr(target_psi[:, i]) + eps).log()) KL -= (cplx.norm_sqr(target_psi[:, i]) * (cplx.norm_sqr(nn_state.psi(space[i])) + eps).log()) KL += cplx.norm_sqr(target_psi[:, i]) * Z.log() else: num_bases = len(bases) for b in range(1, len(bases)): psi_r = rotate_psi(nn_state, bases[b], space, unitary_dict) target_psi_r = rotate_psi(nn_state, bases[b], space, unitary_dict, target_psi) for ii in range(len(space)): if cplx.norm_sqr(target_psi_r[:, ii]) > 0.0: KL += (cplx.norm_sqr(target_psi_r[:, ii]) * cplx.norm_sqr(target_psi_r[:, ii]).log()) KL -= (cplx.norm_sqr(target_psi_r[:, ii]) * cplx.norm_sqr(psi_r[:, ii]).log().item()) KL += cplx.norm_sqr(target_psi_r[:, ii]) * Z.log() return (KL / float(num_bases)).item()