def KL(nn_state, target_psi, bases=None):
    psi_r = torch.zeros(2,
                        1 << nn_state.num_visible,
                        dtype=torch.double,
                        device=nn_state.device)
    KL = 0.0
    unitary_dict = unitaries.create_dict()
    target_psi = target_psi.to(nn_state.device)
    space = nn_state.generate_Hilbert_space(nn_state.num_visible)
    nn_state.compute_normalization()
    if bases is None:
        num_bases = 1
        for i in range(len(space)):
            KL += cplx.norm(target_psi[:, i]) * cplx.norm(target_psi[:,
                                                                     i]).log()
            KL -= cplx.norm(target_psi[:, i]) * cplx.norm(
                nn_state.psi(space[i])).log()
            KL += cplx.norm(target_psi[:, i]) * nn_state.Z.log()

    else:
        num_bases = len(bases)
        for b in range(1, len(bases)):
            psi_r = rotate_psi(nn_state, bases[b], unitary_dict)
            target_psi_r = rotate_psi(nn_state, bases[b], unitary_dict,
                                      target_psi)
            for ii in range(len(space)):
                if (cplx.norm(target_psi_r[:, ii]) > 0.0):
                    KL += cplx.norm(target_psi_r[:, ii]) * cplx.norm(
                        target_psi_r[:, ii]).log()
                KL -= cplx.norm(target_psi_r[:, ii]) * cplx.norm(
                    psi_r[:, ii]).log().item()
                KL += cplx.norm(target_psi_r[:, ii]) * nn_state.Z.log()
    return KL / float(num_bases)
Example #2
0
def test_rotate_rho_probs(num_visible, state_type, precompute_rho):
    nn_state = state_type(num_visible, gpu=False)
    basis = "X" * num_visible
    unitary_dict = create_dict()

    space = nn_state.generate_hilbert_space()

    rho = nn_state.rho(space, expand=True) if precompute_rho else None
    rho_r = rotate_rho(nn_state, basis, space, unitary_dict, rho=rho)
    rho_r_probs = torch.diagonal(cplx.real(rho_r))

    rho_r_probs_fast = rotate_rho_probs(nn_state,
                                        basis,
                                        space,
                                        unitary_dict,
                                        rho=rho)

    # use different tolerance as this sometimes just barely breaks through the
    #  smaller TOL value from test_grads.py
    assertAlmostEqual(
        rho_r_probs,
        rho_r_probs_fast,
        tol=(TOL * 10),
        msg="Fast rho probs rotation failed!",
    )
Example #3
0
    def __init__(self,
                 num_visible,
                 num_hidden,
                 num_aux,
                 unitary_dict=None,
                 gpu=False):
        self.rbm_am = PurificationRBM(int(num_visible),
                                      int(num_hidden),
                                      int(num_aux),
                                      gpu=gpu)

        self.rbm_ph = PurificationRBM(int(num_visible),
                                      int(num_hidden),
                                      int(num_aux),
                                      gpu=gpu)

        self.num_visible = int(num_visible)
        self.num_hidden = int(num_hidden)
        self.num_aux = int(num_aux)

        self.device = self.rbm_am.device

        self.unitary_dict = unitary_dict if unitary_dict else unitaries.create_dict(
        )
        self.unitary_dict = {k: v for k, v in self.unitary_dict.items()}
