def algorithmic_gradKL(nn_state, psi_dict, vis, unitary_dict, bases):
    grad_KL = [
        torch.zeros(nn_state.rbm_am.num_pars,
                    dtype=torch.double,
                    device=nn_state.device),
        torch.zeros(nn_state.rbm_ph.num_pars,
                    dtype=torch.double,
                    device=nn_state.device)
    ]
    Z = partition(nn_state, vis).to(device=nn_state.device)

    for i in range(len(vis)):
        grad_KL[0] += (cplx.norm(psi_dict[bases[0]][:, i]) *
                       nn_state.rbm_am.effective_energy_gradient(vis[i]) /
                       float(len(bases)))
        grad_KL[0] -= (probability(nn_state, vis[i], Z) *
                       nn_state.rbm_am.effective_energy_gradient(vis[i]) /
                       float(len(bases)))

    for b in range(1, len(bases)):
        for i in range(len(vis)):
            rotated_grad = nn_state.gradient(bases[b], vis[i])
            grad_KL[0] += (cplx.norm(psi_dict[bases[b]][:, i]) *
                           rotated_grad[0] / float(len(bases)))
            grad_KL[1] += (cplx.norm(psi_dict[bases[b]][:, i]) *
                           rotated_grad[1] / float(len(bases)))
            grad_KL[0] -= (probability(nn_state, vis[i], Z) *
                           nn_state.rbm_am.effective_energy_gradient(vis[i]) /
                           float(len(bases)))
    return grad_KL
def compute_numerical_kl(nn_state, psi_dict, vis, Z, unitary_dict, bases):
    N = nn_state.num_visible
    psi_r = torch.zeros(2, 1 << N, dtype=torch.double)
    KL = 0.0
    for i in range(len(vis)):
        KL += (cplx.norm(psi_dict[bases[0]][:, i]) *
               cplx.norm(psi_dict[bases[0]][:, i]).log() / float(len(bases)))
        KL -= (cplx.norm(psi_dict[bases[0]][:, i]) *
               probability(nn_state, vis[i], Z).log().item() /
               float(len(bases)))

    for b in range(1, len(bases)):
        psi_r = rotate_psi(nn_state, bases[b], unitary_dict, vis)
        for ii in range(len(vis)):
            if (cplx.norm(psi_dict[bases[b]][:, ii]) > 0.0):
                KL += (cplx.norm(psi_dict[bases[b]][:, ii]) *
                       cplx.norm(psi_dict[bases[b]][:, ii]).log() /
                       float(len(bases)))

            KL -= (cplx.norm(psi_dict[bases[b]][:, ii]) *
                   cplx.norm(psi_r[:, ii]).log() / float(len(bases)))
            KL += (cplx.norm(psi_dict[bases[b]][:, ii]) * Z.log() /
                   float(len(bases)))

    return KL
def fidelity(nn_state, target_psi, bases=None):
    nn_state.compute_normalization()
    F = torch.tensor([0., 0.], dtype=torch.double, device=nn_state.device)
    target_psi = target_psi.to(nn_state.device)
    for i in range(len(nn_state.space)):
        psi = nn_state.psi(nn_state.space[i]) / (nn_state.Z).sqrt()
        F[0] += target_psi[0, i] * psi[0] + target_psi[1, i] * psi[1]
        F[1] += target_psi[0, i] * psi[1] - target_psi[1, i] * psi[0]
    return cplx.norm(F)
def compute_numerical_NLL(nn_state, data_samples, data_bases, Z, unitary_dict,
                          vis):
    NLL = 0
    batch_size = len(data_samples)
    b_flag = 0
    for i in range(batch_size):
        bitstate = []
        for j in range(nn_state.num_visible):
            ind = 0
            if (data_bases[i][j] != 'Z'):
                b_flag = 1
            bitstate.append(int(data_samples[i, j].item()))
        ind = int("".join(str(i) for i in bitstate), 2)
        if (b_flag == 0):
            NLL -= ((probability(nn_state, data_samples[i], Z)).log().item() /
                    batch_size)
        else:
            psi_r = rotate_psi(nn_state, data_bases[i], unitary_dict, vis)
            NLL -= (cplx.norm(psi_r[:, ind]).log() -
                    Z.log()).item() / batch_size
    return NLL
def KL(nn_state, target_psi, bases=None):
    psi_r = torch.zeros(2,
                        1 << nn_state.num_visible,
                        dtype=torch.double,
                        device=nn_state.device)
    KL = 0.0
    unitary_dict = unitaries.create_dict()
    target_psi = target_psi.to(nn_state.device)
    space = nn_state.generate_Hilbert_space(nn_state.num_visible)
    nn_state.compute_normalization()
    if bases is None:
        num_bases = 1
        for i in range(len(space)):
            KL += cplx.norm(target_psi[:, i]) * cplx.norm(target_psi[:,
                                                                     i]).log()
            KL -= cplx.norm(target_psi[:, i]) * cplx.norm(
                nn_state.psi(space[i])).log()
            KL += cplx.norm(target_psi[:, i]) * nn_state.Z.log()

    else:
        num_bases = len(bases)
        for b in range(1, len(bases)):
            psi_r = rotate_psi(nn_state, bases[b], unitary_dict)
            target_psi_r = rotate_psi(nn_state, bases[b], unitary_dict,
                                      target_psi)
            for ii in range(len(space)):
                if (cplx.norm(target_psi_r[:, ii]) > 0.0):
                    KL += cplx.norm(target_psi_r[:, ii]) * cplx.norm(
                        target_psi_r[:, ii]).log()
                KL -= cplx.norm(target_psi_r[:, ii]) * cplx.norm(
                    psi_r[:, ii]).log().item()
                KL += cplx.norm(target_psi_r[:, ii]) * nn_state.Z.log()
    return KL / float(num_bases)
Example #6
0
    def test_norm(self):
        scalar = torch.tensor([3, 4], dtype=torch.double)
        expect = torch.tensor(5, dtype=torch.double)

        self.assertTensorsEqual(cplx.norm(scalar), expect, msg="Norm failed!")