def algorithmic_gradKL(nn_state, psi_dict, vis, unitary_dict, bases): grad_KL = [ torch.zeros(nn_state.rbm_am.num_pars, dtype=torch.double, device=nn_state.device), torch.zeros(nn_state.rbm_ph.num_pars, dtype=torch.double, device=nn_state.device) ] Z = partition(nn_state, vis).to(device=nn_state.device) for i in range(len(vis)): grad_KL[0] += (cplx.norm(psi_dict[bases[0]][:, i]) * nn_state.rbm_am.effective_energy_gradient(vis[i]) / float(len(bases))) grad_KL[0] -= (probability(nn_state, vis[i], Z) * nn_state.rbm_am.effective_energy_gradient(vis[i]) / float(len(bases))) for b in range(1, len(bases)): for i in range(len(vis)): rotated_grad = nn_state.gradient(bases[b], vis[i]) grad_KL[0] += (cplx.norm(psi_dict[bases[b]][:, i]) * rotated_grad[0] / float(len(bases))) grad_KL[1] += (cplx.norm(psi_dict[bases[b]][:, i]) * rotated_grad[1] / float(len(bases))) grad_KL[0] -= (probability(nn_state, vis[i], Z) * nn_state.rbm_am.effective_energy_gradient(vis[i]) / float(len(bases))) return grad_KL
def compute_numerical_kl(nn_state, psi_dict, vis, Z, unitary_dict, bases): N = nn_state.num_visible psi_r = torch.zeros(2, 1 << N, dtype=torch.double) KL = 0.0 for i in range(len(vis)): KL += (cplx.norm(psi_dict[bases[0]][:, i]) * cplx.norm(psi_dict[bases[0]][:, i]).log() / float(len(bases))) KL -= (cplx.norm(psi_dict[bases[0]][:, i]) * probability(nn_state, vis[i], Z).log().item() / float(len(bases))) for b in range(1, len(bases)): psi_r = rotate_psi(nn_state, bases[b], unitary_dict, vis) for ii in range(len(vis)): if (cplx.norm(psi_dict[bases[b]][:, ii]) > 0.0): KL += (cplx.norm(psi_dict[bases[b]][:, ii]) * cplx.norm(psi_dict[bases[b]][:, ii]).log() / float(len(bases))) KL -= (cplx.norm(psi_dict[bases[b]][:, ii]) * cplx.norm(psi_r[:, ii]).log() / float(len(bases))) KL += (cplx.norm(psi_dict[bases[b]][:, ii]) * Z.log() / float(len(bases))) return KL
def fidelity(nn_state, target_psi, bases=None): nn_state.compute_normalization() F = torch.tensor([0., 0.], dtype=torch.double, device=nn_state.device) target_psi = target_psi.to(nn_state.device) for i in range(len(nn_state.space)): psi = nn_state.psi(nn_state.space[i]) / (nn_state.Z).sqrt() F[0] += target_psi[0, i] * psi[0] + target_psi[1, i] * psi[1] F[1] += target_psi[0, i] * psi[1] - target_psi[1, i] * psi[0] return cplx.norm(F)
def compute_numerical_NLL(nn_state, data_samples, data_bases, Z, unitary_dict, vis): NLL = 0 batch_size = len(data_samples) b_flag = 0 for i in range(batch_size): bitstate = [] for j in range(nn_state.num_visible): ind = 0 if (data_bases[i][j] != 'Z'): b_flag = 1 bitstate.append(int(data_samples[i, j].item())) ind = int("".join(str(i) for i in bitstate), 2) if (b_flag == 0): NLL -= ((probability(nn_state, data_samples[i], Z)).log().item() / batch_size) else: psi_r = rotate_psi(nn_state, data_bases[i], unitary_dict, vis) NLL -= (cplx.norm(psi_r[:, ind]).log() - Z.log()).item() / batch_size return NLL
def KL(nn_state, target_psi, bases=None): psi_r = torch.zeros(2, 1 << nn_state.num_visible, dtype=torch.double, device=nn_state.device) KL = 0.0 unitary_dict = unitaries.create_dict() target_psi = target_psi.to(nn_state.device) space = nn_state.generate_Hilbert_space(nn_state.num_visible) nn_state.compute_normalization() if bases is None: num_bases = 1 for i in range(len(space)): KL += cplx.norm(target_psi[:, i]) * cplx.norm(target_psi[:, i]).log() KL -= cplx.norm(target_psi[:, i]) * cplx.norm( nn_state.psi(space[i])).log() KL += cplx.norm(target_psi[:, i]) * nn_state.Z.log() else: num_bases = len(bases) for b in range(1, len(bases)): psi_r = rotate_psi(nn_state, bases[b], unitary_dict) target_psi_r = rotate_psi(nn_state, bases[b], unitary_dict, target_psi) for ii in range(len(space)): if (cplx.norm(target_psi_r[:, ii]) > 0.0): KL += cplx.norm(target_psi_r[:, ii]) * cplx.norm( target_psi_r[:, ii]).log() KL -= cplx.norm(target_psi_r[:, ii]) * cplx.norm( psi_r[:, ii]).log().item() KL += cplx.norm(target_psi_r[:, ii]) * nn_state.Z.log() return KL / float(num_bases)
def test_norm(self): scalar = torch.tensor([3, 4], dtype=torch.double) expect = torch.tensor(5, dtype=torch.double) self.assertTensorsEqual(cplx.norm(scalar), expect, msg="Norm failed!")