Example #4
0
def complex_wavefunction_data(gpu, num_hidden):
    with open(
        os.path.join(__tests_location__, "data", "test_grad_data.pkl"), "rb"
    ) as f:
        test_data = pickle.load(f)

    qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True)

    data_bases = test_data["2qubits"]["train_bases"]
    data_samples = torch.tensor(
        test_data["2qubits"]["train_samples"], dtype=torch.double
    )

    bases_data = test_data["2qubits"]["bases"]
    target_psi_tmp = torch.tensor(
        test_data["2qubits"]["target_psi"], dtype=torch.double
    )

    num_visible = data_samples.shape[-1]

    unitary_dict = unitaries.create_dict()
    nn_state = ComplexWaveFunction(
        num_visible, num_hidden, unitary_dict=unitary_dict, gpu=gpu
    )
    CGU = ComplexGradsUtils(nn_state)

    bases = CGU.transform_bases(bases_data)

    psi_dict = CGU.load_target_psi(bases, target_psi_tmp)
    vis = nn_state.generate_hilbert_space(num_visible)

    data_samples = data_samples.to(device=nn_state.device)

    unitary_dict = {b: v.to(device=nn_state.device) for b, v in unitary_dict.items()}
    psi_dict = {b: v.to(device=nn_state.device) for b, v in psi_dict.items()}

    ComplexWaveFunctionFixture = namedtuple(
        "ComplexWaveFunctionFixture",
        [
            "data_samples",
            "data_bases",
            "grad_utils",
            "bases",
            "psi_dict",
            "vis",
            "nn_state",
            "unitary_dict",
        ],
    )

    return ComplexWaveFunctionFixture(
        data_samples=data_samples,
        data_bases=data_bases,
        grad_utils=CGU,
        bases=bases,
        psi_dict=psi_dict,
        vis=vis,
        nn_state=nn_state,
        unitary_dict=unitary_dict,
    )
    def __init__(
        self, num_visible, num_hidden=None, unitary_dict=None, gpu=False, module=None
    ):
        if gpu and torch.cuda.is_available():
            warnings.warn(
                "Using ComplexWaveFunction on GPU is not recommended due to poor performance compared to CPU.",
                ResourceWarning,
                2,
            )
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        if module is None:
            self.rbm_am = BinaryRBM(num_visible, num_hidden, gpu=gpu)
            self.rbm_ph = BinaryRBM(num_visible, num_hidden, gpu=gpu)
        else:
            _warn_on_missing_gpu(gpu)
            self.rbm_am = module.to(self.device)
            self.rbm_am.device = self.device
            self.rbm_ph = module.to(self.device).clone()
            self.rbm_ph.device = self.device

        self.num_visible = self.rbm_am.num_visible
        self.num_hidden = self.rbm_am.num_hidden
        self.device = self.rbm_am.device

        self.unitary_dict = unitary_dict if unitary_dict else unitaries.create_dict()
        self.unitary_dict = {
            k: v.to(device=self.device) for k, v in self.unitary_dict.items()
        }
Example #6
0
def test_adding_unitaries():
    unitary_dict = unitaries.create_dict(
        A=(
            0.5
            * torch.tensor([[[1, 1], [1, 1]], [[1, -1], [-1, 1]]], dtype=torch.double)
        )
    )

    msg = "Unitary dictionary has the wrong keys!"
    assert {"X", "Y", "Z", "A"} == set(unitary_dict.keys()), msg
Example #7
0
def NLL(nn_state, samples, space, train_bases=None, **kwargs):
    r"""A function for calculating the negative log-likelihood.

    :param nn_state: The neural network state (i.e. complex wavefunction or
                     positive wavefunction).
    :type nn_state: WaveFunction
    :param samples: Samples to compute the NLL on.
    :type samples: torch.Tensor
    :param space: The hilbert space of the system.
    :type space: torch.Tensor
    :param train_bases: An array of bases where measurements were taken.
    :type train_bases: np.array(dtype=str)
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The Negative Log-Likelihood.
    :rtype: float
    """
    psi_r = torch.zeros(2,
                        1 << nn_state.num_visible,
                        dtype=torch.double,
                        device=nn_state.device)
    NLL = 0.0
    unitary_dict = unitaries.create_dict()
    Z = nn_state.compute_normalization(space)
    eps = 0.000001
    if train_bases is None:
        for i in range(len(samples)):
            NLL -= (cplx.norm_sqr(nn_state.psi(samples[i])) + eps).log()
            NLL += Z.log()
    else:
        for i in range(len(samples)):
            # Check whether the sample was measured the reference basis
            is_reference_basis = True
            # b_ID = 0
            for j in range(nn_state.num_visible):
                if train_bases[i][j] != "Z":
                    is_reference_basis = False
                    break
            if is_reference_basis is True:
                NLL -= (cplx.norm_sqr(nn_state.psi(samples[i])) + eps).log()
                NLL += Z.log()
            else:
                psi_r = rotate_psi(nn_state, train_bases[i], space,
                                   unitary_dict)
                # Get the index value of the sample state
                ind = 0
                for j in range(nn_state.num_visible):
                    if samples[i, nn_state.num_visible - j - 1] == 1:
                        ind += pow(2, j)
                NLL -= cplx.norm_sqr(psi_r[:, ind]).log().item()
                NLL += Z.log()
    return (NLL / float(len(samples))).item()
