Beispiel #1
0
def complex_wavefunction_data(gpu, num_hidden):
    with open(
        os.path.join(__tests_location__, "data", "test_grad_data.pkl"), "rb"
    ) as f:
        test_data = pickle.load(f)

    qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True)

    data_bases = test_data["2qubits"]["train_bases"]
    data_samples = torch.tensor(
        test_data["2qubits"]["train_samples"], dtype=torch.double
    )

    bases_data = test_data["2qubits"]["bases"]
    target_psi_tmp = torch.tensor(
        test_data["2qubits"]["target_psi"], dtype=torch.double
    )

    num_visible = data_samples.shape[-1]

    unitary_dict = unitaries.create_dict()
    nn_state = ComplexWaveFunction(
        num_visible, num_hidden, unitary_dict=unitary_dict, gpu=gpu
    )
    CGU = ComplexGradsUtils(nn_state)

    bases = CGU.transform_bases(bases_data)

    psi_dict = CGU.load_target_psi(bases, target_psi_tmp)
    vis = nn_state.generate_hilbert_space(num_visible)

    data_samples = data_samples.to(device=nn_state.device)

    unitary_dict = {b: v.to(device=nn_state.device) for b, v in unitary_dict.items()}
    psi_dict = {b: v.to(device=nn_state.device) for b, v in psi_dict.items()}

    ComplexWaveFunctionFixture = namedtuple(
        "ComplexWaveFunctionFixture",
        [
            "data_samples",
            "data_bases",
            "grad_utils",
            "bases",
            "psi_dict",
            "vis",
            "nn_state",
            "unitary_dict",
        ],
    )

    return ComplexWaveFunctionFixture(
        data_samples=data_samples,
        data_bases=data_bases,
        grad_utils=CGU,
        bases=bases,
        psi_dict=psi_dict,
        vis=vis,
        nn_state=nn_state,
        unitary_dict=unitary_dict,
    )
Beispiel #2
0
def complex_wavefunction_data(request, gpu, num_hidden):
    with open(
            os.path.join(request.fspath.dirname, "data", "test_grad_data.pkl"),
            "rb") as f:
        test_data = pickle.load(f)

    qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True)

    data_bases = test_data["2qubits"]["train_bases"]
    data_samples = torch.tensor(test_data["2qubits"]["train_samples"],
                                dtype=torch.double)

    all_bases = test_data["2qubits"]["bases"]
    target_psi_tmp = torch.tensor(test_data["2qubits"]["target_psi"],
                                  dtype=torch.double).t()

    num_visible = data_samples.shape[-1]

    nn_state = ComplexWaveFunction(num_visible, num_hidden, gpu=gpu)
    unitary_dict = nn_state.unitary_dict

    CGU = ComplexGradsUtils(nn_state)

    all_bases = CGU.transform_bases(all_bases)

    target = CGU.load_target_psi(all_bases, target_psi_tmp)
    target = {b: v.to(device=nn_state.device) for b, v in target.items()}

    space = nn_state.generate_hilbert_space()
    data_samples = data_samples.to(device=nn_state.device)

    ComplexWaveFunctionFixture = namedtuple(
        "ComplexWaveFunctionFixture",
        [
            "data_samples",
            "data_bases",
            "grad_utils",
            "all_bases",
            "target",
            "space",
            "nn_state",
            "unitary_dict",
        ],
    )

    return ComplexWaveFunctionFixture(
        data_samples=data_samples,
        data_bases=data_bases,
        grad_utils=CGU,
        all_bases=all_bases,
        target=target,
        space=space,
        nn_state=nn_state,
        unitary_dict=unitary_dict,
    )
Beispiel #3
0
def test_complex_training_without_bases_fail():
    qucumber.set_random_seed(SEED, cpu=True, gpu=False, quiet=True)
    np.random.seed(SEED)

    nn_state = ComplexWaveFunction(10, gpu=False)

    data = torch.ones(100, 10)

    msg = "Training ComplexWaveFunction without providing bases should fail!"
    with pytest.raises(ValueError, message=msg):
        nn_state.fit(data, epochs=1, pos_batch_size=10, input_bases=None)
Beispiel #4
0
def test_complex_wavefunction(gpu):
    qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True)
    np.random.seed(SEED)

    nn_state = ComplexWaveFunction(10, gpu=gpu)

    old_params = parameters_to_vector(nn_state.rbm_am.parameters())

    data = torch.ones(100, 10)

    # generate sample bases randomly, with probability 0.9 of being 'Z', otherwise 'X'
    bases = np.where(np.random.binomial(1, 0.9, size=(100, 10)), "Z", "X")

    nn_state.fit(data, epochs=1, pos_batch_size=10, input_bases=bases)

    new_params = parameters_to_vector(nn_state.rbm_am.parameters())

    msg = "ComplexWaveFunction's parameters did not change!"
    assert not torch.equal(old_params, new_params), msg
Beispiel #5
0
def test_complex_warn_on_gpu():
    with pytest.warns(ResourceWarning):
        ComplexWaveFunction(10, gpu=True)
