def complex_wavefunction_data(gpu, num_hidden): with open( os.path.join(__tests_location__, "data", "test_grad_data.pkl"), "rb" ) as f: test_data = pickle.load(f) qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True) data_bases = test_data["2qubits"]["train_bases"] data_samples = torch.tensor( test_data["2qubits"]["train_samples"], dtype=torch.double ) bases_data = test_data["2qubits"]["bases"] target_psi_tmp = torch.tensor( test_data["2qubits"]["target_psi"], dtype=torch.double ) num_visible = data_samples.shape[-1] unitary_dict = unitaries.create_dict() nn_state = ComplexWaveFunction( num_visible, num_hidden, unitary_dict=unitary_dict, gpu=gpu ) CGU = ComplexGradsUtils(nn_state) bases = CGU.transform_bases(bases_data) psi_dict = CGU.load_target_psi(bases, target_psi_tmp) vis = nn_state.generate_hilbert_space(num_visible) data_samples = data_samples.to(device=nn_state.device) unitary_dict = {b: v.to(device=nn_state.device) for b, v in unitary_dict.items()} psi_dict = {b: v.to(device=nn_state.device) for b, v in psi_dict.items()} ComplexWaveFunctionFixture = namedtuple( "ComplexWaveFunctionFixture", [ "data_samples", "data_bases", "grad_utils", "bases", "psi_dict", "vis", "nn_state", "unitary_dict", ], ) return ComplexWaveFunctionFixture( data_samples=data_samples, data_bases=data_bases, grad_utils=CGU, bases=bases, psi_dict=psi_dict, vis=vis, nn_state=nn_state, unitary_dict=unitary_dict, )
def complex_wavefunction_data(request, gpu, num_hidden): with open( os.path.join(request.fspath.dirname, "data", "test_grad_data.pkl"), "rb") as f: test_data = pickle.load(f) qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True) data_bases = test_data["2qubits"]["train_bases"] data_samples = torch.tensor(test_data["2qubits"]["train_samples"], dtype=torch.double) all_bases = test_data["2qubits"]["bases"] target_psi_tmp = torch.tensor(test_data["2qubits"]["target_psi"], dtype=torch.double).t() num_visible = data_samples.shape[-1] nn_state = ComplexWaveFunction(num_visible, num_hidden, gpu=gpu) unitary_dict = nn_state.unitary_dict CGU = ComplexGradsUtils(nn_state) all_bases = CGU.transform_bases(all_bases) target = CGU.load_target_psi(all_bases, target_psi_tmp) target = {b: v.to(device=nn_state.device) for b, v in target.items()} space = nn_state.generate_hilbert_space() data_samples = data_samples.to(device=nn_state.device) ComplexWaveFunctionFixture = namedtuple( "ComplexWaveFunctionFixture", [ "data_samples", "data_bases", "grad_utils", "all_bases", "target", "space", "nn_state", "unitary_dict", ], ) return ComplexWaveFunctionFixture( data_samples=data_samples, data_bases=data_bases, grad_utils=CGU, all_bases=all_bases, target=target, space=space, nn_state=nn_state, unitary_dict=unitary_dict, )
def test_complex_training_without_bases_fail(): qucumber.set_random_seed(SEED, cpu=True, gpu=False, quiet=True) np.random.seed(SEED) nn_state = ComplexWaveFunction(10, gpu=False) data = torch.ones(100, 10) msg = "Training ComplexWaveFunction without providing bases should fail!" with pytest.raises(ValueError, message=msg): nn_state.fit(data, epochs=1, pos_batch_size=10, input_bases=None)
def test_complex_wavefunction(gpu): qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True) np.random.seed(SEED) nn_state = ComplexWaveFunction(10, gpu=gpu) old_params = parameters_to_vector(nn_state.rbm_am.parameters()) data = torch.ones(100, 10) # generate sample bases randomly, with probability 0.9 of being 'Z', otherwise 'X' bases = np.where(np.random.binomial(1, 0.9, size=(100, 10)), "Z", "X") nn_state.fit(data, epochs=1, pos_batch_size=10, input_bases=bases) new_params = parameters_to_vector(nn_state.rbm_am.parameters()) msg = "ComplexWaveFunction's parameters did not change!" assert not torch.equal(old_params, new_params), msg
def test_complex_warn_on_gpu(): with pytest.warns(ResourceWarning): ComplexWaveFunction(10, gpu=True)
def test_trainingcomplex(vectorized): print("Complex WaveFunction") print("--------------------") train_samples_path = os.path.join( __tests_location__, "..", "examples", "Tutorial2_TrainComplexWaveFunction", "qubits_train.txt", ) train_bases_path = os.path.join( __tests_location__, "..", "examples", "Tutorial2_TrainComplexWaveFunction", "qubits_train_bases.txt", ) bases_path = os.path.join( __tests_location__, "..", "examples", "Tutorial2_TrainComplexWaveFunction", "qubits_bases.txt", ) psi_path = os.path.join( __tests_location__, "..", "examples", "Tutorial2_TrainComplexWaveFunction", "qubits_psi.txt", ) train_samples, target_psi, train_bases, bases = data.load_data( train_samples_path, psi_path, train_bases_path, bases_path) unitary_dict = unitaries.create_dict() nv = nh = train_samples.shape[-1] fidelities = [] KLs = [] epochs = 5 batch_size = 50 num_chains = 10 CD = 10 lr = 0.1 log_every = 5 print("Training 10 times and checking fidelity and KL at 5 epochs...\n") for i in range(10): print("Iteration: ", i + 1) nn_state = ComplexWaveFunction(unitary_dict=unitary_dict, num_visible=nv, num_hidden=nh, gpu=False) if not vectorized: nn_state.debug_gradient_rotation = True space = nn_state.generate_hilbert_space(nv) callbacks = [ MetricEvaluator( log_every, { "Fidelity": ts.fidelity, "KL": ts.KL }, target_psi=target_psi, bases=bases, space=space, verbose=True, ) ] initialize_complex_params(nn_state) nn_state.fit( data=train_samples, epochs=epochs, pos_batch_size=batch_size, neg_batch_size=num_chains, k=CD, lr=lr, time=True, input_bases=train_bases, progbar=False, callbacks=callbacks, ) fidelities.append(ts.fidelity(nn_state, target_psi, space)) KLs.append(ts.KL(nn_state, target_psi, space, bases=bases)) print("\nStatistics") print("----------") print( "Fidelity: ", np.average(fidelities), "+/-", np.std(fidelities) / np.sqrt(len(fidelities)), "\n", ) print("KL: ", np.average(KLs), "+/-", np.std(KLs) / np.sqrt(len(KLs)), "\n") assert abs(np.average(fidelities) - 0.38) < 0.05 assert abs(np.average(KLs) - 0.33) < 0.05 assert (np.std(fidelities) / np.sqrt(len(fidelities))) < 0.01 assert (np.std(KLs) / np.sqrt(len(KLs))) < 0.01
def quantum_state_training_data(request): nn_state_type = request.param if nn_state_type == PositiveWaveFunction: root = os.path.join( request.fspath.dirname, "..", "examples", "Tutorial1_TrainPosRealWaveFunction", ) train_samples, target = data.load_data( tr_samples_path=os.path.join(root, "tfim1d_data.txt"), tr_psi_path=os.path.join(root, "tfim1d_psi.txt"), ) train_bases, bases = None, None nn_state = PositiveWaveFunction(num_visible=train_samples.shape[-1], gpu=False) batch_size, num_chains = 100, 200 fid_target, kl_target = 0.85, 0.29 reinit_params_fn = initialize_posreal_params elif nn_state_type == ComplexWaveFunction: root = os.path.join( request.fspath.dirname, "..", "examples", "Tutorial2_TrainComplexWaveFunction", ) train_samples, target, train_bases, bases = data.load_data( tr_samples_path=os.path.join(root, "qubits_train.txt"), tr_psi_path=os.path.join(root, "qubits_psi.txt"), tr_bases_path=os.path.join(root, "qubits_train_bases.txt"), bases_path=os.path.join(root, "qubits_bases.txt"), ) nn_state = ComplexWaveFunction(num_visible=train_samples.shape[-1], gpu=False) batch_size, num_chains = 50, 10 fid_target, kl_target = 0.38, 0.33 reinit_params_fn = initialize_complex_params elif nn_state_type == DensityMatrix: root = os.path.join(request.fspath.dirname, "..", "examples", "Tutorial3_TrainDensityMatrix") train_samples, target, train_bases, bases = data.load_data_DM( tr_samples_path=os.path.join(root, "N2_W_state_100_samples_data.txt"), tr_mtx_real_path=os.path.join(root, "N2_W_state_target_real.txt"), tr_mtx_imag_path=os.path.join(root, "N2_W_state_target_imag.txt"), tr_bases_path=os.path.join(root, "N2_W_state_100_samples_bases.txt"), bases_path=os.path.join(root, "N2_IC_bases.txt"), ) nn_state = DensityMatrix(num_visible=train_samples.shape[-1], gpu=False) batch_size, num_chains = 100, 10 fid_target, kl_target = 0.45, 0.42 def reinit_params_fn(request, nn_state): nn_state.reinitialize_parameters() else: raise ValueError( f"invalid test config: {nn_state_type} is not a valid quantum state type" ) return { "nn_state": nn_state, "data": train_samples, "input_bases": train_bases, "target": target, "bases": bases, "epochs": 5, "pos_batch_size": batch_size, "neg_batch_size": num_chains, "k": 10, "lr": 0.1, "space": nn_state.generate_hilbert_space(), "fid_target": fid_target, "kl_target": kl_target, "reinit_params_fn": reinit_params_fn, }