def test_rotate_psi_inner_prod(num_visible, state_type, precompute_psi):
    nn_state = state_type(num_visible, gpu=False)
    basis = "X" * num_visible
    unitary_dict = create_dict()

    space = nn_state.generate_hilbert_space()

    psi = nn_state.psi(space) if precompute_psi else None
    psi_r = rotate_psi(nn_state, basis, space, unitary_dict, psi=psi)

    psi_r_ip = rotate_psi_inner_prod(nn_state, basis, space, unitary_dict, psi=psi)

    assertAlmostEqual(psi_r, psi_r_ip, msg="Fast psi inner product rotation failed!")
def test_rotate_psi(num_visible, wvfn_type):
    nn_state = wvfn_type(num_visible, gpu=False)
    basis = "X" * num_visible
    unitary_dict = create_dict()

    space = nn_state.generate_hilbert_space()
    psi = nn_state.psi(space)

    psi_r_fast = rotate_psi(nn_state, basis, space, unitary_dict, psi=psi)

    U = reduce(cplx.kronecker_prod, [unitary_dict[b] for b in basis])
    psi_r_correct = cplx.matmul(U, psi)

    assertAlmostEqual(psi_r_fast, psi_r_correct, msg="Fast psi rotation failed!")
def test_rotate_rho(num_visible, state_type):
    nn_state = state_type(num_visible, gpu=False)
    basis = "X" * num_visible
    unitary_dict = create_dict()

    space = nn_state.generate_hilbert_space()
    rho = nn_state.rho(space, space)

    rho_r_fast = rotate_rho(nn_state, basis, space, unitary_dict, rho=rho)

    U = reduce(cplx.kronecker_prod, [unitary_dict[b] for b in basis])
    rho_r_correct = cplx.matmul(U, rho)
    rho_r_correct = cplx.matmul(rho_r_correct, cplx.conjugate(U))

    assertAlmostEqual(rho_r_fast, rho_r_correct, msg="Fast rho rotation failed!")
Example #11
0
def KL(nn_state, target_psi, space, bases=None, **kwargs):
    r"""A function for calculating the total KL divergence.

    :param nn_state: The neural network state (i.e. complex wavefunction or
                     positive wavefunction).
    :type nn_state: WaveFunction
    :param target_psi: The true wavefunction of the system.
    :type target_psi: torch.Tensor
    :param space: The hilbert space of the system.
    :type space: torch.Tensor
    :param bases: An array of unique bases.
    :type bases: np.array(dtype=str)
    :param \**kwargs: Extra keyword arguments that may be passed. Will be ignored.

    :returns: The KL divergence.
    :rtype: float
    """
    psi_r = torch.zeros(2,
                        1 << nn_state.num_visible,
                        dtype=torch.double,
                        device=nn_state.device)
    KL = 0.0
    unitary_dict = unitaries.create_dict()
    target_psi = target_psi.to(nn_state.device)
    Z = nn_state.compute_normalization(space)
    eps = 0.000001
    if bases is None:
        num_bases = 1
        for i in range(len(space)):
            KL += (cplx.norm_sqr(target_psi[:, i]) *
                   (cplx.norm_sqr(target_psi[:, i]) + eps).log())
            KL -= (cplx.norm_sqr(target_psi[:, i]) *
                   (cplx.norm_sqr(nn_state.psi(space[i])) + eps).log())
            KL += cplx.norm_sqr(target_psi[:, i]) * Z.log()
    else:
        num_bases = len(bases)
        for b in range(1, len(bases)):
            psi_r = rotate_psi(nn_state, bases[b], space, unitary_dict)
            target_psi_r = rotate_psi(nn_state, bases[b], space, unitary_dict,
                                      target_psi)
            for ii in range(len(space)):
                if cplx.norm_sqr(target_psi_r[:, ii]) > 0.0:
                    KL += (cplx.norm_sqr(target_psi_r[:, ii]) *
                           cplx.norm_sqr(target_psi_r[:, ii]).log())
                KL -= (cplx.norm_sqr(target_psi_r[:, ii]) *
                       cplx.norm_sqr(psi_r[:, ii]).log().item())
                KL += cplx.norm_sqr(target_psi_r[:, ii]) * Z.log()
    return (KL / float(num_bases)).item()
    def __init__(self,
                 num_visible,
                 num_hidden=None,
                 unitary_dict=None,
                 gpu=True,
                 module=None):
        if gpu and torch.cuda.is_available():
            warnings.warn(
                ("Using ComplexWaveFunction on GPU is not recommended due to poor "
                 "performance compared to CPU. In the future, ComplexWaveFunction "
                 "will default to using CPU, even if a GPU is available."),
                ResourceWarning,
                2,
            )
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        if module is None:
            self.rbm_am = BinaryRBM(
                int(num_visible),
                int(num_hidden) if num_hidden else int(num_visible),
                gpu=gpu,
            )
            self.rbm_ph = BinaryRBM(
                int(num_visible),
                int(num_hidden) if num_hidden else int(num_visible),
                gpu=gpu,
            )
        else:
            _warn_on_missing_gpu(gpu)
            self.rbm_am = module.to(self.device)
            self.rbm_am.device = self.device
            self.rbm_ph = module.to(self.device).clone()
            self.rbm_ph.device = self.device

        self.num_visible = int(num_visible)
        self.num_hidden = int(num_hidden) if num_hidden else self.num_visible

        self.unitary_dict = unitary_dict if unitary_dict else unitaries.create_dict(
        )
        self.unitary_dict = {
            k: v.to(device=self.device)
            for k, v in self.unitary_dict.items()
        }
