Exemplo n.º 1
0
def complex_wavefunction_data(gpu, num_hidden):
    with open(
        os.path.join(__tests_location__, "data", "test_grad_data.pkl"), "rb"
    ) as f:
        test_data = pickle.load(f)

    qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True)

    data_bases = test_data["2qubits"]["train_bases"]
    data_samples = torch.tensor(
        test_data["2qubits"]["train_samples"], dtype=torch.double
    )

    bases_data = test_data["2qubits"]["bases"]
    target_psi_tmp = torch.tensor(
        test_data["2qubits"]["target_psi"], dtype=torch.double
    )

    num_visible = data_samples.shape[-1]

    unitary_dict = unitaries.create_dict()
    nn_state = ComplexWaveFunction(
        num_visible, num_hidden, unitary_dict=unitary_dict, gpu=gpu
    )
    CGU = ComplexGradsUtils(nn_state)

    bases = CGU.transform_bases(bases_data)

    psi_dict = CGU.load_target_psi(bases, target_psi_tmp)
    vis = nn_state.generate_hilbert_space(num_visible)

    data_samples = data_samples.to(device=nn_state.device)

    unitary_dict = {b: v.to(device=nn_state.device) for b, v in unitary_dict.items()}
    psi_dict = {b: v.to(device=nn_state.device) for b, v in psi_dict.items()}

    ComplexWaveFunctionFixture = namedtuple(
        "ComplexWaveFunctionFixture",
        [
            "data_samples",
            "data_bases",
            "grad_utils",
            "bases",
            "psi_dict",
            "vis",
            "nn_state",
            "unitary_dict",
        ],
    )

    return ComplexWaveFunctionFixture(
        data_samples=data_samples,
        data_bases=data_bases,
        grad_utils=CGU,
        bases=bases,
        psi_dict=psi_dict,
        vis=vis,
        nn_state=nn_state,
        unitary_dict=unitary_dict,
    )
Exemplo n.º 2
0
def complex_wavefunction_data(request, gpu, num_hidden):
    with open(
            os.path.join(request.fspath.dirname, "data", "test_grad_data.pkl"),
            "rb") as f:
        test_data = pickle.load(f)

    qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True)

    data_bases = test_data["2qubits"]["train_bases"]
    data_samples = torch.tensor(test_data["2qubits"]["train_samples"],
                                dtype=torch.double)

    all_bases = test_data["2qubits"]["bases"]
    target_psi_tmp = torch.tensor(test_data["2qubits"]["target_psi"],
                                  dtype=torch.double).t()

    num_visible = data_samples.shape[-1]

    nn_state = ComplexWaveFunction(num_visible, num_hidden, gpu=gpu)
    unitary_dict = nn_state.unitary_dict

    CGU = ComplexGradsUtils(nn_state)

    all_bases = CGU.transform_bases(all_bases)

    target = CGU.load_target_psi(all_bases, target_psi_tmp)
    target = {b: v.to(device=nn_state.device) for b, v in target.items()}

    space = nn_state.generate_hilbert_space()
    data_samples = data_samples.to(device=nn_state.device)

    ComplexWaveFunctionFixture = namedtuple(
        "ComplexWaveFunctionFixture",
        [
            "data_samples",
            "data_bases",
            "grad_utils",
            "all_bases",
            "target",
            "space",
            "nn_state",
            "unitary_dict",
        ],
    )

    return ComplexWaveFunctionFixture(
        data_samples=data_samples,
        data_bases=data_bases,
        grad_utils=CGU,
        all_bases=all_bases,
        target=target,
        space=space,
        nn_state=nn_state,
        unitary_dict=unitary_dict,
    )
Exemplo n.º 3
0
def test_trainingcomplex(vectorized):
    print("Complex WaveFunction")
    print("--------------------")

    train_samples_path = os.path.join(
        __tests_location__,
        "..",
        "examples",
        "Tutorial2_TrainComplexWaveFunction",
        "qubits_train.txt",
    )
    train_bases_path = os.path.join(
        __tests_location__,
        "..",
        "examples",
        "Tutorial2_TrainComplexWaveFunction",
        "qubits_train_bases.txt",
    )
    bases_path = os.path.join(
        __tests_location__,
        "..",
        "examples",
        "Tutorial2_TrainComplexWaveFunction",
        "qubits_bases.txt",
    )
    psi_path = os.path.join(
        __tests_location__,
        "..",
        "examples",
        "Tutorial2_TrainComplexWaveFunction",
        "qubits_psi.txt",
    )

    train_samples, target_psi, train_bases, bases = data.load_data(
        train_samples_path, psi_path, train_bases_path, bases_path)

    unitary_dict = unitaries.create_dict()
    nv = nh = train_samples.shape[-1]

    fidelities = []
    KLs = []

    epochs = 5
    batch_size = 50
    num_chains = 10
    CD = 10
    lr = 0.1
    log_every = 5

    print("Training 10 times and checking fidelity and KL at 5 epochs...\n")
    for i in range(10):
        print("Iteration: ", i + 1)

        nn_state = ComplexWaveFunction(unitary_dict=unitary_dict,
                                       num_visible=nv,
                                       num_hidden=nh,
                                       gpu=False)

        if not vectorized:
            nn_state.debug_gradient_rotation = True

        space = nn_state.generate_hilbert_space(nv)
        callbacks = [
            MetricEvaluator(
                log_every,
                {
                    "Fidelity": ts.fidelity,
                    "KL": ts.KL
                },
                target_psi=target_psi,
                bases=bases,
                space=space,
                verbose=True,
            )
        ]

        initialize_complex_params(nn_state)

        nn_state.fit(
            data=train_samples,
            epochs=epochs,
            pos_batch_size=batch_size,
            neg_batch_size=num_chains,
            k=CD,
            lr=lr,
            time=True,
            input_bases=train_bases,
            progbar=False,
            callbacks=callbacks,
        )

        fidelities.append(ts.fidelity(nn_state, target_psi, space))
        KLs.append(ts.KL(nn_state, target_psi, space, bases=bases))

    print("\nStatistics")
    print("----------")
    print(
        "Fidelity: ",
        np.average(fidelities),
        "+/-",
        np.std(fidelities) / np.sqrt(len(fidelities)),
        "\n",
    )
    print("KL: ", np.average(KLs), "+/-",
          np.std(KLs) / np.sqrt(len(KLs)), "\n")

    assert abs(np.average(fidelities) - 0.38) < 0.05
    assert abs(np.average(KLs) - 0.33) < 0.05
    assert (np.std(fidelities) / np.sqrt(len(fidelities))) < 0.01
    assert (np.std(KLs) / np.sqrt(len(KLs))) < 0.01