Beispiel #6
0
def test_trainingcomplex(vectorized):
    print("Complex WaveFunction")
    print("--------------------")

    train_samples_path = os.path.join(
        __tests_location__,
        "..",
        "examples",
        "Tutorial2_TrainComplexWaveFunction",
        "qubits_train.txt",
    )
    train_bases_path = os.path.join(
        __tests_location__,
        "..",
        "examples",
        "Tutorial2_TrainComplexWaveFunction",
        "qubits_train_bases.txt",
    )
    bases_path = os.path.join(
        __tests_location__,
        "..",
        "examples",
        "Tutorial2_TrainComplexWaveFunction",
        "qubits_bases.txt",
    )
    psi_path = os.path.join(
        __tests_location__,
        "..",
        "examples",
        "Tutorial2_TrainComplexWaveFunction",
        "qubits_psi.txt",
    )

    train_samples, target_psi, train_bases, bases = data.load_data(
        train_samples_path, psi_path, train_bases_path, bases_path)

    unitary_dict = unitaries.create_dict()
    nv = nh = train_samples.shape[-1]

    fidelities = []
    KLs = []

    epochs = 5
    batch_size = 50
    num_chains = 10
    CD = 10
    lr = 0.1
    log_every = 5

    print("Training 10 times and checking fidelity and KL at 5 epochs...\n")
    for i in range(10):
        print("Iteration: ", i + 1)

        nn_state = ComplexWaveFunction(unitary_dict=unitary_dict,
                                       num_visible=nv,
                                       num_hidden=nh,
                                       gpu=False)

        if not vectorized:
            nn_state.debug_gradient_rotation = True

        space = nn_state.generate_hilbert_space(nv)
        callbacks = [
            MetricEvaluator(
                log_every,
                {
                    "Fidelity": ts.fidelity,
                    "KL": ts.KL
                },
                target_psi=target_psi,
                bases=bases,
                space=space,
                verbose=True,
            )
        ]

        initialize_complex_params(nn_state)

        nn_state.fit(
            data=train_samples,
            epochs=epochs,
            pos_batch_size=batch_size,
            neg_batch_size=num_chains,
            k=CD,
            lr=lr,
            time=True,
            input_bases=train_bases,
            progbar=False,
            callbacks=callbacks,
        )

        fidelities.append(ts.fidelity(nn_state, target_psi, space))
        KLs.append(ts.KL(nn_state, target_psi, space, bases=bases))

    print("\nStatistics")
    print("----------")
    print(
        "Fidelity: ",
        np.average(fidelities),
        "+/-",
        np.std(fidelities) / np.sqrt(len(fidelities)),
        "\n",
    )
    print("KL: ", np.average(KLs), "+/-",
          np.std(KLs) / np.sqrt(len(KLs)), "\n")

    assert abs(np.average(fidelities) - 0.38) < 0.05
    assert abs(np.average(KLs) - 0.33) < 0.05
    assert (np.std(fidelities) / np.sqrt(len(fidelities))) < 0.01
    assert (np.std(KLs) / np.sqrt(len(KLs))) < 0.01
Beispiel #7
0
def quantum_state_training_data(request):
    nn_state_type = request.param

    if nn_state_type == PositiveWaveFunction:

        root = os.path.join(
            request.fspath.dirname,
            "..",
            "examples",
            "Tutorial1_TrainPosRealWaveFunction",
        )

        train_samples, target = data.load_data(
            tr_samples_path=os.path.join(root, "tfim1d_data.txt"),
            tr_psi_path=os.path.join(root, "tfim1d_psi.txt"),
        )
        train_bases, bases = None, None

        nn_state = PositiveWaveFunction(num_visible=train_samples.shape[-1],
                                        gpu=False)

        batch_size, num_chains = 100, 200
        fid_target, kl_target = 0.85, 0.29

        reinit_params_fn = initialize_posreal_params

    elif nn_state_type == ComplexWaveFunction:

        root = os.path.join(
            request.fspath.dirname,
            "..",
            "examples",
            "Tutorial2_TrainComplexWaveFunction",
        )

        train_samples, target, train_bases, bases = data.load_data(
            tr_samples_path=os.path.join(root, "qubits_train.txt"),
            tr_psi_path=os.path.join(root, "qubits_psi.txt"),
            tr_bases_path=os.path.join(root, "qubits_train_bases.txt"),
            bases_path=os.path.join(root, "qubits_bases.txt"),
        )

        nn_state = ComplexWaveFunction(num_visible=train_samples.shape[-1],
                                       gpu=False)

        batch_size, num_chains = 50, 10
        fid_target, kl_target = 0.38, 0.33

        reinit_params_fn = initialize_complex_params

    elif nn_state_type == DensityMatrix:

        root = os.path.join(request.fspath.dirname, "..", "examples",
                            "Tutorial3_TrainDensityMatrix")

        train_samples, target, train_bases, bases = data.load_data_DM(
            tr_samples_path=os.path.join(root,
                                         "N2_W_state_100_samples_data.txt"),
            tr_mtx_real_path=os.path.join(root, "N2_W_state_target_real.txt"),
            tr_mtx_imag_path=os.path.join(root, "N2_W_state_target_imag.txt"),
            tr_bases_path=os.path.join(root,
                                       "N2_W_state_100_samples_bases.txt"),
            bases_path=os.path.join(root, "N2_IC_bases.txt"),
        )

        nn_state = DensityMatrix(num_visible=train_samples.shape[-1],
                                 gpu=False)

        batch_size, num_chains = 100, 10
        fid_target, kl_target = 0.45, 0.42

        def reinit_params_fn(request, nn_state):
            nn_state.reinitialize_parameters()

    else:
        raise ValueError(
            f"invalid test config: {nn_state_type} is not a valid quantum state type"
        )

    return {
        "nn_state": nn_state,
        "data": train_samples,
        "input_bases": train_bases,
        "target": target,
        "bases": bases,
        "epochs": 5,
        "pos_batch_size": batch_size,
        "neg_batch_size": num_chains,
        "k": 10,
        "lr": 0.1,
        "space": nn_state.generate_hilbert_space(),
        "fid_target": fid_target,
        "kl_target": kl_target,
        "reinit_params_fn": reinit_params_fn,
    }