Example #13
0
    def __init__(self,
                 num_visible,
                 num_hidden=None,
                 unitary_dict=None,
                 gpu=True):
        self.num_visible = int(num_visible)
        self.num_hidden = int(num_hidden) if num_hidden else self.num_visible
        self.rbm_am = BinaryRBM(self.num_visible, self.num_hidden, gpu=gpu)
        self.rbm_ph = BinaryRBM(self.num_visible, self.num_hidden, gpu=gpu)

        self.device = self.rbm_am.device

        self.unitary_dict = unitary_dict if unitary_dict else unitaries.create_dict(
        )
        self.unitary_dict = {
            k: v.to(device=self.device)
            for k, v in self.unitary_dict.items()
        }
Example #14
0
    def test_trainingcomplex(self):
        print("Complex Wavefunction")
        print("--------------------")

        train_samples_path = os.path.join(
            __tests_location__,
            "..",
            "examples",
            "Tutorial2_TrainComplexWavefunction",
            "qubits_train.txt",
        )
        train_bases_path = os.path.join(
            __tests_location__,
            "..",
            "examples",
            "Tutorial2_TrainComplexWavefunction",
            "qubits_train_bases.txt",
        )
        bases_path = os.path.join(
            __tests_location__,
            "..",
            "examples",
            "Tutorial2_TrainComplexWavefunction",
            "qubits_bases.txt",
        )
        psi_path = os.path.join(
            __tests_location__,
            "..",
            "examples",
            "Tutorial2_TrainComplexWavefunction",
            "qubits_psi.txt",
        )

        train_samples, target_psi, train_bases, bases = data.load_data(
            train_samples_path, psi_path, train_bases_path, bases_path)

        unitary_dict = unitaries.create_dict()
        nv = nh = train_samples.shape[-1]

        fidelities = []
        KLs = []

        epochs = 5
        batch_size = 50
        num_chains = 10
        CD = 10
        lr = 0.1
        log_every = 5

        print(
            "Training 10 times and checking fidelity and KL at 5 epochs...\n")
        for i in range(10):
            print("Iteration: ", i + 1)

            nn_state = ComplexWavefunction(unitary_dict=unitary_dict,
                                           num_visible=nv,
                                           num_hidden=nh,
                                           gpu=False)

            space = nn_state.generate_hilbert_space(nv)
            callbacks = [
                MetricEvaluator(
                    log_every,
                    {
                        "Fidelity": ts.fidelity,
                        "KL": ts.KL
                    },
                    target_psi=target_psi,
                    bases=bases,
                    space=space,
                    verbose=True,
                )
            ]

            self.initialize_complex_params(nn_state)

            nn_state.fit(
                data=train_samples,
                epochs=epochs,
                pos_batch_size=batch_size,
                neg_batch_size=num_chains,
                k=CD,
                lr=lr,
                time=True,
                input_bases=train_bases,
                progbar=False,
                callbacks=callbacks,
            )

            fidelities.append(ts.fidelity(nn_state, target_psi, space).item())
            KLs.append(ts.KL(nn_state, target_psi, space, bases=bases).item())

        print("\nStatistics")
        print("----------")
        print(
            "Fidelity: ",
            np.average(fidelities),
            "+/-",
            np.std(fidelities) / np.sqrt(len(fidelities)),
            "\n",
        )
        print("KL: ", np.average(KLs), "+/-",
              np.std(KLs) / np.sqrt(len(KLs)), "\n")

        self.assertTrue(abs(np.average(fidelities) - 0.38) < 0.05)
        self.assertTrue(abs(np.average(KLs) - 0.33) < 0.05)
        self.assertTrue((np.std(fidelities) / np.sqrt(len(fidelities))) < 0.01)
        self.assertTrue((np.std(KLs) / np.sqrt(len(KLs))) < 0.01)
Example #15
0
def density_matrix_data(request, gpu, num_hidden):
    with open(
            os.path.join(request.fspath.dirname, "data", "test_grad_data.pkl"),
            "rb") as f:
        test_data = pickle.load(f)

    qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True)

    data_bases = test_data["density_matrix"]["train_bases"]
    data_samples = torch.tensor(test_data["density_matrix"]["train_samples"],
                                dtype=torch.double)

    bases_data = test_data["density_matrix"]["bases"]
    target_matrix = torch.tensor(test_data["density_matrix"]["density_matrix"],
                                 dtype=torch.double)

    num_visible = data_samples.shape[-1]
    num_aux = num_hidden + 1  # this is not a rule, will change with data

    unitary_dict = unitaries.create_dict()
    nn_state = DensityMatrix(num_visible,
                             num_hidden,
                             num_aux,
                             unitary_dict=unitary_dict,
                             gpu=gpu)
    DGU = DensityGradsUtils(nn_state)

    bases = DGU.transform_bases(bases_data)

    v_space = nn_state.generate_hilbert_space(num_visible)
    a_space = nn_state.generate_hilbert_space(num_aux)

    data_samples = data_samples.to(device=nn_state.device)

    unitary_dict = {
        b: v.to(device=nn_state.device)
        for b, v in unitary_dict.items()
    }

    DensityMatrixFixture = namedtuple(
        "DensityMatrixFixture",
        [
            "data_samples",
            "data_bases",
            "grad_utils",
            "bases",
            "target",
            "v_space",
            "a_space",
            "nn_state",
            "unitary_dict",
        ],
    )

    return DensityMatrixFixture(
        data_samples=data_samples,
        data_bases=data_bases,
        grad_utils=DGU,
        bases=bases,
        target=target_matrix,
        v_space=v_space,
        a_space=a_space,
        nn_state=nn_state,
        unitary_dict=unitary_dict,
    )
Example #16
0
def test_rotate_psi(num_visible, wvfn_type):
    nn_state = wvfn_type(num_visible, gpu=False)
    basis = "X" * num_visible
    unitary_dict = create_dict()
    rotate_psi(nn_state, basis, nn_state.generate_hilbert_space(), unitary_dict)
    print('')


if __name__ == '__main__':
    k = 2
    num_chains = 10
    seed = 1234
    with open('test_data.pkl', 'rb') as fin:
        test_data = pickle.load(fin)

    qucumber.set_random_seed(seed)
    train_bases = test_data['2qubits']['train_bases']
    train_samples = torch.tensor(test_data['2qubits']['train_samples'],
                                 dtype=torch.double)
    bases_data = test_data['2qubits']['bases']
    target_psi_tmp = torch.tensor(test_data['2qubits']['target_psi'],
                                  dtype=torch.double)
    nh = train_samples.shape[-1]
    bases = transform_bases(bases_data)
    unitary_dict = unitaries.create_dict()
    psi_dict = load_target_psi(bases, target_psi_tmp)
    vis = generate_visible_space(train_samples.shape[-1])
    nn_state = ComplexWavefunction(num_visible=train_samples.shape[-1],
                                   num_hidden=nh,
                                   unitary_dict=unitary_dict)
    qr = QuantumReconstruction(nn_state)
    eps = 1.e-6
    run(qr, psi_dict, train_samples, train_bases, unitary_dict, bases, vis,
        eps, k)
Example #18
0
def test_default_unitary_dict():
    unitary_dict = unitaries.create_dict()

    msg = "Default Unitary dictionary has the wrong keys!"
    assert {"X", "Y", "Z"} == set(unitary_dict.keys()), msg